commit 559e65a922e227027c63840554ac8e7eed79261a
parent 1caec259ba4d6526cb7da5cd3b752f34d680a459
Author: Steven Atkinson <[email protected]>
Date: Sat, 25 Feb 2023 17:44:59 -0800
GUI Trainer (#101)
* Bare-bones GUI trainer
* 100 epochs
* Version 0.6.0
* GUI better. Need epochs and architecture and we're good
* Only calibrate delay if unknown
* Installation checker, advanced options currently a mess.
* Disable advanced options
* Option menu for architectures working
* Better default handling, allow for non-sets on advanced options
* GUI trainer in a working state.
* Update README.md and trainer config files
* Make GUI trainer acessible via an entry point
* Fix relative imports
* Update README.md
* Fix environment name
Diffstat:
10 files changed, 921 insertions(+), 368 deletions(-)
diff --git a/README.md b/README.md
@@ -54,23 +54,18 @@ Then activate the environment you've created with
conda activate nam
```
-### Things you can do
+### Train models (GUI)
+Open a GUI trainer by running
-Here are the primary ways this is meant to be used:
-
-#### Train a model
-
-You'll need at least two mono wav files: the input (DI) and the amped sound (without the cab).
-You can either record enough to have a training and validation set in the same file and
-split the file, or you can use 4 files (input/output for train/test).
-Also, you can provide _multiple_ file pairs for training (or validation).
+```bash
+nam
+```
-For the first option, Modify `bin/train/inputs/config_data_single_pair.json` to point at the audio files, and set the
-start/stop to the point (in samples) where the training segment ends and the validation
-starts.
-For the second option, modify and use `bin/train/inputs/config_data_two_pairs.json`.
+from the terminal.
-Then run:
+### Train models (Python script)
+For users looking to get more fine-grained control over the modeling process,
+NAM includes a training script that can be run from the terminal, e.g.:
```bash
python bin/train/main.py \
@@ -80,40 +75,42 @@ bin/train/inputs/config_learning.json \
bin/train/outputs/MyAmp
```
-#### Run a model on an input signal ("reamping")
+where `config_data.json` contains the information about the data you're training
+on, `config_model.json` contains information about the model architecture that
+is being trained, and `config_learning.json` contains information about the
+training run itself (e.g. number of epochs).
+You'll need to configure the data JSON to the specifics of the data you're
+training on. The others may work for your needs out-of-the-box with no
+modification.
-Handy if you want to just check it out without going through the trouble of building the
-plugin.
+Since NAM uses [PyTorch Lightning](https://lightning.ai/pages/open-source/)
+under the hood as a modeling framework, many of the configuration options that
+are passed to its componenets can be configured from the data/model/learning
+JSONs.
-For example:
+#### Export a model (to use with [the plugin](https://github.com/sdatkinson/NeuralAmpModelerPlugin))
+Exporting the trained model to a `.nam` file for use with the plugin can be done
+with:
```bash
-python bin/run.py \
-path/to/source.wav \
+python bin/export.py \
path/to/config_model.json \
path/to/checkpoints/epoch=123_val_loss=0.000010.ckpt \
-path/to/output.wav
+path/to/exported_models/MyAmp
```
-#### Export a model (to use with [the plugin](https://github.com/sdatkinson/iPlug2))
+Then, point the plugin at the exported `model.nam` file and you're good to go!
+
+### Other utilities
-Let's get ready to rock!
+#### Run a model on an input signal ("reamping")
+
+Handy if you want to just check it out without needing to use the plugin:
```bash
-python bin/export.py \
+python bin/run.py \
+path/to/source.wav \
path/to/config_model.json \
path/to/checkpoints/epoch=123_val_loss=0.000010.ckpt \
-path/to/exported_models/MyAmp
+path/to/output.wav
```
-
-Then point the plugin at the exported model directory and you're good to go!
-
-## Advanced usage
-
-The model architectures and cofigurations in `bin/train/inputs/models` should work plenty well out of the box.
-However, feel free to play around with it; sometimes some tweaks can help improve performance.
-
-Also, you can train for shorter or longer.
-1000 epochs is typically overkill, but how little you can get away with depends on the model you're using.
-I recommend watching the checkpoints and keeping an eye out for when the ESR drops below
-0.01--usually it'll sound pretty good by that point.
diff --git a/bin/train/inputs/data/single_pair.json b/bin/train/inputs/data/single_pair.json
@@ -1,11 +1,11 @@
{
"train": {
"start": null,
- "stop": 36576000,
- "ny": 1024
+ "stop": -432000,
+ "ny": 8192
},
"validation": {
- "start": 36576000,
+ "start": -432000,
"stop": null,
"ny": null
},
diff --git a/bin/train/inputs/data/two_pairs.json b/bin/train/inputs/data/two_pairs.json
@@ -2,7 +2,7 @@
"train": {
"x_path": "C:\\path\\to\\train\\source.wav",
"y_path": "C:\\path\\to\\train\\target.wav",
- "ny": 1024
+ "ny": 8192
},
"validation": {
"x_path": "C:\\path\\to\\validation\\source.wav",
diff --git a/bin/train/inputs/learning/default.json b/bin/train/inputs/learning/default.json
@@ -6,12 +6,10 @@
"drop_last": true,
"num_workers": 0
},
- "val_dataloader": {
- },
+ "val_dataloader": {},
"trainer": {
"gpus": 1,
- "max_epochs": 1000
+ "max_epochs": 100
},
- "trainer_fit_kwargs": {
- }
+ "trainer_fit_kwargs": {}
}
\ No newline at end of file
diff --git a/environment_gpu.yml b/environment_gpu.yml
@@ -2,7 +2,7 @@
# Created Date: Saturday February 13th 2021
# Author: Steven Atkinson ([email protected])
-name: nam-test
+name: nam
channels:
- pytorch
- nvidia
@@ -24,7 +24,7 @@ dependencies:
- wheel
- pip:
- onnx
- - onnxruntime # TODO GPU...
+ - onnxruntime # TODO GPU...
- pre-commit
- pytorch_lightning
- sounddevice
diff --git a/nam/train/_version.py b/nam/train/_version.py
@@ -0,0 +1,24 @@
+# File: _version.py
+# Created Date: Tuesday December 20th 2022
+# Author: Steven Atkinson ([email protected])
+
+"""
+Version utility
+"""
+
+class Version:
+ def __init__(self, major: int, minor: int, patch: int):
+ self.major = major
+ self.minor = minor
+ self.patch = patch
+
+ def __lt__(self, other) -> bool:
+ if self.major != other.major:
+ return self.major < other.major
+ if self.minor != other.minor:
+ return self.minor < other.minor
+ if self.patch != other.patch:
+ return self.patch < other.patch
+
+ def __str__(self) -> str:
+ return f"{self.major}.{self.minor}.{self.patch}"
diff --git a/nam/train/colab.py b/nam/train/colab.py
@@ -6,54 +6,24 @@
Hide the mess in Colab to make things look pretty for users.
"""
-from enum import Enum
+
from pathlib import Path
-from time import time
from typing import Optional, Tuple
-import matplotlib.pyplot as plt
-import numpy as np
-import pytorch_lightning as pl
-import torch
-from torch.utils.data import DataLoader
-
-from nam.data import REQUIRED_RATE, Split, init_dataset, wav_to_np
-from nam.models import Model
-
-
-class _Architecture(Enum):
- STANDARD = "standard"
- LITE = "lite"
- FEATHER = "feather"
-
-
-class _Version:
- def __init__(self, major: int, minor: int, patch: int):
- self.major = major
- self.minor = minor
- self.patch = patch
+from ._version import Version
+from .core import train
- def __lt__(self, other) -> bool:
- if self.major != other.major:
- return self.major < other.major
- if self.minor != other.minor:
- return self.minor < other.minor
- if self.patch != other.patch:
- return self.patch < other.patch
- def __str__(self) -> str:
- return f"{self.major}.{self.minor}.{self.patch}"
-
-
-_INPUT_BASENAMES = ((_Version(1, 1, 1), "v1_1_1.wav"), (_Version(1, 0, 0), "v1.wav"))
+_INPUT_BASENAMES = ((Version(1, 1, 1), "v1_1_1.wav"), (Version(1, 0, 0), "v1.wav"))
_BUGGY_INPUT_BASENAMES = {
# 1.1.0 has the spikes at the wrong spots.
"v1_1_0.wav"
}
_OUTPUT_BASENAME = "output.wav"
+_TRAIN_PATH = "."
-def _check_for_files() -> Tuple[_Version, str]:
+def _check_for_files() -> Tuple[Version, str]:
print("Checking that we have all of the required audio files...")
for name in _BUGGY_INPUT_BASENAMES:
if Path(name).exists():
@@ -79,239 +49,6 @@ def _check_for_files() -> Tuple[_Version, str]:
return input_version, input_basename
-def _calibrate_delay_v1() -> int:
- safety_factor = 4
- # Locations of blips in v1 signal file:
- i1, i2 = 12_000, 36_000
- j1_start_looking = i1 - 1_000
- j2_start_looking = i2 - 1_000
-
- y = wav_to_np(_OUTPUT_BASENAME)[:48_000]
-
- background_level = np.max(np.abs(y[:6_000]))
- trigger_threshold = max(background_level + 0.01, 1.01 * background_level)
- j1 = np.where(np.abs(y[j1_start_looking:j2_start_looking]) > trigger_threshold)[0][
- 0
- ]
- j2 = np.where(np.abs(y[j2_start_looking:]) > trigger_threshold)[0][0]
-
- delay_1 = (j1 + j1_start_looking) - i1
- delay_2 = (j2 + j2_start_looking) - i2
- print(f"Delays: {delay_1}, {delay_2}")
- delay = int(np.min([delay_1, delay_2])) - safety_factor
- print(f"Final delay is {delay}")
- return delay
-
-
-def _plot_delay_v1(delay: int, input_basename: str):
- print("Plotting the delay for manual inspection...")
- x = wav_to_np(input_basename)[:48_000]
- y = wav_to_np(_OUTPUT_BASENAME)[:48_000]
- i = np.where(np.abs(x) > 0.1)[0][0] # In case resampled poorly
- di = 20
- plt.figure()
- # plt.plot(x[i - di : i + di], ".-", label="Input")
- plt.plot(
- np.arange(-di, di),
- y[i - di + delay : i + di + delay],
- ".-",
- label="Output",
- )
- plt.axvline(x=0, linestyle="--", color="C1")
- plt.legend()
- plt.show() # This doesn't freeze the notebook
-
-
-def _calibrate_delay(
- delay: Optional[int], input_version: _Version, input_basename: str
-) -> int:
- if input_version.major == 1:
- calibrate, plot = _calibrate_delay_v1, _plot_delay_v1
- else:
- raise NotImplementedError(
- f"Input calibration not implemented for input version {input_version}"
- )
- if delay is not None:
- print(f"Delay is specified as {delay}")
- else:
- print("Delay wasn't provided; attempting to calibrate automatically...")
- delay = calibrate()
- plot(delay, input_basename)
- return delay
-
-
-def _get_wavenet_config(architecture):
- return {
- _Architecture.STANDARD: {
- "layers_configs": [
- {
- "input_size": 1,
- "condition_size": 1,
- "channels": 16,
- "head_size": 8,
- "kernel_size": 3,
- "dilations": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
- "activation": "Tanh",
- "gated": False,
- "head_bias": False,
- },
- {
- "condition_size": 1,
- "input_size": 16,
- "channels": 8,
- "head_size": 1,
- "kernel_size": 3,
- "dilations": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
- "activation": "Tanh",
- "gated": False,
- "head_bias": True,
- },
- ],
- "head_scale": 0.02,
- },
- _Architecture.LITE: {
- "layers_configs": [
- {
- "input_size": 1,
- "condition_size": 1,
- "channels": 12,
- "head_size": 6,
- "kernel_size": 3,
- "dilations": [1, 2, 4, 8, 16, 32, 64],
- "activation": "Tanh",
- "gated": False,
- "head_bias": False,
- },
- {
- "condition_size": 1,
- "input_size": 12,
- "channels": 6,
- "head_size": 1,
- "kernel_size": 3,
- "dilations": [128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
- "activation": "Tanh",
- "gated": False,
- "head_bias": True,
- },
- ],
- "head_scale": 0.02,
- },
- _Architecture.FEATHER: {
- "layers_configs": [
- {
- "input_size": 1,
- "condition_size": 1,
- "channels": 8,
- "head_size": 4,
- "kernel_size": 3,
- "dilations": [1, 2, 4, 8, 16, 32, 64],
- "activation": "Tanh",
- "gated": False,
- "head_bias": False,
- },
- {
- "condition_size": 1,
- "input_size": 8,
- "channels": 4,
- "head_size": 1,
- "kernel_size": 3,
- "dilations": [128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
- "activation": "Tanh",
- "gated": False,
- "head_bias": True,
- },
- ],
- "head_scale": 0.02,
- },
- }[architecture]
-
-
-def _get_configs(
- input_basename: str,
- delay: int,
- epochs: int,
- architecture: _Architecture,
- lr: float,
- lr_decay: float,
-):
- val_seconds = 9
- train_val_split = -val_seconds * REQUIRED_RATE
- data_config = {
- "train": {"ny": 8192, "stop": train_val_split},
- "validation": {"ny": None, "start": train_val_split},
- "common": {
- "x_path": input_basename,
- "y_path": _OUTPUT_BASENAME,
- "delay": delay,
- },
- }
- model_config = {
- "net": {
- "name": "WaveNet",
- # This should do decently. If you really want a nice model, try turning up
- # "channels" in the first block and "input_size" in the second from 12 to 16.
- "config": _get_wavenet_config(architecture),
- },
- "loss": {"val_loss": "esr"},
- "optimizer": {"lr": lr},
- "lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 1.0 - lr_decay}},
- }
- learning_config = {
- "train_dataloader": {
- "batch_size": 16,
- "shuffle": True,
- "pin_memory": True,
- "drop_last": True,
- "num_workers": 0,
- },
- "val_dataloader": {},
- "trainer": {"accelerator": "gpu", "devices": 1, "max_epochs": epochs},
- }
- return data_config, model_config, learning_config
-
-
-def _esr(pred: torch.Tensor, target: torch.Tensor) -> float:
- return (
- torch.mean(torch.square(pred - target)).item()
- / torch.mean(torch.square(target)).item()
- )
-
-
-def _plot(
- model, ds, window_start: Optional[int] = None, window_end: Optional[int] = None
-):
- print("Plotting a comparison of your model with the target output...")
- with torch.no_grad():
- tx = len(ds.x) / 48_000
- print(f"Run (t={tx:.2f} sec)")
- t0 = time()
- output = model(ds.x).flatten().cpu().numpy()
- t1 = time()
- print(f"Took {t1 - t0:.2f} sec ({tx / (t1 - t0):.2f}x)")
-
- esr = _esr(torch.Tensor(output), ds.y)
- # Trying my best to put numbers to it...
- if esr < 0.01:
- esr_comment = "Great!"
- elif esr < 0.035:
- esr_comment = "Not bad!"
- elif esr < 0.1:
- esr_comment = "...This *might* sound ok!"
- elif esr < 0.3:
- esr_comment = "...This probably won't sound great :("
- else:
- esr_comment = "...Something seems to have gone wrong."
- print(f"Error-signal ratio = {esr:.3f}")
- print(esr_comment)
-
- plt.figure(figsize=(16, 5))
- plt.plot(output[window_start:window_end], label="Prediction")
- plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target")
- plt.title(f"ESR={esr:.3f}")
- plt.legend()
- plt.show()
-
-
def _get_valid_export_directory():
def get_path(version):
return Path("exported_models", f"version_{version}")
@@ -326,9 +63,9 @@ def run(
epochs: int = 100,
delay: Optional[int] = None,
architecture: str = "standard",
- lr=0.004,
- lr_decay=0.007,
- seed=0,
+ lr: float=0.004,
+ lr_decay: float=0.007,
+ seed: Optional[int]=0,
):
"""
:param epochs: How amny epochs we'll train for.
@@ -340,56 +77,20 @@ def run(
:param lr_decay: The amount by which the learning rate decays each epoch
:param seed: RNG seed for reproducibility.
"""
- torch.manual_seed(seed)
- input_version, input_basename = _check_for_files()
- delay = _calibrate_delay(delay, input_version, input_basename)
- data_config, model_config, learning_config = _get_configs(
- input_basename,
- delay,
- epochs,
- _Architecture(architecture),
- lr,
- lr_decay,
- )
-
- print("Starting training. Let's rock!")
- model = Model.init_from_config(model_config)
- data_config["common"]["nx"] = model.net.receptive_field
- dataset_train = init_dataset(data_config, Split.TRAIN)
- dataset_validation = init_dataset(data_config, Split.VALIDATION)
- train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"])
- val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"])
- trainer = pl.Trainer(
- callbacks=[
- pl.callbacks.model_checkpoint.ModelCheckpoint(
- filename="checkpoint_best_{epoch:04d}_{step}_{ESR:.4f}_{MSE:.3e}",
- save_top_k=3,
- monitor="val_loss",
- every_n_epochs=1,
- ),
- pl.callbacks.model_checkpoint.ModelCheckpoint(
- filename="checkpoint_last_{epoch:04d}_{step}", every_n_epochs=1
- ),
- ],
- **learning_config["trainer"],
- )
- trainer.fit(model, train_dataloader, val_dataloader)
-
- # Go to best checkpoint
- best_checkpoint = trainer.checkpoint_callback.best_model_path
- if best_checkpoint != "":
- model = Model.load_from_checkpoint(
- trainer.checkpoint_callback.best_model_path,
- **Model.parse_config(model_config),
- )
- model.eval()
+ input_version, input_basename = _check_for_files()
- _plot(
- model,
- dataset_validation,
- window_start=100_000, # Start of the plotting window, in samples
- window_end=101_000, # End of the plotting window, in samples
+ model = train(
+ input_basename,
+ _OUTPUT_BASENAME,
+ _TRAIN_PATH,
+ input_version=input_version,
+ epochs=epochs,
+ delay=delay,
+ architecture=architecture,
+ lr=lr,
+ lr_decay=lr_decay,
+ seed=seed,
)
print("Exporting your model...")
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -0,0 +1,371 @@
+# File: gui.py
+# Created Date: Tuesday December 20th 2022
+# Author: Steven Atkinson ([email protected])
+
+"""
+Functions used by the GUI trainer.
+"""
+
+import hashlib
+from enum import Enum
+from time import time
+from typing import Optional, Union
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pytorch_lightning as pl
+import torch
+from torch.utils.data import DataLoader
+
+from ..data import REQUIRED_RATE, Split, init_dataset, wav_to_np
+from ..models import Model
+from ._version import Version
+
+
+class Architecture(Enum):
+ STANDARD = "standard"
+ LITE = "lite"
+ FEATHER = "feather"
+
+
+def _detect_input_version(input_path) -> Version:
+ """
+ Check to see if the input matches any of the known inputs
+ """
+ md5 = hashlib.md5()
+ buffer_size = 65536
+ with open(input_path, "rb") as f:
+ while True:
+ data = f.read(buffer_size)
+ if not data:
+ break
+ md5.update(data)
+ file_hash = md5.hexdigest()
+
+ version = {
+ "4d54a958861bf720ec4637f43d44a7ef": Version(1, 0, 0),
+ "7c3b6119c74465f79d96c761a0e27370": Version(1, 1, 1),
+ }.get(file_hash)
+ if version is None:
+ raise RuntimeError(
+ f"Provided input file {input_path} does not match any known standard input "
+ "files."
+ )
+ return version
+
+
+def _calibrate_delay_v1(input_path, output_path) -> int:
+ safety_factor = 4
+ # Locations of blips in v1 signal file:
+ i1, i2 = 12_000, 36_000
+ j1_start_looking = i1 - 1_000
+ j2_start_looking = i2 - 1_000
+
+ y = wav_to_np(output_path)[:48_000]
+
+ background_level = np.max(np.abs(y[:6_000]))
+ trigger_threshold = max(background_level + 0.01, 1.01 * background_level)
+ j1 = np.where(np.abs(y[j1_start_looking:j2_start_looking]) > trigger_threshold)[0][
+ 0
+ ]
+ j2 = np.where(np.abs(y[j2_start_looking:]) > trigger_threshold)[0][0]
+
+ delay_1 = (j1 + j1_start_looking) - i1
+ delay_2 = (j2 + j2_start_looking) - i2
+ print(f"Delays: {delay_1}, {delay_2}")
+ delay = int(np.min([delay_1, delay_2])) - safety_factor
+ print(f"Final delay is {delay}")
+ return delay
+
+
+def _plot_delay_v1(delay: int, input_path: str, output_path: str):
+ print("Plotting the delay for manual inspection...")
+ x = wav_to_np(input_path)[:48_000]
+ y = wav_to_np(output_path)[:48_000]
+ i = np.where(np.abs(x) > 0.1)[0][0] # In case resampled poorly
+ di = 20
+ plt.figure()
+ # plt.plot(x[i - di : i + di], ".-", label="Input")
+ plt.plot(
+ np.arange(-di, di),
+ y[i - di + delay : i + di + delay],
+ ".-",
+ label="Output",
+ )
+ plt.axvline(x=0, linestyle="--", color="C1")
+ plt.legend()
+ plt.show() # This doesn't freeze the notebook
+
+
+def _calibrate_delay(
+ delay: Optional[int], input_version: Version, input_path: str, output_path: str,
+) -> int:
+ if input_version.major == 1:
+ calibrate, plot = _calibrate_delay_v1, _plot_delay_v1
+ else:
+ raise NotImplementedError(
+ f"Input calibration not implemented for input version {input_version}"
+ )
+ if delay is not None:
+ print(f"Delay is specified as {delay}")
+ else:
+ print("Delay wasn't provided; attempting to calibrate automatically...")
+ delay = calibrate(input_path, output_path)
+ plot(delay, input_path, output_path)
+ return delay
+
+
+def _get_wavenet_config(architecture):
+ return {
+ Architecture.STANDARD: {
+ "layers_configs": [
+ {
+ "input_size": 1,
+ "condition_size": 1,
+ "channels": 16,
+ "head_size": 8,
+ "kernel_size": 3,
+ "dilations": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
+ "activation": "Tanh",
+ "gated": False,
+ "head_bias": False,
+ },
+ {
+ "condition_size": 1,
+ "input_size": 16,
+ "channels": 8,
+ "head_size": 1,
+ "kernel_size": 3,
+ "dilations": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
+ "activation": "Tanh",
+ "gated": False,
+ "head_bias": True,
+ },
+ ],
+ "head_scale": 0.02,
+ },
+ Architecture.LITE: {
+ "layers_configs": [
+ {
+ "input_size": 1,
+ "condition_size": 1,
+ "channels": 12,
+ "head_size": 6,
+ "kernel_size": 3,
+ "dilations": [1, 2, 4, 8, 16, 32, 64],
+ "activation": "Tanh",
+ "gated": False,
+ "head_bias": False,
+ },
+ {
+ "condition_size": 1,
+ "input_size": 12,
+ "channels": 6,
+ "head_size": 1,
+ "kernel_size": 3,
+ "dilations": [128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
+ "activation": "Tanh",
+ "gated": False,
+ "head_bias": True,
+ },
+ ],
+ "head_scale": 0.02,
+ },
+ Architecture.FEATHER: {
+ "layers_configs": [
+ {
+ "input_size": 1,
+ "condition_size": 1,
+ "channels": 8,
+ "head_size": 4,
+ "kernel_size": 3,
+ "dilations": [1, 2, 4, 8, 16, 32, 64],
+ "activation": "Tanh",
+ "gated": False,
+ "head_bias": False,
+ },
+ {
+ "condition_size": 1,
+ "input_size": 8,
+ "channels": 4,
+ "head_size": 1,
+ "kernel_size": 3,
+ "dilations": [128, 256, 512, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512],
+ "activation": "Tanh",
+ "gated": False,
+ "head_bias": True,
+ },
+ ],
+ "head_scale": 0.02,
+ },
+ }[architecture]
+
+
+def _get_configs(
+ input_basename: str,
+ output_basename: str,
+ delay: int,
+ epochs: int,
+ architecture: Architecture,
+ lr: float,
+ lr_decay: float,
+):
+ val_seconds = 9
+ train_val_split = -val_seconds * REQUIRED_RATE
+ data_config = {
+ "train": {"ny": 8192, "stop": train_val_split},
+ "validation": {"ny": None, "start": train_val_split},
+ "common": {
+ "x_path": input_basename,
+ "y_path": output_basename,
+ "delay": delay,
+ },
+ }
+ model_config = {
+ "net": {
+ "name": "WaveNet",
+ # This should do decently. If you really want a nice model, try turning up
+ # "channels" in the first block and "input_size" in the second from 12 to 16.
+ "config": _get_wavenet_config(architecture)
+ },
+ "loss": {"val_loss": "esr"},
+ "optimizer": {"lr": lr},
+ "lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 1.0 - lr_decay}},
+ }
+ if torch.cuda.is_available():
+ device_config = {"accelerator": "gpu", "devices": 1}
+ else:
+ print("WARNING: No GPU was found. Training will be very slow!")
+ device_config = {}
+ learning_config = {
+ "train_dataloader": {
+ "batch_size": 16,
+ "shuffle": True,
+ "pin_memory": True,
+ "drop_last": True,
+ "num_workers": 0,
+ },
+ "val_dataloader": {},
+ "trainer": {"max_epochs": epochs, **device_config},
+ }
+ return data_config, model_config, learning_config
+
+
+def _esr(pred: torch.Tensor, target: torch.Tensor) -> float:
+ return (
+ torch.mean(torch.square(pred - target)).item()
+ / torch.mean(torch.square(target)).item()
+ )
+
+
+def _plot(
+ model, ds, window_start: Optional[int] = None, window_end: Optional[int] = None
+):
+ print("Plotting a comparison of your model with the target output...")
+ with torch.no_grad():
+ tx = len(ds.x) / 48_000
+ print(f"Run (t={tx:.2f} sec)")
+ t0 = time()
+ output = model(ds.x).flatten().cpu().numpy()
+ t1 = time()
+ print(f"Took {t1 - t0:.2f} sec ({tx / (t1 - t0):.2f}x)")
+
+ esr = _esr(torch.Tensor(output), ds.y)
+ # Trying my best to put numbers to it...
+ if esr < 0.01:
+ esr_comment = "Great!"
+ elif esr < 0.035:
+ esr_comment = "Not bad!"
+ elif esr < 0.1:
+ esr_comment = "...This *might* sound ok!"
+ elif esr < 0.3:
+ esr_comment = "...This probably won't sound great :("
+ else:
+ esr_comment = "...Something seems to have gone wrong."
+ print(f"Error-signal ratio = {esr:.3f}")
+ print(esr_comment)
+
+ plt.figure(figsize=(16, 5))
+ plt.plot(output[window_start:window_end], label="Prediction")
+ plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target")
+ plt.title(f"ESR={esr:.3f}")
+ plt.legend()
+ plt.show()
+
+
+def train(
+ input_path: str,
+ output_path: str,
+ train_path: str,
+ input_version: Optional[Version] = None,
+ epochs=100,
+ delay=None,
+ architecture: Union[Architecture, str]=Architecture.STANDARD,
+ lr=0.004,
+ lr_decay=0.007,
+ seed: Optional[int] = 0,
+):
+ if seed is not None:
+ torch.manual_seed(seed)
+
+ # This needs more thought...
+ # 1. Does the user want me to calibrate the delay?
+ # 2. Does the user want to see what the chosen (by them or me) delay looks like?
+ if delay is None:
+ if input_version is None:
+ input_version = _detect_input_version(input_path)
+ delay = _calibrate_delay(delay, input_version, input_path, output_path)
+ else:
+ print(f"Delay provided as {delay}; skip calibration")
+
+ data_config, model_config, learning_config = _get_configs(
+ input_path,
+ output_path,
+ delay,
+ epochs,
+ Architecture(architecture),
+ lr,
+ lr_decay,
+ )
+
+ print("Starting training. Let's rock!")
+ model = Model.init_from_config(model_config)
+ data_config["common"]["nx"] = model.net.receptive_field
+ dataset_train = init_dataset(data_config, Split.TRAIN)
+ dataset_validation = init_dataset(data_config, Split.VALIDATION)
+ train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"])
+ val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"])
+
+ trainer = pl.Trainer(
+ callbacks=[
+ pl.callbacks.model_checkpoint.ModelCheckpoint(
+ filename="checkpoint_best_{epoch:04d}_{step}_{ESR:.4f}_{MSE:.3e}",
+ save_top_k=3,
+ monitor="val_loss",
+ every_n_epochs=1,
+ ),
+ pl.callbacks.model_checkpoint.ModelCheckpoint(
+ filename="checkpoint_last_{epoch:04d}_{step}", every_n_epochs=1
+ ),
+ ],
+ default_root_dir=train_path,
+ **learning_config["trainer"],
+ )
+ trainer.fit(model, train_dataloader, val_dataloader)
+
+ # Go to best checkpoint
+ best_checkpoint = trainer.checkpoint_callback.best_model_path
+ if best_checkpoint != "":
+ model = Model.load_from_checkpoint(
+ trainer.checkpoint_callback.best_model_path,
+ **Model.parse_config(model_config),
+ )
+ model.eval()
+
+ _plot(
+ model,
+ dataset_validation,
+ window_start=100_000, # Start of the plotting window, in samples
+ window_end=101_000, # End of the plotting window, in samples
+ )
+ return model
diff --git a/nam/train/gui.py b/nam/train/gui.py
@@ -0,0 +1,457 @@
+# File: __init__.py
+# Created Date: Saturday February 25th 2023
+# Author: Steven Atkinson ([email protected])
+
+"""
+GUI for training
+
+Usage:
+>>> import nam.train.gui
+"""
+
+import tkinter as tk
+from dataclasses import dataclass
+from enum import Enum
+from pathlib import Path
+from tkinter import filedialog
+from typing import Callable, Optional, Sequence
+
+try:
+ from .. import __version__
+ from . import core
+
+ _install_is_valid = True
+except ImportError:
+ _install_is_valid = False
+
+_BUTTON_WIDTH = 20
+_BUTTON_HEIGHT = 2
+_TEXT_WIDTH = 70
+
+_DEFAULT_NUM_EPOCHS = 100
+_DEFAULT_DELAY = None
+
+
+@dataclass
+class _AdvancedOptions(object):
+ architecture: core.Architecture
+ num_epochs: int
+ delay: Optional[int]
+
+
+class _PathType(Enum):
+ FILE = "file"
+ DIRECTORY = "directory"
+
+
+class _PathButton(object):
+ """
+ Button and the path
+ """
+
+ def __init__(
+ self,
+ frame: tk.Frame,
+ button_text,
+ info_str: str,
+ path_type: _PathType,
+ hooks: Optional[Sequence[Callable[[], None]]] = None,
+ ):
+ self._info_str = info_str
+ self._path: Optional[Path] = None
+ self._path_type = path_type
+ self._button = tk.Button(
+ frame,
+ text=button_text,
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ fg="black",
+ command=self._set_val,
+ )
+ self._button.pack(side=tk.LEFT)
+ self._label = tk.Label(
+ frame,
+ width=_TEXT_WIDTH,
+ height=_BUTTON_HEIGHT,
+ fg="black",
+ bg=None,
+ anchor="w",
+ )
+ self._label.pack(side=tk.RIGHT)
+ self._hooks = hooks
+ self._set_text()
+
+ @property
+ def val(self) -> Optional[Path]:
+ return self._path
+
+ def _set_text(self):
+ if self._path is None:
+ self._label["fg"] = "red"
+ self._label["text"] = f"{self._info_str} is not set!"
+ else:
+ self._label["fg"] = "black"
+ self._label["text"] = f"{self._info_str} set to {self.val}"
+
+ def _set_val(self):
+ res = {
+ _PathType.FILE: filedialog.askopenfilename,
+ _PathType.DIRECTORY: filedialog.askdirectory,
+ }[self._path_type]()
+ if res != "":
+ self._path = res
+ self._set_text()
+
+ if self._hooks is not None:
+ for h in self._hooks:
+ h()
+
+
+class _GUI(object):
+ def __init__(self):
+ self._root = tk.Tk()
+ self._root.title(f"NAM Trainer - v{__version__}")
+
+ # Buttons for paths:
+ self._frame_input_path = tk.Frame(self._root)
+ self._frame_input_path.pack()
+ self._path_button_input = _PathButton(
+ self._frame_input_path,
+ "Input Audio",
+ "Input audio",
+ _PathType.FILE,
+ hooks=[self._check_button_states],
+ )
+
+ self._frame_output_path = tk.Frame(self._root)
+ self._frame_output_path.pack()
+ self._path_button_output = _PathButton(
+ self._frame_output_path,
+ "Output Audio",
+ "Output audio",
+ _PathType.FILE,
+ hooks=[self._check_button_states],
+ )
+
+ self._frame_train_destination = tk.Frame(self._root)
+ self._frame_train_destination.pack()
+ self._path_button_train_destination = _PathButton(
+ self._frame_train_destination,
+ "Train Destination",
+ "Train destination",
+ _PathType.DIRECTORY,
+ hooks=[self._check_button_states],
+ )
+
+ # Advanced options for training
+ default_architecture = core.Architecture.STANDARD
+ self.advanced_options = _AdvancedOptions(
+ default_architecture, _DEFAULT_NUM_EPOCHS, _DEFAULT_DELAY
+ )
+ # Window to edit them:
+ self._frame_advanced_options = tk.Frame(self._root)
+ self._frame_advanced_options.pack()
+ self._button_advanced_options = tk.Button(
+ self._frame_advanced_options,
+ text="Advanced options...",
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ fg="black",
+ command=self._open_advanced_options,
+ )
+ self._button_advanced_options.pack()
+
+ # Train button
+ self._frame_train = tk.Frame(self._root)
+ self._frame_train.pack()
+ self._button_train = tk.Button(
+ self._frame_train,
+ text="Train",
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ fg="black",
+ command=self._train,
+ )
+ self._button_train.pack()
+
+ self._check_button_states()
+
+ def mainloop(self):
+ self._root.mainloop()
+
+ def _open_advanced_options(self):
+ """
+ Open advanced options
+ """
+ ao = _AdvancedOptionsGUI(self)
+ # I should probably disable the main GUI...
+ ao.mainloop()
+ # ...and then re-enable it once it gets closed.
+
+ def _train(self):
+ # Advanced options:
+ num_epochs = self.advanced_options.num_epochs
+ architecture = self.advanced_options.architecture
+ delay = self.advanced_options.delay
+
+ # Advanced-er options
+ # If you're poking around looking for these, then maybe it's time to learn to
+ # use the command-line scripts ;)
+ lr = 0.004
+ lr_decay = 0.007
+ seed = 0
+
+ # Run it
+ trained_model = core.train(
+ self._path_button_input.val,
+ self._path_button_output.val,
+ self._path_button_train_destination.val,
+ epochs=num_epochs,
+ delay=delay,
+ architecture=architecture,
+ lr=lr,
+ lr_decay=lr_decay,
+ seed=seed,
+ )
+ print("Model training complete!")
+ print("Exporting...")
+ outdir = self._path_button_train_destination.val
+ print(f"Exporting trained model to {outdir}...")
+ trained_model.net.export(outdir)
+ print("Done!")
+
+ def _check_button_states(self):
+ """
+ Determine if any buttons should be disabled
+ """
+ # Train button is diabled unless all paths are set
+ if any(
+ pb.val is None
+ for pb in (
+ self._path_button_input,
+ self._path_button_output,
+ self._path_button_train_destination,
+ )
+ ):
+ self._button_train["state"] = tk.DISABLED
+ return
+ self._button_train["state"] = tk.NORMAL
+
+
+_ADVANCED_OPTIONS_LEFT_WIDTH = 12
+_ADVANCED_OPTIONS_RIGHT_WIDTH = 12
+
+
+class _LabeledOptionMenu(object):
+ """
+ Label (left) and radio buttons (right)
+ """
+
+ def __init__(
+ self, frame: tk.Frame, label: str, choices: Enum, default: Optional[Enum] = None
+ ):
+ """
+ :param command: Called to propagate option selection. Is provided with the
+ value corresponding to the radio button selected.
+ """
+ self._frame = frame
+ self._choices = choices
+ height = _BUTTON_HEIGHT
+ bg = None
+ fg = "black"
+ self._label = tk.Label(
+ frame,
+ width=_ADVANCED_OPTIONS_LEFT_WIDTH,
+ height=height,
+ fg=fg,
+ bg=bg,
+ anchor="w",
+ text=label,
+ )
+ self._label.pack(side=tk.LEFT)
+
+ frame_menu = tk.Frame(frame)
+ frame_menu.pack(side=tk.RIGHT)
+
+ self._selected_value = None
+ default = (list(choices)[0] if default is None else default).value
+ self._menu = tk.OptionMenu(
+ frame_menu,
+ tk.StringVar(master=frame, value=default, name=label),
+ # default,
+ *[choice.value for choice in choices], # if choice.value!=default],
+ command=self._set,
+ )
+ self._menu.config(width=_ADVANCED_OPTIONS_RIGHT_WIDTH)
+ self._menu.pack(side=tk.RIGHT)
+ # Initialize
+ self._set(default)
+
+ def get(self) -> Enum:
+ return self._selected_value
+
+ def _set(self, val: str):
+ """
+ Set the value selected
+ """
+ self._selected_value = self._choices(val)
+
+
+class _LabeledText(object):
+ """
+ Label (left) and text input (right)
+ """
+
+ def __init__(self, frame: tk.Frame, label: str, default=None, type=None):
+ """
+ :param command: Called to propagate option selection. Is provided with the
+ value corresponding to the radio button selected.
+ :param type: If provided, casts value to given type
+ """
+ self._frame = frame
+ label_height = 2
+ text_height = 1
+ self._label = tk.Label(
+ frame,
+ width=_ADVANCED_OPTIONS_LEFT_WIDTH,
+ height=label_height,
+ fg="black",
+ bg=None,
+ anchor="w",
+ text=label,
+ )
+ self._label.pack(side=tk.LEFT)
+
+ self._text = tk.Text(
+ frame,
+ width=_ADVANCED_OPTIONS_RIGHT_WIDTH,
+ height=text_height,
+ fg="black",
+ bg=None,
+ )
+ self._text.pack(side=tk.RIGHT)
+
+ self._type = type
+
+ if default is not None:
+ self._text.insert("1.0", str(default))
+
+ def get(self):
+ try:
+ val = self._text.get("1.0", tk.END) # Line 1, character zero (wat)
+ if self._type is not None:
+ val = self._type(val)
+ return val
+ except tk.TclError:
+ return None
+
+
+class _AdvancedOptionsGUI(object):
+ """
+ A window to hold advanced options (Architecture and number of epochs)
+ """
+
+ def __init__(self, parent: _GUI):
+ self._parent = parent
+ self._root = tk.Tk()
+ self._root.title("Advanced Options")
+
+ # Architecture: radio buttons
+ self._frame_architecture = tk.Frame(self._root)
+ self._frame_architecture.pack()
+ self._architecture = _LabeledOptionMenu(
+ self._frame_architecture,
+ "Architecture",
+ core.Architecture,
+ default=self._parent.advanced_options.architecture,
+ )
+
+ # Number of epochs: text box
+ self._frame_epochs = tk.Frame(self._root)
+ self._frame_epochs.pack()
+
+ def non_negative_int(val):
+ val = int(val)
+ if val < 0:
+ val = 0
+ return val
+
+ self._epochs = _LabeledText(
+ self._frame_epochs,
+ "Epochs",
+ default=str(self._parent.advanced_options.num_epochs),
+ type=non_negative_int,
+ )
+
+ # Delay: text box
+ self._frame_delay = tk.Frame(self._root)
+ self._frame_delay.pack()
+
+ def int_or_null(val):
+ val = val.rstrip()
+ if val == "null":
+ return val
+ return int(val)
+
+ def int_or_null_inv(val):
+ return "null" if val is None else str(val)
+
+ self._delay = _LabeledText(
+ self._frame_delay,
+ "Delay",
+ default=int_or_null_inv(self._parent.advanced_options.delay),
+ type=int_or_null,
+ )
+
+ # "Ok": apply and destory
+ self._frame_ok = tk.Frame(self._root)
+ self._frame_ok.pack()
+ self._button_ok = tk.Button(
+ self._frame_ok,
+ text="Ok",
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ fg="black",
+ command=self._apply_and_destroy,
+ )
+ self._button_ok.pack()
+
+ def mainloop(self):
+ self._root.mainloop()
+
+ def _apply_and_destroy(self):
+ """
+ Set values to parent and destroy this object
+ """
+ self._parent.advanced_options.architecture = self._architecture.get()
+ epochs = self._epochs.get()
+ if epochs is not None:
+ self._parent.advanced_options.num_epochs = epochs
+ delay = self._delay.get()
+ # Value None is returned as "null" to disambiguate from non-set.
+ if delay is not None:
+ self._parent.advanced_options.delay = None if delay == "null" else delay
+ self._root.destroy()
+
+
+def _install_error():
+ window = tk.Tk()
+ window.title("ERROR")
+ label = tk.Label(
+ window,
+ width=45,
+ height=2,
+ text="The NAM training software has not been installed correctly.",
+ )
+ label.pack()
+ button = tk.Button(window, width=10, height=2, text="Quit", command=window.destroy)
+ button.pack()
+ window.mainloop()
+
+
+def run():
+ if _install_is_valid:
+ _gui = _GUI()
+ _gui.mainloop()
+ else:
+ _install_error()
diff --git a/setup.py b/setup.py
@@ -31,4 +31,9 @@ setup(
url="https://github.com/sdatkinson/",
install_requires=requirements,
packages=find_packages(),
+ entry_points={
+ 'console_scripts': [
+ 'nam = nam.train.gui:run',
+ ]
+ }
)