neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit faf62d62b27c8ba53bd41b077af6913f3fddc94c
Author: Steven Atkinson <[email protected]>
Date:   Wed,  2 Jan 2019 20:00:30 -0600

FC model, training script, Reaper JS distortion plugin data

Diffstat:
AREADME.md | 20++++++++++++++++++++
Aarchitectures/fc_32_32_32.json | 9+++++++++
Adata/reaper-js-distortion/train.npy | 0
Adata/reaper-js-distortion/validate.npy | 0
Amodels.py | 155+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Arequirements.txt | 2++
Atrain.py | 198+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
7 files changed, 384 insertions(+), 0 deletions(-)

diff --git a/README.md b/README.md @@ -0,0 +1,20 @@ +# Guitar amp + +Let's make a data-driven model (neural network) of a guitar amp! + +## Autoregressive models + +* Fully-connected (n-to-1) + +## To do + +* Other models +* Real-time (eventually) +* Model stereo outputs +* Data classes for .wav data, keeping sample rate, etc +* spek.cc? + +## Other interesting tidbits? + +* Audio analysis plugins: vamp-plugins.org +* Spectrogram: sonicvisualiser.org diff --git a/architectures/fc_32_32_32.json b/architectures/fc_32_32_32.json @@ -0,0 +1,8 @@ +{ + "type": "FullyConnected", + "input_length": 32, + "layer_sizes": [ + 32, + 32 + ] +} +\ No newline at end of file diff --git a/data/reaper-js-distortion/train.npy b/data/reaper-js-distortion/train.npy Binary files differ. diff --git a/data/reaper-js-distortion/validate.npy b/data/reaper-js-distortion/validate.npy Binary files differ. diff --git a/models.py b/models.py @@ -0,0 +1,155 @@ +# File: models.py +# File Created: Sunday, 30th December 2018 9:42:29 pm +# Author: Steven Atkinson ([email protected]) + +import numpy as np +import tensorflow as tf +import abc +from tempfile import mkdtemp +import os +import json + + +def from_json(f, n_train, checkpoint_path=None): + if isinstance(f, str): + with open(f, "r") as json_file: + f = json.load(json_file) + + if f["type"] == "FullyConnected": + return FullyConnected(n_train, f["input_length"], + layer_sizes=f["layer_sizes"], checkpoint_path=checkpoint_path) + else: + raise NotImplementedError("Model type {} unrecognized".format( + f["type"])) + + +class Model(object): + """ + Model parent class + """ + def __init__(self, n_train, sess=None, checkpoint_path=None): + """ + Make sure child classes call _build() after this! + """ + if sess is None: + sess = tf.get_default_session() + self.sess = sess + + if checkpoint_path is None: + checkpoint_path = os.path.join(mkdtemp(), "model.ckpt") + if not os.path.isdir(os.path.dirname(checkpoint_path)): + os.makedirs(os.path.dirname(checkpoint_path)) + self.checkpoint_path = checkpoint_path + + # self._batch_size = batch_size + self._n_train = n_train + self.target = tf.placeholder(tf.float32, shape=(None, 1)) + self.prediction = None + self.total_prediction_loss = None + self.rmse = None + self.loss = None # Includes any regularization + + # @property + # def batch_size(self): + # return self._batch_size + + @property + def n_train(self): + return self._n_train + + def load(self, checkpoint_path=None): + checkpoint_path = checkpoint_path or self.checkpoint_path + + try: + ckpt = tf.train.get_checkpoint_state(checkpoint_path) + print("Loading model: {}".format(ckpt.model_checkpoint_path)) + self.saver.restore(self.sess, ckpt.model_checkpoint_path) + except Exception as e: + print("Error while attempting to load model: {}".format(e)) + + @abc.abstractclassmethod + def predict(self, x): + """ + A nice function for prediction. + :param x: input array (length=n) + :type x: array-like + :return: (array-like) corresponding predicted outputs (length=n) + """ + raise NotImplementedError("Implement predict()") + + def save(self, iter, checkpoint_path=None): + checkpoint_path = checkpoint_path or self.checkpoint_path + self.saver.save(self.sess, checkpoint_path, global_step=iter) + + def _build(self): + self.prediction = self._build_prediction() + self.loss = self._build_loss() + + # Launch the session + self.sess.run(tf.global_variables_initializer()) + self.saver = tf.train.Saver(tf.global_variables()) + + def _build_loss(self): + self.total_prediction_loss = tf.losses.mean_squared_error(self.target, + self.prediction, weights=self.n_train) + + # Don't count this as a loss! + self.rmse = tf.sqrt( + self.total_prediction_loss / self.n_train) + + return tf.losses.get_total_loss() + + @abc.abstractclassmethod + def _build_prediction(self): + raise NotImplementedError('Implement prediction for model') + + +class Autoregressive(Model): + """ + Autoregressive models that take in a few of the most recent input samples + and predict the output at the last time point. + """ + def __init__(self, n_train, input_length, sess=None, + checkpoint_path=None): + super().__init__(n_train, sess=sess, checkpoint_path=checkpoint_path) + self._input_length = input_length + self.x = tf.placeholder(tf.float32, shape=(None, self.input_length)) + + @property + def input_length(self): + return self._input_length + + def predict(self, x): + """ + Return 1D array of predictions same length as x + """ + n = x.size + # Pad x with leading zeros: + x = np.concatenate((np.zeros(self.input_length - 1), x)) + # Reshape into a batch: + x_mtx = np.stack([x[i: i + self.input_length] for i in range(n)]) + # Predict and flatten. + return self.sess.run(self.prediction, feed_dict={self.x: x_mtx}) \ + .flatten() + + +class FullyConnected(Autoregressive): + """ + Autoregressive model taking in a sequence of the most recent inputs, putting + them through a series of FC layers, and outputting the single output at the + last time step. + """ + def __init__(self, n_train, input_length, layer_sizes=(512,), + sess=None, checkpoint_path=None): + super().__init__(n_train, input_length, sess=sess, + checkpoint_path=checkpoint_path) + self._layer_sizes = layer_sizes + self._build() + + def _build_prediction(self): + h = self.x + for m in self._layer_sizes: + h = tf.contrib.layers.fully_connected(h, m) + y = -1.0 + 2.0 * tf.contrib.layers.fully_connected(h, 1, + activation_fn=tf.nn.sigmoid) + return y diff --git a/requirements.txt b/requirements.txt @@ -0,0 +1,2 @@ +tensorflow +wavio diff --git a/train.py b/train.py @@ -0,0 +1,198 @@ +# File: train.py +# # File Created: Sunday, 30th December 2018 2:08:54 pm +# Author: Steven Atkinson ([email protected]) + +""" +Here's a script for training new models. +""" + +from argparse import ArgumentParser +import numpy as np +import abc +import tensorflow as tf +import matplotlib.pyplot as plt +from time import time +import wavio +import json +import os + +import models + +# Parameters for training +check_training_at = np.concatenate([10 ** pwr * np.array([1, 2, 5]) + for pwr in np.arange(0, 7)]) +plot_kwargs = { + "window": (30000, 40000) +} +wav_kwargs = {"window": (0, 5 * 44100)} + + +class Data(object): + """ + Object for holding data and spitting out minibatches for training/segments + for testing. + """ + def __init__(self, fname, input_length, batch_size=None): + xy = np.load(fname) + self.x = xy[0] + self.y = xy[1] + self._n = self.x.size - input_length + 1 + self.input_length = input_length + self.batch_size = batch_size + + @property + def n(self): + return self._n + + def minibatch(self, n=None, ilist=None): + """ + Pull a random minibatch out of the data set + """ + if ilist is None: + n = n or self.batch_size + ilist = np.random.randint(0, self.n, size=(n,)) + x = np.stack([self.x[i: i + self.input_length] for i in ilist]) + y = np.array([self.y[i + self.input_length - 1] for i in ilist]) \ + [:, np.newaxis] + return x, y + + def sequence(self, start, end): + end += self.input_length - 1 + return self.x[start: end], self.y[start + self.input_length - 1: end] + + +def train(model, train_data, batch_size=None, n_minibatches=10, + validation_data=None, plot_at=(), wav_at=(), plot_kwargs={}, + wav_kwargs={}, save_every=100, validate_every=100): + save_dir = os.path.dirname(model.checkpoint_path) + sess = model.sess + opt = tf.train.AdamOptimizer().minimize(model.loss) + sess.run(tf.global_variables_initializer()) # For opt + t_loss_list, v_loss_list = [], [] + t0 = time() + for i in range(n_minibatches): + x, y = train_data.minibatch(batch_size) + t_loss, _ = sess.run((model.rmse, opt), + feed_dict={model.x: x, model.target: y}) + t_loss_list.append([i, t_loss]) + print("t={:7} | MB {:>7} / {:>7} | TLoss={:8}".format(int(time() - t0), + i + 1, n_minibatches, t_loss)) + + # Callbacks, basically... + if i + 1 in plot_at: + plot_predictions(model, + validation_data if validation_data is not None else train_data, + title="Minibatch {}".format(i + 1), + fname="{}/mb_{}.png".format(save_dir, i + 1), + **plot_kwargs) + if i + 1 in wav_at: + print("Making wav for mb {}".format(i + 1)) + predict(model, validation_data, + save_wav_file="{}/predict_{}.wav".format(save_dir, i + 1), + **wav_kwargs) + if (i + 1) % save_every == 0: + model.save(iter=i + 1) + if i == 0 or (i + 1) % validate_every == 0: + v_loss, _ = sess.run((model.rmse, opt), + feed_dict={model.x: x, model.target: y}) + print("VLoss={:8}".format(v_loss)) + v_loss_list.append([i, v_loss]) + + # After training loop... + if validation_data is not None: + x, y = validation_data.minibatch(train_data.batch_size) + v_loss = sess.run(model.rmse, + feed_dict={model.x: x, model.target: y}) + print("Validation loss={}".format(v_loss)) + return np.array(t_loss_list).T, np.array(v_loss_list).T + + +def plot_predictions(model, data, title=None, fname=None, window=None): + x, y, t = predict(model, data, window=window) + plt.figure(figsize=(12, 4)) + plt.plot(x) + plt.plot(t) + plt.plot(y) + plt.legend(('Input', 'Target', 'Prediction')) + if title is not None: + plt.title(title) + if fname is not None: + print("Saving to {}...".format(fname)) + plt.savefig(fname) + plt.close() + else: + plt.show() + + +def plot_loss(t_loss, v_loss, fname): + plt.figure() + plt.loglog(t_loss[0], t_loss[1]) + plt.loglog(v_loss[0], v_loss[1]) + plt.xlabel("Minibatch") + plt.ylabel("RMSE") + plt.legend(("Training", "Validation")) + plt.savefig(fname) + plt.close() + + +def predict(model, data, window=None, save_wav_file=None): + x, t = data.x, data.y + if window is not None: + x, t = x[window[0]: window[1]], t[window[0]: window[1]] + y = model.predict(x).flatten() + + if save_wav_file is not None: + rate = 44100 # TODO from data + sampwidth = 3 # 24-bit + wavio.write(save_wav_file, y * 2 ** 23, rate, scale="none", + sampwidth=sampwidth) + + return x, y, t + + +def _get_input_length(archfile): + return json.load(open(archfile, "r"))["input_length"] + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("model_arch", type=str, + help="JSON containing model architecture") + parser.add_argument("train_data", type=str, + help="Filename for training data") + parser.add_argument("validation_data", type=str, + help="Filename for validation data") + parser.add_argument("--save_dir", type=str, default=None, + help="Where to save the run data (checkpoints, prediction...)") + parser.add_argument("--batch_size", type=str, default=4096, + help="Number of data per minibatch") + parser.add_argument("--minibatches", type=int, default=10, + help="Number of minibatches to train for") + args = parser.parse_args() + input_length = _get_input_length(args.model_arch) # Ugh, kludge + + # Load the data + train_data = Data(args.train_data, input_length, + batch_size=args.batch_size) + validate_data = Data(args.validation_data, input_length) + + # Training + with tf.Session() as sess: + model = models.from_json(args.model_arch, train_data.n, + checkpoint_path=os.path.join(args.save_dir, "model.ckpt")) + t_loss_list, v_loss_list = train( + model, + train_data, + validation_data=validate_data, + n_minibatches=args.minibatches, + plot_at=check_training_at, + wav_at=check_training_at, + plot_kwargs=plot_kwargs, + wav_kwargs=wav_kwargs) + plot_predictions(model, validate_data, window=(0, 44100)) + print("Predict the full output") + predict(model, validate_data, + save_wav_file="{}/predict.wav".format( + os.path.dirname(model.checkpoint_path))) + + plot_loss(t_loss_list, v_loss_list, "{}/loss.png".format(args.save_dir))