neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 2efabe4a5d04658bb93dc2679f31a75d4c17527f
parent 1c2565b173a3b49150c890e15c0d20f7f21367be
Author: Steven Atkinson <[email protected]>
Date:   Wed, 20 Apr 2022 23:21:05 -0700

Version 0.2.0

Diffstat:
M.gitignore | 34+++++++++++++++++++++++++++++++++-
A.pre-commit-config.yaml | 8++++++++
MLICENSE | 2+-
MREADME.md | 86+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------
Darchitectures/fc_32_32_32.json | 9---------
Abin/export.py | 41+++++++++++++++++++++++++++++++++++++++++
Abin/run.py | 33+++++++++++++++++++++++++++++++++
Abin/train/inputs/config_data_single_pair.json | 19+++++++++++++++++++
Abin/train/inputs/config_data_two_pairs.json | 17+++++++++++++++++
Abin/train/inputs/config_learning.json | 18++++++++++++++++++
Abin/train/inputs/config_model.json | 26++++++++++++++++++++++++++
Abin/train/main.py | 138+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Ddata/emissary/train.npy | 0
Ddata/emissary/validate.npy | 0
Ddata/reaper-js-distortion/train.npy | 0
Ddata/reaper-js-distortion/validate.npy | 0
Aenvironment.yml | 25+++++++++++++++++++++++++
Dmodels.py | 165-------------------------------------------------------------------------------
Anam/__init__.py | 9+++++++++
Anam/_core.py | 13+++++++++++++
Anam/_version.py | 1+
Anam/data.py | 221+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Anam/models/__init__.py | 9+++++++++
Anam/models/_base.py | 58++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Anam/models/_exportable.py | 33+++++++++++++++++++++++++++++++++
Anam/models/base.py | 97+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Anam/models/conv_net.py | 290+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Anam/models/linear.py | 70++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Dreamp.py | 67-------------------------------------------------------------------
Mrequirements.txt | 18+++++++++++++++++-
Asetup.py | 31+++++++++++++++++++++++++++++++
Atests/test_install.py | 13+++++++++++++
Atests/test_nam/__init__.py | 0
Atests/test_nam/test_importable.py | 13+++++++++++++
Dtrain.py | 198-------------------------------------------------------------------------------
35 files changed, 1307 insertions(+), 455 deletions(-)

diff --git a/.gitignore b/.gitignore @@ -20,6 +20,8 @@ parts/ sdist/ var/ wheels/ +pip-wheel-metadata/ +share/python-wheels/ *.egg-info/ .installed.cfg *.egg @@ -38,12 +40,14 @@ pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ +.nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover +*.py,cover .hypothesis/ .pytest_cache/ @@ -55,6 +59,7 @@ coverage.xml *.log local_settings.py db.sqlite3 +db.sqlite3-journal # Flask stuff: instance/ @@ -72,11 +77,26 @@ target/ # Jupyter Notebook .ipynb_checkpoints +# IPython +profile_default/ +ipython_config.py + # pyenv .python-version -# celery beat schedule file +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff celerybeat-schedule +celerybeat.pid # SageMath parsed files *.sage.py @@ -89,6 +109,8 @@ venv/ ENV/ env.bak/ venv.bak/ +venv-*/ +tmp/ # Spyder project settings .spyderproject @@ -102,3 +124,13 @@ venv.bak/ # mypy .mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# etc +data/ +.vscode/ +lightning_logs/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml @@ -0,0 +1,7 @@ +repos: + - repo: https://github.com/psf/black + rev: stable + hooks: + - id: black + language_version: python3.7 + +\ No newline at end of file diff --git a/LICENSE b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2019 Steven Atkinson +Copyright (c) 2022 Steven Atkinson Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md @@ -1,24 +1,84 @@ # NAM: neural amp modeler -Let's use deep learning to make a model of a guitar amp! +This is the training part of NAM. +For the code to create the plugin with a trained model, see my +[iPlug2 fork](https://github.com/sdatkinson/iPlug2). -## Setup +## How to use -Use `pip install -r requirements.txt` as well as intalling the TensorFlow version of your choice (e.g. `pip install tensorflow-gpu`). +### Train a model -## Autoregressive models +You'll need at least two mono wav files: the input (DI) and the amped sound (without the cab). +You can either record enough to have a training and validation set in the same file and +split the file, or you can use 4 files (input/output for train/test). +Also, you can provide _multiple_ file pairs for training (or validation). -* Fully-connected (n-to-1) +For the first option, Modify `bin/train/inputs/config_data_single_pair.json` to point at the audio files, and set the +start/stop to the point (in samples) where the training segment ends and the validation +starts. +For the second option, modify and use `bin/train/inputs/config_data_two_pairs.json`. -## To do +Then run: -* Other models -* Real-time (eventually) -* Model stereo outputs -* Data classes for .wav data, keeping sample rate, etc -* spek.cc? +```bash +python bin/train/main.py \ +bin/train/inputs/config_data.json \ +bin/train/inputs/config_model.json \ +bin/train/inputs/config_learning.json \ +bin/train/outputs/MyAmp +``` + +### Run a model on an input signal ("reamping") + +Handy if you want to just check it out without going through the trouble of building the +plugin. + +```bash +python bin/run.py \ +path/to/source.wav \ +path/to/config_model.json \ +path/to/checkpoints/epoch=123_val_loss=0.000010.ckpt \ +path/to/output.wav +``` + +### Export a model (to use with [the plugin](https://github.com/sdatkinson/iPlug2)) + +Let's get ready to rock! + +```bash +python bin/export.py \ +path/to/config_model.json \ +path/to/checkpoints/epoch=123_val_loss=0.000010.ckpt \ +path/to/exported_models/MyAmp +``` + +You'll want the `HardCodedModel.h` to paste over into the plugin source. + +## Advanced usage + +The model architecture in `config_model.json` should work plenty good. However, if you +want to, you can increase the number of channels and the model will generally fit +better (though it'll get closer to the threshold of real-time. 20 works for a "large" +model and is still about 6x real time on my desktop). + +If you want to mess with the model architecture and end up with a different receptive +field (e.g. by messing with the dilation pattern), then you need to make sure that `nx` +is changed accordingly in the data setup. +The default architecture has a receptive field of 8191 samples, so `nx` is `8191`. +Generally, for the conv net architecture the receptive field is one elss than the sum of the `dilations`. + +You can train for shorter or longer. +1000 gives pretty great results, but if you're impatient you can sometimes get away with +comparable results after 500 epochs, and you might nto even be able to tell the +difference with far fewer (maybe 200?...100?) + +- Other models +- Real-time (eventually) +- Model stereo outputs +- Data classes for .wav data, keeping sample rate, etc +- spek.cc? ## Other interesting tidbits? -* Audio analysis plugins: vamp-plugins.org -* Spectrogram: sonicvisualiser.org +- Audio analysis plugins: vamp-plugins.org +- Spectrogram: sonicvisualiser.org diff --git a/architectures/fc_32_32_32.json b/architectures/fc_32_32_32.json @@ -1,8 +0,0 @@ -{ - "type": "FullyConnected", - "input_length": 32, - "layer_sizes": [ - 32, - 32 - ] -} -\ No newline at end of file diff --git a/bin/export.py b/bin/export.py @@ -0,0 +1,41 @@ +# File: export.py +# Created Date: Sunday February 6th 2022 +# Author: Steven Atkinson ([email protected]) + +""" +Export a model to TorchScript +""" + +import json +from argparse import ArgumentParser +from pathlib import Path + +import torch + +from nam.models import Model + + +class Dummy(torch.nn.Module): + def forward(self, x): + return x[8191:] + + +def main(args): + with open(args.model_config_path, "r") as fp: + net = Model.load_from_checkpoint( + args.checkpoint, **Model.parse_config(json.load(fp)) + ).net + net.eval() + outdir = Path(args.outdir) + outdir.mkdir(parents=True, exist_ok=True) + net.export(outdir) + net.export_cpp_header(Path(outdir, "HardCodedModel.h")) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("model_config_path", type=str) + parser.add_argument("checkpoint", type=str) + parser.add_argument("outdir") + parser.add_argument("--source-length", type=int, default=16384) + main(parser.parse_args()) diff --git a/bin/run.py b/bin/run.py @@ -0,0 +1,33 @@ +# File: run.py +# Created Date: Sunday February 6th 2022 +# Author: Steven Atkinson ([email protected]) + +""" +Load up a model, process a WAV, and save. +""" + +import json +from argparse import ArgumentParser + +from nam.data import wav_to_tensor, tensor_to_wav +from nam.models import Model + + +def main(args): + source = wav_to_tensor(args.source_path) + with open(args.model_config_path, "r") as fp: + model = Model.load_from_checkpoint( + args.checkpoint, **Model.parse_config(json.load(fp)) + ) + model.eval() + output = model(source) + tensor_to_wav(output, args.outfile) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("source_path", type=str) + parser.add_argument("model_config_path", type=str) + parser.add_argument("checkpoint", type=str) + parser.add_argument("outfile") + main(parser.parse_args()) diff --git a/bin/train/inputs/config_data_single_pair.json b/bin/train/inputs/config_data_single_pair.json @@ -0,0 +1,18 @@ +{ + "train": { + "start": null, + "stop": 36576000, + "ny": 1024 + }, + "validation": { + "start": 36576000, + "stop": null, + "ny": null + }, + "common": { + "x_path": "C:\\path\\to\\source.wav", + "y_path": "C:\\path\\to\\target.wav", + "nx": 8191, + "delay": 0 + } +} +\ No newline at end of file diff --git a/bin/train/inputs/config_data_two_pairs.json b/bin/train/inputs/config_data_two_pairs.json @@ -0,0 +1,16 @@ +{ + "train": { + "x_path": "C:\\path\\to\\train\\source.wav", + "y_path": "C:\\path\\to\\train\\target.wav", + "ny": 1024 + }, + "validation": { + "x_path": "C:\\path\\to\\validation\\source.wav", + "y_path": "C:\\path\\to\\validation\\target.wav", + "ny": null + }, + "common": { + "nx": 8191, + "delay": 0 + } +} +\ No newline at end of file diff --git a/bin/train/inputs/config_learning.json b/bin/train/inputs/config_learning.json @@ -0,0 +1,17 @@ +{ + "train_dataloader": { + "batch_size": 32, + "shuffle": true, + "pin_memory": true, + "drop_last": true, + "num_workers": 4 + }, + "val_dataloader": { + }, + "trainer": { + "gpus": 1, + "max_epochs": 1000 + }, + "trainer_fit_kwargs": { + } +} +\ No newline at end of file diff --git a/bin/train/inputs/config_model.json b/bin/train/inputs/config_model.json @@ -0,0 +1,25 @@ +{ + "net": { + "name": "WaveNet", + "config": { + "channels": 16, + "dilations": [1,2,4,8,16,32,64,128,256,512,1024,2048,1,2,4,8,16,32,64,128,256,512,1024,2048], + "batchnorm": true, + "activation": "Tanh" + } + }, + "optimizer": { + "lr": 0.003 + }, + "lr_scheduler": { + "class": "ReduceLROnPlateau", + "kwargs": { + "factor": 0.5, + "patience": 50, + "cooldown": 50, + "min_lr": 1.0e-5, + "verbose": true + }, + "monitor": "val_loss" + } +} +\ No newline at end of file diff --git a/bin/train/main.py b/bin/train/main.py @@ -0,0 +1,138 @@ +# File: train.py +# Created Date: Saturday February 5th 2022 +# Author: Steven Atkinson ([email protected]) + +import json +from argparse import ArgumentParser +from datetime import datetime +from pathlib import Path +from time import time +from typing import Optional, Union + +import matplotlib.pyplot as plt +import numpy as np +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader + +from nam.data import Split, init_dataset +from nam.models import Model + +torch.manual_seed(0) + + +def timestamp() -> str: + t = datetime.now() + return f"{t.year:04d}-{t.month:02d}-{t.day:02d}-{t.hour:02d}-{t.minute:02d}-{t.second:02d}" + + +def ensure_outdir(outdir: str) -> Path: + outdir = Path(outdir, timestamp()) + outdir.mkdir(parents=True, exist_ok=False) + return outdir + + +def _rms(x: Union[np.ndarray, torch.Tensor]) -> float: + if isinstance(x, np.ndarray): + return np.sqrt(np.mean(np.square(x))) + elif isinstance(x, torch.Tensor): + return torch.sqrt(torch.mean(torch.square(x))).item() + else: + raise TypeError(type(x)) + + +def plot( + model, + ds, + savefig=None, + show=True, + window_start: Optional[int] = None, + window_end: Optional[int] = None, +): + with torch.no_grad(): + tx = len(ds.x) / 48_000 + print(f"Run (t={tx})") + t0 = time() + output = model(ds.x).flatten().cpu().numpy() + t1 = time() + print(f"Took {t1 - t0} ({tx / (t1 - t0):.2f}x)") + + plt.figure(figsize=(16, 5)) + plt.plot(ds.x[window_start:window_end], label="Input") + plt.plot(output[window_start:window_end], label="Output") + plt.plot(ds.y[window_start:window_end], label="Target") + plt.title(f"NRMSE={_rms(torch.Tensor(output) - ds.y) / _rms(ds.y)}") + plt.legend() + if savefig is not None: + plt.savefig(savefig) + if show: + plt.show() + + +def main(args): + outdir = ensure_outdir(args.outdir) + # Read + with open(args.data_config_path, "r") as fp: + data_config = json.load(fp) + with open(args.model_config_path, "r") as fp: + model_config = json.load(fp) + with open(args.learning_config_path, "r") as fp: + learning_config = json.load(fp) + # Write + for basename, config in ( + ("data", data_config), + ("model", model_config), + ("learning", learning_config), + ): + with open(Path(outdir, f"config_{basename}.json"), "w") as fp: + json.dump(config, fp, indent=4) + + model = Model.init_from_config(model_config) + + dataset_train = init_dataset(data_config, Split.TRAIN) + dataset_validation = init_dataset(data_config, Split.VALIDATION) + train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"]) + val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"]) + + # ckpt_path = Path(outdir, "checkpoints") + # ckpt_path.mkdir() + trainer = pl.Trainer( + callbacks=[ + pl.callbacks.model_checkpoint.ModelCheckpoint( + filename="{epoch}_{val_loss:.6f}", + save_top_k=3, + monitor="val_loss", + every_n_epochs=1, + ), + pl.callbacks.model_checkpoint.ModelCheckpoint( + filename="checkpoint_last_{epoch:04d}", every_n_epochs=1 + ), + ], + default_root_dir=outdir, + **learning_config["trainer"], + ) + trainer.fit( + model, + train_dataloader, + val_dataloader, + **learning_config.get("trainer_fit_kwargs", {}), + ) + model.eval() + plot( + model, + dataset_validation, + savefig=Path(outdir, "comparison.png"), + window_start=100_000, + window_end=110_000, + show=False, + ) + plot(model, dataset_validation) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("data_config_path", type=str) + parser.add_argument("model_config_path", type=str) + parser.add_argument("learning_config_path", type=str) + parser.add_argument("outdir") + main(parser.parse_args()) diff --git a/data/emissary/train.npy b/data/emissary/train.npy Binary files differ. diff --git a/data/emissary/validate.npy b/data/emissary/validate.npy Binary files differ. diff --git a/data/reaper-js-distortion/train.npy b/data/reaper-js-distortion/train.npy Binary files differ. diff --git a/data/reaper-js-distortion/validate.npy b/data/reaper-js-distortion/validate.npy Binary files differ. diff --git a/environment.yml b/environment.yml @@ -0,0 +1,25 @@ +# File: environment.yml +# Created Date: Saturday February 13th 2021 +# Author: Steven Atkinson ([email protected]) + +name: nam2 +channels: + - pytorch +dependencies: + - black + - flake8 + - h5py + - jupyter + - matplotlib + - numpy + - pip + - pytest + - pytorch + - scipy + - tqdm + - wheel + - pip: + - pre-commit + - pytorch_lightning + - sounddevice + - wavio diff --git a/models.py b/models.py @@ -1,165 +0,0 @@ -# File: models.py -# File Created: Sunday, 30th December 2018 9:42:29 pm -# Author: Steven Atkinson ([email protected]) - -import numpy as np -import tensorflow as tf -import abc -from tempfile import mkdtemp -import os -import json - - -def from_json(f, n_train=1, checkpoint_path=None): - if isinstance(f, str): - with open(f, "r") as json_file: - f = json.load(json_file) - - if f["type"] == "FullyConnected": - return FullyConnected(n_train, f["input_length"], - layer_sizes=f["layer_sizes"], checkpoint_path=checkpoint_path) - else: - raise NotImplementedError("Model type {} unrecognized".format( - f["type"])) - - -class Model(object): - """ - Model parent class - """ - def __init__(self, n_train, sess=None, checkpoint_path=None): - """ - Make sure child classes call _build() after this! - """ - if sess is None: - sess = tf.get_default_session() - self.sess = sess - - if checkpoint_path is None: - checkpoint_path = os.path.join(mkdtemp(), "model.ckpt") - if not os.path.isdir(os.path.dirname(checkpoint_path)): - os.makedirs(os.path.dirname(checkpoint_path)) - self.checkpoint_path = checkpoint_path - - # self._batch_size = batch_size - self._n_train = n_train - self.target = tf.placeholder(tf.float32, shape=(None, 1)) - self.prediction = None - self.total_prediction_loss = None - self.rmse = None - self.loss = None # Includes any regularization - - # @property - # def batch_size(self): - # return self._batch_size - - @property - def n_train(self): - return self._n_train - - def load(self, checkpoint_path=None): - checkpoint_path = checkpoint_path or self.checkpoint_path - - try: - ckpt = tf.train.get_checkpoint_state(checkpoint_path) - print("Loading model: {}".format(ckpt.model_checkpoint_path)) - self.saver.restore(self.sess, ckpt.model_checkpoint_path) - except Exception as e: - print("Error while attempting to load model: {}".format(e)) - - @abc.abstractclassmethod - def predict(self, x): - """ - A nice function for prediction. - :param x: input array (length=n) - :type x: array-like - :return: (array-like) corresponding predicted outputs (length=n) - """ - raise NotImplementedError("Implement predict()") - - def save(self, iter, checkpoint_path=None): - checkpoint_path = checkpoint_path or self.checkpoint_path - self.saver.save(self.sess, checkpoint_path, global_step=iter) - - def _build(self): - self.prediction = self._build_prediction() - self.loss = self._build_loss() - - # Launch the session - self.sess.run(tf.global_variables_initializer()) - self.saver = tf.train.Saver(tf.global_variables()) - - def _build_loss(self): - self.total_prediction_loss = tf.losses.mean_squared_error(self.target, - self.prediction, weights=self.n_train) - - # Don't count this as a loss! - self.rmse = tf.sqrt( - self.total_prediction_loss / self.n_train) - - return tf.losses.get_total_loss() - - @abc.abstractclassmethod - def _build_prediction(self): - raise NotImplementedError('Implement prediction for model') - - -class Autoregressive(Model): - """ - Autoregressive models that take in a few of the most recent input samples - and predict the output at the last time point. - """ - def __init__(self, n_train, input_length, sess=None, - checkpoint_path=None): - super().__init__(n_train, sess=sess, checkpoint_path=checkpoint_path) - self._input_length = input_length - self.x = tf.placeholder(tf.float32, shape=(None, self.input_length)) - - @property - def input_length(self): - return self._input_length - - def predict(self, x, batch_size=None, verbose=False): - """ - Return 1D array of predictions same length as x - """ - n = x.size - batch_size = batch_size or n - # Pad x with leading zeros: - x = np.concatenate((np.zeros(self.input_length - 1), x)) - i = 0 - y = [] - while i < n: - if verbose: - print("model.predict {}/{}".format(i, n)) - this_batch_size = np.min([batch_size, n - i]) - # Reshape into a batch: - x_mtx = np.stack([x[j: j + self.input_length] - for j in range(i, i + this_batch_size)]) - # Predict and flatten. - y.append(self.sess.run(self.prediction, feed_dict={self.x: x_mtx}) \ - .flatten()) - i += this_batch_size - return np.concatenate(y) - - -class FullyConnected(Autoregressive): - """ - Autoregressive model taking in a sequence of the most recent inputs, putting - them through a series of FC layers, and outputting the single output at the - last time step. - """ - def __init__(self, n_train, input_length, layer_sizes=(512,), - sess=None, checkpoint_path=None): - super().__init__(n_train, input_length, sess=sess, - checkpoint_path=checkpoint_path) - self._layer_sizes = layer_sizes - self._build() - - def _build_prediction(self): - h = self.x - for m in self._layer_sizes: - h = tf.contrib.layers.fully_connected(h, m) - y = -1.0 + 2.0 * tf.contrib.layers.fully_connected(h, 1, - activation_fn=tf.nn.sigmoid) - return y diff --git a/nam/__init__.py b/nam/__init__.py @@ -0,0 +1,9 @@ +# File: __init__.py +# File Created: Tuesday, 2nd February 2021 9:42:50 pm +# Author: Steven Atkinson ([email protected]) + +from ._version import __version__ # Must be before models or else circular + +from . import _core # noqa F401 +from . import data # noqa F401 +from . import models # noqa F401 diff --git a/nam/_core.py b/nam/_core.py @@ -0,0 +1,13 @@ +# File: core.py +# Created Date: Saturday February 5th 2022 +# Author: Steven Atkinson ([email protected]) + + +class InitializableFromConfig(object): + @classmethod + def init_from_config(cls, config): + return cls(**cls.parse_config(config)) + + @classmethod + def parse_config(cls, config): + return config diff --git a/nam/_version.py b/nam/_version.py @@ -0,0 +1 @@ +__version__ = "0.2.0" diff --git a/nam/data.py b/nam/data.py @@ -0,0 +1,221 @@ +# File: data.py +# Created Date: Saturday February 5th 2022 +# Author: Steven Atkinson ([email protected]) + +import abc +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import wavio +from torch.utils.data import Dataset as _Dataset + +from ._core import InitializableFromConfig + +_REQUIRED_SAMPWIDTH = 3 +_REQUIRED_RATE = 48_000 +_REQUIRED_CHANNELS = 1 # Mono + + +class Split(Enum): + TRAIN = "train" + VALIDATION = "validation" + + +@dataclass +class WavInfo: + sampwidth: int + rate: int + + +def wav_to_np( + filename: Union[str, Path], + require_match: Optional[Union[str, Path]] = None, + required_shape: Optional[Tuple[int]] = None, + required_wavinfo: Optional[WavInfo] = None, + preroll: Optional[int] = None, + info: bool = False, +) -> Union[np.ndarray, Tuple[np.ndarray, WavInfo]]: + """ + :param preroll: Drop this many samples off the front + """ + x_wav = wavio.read(str(filename)) + assert x_wav.data.shape[1] == _REQUIRED_CHANNELS, "Mono" + assert x_wav.sampwidth == _REQUIRED_SAMPWIDTH, "24-bit" + assert x_wav.rate == _REQUIRED_RATE, "48 kHz" + + if require_match is not None: + assert required_shape is None + assert required_wavinfo is None + y_wav = wavio.read(str(require_match)) + required_shape = y_wav.data.shape + required_wavinfo = WavInfo(y_wav.sampwidth, y_wav.rate) + if required_wavinfo is not None: + if x_wav.rate != required_wavinfo.rate: + raise ValueError( + f"Mismatched rates {x_wav.rate} versus {required_wavinfo.rate}" + ) + arr_premono = x_wav.data[preroll:] / (2.0 ** (8 * x_wav.sampwidth - 1)) + if required_shape is not None: + if arr_premono.shape != required_shape: + raise ValueError( + f"Mismatched shapes {arr_premono.shape} versus {required_shape}" + ) + # sampwidth fine--we're just casting to 32-bit float anyways + arr = arr_premono[:, 0] + return arr if not info else (arr, WavInfo(x_wav.sampwidth, x_wav.rate)) + + +def wav_to_tensor( + *args, info: bool = False, **kwargs +) -> Union[torch.Tensor, Tuple[torch.Tensor, WavInfo]]: + out = wav_to_np(*args, info=info, **kwargs) + if info: + arr, info = out + return torch.Tensor(arr), info + else: + arr = out + return torch.Tensor(arr) + + +def tensor_to_wav( + x: torch.Tensor, + filename: Union[str, Path], + rate: int = 48_000, + sampwidth: int = 3, + scale="none", +): + wavio.write( + filename, + (torch.clamp(x, -1.0, 1.0) * (2 ** (8 * sampwidth - 1))) + .detach() + .cpu() + .numpy() + .astype(np.int32), + rate, + scale=scale, + sampwidth=sampwidth, + ) + + +class AbstractDataset(_Dataset, abc.ABC): + @abc.abstractmethod + def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]: + pass + + +class Dataset(AbstractDataset, InitializableFromConfig): + """ + Take a pair of matched audio files and serve input + output pairs + """ + + def __init__( + self, + x: torch.Tensor, + y: torch.Tensor, + nx: int, + ny: Optional[int], + start: Optional[int] = None, + stop: Optional[int] = None, + delay: Optional[int] = None, + ): + """ + :param start: In samples + :param stop: In samples + :param delay: In samples. Positive means we get rid of the start of x, end of y. + """ + x, y = [z[start:stop] for z in (x, y)] + if delay is not None: + if delay > 0: + x = x[:-delay] + y = y[delay:] + else: + x = x[-delay:] + y = y[:delay] + self._validate_inputs(x, y, nx, ny) + self._x = x + self._y = y + self._nx = nx + self._ny = ny if ny is not None else len(x) - nx + 1 + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + if idx >= len(self): + raise IndexError(f"Attempted to access datum {idx}, but len is {len(self)}") + i = idx * self._ny + j = i + self.y_offset + return self.x[i : i + self._nx + self._ny - 1], self.y[j : j + self._ny] + + def __len__(self) -> int: + n = len(self.x) + # If ny were 1 + single_pairs = n - self._nx + 1 + return single_pairs // self._ny + + @property + def x(self): + return self._x + + @property + def y(self): + return self._y + + @property + def y_offset(self) -> int: + return self._nx - 1 + + @classmethod + def parse_config(cls, config): + x, x_wavinfo = wav_to_tensor(config["x_path"], info=True) + y = wav_to_tensor( + config["y_path"], + preroll=config.get("y_preroll"), + required_shape=(len(x), 1), + required_wavinfo=x_wavinfo, + ) + return { + "x": x, + "y": y, + "nx": config["nx"], + "ny": config["ny"], + "start": config.get("start"), + "stop": config.get("stop"), + } + + def _validate_inputs(self, x, y, nx, ny): + assert x.ndim == 1 + assert y.ndim == 1 + assert len(x) == len(y) + assert nx <= len(x) + if ny is not None: + assert ny <= len(y) - nx + 1 + + +class ConcatDataset(AbstractDataset, InitializableFromConfig): + def __init__(self, datasets: Sequence[Dataset]): + self._datasets = datasets + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + for d in self._datasets: + if idx < len(d): + return d[idx] + else: + idx = idx - len(d) + + def __len__(self) -> int: + return sum(len(d) for d in self._datasets) + + @classmethod + def parse_config(cls, config): + return {"datasets": tuple(Dataset.init_from_config(c) for c in config)} + + +def init_dataset(config, split: Split) -> AbstractDataset: + base_config = config[split.value] + common = config.get("common", {}) + if isinstance(base_config, dict): + return Dataset.init_from_config({**common, **base_config}) + elif isinstance(base_config, list): + return ConcatDataset.init_from_config([{**common, **c} for c in base_config]) diff --git a/nam/models/__init__.py b/nam/models/__init__.py @@ -0,0 +1,9 @@ +# File: __init__.py +# Created Date: Saturday February 5th 2022 +# Author: Steven Atkinson ([email protected]) + +from . import _base # noqa F401 +from . import _exportable # noqa F401 +from .base import Model # noqa F401 +from .linear import Linear # noqa F401 +from .ConvNet import ConvNet # noqa F401 diff --git a/nam/models/_base.py b/nam/models/_base.py @@ -0,0 +1,58 @@ +# File: _base.py +# Created Date: Tuesday February 8th 2022 +# Author: Steven Atkinson ([email protected]) + +import abc +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn + +from .._core import InitializableFromConfig +from ._exportable import Exportable + + +class BaseNet(nn.Module, InitializableFromConfig, Exportable): + @abc.abstractproperty + def pad_start_default(self) -> bool: + pass + + @abc.abstractproperty + def receptive_field(self) -> int: + """ + Receptive field of the model + """ + pass + + def forward(self, x: torch.Tensor, pad_start: bool = None): + pad_start = self.pad_start_default if pad_start is None else pad_start + scalar = x.ndim == 1 + if scalar: + x = x[None] + if pad_start: + x = torch.cat((torch.zeros((len(x), self.receptive_field - 1)), x), dim=1) + y = self._forward(x) + if scalar: + y = y[0] + return y + + @abc.abstractmethod + def _forward(self, x: torch.Tensor) -> torch.Tensor: + """ + The true forward method. + + :param x: (N,L1) + :return: (N,L1-RF+1) + """ + pass + + def _test_signal( + self, seed=0, extra_length=13 + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = torch.Tensor( + np.random.default_rng(seed).normal( + size=(self.receptive_field + extra_length,) + ) + ) + return x, self(x, pad_start=False) diff --git a/nam/models/_exportable.py b/nam/models/_exportable.py @@ -0,0 +1,33 @@ +# File: _exportable.py +# Created Date: Tuesday February 8th 2022 +# Author: Steven Atkinson ([email protected]) + +import abc +from pathlib import Path + + +class Exportable(abc.ABC): + """ + Interface for my custon export format for use in the plugin. + """ + + @abc.abstractmethod + def export(self, outdir: Path): + """ + Interface for exporting. + You should create at least a `config.json` containing the two fields: + * "version" (str) + * "architecture" (str) + * "config": (dict w/ other necessary data like tensor shapes etc) + + :param outdir: Assumed to exist. Can be edited inside at will. + """ + pass + + @abc.abstractmethod + def export_cpp_header(self, filename: Path): + """ + Export a .h file to compile into the plugin with the weights written right out + as text + """ + pass diff --git a/nam/models/base.py b/nam/models/base.py @@ -0,0 +1,97 @@ +# File: base.py +# Created Date: Saturday February 5th 2022 +# Author: Steven Atkinson ([email protected]) + +""" +Lightning stuff +""" + +from typing import Optional + +import pytorch_lightning as pl +import torch +import torch.nn as nn + +from .._core import InitializableFromConfig +from .linear import Linear +from .conv_net import ConvNet + + +class Model(pl.LightningModule, InitializableFromConfig): + def __init__( + self, + net, + optimizer_config: Optional[dict] = None, + scheduler_config: Optional[dict] = None, + ): + """ + :param scheduler_config: contains + Required: + * "class" + * "kwargs" + Optional (defaults to Lightning defaults): + * "interval" ("epoch" of "step") + * "frequency" (int) + * "monitor" (str) + """ + super().__init__() + self._net = net + self._optimizer_config = {} if optimizer_config is None else optimizer_config + self._scheduler_config = scheduler_config + + @classmethod + def init_from_config(cls, config): + checkpoint_path = config.get("checkpoint_path") + config = cls.parse_config(config) + return ( + cls(**config) + if checkpoint_path is None + else cls.load_from_checkpoint(checkpoint_path, **config) + ) + + @classmethod + def parse_config(cls, config): + config = super().parse_config(config) + net_config = config["net"] + net = {"Linear": Linear.init_from_config, "ConvNet": ConvNet.init_from_config}[ + net_config["name"] + ](net_config["config"]) + return { + "net": net, + "optimizer_config": config["optimizer"], + "scheduler_config": config["lr_scheduler"], + } + + @property + def net(self) -> nn.Module: + return self._net + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), **self._optimizer_config) + if self._scheduler_config is None: + return optimizer + else: + lr_scheduler = getattr( + torch.optim.lr_scheduler, self._scheduler_config["class"] + )(optimizer, **self._scheduler_config["kwargs"]) + lr_scheduler_config = {"scheduler": lr_scheduler} + for key in ("interval", "frequency", "monitor"): + if key in self._scheduler_config: + lr_scheduler_config[key] = self._scheduler_config[key] + return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} + + def forward(self, *args, **kwargs): + return self.net(*args, **kwargs) + + def _shared_step(self, batch): + sources, targets = batch + preds = self(sources, pad_start=False) + return nn.MSELoss()(preds, targets) + + def training_step(self, batch, batch_idx): + return self._shared_step(batch) + + def validation_step(self, batch, batch_idx): + val_loss = self._shared_step(batch) + self.log_dict({"val_loss": val_loss}) + return val_loss diff --git a/nam/models/conv_net.py b/nam/models/conv_net.py @@ -0,0 +1,290 @@ +# File: conv_net.py +# Created Date: Saturday February 5th 2022 +# Author: Steven Atkinson ([email protected]) + +import json +from enum import Enum +from functools import partial +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +from .. import __version__ +from ..data import wav_to_tensor +from ._base import BaseNet + +_CONV_NAME = "conv" +_BATCHNORM_NAME = "batchnorm" +_ACTIVATION_NAME = "activation" + + +class TrainStrategy(Enum): + STRIDE = "stride" + DILATE = "dilate" + + +default_train_strategy = TrainStrategy.DILATE + + +class _Functional(nn.Module): + """ + Define a layer by a function w/ no params + """ + + def __init__(self, op): + super().__init__() + self._op = op + + def forward(self, *args, **kwargs): + return self._op(*args, **kwargs) + + +class _IR(nn.Module): + def __init__(self, filename: Union[str, Path]): + super().__init__() + self.register_buffer("_weight", reversed(wav_to_tensor(filename))[None, None]) + + @property + def length(self) -> int: + return self._weight.shape[-1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + :param x: (N,D) + :return: (N,D-length+1) + """ + return F.conv1d(x[:, None], self._weight)[:, 0] + + +def _conv_net( + channels: int = 32, + dilations: Sequence[int] = None, + batchnorm: bool = False, + activation: str = "Tanh", +) -> nn.Sequential: + def block(cin, cout, dilation): + net = nn.Sequential() + net.add_module( + _CONV_NAME, nn.Conv1d(cin, cout, 2, dilation=dilation, bias=not batchnorm) + ) + if batchnorm: + net.add_module(_BATCHNORM_NAME, nn.BatchNorm1d(cout)) + net.add_module(_ACTIVATION_NAME, getattr(nn, activation)()) + return net + + def check_and_expand(n, x): + if x.shape[1] < n: + raise ValueError( + f"Input of length {x.shape[1]} is shorter than model receptive field ({n})" + ) + return x[:, None, :] + + dilations = [1, 2, 4, 8] if dilations is None else dilations + receptive_field = sum(dilations) + 1 + net = nn.Sequential() + net.add_module("expand", _Functional(partial(check_and_expand, receptive_field))) + cin = 1 + cout = channels + for i, dilation in enumerate(dilations): + net.add_module(f"block_{i}", block(cin, cout, dilation)) + cin = cout + net.add_module("head", nn.Conv1d(channels, 1, 1)) + net.add_module("flatten", nn.Flatten()) + return net + + +class ConvNet(BaseNet): + """ + A straightforward convolutional neural network. + + Works surprisingly well! + """ + + def __init__( + self, + *args, + train_strategy: TrainStrategy = default_train_strategy, + ir: Optional[_IR] = None, + **kwargs, + ): + super().__init__() + self._net = _conv_net(*args, **kwargs) + assert train_strategy == TrainStrategy.DILATE, "Stride no longer supported" + self._train_strategy = train_strategy + self._num_blocks = self._get_num_blocks(self._net) + self._pad_start_default = True + self._ir = ir + + @classmethod + def parse_config(cls, config): + config = super().parse_config(config) + config["train_strategy"] = TrainStrategy( + config.get("train_strategy", default_train_strategy.value) + ) + config["ir"] = ( + None if "ir_filename" not in config else _IR(config.pop("ir_filename")) + ) + return config + + @property + def pad_start_default(self) -> bool: + return self._pad_start_default + + @property + def receptive_field(self) -> int: + net_rf = 1 + sum( + self._net._modules[f"block_{i}"]._modules["conv"].dilation[0] + for i in range(self._num_blocks) + ) + # Minus 1 because it composes w/ the net + ir_rf = 0 if self._ir is None else self._ir.length - 1 + return net_rf + ir_rf + + @property + def _activation(self): + return ( + self._net._modules["block_0"]._modules[_ACTIVATION_NAME].__class__.__name__ + ) + + @property + def _channels(self) -> int: + return self._net._modules["block_0"]._modules[_CONV_NAME].weight.shape[0] + + @property + def _num_layers(self) -> int: + return self._num_blocks + + @property + def _batchnorm(self) -> bool: + return _BATCHNORM_NAME in self._net._modules["block_0"]._modules + + def export(self, outdir: Path): + """ + Files created: + * config.json + * weights.npy + * input.npy + * output.npy + + weights are serialized to weights.npy in the following order: + * (expand: no params) + * loop blocks 0,...,L-1 + * conv: + * weight (Cout, Cin, K) + * bias (if no batchnorm) (Cout) + * BN + * running mean + * running_var + * weight (Cout) + * bias (Cout) + * eps () + * head + * weight (C, 1, 1) + * bias (1, 1) + * (flatten: no params) + + A test input & output are also provided, input.npy and output.npy + """ + training = self.training + self.eval() + with open(Path(outdir, "config.json"), "w") as fp: + json.dump( + { + "version": __version__, + "architecture": "ConvNet", + "config": { + "channels": self._channels, + "dilations": self._get_dilations(), + "batchnorm": self._batchnorm, + "activation": self._activation, + }, + }, + fp, + indent=4, + ) + + params = [] + for i in range(self._num_layers): + block_name = f"block_{i}" + block = self._net._modules[block_name] + conv = block._modules[_CONV_NAME] + params.append(conv.weight.flatten()) + if conv.bias is not None: + params.append(conv.bias.flatten()) + if self._batchnorm: + bn = block._modules[_BATCHNORM_NAME] + params.append(bn.running_mean.flatten()) + params.append(bn.running_var.flatten()) + params.append(bn.weight.flatten()) + params.append(bn.bias.flatten()) + params.append(torch.Tensor([bn.eps]).to(bn.weight.device)) + head = self._net._modules["head"] + params.append(head.weight.flatten()) + params.append(head.bias.flatten()) + params = torch.cat(params).detach().cpu().numpy() + # Hope I don't regret using np.save... + np.save(Path(outdir, "weights.npy"), params) + + # And an input/output to verify correct computation: + x, y = self._test_signal() + np.save(Path(outdir, "input.npy"), x.detach().cpu().numpy()) + np.save(Path(outdir, "output.npy"), y.detach().cpu().numpy()) + + # And resume training state + self.train(training) + + def export_cpp_header(self, filename: Path): + with TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + self.export(Path(tmpdir)) + with open(Path(tmpdir, "config.json"), "r") as fp: + _c = json.load(fp) + version = _c["version"] + config = _c["config"] + with open(filename, "w") as f: + f.writelines( + ( + "#pragma once\n", + "// Automatically-generated model file\n", + "#include <vector>\n", + f'#define PYTHON_MODEL_VERSION "{version}"\n', + f"const int CHANNELS = {config['channels']};\n", + f"const bool BATCHNORM = {'true' if config['batchnorm'] else 'false'};\n", + "std::vector<int> DILATIONS{" + + ",".join([str(d) for d in config["dilations"]]) + + "};\n", + f"const std::string ACTIVATION = \"{config['activation']}\";\n", + "std::vector<float> PARAMS{" + + ",".join( + [f"{w:.16f}" for w in np.load(Path(tmpdir, "weights.npy"))] + ) + + "};\n", + ) + ) + + def _forward(self, x): + y = self._net(x) + if self._ir is not None: + y = self._ir(y) + return y + + def _get_dilations(self) -> Tuple[int]: + return tuple( + self._net._modules[f"block_{i}"]._modules[_CONV_NAME].dilation[0] + for i in range(self._num_blocks) + ) + + def _get_num_blocks(self, net: nn.Sequential): + i = 0 + while True: + if f"block_{i}" not in net._modules: + break + else: + i += 1 + return i diff --git a/nam/models/linear.py b/nam/models/linear.py @@ -0,0 +1,70 @@ +# File: linear.py +# Created Date: Tuesday February 8th 2022 +# Author: Steven Atkinson ([email protected]) + +""" +Linear model +""" + +import json +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn + +from .._version import __version__ +from ._base import BaseNet + + +class Linear(BaseNet): + def __init__(self, receptive_field: int, bias: bool = False): + super().__init__() + self._net = nn.Conv1d(1, 1, receptive_field, bias=bias) + + @property + def pad_start_default(self) -> bool: + return True + + @property + def receptive_field(self) -> int: + return self._net.weight.shape[2] + + def export(self, outdir: Path): + training = self.training + self.eval() + with open(Path(outdir, "config.json"), "w") as fp: + json.dump( + { + "version": __version__, + "architecture": self.__class__.__name__, + "config": { + "receptive_field": self.receptive_field, + "bias": self._bias, + }, + }, + fp, + indent=4, + ) + + params = [self._net.weight.flatten()] + if self._bias: + params.append(self._net.bias.flatten()) + params = torch.cat(params).detach().cpu().numpy() + # Hope I don't regret using np.save... + np.save(Path(outdir, "weights.npy"), params) + + # And an input/output to verify correct computation: + x, y = self._test_signal() + np.save(Path(outdir, "input.npy"), x.detach().cpu().numpy()) + np.save(Path(outdir, "output.npy"), y.detach().cpu().numpy()) + + # And resume training state + self.train(training) + + @property + def _bias(self) -> bool: + return self._net.bias is not None + + def _forward(self, x: torch.Tensor) -> torch.Tensor: + return self._net(x[:, None])[:, 0] diff --git a/reamp.py b/reamp.py @@ -1,66 +0,0 @@ -""" -Reamp a .wav file - -Assumes 24-bit WAV files -""" - -from argparse import ArgumentParser -import tensorflow as tf -import os -import wavio -import matplotlib.pyplot as plt - -import models - - -def _sampwidth_to_bits(x): - return {2: 16, 3: 24, 4: 32}[x] - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("architecture", type=str, - help="JSON filename containing NN architecture") - parser.add_argument("checkpoint_dir", type=str, - help="directory holding model checkpoint to use") - parser.add_argument("input_file", type=str, - help="Input .wav file to convert") - parser.add_argument("--output_file", type=str, default=None, - help="Where to save the output") - parser.add_argument("--batch_size", type=int, default=8192, - help="How many samples to process at a time. " + - "Reduce if there are out-of-memory issues.") - parser.add_argument("--target_file", type=str, default=None, - help=".wav file of the true output (if you want to compare)") - args = parser.parse_args() - - if args.output_file is None: - args.output_file = args.input_file.rstrip(".wav") + "_reamped.wav" - - if os.path.isfile(args.output_file): - print("Output file exists; skip") - exit(1) - - x = wavio.read(args.input_file) - rate, sampwidth = x.rate, x.sampwidth - bits = _sampwidth_to_bits(sampwidth) - x_data = x.data.flatten() / 2 ** (bits - 1) - - with tf.Session() as sess: - model = models.from_json(args.architecture, - checkpoint_path=args.checkpoint_dir) - model.load() - y = model.predict(x_data, batch_size=args.batch_size, verbose=True) - wavio.write(args.output_file, y * 2 ** (bits - 1), rate, scale="none", - sampwidth=sampwidth) - - if args.target_file is not None and os.path.isfile(args.target_file): - t = wavio.read(args.target_file) - t_data = t.data.flatten() / 2 ** (_sampwidth_to_bits(t.sampwidth) - 1) - plt.figure() - plt.plot(x_data) - plt.plot(t_data) - plt.plot(y) - plt.legend(["Input", "Target", "Prediction"]) - plt.show() - -\ No newline at end of file diff --git a/requirements.txt b/requirements.txt @@ -1,2 +1,18 @@ -tensorflow +# File: requirements.txt +# Created Date: 2021-01-24 +# Author: Steven Atkinson ([email protected]) + +black +flake8 +matplotlib +numpy +pip +pre-commit +pytest +pytorch_lightning +scipy +sounddevice +torch +tqdm wavio +wheel diff --git a/setup.py b/setup.py @@ -0,0 +1,31 @@ +# File: setup.py +# Created Date: 2020-04-08 +# Author: Steven Atkinson ([email protected]) + +from distutils.util import convert_path +from setuptools import setup, find_packages + +main_ns = {} +ver_path = convert_path("nam/_version.py") +with open(ver_path) as ver_file: + exec(ver_file.read(), main_ns) + +requirements = [] # torch... + +try: + import torch # noqa F401 +except ImportError as e: + raise ImportError( + f"PyTorch not found. Please install it as needed.\nOriginal error: {e}" + ) + +setup( + name="nam", + version=main_ns["__version__"], + description="Neural amp modeler", + author="Steven Atkinson", + author_email="[email protected]", + url="https://github.com/sdatkinson/", + install_requires=requirements, + packages=find_packages(), +) diff --git a/tests/test_install.py b/tests/test_install.py @@ -0,0 +1,13 @@ +# File: test_install.py +# File Created: Tuesday, 2nd February 2021 9:46:01 pm +# Author: Steven Atkinson ([email protected]) + +import pytest + + +def test_torch(): + import torch # noqa F401 + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/test_nam/__init__.py b/tests/test_nam/__init__.py diff --git a/tests/test_nam/test_importable.py b/tests/test_nam/test_importable.py @@ -0,0 +1,13 @@ +# File: test_importable.py +# Created Date: Sunday December 12th 2021 +# Author: Steven Atkinson ([email protected]) + +import pytest + + +def test_importable(): + import nam # noqa F401 + + +if __name__ == "__main__": + pytest.main() diff --git a/train.py b/train.py @@ -1,198 +0,0 @@ -# File: train.py -# # File Created: Sunday, 30th December 2018 2:08:54 pm -# Author: Steven Atkinson ([email protected]) - -""" -Here's a script for training new models. -""" - -from argparse import ArgumentParser -import numpy as np -import abc -import tensorflow as tf -import matplotlib.pyplot as plt -from time import time -import wavio -import json -import os - -import models - -# Parameters for training -check_training_at = np.concatenate([10 ** pwr * np.array([1, 2, 5]) - for pwr in np.arange(0, 7)]) -plot_kwargs = { - "window": (30000, 40000) -} -wav_kwargs = {"window": (0, 5 * 44100)} - - -class Data(object): - """ - Object for holding data and spitting out minibatches for training/segments - for testing. - """ - def __init__(self, fname, input_length, batch_size=None): - xy = np.load(fname) - self.x = xy[0] - self.y = xy[1] - self._n = self.x.size - input_length + 1 - self.input_length = input_length - self.batch_size = batch_size - - @property - def n(self): - return self._n - - def minibatch(self, n=None, ilist=None): - """ - Pull a random minibatch out of the data set - """ - if ilist is None: - n = n or self.batch_size - ilist = np.random.randint(0, self.n, size=(n,)) - x = np.stack([self.x[i: i + self.input_length] for i in ilist]) - y = np.array([self.y[i + self.input_length - 1] for i in ilist]) \ - [:, np.newaxis] - return x, y - - def sequence(self, start, end): - end += self.input_length - 1 - return self.x[start: end], self.y[start + self.input_length - 1: end] - - -def train(model, train_data, batch_size=None, n_minibatches=10, - validation_data=None, plot_at=(), wav_at=(), plot_kwargs={}, - wav_kwargs={}, save_every=100, validate_every=100): - save_dir = os.path.dirname(model.checkpoint_path) - sess = model.sess - opt = tf.train.AdamOptimizer().minimize(model.loss) - sess.run(tf.global_variables_initializer()) # For opt - t_loss_list, v_loss_list = [], [] - t0 = time() - for i in range(n_minibatches): - x, y = train_data.minibatch(batch_size) - t_loss, _ = sess.run((model.rmse, opt), - feed_dict={model.x: x, model.target: y}) - t_loss_list.append([i, t_loss]) - print("t={:7} | MB {:>7} / {:>7} | TLoss={:8}".format(int(time() - t0), - i + 1, n_minibatches, t_loss)) - - # Callbacks, basically... - if i + 1 in plot_at: - plot_predictions(model, - validation_data if validation_data is not None else train_data, - title="Minibatch {}".format(i + 1), - fname="{}/mb_{}.png".format(save_dir, i + 1), - **plot_kwargs) - if i + 1 in wav_at: - print("Making wav for mb {}".format(i + 1)) - predict(model, validation_data, - save_wav_file="{}/predict_{}.wav".format(save_dir, i + 1), - **wav_kwargs) - if (i + 1) % save_every == 0: - model.save(iter=i + 1) - if i == 0 or (i + 1) % validate_every == 0: - v_loss, _ = sess.run((model.rmse, opt), - feed_dict={model.x: x, model.target: y}) - print("VLoss={:8}".format(v_loss)) - v_loss_list.append([i, v_loss]) - - # After training loop... - if validation_data is not None: - x, y = validation_data.minibatch(train_data.batch_size) - v_loss = sess.run(model.rmse, - feed_dict={model.x: x, model.target: y}) - print("Validation loss={}".format(v_loss)) - return np.array(t_loss_list).T, np.array(v_loss_list).T - - -def plot_predictions(model, data, title=None, fname=None, window=None): - x, y, t = predict(model, data, window=window) - plt.figure(figsize=(12, 4)) - plt.plot(x) - plt.plot(t) - plt.plot(y) - plt.legend(('Input', 'Target', 'Prediction')) - if title is not None: - plt.title(title) - if fname is not None: - print("Saving to {}...".format(fname)) - plt.savefig(fname) - plt.close() - else: - plt.show() - - -def plot_loss(t_loss, v_loss, fname): - plt.figure() - plt.loglog(t_loss[0], t_loss[1]) - plt.loglog(v_loss[0], v_loss[1]) - plt.xlabel("Minibatch") - plt.ylabel("RMSE") - plt.legend(("Training", "Validation")) - plt.savefig(fname) - plt.close() - - -def predict(model, data, window=None, save_wav_file=None): - x, t = data.x, data.y - if window is not None: - x, t = x[window[0]: window[1]], t[window[0]: window[1]] - y = model.predict(x).flatten() - - if save_wav_file is not None: - rate = 44100 # TODO from data - sampwidth = 3 # 24-bit - wavio.write(save_wav_file, y * 2 ** 23, rate, scale="none", - sampwidth=sampwidth) - - return x, y, t - - -def _get_input_length(archfile): - return json.load(open(archfile, "r"))["input_length"] - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("model_arch", type=str, - help="JSON containing model architecture") - parser.add_argument("train_data", type=str, - help="Filename for training data") - parser.add_argument("validation_data", type=str, - help="Filename for validation data") - parser.add_argument("--save_dir", type=str, default=None, - help="Where to save the run data (checkpoints, prediction...)") - parser.add_argument("--batch_size", type=str, default=4096, - help="Number of data per minibatch") - parser.add_argument("--minibatches", type=int, default=10, - help="Number of minibatches to train for") - args = parser.parse_args() - input_length = _get_input_length(args.model_arch) # Ugh, kludge - - # Load the data - train_data = Data(args.train_data, input_length, - batch_size=args.batch_size) - validate_data = Data(args.validation_data, input_length) - - # Training - with tf.Session() as sess: - model = models.from_json(args.model_arch, train_data.n, - checkpoint_path=os.path.join(args.save_dir, "model.ckpt")) - t_loss_list, v_loss_list = train( - model, - train_data, - validation_data=validate_data, - n_minibatches=args.minibatches, - plot_at=check_training_at, - wav_at=check_training_at, - plot_kwargs=plot_kwargs, - wav_kwargs=wav_kwargs) - plot_predictions(model, validate_data, window=(0, 44100)) - print("Predict the full output") - predict(model, validate_data, - save_wav_file="{}/predict.wav".format( - os.path.dirname(model.checkpoint_path))) - - plot_loss(t_loss_list, v_loss_list, "{}/loss.png".format(args.save_dir))