commit 2efabe4a5d04658bb93dc2679f31a75d4c17527f
parent 1c2565b173a3b49150c890e15c0d20f7f21367be
Author: Steven Atkinson <[email protected]>
Date: Wed, 20 Apr 2022 23:21:05 -0700
Version 0.2.0
Diffstat:
35 files changed, 1307 insertions(+), 455 deletions(-)
diff --git a/.gitignore b/.gitignore
@@ -20,6 +20,8 @@ parts/
sdist/
var/
wheels/
+pip-wheel-metadata/
+share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
@@ -38,12 +40,14 @@ pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
+.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
+*.py,cover
.hypothesis/
.pytest_cache/
@@ -55,6 +59,7 @@ coverage.xml
*.log
local_settings.py
db.sqlite3
+db.sqlite3-journal
# Flask stuff:
instance/
@@ -72,11 +77,26 @@ target/
# Jupyter Notebook
.ipynb_checkpoints
+# IPython
+profile_default/
+ipython_config.py
+
# pyenv
.python-version
-# celery beat schedule file
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
celerybeat-schedule
+celerybeat.pid
# SageMath parsed files
*.sage.py
@@ -89,6 +109,8 @@ venv/
ENV/
env.bak/
venv.bak/
+venv-*/
+tmp/
# Spyder project settings
.spyderproject
@@ -102,3 +124,13 @@ venv.bak/
# mypy
.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# etc
+data/
+.vscode/
+lightning_logs/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
@@ -0,0 +1,7 @@
+repos:
+ - repo: https://github.com/psf/black
+ rev: stable
+ hooks:
+ - id: black
+ language_version: python3.7
+
+\ No newline at end of file
diff --git a/LICENSE b/LICENSE
@@ -1,6 +1,6 @@
MIT License
-Copyright (c) 2019 Steven Atkinson
+Copyright (c) 2022 Steven Atkinson
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
diff --git a/README.md b/README.md
@@ -1,24 +1,84 @@
# NAM: neural amp modeler
-Let's use deep learning to make a model of a guitar amp!
+This is the training part of NAM.
+For the code to create the plugin with a trained model, see my
+[iPlug2 fork](https://github.com/sdatkinson/iPlug2).
-## Setup
+## How to use
-Use `pip install -r requirements.txt` as well as intalling the TensorFlow version of your choice (e.g. `pip install tensorflow-gpu`).
+### Train a model
-## Autoregressive models
+You'll need at least two mono wav files: the input (DI) and the amped sound (without the cab).
+You can either record enough to have a training and validation set in the same file and
+split the file, or you can use 4 files (input/output for train/test).
+Also, you can provide _multiple_ file pairs for training (or validation).
-* Fully-connected (n-to-1)
+For the first option, Modify `bin/train/inputs/config_data_single_pair.json` to point at the audio files, and set the
+start/stop to the point (in samples) where the training segment ends and the validation
+starts.
+For the second option, modify and use `bin/train/inputs/config_data_two_pairs.json`.
-## To do
+Then run:
-* Other models
-* Real-time (eventually)
-* Model stereo outputs
-* Data classes for .wav data, keeping sample rate, etc
-* spek.cc?
+```bash
+python bin/train/main.py \
+bin/train/inputs/config_data.json \
+bin/train/inputs/config_model.json \
+bin/train/inputs/config_learning.json \
+bin/train/outputs/MyAmp
+```
+
+### Run a model on an input signal ("reamping")
+
+Handy if you want to just check it out without going through the trouble of building the
+plugin.
+
+```bash
+python bin/run.py \
+path/to/source.wav \
+path/to/config_model.json \
+path/to/checkpoints/epoch=123_val_loss=0.000010.ckpt \
+path/to/output.wav
+```
+
+### Export a model (to use with [the plugin](https://github.com/sdatkinson/iPlug2))
+
+Let's get ready to rock!
+
+```bash
+python bin/export.py \
+path/to/config_model.json \
+path/to/checkpoints/epoch=123_val_loss=0.000010.ckpt \
+path/to/exported_models/MyAmp
+```
+
+You'll want the `HardCodedModel.h` to paste over into the plugin source.
+
+## Advanced usage
+
+The model architecture in `config_model.json` should work plenty good. However, if you
+want to, you can increase the number of channels and the model will generally fit
+better (though it'll get closer to the threshold of real-time. 20 works for a "large"
+model and is still about 6x real time on my desktop).
+
+If you want to mess with the model architecture and end up with a different receptive
+field (e.g. by messing with the dilation pattern), then you need to make sure that `nx`
+is changed accordingly in the data setup.
+The default architecture has a receptive field of 8191 samples, so `nx` is `8191`.
+Generally, for the conv net architecture the receptive field is one elss than the sum of the `dilations`.
+
+You can train for shorter or longer.
+1000 gives pretty great results, but if you're impatient you can sometimes get away with
+comparable results after 500 epochs, and you might nto even be able to tell the
+difference with far fewer (maybe 200?...100?)
+
+- Other models
+- Real-time (eventually)
+- Model stereo outputs
+- Data classes for .wav data, keeping sample rate, etc
+- spek.cc?
## Other interesting tidbits?
-* Audio analysis plugins: vamp-plugins.org
-* Spectrogram: sonicvisualiser.org
+- Audio analysis plugins: vamp-plugins.org
+- Spectrogram: sonicvisualiser.org
diff --git a/architectures/fc_32_32_32.json b/architectures/fc_32_32_32.json
@@ -1,8 +0,0 @@
-{
- "type": "FullyConnected",
- "input_length": 32,
- "layer_sizes": [
- 32,
- 32
- ]
-}
-\ No newline at end of file
diff --git a/bin/export.py b/bin/export.py
@@ -0,0 +1,41 @@
+# File: export.py
+# Created Date: Sunday February 6th 2022
+# Author: Steven Atkinson ([email protected])
+
+"""
+Export a model to TorchScript
+"""
+
+import json
+from argparse import ArgumentParser
+from pathlib import Path
+
+import torch
+
+from nam.models import Model
+
+
+class Dummy(torch.nn.Module):
+ def forward(self, x):
+ return x[8191:]
+
+
+def main(args):
+ with open(args.model_config_path, "r") as fp:
+ net = Model.load_from_checkpoint(
+ args.checkpoint, **Model.parse_config(json.load(fp))
+ ).net
+ net.eval()
+ outdir = Path(args.outdir)
+ outdir.mkdir(parents=True, exist_ok=True)
+ net.export(outdir)
+ net.export_cpp_header(Path(outdir, "HardCodedModel.h"))
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser()
+ parser.add_argument("model_config_path", type=str)
+ parser.add_argument("checkpoint", type=str)
+ parser.add_argument("outdir")
+ parser.add_argument("--source-length", type=int, default=16384)
+ main(parser.parse_args())
diff --git a/bin/run.py b/bin/run.py
@@ -0,0 +1,33 @@
+# File: run.py
+# Created Date: Sunday February 6th 2022
+# Author: Steven Atkinson ([email protected])
+
+"""
+Load up a model, process a WAV, and save.
+"""
+
+import json
+from argparse import ArgumentParser
+
+from nam.data import wav_to_tensor, tensor_to_wav
+from nam.models import Model
+
+
+def main(args):
+ source = wav_to_tensor(args.source_path)
+ with open(args.model_config_path, "r") as fp:
+ model = Model.load_from_checkpoint(
+ args.checkpoint, **Model.parse_config(json.load(fp))
+ )
+ model.eval()
+ output = model(source)
+ tensor_to_wav(output, args.outfile)
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser()
+ parser.add_argument("source_path", type=str)
+ parser.add_argument("model_config_path", type=str)
+ parser.add_argument("checkpoint", type=str)
+ parser.add_argument("outfile")
+ main(parser.parse_args())
diff --git a/bin/train/inputs/config_data_single_pair.json b/bin/train/inputs/config_data_single_pair.json
@@ -0,0 +1,18 @@
+{
+ "train": {
+ "start": null,
+ "stop": 36576000,
+ "ny": 1024
+ },
+ "validation": {
+ "start": 36576000,
+ "stop": null,
+ "ny": null
+ },
+ "common": {
+ "x_path": "C:\\path\\to\\source.wav",
+ "y_path": "C:\\path\\to\\target.wav",
+ "nx": 8191,
+ "delay": 0
+ }
+}
+\ No newline at end of file
diff --git a/bin/train/inputs/config_data_two_pairs.json b/bin/train/inputs/config_data_two_pairs.json
@@ -0,0 +1,16 @@
+{
+ "train": {
+ "x_path": "C:\\path\\to\\train\\source.wav",
+ "y_path": "C:\\path\\to\\train\\target.wav",
+ "ny": 1024
+ },
+ "validation": {
+ "x_path": "C:\\path\\to\\validation\\source.wav",
+ "y_path": "C:\\path\\to\\validation\\target.wav",
+ "ny": null
+ },
+ "common": {
+ "nx": 8191,
+ "delay": 0
+ }
+}
+\ No newline at end of file
diff --git a/bin/train/inputs/config_learning.json b/bin/train/inputs/config_learning.json
@@ -0,0 +1,17 @@
+{
+ "train_dataloader": {
+ "batch_size": 32,
+ "shuffle": true,
+ "pin_memory": true,
+ "drop_last": true,
+ "num_workers": 4
+ },
+ "val_dataloader": {
+ },
+ "trainer": {
+ "gpus": 1,
+ "max_epochs": 1000
+ },
+ "trainer_fit_kwargs": {
+ }
+}
+\ No newline at end of file
diff --git a/bin/train/inputs/config_model.json b/bin/train/inputs/config_model.json
@@ -0,0 +1,25 @@
+{
+ "net": {
+ "name": "WaveNet",
+ "config": {
+ "channels": 16,
+ "dilations": [1,2,4,8,16,32,64,128,256,512,1024,2048,1,2,4,8,16,32,64,128,256,512,1024,2048],
+ "batchnorm": true,
+ "activation": "Tanh"
+ }
+ },
+ "optimizer": {
+ "lr": 0.003
+ },
+ "lr_scheduler": {
+ "class": "ReduceLROnPlateau",
+ "kwargs": {
+ "factor": 0.5,
+ "patience": 50,
+ "cooldown": 50,
+ "min_lr": 1.0e-5,
+ "verbose": true
+ },
+ "monitor": "val_loss"
+ }
+}
+\ No newline at end of file
diff --git a/bin/train/main.py b/bin/train/main.py
@@ -0,0 +1,138 @@
+# File: train.py
+# Created Date: Saturday February 5th 2022
+# Author: Steven Atkinson ([email protected])
+
+import json
+from argparse import ArgumentParser
+from datetime import datetime
+from pathlib import Path
+from time import time
+from typing import Optional, Union
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pytorch_lightning as pl
+import torch
+from torch.utils.data import DataLoader
+
+from nam.data import Split, init_dataset
+from nam.models import Model
+
+torch.manual_seed(0)
+
+
+def timestamp() -> str:
+ t = datetime.now()
+ return f"{t.year:04d}-{t.month:02d}-{t.day:02d}-{t.hour:02d}-{t.minute:02d}-{t.second:02d}"
+
+
+def ensure_outdir(outdir: str) -> Path:
+ outdir = Path(outdir, timestamp())
+ outdir.mkdir(parents=True, exist_ok=False)
+ return outdir
+
+
+def _rms(x: Union[np.ndarray, torch.Tensor]) -> float:
+ if isinstance(x, np.ndarray):
+ return np.sqrt(np.mean(np.square(x)))
+ elif isinstance(x, torch.Tensor):
+ return torch.sqrt(torch.mean(torch.square(x))).item()
+ else:
+ raise TypeError(type(x))
+
+
+def plot(
+ model,
+ ds,
+ savefig=None,
+ show=True,
+ window_start: Optional[int] = None,
+ window_end: Optional[int] = None,
+):
+ with torch.no_grad():
+ tx = len(ds.x) / 48_000
+ print(f"Run (t={tx})")
+ t0 = time()
+ output = model(ds.x).flatten().cpu().numpy()
+ t1 = time()
+ print(f"Took {t1 - t0} ({tx / (t1 - t0):.2f}x)")
+
+ plt.figure(figsize=(16, 5))
+ plt.plot(ds.x[window_start:window_end], label="Input")
+ plt.plot(output[window_start:window_end], label="Output")
+ plt.plot(ds.y[window_start:window_end], label="Target")
+ plt.title(f"NRMSE={_rms(torch.Tensor(output) - ds.y) / _rms(ds.y)}")
+ plt.legend()
+ if savefig is not None:
+ plt.savefig(savefig)
+ if show:
+ plt.show()
+
+
+def main(args):
+ outdir = ensure_outdir(args.outdir)
+ # Read
+ with open(args.data_config_path, "r") as fp:
+ data_config = json.load(fp)
+ with open(args.model_config_path, "r") as fp:
+ model_config = json.load(fp)
+ with open(args.learning_config_path, "r") as fp:
+ learning_config = json.load(fp)
+ # Write
+ for basename, config in (
+ ("data", data_config),
+ ("model", model_config),
+ ("learning", learning_config),
+ ):
+ with open(Path(outdir, f"config_{basename}.json"), "w") as fp:
+ json.dump(config, fp, indent=4)
+
+ model = Model.init_from_config(model_config)
+
+ dataset_train = init_dataset(data_config, Split.TRAIN)
+ dataset_validation = init_dataset(data_config, Split.VALIDATION)
+ train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"])
+ val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"])
+
+ # ckpt_path = Path(outdir, "checkpoints")
+ # ckpt_path.mkdir()
+ trainer = pl.Trainer(
+ callbacks=[
+ pl.callbacks.model_checkpoint.ModelCheckpoint(
+ filename="{epoch}_{val_loss:.6f}",
+ save_top_k=3,
+ monitor="val_loss",
+ every_n_epochs=1,
+ ),
+ pl.callbacks.model_checkpoint.ModelCheckpoint(
+ filename="checkpoint_last_{epoch:04d}", every_n_epochs=1
+ ),
+ ],
+ default_root_dir=outdir,
+ **learning_config["trainer"],
+ )
+ trainer.fit(
+ model,
+ train_dataloader,
+ val_dataloader,
+ **learning_config.get("trainer_fit_kwargs", {}),
+ )
+ model.eval()
+ plot(
+ model,
+ dataset_validation,
+ savefig=Path(outdir, "comparison.png"),
+ window_start=100_000,
+ window_end=110_000,
+ show=False,
+ )
+ plot(model, dataset_validation)
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser()
+ parser.add_argument("data_config_path", type=str)
+ parser.add_argument("model_config_path", type=str)
+ parser.add_argument("learning_config_path", type=str)
+ parser.add_argument("outdir")
+ main(parser.parse_args())
diff --git a/data/emissary/train.npy b/data/emissary/train.npy
Binary files differ.
diff --git a/data/emissary/validate.npy b/data/emissary/validate.npy
Binary files differ.
diff --git a/data/reaper-js-distortion/train.npy b/data/reaper-js-distortion/train.npy
Binary files differ.
diff --git a/data/reaper-js-distortion/validate.npy b/data/reaper-js-distortion/validate.npy
Binary files differ.
diff --git a/environment.yml b/environment.yml
@@ -0,0 +1,25 @@
+# File: environment.yml
+# Created Date: Saturday February 13th 2021
+# Author: Steven Atkinson ([email protected])
+
+name: nam2
+channels:
+ - pytorch
+dependencies:
+ - black
+ - flake8
+ - h5py
+ - jupyter
+ - matplotlib
+ - numpy
+ - pip
+ - pytest
+ - pytorch
+ - scipy
+ - tqdm
+ - wheel
+ - pip:
+ - pre-commit
+ - pytorch_lightning
+ - sounddevice
+ - wavio
diff --git a/models.py b/models.py
@@ -1,165 +0,0 @@
-# File: models.py
-# File Created: Sunday, 30th December 2018 9:42:29 pm
-# Author: Steven Atkinson ([email protected])
-
-import numpy as np
-import tensorflow as tf
-import abc
-from tempfile import mkdtemp
-import os
-import json
-
-
-def from_json(f, n_train=1, checkpoint_path=None):
- if isinstance(f, str):
- with open(f, "r") as json_file:
- f = json.load(json_file)
-
- if f["type"] == "FullyConnected":
- return FullyConnected(n_train, f["input_length"],
- layer_sizes=f["layer_sizes"], checkpoint_path=checkpoint_path)
- else:
- raise NotImplementedError("Model type {} unrecognized".format(
- f["type"]))
-
-
-class Model(object):
- """
- Model parent class
- """
- def __init__(self, n_train, sess=None, checkpoint_path=None):
- """
- Make sure child classes call _build() after this!
- """
- if sess is None:
- sess = tf.get_default_session()
- self.sess = sess
-
- if checkpoint_path is None:
- checkpoint_path = os.path.join(mkdtemp(), "model.ckpt")
- if not os.path.isdir(os.path.dirname(checkpoint_path)):
- os.makedirs(os.path.dirname(checkpoint_path))
- self.checkpoint_path = checkpoint_path
-
- # self._batch_size = batch_size
- self._n_train = n_train
- self.target = tf.placeholder(tf.float32, shape=(None, 1))
- self.prediction = None
- self.total_prediction_loss = None
- self.rmse = None
- self.loss = None # Includes any regularization
-
- # @property
- # def batch_size(self):
- # return self._batch_size
-
- @property
- def n_train(self):
- return self._n_train
-
- def load(self, checkpoint_path=None):
- checkpoint_path = checkpoint_path or self.checkpoint_path
-
- try:
- ckpt = tf.train.get_checkpoint_state(checkpoint_path)
- print("Loading model: {}".format(ckpt.model_checkpoint_path))
- self.saver.restore(self.sess, ckpt.model_checkpoint_path)
- except Exception as e:
- print("Error while attempting to load model: {}".format(e))
-
- @abc.abstractclassmethod
- def predict(self, x):
- """
- A nice function for prediction.
- :param x: input array (length=n)
- :type x: array-like
- :return: (array-like) corresponding predicted outputs (length=n)
- """
- raise NotImplementedError("Implement predict()")
-
- def save(self, iter, checkpoint_path=None):
- checkpoint_path = checkpoint_path or self.checkpoint_path
- self.saver.save(self.sess, checkpoint_path, global_step=iter)
-
- def _build(self):
- self.prediction = self._build_prediction()
- self.loss = self._build_loss()
-
- # Launch the session
- self.sess.run(tf.global_variables_initializer())
- self.saver = tf.train.Saver(tf.global_variables())
-
- def _build_loss(self):
- self.total_prediction_loss = tf.losses.mean_squared_error(self.target,
- self.prediction, weights=self.n_train)
-
- # Don't count this as a loss!
- self.rmse = tf.sqrt(
- self.total_prediction_loss / self.n_train)
-
- return tf.losses.get_total_loss()
-
- @abc.abstractclassmethod
- def _build_prediction(self):
- raise NotImplementedError('Implement prediction for model')
-
-
-class Autoregressive(Model):
- """
- Autoregressive models that take in a few of the most recent input samples
- and predict the output at the last time point.
- """
- def __init__(self, n_train, input_length, sess=None,
- checkpoint_path=None):
- super().__init__(n_train, sess=sess, checkpoint_path=checkpoint_path)
- self._input_length = input_length
- self.x = tf.placeholder(tf.float32, shape=(None, self.input_length))
-
- @property
- def input_length(self):
- return self._input_length
-
- def predict(self, x, batch_size=None, verbose=False):
- """
- Return 1D array of predictions same length as x
- """
- n = x.size
- batch_size = batch_size or n
- # Pad x with leading zeros:
- x = np.concatenate((np.zeros(self.input_length - 1), x))
- i = 0
- y = []
- while i < n:
- if verbose:
- print("model.predict {}/{}".format(i, n))
- this_batch_size = np.min([batch_size, n - i])
- # Reshape into a batch:
- x_mtx = np.stack([x[j: j + self.input_length]
- for j in range(i, i + this_batch_size)])
- # Predict and flatten.
- y.append(self.sess.run(self.prediction, feed_dict={self.x: x_mtx}) \
- .flatten())
- i += this_batch_size
- return np.concatenate(y)
-
-
-class FullyConnected(Autoregressive):
- """
- Autoregressive model taking in a sequence of the most recent inputs, putting
- them through a series of FC layers, and outputting the single output at the
- last time step.
- """
- def __init__(self, n_train, input_length, layer_sizes=(512,),
- sess=None, checkpoint_path=None):
- super().__init__(n_train, input_length, sess=sess,
- checkpoint_path=checkpoint_path)
- self._layer_sizes = layer_sizes
- self._build()
-
- def _build_prediction(self):
- h = self.x
- for m in self._layer_sizes:
- h = tf.contrib.layers.fully_connected(h, m)
- y = -1.0 + 2.0 * tf.contrib.layers.fully_connected(h, 1,
- activation_fn=tf.nn.sigmoid)
- return y
diff --git a/nam/__init__.py b/nam/__init__.py
@@ -0,0 +1,9 @@
+# File: __init__.py
+# File Created: Tuesday, 2nd February 2021 9:42:50 pm
+# Author: Steven Atkinson ([email protected])
+
+from ._version import __version__ # Must be before models or else circular
+
+from . import _core # noqa F401
+from . import data # noqa F401
+from . import models # noqa F401
diff --git a/nam/_core.py b/nam/_core.py
@@ -0,0 +1,13 @@
+# File: core.py
+# Created Date: Saturday February 5th 2022
+# Author: Steven Atkinson ([email protected])
+
+
+class InitializableFromConfig(object):
+ @classmethod
+ def init_from_config(cls, config):
+ return cls(**cls.parse_config(config))
+
+ @classmethod
+ def parse_config(cls, config):
+ return config
diff --git a/nam/_version.py b/nam/_version.py
@@ -0,0 +1 @@
+__version__ = "0.2.0"
diff --git a/nam/data.py b/nam/data.py
@@ -0,0 +1,221 @@
+# File: data.py
+# Created Date: Saturday February 5th 2022
+# Author: Steven Atkinson ([email protected])
+
+import abc
+from dataclasses import dataclass
+from enum import Enum
+from pathlib import Path
+from typing import Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+import wavio
+from torch.utils.data import Dataset as _Dataset
+
+from ._core import InitializableFromConfig
+
+_REQUIRED_SAMPWIDTH = 3
+_REQUIRED_RATE = 48_000
+_REQUIRED_CHANNELS = 1 # Mono
+
+
+class Split(Enum):
+ TRAIN = "train"
+ VALIDATION = "validation"
+
+
+@dataclass
+class WavInfo:
+ sampwidth: int
+ rate: int
+
+
+def wav_to_np(
+ filename: Union[str, Path],
+ require_match: Optional[Union[str, Path]] = None,
+ required_shape: Optional[Tuple[int]] = None,
+ required_wavinfo: Optional[WavInfo] = None,
+ preroll: Optional[int] = None,
+ info: bool = False,
+) -> Union[np.ndarray, Tuple[np.ndarray, WavInfo]]:
+ """
+ :param preroll: Drop this many samples off the front
+ """
+ x_wav = wavio.read(str(filename))
+ assert x_wav.data.shape[1] == _REQUIRED_CHANNELS, "Mono"
+ assert x_wav.sampwidth == _REQUIRED_SAMPWIDTH, "24-bit"
+ assert x_wav.rate == _REQUIRED_RATE, "48 kHz"
+
+ if require_match is not None:
+ assert required_shape is None
+ assert required_wavinfo is None
+ y_wav = wavio.read(str(require_match))
+ required_shape = y_wav.data.shape
+ required_wavinfo = WavInfo(y_wav.sampwidth, y_wav.rate)
+ if required_wavinfo is not None:
+ if x_wav.rate != required_wavinfo.rate:
+ raise ValueError(
+ f"Mismatched rates {x_wav.rate} versus {required_wavinfo.rate}"
+ )
+ arr_premono = x_wav.data[preroll:] / (2.0 ** (8 * x_wav.sampwidth - 1))
+ if required_shape is not None:
+ if arr_premono.shape != required_shape:
+ raise ValueError(
+ f"Mismatched shapes {arr_premono.shape} versus {required_shape}"
+ )
+ # sampwidth fine--we're just casting to 32-bit float anyways
+ arr = arr_premono[:, 0]
+ return arr if not info else (arr, WavInfo(x_wav.sampwidth, x_wav.rate))
+
+
+def wav_to_tensor(
+ *args, info: bool = False, **kwargs
+) -> Union[torch.Tensor, Tuple[torch.Tensor, WavInfo]]:
+ out = wav_to_np(*args, info=info, **kwargs)
+ if info:
+ arr, info = out
+ return torch.Tensor(arr), info
+ else:
+ arr = out
+ return torch.Tensor(arr)
+
+
+def tensor_to_wav(
+ x: torch.Tensor,
+ filename: Union[str, Path],
+ rate: int = 48_000,
+ sampwidth: int = 3,
+ scale="none",
+):
+ wavio.write(
+ filename,
+ (torch.clamp(x, -1.0, 1.0) * (2 ** (8 * sampwidth - 1)))
+ .detach()
+ .cpu()
+ .numpy()
+ .astype(np.int32),
+ rate,
+ scale=scale,
+ sampwidth=sampwidth,
+ )
+
+
+class AbstractDataset(_Dataset, abc.ABC):
+ @abc.abstractmethod
+ def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:
+ pass
+
+
+class Dataset(AbstractDataset, InitializableFromConfig):
+ """
+ Take a pair of matched audio files and serve input + output pairs
+ """
+
+ def __init__(
+ self,
+ x: torch.Tensor,
+ y: torch.Tensor,
+ nx: int,
+ ny: Optional[int],
+ start: Optional[int] = None,
+ stop: Optional[int] = None,
+ delay: Optional[int] = None,
+ ):
+ """
+ :param start: In samples
+ :param stop: In samples
+ :param delay: In samples. Positive means we get rid of the start of x, end of y.
+ """
+ x, y = [z[start:stop] for z in (x, y)]
+ if delay is not None:
+ if delay > 0:
+ x = x[:-delay]
+ y = y[delay:]
+ else:
+ x = x[-delay:]
+ y = y[:delay]
+ self._validate_inputs(x, y, nx, ny)
+ self._x = x
+ self._y = y
+ self._nx = nx
+ self._ny = ny if ny is not None else len(x) - nx + 1
+
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ if idx >= len(self):
+ raise IndexError(f"Attempted to access datum {idx}, but len is {len(self)}")
+ i = idx * self._ny
+ j = i + self.y_offset
+ return self.x[i : i + self._nx + self._ny - 1], self.y[j : j + self._ny]
+
+ def __len__(self) -> int:
+ n = len(self.x)
+ # If ny were 1
+ single_pairs = n - self._nx + 1
+ return single_pairs // self._ny
+
+ @property
+ def x(self):
+ return self._x
+
+ @property
+ def y(self):
+ return self._y
+
+ @property
+ def y_offset(self) -> int:
+ return self._nx - 1
+
+ @classmethod
+ def parse_config(cls, config):
+ x, x_wavinfo = wav_to_tensor(config["x_path"], info=True)
+ y = wav_to_tensor(
+ config["y_path"],
+ preroll=config.get("y_preroll"),
+ required_shape=(len(x), 1),
+ required_wavinfo=x_wavinfo,
+ )
+ return {
+ "x": x,
+ "y": y,
+ "nx": config["nx"],
+ "ny": config["ny"],
+ "start": config.get("start"),
+ "stop": config.get("stop"),
+ }
+
+ def _validate_inputs(self, x, y, nx, ny):
+ assert x.ndim == 1
+ assert y.ndim == 1
+ assert len(x) == len(y)
+ assert nx <= len(x)
+ if ny is not None:
+ assert ny <= len(y) - nx + 1
+
+
+class ConcatDataset(AbstractDataset, InitializableFromConfig):
+ def __init__(self, datasets: Sequence[Dataset]):
+ self._datasets = datasets
+
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ for d in self._datasets:
+ if idx < len(d):
+ return d[idx]
+ else:
+ idx = idx - len(d)
+
+ def __len__(self) -> int:
+ return sum(len(d) for d in self._datasets)
+
+ @classmethod
+ def parse_config(cls, config):
+ return {"datasets": tuple(Dataset.init_from_config(c) for c in config)}
+
+
+def init_dataset(config, split: Split) -> AbstractDataset:
+ base_config = config[split.value]
+ common = config.get("common", {})
+ if isinstance(base_config, dict):
+ return Dataset.init_from_config({**common, **base_config})
+ elif isinstance(base_config, list):
+ return ConcatDataset.init_from_config([{**common, **c} for c in base_config])
diff --git a/nam/models/__init__.py b/nam/models/__init__.py
@@ -0,0 +1,9 @@
+# File: __init__.py
+# Created Date: Saturday February 5th 2022
+# Author: Steven Atkinson ([email protected])
+
+from . import _base # noqa F401
+from . import _exportable # noqa F401
+from .base import Model # noqa F401
+from .linear import Linear # noqa F401
+from .ConvNet import ConvNet # noqa F401
diff --git a/nam/models/_base.py b/nam/models/_base.py
@@ -0,0 +1,58 @@
+# File: _base.py
+# Created Date: Tuesday February 8th 2022
+# Author: Steven Atkinson ([email protected])
+
+import abc
+from typing import Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from .._core import InitializableFromConfig
+from ._exportable import Exportable
+
+
+class BaseNet(nn.Module, InitializableFromConfig, Exportable):
+ @abc.abstractproperty
+ def pad_start_default(self) -> bool:
+ pass
+
+ @abc.abstractproperty
+ def receptive_field(self) -> int:
+ """
+ Receptive field of the model
+ """
+ pass
+
+ def forward(self, x: torch.Tensor, pad_start: bool = None):
+ pad_start = self.pad_start_default if pad_start is None else pad_start
+ scalar = x.ndim == 1
+ if scalar:
+ x = x[None]
+ if pad_start:
+ x = torch.cat((torch.zeros((len(x), self.receptive_field - 1)), x), dim=1)
+ y = self._forward(x)
+ if scalar:
+ y = y[0]
+ return y
+
+ @abc.abstractmethod
+ def _forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ The true forward method.
+
+ :param x: (N,L1)
+ :return: (N,L1-RF+1)
+ """
+ pass
+
+ def _test_signal(
+ self, seed=0, extra_length=13
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ x = torch.Tensor(
+ np.random.default_rng(seed).normal(
+ size=(self.receptive_field + extra_length,)
+ )
+ )
+ return x, self(x, pad_start=False)
diff --git a/nam/models/_exportable.py b/nam/models/_exportable.py
@@ -0,0 +1,33 @@
+# File: _exportable.py
+# Created Date: Tuesday February 8th 2022
+# Author: Steven Atkinson ([email protected])
+
+import abc
+from pathlib import Path
+
+
+class Exportable(abc.ABC):
+ """
+ Interface for my custon export format for use in the plugin.
+ """
+
+ @abc.abstractmethod
+ def export(self, outdir: Path):
+ """
+ Interface for exporting.
+ You should create at least a `config.json` containing the two fields:
+ * "version" (str)
+ * "architecture" (str)
+ * "config": (dict w/ other necessary data like tensor shapes etc)
+
+ :param outdir: Assumed to exist. Can be edited inside at will.
+ """
+ pass
+
+ @abc.abstractmethod
+ def export_cpp_header(self, filename: Path):
+ """
+ Export a .h file to compile into the plugin with the weights written right out
+ as text
+ """
+ pass
diff --git a/nam/models/base.py b/nam/models/base.py
@@ -0,0 +1,97 @@
+# File: base.py
+# Created Date: Saturday February 5th 2022
+# Author: Steven Atkinson ([email protected])
+
+"""
+Lightning stuff
+"""
+
+from typing import Optional
+
+import pytorch_lightning as pl
+import torch
+import torch.nn as nn
+
+from .._core import InitializableFromConfig
+from .linear import Linear
+from .conv_net import ConvNet
+
+
+class Model(pl.LightningModule, InitializableFromConfig):
+ def __init__(
+ self,
+ net,
+ optimizer_config: Optional[dict] = None,
+ scheduler_config: Optional[dict] = None,
+ ):
+ """
+ :param scheduler_config: contains
+ Required:
+ * "class"
+ * "kwargs"
+ Optional (defaults to Lightning defaults):
+ * "interval" ("epoch" of "step")
+ * "frequency" (int)
+ * "monitor" (str)
+ """
+ super().__init__()
+ self._net = net
+ self._optimizer_config = {} if optimizer_config is None else optimizer_config
+ self._scheduler_config = scheduler_config
+
+ @classmethod
+ def init_from_config(cls, config):
+ checkpoint_path = config.get("checkpoint_path")
+ config = cls.parse_config(config)
+ return (
+ cls(**config)
+ if checkpoint_path is None
+ else cls.load_from_checkpoint(checkpoint_path, **config)
+ )
+
+ @classmethod
+ def parse_config(cls, config):
+ config = super().parse_config(config)
+ net_config = config["net"]
+ net = {"Linear": Linear.init_from_config, "ConvNet": ConvNet.init_from_config}[
+ net_config["name"]
+ ](net_config["config"])
+ return {
+ "net": net,
+ "optimizer_config": config["optimizer"],
+ "scheduler_config": config["lr_scheduler"],
+ }
+
+ @property
+ def net(self) -> nn.Module:
+ return self._net
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.Adam(self.parameters(), **self._optimizer_config)
+ if self._scheduler_config is None:
+ return optimizer
+ else:
+ lr_scheduler = getattr(
+ torch.optim.lr_scheduler, self._scheduler_config["class"]
+ )(optimizer, **self._scheduler_config["kwargs"])
+ lr_scheduler_config = {"scheduler": lr_scheduler}
+ for key in ("interval", "frequency", "monitor"):
+ if key in self._scheduler_config:
+ lr_scheduler_config[key] = self._scheduler_config[key]
+ return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
+
+ def forward(self, *args, **kwargs):
+ return self.net(*args, **kwargs)
+
+ def _shared_step(self, batch):
+ sources, targets = batch
+ preds = self(sources, pad_start=False)
+ return nn.MSELoss()(preds, targets)
+
+ def training_step(self, batch, batch_idx):
+ return self._shared_step(batch)
+
+ def validation_step(self, batch, batch_idx):
+ val_loss = self._shared_step(batch)
+ self.log_dict({"val_loss": val_loss})
+ return val_loss
diff --git a/nam/models/conv_net.py b/nam/models/conv_net.py
@@ -0,0 +1,290 @@
+# File: conv_net.py
+# Created Date: Saturday February 5th 2022
+# Author: Steven Atkinson ([email protected])
+
+import json
+from enum import Enum
+from functools import partial
+from pathlib import Path
+from tempfile import TemporaryDirectory
+from typing import Optional, Sequence, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+from .. import __version__
+from ..data import wav_to_tensor
+from ._base import BaseNet
+
+_CONV_NAME = "conv"
+_BATCHNORM_NAME = "batchnorm"
+_ACTIVATION_NAME = "activation"
+
+
+class TrainStrategy(Enum):
+ STRIDE = "stride"
+ DILATE = "dilate"
+
+
+default_train_strategy = TrainStrategy.DILATE
+
+
+class _Functional(nn.Module):
+ """
+ Define a layer by a function w/ no params
+ """
+
+ def __init__(self, op):
+ super().__init__()
+ self._op = op
+
+ def forward(self, *args, **kwargs):
+ return self._op(*args, **kwargs)
+
+
+class _IR(nn.Module):
+ def __init__(self, filename: Union[str, Path]):
+ super().__init__()
+ self.register_buffer("_weight", reversed(wav_to_tensor(filename))[None, None])
+
+ @property
+ def length(self) -> int:
+ return self._weight.shape[-1]
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ :param x: (N,D)
+ :return: (N,D-length+1)
+ """
+ return F.conv1d(x[:, None], self._weight)[:, 0]
+
+
+def _conv_net(
+ channels: int = 32,
+ dilations: Sequence[int] = None,
+ batchnorm: bool = False,
+ activation: str = "Tanh",
+) -> nn.Sequential:
+ def block(cin, cout, dilation):
+ net = nn.Sequential()
+ net.add_module(
+ _CONV_NAME, nn.Conv1d(cin, cout, 2, dilation=dilation, bias=not batchnorm)
+ )
+ if batchnorm:
+ net.add_module(_BATCHNORM_NAME, nn.BatchNorm1d(cout))
+ net.add_module(_ACTIVATION_NAME, getattr(nn, activation)())
+ return net
+
+ def check_and_expand(n, x):
+ if x.shape[1] < n:
+ raise ValueError(
+ f"Input of length {x.shape[1]} is shorter than model receptive field ({n})"
+ )
+ return x[:, None, :]
+
+ dilations = [1, 2, 4, 8] if dilations is None else dilations
+ receptive_field = sum(dilations) + 1
+ net = nn.Sequential()
+ net.add_module("expand", _Functional(partial(check_and_expand, receptive_field)))
+ cin = 1
+ cout = channels
+ for i, dilation in enumerate(dilations):
+ net.add_module(f"block_{i}", block(cin, cout, dilation))
+ cin = cout
+ net.add_module("head", nn.Conv1d(channels, 1, 1))
+ net.add_module("flatten", nn.Flatten())
+ return net
+
+
+class ConvNet(BaseNet):
+ """
+ A straightforward convolutional neural network.
+
+ Works surprisingly well!
+ """
+
+ def __init__(
+ self,
+ *args,
+ train_strategy: TrainStrategy = default_train_strategy,
+ ir: Optional[_IR] = None,
+ **kwargs,
+ ):
+ super().__init__()
+ self._net = _conv_net(*args, **kwargs)
+ assert train_strategy == TrainStrategy.DILATE, "Stride no longer supported"
+ self._train_strategy = train_strategy
+ self._num_blocks = self._get_num_blocks(self._net)
+ self._pad_start_default = True
+ self._ir = ir
+
+ @classmethod
+ def parse_config(cls, config):
+ config = super().parse_config(config)
+ config["train_strategy"] = TrainStrategy(
+ config.get("train_strategy", default_train_strategy.value)
+ )
+ config["ir"] = (
+ None if "ir_filename" not in config else _IR(config.pop("ir_filename"))
+ )
+ return config
+
+ @property
+ def pad_start_default(self) -> bool:
+ return self._pad_start_default
+
+ @property
+ def receptive_field(self) -> int:
+ net_rf = 1 + sum(
+ self._net._modules[f"block_{i}"]._modules["conv"].dilation[0]
+ for i in range(self._num_blocks)
+ )
+ # Minus 1 because it composes w/ the net
+ ir_rf = 0 if self._ir is None else self._ir.length - 1
+ return net_rf + ir_rf
+
+ @property
+ def _activation(self):
+ return (
+ self._net._modules["block_0"]._modules[_ACTIVATION_NAME].__class__.__name__
+ )
+
+ @property
+ def _channels(self) -> int:
+ return self._net._modules["block_0"]._modules[_CONV_NAME].weight.shape[0]
+
+ @property
+ def _num_layers(self) -> int:
+ return self._num_blocks
+
+ @property
+ def _batchnorm(self) -> bool:
+ return _BATCHNORM_NAME in self._net._modules["block_0"]._modules
+
+ def export(self, outdir: Path):
+ """
+ Files created:
+ * config.json
+ * weights.npy
+ * input.npy
+ * output.npy
+
+ weights are serialized to weights.npy in the following order:
+ * (expand: no params)
+ * loop blocks 0,...,L-1
+ * conv:
+ * weight (Cout, Cin, K)
+ * bias (if no batchnorm) (Cout)
+ * BN
+ * running mean
+ * running_var
+ * weight (Cout)
+ * bias (Cout)
+ * eps ()
+ * head
+ * weight (C, 1, 1)
+ * bias (1, 1)
+ * (flatten: no params)
+
+ A test input & output are also provided, input.npy and output.npy
+ """
+ training = self.training
+ self.eval()
+ with open(Path(outdir, "config.json"), "w") as fp:
+ json.dump(
+ {
+ "version": __version__,
+ "architecture": "ConvNet",
+ "config": {
+ "channels": self._channels,
+ "dilations": self._get_dilations(),
+ "batchnorm": self._batchnorm,
+ "activation": self._activation,
+ },
+ },
+ fp,
+ indent=4,
+ )
+
+ params = []
+ for i in range(self._num_layers):
+ block_name = f"block_{i}"
+ block = self._net._modules[block_name]
+ conv = block._modules[_CONV_NAME]
+ params.append(conv.weight.flatten())
+ if conv.bias is not None:
+ params.append(conv.bias.flatten())
+ if self._batchnorm:
+ bn = block._modules[_BATCHNORM_NAME]
+ params.append(bn.running_mean.flatten())
+ params.append(bn.running_var.flatten())
+ params.append(bn.weight.flatten())
+ params.append(bn.bias.flatten())
+ params.append(torch.Tensor([bn.eps]).to(bn.weight.device))
+ head = self._net._modules["head"]
+ params.append(head.weight.flatten())
+ params.append(head.bias.flatten())
+ params = torch.cat(params).detach().cpu().numpy()
+ # Hope I don't regret using np.save...
+ np.save(Path(outdir, "weights.npy"), params)
+
+ # And an input/output to verify correct computation:
+ x, y = self._test_signal()
+ np.save(Path(outdir, "input.npy"), x.detach().cpu().numpy())
+ np.save(Path(outdir, "output.npy"), y.detach().cpu().numpy())
+
+ # And resume training state
+ self.train(training)
+
+ def export_cpp_header(self, filename: Path):
+ with TemporaryDirectory() as tmpdir:
+ tmpdir = Path(tmpdir)
+ self.export(Path(tmpdir))
+ with open(Path(tmpdir, "config.json"), "r") as fp:
+ _c = json.load(fp)
+ version = _c["version"]
+ config = _c["config"]
+ with open(filename, "w") as f:
+ f.writelines(
+ (
+ "#pragma once\n",
+ "// Automatically-generated model file\n",
+ "#include <vector>\n",
+ f'#define PYTHON_MODEL_VERSION "{version}"\n',
+ f"const int CHANNELS = {config['channels']};\n",
+ f"const bool BATCHNORM = {'true' if config['batchnorm'] else 'false'};\n",
+ "std::vector<int> DILATIONS{"
+ + ",".join([str(d) for d in config["dilations"]])
+ + "};\n",
+ f"const std::string ACTIVATION = \"{config['activation']}\";\n",
+ "std::vector<float> PARAMS{"
+ + ",".join(
+ [f"{w:.16f}" for w in np.load(Path(tmpdir, "weights.npy"))]
+ )
+ + "};\n",
+ )
+ )
+
+ def _forward(self, x):
+ y = self._net(x)
+ if self._ir is not None:
+ y = self._ir(y)
+ return y
+
+ def _get_dilations(self) -> Tuple[int]:
+ return tuple(
+ self._net._modules[f"block_{i}"]._modules[_CONV_NAME].dilation[0]
+ for i in range(self._num_blocks)
+ )
+
+ def _get_num_blocks(self, net: nn.Sequential):
+ i = 0
+ while True:
+ if f"block_{i}" not in net._modules:
+ break
+ else:
+ i += 1
+ return i
diff --git a/nam/models/linear.py b/nam/models/linear.py
@@ -0,0 +1,70 @@
+# File: linear.py
+# Created Date: Tuesday February 8th 2022
+# Author: Steven Atkinson ([email protected])
+
+"""
+Linear model
+"""
+
+import json
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from .._version import __version__
+from ._base import BaseNet
+
+
+class Linear(BaseNet):
+ def __init__(self, receptive_field: int, bias: bool = False):
+ super().__init__()
+ self._net = nn.Conv1d(1, 1, receptive_field, bias=bias)
+
+ @property
+ def pad_start_default(self) -> bool:
+ return True
+
+ @property
+ def receptive_field(self) -> int:
+ return self._net.weight.shape[2]
+
+ def export(self, outdir: Path):
+ training = self.training
+ self.eval()
+ with open(Path(outdir, "config.json"), "w") as fp:
+ json.dump(
+ {
+ "version": __version__,
+ "architecture": self.__class__.__name__,
+ "config": {
+ "receptive_field": self.receptive_field,
+ "bias": self._bias,
+ },
+ },
+ fp,
+ indent=4,
+ )
+
+ params = [self._net.weight.flatten()]
+ if self._bias:
+ params.append(self._net.bias.flatten())
+ params = torch.cat(params).detach().cpu().numpy()
+ # Hope I don't regret using np.save...
+ np.save(Path(outdir, "weights.npy"), params)
+
+ # And an input/output to verify correct computation:
+ x, y = self._test_signal()
+ np.save(Path(outdir, "input.npy"), x.detach().cpu().numpy())
+ np.save(Path(outdir, "output.npy"), y.detach().cpu().numpy())
+
+ # And resume training state
+ self.train(training)
+
+ @property
+ def _bias(self) -> bool:
+ return self._net.bias is not None
+
+ def _forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self._net(x[:, None])[:, 0]
diff --git a/reamp.py b/reamp.py
@@ -1,66 +0,0 @@
-"""
-Reamp a .wav file
-
-Assumes 24-bit WAV files
-"""
-
-from argparse import ArgumentParser
-import tensorflow as tf
-import os
-import wavio
-import matplotlib.pyplot as plt
-
-import models
-
-
-def _sampwidth_to_bits(x):
- return {2: 16, 3: 24, 4: 32}[x]
-
-
-if __name__ == "__main__":
- parser = ArgumentParser()
- parser.add_argument("architecture", type=str,
- help="JSON filename containing NN architecture")
- parser.add_argument("checkpoint_dir", type=str,
- help="directory holding model checkpoint to use")
- parser.add_argument("input_file", type=str,
- help="Input .wav file to convert")
- parser.add_argument("--output_file", type=str, default=None,
- help="Where to save the output")
- parser.add_argument("--batch_size", type=int, default=8192,
- help="How many samples to process at a time. " +
- "Reduce if there are out-of-memory issues.")
- parser.add_argument("--target_file", type=str, default=None,
- help=".wav file of the true output (if you want to compare)")
- args = parser.parse_args()
-
- if args.output_file is None:
- args.output_file = args.input_file.rstrip(".wav") + "_reamped.wav"
-
- if os.path.isfile(args.output_file):
- print("Output file exists; skip")
- exit(1)
-
- x = wavio.read(args.input_file)
- rate, sampwidth = x.rate, x.sampwidth
- bits = _sampwidth_to_bits(sampwidth)
- x_data = x.data.flatten() / 2 ** (bits - 1)
-
- with tf.Session() as sess:
- model = models.from_json(args.architecture,
- checkpoint_path=args.checkpoint_dir)
- model.load()
- y = model.predict(x_data, batch_size=args.batch_size, verbose=True)
- wavio.write(args.output_file, y * 2 ** (bits - 1), rate, scale="none",
- sampwidth=sampwidth)
-
- if args.target_file is not None and os.path.isfile(args.target_file):
- t = wavio.read(args.target_file)
- t_data = t.data.flatten() / 2 ** (_sampwidth_to_bits(t.sampwidth) - 1)
- plt.figure()
- plt.plot(x_data)
- plt.plot(t_data)
- plt.plot(y)
- plt.legend(["Input", "Target", "Prediction"])
- plt.show()
-
-\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
@@ -1,2 +1,18 @@
-tensorflow
+# File: requirements.txt
+# Created Date: 2021-01-24
+# Author: Steven Atkinson ([email protected])
+
+black
+flake8
+matplotlib
+numpy
+pip
+pre-commit
+pytest
+pytorch_lightning
+scipy
+sounddevice
+torch
+tqdm
wavio
+wheel
diff --git a/setup.py b/setup.py
@@ -0,0 +1,31 @@
+# File: setup.py
+# Created Date: 2020-04-08
+# Author: Steven Atkinson ([email protected])
+
+from distutils.util import convert_path
+from setuptools import setup, find_packages
+
+main_ns = {}
+ver_path = convert_path("nam/_version.py")
+with open(ver_path) as ver_file:
+ exec(ver_file.read(), main_ns)
+
+requirements = [] # torch...
+
+try:
+ import torch # noqa F401
+except ImportError as e:
+ raise ImportError(
+ f"PyTorch not found. Please install it as needed.\nOriginal error: {e}"
+ )
+
+setup(
+ name="nam",
+ version=main_ns["__version__"],
+ description="Neural amp modeler",
+ author="Steven Atkinson",
+ author_email="[email protected]",
+ url="https://github.com/sdatkinson/",
+ install_requires=requirements,
+ packages=find_packages(),
+)
diff --git a/tests/test_install.py b/tests/test_install.py
@@ -0,0 +1,13 @@
+# File: test_install.py
+# File Created: Tuesday, 2nd February 2021 9:46:01 pm
+# Author: Steven Atkinson ([email protected])
+
+import pytest
+
+
+def test_torch():
+ import torch # noqa F401
+
+
+if __name__ == "__main__":
+ pytest.main()
diff --git a/tests/test_nam/__init__.py b/tests/test_nam/__init__.py
diff --git a/tests/test_nam/test_importable.py b/tests/test_nam/test_importable.py
@@ -0,0 +1,13 @@
+# File: test_importable.py
+# Created Date: Sunday December 12th 2021
+# Author: Steven Atkinson ([email protected])
+
+import pytest
+
+
+def test_importable():
+ import nam # noqa F401
+
+
+if __name__ == "__main__":
+ pytest.main()
diff --git a/train.py b/train.py
@@ -1,198 +0,0 @@
-# File: train.py
-# # File Created: Sunday, 30th December 2018 2:08:54 pm
-# Author: Steven Atkinson ([email protected])
-
-"""
-Here's a script for training new models.
-"""
-
-from argparse import ArgumentParser
-import numpy as np
-import abc
-import tensorflow as tf
-import matplotlib.pyplot as plt
-from time import time
-import wavio
-import json
-import os
-
-import models
-
-# Parameters for training
-check_training_at = np.concatenate([10 ** pwr * np.array([1, 2, 5])
- for pwr in np.arange(0, 7)])
-plot_kwargs = {
- "window": (30000, 40000)
-}
-wav_kwargs = {"window": (0, 5 * 44100)}
-
-
-class Data(object):
- """
- Object for holding data and spitting out minibatches for training/segments
- for testing.
- """
- def __init__(self, fname, input_length, batch_size=None):
- xy = np.load(fname)
- self.x = xy[0]
- self.y = xy[1]
- self._n = self.x.size - input_length + 1
- self.input_length = input_length
- self.batch_size = batch_size
-
- @property
- def n(self):
- return self._n
-
- def minibatch(self, n=None, ilist=None):
- """
- Pull a random minibatch out of the data set
- """
- if ilist is None:
- n = n or self.batch_size
- ilist = np.random.randint(0, self.n, size=(n,))
- x = np.stack([self.x[i: i + self.input_length] for i in ilist])
- y = np.array([self.y[i + self.input_length - 1] for i in ilist]) \
- [:, np.newaxis]
- return x, y
-
- def sequence(self, start, end):
- end += self.input_length - 1
- return self.x[start: end], self.y[start + self.input_length - 1: end]
-
-
-def train(model, train_data, batch_size=None, n_minibatches=10,
- validation_data=None, plot_at=(), wav_at=(), plot_kwargs={},
- wav_kwargs={}, save_every=100, validate_every=100):
- save_dir = os.path.dirname(model.checkpoint_path)
- sess = model.sess
- opt = tf.train.AdamOptimizer().minimize(model.loss)
- sess.run(tf.global_variables_initializer()) # For opt
- t_loss_list, v_loss_list = [], []
- t0 = time()
- for i in range(n_minibatches):
- x, y = train_data.minibatch(batch_size)
- t_loss, _ = sess.run((model.rmse, opt),
- feed_dict={model.x: x, model.target: y})
- t_loss_list.append([i, t_loss])
- print("t={:7} | MB {:>7} / {:>7} | TLoss={:8}".format(int(time() - t0),
- i + 1, n_minibatches, t_loss))
-
- # Callbacks, basically...
- if i + 1 in plot_at:
- plot_predictions(model,
- validation_data if validation_data is not None else train_data,
- title="Minibatch {}".format(i + 1),
- fname="{}/mb_{}.png".format(save_dir, i + 1),
- **plot_kwargs)
- if i + 1 in wav_at:
- print("Making wav for mb {}".format(i + 1))
- predict(model, validation_data,
- save_wav_file="{}/predict_{}.wav".format(save_dir, i + 1),
- **wav_kwargs)
- if (i + 1) % save_every == 0:
- model.save(iter=i + 1)
- if i == 0 or (i + 1) % validate_every == 0:
- v_loss, _ = sess.run((model.rmse, opt),
- feed_dict={model.x: x, model.target: y})
- print("VLoss={:8}".format(v_loss))
- v_loss_list.append([i, v_loss])
-
- # After training loop...
- if validation_data is not None:
- x, y = validation_data.minibatch(train_data.batch_size)
- v_loss = sess.run(model.rmse,
- feed_dict={model.x: x, model.target: y})
- print("Validation loss={}".format(v_loss))
- return np.array(t_loss_list).T, np.array(v_loss_list).T
-
-
-def plot_predictions(model, data, title=None, fname=None, window=None):
- x, y, t = predict(model, data, window=window)
- plt.figure(figsize=(12, 4))
- plt.plot(x)
- plt.plot(t)
- plt.plot(y)
- plt.legend(('Input', 'Target', 'Prediction'))
- if title is not None:
- plt.title(title)
- if fname is not None:
- print("Saving to {}...".format(fname))
- plt.savefig(fname)
- plt.close()
- else:
- plt.show()
-
-
-def plot_loss(t_loss, v_loss, fname):
- plt.figure()
- plt.loglog(t_loss[0], t_loss[1])
- plt.loglog(v_loss[0], v_loss[1])
- plt.xlabel("Minibatch")
- plt.ylabel("RMSE")
- plt.legend(("Training", "Validation"))
- plt.savefig(fname)
- plt.close()
-
-
-def predict(model, data, window=None, save_wav_file=None):
- x, t = data.x, data.y
- if window is not None:
- x, t = x[window[0]: window[1]], t[window[0]: window[1]]
- y = model.predict(x).flatten()
-
- if save_wav_file is not None:
- rate = 44100 # TODO from data
- sampwidth = 3 # 24-bit
- wavio.write(save_wav_file, y * 2 ** 23, rate, scale="none",
- sampwidth=sampwidth)
-
- return x, y, t
-
-
-def _get_input_length(archfile):
- return json.load(open(archfile, "r"))["input_length"]
-
-
-if __name__ == "__main__":
- parser = ArgumentParser()
- parser.add_argument("model_arch", type=str,
- help="JSON containing model architecture")
- parser.add_argument("train_data", type=str,
- help="Filename for training data")
- parser.add_argument("validation_data", type=str,
- help="Filename for validation data")
- parser.add_argument("--save_dir", type=str, default=None,
- help="Where to save the run data (checkpoints, prediction...)")
- parser.add_argument("--batch_size", type=str, default=4096,
- help="Number of data per minibatch")
- parser.add_argument("--minibatches", type=int, default=10,
- help="Number of minibatches to train for")
- args = parser.parse_args()
- input_length = _get_input_length(args.model_arch) # Ugh, kludge
-
- # Load the data
- train_data = Data(args.train_data, input_length,
- batch_size=args.batch_size)
- validate_data = Data(args.validation_data, input_length)
-
- # Training
- with tf.Session() as sess:
- model = models.from_json(args.model_arch, train_data.n,
- checkpoint_path=os.path.join(args.save_dir, "model.ckpt"))
- t_loss_list, v_loss_list = train(
- model,
- train_data,
- validation_data=validate_data,
- n_minibatches=args.minibatches,
- plot_at=check_training_at,
- wav_at=check_training_at,
- plot_kwargs=plot_kwargs,
- wav_kwargs=wav_kwargs)
- plot_predictions(model, validate_data, window=(0, 44100))
- print("Predict the full output")
- predict(model, validate_data,
- save_wav_file="{}/predict.wav".format(
- os.path.dirname(model.checkpoint_path)))
-
- plot_loss(t_loss_list, v_loss_list, "{}/loss.png".format(args.save_dir))