neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 29d1fb038f892a227e27b461553ea0ce350d3eeb
parent 46560898db9b4617d17d9a9586231bd1d43769eb
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Tue, 16 May 2023 19:04:35 -0700

Support modeling cabs (#250)

* Update data.py

* Sample rate more flexible in wav_to_x
* Better error messages in ConcatDataset

* Black

* Docstrings on Dataset

* Amp and cab training in core

* Cab modeling option in GUI trainer

* Cab modeling in Colab

* Add ir.py, update core.py

Drop in the real IR fitter

* Remove cab modeling option. Leave the setting on.

* Remove IR-fitting code

* Remove ir.py

* Revert easy_colab.ipynb

* Remove _with_ir.py

* Remove WithIR code

* Revert hyper_net.py

* Clean up WithIR code

* Clean up WithIR code
Diffstat:
Mnam/data.py | 31+++++++++++++++++++++++++++----
Mnam/models/_base.py | 2+-
Mnam/models/base.py | 2+-
Mnam/models/metadata.py | 1+
Mnam/train/_version.py | 6+++++-
Mnam/train/colab.py | 2+-
Mnam/train/core.py | 47+++++++++++++++++++++++++++++++++--------------
Mnam/train/gui.py | 26+++++++++++++++-----------
8 files changed, 84 insertions(+), 33 deletions(-)

diff --git a/nam/data.py b/nam/data.py @@ -298,11 +298,21 @@ class Dataset(AbstractDataset, InitializableFromConfig): return self._ny @property - def x(self): + def x(self) -> torch.Tensor: + """ + The input audio data + + :return: (N,) + """ return self._x @property - def y(self): + def y(self) -> torch.Tensor: + """ + The output audio data + + :return: (N,) + """ return self._y @property @@ -578,6 +588,9 @@ class ConcatDataset(AbstractDataset, InitializableFromConfig): return self.datasets[i][j] def __len__(self) -> int: + """ + How many data sets are in this data set + """ return sum(len(d) for d in self._datasets) @property @@ -622,8 +635,18 @@ class ConcatDataset(AbstractDataset, InitializableFromConfig): j += 1 lookup[i] = (j, offset) offset += 1 - assert j == len(self.datasets) - 1 - assert offset == len(self.datasets[-1]) + # Assert that we got to the last data set + if j != len(self.datasets) - 1: + raise RuntimeError( + f"During lookup population, didn't get to the last dataset (index " + f"{len(self.datasets)-1}). Instead index ended at {j}." + ) + if offset != len(self.datasets[-1]): + raise RuntimeError( + "During lookup population, didn't end at the index of the last datum " + f"in the last dataset. Expected index {len(self.datasets[-1])}, got " + f"{offset} instead." + ) return lookup @classmethod diff --git a/nam/models/_base.py b/nam/models/_base.py @@ -3,7 +3,7 @@ # Author: Steven Atkinson (steven@atkinson.mn) """ -The foundation of the model without the PyTorch Lightning attributes (losses, training +The foundation of the model without the PyTorch Lightning attributes (losses, training steps) """ diff --git a/nam/models/base.py b/nam/models/base.py @@ -220,7 +220,7 @@ class Model(pl.LightningModule, InitializableFromConfig): return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} def forward(self, *args, **kwargs): - return self.net(*args, **kwargs) + return self.net(*args, **kwargs) # TODO deprecate--use self.net() instead. def _shared_step(self, batch) -> Tuple[torch.Tensor, torch.Tensor]: """ diff --git a/nam/models/metadata.py b/nam/models/metadata.py @@ -23,6 +23,7 @@ class GearType(Enum): PREAMP = "preamp" STUDIO = "studio" + # Note: if you change this enum, you need to update the options in easy_colab.ipynb! class ToneType(Enum): CLEAN = "clean" diff --git a/nam/train/_version.py b/nam/train/_version.py @@ -14,7 +14,11 @@ class Version: self.patch = patch def __eq__(self, other) -> bool: - return self.major == other.major and self.minor == other.minor and self.patch == other.patch + return ( + self.major == other.major + and self.minor == other.minor + and self.patch == other.patch + ) def __lt__(self, other) -> bool: if self.major != other.major: diff --git a/nam/train/colab.py b/nam/train/colab.py @@ -86,7 +86,7 @@ def run( :param lr: The initial learning rate :param lr_decay: The amount by which the learning rate decays each epoch :param seed: RNG seed for reproducibility. - :param user_metadata: To include in the exported model + :param user_metadata: User-specified metadata to include in the .nam file. :param ignore_checks: Ignores the data quality checks and YOLOs it """ diff --git a/nam/train/core.py b/nam/train/core.py @@ -8,15 +8,15 @@ Functions used by the GUI trainer. import hashlib import tkinter as tk +from copy import deepcopy from enum import Enum from time import time -from typing import Optional, Sequence, Union +from typing import Dict, Optional, Sequence, Tuple, Union import matplotlib.pyplot as plt import numpy as np import pytorch_lightning as pl import torch -import torch.nn as nn from torch.utils.data import DataLoader from ..data import REQUIRED_RATE, Split, init_dataset, wav_to_np, wav_to_tensor @@ -134,6 +134,8 @@ def _detect_input_version(input_path) -> Version: _V1_BLIP_LOCATIONS = 12_000, 36_000 _V2_START_BLIP_LOCATIONS = _V1_BLIP_LOCATIONS _V2_END_BLIP_LOCATIONS = -36_000, -12_000 +_DELAY_CALIBRATION_ABS_THRESHOLD = 0.0001 +_DELAY_CALIBRATION_REL_THRESHOLD = 0.001 def _calibrate_delay_v1( @@ -146,7 +148,10 @@ def _calibrate_delay_v1( # Calibrate the trigger: y = wav_to_np(output_path)[:48_000] background_level = np.max(np.abs(y[:6_000])) - trigger_threshold = max(background_level + 0.01, 1.01 * background_level) + trigger_threshold = max( + background_level + _DELAY_CALIBRATION_ABS_THRESHOLD, + (1.0 + _DELAY_CALIBRATION_REL_THRESHOLD) * background_level, + ) delays = [] for blip_index, i in enumerate(locations, 1): @@ -186,7 +191,10 @@ def _calibrate_delay_v1( return delay -_calibrate_delay_v2 = _calibrate_delay_v1 +def _calibrate_delay_v2( + input_path, output_path, locations: Sequence[int] = _V2_START_BLIP_LOCATIONS +) -> int: + return _calibrate_delay_v1(input_path, output_path, locations=locations) def _plot_delay_v1(delay: int, input_path: str, output_path: str, _nofail=True): @@ -477,8 +485,8 @@ def _get_wavenet_config(architecture): def _get_configs( input_version: Version, - input_basename: str, - output_basename: str, + input_path: str, + output_path: str, delay: int, epochs: int, model_type: str, @@ -512,8 +520,8 @@ def _get_configs( "train": {"ny": ny, **train_kwargs}, "validation": {"ny": None, **validation_kwargs}, "common": { - "x_path": input_basename, - "y_path": output_basename, + "x_path": input_path, + "y_path": output_path, "delay": delay, }, } @@ -548,6 +556,7 @@ def _get_configs( "optimizer": {"lr": 0.01}, "lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 0.995}}, } + model_config["loss"]["mrstft_weight"] = 2e-4 if torch.cuda.is_available(): device_config = {"accelerator": "gpu", "devices": 1} @@ -570,6 +579,18 @@ def _get_configs( return data_config, model_config, learning_config +def _get_dataloaders( + data_config: Dict, learning_config: Dict, model: Model +) -> Tuple[DataLoader, DataLoader]: + data_config, learning_config = [deepcopy(c) for c in (data_config, learning_config)] + data_config["common"]["nx"] = model.net.receptive_field + dataset_train = init_dataset(data_config, Split.TRAIN) + dataset_validation = init_dataset(data_config, Split.VALIDATION) + train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"]) + val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"]) + return train_dataloader, val_dataloader + + def _esr(pred: torch.Tensor, target: torch.Tensor) -> float: return ( torch.mean(torch.square(pred - target)).item() @@ -729,11 +750,9 @@ def train( print("Starting training. It's time to kick ass and chew bubblegum!") model = Model.init_from_config(model_config) - data_config["common"]["nx"] = model.net.receptive_field - dataset_train = init_dataset(data_config, Split.TRAIN) - dataset_validation = init_dataset(data_config, Split.VALIDATION) - train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"]) - val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"]) + train_dataloader, val_dataloader = _get_dataloaders( + data_config, learning_config, model + ) trainer = pl.Trainer( callbacks=[ @@ -782,7 +801,7 @@ def train( _plot( model, - dataset_validation, + val_dataloader.dataset, filepath=train_path + "/" + modelname if save_plot else None, silent=silent, **window_kwargs(input_version), diff --git a/nam/train/gui.py b/nam/train/gui.py @@ -257,37 +257,41 @@ class _GUI(object): def _get_additional_options_frame(self): # Checkboxes - self._frame_silent = tk.Frame(self._root) - self._frame_silent.pack(side=tk.LEFT) + self._frame_checkboxes = tk.Frame(self._root) + self._frame_checkboxes.pack(side=tk.LEFT) # Silent run (bypass popups) + row = 1 self._silent = tk.BooleanVar() - self._chkbox_silent = tk.Checkbutton( - self._frame_silent, + self._checkbox_silent = tk.Checkbutton( + self._frame_checkboxes, text="Silent run (suggested for batch training)", variable=self._silent, ) - self._chkbox_silent.grid(row=1, column=1, sticky="W") + self._checkbox_silent.grid(row=row, column=1, sticky="W") + row += 1 # Auto save the end plot self._save_plot = tk.BooleanVar() self._save_plot.set(True) # default this to true - self._chkbox_save_plot = tk.Checkbutton( - self._frame_silent, + self._checkbox_save_plot = tk.Checkbutton( + self._frame_checkboxes, text="Save ESR plot automatically", variable=self._save_plot, ) - self._chkbox_save_plot.grid(row=2, column=1, sticky="W") + self._checkbox_save_plot.grid(row=row, column=1, sticky="W") + row += 1 # Skip the data quality checks! self._ignore_checks = tk.BooleanVar() self._ignore_checks.set(False) - self._chkbox_ignore_checks = tk.Checkbutton( - self._frame_silent, + self._checkbox_ignore_checks = tk.Checkbutton( + self._frame_checkboxes, text="Ignore data quality checks (DO AT YOUR OWN RISK!)", variable=self._ignore_checks, ) - self._chkbox_ignore_checks.grid(row=3, column=1, sticky="W") + self._checkbox_ignore_checks.grid(row=row, column=1, sticky="W") + row += 1 def mainloop(self): self._root.mainloop()