neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 559e65a922e227027c63840554ac8e7eed79261a
parent 1caec259ba4d6526cb7da5cd3b752f34d680a459
Author: Steven Atkinson <[email protected]>
Date:   Sat, 25 Feb 2023 17:44:59 -0800

GUI Trainer (#101)

* Bare-bones GUI trainer

* 100 epochs

* Version 0.6.0

* GUI better. Need epochs and architecture and we're good

* Only calibrate delay if unknown

* Installation checker, advanced options currently a mess.

* Disable advanced options

* Option menu for architectures working

* Better default handling, allow for non-sets on advanced options

* GUI trainer in a working state.

* Update README.md and trainer config files

* Make GUI trainer acessible via an entry point

* Fix relative imports

* Update README.md

* Fix environment name
Diffstat:
MREADME.md | 71++++++++++++++++++++++++++++++++++-------------------------------------
Mbin/train/inputs/data/single_pair.json | 6+++---
Mbin/train/inputs/data/two_pairs.json | 2+-
Mbin/train/inputs/learning/default.json | 8+++-----
Menvironment_gpu.yml | 4++--
Anam/train/_version.py | 24++++++++++++++++++++++++
Mnam/train/colab.py | 341+++++--------------------------------------------------------------------------
Anam/train/core.py | 371+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Anam/train/gui.py | 457+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Msetup.py | 5+++++
10 files changed, 921 insertions(+), 368 deletions(-)

diff --git a/README.md b/README.md @@ -54,23 +54,18 @@ Then activate the environment you've created with conda activate nam ``` -### Things you can do +### Train models (GUI) +Open a GUI trainer by running -Here are the primary ways this is meant to be used: - -#### Train a model - -You'll need at least two mono wav files: the input (DI) and the amped sound (without the cab). -You can either record enough to have a training and validation set in the same file and -split the file, or you can use 4 files (input/output for train/test). -Also, you can provide _multiple_ file pairs for training (or validation). +```bash +nam +``` -For the first option, Modify `bin/train/inputs/config_data_single_pair.json` to point at the audio files, and set the -start/stop to the point (in samples) where the training segment ends and the validation -starts. -For the second option, modify and use `bin/train/inputs/config_data_two_pairs.json`. +from the terminal. -Then run: +### Train models (Python script) +For users looking to get more fine-grained control over the modeling process, +NAM includes a training script that can be run from the terminal, e.g.: ```bash python bin/train/main.py \ @@ -80,40 +75,42 @@ bin/train/inputs/config_learning.json \ bin/train/outputs/MyAmp ``` -#### Run a model on an input signal ("reamping") +where `config_data.json` contains the information about the data you're training +on, `config_model.json` contains information about the model architecture that +is being trained, and `config_learning.json` contains information about the +training run itself (e.g. number of epochs). +You'll need to configure the data JSON to the specifics of the data you're +training on. The others may work for your needs out-of-the-box with no +modification. -Handy if you want to just check it out without going through the trouble of building the -plugin. +Since NAM uses [PyTorch Lightning](https://lightning.ai/pages/open-source/) +under the hood as a modeling framework, many of the configuration options that +are passed to its componenets can be configured from the data/model/learning +JSONs. -For example: +#### Export a model (to use with [the plugin](https://github.com/sdatkinson/NeuralAmpModelerPlugin)) +Exporting the trained model to a `.nam` file for use with the plugin can be done +with: ```bash -python bin/run.py \ -path/to/source.wav \ +python bin/export.py \ path/to/config_model.json \ path/to/checkpoints/epoch=123_val_loss=0.000010.ckpt \ -path/to/output.wav +path/to/exported_models/MyAmp ``` -#### Export a model (to use with [the plugin](https://github.com/sdatkinson/iPlug2)) +Then, point the plugin at the exported `model.nam` file and you're good to go! + +### Other utilities -Let's get ready to rock! +#### Run a model on an input signal ("reamping") + +Handy if you want to just check it out without needing to use the plugin: ```bash -python bin/export.py \ +python bin/run.py \ +path/to/source.wav \ path/to/config_model.json \ path/to/checkpoints/epoch=123_val_loss=0.000010.ckpt \ -path/to/exported_models/MyAmp +path/to/output.wav ``` - -Then point the plugin at the exported model directory and you're good to go! - -## Advanced usage - -The model architectures and cofigurations in `bin/train/inputs/models` should work plenty well out of the box. -However, feel free to play around with it; sometimes some tweaks can help improve performance. - -Also, you can train for shorter or longer. -1000 epochs is typically overkill, but how little you can get away with depends on the model you're using. -I recommend watching the checkpoints and keeping an eye out for when the ESR drops below -0.01--usually it'll sound pretty good by that point. diff --git a/bin/train/inputs/data/single_pair.json b/bin/train/inputs/data/single_pair.json @@ -1,11 +1,11 @@ { "train": { "start": null, - "stop": 36576000, - "ny": 1024 + "stop": -432000, + "ny": 8192 }, "validation": { - "start": 36576000, + "start": -432000, "stop": null, "ny": null }, diff --git a/bin/train/inputs/data/two_pairs.json b/bin/train/inputs/data/two_pairs.json @@ -2,7 +2,7 @@ "train": { "x_path": "C:\\path\\to\\train\\source.wav", "y_path": "C:\\path\\to\\train\\target.wav", - "ny": 1024 + "ny": 8192 }, "validation": { "x_path": "C:\\path\\to\\validation\\source.wav", diff --git a/bin/train/inputs/learning/default.json b/bin/train/inputs/learning/default.json @@ -6,12 +6,10 @@ "drop_last": true, "num_workers": 0 }, - "val_dataloader": { - }, + "val_dataloader": {}, "trainer": { "gpus": 1, - "max_epochs": 1000 + "max_epochs": 100 }, - "trainer_fit_kwargs": { - } + "trainer_fit_kwargs": {} } \ No newline at end of file diff --git a/environment_gpu.yml b/environment_gpu.yml @@ -2,7 +2,7 @@ # Created Date: Saturday February 13th 2021 # Author: Steven Atkinson ([email protected]) -name: nam-test +name: nam channels: - pytorch - nvidia @@ -24,7 +24,7 @@ dependencies: - wheel - pip: - onnx - - onnxruntime # TODO GPU... + - onnxruntime # TODO GPU... - pre-commit - pytorch_lightning - sounddevice diff --git a/nam/train/_version.py b/nam/train/_version.py @@ -0,0 +1,24 @@ +# File: _version.py +# Created Date: Tuesday December 20th 2022 +# Author: Steven Atkinson ([email protected]) + +""" +Version utility +""" + +class Version: + def __init__(self, major: int, minor: int, patch: int): + self.major = major + self.minor = minor + self.patch = patch + + def __lt__(self, other) -> bool: + if self.major != other.major: + return self.major < other.major + if self.minor != other.minor: + return self.minor < other.minor + if self.patch != other.patch: + return self.patch < other.patch + + def __str__(self) -> str: + return f"{self.major}.{self.minor}.{self.patch}" diff --git a/nam/train/colab.py b/nam/train/colab.py @@ -6,54 +6,24 @@ Hide the mess in Colab to make things look pretty for users. """ -from enum import Enum + from pathlib import Path -from time import time from typing import Optional, Tuple -import matplotlib.pyplot as plt -import numpy as np -import pytorch_lightning as pl -import torch -from torch.utils.data import DataLoader - -from nam.data import REQUIRED_RATE, Split, init_dataset, wav_to_np -from nam.models import Model - - -class _Architecture(Enum): - STANDARD = "standard" - LITE = "lite" - FEATHER = "feather" - - -class _Version: - def __init__(self, major: int, minor: int, patch: int): - self.major = major - self.minor = minor - self.patch = patch +from ._version import Version +from .core import train - def __lt__(self, other) -> bool: - if self.major != other.major: - return self.major < other.major - if self.minor != other.minor: - return self.minor < other.minor - if self.patch != other.patch: - return self.patch < other.patch - def __str__(self) -> str: - return f"{self.major}.{self.minor}.{self.patch}" - - -_INPUT_BASENAMES = ((_Version(1, 1, 1), "v1_1_1.wav"), (_Version(1, 0, 0), "v1.wav")) +_INPUT_BASENAMES = ((Version(1, 1, 1), "v1_1_1.wav"), (Version(1, 0, 0), "v1.wav")) _BUGGY_INPUT_BASENAMES = { # 1.1.0 has the spikes at the wrong spots. "v1_1_0.wav" } _OUTPUT_BASENAME = "output.wav" +_TRAIN_PATH = "." -def _check_for_files() -> Tuple[_Version, str]: +def _check_for_files() -> Tuple[Version, str]: print("Checking that we have all of the required audio files...") for name in _BUGGY_INPUT_BASENAMES: if Path(name).exists(): @@ -79,239 +49,6 @@ def _check_for_files() -> Tuple[_Version, str]: return input_version, input_basename -def _calibrate_delay_v1() -> int: - safety_factor = 4 - # Locations of blips in v1 signal file: - i1, i2 = 12_000, 36_000 - j1_start_looking = i1 - 1_000 - j2_start_looking = i2 - 1_000 - - y = wav_to_np(_OUTPUT_BASENAME)[:48_000] - - background_level = np.max(np.abs(y[:6_000])) - trigger_threshold = max(background_level + 0.01, 1.01 * background_level) - j1 = np.where(np.abs(y[j1_start_looking:j2_start_looking]) > trigger_threshold)[0][ - 0 - ] - j2 = np.where(np.abs(y[j2_start_looking:]) > trigger_threshold)[0][0] - - delay_1 = (j1 + j1_start_looking) - i1 - delay_2 = (j2 + j2_start_looking) - i2 - print(f"Delays: {delay_1}, {delay_2}") - delay = int(np.min([delay_1, delay_2])) - safety_factor - print(f"Final delay is {delay}") - return delay - - -def _plot_delay_v1(delay: int, input_basename: str): - print("Plotting the delay for manual inspection...") - x = wav_to_np(input_basename)[:48_000] - y = wav_to_np(_OUTPUT_BASENAME)[:48_000] - i = np.where(np.abs(x) > 0.1)[0][0] # In case resampled poorly - di = 20 - plt.figure() - # plt.plot(x[i - di : i + di], ".-", label="Input") - plt.plot( - np.arange(-di, di), - y[i - di + delay : i + di + delay], - ".-", - label="Output", - ) - plt.axvline(x=0, linestyle="--", color="C1") - plt.legend() - plt.show() # This doesn't freeze the notebook - - -def _calibrate_delay( - delay: Optional[int], input_version: _Version, input_basename: str -) -> int: - if input_version.major == 1: - calibrate, plot = _calibrate_delay_v1, _plot_delay_v1 - else: - raise NotImplementedError( - f"Input calibration not implemented for input version {input_version}" - ) - if delay is not None: - print(f"Delay is specified as {delay}") - else: - print("Delay wasn't provided; attempting to calibrate automatically...") - delay = calibrate() - plot(delay, input_basename) - return delay - - -def _get_wavenet_config(architecture): - return { - _Architecture.STANDARD: { - "layers_configs": [ - { - "input_size": 1, - "condition_size": 1, - "channels": 16, - "head_size": 8, - "kernel_size": 3, - "dilations": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], - "activation": "Tanh", - "gated": False, - "head_bias": False, - }, - { - "condition_size": 1, - "input_size": 16, - "channels": 8, - "head_size": 1, - "kernel_size": 3, - "dilations": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], - "activation": "Tanh", - "gated": False, - "head_bias": True, - }, - ], - "head_scale": 0.02, - }, - _Architecture.LITE: { - "layers_configs": [ - { - "input_size": 1, - "condition_size": 1, - "channels": 12, - "head_size": 6, - "kernel_size": 3, - "dilations": [1, 2, 4, 8, 16, 32, 64], - "activation": "Tanh", - "gated": False, - "head_bias": False, - }, - { - "condition_size": 1, - "input_size": 12, - "channels": 6, - "head_size": 1, - "kernel_size": 3, - "dilations": [128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512], - "activation": "Tanh", - "gated": False, - "head_bias": True, - }, - ], - "head_scale": 0.02, - }, - _Architecture.FEATHER: { - "layers_configs": [ - { - "input_size": 1, - "condition_size": 1, - "channels": 8, - "head_size": 4, - "kernel_size": 3, - "dilations": [1, 2, 4, 8, 16, 32, 64], - "activation": "Tanh", - "gated": False, - "head_bias": False, - }, - { - "condition_size": 1, - "input_size": 8, - "channels": 4, - "head_size": 1, - "kernel_size": 3, - "dilations": [128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512], - "activation": "Tanh", - "gated": False, - "head_bias": True, - }, - ], - "head_scale": 0.02, - }, - }[architecture] - - -def _get_configs( - input_basename: str, - delay: int, - epochs: int, - architecture: _Architecture, - lr: float, - lr_decay: float, -): - val_seconds = 9 - train_val_split = -val_seconds * REQUIRED_RATE - data_config = { - "train": {"ny": 8192, "stop": train_val_split}, - "validation": {"ny": None, "start": train_val_split}, - "common": { - "x_path": input_basename, - "y_path": _OUTPUT_BASENAME, - "delay": delay, - }, - } - model_config = { - "net": { - "name": "WaveNet", - # This should do decently. If you really want a nice model, try turning up - # "channels" in the first block and "input_size" in the second from 12 to 16. - "config": _get_wavenet_config(architecture), - }, - "loss": {"val_loss": "esr"}, - "optimizer": {"lr": lr}, - "lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 1.0 - lr_decay}}, - } - learning_config = { - "train_dataloader": { - "batch_size": 16, - "shuffle": True, - "pin_memory": True, - "drop_last": True, - "num_workers": 0, - }, - "val_dataloader": {}, - "trainer": {"accelerator": "gpu", "devices": 1, "max_epochs": epochs}, - } - return data_config, model_config, learning_config - - -def _esr(pred: torch.Tensor, target: torch.Tensor) -> float: - return ( - torch.mean(torch.square(pred - target)).item() - / torch.mean(torch.square(target)).item() - ) - - -def _plot( - model, ds, window_start: Optional[int] = None, window_end: Optional[int] = None -): - print("Plotting a comparison of your model with the target output...") - with torch.no_grad(): - tx = len(ds.x) / 48_000 - print(f"Run (t={tx:.2f} sec)") - t0 = time() - output = model(ds.x).flatten().cpu().numpy() - t1 = time() - print(f"Took {t1 - t0:.2f} sec ({tx / (t1 - t0):.2f}x)") - - esr = _esr(torch.Tensor(output), ds.y) - # Trying my best to put numbers to it... - if esr < 0.01: - esr_comment = "Great!" - elif esr < 0.035: - esr_comment = "Not bad!" - elif esr < 0.1: - esr_comment = "...This *might* sound ok!" - elif esr < 0.3: - esr_comment = "...This probably won't sound great :(" - else: - esr_comment = "...Something seems to have gone wrong." - print(f"Error-signal ratio = {esr:.3f}") - print(esr_comment) - - plt.figure(figsize=(16, 5)) - plt.plot(output[window_start:window_end], label="Prediction") - plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target") - plt.title(f"ESR={esr:.3f}") - plt.legend() - plt.show() - - def _get_valid_export_directory(): def get_path(version): return Path("exported_models", f"version_{version}") @@ -326,9 +63,9 @@ def run( epochs: int = 100, delay: Optional[int] = None, architecture: str = "standard", - lr=0.004, - lr_decay=0.007, - seed=0, + lr: float=0.004, + lr_decay: float=0.007, + seed: Optional[int]=0, ): """ :param epochs: How amny epochs we'll train for. @@ -340,56 +77,20 @@ def run( :param lr_decay: The amount by which the learning rate decays each epoch :param seed: RNG seed for reproducibility. """ - torch.manual_seed(seed) - input_version, input_basename = _check_for_files() - delay = _calibrate_delay(delay, input_version, input_basename) - data_config, model_config, learning_config = _get_configs( - input_basename, - delay, - epochs, - _Architecture(architecture), - lr, - lr_decay, - ) - - print("Starting training. Let's rock!") - model = Model.init_from_config(model_config) - data_config["common"]["nx"] = model.net.receptive_field - dataset_train = init_dataset(data_config, Split.TRAIN) - dataset_validation = init_dataset(data_config, Split.VALIDATION) - train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"]) - val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"]) - trainer = pl.Trainer( - callbacks=[ - pl.callbacks.model_checkpoint.ModelCheckpoint( - filename="checkpoint_best_{epoch:04d}_{step}_{ESR:.4f}_{MSE:.3e}", - save_top_k=3, - monitor="val_loss", - every_n_epochs=1, - ), - pl.callbacks.model_checkpoint.ModelCheckpoint( - filename="checkpoint_last_{epoch:04d}_{step}", every_n_epochs=1 - ), - ], - **learning_config["trainer"], - ) - trainer.fit(model, train_dataloader, val_dataloader) - - # Go to best checkpoint - best_checkpoint = trainer.checkpoint_callback.best_model_path - if best_checkpoint != "": - model = Model.load_from_checkpoint( - trainer.checkpoint_callback.best_model_path, - **Model.parse_config(model_config), - ) - model.eval() + input_version, input_basename = _check_for_files() - _plot( - model, - dataset_validation, - window_start=100_000, # Start of the plotting window, in samples - window_end=101_000, # End of the plotting window, in samples + model = train( + input_basename, + _OUTPUT_BASENAME, + _TRAIN_PATH, + input_version=input_version, + epochs=epochs, + delay=delay, + architecture=architecture, + lr=lr, + lr_decay=lr_decay, + seed=seed, ) print("Exporting your model...") diff --git a/nam/train/core.py b/nam/train/core.py @@ -0,0 +1,371 @@ +# File: gui.py +# Created Date: Tuesday December 20th 2022 +# Author: Steven Atkinson ([email protected]) + +""" +Functions used by the GUI trainer. +""" + +import hashlib +from enum import Enum +from time import time +from typing import Optional, Union + +import matplotlib.pyplot as plt +import numpy as np +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader + +from ..data import REQUIRED_RATE, Split, init_dataset, wav_to_np +from ..models import Model +from ._version import Version + + +class Architecture(Enum): + STANDARD = "standard" + LITE = "lite" + FEATHER = "feather" + + +def _detect_input_version(input_path) -> Version: + """ + Check to see if the input matches any of the known inputs + """ + md5 = hashlib.md5() + buffer_size = 65536 + with open(input_path, "rb") as f: + while True: + data = f.read(buffer_size) + if not data: + break + md5.update(data) + file_hash = md5.hexdigest() + + version = { + "4d54a958861bf720ec4637f43d44a7ef": Version(1, 0, 0), + "7c3b6119c74465f79d96c761a0e27370": Version(1, 1, 1), + }.get(file_hash) + if version is None: + raise RuntimeError( + f"Provided input file {input_path} does not match any known standard input " + "files." + ) + return version + + +def _calibrate_delay_v1(input_path, output_path) -> int: + safety_factor = 4 + # Locations of blips in v1 signal file: + i1, i2 = 12_000, 36_000 + j1_start_looking = i1 - 1_000 + j2_start_looking = i2 - 1_000 + + y = wav_to_np(output_path)[:48_000] + + background_level = np.max(np.abs(y[:6_000])) + trigger_threshold = max(background_level + 0.01, 1.01 * background_level) + j1 = np.where(np.abs(y[j1_start_looking:j2_start_looking]) > trigger_threshold)[0][ + 0 + ] + j2 = np.where(np.abs(y[j2_start_looking:]) > trigger_threshold)[0][0] + + delay_1 = (j1 + j1_start_looking) - i1 + delay_2 = (j2 + j2_start_looking) - i2 + print(f"Delays: {delay_1}, {delay_2}") + delay = int(np.min([delay_1, delay_2])) - safety_factor + print(f"Final delay is {delay}") + return delay + + +def _plot_delay_v1(delay: int, input_path: str, output_path: str): + print("Plotting the delay for manual inspection...") + x = wav_to_np(input_path)[:48_000] + y = wav_to_np(output_path)[:48_000] + i = np.where(np.abs(x) > 0.1)[0][0] # In case resampled poorly + di = 20 + plt.figure() + # plt.plot(x[i - di : i + di], ".-", label="Input") + plt.plot( + np.arange(-di, di), + y[i - di + delay : i + di + delay], + ".-", + label="Output", + ) + plt.axvline(x=0, linestyle="--", color="C1") + plt.legend() + plt.show() # This doesn't freeze the notebook + + +def _calibrate_delay( + delay: Optional[int], input_version: Version, input_path: str, output_path: str, +) -> int: + if input_version.major == 1: + calibrate, plot = _calibrate_delay_v1, _plot_delay_v1 + else: + raise NotImplementedError( + f"Input calibration not implemented for input version {input_version}" + ) + if delay is not None: + print(f"Delay is specified as {delay}") + else: + print("Delay wasn't provided; attempting to calibrate automatically...") + delay = calibrate(input_path, output_path) + plot(delay, input_path, output_path) + return delay + + +def _get_wavenet_config(architecture): + return { + Architecture.STANDARD: { + "layers_configs": [ + { + "input_size": 1, + "condition_size": 1, + "channels": 16, + "head_size": 8, + "kernel_size": 3, + "dilations": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], + "activation": "Tanh", + "gated": False, + "head_bias": False, + }, + { + "condition_size": 1, + "input_size": 16, + "channels": 8, + "head_size": 1, + "kernel_size": 3, + "dilations": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], + "activation": "Tanh", + "gated": False, + "head_bias": True, + }, + ], + "head_scale": 0.02, + }, + Architecture.LITE: { + "layers_configs": [ + { + "input_size": 1, + "condition_size": 1, + "channels": 12, + "head_size": 6, + "kernel_size": 3, + "dilations": [1, 2, 4, 8, 16, 32, 64], + "activation": "Tanh", + "gated": False, + "head_bias": False, + }, + { + "condition_size": 1, + "input_size": 12, + "channels": 6, + "head_size": 1, + "kernel_size": 3, + "dilations": [128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512], + "activation": "Tanh", + "gated": False, + "head_bias": True, + }, + ], + "head_scale": 0.02, + }, + Architecture.FEATHER: { + "layers_configs": [ + { + "input_size": 1, + "condition_size": 1, + "channels": 8, + "head_size": 4, + "kernel_size": 3, + "dilations": [1, 2, 4, 8, 16, 32, 64], + "activation": "Tanh", + "gated": False, + "head_bias": False, + }, + { + "condition_size": 1, + "input_size": 8, + "channels": 4, + "head_size": 1, + "kernel_size": 3, + "dilations": [128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512], + "activation": "Tanh", + "gated": False, + "head_bias": True, + }, + ], + "head_scale": 0.02, + }, + }[architecture] + + +def _get_configs( + input_basename: str, + output_basename: str, + delay: int, + epochs: int, + architecture: Architecture, + lr: float, + lr_decay: float, +): + val_seconds = 9 + train_val_split = -val_seconds * REQUIRED_RATE + data_config = { + "train": {"ny": 8192, "stop": train_val_split}, + "validation": {"ny": None, "start": train_val_split}, + "common": { + "x_path": input_basename, + "y_path": output_basename, + "delay": delay, + }, + } + model_config = { + "net": { + "name": "WaveNet", + # This should do decently. If you really want a nice model, try turning up + # "channels" in the first block and "input_size" in the second from 12 to 16. + "config": _get_wavenet_config(architecture) + }, + "loss": {"val_loss": "esr"}, + "optimizer": {"lr": lr}, + "lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 1.0 - lr_decay}}, + } + if torch.cuda.is_available(): + device_config = {"accelerator": "gpu", "devices": 1} + else: + print("WARNING: No GPU was found. Training will be very slow!") + device_config = {} + learning_config = { + "train_dataloader": { + "batch_size": 16, + "shuffle": True, + "pin_memory": True, + "drop_last": True, + "num_workers": 0, + }, + "val_dataloader": {}, + "trainer": {"max_epochs": epochs, **device_config}, + } + return data_config, model_config, learning_config + + +def _esr(pred: torch.Tensor, target: torch.Tensor) -> float: + return ( + torch.mean(torch.square(pred - target)).item() + / torch.mean(torch.square(target)).item() + ) + + +def _plot( + model, ds, window_start: Optional[int] = None, window_end: Optional[int] = None +): + print("Plotting a comparison of your model with the target output...") + with torch.no_grad(): + tx = len(ds.x) / 48_000 + print(f"Run (t={tx:.2f} sec)") + t0 = time() + output = model(ds.x).flatten().cpu().numpy() + t1 = time() + print(f"Took {t1 - t0:.2f} sec ({tx / (t1 - t0):.2f}x)") + + esr = _esr(torch.Tensor(output), ds.y) + # Trying my best to put numbers to it... + if esr < 0.01: + esr_comment = "Great!" + elif esr < 0.035: + esr_comment = "Not bad!" + elif esr < 0.1: + esr_comment = "...This *might* sound ok!" + elif esr < 0.3: + esr_comment = "...This probably won't sound great :(" + else: + esr_comment = "...Something seems to have gone wrong." + print(f"Error-signal ratio = {esr:.3f}") + print(esr_comment) + + plt.figure(figsize=(16, 5)) + plt.plot(output[window_start:window_end], label="Prediction") + plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target") + plt.title(f"ESR={esr:.3f}") + plt.legend() + plt.show() + + +def train( + input_path: str, + output_path: str, + train_path: str, + input_version: Optional[Version] = None, + epochs=100, + delay=None, + architecture: Union[Architecture, str]=Architecture.STANDARD, + lr=0.004, + lr_decay=0.007, + seed: Optional[int] = 0, +): + if seed is not None: + torch.manual_seed(seed) + + # This needs more thought... + # 1. Does the user want me to calibrate the delay? + # 2. Does the user want to see what the chosen (by them or me) delay looks like? + if delay is None: + if input_version is None: + input_version = _detect_input_version(input_path) + delay = _calibrate_delay(delay, input_version, input_path, output_path) + else: + print(f"Delay provided as {delay}; skip calibration") + + data_config, model_config, learning_config = _get_configs( + input_path, + output_path, + delay, + epochs, + Architecture(architecture), + lr, + lr_decay, + ) + + print("Starting training. Let's rock!") + model = Model.init_from_config(model_config) + data_config["common"]["nx"] = model.net.receptive_field + dataset_train = init_dataset(data_config, Split.TRAIN) + dataset_validation = init_dataset(data_config, Split.VALIDATION) + train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"]) + val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"]) + + trainer = pl.Trainer( + callbacks=[ + pl.callbacks.model_checkpoint.ModelCheckpoint( + filename="checkpoint_best_{epoch:04d}_{step}_{ESR:.4f}_{MSE:.3e}", + save_top_k=3, + monitor="val_loss", + every_n_epochs=1, + ), + pl.callbacks.model_checkpoint.ModelCheckpoint( + filename="checkpoint_last_{epoch:04d}_{step}", every_n_epochs=1 + ), + ], + default_root_dir=train_path, + **learning_config["trainer"], + ) + trainer.fit(model, train_dataloader, val_dataloader) + + # Go to best checkpoint + best_checkpoint = trainer.checkpoint_callback.best_model_path + if best_checkpoint != "": + model = Model.load_from_checkpoint( + trainer.checkpoint_callback.best_model_path, + **Model.parse_config(model_config), + ) + model.eval() + + _plot( + model, + dataset_validation, + window_start=100_000, # Start of the plotting window, in samples + window_end=101_000, # End of the plotting window, in samples + ) + return model diff --git a/nam/train/gui.py b/nam/train/gui.py @@ -0,0 +1,457 @@ +# File: __init__.py +# Created Date: Saturday February 25th 2023 +# Author: Steven Atkinson ([email protected]) + +""" +GUI for training + +Usage: +>>> import nam.train.gui +""" + +import tkinter as tk +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from tkinter import filedialog +from typing import Callable, Optional, Sequence + +try: + from .. import __version__ + from . import core + + _install_is_valid = True +except ImportError: + _install_is_valid = False + +_BUTTON_WIDTH = 20 +_BUTTON_HEIGHT = 2 +_TEXT_WIDTH = 70 + +_DEFAULT_NUM_EPOCHS = 100 +_DEFAULT_DELAY = None + + +@dataclass +class _AdvancedOptions(object): + architecture: core.Architecture + num_epochs: int + delay: Optional[int] + + +class _PathType(Enum): + FILE = "file" + DIRECTORY = "directory" + + +class _PathButton(object): + """ + Button and the path + """ + + def __init__( + self, + frame: tk.Frame, + button_text, + info_str: str, + path_type: _PathType, + hooks: Optional[Sequence[Callable[[], None]]] = None, + ): + self._info_str = info_str + self._path: Optional[Path] = None + self._path_type = path_type + self._button = tk.Button( + frame, + text=button_text, + width=_BUTTON_WIDTH, + height=_BUTTON_HEIGHT, + fg="black", + command=self._set_val, + ) + self._button.pack(side=tk.LEFT) + self._label = tk.Label( + frame, + width=_TEXT_WIDTH, + height=_BUTTON_HEIGHT, + fg="black", + bg=None, + anchor="w", + ) + self._label.pack(side=tk.RIGHT) + self._hooks = hooks + self._set_text() + + @property + def val(self) -> Optional[Path]: + return self._path + + def _set_text(self): + if self._path is None: + self._label["fg"] = "red" + self._label["text"] = f"{self._info_str} is not set!" + else: + self._label["fg"] = "black" + self._label["text"] = f"{self._info_str} set to {self.val}" + + def _set_val(self): + res = { + _PathType.FILE: filedialog.askopenfilename, + _PathType.DIRECTORY: filedialog.askdirectory, + }[self._path_type]() + if res != "": + self._path = res + self._set_text() + + if self._hooks is not None: + for h in self._hooks: + h() + + +class _GUI(object): + def __init__(self): + self._root = tk.Tk() + self._root.title(f"NAM Trainer - v{__version__}") + + # Buttons for paths: + self._frame_input_path = tk.Frame(self._root) + self._frame_input_path.pack() + self._path_button_input = _PathButton( + self._frame_input_path, + "Input Audio", + "Input audio", + _PathType.FILE, + hooks=[self._check_button_states], + ) + + self._frame_output_path = tk.Frame(self._root) + self._frame_output_path.pack() + self._path_button_output = _PathButton( + self._frame_output_path, + "Output Audio", + "Output audio", + _PathType.FILE, + hooks=[self._check_button_states], + ) + + self._frame_train_destination = tk.Frame(self._root) + self._frame_train_destination.pack() + self._path_button_train_destination = _PathButton( + self._frame_train_destination, + "Train Destination", + "Train destination", + _PathType.DIRECTORY, + hooks=[self._check_button_states], + ) + + # Advanced options for training + default_architecture = core.Architecture.STANDARD + self.advanced_options = _AdvancedOptions( + default_architecture, _DEFAULT_NUM_EPOCHS, _DEFAULT_DELAY + ) + # Window to edit them: + self._frame_advanced_options = tk.Frame(self._root) + self._frame_advanced_options.pack() + self._button_advanced_options = tk.Button( + self._frame_advanced_options, + text="Advanced options...", + width=_BUTTON_WIDTH, + height=_BUTTON_HEIGHT, + fg="black", + command=self._open_advanced_options, + ) + self._button_advanced_options.pack() + + # Train button + self._frame_train = tk.Frame(self._root) + self._frame_train.pack() + self._button_train = tk.Button( + self._frame_train, + text="Train", + width=_BUTTON_WIDTH, + height=_BUTTON_HEIGHT, + fg="black", + command=self._train, + ) + self._button_train.pack() + + self._check_button_states() + + def mainloop(self): + self._root.mainloop() + + def _open_advanced_options(self): + """ + Open advanced options + """ + ao = _AdvancedOptionsGUI(self) + # I should probably disable the main GUI... + ao.mainloop() + # ...and then re-enable it once it gets closed. + + def _train(self): + # Advanced options: + num_epochs = self.advanced_options.num_epochs + architecture = self.advanced_options.architecture + delay = self.advanced_options.delay + + # Advanced-er options + # If you're poking around looking for these, then maybe it's time to learn to + # use the command-line scripts ;) + lr = 0.004 + lr_decay = 0.007 + seed = 0 + + # Run it + trained_model = core.train( + self._path_button_input.val, + self._path_button_output.val, + self._path_button_train_destination.val, + epochs=num_epochs, + delay=delay, + architecture=architecture, + lr=lr, + lr_decay=lr_decay, + seed=seed, + ) + print("Model training complete!") + print("Exporting...") + outdir = self._path_button_train_destination.val + print(f"Exporting trained model to {outdir}...") + trained_model.net.export(outdir) + print("Done!") + + def _check_button_states(self): + """ + Determine if any buttons should be disabled + """ + # Train button is diabled unless all paths are set + if any( + pb.val is None + for pb in ( + self._path_button_input, + self._path_button_output, + self._path_button_train_destination, + ) + ): + self._button_train["state"] = tk.DISABLED + return + self._button_train["state"] = tk.NORMAL + + +_ADVANCED_OPTIONS_LEFT_WIDTH = 12 +_ADVANCED_OPTIONS_RIGHT_WIDTH = 12 + + +class _LabeledOptionMenu(object): + """ + Label (left) and radio buttons (right) + """ + + def __init__( + self, frame: tk.Frame, label: str, choices: Enum, default: Optional[Enum] = None + ): + """ + :param command: Called to propagate option selection. Is provided with the + value corresponding to the radio button selected. + """ + self._frame = frame + self._choices = choices + height = _BUTTON_HEIGHT + bg = None + fg = "black" + self._label = tk.Label( + frame, + width=_ADVANCED_OPTIONS_LEFT_WIDTH, + height=height, + fg=fg, + bg=bg, + anchor="w", + text=label, + ) + self._label.pack(side=tk.LEFT) + + frame_menu = tk.Frame(frame) + frame_menu.pack(side=tk.RIGHT) + + self._selected_value = None + default = (list(choices)[0] if default is None else default).value + self._menu = tk.OptionMenu( + frame_menu, + tk.StringVar(master=frame, value=default, name=label), + # default, + *[choice.value for choice in choices], # if choice.value!=default], + command=self._set, + ) + self._menu.config(width=_ADVANCED_OPTIONS_RIGHT_WIDTH) + self._menu.pack(side=tk.RIGHT) + # Initialize + self._set(default) + + def get(self) -> Enum: + return self._selected_value + + def _set(self, val: str): + """ + Set the value selected + """ + self._selected_value = self._choices(val) + + +class _LabeledText(object): + """ + Label (left) and text input (right) + """ + + def __init__(self, frame: tk.Frame, label: str, default=None, type=None): + """ + :param command: Called to propagate option selection. Is provided with the + value corresponding to the radio button selected. + :param type: If provided, casts value to given type + """ + self._frame = frame + label_height = 2 + text_height = 1 + self._label = tk.Label( + frame, + width=_ADVANCED_OPTIONS_LEFT_WIDTH, + height=label_height, + fg="black", + bg=None, + anchor="w", + text=label, + ) + self._label.pack(side=tk.LEFT) + + self._text = tk.Text( + frame, + width=_ADVANCED_OPTIONS_RIGHT_WIDTH, + height=text_height, + fg="black", + bg=None, + ) + self._text.pack(side=tk.RIGHT) + + self._type = type + + if default is not None: + self._text.insert("1.0", str(default)) + + def get(self): + try: + val = self._text.get("1.0", tk.END) # Line 1, character zero (wat) + if self._type is not None: + val = self._type(val) + return val + except tk.TclError: + return None + + +class _AdvancedOptionsGUI(object): + """ + A window to hold advanced options (Architecture and number of epochs) + """ + + def __init__(self, parent: _GUI): + self._parent = parent + self._root = tk.Tk() + self._root.title("Advanced Options") + + # Architecture: radio buttons + self._frame_architecture = tk.Frame(self._root) + self._frame_architecture.pack() + self._architecture = _LabeledOptionMenu( + self._frame_architecture, + "Architecture", + core.Architecture, + default=self._parent.advanced_options.architecture, + ) + + # Number of epochs: text box + self._frame_epochs = tk.Frame(self._root) + self._frame_epochs.pack() + + def non_negative_int(val): + val = int(val) + if val < 0: + val = 0 + return val + + self._epochs = _LabeledText( + self._frame_epochs, + "Epochs", + default=str(self._parent.advanced_options.num_epochs), + type=non_negative_int, + ) + + # Delay: text box + self._frame_delay = tk.Frame(self._root) + self._frame_delay.pack() + + def int_or_null(val): + val = val.rstrip() + if val == "null": + return val + return int(val) + + def int_or_null_inv(val): + return "null" if val is None else str(val) + + self._delay = _LabeledText( + self._frame_delay, + "Delay", + default=int_or_null_inv(self._parent.advanced_options.delay), + type=int_or_null, + ) + + # "Ok": apply and destory + self._frame_ok = tk.Frame(self._root) + self._frame_ok.pack() + self._button_ok = tk.Button( + self._frame_ok, + text="Ok", + width=_BUTTON_WIDTH, + height=_BUTTON_HEIGHT, + fg="black", + command=self._apply_and_destroy, + ) + self._button_ok.pack() + + def mainloop(self): + self._root.mainloop() + + def _apply_and_destroy(self): + """ + Set values to parent and destroy this object + """ + self._parent.advanced_options.architecture = self._architecture.get() + epochs = self._epochs.get() + if epochs is not None: + self._parent.advanced_options.num_epochs = epochs + delay = self._delay.get() + # Value None is returned as "null" to disambiguate from non-set. + if delay is not None: + self._parent.advanced_options.delay = None if delay == "null" else delay + self._root.destroy() + + +def _install_error(): + window = tk.Tk() + window.title("ERROR") + label = tk.Label( + window, + width=45, + height=2, + text="The NAM training software has not been installed correctly.", + ) + label.pack() + button = tk.Button(window, width=10, height=2, text="Quit", command=window.destroy) + button.pack() + window.mainloop() + + +def run(): + if _install_is_valid: + _gui = _GUI() + _gui.mainloop() + else: + _install_error() diff --git a/setup.py b/setup.py @@ -31,4 +31,9 @@ setup( url="https://github.com/sdatkinson/", install_requires=requirements, packages=find_packages(), + entry_points={ + 'console_scripts': [ + 'nam = nam.train.gui:run', + ] + } )