neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 0939c40af90b9ce5ce7b734d12783cf99c6f06bf
parent 9bc94d6c185be755135e1ed24919ddd92a78cfe4
Author: Steven Atkinson <[email protected]>
Date:   Sun, 21 Feb 2021 15:49:39 -0500

Revert "PyTorch"

Diffstat:
MREADME.md | 4+---
Mmodels.py | 158++++++++++++++++++++++++++++++++++++++++++++++++++-----------------------------
Mreamp.py | 79++++++++++++++++++++++++++++++++-----------------------------------------------
Mrequirements.txt | 11+----------
Mtrain.py | 270++++++++++++++++++++++++++++++-------------------------------------------------
5 files changed, 236 insertions(+), 286 deletions(-)

diff --git a/README.md b/README.md @@ -4,9 +4,7 @@ Let's use deep learning to make a model of a guitar amp! ## Setup -```bash -`pip install -r requirements.txt` -``` +Use `pip install -r requirements.txt` as well as intalling the TensorFlow version of your choice (e.g. `pip install tensorflow-gpu`). ## Autoregressive models diff --git a/models.py b/models.py @@ -2,35 +2,84 @@ # File Created: Sunday, 30th December 2018 9:42:29 pm # Author: Steven Atkinson ([email protected]) +import numpy as np +import tensorflow as tf import abc -import json -import os from tempfile import mkdtemp - -import numpy as np -import torch -import torch.nn as nn +import os +import json -def from_json(f): +def from_json(f, n_train=1, checkpoint_path=None): if isinstance(f, str): with open(f, "r") as json_file: f = json.load(json_file) - + if f["type"] == "FullyConnected": - return mlp(f["input_length"], 1, layer_sizes=f["layer_sizes"]) + return FullyConnected(n_train, f["input_length"], + layer_sizes=f["layer_sizes"], checkpoint_path=checkpoint_path) else: - raise NotImplementedError("Model type {} unrecognized".format(f["type"])) + raise NotImplementedError("Model type {} unrecognized".format( + f["type"])) -class Model(nn.Module): +class Model(object): """ Model parent class """ + def __init__(self, n_train, sess=None, checkpoint_path=None): + """ + Make sure child classes call _build() after this! + """ + if sess is None: + sess = tf.get_default_session() + self.sess = sess + + if checkpoint_path is None: + checkpoint_path = os.path.join(mkdtemp(), "model.ckpt") + if not os.path.isdir(os.path.dirname(checkpoint_path)): + os.makedirs(os.path.dirname(checkpoint_path)) + self.checkpoint_path = checkpoint_path + + # self._batch_size = batch_size + self._n_train = n_train + self.target = tf.placeholder(tf.float32, shape=(None, 1)) + self.prediction = None + self.total_prediction_loss = None + self.rmse = None + self.loss = None # Includes any regularization + + # @property + # def batch_size(self): + # return self._batch_size + + @property + def n_train(self): + return self._n_train + + def load(self, checkpoint_path=None): + checkpoint_path = checkpoint_path or self.checkpoint_path + + try: + ckpt = tf.train.get_checkpoint_state(checkpoint_path) + print("Loading model: {}".format(ckpt.model_checkpoint_path)) + self.saver.restore(self.sess, ckpt.model_checkpoint_path) + except Exception as e: + print("Error while attempting to load model: {}".format(e)) + + @abc.abstractclassmethod + def predict(self, x): + """ + A nice function for prediction. + :param x: input array (length=n) + :type x: array-like + :return: (array-like) corresponding predicted outputs (length=n) + """ + raise NotImplementedError("Implement predict()") - @abc.abstractmethod - def predict_sequence(self, x): - raise NotImplementedError() + def save(self, iter, checkpoint_path=None): + checkpoint_path = checkpoint_path or self.checkpoint_path + self.saver.save(self.sess, checkpoint_path, global_step=iter) def _build(self): self.prediction = self._build_prediction() @@ -41,18 +90,18 @@ class Model(nn.Module): self.saver = tf.train.Saver(tf.global_variables()) def _build_loss(self): - self.total_prediction_loss = tf.losses.mean_squared_error( - self.target, self.prediction, weights=self.n_train - ) - + self.total_prediction_loss = tf.losses.mean_squared_error(self.target, + self.prediction, weights=self.n_train) + # Don't count this as a loss! - self.rmse = tf.sqrt(self.total_prediction_loss / self.n_train) + self.rmse = tf.sqrt( + self.total_prediction_loss / self.n_train) return tf.losses.get_total_loss() @abc.abstractclassmethod def _build_prediction(self): - raise NotImplementedError("Implement prediction for model") + raise NotImplementedError('Implement prediction for model') class Autoregressive(Model): @@ -60,62 +109,57 @@ class Autoregressive(Model): Autoregressive models that take in a few of the most recent input samples and predict the output at the last time point. """ - - @abc.abstractproperty + def __init__(self, n_train, input_length, sess=None, + checkpoint_path=None): + super().__init__(n_train, sess=sess, checkpoint_path=checkpoint_path) + self._input_length = input_length + self.x = tf.placeholder(tf.float32, shape=(None, self.input_length)) + + @property def input_length(self): - raise NotImplementedError() + return self._input_length - def predict_sequence(self, x: torch.Tensor, batch_size=None, verbose=False): + def predict(self, x, batch_size=None, verbose=False): """ Return 1D array of predictions same length as x """ - n = x.numel() + n = x.size batch_size = batch_size or n # Pad x with leading zeros: - x = torch.cat((torch.zeros(self.input_length - 1).to(x.device), x)) + x = np.concatenate((np.zeros(self.input_length - 1), x)) i = 0 y = [] while i < n: + if verbose: + print("model.predict {}/{}".format(i, n)) this_batch_size = np.min([batch_size, n - i]) # Reshape into a batch: - x_mtx = torch.stack( - [x[j : j + self.input_length] for j in range(i, i + this_batch_size)] - ) + x_mtx = np.stack([x[j: j + self.input_length] + for j in range(i, i + this_batch_size)]) # Predict and flatten. - y.append(self(x_mtx).squeeze()) + y.append(self.sess.run(self.prediction, feed_dict={self.x: x_mtx}) \ + .flatten()) i += this_batch_size - return torch.cat(y) + return np.concatenate(y) class FullyConnected(Autoregressive): """ Autoregressive model taking in a sequence of the most recent inputs, putting - them through a series of FC layers, and outputting the single output at the + them through a series of FC layers, and outputting the single output at the last time step. """ + def __init__(self, n_train, input_length, layer_sizes=(512,), + sess=None, checkpoint_path=None): + super().__init__(n_train, input_length, sess=sess, + checkpoint_path=checkpoint_path) + self._layer_sizes = layer_sizes + self._build() - def __init__(self, net): - super().__init__() - self._net = net - - @property - def input_length(self): - return self._net[0][0].weight.data.shape[1] - - def forward(self, inputs): - return self._net(inputs) - - -def mlp(dx, dy, layer_sizes=None): - def block(dx, dy, Activation=nn.ReLU): - return nn.Sequential(nn.Linear(dx, dy), Activation()) - - layer_sizes = [256, 256] if layer_sizes is None else layer_sizes - - net = nn.Sequential() - in_features = dx - for i, out_features in enumerate(layer_sizes): - net.add_module("layer_%i" % i, block(in_features, out_features)) - in_features = out_features - net.add_module("head", block(in_features, dy, Activation=nn.Tanh)) - return FullyConnected(net) + def _build_prediction(self): + h = self.x + for m in self._layer_sizes: + h = tf.contrib.layers.fully_connected(h, m) + y = -1.0 + 2.0 * tf.contrib.layers.fully_connected(h, 1, + activation_fn=tf.nn.sigmoid) + return y diff --git a/reamp.py b/reamp.py @@ -4,13 +4,11 @@ Reamp a .wav file Assumes 24-bit WAV files """ - -import os from argparse import ArgumentParser - -import matplotlib.pyplot as plt -import torch +import tensorflow as tf +import os import wavio +import matplotlib.pyplot as plt import models @@ -21,56 +19,41 @@ def _sampwidth_to_bits(x): if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument( - "architecture", type=str, help="JSON filename containing NN architecture" - ) - parser.add_argument( - "params", type=str, help="directory holding model checkpoint to use" - ) - parser.add_argument("input_file", type=str, help="Input .wav file to convert") - parser.add_argument( - "--output_file", type=str, default=None, help="Where to save the output" - ) - parser.add_argument( - "--batch_size", - type=int, - default=8192, - help="How many samples to process at a time. " - + "Reduce if there are out-of-memory issues.", - ) - parser.add_argument( - "--target_file", - type=str, - default=None, - help=".wav file of the true output (if you want to compare)", - ) + parser.add_argument("architecture", type=str, + help="JSON filename containing NN architecture") + parser.add_argument("checkpoint_dir", type=str, + help="directory holding model checkpoint to use") + parser.add_argument("input_file", type=str, + help="Input .wav file to convert") + parser.add_argument("--output_file", type=str, default=None, + help="Where to save the output") + parser.add_argument("--batch_size", type=int, default=8192, + help="How many samples to process at a time. " + + "Reduce if there are out-of-memory issues.") + parser.add_argument("--target_file", type=str, default=None, + help=".wav file of the true output (if you want to compare)") args = parser.parse_args() - output_file = ( - args.output_file - if args.output_file is not None - else args.input_file.rstrip(".wav") + "_reamped.wav" - ) - + if args.output_file is None: + args.output_file = args.input_file.rstrip(".wav") + "_reamped.wav" + if os.path.isfile(args.output_file): print("Output file exists; skip") exit(1) - + x = wavio.read(args.input_file) rate, sampwidth = x.rate, x.sampwidth bits = _sampwidth_to_bits(sampwidth) - x_data = torch.Tensor(x.data.flatten() / 2 ** (bits - 1)) - - model = models.from_json(args.architecture) - model.load_state_dict(torch.load(args.params)) - with torch.no_grad(): - y = model.predict_sequence( - x_data, batch_size=args.batch_size, verbose=True - ).numpy() - wavio.write( - output_file, y * 2 ** (bits - 1), rate, scale="none", sampwidth=sampwidth - ) - + x_data = x.data.flatten() / 2 ** (bits - 1) + + with tf.Session() as sess: + model = models.from_json(args.architecture, + checkpoint_path=args.checkpoint_dir) + model.load() + y = model.predict(x_data, batch_size=args.batch_size, verbose=True) + wavio.write(args.output_file, y * 2 ** (bits - 1), rate, scale="none", + sampwidth=sampwidth) + if args.target_file is not None and os.path.isfile(args.target_file): t = wavio.read(args.target_file) t_data = t.data.flatten() / 2 ** (_sampwidth_to_bits(t.sampwidth) - 1) @@ -80,3 +63,4 @@ if __name__ == "__main__": plt.plot(y) plt.legend(["Input", "Target", "Prediction"]) plt.show() + +\ No newline at end of file diff --git a/requirements.txt b/requirements.txt @@ -1,11 +1,2 @@ -# File: requirements.txt -# Created Date: 2019-01-02 -# Author: Steven Atkinson ([email protected]) - -black -flake8 -matplotlib -numpy -pytest -torch>=1 +tensorflow wavio diff --git a/train.py b/train.py @@ -4,62 +4,42 @@ """ Here's a script for training new models. - -TODO -* Device -* Lightning? """ -import abc -import json -import os from argparse import ArgumentParser -from time import time - -import matplotlib.pyplot as plt import numpy as np -import torch -import torch.nn as nn +import abc +import tensorflow as tf +import matplotlib.pyplot as plt +from time import time import wavio +import json +import os import models # Parameters for training -check_training_at = np.concatenate( - [10 ** pwr * np.array([1, 2, 5]) for pwr in np.arange(0, 7)] -) -plot_kwargs = {"window": (30000, 40000)} +check_training_at = np.concatenate([10 ** pwr * np.array([1, 2, 5]) + for pwr in np.arange(0, 7)]) +plot_kwargs = { + "window": (30000, 40000) +} wav_kwargs = {"window": (0, 5 * 44100)} -torch.manual_seed(0) -_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - -def ensure_save_dir(args): - save_dir = ( - os.path.join(os.path.dirname(__file__), "output") - if args.save_dir is None - else args.save_dir - ) - if not os.path.isdir(save_dir): - os.makedirs(save_dir) - return save_dir - class Data(object): """ - Object for holding data and spitting out minibatches for training/segments + Object for holding data and spitting out minibatches for training/segments for testing. """ - def __init__(self, fname, input_length, batch_size=None): xy = np.load(fname) - self.x = torch.Tensor(xy[0]) - self.y = torch.Tensor(xy[1]) - self._n = self.x.numel() - input_length + 1 + self.x = xy[0] + self.y = xy[1] + self._n = self.x.size - input_length + 1 self.input_length = input_length self.batch_size = batch_size - + @property def n(self): return self._n @@ -70,101 +50,70 @@ class Data(object): """ if ilist is None: n = n or self.batch_size - ilist = torch.randint(0, self.n, size=(n,)) - x = torch.stack([self.x[i : i + self.input_length] for i in ilist]) - y = torch.stack([self.y[i + self.input_length - 1] for i in ilist])[:, None] - return x.to(_device), y.to(_device) + ilist = np.random.randint(0, self.n, size=(n,)) + x = np.stack([self.x[i: i + self.input_length] for i in ilist]) + y = np.array([self.y[i + self.input_length - 1] for i in ilist]) \ + [:, np.newaxis] + return x, y def sequence(self, start, end): end += self.input_length - 1 - return self.x[start:end], self.y[start + self.input_length - 1 : end] - - -def train_step(model, optimizer, batch): - model.train() - model.zero_grad() - inputs, targets = batch - preds = model(inputs) - loss = nn.MSELoss()(preds, targets) - loss.backward() - optimizer.step() - return loss.item() - - -def validation_step(model, batch): - with torch.no_grad(): - model.eval() - inputs, targets = batch - return nn.MSELoss()(model(inputs), targets).item() + return self.x[start: end], self.y[start + self.input_length - 1: end] -def train( - model, - train_data, - batch_size=None, - n_minibatches=10, - validation_data=None, - plot_at=(), - wav_at=(), - plot_kwargs={}, - wav_kwargs={}, - save_every=100, - validate_every=100, -): - optimizer = torch.optim.Adam(model.parameters(), lr=0.001) +def train(model, train_data, batch_size=None, n_minibatches=10, + validation_data=None, plot_at=(), wav_at=(), plot_kwargs={}, + wav_kwargs={}, save_every=100, validate_every=100): + save_dir = os.path.dirname(model.checkpoint_path) + sess = model.sess + opt = tf.train.AdamOptimizer().minimize(model.loss) + sess.run(tf.global_variables_initializer()) # For opt t_loss_list, v_loss_list = [], [] t0 = time() for i in range(n_minibatches): - batch = train_data.minibatch(batch_size) - t_loss = train_step(model, optimizer, batch) + x, y = train_data.minibatch(batch_size) + t_loss, _ = sess.run((model.rmse, opt), + feed_dict={model.x: x, model.target: y}) t_loss_list.append([i, t_loss]) - print( - "t={:7} | MB {:>7} / {:>7} | TLoss={:8}".format( - int(time() - t0), i + 1, n_minibatches, t_loss - ) - ) - + print("t={:7} | MB {:>7} / {:>7} | TLoss={:8}".format(int(time() - t0), + i + 1, n_minibatches, t_loss)) + # Callbacks, basically... if i + 1 in plot_at: - plot_predictions( - model, - validation_data if validation_data is not None else train_data, - title="Minibatch {}".format(i + 1), - fname="{}/mb_{}.png".format(save_dir, i + 1), - **plot_kwargs - ) + plot_predictions(model, + validation_data if validation_data is not None else train_data, + title="Minibatch {}".format(i + 1), + fname="{}/mb_{}.png".format(save_dir, i + 1), + **plot_kwargs) if i + 1 in wav_at: print("Making wav for mb {}".format(i + 1)) - predict( - model, - validation_data, - save_wav_file="{}/predict_{}.wav".format(save_dir, i + 1), - **wav_kwargs - ) + predict(model, validation_data, + save_wav_file="{}/predict_{}.wav".format(save_dir, i + 1), + **wav_kwargs) if (i + 1) % save_every == 0: - torch.save(model.state_dict(), os.path.join(save_dir, "model_%i.pt" % i)) + model.save(iter=i + 1) if i == 0 or (i + 1) % validate_every == 0: - v_loss = validation_step(model, batch) + v_loss, _ = sess.run((model.rmse, opt), + feed_dict={model.x: x, model.target: y}) print("VLoss={:8}".format(v_loss)) v_loss_list.append([i, v_loss]) - + # After training loop... if validation_data is not None: - batch = validation_data.minibatch(train_data.batch_size) - v_loss = validation_step(model, batch) + x, y = validation_data.minibatch(train_data.batch_size) + v_loss = sess.run(model.rmse, + feed_dict={model.x: x, model.target: y}) print("Validation loss={}".format(v_loss)) - torch.save(model.state_dict(), os.path.join(save_dir, "model.pt")) return np.array(t_loss_list).T, np.array(v_loss_list).T def plot_predictions(model, data, title=None, fname=None, window=None): - with torch.no_grad(): - x, y, t = predict(model, data, window=window) - plt.figure(figsize=(12, 4)) - plt.plot(x.cpu()) - plt.plot(t.cpu()) - plt.plot(y.cpu()) - plt.legend(("Input", "Target", "Prediction")) + x, y, t = predict(model, data, window=window) + plt.figure(figsize=(12, 4)) + plt.plot(x) + plt.plot(t) + plt.plot(y) + plt.legend(('Input', 'Target', 'Prediction')) if title is not None: plt.title(title) if fname is not None: @@ -184,27 +133,21 @@ def plot_loss(t_loss, v_loss, fname): plt.legend(("Training", "Validation")) plt.savefig(fname) plt.close() - + def predict(model, data, window=None, save_wav_file=None): - with torch.no_grad(): - x, t = data.x.to(_device), data.y.to(_device) - if window is not None: - x, t = x[window[0] : window[1]], t[window[0] : window[1]] - y = model.predict_sequence(x).squeeze() + x, t = data.x, data.y + if window is not None: + x, t = x[window[0]: window[1]], t[window[0]: window[1]] + y = model.predict(x).flatten() - if save_wav_file is not None: - rate = 44100 # TODO from data - sampwidth = 3 # 24-bit - wavio.write( - save_wav_file, - y.cpu().numpy() * 2 ** 23, - rate, - scale="none", - sampwidth=sampwidth, - ) + if save_wav_file is not None: + rate = 44100 # TODO from data + sampwidth = 3 # 24-bit + wavio.write(save_wav_file, y * 2 ** 23, rate, scale="none", + sampwidth=sampwidth) - return x, y, t + return x, y, t def _get_input_length(archfile): @@ -213,54 +156,43 @@ def _get_input_length(archfile): if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument( - "model_arch", type=str, help="JSON containing model architecture" - ) - parser.add_argument("train_data", type=str, help="Filename for training data") - parser.add_argument( - "validation_data", type=str, help="Filename for validation data" - ) - parser.add_argument( - "--save_dir", - type=str, - default=None, - help="Where to save the run data (checkpoints, prediction...)", - ) - parser.add_argument( - "--batch_size", type=str, default=4096, help="Number of data per minibatch" - ) - parser.add_argument( - "--minibatches", type=int, default=10, help="Number of minibatches to train for" - ) + parser.add_argument("model_arch", type=str, + help="JSON containing model architecture") + parser.add_argument("train_data", type=str, + help="Filename for training data") + parser.add_argument("validation_data", type=str, + help="Filename for validation data") + parser.add_argument("--save_dir", type=str, default=None, + help="Where to save the run data (checkpoints, prediction...)") + parser.add_argument("--batch_size", type=str, default=4096, + help="Number of data per minibatch") + parser.add_argument("--minibatches", type=int, default=10, + help="Number of minibatches to train for") args = parser.parse_args() - save_dir = ensure_save_dir(args) input_length = _get_input_length(args.model_arch) # Ugh, kludge # Load the data - train_data = Data(args.train_data, input_length, batch_size=args.batch_size) + train_data = Data(args.train_data, input_length, + batch_size=args.batch_size) validate_data = Data(args.validation_data, input_length) - + # Training - model = models.from_json(args.model_arch).to(_device) - t_loss_list, v_loss_list = train( - model, - train_data, - validation_data=validate_data, - n_minibatches=args.minibatches, - plot_at=check_training_at, - wav_at=check_training_at, - plot_kwargs=plot_kwargs, - wav_kwargs=wav_kwargs, - ) - plot_predictions(model, validate_data, window=(0, 44100)) - - print("Predict the full output") - _device = "cpu" - model.to(_device) - predict( - model, - validate_data, - save_wav_file="{}/predict.wav".format(save_dir), - ) - - plot_loss(t_loss_list, v_loss_list, "{}/loss.png".format(save_dir)) + with tf.Session() as sess: + model = models.from_json(args.model_arch, train_data.n, + checkpoint_path=os.path.join(args.save_dir, "model.ckpt")) + t_loss_list, v_loss_list = train( + model, + train_data, + validation_data=validate_data, + n_minibatches=args.minibatches, + plot_at=check_training_at, + wav_at=check_training_at, + plot_kwargs=plot_kwargs, + wav_kwargs=wav_kwargs) + plot_predictions(model, validate_data, window=(0, 44100)) + print("Predict the full output") + predict(model, validate_data, + save_wav_file="{}/predict.wav".format( + os.path.dirname(model.checkpoint_path))) + + plot_loss(t_loss_list, v_loss_list, "{}/loss.png".format(args.save_dir))