commit 1c2565b173a3b49150c890e15c0d20f7f21367be
parent 9bc94d6c185be755135e1ed24919ddd92a78cfe4
Author: Steven Atkinson <[email protected]>
Date: Sun, 21 Feb 2021 15:50:00 -0500
Merge pull request #3 from sdatkinson/revert-2-pytorch
Revert "PyTorch"
Diffstat:
M | README.md | | | 4 | +--- |
M | models.py | | | 158 | ++++++++++++++++++++++++++++++++++++++++++++++++++----------------------------- |
M | reamp.py | | | 79 | ++++++++++++++++++++++++++++++++----------------------------------------------- |
M | requirements.txt | | | 11 | +---------- |
M | train.py | | | 270 | ++++++++++++++++++++++++++++++------------------------------------------------- |
5 files changed, 236 insertions(+), 286 deletions(-)
diff --git a/README.md b/README.md
@@ -4,9 +4,7 @@ Let's use deep learning to make a model of a guitar amp!
## Setup
-```bash
-`pip install -r requirements.txt`
-```
+Use `pip install -r requirements.txt` as well as intalling the TensorFlow version of your choice (e.g. `pip install tensorflow-gpu`).
## Autoregressive models
diff --git a/models.py b/models.py
@@ -2,35 +2,84 @@
# File Created: Sunday, 30th December 2018 9:42:29 pm
# Author: Steven Atkinson ([email protected])
+import numpy as np
+import tensorflow as tf
import abc
-import json
-import os
from tempfile import mkdtemp
-
-import numpy as np
-import torch
-import torch.nn as nn
+import os
+import json
-def from_json(f):
+def from_json(f, n_train=1, checkpoint_path=None):
if isinstance(f, str):
with open(f, "r") as json_file:
f = json.load(json_file)
-
+
if f["type"] == "FullyConnected":
- return mlp(f["input_length"], 1, layer_sizes=f["layer_sizes"])
+ return FullyConnected(n_train, f["input_length"],
+ layer_sizes=f["layer_sizes"], checkpoint_path=checkpoint_path)
else:
- raise NotImplementedError("Model type {} unrecognized".format(f["type"]))
+ raise NotImplementedError("Model type {} unrecognized".format(
+ f["type"]))
-class Model(nn.Module):
+class Model(object):
"""
Model parent class
"""
+ def __init__(self, n_train, sess=None, checkpoint_path=None):
+ """
+ Make sure child classes call _build() after this!
+ """
+ if sess is None:
+ sess = tf.get_default_session()
+ self.sess = sess
+
+ if checkpoint_path is None:
+ checkpoint_path = os.path.join(mkdtemp(), "model.ckpt")
+ if not os.path.isdir(os.path.dirname(checkpoint_path)):
+ os.makedirs(os.path.dirname(checkpoint_path))
+ self.checkpoint_path = checkpoint_path
+
+ # self._batch_size = batch_size
+ self._n_train = n_train
+ self.target = tf.placeholder(tf.float32, shape=(None, 1))
+ self.prediction = None
+ self.total_prediction_loss = None
+ self.rmse = None
+ self.loss = None # Includes any regularization
+
+ # @property
+ # def batch_size(self):
+ # return self._batch_size
+
+ @property
+ def n_train(self):
+ return self._n_train
+
+ def load(self, checkpoint_path=None):
+ checkpoint_path = checkpoint_path or self.checkpoint_path
+
+ try:
+ ckpt = tf.train.get_checkpoint_state(checkpoint_path)
+ print("Loading model: {}".format(ckpt.model_checkpoint_path))
+ self.saver.restore(self.sess, ckpt.model_checkpoint_path)
+ except Exception as e:
+ print("Error while attempting to load model: {}".format(e))
+
+ @abc.abstractclassmethod
+ def predict(self, x):
+ """
+ A nice function for prediction.
+ :param x: input array (length=n)
+ :type x: array-like
+ :return: (array-like) corresponding predicted outputs (length=n)
+ """
+ raise NotImplementedError("Implement predict()")
- @abc.abstractmethod
- def predict_sequence(self, x):
- raise NotImplementedError()
+ def save(self, iter, checkpoint_path=None):
+ checkpoint_path = checkpoint_path or self.checkpoint_path
+ self.saver.save(self.sess, checkpoint_path, global_step=iter)
def _build(self):
self.prediction = self._build_prediction()
@@ -41,18 +90,18 @@ class Model(nn.Module):
self.saver = tf.train.Saver(tf.global_variables())
def _build_loss(self):
- self.total_prediction_loss = tf.losses.mean_squared_error(
- self.target, self.prediction, weights=self.n_train
- )
-
+ self.total_prediction_loss = tf.losses.mean_squared_error(self.target,
+ self.prediction, weights=self.n_train)
+
# Don't count this as a loss!
- self.rmse = tf.sqrt(self.total_prediction_loss / self.n_train)
+ self.rmse = tf.sqrt(
+ self.total_prediction_loss / self.n_train)
return tf.losses.get_total_loss()
@abc.abstractclassmethod
def _build_prediction(self):
- raise NotImplementedError("Implement prediction for model")
+ raise NotImplementedError('Implement prediction for model')
class Autoregressive(Model):
@@ -60,62 +109,57 @@ class Autoregressive(Model):
Autoregressive models that take in a few of the most recent input samples
and predict the output at the last time point.
"""
-
- @abc.abstractproperty
+ def __init__(self, n_train, input_length, sess=None,
+ checkpoint_path=None):
+ super().__init__(n_train, sess=sess, checkpoint_path=checkpoint_path)
+ self._input_length = input_length
+ self.x = tf.placeholder(tf.float32, shape=(None, self.input_length))
+
+ @property
def input_length(self):
- raise NotImplementedError()
+ return self._input_length
- def predict_sequence(self, x: torch.Tensor, batch_size=None, verbose=False):
+ def predict(self, x, batch_size=None, verbose=False):
"""
Return 1D array of predictions same length as x
"""
- n = x.numel()
+ n = x.size
batch_size = batch_size or n
# Pad x with leading zeros:
- x = torch.cat((torch.zeros(self.input_length - 1).to(x.device), x))
+ x = np.concatenate((np.zeros(self.input_length - 1), x))
i = 0
y = []
while i < n:
+ if verbose:
+ print("model.predict {}/{}".format(i, n))
this_batch_size = np.min([batch_size, n - i])
# Reshape into a batch:
- x_mtx = torch.stack(
- [x[j : j + self.input_length] for j in range(i, i + this_batch_size)]
- )
+ x_mtx = np.stack([x[j: j + self.input_length]
+ for j in range(i, i + this_batch_size)])
# Predict and flatten.
- y.append(self(x_mtx).squeeze())
+ y.append(self.sess.run(self.prediction, feed_dict={self.x: x_mtx}) \
+ .flatten())
i += this_batch_size
- return torch.cat(y)
+ return np.concatenate(y)
class FullyConnected(Autoregressive):
"""
Autoregressive model taking in a sequence of the most recent inputs, putting
- them through a series of FC layers, and outputting the single output at the
+ them through a series of FC layers, and outputting the single output at the
last time step.
"""
+ def __init__(self, n_train, input_length, layer_sizes=(512,),
+ sess=None, checkpoint_path=None):
+ super().__init__(n_train, input_length, sess=sess,
+ checkpoint_path=checkpoint_path)
+ self._layer_sizes = layer_sizes
+ self._build()
- def __init__(self, net):
- super().__init__()
- self._net = net
-
- @property
- def input_length(self):
- return self._net[0][0].weight.data.shape[1]
-
- def forward(self, inputs):
- return self._net(inputs)
-
-
-def mlp(dx, dy, layer_sizes=None):
- def block(dx, dy, Activation=nn.ReLU):
- return nn.Sequential(nn.Linear(dx, dy), Activation())
-
- layer_sizes = [256, 256] if layer_sizes is None else layer_sizes
-
- net = nn.Sequential()
- in_features = dx
- for i, out_features in enumerate(layer_sizes):
- net.add_module("layer_%i" % i, block(in_features, out_features))
- in_features = out_features
- net.add_module("head", block(in_features, dy, Activation=nn.Tanh))
- return FullyConnected(net)
+ def _build_prediction(self):
+ h = self.x
+ for m in self._layer_sizes:
+ h = tf.contrib.layers.fully_connected(h, m)
+ y = -1.0 + 2.0 * tf.contrib.layers.fully_connected(h, 1,
+ activation_fn=tf.nn.sigmoid)
+ return y
diff --git a/reamp.py b/reamp.py
@@ -4,13 +4,11 @@ Reamp a .wav file
Assumes 24-bit WAV files
"""
-
-import os
from argparse import ArgumentParser
-
-import matplotlib.pyplot as plt
-import torch
+import tensorflow as tf
+import os
import wavio
+import matplotlib.pyplot as plt
import models
@@ -21,56 +19,41 @@ def _sampwidth_to_bits(x):
if __name__ == "__main__":
parser = ArgumentParser()
- parser.add_argument(
- "architecture", type=str, help="JSON filename containing NN architecture"
- )
- parser.add_argument(
- "params", type=str, help="directory holding model checkpoint to use"
- )
- parser.add_argument("input_file", type=str, help="Input .wav file to convert")
- parser.add_argument(
- "--output_file", type=str, default=None, help="Where to save the output"
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=8192,
- help="How many samples to process at a time. "
- + "Reduce if there are out-of-memory issues.",
- )
- parser.add_argument(
- "--target_file",
- type=str,
- default=None,
- help=".wav file of the true output (if you want to compare)",
- )
+ parser.add_argument("architecture", type=str,
+ help="JSON filename containing NN architecture")
+ parser.add_argument("checkpoint_dir", type=str,
+ help="directory holding model checkpoint to use")
+ parser.add_argument("input_file", type=str,
+ help="Input .wav file to convert")
+ parser.add_argument("--output_file", type=str, default=None,
+ help="Where to save the output")
+ parser.add_argument("--batch_size", type=int, default=8192,
+ help="How many samples to process at a time. " +
+ "Reduce if there are out-of-memory issues.")
+ parser.add_argument("--target_file", type=str, default=None,
+ help=".wav file of the true output (if you want to compare)")
args = parser.parse_args()
- output_file = (
- args.output_file
- if args.output_file is not None
- else args.input_file.rstrip(".wav") + "_reamped.wav"
- )
-
+ if args.output_file is None:
+ args.output_file = args.input_file.rstrip(".wav") + "_reamped.wav"
+
if os.path.isfile(args.output_file):
print("Output file exists; skip")
exit(1)
-
+
x = wavio.read(args.input_file)
rate, sampwidth = x.rate, x.sampwidth
bits = _sampwidth_to_bits(sampwidth)
- x_data = torch.Tensor(x.data.flatten() / 2 ** (bits - 1))
-
- model = models.from_json(args.architecture)
- model.load_state_dict(torch.load(args.params))
- with torch.no_grad():
- y = model.predict_sequence(
- x_data, batch_size=args.batch_size, verbose=True
- ).numpy()
- wavio.write(
- output_file, y * 2 ** (bits - 1), rate, scale="none", sampwidth=sampwidth
- )
-
+ x_data = x.data.flatten() / 2 ** (bits - 1)
+
+ with tf.Session() as sess:
+ model = models.from_json(args.architecture,
+ checkpoint_path=args.checkpoint_dir)
+ model.load()
+ y = model.predict(x_data, batch_size=args.batch_size, verbose=True)
+ wavio.write(args.output_file, y * 2 ** (bits - 1), rate, scale="none",
+ sampwidth=sampwidth)
+
if args.target_file is not None and os.path.isfile(args.target_file):
t = wavio.read(args.target_file)
t_data = t.data.flatten() / 2 ** (_sampwidth_to_bits(t.sampwidth) - 1)
@@ -80,3 +63,4 @@ if __name__ == "__main__":
plt.plot(y)
plt.legend(["Input", "Target", "Prediction"])
plt.show()
+
+\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
@@ -1,11 +1,2 @@
-# File: requirements.txt
-# Created Date: 2019-01-02
-# Author: Steven Atkinson ([email protected])
-
-black
-flake8
-matplotlib
-numpy
-pytest
-torch>=1
+tensorflow
wavio
diff --git a/train.py b/train.py
@@ -4,62 +4,42 @@
"""
Here's a script for training new models.
-
-TODO
-* Device
-* Lightning?
"""
-import abc
-import json
-import os
from argparse import ArgumentParser
-from time import time
-
-import matplotlib.pyplot as plt
import numpy as np
-import torch
-import torch.nn as nn
+import abc
+import tensorflow as tf
+import matplotlib.pyplot as plt
+from time import time
import wavio
+import json
+import os
import models
# Parameters for training
-check_training_at = np.concatenate(
- [10 ** pwr * np.array([1, 2, 5]) for pwr in np.arange(0, 7)]
-)
-plot_kwargs = {"window": (30000, 40000)}
+check_training_at = np.concatenate([10 ** pwr * np.array([1, 2, 5])
+ for pwr in np.arange(0, 7)])
+plot_kwargs = {
+ "window": (30000, 40000)
+}
wav_kwargs = {"window": (0, 5 * 44100)}
-torch.manual_seed(0)
-_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
-
-def ensure_save_dir(args):
- save_dir = (
- os.path.join(os.path.dirname(__file__), "output")
- if args.save_dir is None
- else args.save_dir
- )
- if not os.path.isdir(save_dir):
- os.makedirs(save_dir)
- return save_dir
-
class Data(object):
"""
- Object for holding data and spitting out minibatches for training/segments
+ Object for holding data and spitting out minibatches for training/segments
for testing.
"""
-
def __init__(self, fname, input_length, batch_size=None):
xy = np.load(fname)
- self.x = torch.Tensor(xy[0])
- self.y = torch.Tensor(xy[1])
- self._n = self.x.numel() - input_length + 1
+ self.x = xy[0]
+ self.y = xy[1]
+ self._n = self.x.size - input_length + 1
self.input_length = input_length
self.batch_size = batch_size
-
+
@property
def n(self):
return self._n
@@ -70,101 +50,70 @@ class Data(object):
"""
if ilist is None:
n = n or self.batch_size
- ilist = torch.randint(0, self.n, size=(n,))
- x = torch.stack([self.x[i : i + self.input_length] for i in ilist])
- y = torch.stack([self.y[i + self.input_length - 1] for i in ilist])[:, None]
- return x.to(_device), y.to(_device)
+ ilist = np.random.randint(0, self.n, size=(n,))
+ x = np.stack([self.x[i: i + self.input_length] for i in ilist])
+ y = np.array([self.y[i + self.input_length - 1] for i in ilist]) \
+ [:, np.newaxis]
+ return x, y
def sequence(self, start, end):
end += self.input_length - 1
- return self.x[start:end], self.y[start + self.input_length - 1 : end]
-
-
-def train_step(model, optimizer, batch):
- model.train()
- model.zero_grad()
- inputs, targets = batch
- preds = model(inputs)
- loss = nn.MSELoss()(preds, targets)
- loss.backward()
- optimizer.step()
- return loss.item()
-
-
-def validation_step(model, batch):
- with torch.no_grad():
- model.eval()
- inputs, targets = batch
- return nn.MSELoss()(model(inputs), targets).item()
+ return self.x[start: end], self.y[start + self.input_length - 1: end]
-def train(
- model,
- train_data,
- batch_size=None,
- n_minibatches=10,
- validation_data=None,
- plot_at=(),
- wav_at=(),
- plot_kwargs={},
- wav_kwargs={},
- save_every=100,
- validate_every=100,
-):
- optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
+def train(model, train_data, batch_size=None, n_minibatches=10,
+ validation_data=None, plot_at=(), wav_at=(), plot_kwargs={},
+ wav_kwargs={}, save_every=100, validate_every=100):
+ save_dir = os.path.dirname(model.checkpoint_path)
+ sess = model.sess
+ opt = tf.train.AdamOptimizer().minimize(model.loss)
+ sess.run(tf.global_variables_initializer()) # For opt
t_loss_list, v_loss_list = [], []
t0 = time()
for i in range(n_minibatches):
- batch = train_data.minibatch(batch_size)
- t_loss = train_step(model, optimizer, batch)
+ x, y = train_data.minibatch(batch_size)
+ t_loss, _ = sess.run((model.rmse, opt),
+ feed_dict={model.x: x, model.target: y})
t_loss_list.append([i, t_loss])
- print(
- "t={:7} | MB {:>7} / {:>7} | TLoss={:8}".format(
- int(time() - t0), i + 1, n_minibatches, t_loss
- )
- )
-
+ print("t={:7} | MB {:>7} / {:>7} | TLoss={:8}".format(int(time() - t0),
+ i + 1, n_minibatches, t_loss))
+
# Callbacks, basically...
if i + 1 in plot_at:
- plot_predictions(
- model,
- validation_data if validation_data is not None else train_data,
- title="Minibatch {}".format(i + 1),
- fname="{}/mb_{}.png".format(save_dir, i + 1),
- **plot_kwargs
- )
+ plot_predictions(model,
+ validation_data if validation_data is not None else train_data,
+ title="Minibatch {}".format(i + 1),
+ fname="{}/mb_{}.png".format(save_dir, i + 1),
+ **plot_kwargs)
if i + 1 in wav_at:
print("Making wav for mb {}".format(i + 1))
- predict(
- model,
- validation_data,
- save_wav_file="{}/predict_{}.wav".format(save_dir, i + 1),
- **wav_kwargs
- )
+ predict(model, validation_data,
+ save_wav_file="{}/predict_{}.wav".format(save_dir, i + 1),
+ **wav_kwargs)
if (i + 1) % save_every == 0:
- torch.save(model.state_dict(), os.path.join(save_dir, "model_%i.pt" % i))
+ model.save(iter=i + 1)
if i == 0 or (i + 1) % validate_every == 0:
- v_loss = validation_step(model, batch)
+ v_loss, _ = sess.run((model.rmse, opt),
+ feed_dict={model.x: x, model.target: y})
print("VLoss={:8}".format(v_loss))
v_loss_list.append([i, v_loss])
-
+
# After training loop...
if validation_data is not None:
- batch = validation_data.minibatch(train_data.batch_size)
- v_loss = validation_step(model, batch)
+ x, y = validation_data.minibatch(train_data.batch_size)
+ v_loss = sess.run(model.rmse,
+ feed_dict={model.x: x, model.target: y})
print("Validation loss={}".format(v_loss))
- torch.save(model.state_dict(), os.path.join(save_dir, "model.pt"))
return np.array(t_loss_list).T, np.array(v_loss_list).T
def plot_predictions(model, data, title=None, fname=None, window=None):
- with torch.no_grad():
- x, y, t = predict(model, data, window=window)
- plt.figure(figsize=(12, 4))
- plt.plot(x.cpu())
- plt.plot(t.cpu())
- plt.plot(y.cpu())
- plt.legend(("Input", "Target", "Prediction"))
+ x, y, t = predict(model, data, window=window)
+ plt.figure(figsize=(12, 4))
+ plt.plot(x)
+ plt.plot(t)
+ plt.plot(y)
+ plt.legend(('Input', 'Target', 'Prediction'))
if title is not None:
plt.title(title)
if fname is not None:
@@ -184,27 +133,21 @@ def plot_loss(t_loss, v_loss, fname):
plt.legend(("Training", "Validation"))
plt.savefig(fname)
plt.close()
-
+
def predict(model, data, window=None, save_wav_file=None):
- with torch.no_grad():
- x, t = data.x.to(_device), data.y.to(_device)
- if window is not None:
- x, t = x[window[0] : window[1]], t[window[0] : window[1]]
- y = model.predict_sequence(x).squeeze()
+ x, t = data.x, data.y
+ if window is not None:
+ x, t = x[window[0]: window[1]], t[window[0]: window[1]]
+ y = model.predict(x).flatten()
- if save_wav_file is not None:
- rate = 44100 # TODO from data
- sampwidth = 3 # 24-bit
- wavio.write(
- save_wav_file,
- y.cpu().numpy() * 2 ** 23,
- rate,
- scale="none",
- sampwidth=sampwidth,
- )
+ if save_wav_file is not None:
+ rate = 44100 # TODO from data
+ sampwidth = 3 # 24-bit
+ wavio.write(save_wav_file, y * 2 ** 23, rate, scale="none",
+ sampwidth=sampwidth)
- return x, y, t
+ return x, y, t
def _get_input_length(archfile):
@@ -213,54 +156,43 @@ def _get_input_length(archfile):
if __name__ == "__main__":
parser = ArgumentParser()
- parser.add_argument(
- "model_arch", type=str, help="JSON containing model architecture"
- )
- parser.add_argument("train_data", type=str, help="Filename for training data")
- parser.add_argument(
- "validation_data", type=str, help="Filename for validation data"
- )
- parser.add_argument(
- "--save_dir",
- type=str,
- default=None,
- help="Where to save the run data (checkpoints, prediction...)",
- )
- parser.add_argument(
- "--batch_size", type=str, default=4096, help="Number of data per minibatch"
- )
- parser.add_argument(
- "--minibatches", type=int, default=10, help="Number of minibatches to train for"
- )
+ parser.add_argument("model_arch", type=str,
+ help="JSON containing model architecture")
+ parser.add_argument("train_data", type=str,
+ help="Filename for training data")
+ parser.add_argument("validation_data", type=str,
+ help="Filename for validation data")
+ parser.add_argument("--save_dir", type=str, default=None,
+ help="Where to save the run data (checkpoints, prediction...)")
+ parser.add_argument("--batch_size", type=str, default=4096,
+ help="Number of data per minibatch")
+ parser.add_argument("--minibatches", type=int, default=10,
+ help="Number of minibatches to train for")
args = parser.parse_args()
- save_dir = ensure_save_dir(args)
input_length = _get_input_length(args.model_arch) # Ugh, kludge
# Load the data
- train_data = Data(args.train_data, input_length, batch_size=args.batch_size)
+ train_data = Data(args.train_data, input_length,
+ batch_size=args.batch_size)
validate_data = Data(args.validation_data, input_length)
-
+
# Training
- model = models.from_json(args.model_arch).to(_device)
- t_loss_list, v_loss_list = train(
- model,
- train_data,
- validation_data=validate_data,
- n_minibatches=args.minibatches,
- plot_at=check_training_at,
- wav_at=check_training_at,
- plot_kwargs=plot_kwargs,
- wav_kwargs=wav_kwargs,
- )
- plot_predictions(model, validate_data, window=(0, 44100))
-
- print("Predict the full output")
- _device = "cpu"
- model.to(_device)
- predict(
- model,
- validate_data,
- save_wav_file="{}/predict.wav".format(save_dir),
- )
-
- plot_loss(t_loss_list, v_loss_list, "{}/loss.png".format(save_dir))
+ with tf.Session() as sess:
+ model = models.from_json(args.model_arch, train_data.n,
+ checkpoint_path=os.path.join(args.save_dir, "model.ckpt"))
+ t_loss_list, v_loss_list = train(
+ model,
+ train_data,
+ validation_data=validate_data,
+ n_minibatches=args.minibatches,
+ plot_at=check_training_at,
+ wav_at=check_training_at,
+ plot_kwargs=plot_kwargs,
+ wav_kwargs=wav_kwargs)
+ plot_predictions(model, validate_data, window=(0, 44100))
+ print("Predict the full output")
+ predict(model, validate_data,
+ save_wav_file="{}/predict.wav".format(
+ os.path.dirname(model.checkpoint_path)))
+
+ plot_loss(t_loss_list, v_loss_list, "{}/loss.png".format(args.save_dir))