neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 9bc94d6c185be755135e1ed24919ddd92a78cfe4
parent aa489d720868354495f84e1f9f035e3869dd4dea
Author: Steven Atkinson <[email protected]>
Date:   Sun, 21 Feb 2021 15:48:37 -0500

Merge pull request #2 from sdatkinson/pytorch

PyTorch
Diffstat:
MREADME.md | 4+++-
Mmodels.py | 158+++++++++++++++++++++++++++++--------------------------------------------------
Mreamp.py | 79+++++++++++++++++++++++++++++++++++++++++++++++--------------------------------
Mrequirements.txt | 11++++++++++-
Mtrain.py | 270+++++++++++++++++++++++++++++++++++++++++++++++++------------------------------
5 files changed, 286 insertions(+), 236 deletions(-)

diff --git a/README.md b/README.md @@ -4,7 +4,9 @@ Let's use deep learning to make a model of a guitar amp! ## Setup -Use `pip install -r requirements.txt` as well as intalling the TensorFlow version of your choice (e.g. `pip install tensorflow-gpu`). +```bash +`pip install -r requirements.txt` +``` ## Autoregressive models diff --git a/models.py b/models.py @@ -2,84 +2,35 @@ # File Created: Sunday, 30th December 2018 9:42:29 pm # Author: Steven Atkinson ([email protected]) -import numpy as np -import tensorflow as tf import abc -from tempfile import mkdtemp -import os import json +import os +from tempfile import mkdtemp + +import numpy as np +import torch +import torch.nn as nn -def from_json(f, n_train=1, checkpoint_path=None): +def from_json(f): if isinstance(f, str): with open(f, "r") as json_file: f = json.load(json_file) - + if f["type"] == "FullyConnected": - return FullyConnected(n_train, f["input_length"], - layer_sizes=f["layer_sizes"], checkpoint_path=checkpoint_path) + return mlp(f["input_length"], 1, layer_sizes=f["layer_sizes"]) else: - raise NotImplementedError("Model type {} unrecognized".format( - f["type"])) + raise NotImplementedError("Model type {} unrecognized".format(f["type"])) -class Model(object): +class Model(nn.Module): """ Model parent class """ - def __init__(self, n_train, sess=None, checkpoint_path=None): - """ - Make sure child classes call _build() after this! - """ - if sess is None: - sess = tf.get_default_session() - self.sess = sess - - if checkpoint_path is None: - checkpoint_path = os.path.join(mkdtemp(), "model.ckpt") - if not os.path.isdir(os.path.dirname(checkpoint_path)): - os.makedirs(os.path.dirname(checkpoint_path)) - self.checkpoint_path = checkpoint_path - - # self._batch_size = batch_size - self._n_train = n_train - self.target = tf.placeholder(tf.float32, shape=(None, 1)) - self.prediction = None - self.total_prediction_loss = None - self.rmse = None - self.loss = None # Includes any regularization - - # @property - # def batch_size(self): - # return self._batch_size - - @property - def n_train(self): - return self._n_train - - def load(self, checkpoint_path=None): - checkpoint_path = checkpoint_path or self.checkpoint_path - - try: - ckpt = tf.train.get_checkpoint_state(checkpoint_path) - print("Loading model: {}".format(ckpt.model_checkpoint_path)) - self.saver.restore(self.sess, ckpt.model_checkpoint_path) - except Exception as e: - print("Error while attempting to load model: {}".format(e)) - - @abc.abstractclassmethod - def predict(self, x): - """ - A nice function for prediction. - :param x: input array (length=n) - :type x: array-like - :return: (array-like) corresponding predicted outputs (length=n) - """ - raise NotImplementedError("Implement predict()") - def save(self, iter, checkpoint_path=None): - checkpoint_path = checkpoint_path or self.checkpoint_path - self.saver.save(self.sess, checkpoint_path, global_step=iter) + @abc.abstractmethod + def predict_sequence(self, x): + raise NotImplementedError() def _build(self): self.prediction = self._build_prediction() @@ -90,18 +41,18 @@ class Model(object): self.saver = tf.train.Saver(tf.global_variables()) def _build_loss(self): - self.total_prediction_loss = tf.losses.mean_squared_error(self.target, - self.prediction, weights=self.n_train) - + self.total_prediction_loss = tf.losses.mean_squared_error( + self.target, self.prediction, weights=self.n_train + ) + # Don't count this as a loss! - self.rmse = tf.sqrt( - self.total_prediction_loss / self.n_train) + self.rmse = tf.sqrt(self.total_prediction_loss / self.n_train) return tf.losses.get_total_loss() @abc.abstractclassmethod def _build_prediction(self): - raise NotImplementedError('Implement prediction for model') + raise NotImplementedError("Implement prediction for model") class Autoregressive(Model): @@ -109,57 +60,62 @@ class Autoregressive(Model): Autoregressive models that take in a few of the most recent input samples and predict the output at the last time point. """ - def __init__(self, n_train, input_length, sess=None, - checkpoint_path=None): - super().__init__(n_train, sess=sess, checkpoint_path=checkpoint_path) - self._input_length = input_length - self.x = tf.placeholder(tf.float32, shape=(None, self.input_length)) - - @property + + @abc.abstractproperty def input_length(self): - return self._input_length + raise NotImplementedError() - def predict(self, x, batch_size=None, verbose=False): + def predict_sequence(self, x: torch.Tensor, batch_size=None, verbose=False): """ Return 1D array of predictions same length as x """ - n = x.size + n = x.numel() batch_size = batch_size or n # Pad x with leading zeros: - x = np.concatenate((np.zeros(self.input_length - 1), x)) + x = torch.cat((torch.zeros(self.input_length - 1).to(x.device), x)) i = 0 y = [] while i < n: - if verbose: - print("model.predict {}/{}".format(i, n)) this_batch_size = np.min([batch_size, n - i]) # Reshape into a batch: - x_mtx = np.stack([x[j: j + self.input_length] - for j in range(i, i + this_batch_size)]) + x_mtx = torch.stack( + [x[j : j + self.input_length] for j in range(i, i + this_batch_size)] + ) # Predict and flatten. - y.append(self.sess.run(self.prediction, feed_dict={self.x: x_mtx}) \ - .flatten()) + y.append(self(x_mtx).squeeze()) i += this_batch_size - return np.concatenate(y) + return torch.cat(y) class FullyConnected(Autoregressive): """ Autoregressive model taking in a sequence of the most recent inputs, putting - them through a series of FC layers, and outputting the single output at the + them through a series of FC layers, and outputting the single output at the last time step. """ - def __init__(self, n_train, input_length, layer_sizes=(512,), - sess=None, checkpoint_path=None): - super().__init__(n_train, input_length, sess=sess, - checkpoint_path=checkpoint_path) - self._layer_sizes = layer_sizes - self._build() - def _build_prediction(self): - h = self.x - for m in self._layer_sizes: - h = tf.contrib.layers.fully_connected(h, m) - y = -1.0 + 2.0 * tf.contrib.layers.fully_connected(h, 1, - activation_fn=tf.nn.sigmoid) - return y + def __init__(self, net): + super().__init__() + self._net = net + + @property + def input_length(self): + return self._net[0][0].weight.data.shape[1] + + def forward(self, inputs): + return self._net(inputs) + + +def mlp(dx, dy, layer_sizes=None): + def block(dx, dy, Activation=nn.ReLU): + return nn.Sequential(nn.Linear(dx, dy), Activation()) + + layer_sizes = [256, 256] if layer_sizes is None else layer_sizes + + net = nn.Sequential() + in_features = dx + for i, out_features in enumerate(layer_sizes): + net.add_module("layer_%i" % i, block(in_features, out_features)) + in_features = out_features + net.add_module("head", block(in_features, dy, Activation=nn.Tanh)) + return FullyConnected(net) diff --git a/reamp.py b/reamp.py @@ -4,11 +4,13 @@ Reamp a .wav file Assumes 24-bit WAV files """ -from argparse import ArgumentParser -import tensorflow as tf + import os -import wavio +from argparse import ArgumentParser + import matplotlib.pyplot as plt +import torch +import wavio import models @@ -19,41 +21,56 @@ def _sampwidth_to_bits(x): if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("architecture", type=str, - help="JSON filename containing NN architecture") - parser.add_argument("checkpoint_dir", type=str, - help="directory holding model checkpoint to use") - parser.add_argument("input_file", type=str, - help="Input .wav file to convert") - parser.add_argument("--output_file", type=str, default=None, - help="Where to save the output") - parser.add_argument("--batch_size", type=int, default=8192, - help="How many samples to process at a time. " + - "Reduce if there are out-of-memory issues.") - parser.add_argument("--target_file", type=str, default=None, - help=".wav file of the true output (if you want to compare)") + parser.add_argument( + "architecture", type=str, help="JSON filename containing NN architecture" + ) + parser.add_argument( + "params", type=str, help="directory holding model checkpoint to use" + ) + parser.add_argument("input_file", type=str, help="Input .wav file to convert") + parser.add_argument( + "--output_file", type=str, default=None, help="Where to save the output" + ) + parser.add_argument( + "--batch_size", + type=int, + default=8192, + help="How many samples to process at a time. " + + "Reduce if there are out-of-memory issues.", + ) + parser.add_argument( + "--target_file", + type=str, + default=None, + help=".wav file of the true output (if you want to compare)", + ) args = parser.parse_args() - if args.output_file is None: - args.output_file = args.input_file.rstrip(".wav") + "_reamped.wav" - + output_file = ( + args.output_file + if args.output_file is not None + else args.input_file.rstrip(".wav") + "_reamped.wav" + ) + if os.path.isfile(args.output_file): print("Output file exists; skip") exit(1) - + x = wavio.read(args.input_file) rate, sampwidth = x.rate, x.sampwidth bits = _sampwidth_to_bits(sampwidth) - x_data = x.data.flatten() / 2 ** (bits - 1) - - with tf.Session() as sess: - model = models.from_json(args.architecture, - checkpoint_path=args.checkpoint_dir) - model.load() - y = model.predict(x_data, batch_size=args.batch_size, verbose=True) - wavio.write(args.output_file, y * 2 ** (bits - 1), rate, scale="none", - sampwidth=sampwidth) - + x_data = torch.Tensor(x.data.flatten() / 2 ** (bits - 1)) + + model = models.from_json(args.architecture) + model.load_state_dict(torch.load(args.params)) + with torch.no_grad(): + y = model.predict_sequence( + x_data, batch_size=args.batch_size, verbose=True + ).numpy() + wavio.write( + output_file, y * 2 ** (bits - 1), rate, scale="none", sampwidth=sampwidth + ) + if args.target_file is not None and os.path.isfile(args.target_file): t = wavio.read(args.target_file) t_data = t.data.flatten() / 2 ** (_sampwidth_to_bits(t.sampwidth) - 1) @@ -63,4 +80,3 @@ if __name__ == "__main__": plt.plot(y) plt.legend(["Input", "Target", "Prediction"]) plt.show() - -\ No newline at end of file diff --git a/requirements.txt b/requirements.txt @@ -1,2 +1,11 @@ -tensorflow +# File: requirements.txt +# Created Date: 2019-01-02 +# Author: Steven Atkinson ([email protected]) + +black +flake8 +matplotlib +numpy +pytest +torch>=1 wavio diff --git a/train.py b/train.py @@ -4,42 +4,62 @@ """ Here's a script for training new models. + +TODO +* Device +* Lightning? """ -from argparse import ArgumentParser -import numpy as np import abc -import tensorflow as tf -import matplotlib.pyplot as plt -from time import time -import wavio import json import os +from argparse import ArgumentParser +from time import time + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +import wavio import models # Parameters for training -check_training_at = np.concatenate([10 ** pwr * np.array([1, 2, 5]) - for pwr in np.arange(0, 7)]) -plot_kwargs = { - "window": (30000, 40000) -} +check_training_at = np.concatenate( + [10 ** pwr * np.array([1, 2, 5]) for pwr in np.arange(0, 7)] +) +plot_kwargs = {"window": (30000, 40000)} wav_kwargs = {"window": (0, 5 * 44100)} +torch.manual_seed(0) +_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def ensure_save_dir(args): + save_dir = ( + os.path.join(os.path.dirname(__file__), "output") + if args.save_dir is None + else args.save_dir + ) + if not os.path.isdir(save_dir): + os.makedirs(save_dir) + return save_dir + class Data(object): """ - Object for holding data and spitting out minibatches for training/segments + Object for holding data and spitting out minibatches for training/segments for testing. """ + def __init__(self, fname, input_length, batch_size=None): xy = np.load(fname) - self.x = xy[0] - self.y = xy[1] - self._n = self.x.size - input_length + 1 + self.x = torch.Tensor(xy[0]) + self.y = torch.Tensor(xy[1]) + self._n = self.x.numel() - input_length + 1 self.input_length = input_length self.batch_size = batch_size - + @property def n(self): return self._n @@ -50,70 +70,101 @@ class Data(object): """ if ilist is None: n = n or self.batch_size - ilist = np.random.randint(0, self.n, size=(n,)) - x = np.stack([self.x[i: i + self.input_length] for i in ilist]) - y = np.array([self.y[i + self.input_length - 1] for i in ilist]) \ - [:, np.newaxis] - return x, y + ilist = torch.randint(0, self.n, size=(n,)) + x = torch.stack([self.x[i : i + self.input_length] for i in ilist]) + y = torch.stack([self.y[i + self.input_length - 1] for i in ilist])[:, None] + return x.to(_device), y.to(_device) def sequence(self, start, end): end += self.input_length - 1 - return self.x[start: end], self.y[start + self.input_length - 1: end] + return self.x[start:end], self.y[start + self.input_length - 1 : end] + + +def train_step(model, optimizer, batch): + model.train() + model.zero_grad() + inputs, targets = batch + preds = model(inputs) + loss = nn.MSELoss()(preds, targets) + loss.backward() + optimizer.step() + return loss.item() + + +def validation_step(model, batch): + with torch.no_grad(): + model.eval() + inputs, targets = batch + return nn.MSELoss()(model(inputs), targets).item() -def train(model, train_data, batch_size=None, n_minibatches=10, - validation_data=None, plot_at=(), wav_at=(), plot_kwargs={}, - wav_kwargs={}, save_every=100, validate_every=100): - save_dir = os.path.dirname(model.checkpoint_path) - sess = model.sess - opt = tf.train.AdamOptimizer().minimize(model.loss) - sess.run(tf.global_variables_initializer()) # For opt +def train( + model, + train_data, + batch_size=None, + n_minibatches=10, + validation_data=None, + plot_at=(), + wav_at=(), + plot_kwargs={}, + wav_kwargs={}, + save_every=100, + validate_every=100, +): + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) t_loss_list, v_loss_list = [], [] t0 = time() for i in range(n_minibatches): - x, y = train_data.minibatch(batch_size) - t_loss, _ = sess.run((model.rmse, opt), - feed_dict={model.x: x, model.target: y}) + batch = train_data.minibatch(batch_size) + t_loss = train_step(model, optimizer, batch) t_loss_list.append([i, t_loss]) - print("t={:7} | MB {:>7} / {:>7} | TLoss={:8}".format(int(time() - t0), - i + 1, n_minibatches, t_loss)) - + print( + "t={:7} | MB {:>7} / {:>7} | TLoss={:8}".format( + int(time() - t0), i + 1, n_minibatches, t_loss + ) + ) + # Callbacks, basically... if i + 1 in plot_at: - plot_predictions(model, - validation_data if validation_data is not None else train_data, - title="Minibatch {}".format(i + 1), - fname="{}/mb_{}.png".format(save_dir, i + 1), - **plot_kwargs) + plot_predictions( + model, + validation_data if validation_data is not None else train_data, + title="Minibatch {}".format(i + 1), + fname="{}/mb_{}.png".format(save_dir, i + 1), + **plot_kwargs + ) if i + 1 in wav_at: print("Making wav for mb {}".format(i + 1)) - predict(model, validation_data, - save_wav_file="{}/predict_{}.wav".format(save_dir, i + 1), - **wav_kwargs) + predict( + model, + validation_data, + save_wav_file="{}/predict_{}.wav".format(save_dir, i + 1), + **wav_kwargs + ) if (i + 1) % save_every == 0: - model.save(iter=i + 1) + torch.save(model.state_dict(), os.path.join(save_dir, "model_%i.pt" % i)) if i == 0 or (i + 1) % validate_every == 0: - v_loss, _ = sess.run((model.rmse, opt), - feed_dict={model.x: x, model.target: y}) + v_loss = validation_step(model, batch) print("VLoss={:8}".format(v_loss)) v_loss_list.append([i, v_loss]) - + # After training loop... if validation_data is not None: - x, y = validation_data.minibatch(train_data.batch_size) - v_loss = sess.run(model.rmse, - feed_dict={model.x: x, model.target: y}) + batch = validation_data.minibatch(train_data.batch_size) + v_loss = validation_step(model, batch) print("Validation loss={}".format(v_loss)) + torch.save(model.state_dict(), os.path.join(save_dir, "model.pt")) return np.array(t_loss_list).T, np.array(v_loss_list).T def plot_predictions(model, data, title=None, fname=None, window=None): - x, y, t = predict(model, data, window=window) - plt.figure(figsize=(12, 4)) - plt.plot(x) - plt.plot(t) - plt.plot(y) - plt.legend(('Input', 'Target', 'Prediction')) + with torch.no_grad(): + x, y, t = predict(model, data, window=window) + plt.figure(figsize=(12, 4)) + plt.plot(x.cpu()) + plt.plot(t.cpu()) + plt.plot(y.cpu()) + plt.legend(("Input", "Target", "Prediction")) if title is not None: plt.title(title) if fname is not None: @@ -133,21 +184,27 @@ def plot_loss(t_loss, v_loss, fname): plt.legend(("Training", "Validation")) plt.savefig(fname) plt.close() - + def predict(model, data, window=None, save_wav_file=None): - x, t = data.x, data.y - if window is not None: - x, t = x[window[0]: window[1]], t[window[0]: window[1]] - y = model.predict(x).flatten() + with torch.no_grad(): + x, t = data.x.to(_device), data.y.to(_device) + if window is not None: + x, t = x[window[0] : window[1]], t[window[0] : window[1]] + y = model.predict_sequence(x).squeeze() - if save_wav_file is not None: - rate = 44100 # TODO from data - sampwidth = 3 # 24-bit - wavio.write(save_wav_file, y * 2 ** 23, rate, scale="none", - sampwidth=sampwidth) + if save_wav_file is not None: + rate = 44100 # TODO from data + sampwidth = 3 # 24-bit + wavio.write( + save_wav_file, + y.cpu().numpy() * 2 ** 23, + rate, + scale="none", + sampwidth=sampwidth, + ) - return x, y, t + return x, y, t def _get_input_length(archfile): @@ -156,43 +213,54 @@ def _get_input_length(archfile): if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("model_arch", type=str, - help="JSON containing model architecture") - parser.add_argument("train_data", type=str, - help="Filename for training data") - parser.add_argument("validation_data", type=str, - help="Filename for validation data") - parser.add_argument("--save_dir", type=str, default=None, - help="Where to save the run data (checkpoints, prediction...)") - parser.add_argument("--batch_size", type=str, default=4096, - help="Number of data per minibatch") - parser.add_argument("--minibatches", type=int, default=10, - help="Number of minibatches to train for") + parser.add_argument( + "model_arch", type=str, help="JSON containing model architecture" + ) + parser.add_argument("train_data", type=str, help="Filename for training data") + parser.add_argument( + "validation_data", type=str, help="Filename for validation data" + ) + parser.add_argument( + "--save_dir", + type=str, + default=None, + help="Where to save the run data (checkpoints, prediction...)", + ) + parser.add_argument( + "--batch_size", type=str, default=4096, help="Number of data per minibatch" + ) + parser.add_argument( + "--minibatches", type=int, default=10, help="Number of minibatches to train for" + ) args = parser.parse_args() + save_dir = ensure_save_dir(args) input_length = _get_input_length(args.model_arch) # Ugh, kludge # Load the data - train_data = Data(args.train_data, input_length, - batch_size=args.batch_size) + train_data = Data(args.train_data, input_length, batch_size=args.batch_size) validate_data = Data(args.validation_data, input_length) - + # Training - with tf.Session() as sess: - model = models.from_json(args.model_arch, train_data.n, - checkpoint_path=os.path.join(args.save_dir, "model.ckpt")) - t_loss_list, v_loss_list = train( - model, - train_data, - validation_data=validate_data, - n_minibatches=args.minibatches, - plot_at=check_training_at, - wav_at=check_training_at, - plot_kwargs=plot_kwargs, - wav_kwargs=wav_kwargs) - plot_predictions(model, validate_data, window=(0, 44100)) - print("Predict the full output") - predict(model, validate_data, - save_wav_file="{}/predict.wav".format( - os.path.dirname(model.checkpoint_path))) - - plot_loss(t_loss_list, v_loss_list, "{}/loss.png".format(args.save_dir)) + model = models.from_json(args.model_arch).to(_device) + t_loss_list, v_loss_list = train( + model, + train_data, + validation_data=validate_data, + n_minibatches=args.minibatches, + plot_at=check_training_at, + wav_at=check_training_at, + plot_kwargs=plot_kwargs, + wav_kwargs=wav_kwargs, + ) + plot_predictions(model, validate_data, window=(0, 44100)) + + print("Predict the full output") + _device = "cpu" + model.to(_device) + predict( + model, + validate_data, + save_wav_file="{}/predict.wav".format(save_dir), + ) + + plot_loss(t_loss_list, v_loss_list, "{}/loss.png".format(save_dir))