commit 9bc94d6c185be755135e1ed24919ddd92a78cfe4
parent aa489d720868354495f84e1f9f035e3869dd4dea
Author: Steven Atkinson <[email protected]>
Date: Sun, 21 Feb 2021 15:48:37 -0500
Merge pull request #2 from sdatkinson/pytorch
PyTorch
Diffstat:
M | README.md | | | 4 | +++- |
M | models.py | | | 158 | +++++++++++++++++++++++++++++-------------------------------------------------- |
M | reamp.py | | | 79 | +++++++++++++++++++++++++++++++++++++++++++++++-------------------------------- |
M | requirements.txt | | | 11 | ++++++++++- |
M | train.py | | | 270 | +++++++++++++++++++++++++++++++++++++++++++++++++------------------------------ |
5 files changed, 286 insertions(+), 236 deletions(-)
diff --git a/README.md b/README.md
@@ -4,7 +4,9 @@ Let's use deep learning to make a model of a guitar amp!
## Setup
-Use `pip install -r requirements.txt` as well as intalling the TensorFlow version of your choice (e.g. `pip install tensorflow-gpu`).
+```bash
+`pip install -r requirements.txt`
+```
## Autoregressive models
diff --git a/models.py b/models.py
@@ -2,84 +2,35 @@
# File Created: Sunday, 30th December 2018 9:42:29 pm
# Author: Steven Atkinson ([email protected])
-import numpy as np
-import tensorflow as tf
import abc
-from tempfile import mkdtemp
-import os
import json
+import os
+from tempfile import mkdtemp
+
+import numpy as np
+import torch
+import torch.nn as nn
-def from_json(f, n_train=1, checkpoint_path=None):
+def from_json(f):
if isinstance(f, str):
with open(f, "r") as json_file:
f = json.load(json_file)
-
+
if f["type"] == "FullyConnected":
- return FullyConnected(n_train, f["input_length"],
- layer_sizes=f["layer_sizes"], checkpoint_path=checkpoint_path)
+ return mlp(f["input_length"], 1, layer_sizes=f["layer_sizes"])
else:
- raise NotImplementedError("Model type {} unrecognized".format(
- f["type"]))
+ raise NotImplementedError("Model type {} unrecognized".format(f["type"]))
-class Model(object):
+class Model(nn.Module):
"""
Model parent class
"""
- def __init__(self, n_train, sess=None, checkpoint_path=None):
- """
- Make sure child classes call _build() after this!
- """
- if sess is None:
- sess = tf.get_default_session()
- self.sess = sess
-
- if checkpoint_path is None:
- checkpoint_path = os.path.join(mkdtemp(), "model.ckpt")
- if not os.path.isdir(os.path.dirname(checkpoint_path)):
- os.makedirs(os.path.dirname(checkpoint_path))
- self.checkpoint_path = checkpoint_path
-
- # self._batch_size = batch_size
- self._n_train = n_train
- self.target = tf.placeholder(tf.float32, shape=(None, 1))
- self.prediction = None
- self.total_prediction_loss = None
- self.rmse = None
- self.loss = None # Includes any regularization
-
- # @property
- # def batch_size(self):
- # return self._batch_size
-
- @property
- def n_train(self):
- return self._n_train
-
- def load(self, checkpoint_path=None):
- checkpoint_path = checkpoint_path or self.checkpoint_path
-
- try:
- ckpt = tf.train.get_checkpoint_state(checkpoint_path)
- print("Loading model: {}".format(ckpt.model_checkpoint_path))
- self.saver.restore(self.sess, ckpt.model_checkpoint_path)
- except Exception as e:
- print("Error while attempting to load model: {}".format(e))
-
- @abc.abstractclassmethod
- def predict(self, x):
- """
- A nice function for prediction.
- :param x: input array (length=n)
- :type x: array-like
- :return: (array-like) corresponding predicted outputs (length=n)
- """
- raise NotImplementedError("Implement predict()")
- def save(self, iter, checkpoint_path=None):
- checkpoint_path = checkpoint_path or self.checkpoint_path
- self.saver.save(self.sess, checkpoint_path, global_step=iter)
+ @abc.abstractmethod
+ def predict_sequence(self, x):
+ raise NotImplementedError()
def _build(self):
self.prediction = self._build_prediction()
@@ -90,18 +41,18 @@ class Model(object):
self.saver = tf.train.Saver(tf.global_variables())
def _build_loss(self):
- self.total_prediction_loss = tf.losses.mean_squared_error(self.target,
- self.prediction, weights=self.n_train)
-
+ self.total_prediction_loss = tf.losses.mean_squared_error(
+ self.target, self.prediction, weights=self.n_train
+ )
+
# Don't count this as a loss!
- self.rmse = tf.sqrt(
- self.total_prediction_loss / self.n_train)
+ self.rmse = tf.sqrt(self.total_prediction_loss / self.n_train)
return tf.losses.get_total_loss()
@abc.abstractclassmethod
def _build_prediction(self):
- raise NotImplementedError('Implement prediction for model')
+ raise NotImplementedError("Implement prediction for model")
class Autoregressive(Model):
@@ -109,57 +60,62 @@ class Autoregressive(Model):
Autoregressive models that take in a few of the most recent input samples
and predict the output at the last time point.
"""
- def __init__(self, n_train, input_length, sess=None,
- checkpoint_path=None):
- super().__init__(n_train, sess=sess, checkpoint_path=checkpoint_path)
- self._input_length = input_length
- self.x = tf.placeholder(tf.float32, shape=(None, self.input_length))
-
- @property
+
+ @abc.abstractproperty
def input_length(self):
- return self._input_length
+ raise NotImplementedError()
- def predict(self, x, batch_size=None, verbose=False):
+ def predict_sequence(self, x: torch.Tensor, batch_size=None, verbose=False):
"""
Return 1D array of predictions same length as x
"""
- n = x.size
+ n = x.numel()
batch_size = batch_size or n
# Pad x with leading zeros:
- x = np.concatenate((np.zeros(self.input_length - 1), x))
+ x = torch.cat((torch.zeros(self.input_length - 1).to(x.device), x))
i = 0
y = []
while i < n:
- if verbose:
- print("model.predict {}/{}".format(i, n))
this_batch_size = np.min([batch_size, n - i])
# Reshape into a batch:
- x_mtx = np.stack([x[j: j + self.input_length]
- for j in range(i, i + this_batch_size)])
+ x_mtx = torch.stack(
+ [x[j : j + self.input_length] for j in range(i, i + this_batch_size)]
+ )
# Predict and flatten.
- y.append(self.sess.run(self.prediction, feed_dict={self.x: x_mtx}) \
- .flatten())
+ y.append(self(x_mtx).squeeze())
i += this_batch_size
- return np.concatenate(y)
+ return torch.cat(y)
class FullyConnected(Autoregressive):
"""
Autoregressive model taking in a sequence of the most recent inputs, putting
- them through a series of FC layers, and outputting the single output at the
+ them through a series of FC layers, and outputting the single output at the
last time step.
"""
- def __init__(self, n_train, input_length, layer_sizes=(512,),
- sess=None, checkpoint_path=None):
- super().__init__(n_train, input_length, sess=sess,
- checkpoint_path=checkpoint_path)
- self._layer_sizes = layer_sizes
- self._build()
- def _build_prediction(self):
- h = self.x
- for m in self._layer_sizes:
- h = tf.contrib.layers.fully_connected(h, m)
- y = -1.0 + 2.0 * tf.contrib.layers.fully_connected(h, 1,
- activation_fn=tf.nn.sigmoid)
- return y
+ def __init__(self, net):
+ super().__init__()
+ self._net = net
+
+ @property
+ def input_length(self):
+ return self._net[0][0].weight.data.shape[1]
+
+ def forward(self, inputs):
+ return self._net(inputs)
+
+
+def mlp(dx, dy, layer_sizes=None):
+ def block(dx, dy, Activation=nn.ReLU):
+ return nn.Sequential(nn.Linear(dx, dy), Activation())
+
+ layer_sizes = [256, 256] if layer_sizes is None else layer_sizes
+
+ net = nn.Sequential()
+ in_features = dx
+ for i, out_features in enumerate(layer_sizes):
+ net.add_module("layer_%i" % i, block(in_features, out_features))
+ in_features = out_features
+ net.add_module("head", block(in_features, dy, Activation=nn.Tanh))
+ return FullyConnected(net)
diff --git a/reamp.py b/reamp.py
@@ -4,11 +4,13 @@ Reamp a .wav file
Assumes 24-bit WAV files
"""
-from argparse import ArgumentParser
-import tensorflow as tf
+
import os
-import wavio
+from argparse import ArgumentParser
+
import matplotlib.pyplot as plt
+import torch
+import wavio
import models
@@ -19,41 +21,56 @@ def _sampwidth_to_bits(x):
if __name__ == "__main__":
parser = ArgumentParser()
- parser.add_argument("architecture", type=str,
- help="JSON filename containing NN architecture")
- parser.add_argument("checkpoint_dir", type=str,
- help="directory holding model checkpoint to use")
- parser.add_argument("input_file", type=str,
- help="Input .wav file to convert")
- parser.add_argument("--output_file", type=str, default=None,
- help="Where to save the output")
- parser.add_argument("--batch_size", type=int, default=8192,
- help="How many samples to process at a time. " +
- "Reduce if there are out-of-memory issues.")
- parser.add_argument("--target_file", type=str, default=None,
- help=".wav file of the true output (if you want to compare)")
+ parser.add_argument(
+ "architecture", type=str, help="JSON filename containing NN architecture"
+ )
+ parser.add_argument(
+ "params", type=str, help="directory holding model checkpoint to use"
+ )
+ parser.add_argument("input_file", type=str, help="Input .wav file to convert")
+ parser.add_argument(
+ "--output_file", type=str, default=None, help="Where to save the output"
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=8192,
+ help="How many samples to process at a time. "
+ + "Reduce if there are out-of-memory issues.",
+ )
+ parser.add_argument(
+ "--target_file",
+ type=str,
+ default=None,
+ help=".wav file of the true output (if you want to compare)",
+ )
args = parser.parse_args()
- if args.output_file is None:
- args.output_file = args.input_file.rstrip(".wav") + "_reamped.wav"
-
+ output_file = (
+ args.output_file
+ if args.output_file is not None
+ else args.input_file.rstrip(".wav") + "_reamped.wav"
+ )
+
if os.path.isfile(args.output_file):
print("Output file exists; skip")
exit(1)
-
+
x = wavio.read(args.input_file)
rate, sampwidth = x.rate, x.sampwidth
bits = _sampwidth_to_bits(sampwidth)
- x_data = x.data.flatten() / 2 ** (bits - 1)
-
- with tf.Session() as sess:
- model = models.from_json(args.architecture,
- checkpoint_path=args.checkpoint_dir)
- model.load()
- y = model.predict(x_data, batch_size=args.batch_size, verbose=True)
- wavio.write(args.output_file, y * 2 ** (bits - 1), rate, scale="none",
- sampwidth=sampwidth)
-
+ x_data = torch.Tensor(x.data.flatten() / 2 ** (bits - 1))
+
+ model = models.from_json(args.architecture)
+ model.load_state_dict(torch.load(args.params))
+ with torch.no_grad():
+ y = model.predict_sequence(
+ x_data, batch_size=args.batch_size, verbose=True
+ ).numpy()
+ wavio.write(
+ output_file, y * 2 ** (bits - 1), rate, scale="none", sampwidth=sampwidth
+ )
+
if args.target_file is not None and os.path.isfile(args.target_file):
t = wavio.read(args.target_file)
t_data = t.data.flatten() / 2 ** (_sampwidth_to_bits(t.sampwidth) - 1)
@@ -63,4 +80,3 @@ if __name__ == "__main__":
plt.plot(y)
plt.legend(["Input", "Target", "Prediction"])
plt.show()
-
-\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
@@ -1,2 +1,11 @@
-tensorflow
+# File: requirements.txt
+# Created Date: 2019-01-02
+# Author: Steven Atkinson ([email protected])
+
+black
+flake8
+matplotlib
+numpy
+pytest
+torch>=1
wavio
diff --git a/train.py b/train.py
@@ -4,42 +4,62 @@
"""
Here's a script for training new models.
+
+TODO
+* Device
+* Lightning?
"""
-from argparse import ArgumentParser
-import numpy as np
import abc
-import tensorflow as tf
-import matplotlib.pyplot as plt
-from time import time
-import wavio
import json
import os
+from argparse import ArgumentParser
+from time import time
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch.nn as nn
+import wavio
import models
# Parameters for training
-check_training_at = np.concatenate([10 ** pwr * np.array([1, 2, 5])
- for pwr in np.arange(0, 7)])
-plot_kwargs = {
- "window": (30000, 40000)
-}
+check_training_at = np.concatenate(
+ [10 ** pwr * np.array([1, 2, 5]) for pwr in np.arange(0, 7)]
+)
+plot_kwargs = {"window": (30000, 40000)}
wav_kwargs = {"window": (0, 5 * 44100)}
+torch.manual_seed(0)
+_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+def ensure_save_dir(args):
+ save_dir = (
+ os.path.join(os.path.dirname(__file__), "output")
+ if args.save_dir is None
+ else args.save_dir
+ )
+ if not os.path.isdir(save_dir):
+ os.makedirs(save_dir)
+ return save_dir
+
class Data(object):
"""
- Object for holding data and spitting out minibatches for training/segments
+ Object for holding data and spitting out minibatches for training/segments
for testing.
"""
+
def __init__(self, fname, input_length, batch_size=None):
xy = np.load(fname)
- self.x = xy[0]
- self.y = xy[1]
- self._n = self.x.size - input_length + 1
+ self.x = torch.Tensor(xy[0])
+ self.y = torch.Tensor(xy[1])
+ self._n = self.x.numel() - input_length + 1
self.input_length = input_length
self.batch_size = batch_size
-
+
@property
def n(self):
return self._n
@@ -50,70 +70,101 @@ class Data(object):
"""
if ilist is None:
n = n or self.batch_size
- ilist = np.random.randint(0, self.n, size=(n,))
- x = np.stack([self.x[i: i + self.input_length] for i in ilist])
- y = np.array([self.y[i + self.input_length - 1] for i in ilist]) \
- [:, np.newaxis]
- return x, y
+ ilist = torch.randint(0, self.n, size=(n,))
+ x = torch.stack([self.x[i : i + self.input_length] for i in ilist])
+ y = torch.stack([self.y[i + self.input_length - 1] for i in ilist])[:, None]
+ return x.to(_device), y.to(_device)
def sequence(self, start, end):
end += self.input_length - 1
- return self.x[start: end], self.y[start + self.input_length - 1: end]
+ return self.x[start:end], self.y[start + self.input_length - 1 : end]
+
+
+def train_step(model, optimizer, batch):
+ model.train()
+ model.zero_grad()
+ inputs, targets = batch
+ preds = model(inputs)
+ loss = nn.MSELoss()(preds, targets)
+ loss.backward()
+ optimizer.step()
+ return loss.item()
+
+
+def validation_step(model, batch):
+ with torch.no_grad():
+ model.eval()
+ inputs, targets = batch
+ return nn.MSELoss()(model(inputs), targets).item()
-def train(model, train_data, batch_size=None, n_minibatches=10,
- validation_data=None, plot_at=(), wav_at=(), plot_kwargs={},
- wav_kwargs={}, save_every=100, validate_every=100):
- save_dir = os.path.dirname(model.checkpoint_path)
- sess = model.sess
- opt = tf.train.AdamOptimizer().minimize(model.loss)
- sess.run(tf.global_variables_initializer()) # For opt
+def train(
+ model,
+ train_data,
+ batch_size=None,
+ n_minibatches=10,
+ validation_data=None,
+ plot_at=(),
+ wav_at=(),
+ plot_kwargs={},
+ wav_kwargs={},
+ save_every=100,
+ validate_every=100,
+):
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
t_loss_list, v_loss_list = [], []
t0 = time()
for i in range(n_minibatches):
- x, y = train_data.minibatch(batch_size)
- t_loss, _ = sess.run((model.rmse, opt),
- feed_dict={model.x: x, model.target: y})
+ batch = train_data.minibatch(batch_size)
+ t_loss = train_step(model, optimizer, batch)
t_loss_list.append([i, t_loss])
- print("t={:7} | MB {:>7} / {:>7} | TLoss={:8}".format(int(time() - t0),
- i + 1, n_minibatches, t_loss))
-
+ print(
+ "t={:7} | MB {:>7} / {:>7} | TLoss={:8}".format(
+ int(time() - t0), i + 1, n_minibatches, t_loss
+ )
+ )
+
# Callbacks, basically...
if i + 1 in plot_at:
- plot_predictions(model,
- validation_data if validation_data is not None else train_data,
- title="Minibatch {}".format(i + 1),
- fname="{}/mb_{}.png".format(save_dir, i + 1),
- **plot_kwargs)
+ plot_predictions(
+ model,
+ validation_data if validation_data is not None else train_data,
+ title="Minibatch {}".format(i + 1),
+ fname="{}/mb_{}.png".format(save_dir, i + 1),
+ **plot_kwargs
+ )
if i + 1 in wav_at:
print("Making wav for mb {}".format(i + 1))
- predict(model, validation_data,
- save_wav_file="{}/predict_{}.wav".format(save_dir, i + 1),
- **wav_kwargs)
+ predict(
+ model,
+ validation_data,
+ save_wav_file="{}/predict_{}.wav".format(save_dir, i + 1),
+ **wav_kwargs
+ )
if (i + 1) % save_every == 0:
- model.save(iter=i + 1)
+ torch.save(model.state_dict(), os.path.join(save_dir, "model_%i.pt" % i))
if i == 0 or (i + 1) % validate_every == 0:
- v_loss, _ = sess.run((model.rmse, opt),
- feed_dict={model.x: x, model.target: y})
+ v_loss = validation_step(model, batch)
print("VLoss={:8}".format(v_loss))
v_loss_list.append([i, v_loss])
-
+
# After training loop...
if validation_data is not None:
- x, y = validation_data.minibatch(train_data.batch_size)
- v_loss = sess.run(model.rmse,
- feed_dict={model.x: x, model.target: y})
+ batch = validation_data.minibatch(train_data.batch_size)
+ v_loss = validation_step(model, batch)
print("Validation loss={}".format(v_loss))
+ torch.save(model.state_dict(), os.path.join(save_dir, "model.pt"))
return np.array(t_loss_list).T, np.array(v_loss_list).T
def plot_predictions(model, data, title=None, fname=None, window=None):
- x, y, t = predict(model, data, window=window)
- plt.figure(figsize=(12, 4))
- plt.plot(x)
- plt.plot(t)
- plt.plot(y)
- plt.legend(('Input', 'Target', 'Prediction'))
+ with torch.no_grad():
+ x, y, t = predict(model, data, window=window)
+ plt.figure(figsize=(12, 4))
+ plt.plot(x.cpu())
+ plt.plot(t.cpu())
+ plt.plot(y.cpu())
+ plt.legend(("Input", "Target", "Prediction"))
if title is not None:
plt.title(title)
if fname is not None:
@@ -133,21 +184,27 @@ def plot_loss(t_loss, v_loss, fname):
plt.legend(("Training", "Validation"))
plt.savefig(fname)
plt.close()
-
+
def predict(model, data, window=None, save_wav_file=None):
- x, t = data.x, data.y
- if window is not None:
- x, t = x[window[0]: window[1]], t[window[0]: window[1]]
- y = model.predict(x).flatten()
+ with torch.no_grad():
+ x, t = data.x.to(_device), data.y.to(_device)
+ if window is not None:
+ x, t = x[window[0] : window[1]], t[window[0] : window[1]]
+ y = model.predict_sequence(x).squeeze()
- if save_wav_file is not None:
- rate = 44100 # TODO from data
- sampwidth = 3 # 24-bit
- wavio.write(save_wav_file, y * 2 ** 23, rate, scale="none",
- sampwidth=sampwidth)
+ if save_wav_file is not None:
+ rate = 44100 # TODO from data
+ sampwidth = 3 # 24-bit
+ wavio.write(
+ save_wav_file,
+ y.cpu().numpy() * 2 ** 23,
+ rate,
+ scale="none",
+ sampwidth=sampwidth,
+ )
- return x, y, t
+ return x, y, t
def _get_input_length(archfile):
@@ -156,43 +213,54 @@ def _get_input_length(archfile):
if __name__ == "__main__":
parser = ArgumentParser()
- parser.add_argument("model_arch", type=str,
- help="JSON containing model architecture")
- parser.add_argument("train_data", type=str,
- help="Filename for training data")
- parser.add_argument("validation_data", type=str,
- help="Filename for validation data")
- parser.add_argument("--save_dir", type=str, default=None,
- help="Where to save the run data (checkpoints, prediction...)")
- parser.add_argument("--batch_size", type=str, default=4096,
- help="Number of data per minibatch")
- parser.add_argument("--minibatches", type=int, default=10,
- help="Number of minibatches to train for")
+ parser.add_argument(
+ "model_arch", type=str, help="JSON containing model architecture"
+ )
+ parser.add_argument("train_data", type=str, help="Filename for training data")
+ parser.add_argument(
+ "validation_data", type=str, help="Filename for validation data"
+ )
+ parser.add_argument(
+ "--save_dir",
+ type=str,
+ default=None,
+ help="Where to save the run data (checkpoints, prediction...)",
+ )
+ parser.add_argument(
+ "--batch_size", type=str, default=4096, help="Number of data per minibatch"
+ )
+ parser.add_argument(
+ "--minibatches", type=int, default=10, help="Number of minibatches to train for"
+ )
args = parser.parse_args()
+ save_dir = ensure_save_dir(args)
input_length = _get_input_length(args.model_arch) # Ugh, kludge
# Load the data
- train_data = Data(args.train_data, input_length,
- batch_size=args.batch_size)
+ train_data = Data(args.train_data, input_length, batch_size=args.batch_size)
validate_data = Data(args.validation_data, input_length)
-
+
# Training
- with tf.Session() as sess:
- model = models.from_json(args.model_arch, train_data.n,
- checkpoint_path=os.path.join(args.save_dir, "model.ckpt"))
- t_loss_list, v_loss_list = train(
- model,
- train_data,
- validation_data=validate_data,
- n_minibatches=args.minibatches,
- plot_at=check_training_at,
- wav_at=check_training_at,
- plot_kwargs=plot_kwargs,
- wav_kwargs=wav_kwargs)
- plot_predictions(model, validate_data, window=(0, 44100))
- print("Predict the full output")
- predict(model, validate_data,
- save_wav_file="{}/predict.wav".format(
- os.path.dirname(model.checkpoint_path)))
-
- plot_loss(t_loss_list, v_loss_list, "{}/loss.png".format(args.save_dir))
+ model = models.from_json(args.model_arch).to(_device)
+ t_loss_list, v_loss_list = train(
+ model,
+ train_data,
+ validation_data=validate_data,
+ n_minibatches=args.minibatches,
+ plot_at=check_training_at,
+ wav_at=check_training_at,
+ plot_kwargs=plot_kwargs,
+ wav_kwargs=wav_kwargs,
+ )
+ plot_predictions(model, validate_data, window=(0, 44100))
+
+ print("Predict the full output")
+ _device = "cpu"
+ model.to(_device)
+ predict(
+ model,
+ validate_data,
+ save_wav_file="{}/predict.wav".format(save_dir),
+ )
+
+ plot_loss(t_loss_list, v_loss_list, "{}/loss.png".format(save_dir))