commit faf62d62b27c8ba53bd41b077af6913f3fddc94c
Author: Steven Atkinson <[email protected]>
Date: Wed, 2 Jan 2019 20:00:30 -0600
FC model, training script, Reaper JS distortion plugin data
Diffstat:
7 files changed, 384 insertions(+), 0 deletions(-)
diff --git a/README.md b/README.md
@@ -0,0 +1,20 @@
+# Guitar amp
+
+Let's make a data-driven model (neural network) of a guitar amp!
+
+## Autoregressive models
+
+* Fully-connected (n-to-1)
+
+## To do
+
+* Other models
+* Real-time (eventually)
+* Model stereo outputs
+* Data classes for .wav data, keeping sample rate, etc
+* spek.cc?
+
+## Other interesting tidbits?
+
+* Audio analysis plugins: vamp-plugins.org
+* Spectrogram: sonicvisualiser.org
diff --git a/architectures/fc_32_32_32.json b/architectures/fc_32_32_32.json
@@ -0,0 +1,8 @@
+{
+ "type": "FullyConnected",
+ "input_length": 32,
+ "layer_sizes": [
+ 32,
+ 32
+ ]
+}
+\ No newline at end of file
diff --git a/data/reaper-js-distortion/train.npy b/data/reaper-js-distortion/train.npy
Binary files differ.
diff --git a/data/reaper-js-distortion/validate.npy b/data/reaper-js-distortion/validate.npy
Binary files differ.
diff --git a/models.py b/models.py
@@ -0,0 +1,155 @@
+# File: models.py
+# File Created: Sunday, 30th December 2018 9:42:29 pm
+# Author: Steven Atkinson ([email protected])
+
+import numpy as np
+import tensorflow as tf
+import abc
+from tempfile import mkdtemp
+import os
+import json
+
+
+def from_json(f, n_train, checkpoint_path=None):
+ if isinstance(f, str):
+ with open(f, "r") as json_file:
+ f = json.load(json_file)
+
+ if f["type"] == "FullyConnected":
+ return FullyConnected(n_train, f["input_length"],
+ layer_sizes=f["layer_sizes"], checkpoint_path=checkpoint_path)
+ else:
+ raise NotImplementedError("Model type {} unrecognized".format(
+ f["type"]))
+
+
+class Model(object):
+ """
+ Model parent class
+ """
+ def __init__(self, n_train, sess=None, checkpoint_path=None):
+ """
+ Make sure child classes call _build() after this!
+ """
+ if sess is None:
+ sess = tf.get_default_session()
+ self.sess = sess
+
+ if checkpoint_path is None:
+ checkpoint_path = os.path.join(mkdtemp(), "model.ckpt")
+ if not os.path.isdir(os.path.dirname(checkpoint_path)):
+ os.makedirs(os.path.dirname(checkpoint_path))
+ self.checkpoint_path = checkpoint_path
+
+ # self._batch_size = batch_size
+ self._n_train = n_train
+ self.target = tf.placeholder(tf.float32, shape=(None, 1))
+ self.prediction = None
+ self.total_prediction_loss = None
+ self.rmse = None
+ self.loss = None # Includes any regularization
+
+ # @property
+ # def batch_size(self):
+ # return self._batch_size
+
+ @property
+ def n_train(self):
+ return self._n_train
+
+ def load(self, checkpoint_path=None):
+ checkpoint_path = checkpoint_path or self.checkpoint_path
+
+ try:
+ ckpt = tf.train.get_checkpoint_state(checkpoint_path)
+ print("Loading model: {}".format(ckpt.model_checkpoint_path))
+ self.saver.restore(self.sess, ckpt.model_checkpoint_path)
+ except Exception as e:
+ print("Error while attempting to load model: {}".format(e))
+
+ @abc.abstractclassmethod
+ def predict(self, x):
+ """
+ A nice function for prediction.
+ :param x: input array (length=n)
+ :type x: array-like
+ :return: (array-like) corresponding predicted outputs (length=n)
+ """
+ raise NotImplementedError("Implement predict()")
+
+ def save(self, iter, checkpoint_path=None):
+ checkpoint_path = checkpoint_path or self.checkpoint_path
+ self.saver.save(self.sess, checkpoint_path, global_step=iter)
+
+ def _build(self):
+ self.prediction = self._build_prediction()
+ self.loss = self._build_loss()
+
+ # Launch the session
+ self.sess.run(tf.global_variables_initializer())
+ self.saver = tf.train.Saver(tf.global_variables())
+
+ def _build_loss(self):
+ self.total_prediction_loss = tf.losses.mean_squared_error(self.target,
+ self.prediction, weights=self.n_train)
+
+ # Don't count this as a loss!
+ self.rmse = tf.sqrt(
+ self.total_prediction_loss / self.n_train)
+
+ return tf.losses.get_total_loss()
+
+ @abc.abstractclassmethod
+ def _build_prediction(self):
+ raise NotImplementedError('Implement prediction for model')
+
+
+class Autoregressive(Model):
+ """
+ Autoregressive models that take in a few of the most recent input samples
+ and predict the output at the last time point.
+ """
+ def __init__(self, n_train, input_length, sess=None,
+ checkpoint_path=None):
+ super().__init__(n_train, sess=sess, checkpoint_path=checkpoint_path)
+ self._input_length = input_length
+ self.x = tf.placeholder(tf.float32, shape=(None, self.input_length))
+
+ @property
+ def input_length(self):
+ return self._input_length
+
+ def predict(self, x):
+ """
+ Return 1D array of predictions same length as x
+ """
+ n = x.size
+ # Pad x with leading zeros:
+ x = np.concatenate((np.zeros(self.input_length - 1), x))
+ # Reshape into a batch:
+ x_mtx = np.stack([x[i: i + self.input_length] for i in range(n)])
+ # Predict and flatten.
+ return self.sess.run(self.prediction, feed_dict={self.x: x_mtx}) \
+ .flatten()
+
+
+class FullyConnected(Autoregressive):
+ """
+ Autoregressive model taking in a sequence of the most recent inputs, putting
+ them through a series of FC layers, and outputting the single output at the
+ last time step.
+ """
+ def __init__(self, n_train, input_length, layer_sizes=(512,),
+ sess=None, checkpoint_path=None):
+ super().__init__(n_train, input_length, sess=sess,
+ checkpoint_path=checkpoint_path)
+ self._layer_sizes = layer_sizes
+ self._build()
+
+ def _build_prediction(self):
+ h = self.x
+ for m in self._layer_sizes:
+ h = tf.contrib.layers.fully_connected(h, m)
+ y = -1.0 + 2.0 * tf.contrib.layers.fully_connected(h, 1,
+ activation_fn=tf.nn.sigmoid)
+ return y
diff --git a/requirements.txt b/requirements.txt
@@ -0,0 +1,2 @@
+tensorflow
+wavio
diff --git a/train.py b/train.py
@@ -0,0 +1,198 @@
+# File: train.py
+# # File Created: Sunday, 30th December 2018 2:08:54 pm
+# Author: Steven Atkinson ([email protected])
+
+"""
+Here's a script for training new models.
+"""
+
+from argparse import ArgumentParser
+import numpy as np
+import abc
+import tensorflow as tf
+import matplotlib.pyplot as plt
+from time import time
+import wavio
+import json
+import os
+
+import models
+
+# Parameters for training
+check_training_at = np.concatenate([10 ** pwr * np.array([1, 2, 5])
+ for pwr in np.arange(0, 7)])
+plot_kwargs = {
+ "window": (30000, 40000)
+}
+wav_kwargs = {"window": (0, 5 * 44100)}
+
+
+class Data(object):
+ """
+ Object for holding data and spitting out minibatches for training/segments
+ for testing.
+ """
+ def __init__(self, fname, input_length, batch_size=None):
+ xy = np.load(fname)
+ self.x = xy[0]
+ self.y = xy[1]
+ self._n = self.x.size - input_length + 1
+ self.input_length = input_length
+ self.batch_size = batch_size
+
+ @property
+ def n(self):
+ return self._n
+
+ def minibatch(self, n=None, ilist=None):
+ """
+ Pull a random minibatch out of the data set
+ """
+ if ilist is None:
+ n = n or self.batch_size
+ ilist = np.random.randint(0, self.n, size=(n,))
+ x = np.stack([self.x[i: i + self.input_length] for i in ilist])
+ y = np.array([self.y[i + self.input_length - 1] for i in ilist]) \
+ [:, np.newaxis]
+ return x, y
+
+ def sequence(self, start, end):
+ end += self.input_length - 1
+ return self.x[start: end], self.y[start + self.input_length - 1: end]
+
+
+def train(model, train_data, batch_size=None, n_minibatches=10,
+ validation_data=None, plot_at=(), wav_at=(), plot_kwargs={},
+ wav_kwargs={}, save_every=100, validate_every=100):
+ save_dir = os.path.dirname(model.checkpoint_path)
+ sess = model.sess
+ opt = tf.train.AdamOptimizer().minimize(model.loss)
+ sess.run(tf.global_variables_initializer()) # For opt
+ t_loss_list, v_loss_list = [], []
+ t0 = time()
+ for i in range(n_minibatches):
+ x, y = train_data.minibatch(batch_size)
+ t_loss, _ = sess.run((model.rmse, opt),
+ feed_dict={model.x: x, model.target: y})
+ t_loss_list.append([i, t_loss])
+ print("t={:7} | MB {:>7} / {:>7} | TLoss={:8}".format(int(time() - t0),
+ i + 1, n_minibatches, t_loss))
+
+ # Callbacks, basically...
+ if i + 1 in plot_at:
+ plot_predictions(model,
+ validation_data if validation_data is not None else train_data,
+ title="Minibatch {}".format(i + 1),
+ fname="{}/mb_{}.png".format(save_dir, i + 1),
+ **plot_kwargs)
+ if i + 1 in wav_at:
+ print("Making wav for mb {}".format(i + 1))
+ predict(model, validation_data,
+ save_wav_file="{}/predict_{}.wav".format(save_dir, i + 1),
+ **wav_kwargs)
+ if (i + 1) % save_every == 0:
+ model.save(iter=i + 1)
+ if i == 0 or (i + 1) % validate_every == 0:
+ v_loss, _ = sess.run((model.rmse, opt),
+ feed_dict={model.x: x, model.target: y})
+ print("VLoss={:8}".format(v_loss))
+ v_loss_list.append([i, v_loss])
+
+ # After training loop...
+ if validation_data is not None:
+ x, y = validation_data.minibatch(train_data.batch_size)
+ v_loss = sess.run(model.rmse,
+ feed_dict={model.x: x, model.target: y})
+ print("Validation loss={}".format(v_loss))
+ return np.array(t_loss_list).T, np.array(v_loss_list).T
+
+
+def plot_predictions(model, data, title=None, fname=None, window=None):
+ x, y, t = predict(model, data, window=window)
+ plt.figure(figsize=(12, 4))
+ plt.plot(x)
+ plt.plot(t)
+ plt.plot(y)
+ plt.legend(('Input', 'Target', 'Prediction'))
+ if title is not None:
+ plt.title(title)
+ if fname is not None:
+ print("Saving to {}...".format(fname))
+ plt.savefig(fname)
+ plt.close()
+ else:
+ plt.show()
+
+
+def plot_loss(t_loss, v_loss, fname):
+ plt.figure()
+ plt.loglog(t_loss[0], t_loss[1])
+ plt.loglog(v_loss[0], v_loss[1])
+ plt.xlabel("Minibatch")
+ plt.ylabel("RMSE")
+ plt.legend(("Training", "Validation"))
+ plt.savefig(fname)
+ plt.close()
+
+
+def predict(model, data, window=None, save_wav_file=None):
+ x, t = data.x, data.y
+ if window is not None:
+ x, t = x[window[0]: window[1]], t[window[0]: window[1]]
+ y = model.predict(x).flatten()
+
+ if save_wav_file is not None:
+ rate = 44100 # TODO from data
+ sampwidth = 3 # 24-bit
+ wavio.write(save_wav_file, y * 2 ** 23, rate, scale="none",
+ sampwidth=sampwidth)
+
+ return x, y, t
+
+
+def _get_input_length(archfile):
+ return json.load(open(archfile, "r"))["input_length"]
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser()
+ parser.add_argument("model_arch", type=str,
+ help="JSON containing model architecture")
+ parser.add_argument("train_data", type=str,
+ help="Filename for training data")
+ parser.add_argument("validation_data", type=str,
+ help="Filename for validation data")
+ parser.add_argument("--save_dir", type=str, default=None,
+ help="Where to save the run data (checkpoints, prediction...)")
+ parser.add_argument("--batch_size", type=str, default=4096,
+ help="Number of data per minibatch")
+ parser.add_argument("--minibatches", type=int, default=10,
+ help="Number of minibatches to train for")
+ args = parser.parse_args()
+ input_length = _get_input_length(args.model_arch) # Ugh, kludge
+
+ # Load the data
+ train_data = Data(args.train_data, input_length,
+ batch_size=args.batch_size)
+ validate_data = Data(args.validation_data, input_length)
+
+ # Training
+ with tf.Session() as sess:
+ model = models.from_json(args.model_arch, train_data.n,
+ checkpoint_path=os.path.join(args.save_dir, "model.ckpt"))
+ t_loss_list, v_loss_list = train(
+ model,
+ train_data,
+ validation_data=validate_data,
+ n_minibatches=args.minibatches,
+ plot_at=check_training_at,
+ wav_at=check_training_at,
+ plot_kwargs=plot_kwargs,
+ wav_kwargs=wav_kwargs)
+ plot_predictions(model, validate_data, window=(0, 44100))
+ print("Predict the full output")
+ predict(model, validate_data,
+ save_wav_file="{}/predict.wav".format(
+ os.path.dirname(model.checkpoint_path)))
+
+ plot_loss(t_loss_list, v_loss_list, "{}/loss.png".format(args.save_dir))