neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 0270e365da1d0cdf45d9d5fd43887c7be6dcc80a
parent adee42ee3f7f9a4aad345c0830998fd8782a5d25
Author: Steven Atkinson <[email protected]>
Date:   Thu,  1 Jun 2023 08:46:52 -0700

Tweak cab modeling (#260)

* fit_cab option in Colab

* Refactor losses, add pre-emphasized MRSTFT

* Fit cab use pre-emphasized MRSTFT only

* Print input version found

* Fix pre-emph

* Squash bugs

* More bugs

* Define private constants for cab modeling MRSTFT

* Always report ESR on validation

* Fix ESR validation

* Fix loss calculation

* Restore MRSTFT weight

* Update gui.py

Put back "Cab modeling" checkbox.
Refactor checkboxes code

* Minor version bump
Diffstat:
Mnam/_version.py | 2+-
Mnam/models/base.py | 150+++++++++++++++++++++++++++++++++++++++++++++++++++----------------------------
Mnam/models/losses.py | 14+++++++++++++-
Mnam/train/colab.py | 3+++
Mnam/train/core.py | 11++++++++++-
Mnam/train/gui.py | 86++++++++++++++++++++++++++++++++++++++++++++++++-------------------------------
Mtests/test_nam/test_models/test_losses.py | 22++++++++++++++++++++++
7 files changed, 199 insertions(+), 89 deletions(-)

diff --git a/nam/_version.py b/nam/_version.py @@ -1 +1 @@ -__version__ = "0.5.3" +__version__ = "0.6.0" diff --git a/nam/models/base.py b/nam/models/base.py @@ -12,7 +12,7 @@ For the base *PyTorch* model containing the actual architecture, see `._base`. from dataclasses import dataclass from enum import Enum -from typing import Optional, Tuple +from typing import Dict, NamedTuple, Optional, Tuple import auraloss import logging @@ -23,7 +23,7 @@ import torch.nn as nn from .._core import InitializableFromConfig from .conv_net import ConvNet from .linear import Linear -from .losses import esr, multi_resolution_stft_loss, mse_fft +from .losses import apply_pre_emphasis_filter, esr, multi_resolution_stft_loss, mse_fft from .parametric.catnets import CatLSTM, CatWaveNet from .parametric.hyper_net import HyperConvNet from .recurrent import LSTM @@ -51,39 +51,39 @@ class ValidationLoss(Enum): @dataclass class LossConfig(InitializableFromConfig): """ + :param mrstft_weight: Multi-resolution short-time Fourier transform loss + coefficient. None means to skip; 2e-4 works pretty well if one wants to use it. :param mask_first: How many of the first samples to ignore when comptuing the loss. :param dc_weight: Weight for the DC loss term. If 0, ignored. :params val_loss: Which loss to track for the best model checkpoint. :param pre_emph_coef: Coefficient of 1st-order pre-emphasis filter from https://www.mdpi.com/2076-3417/10/3/766. Paper value: 0.95. + :param pre_ """ - mrstft_weight: float = 0.0 # 0.0 means no multiresolution stft loss, 2e-4 works pretty well if one wants to use it + mrstft_weight: Optional[float] = None fourier: bool = False mask_first: int = 0 - dc_weight: float = 0.0 + dc_weight: float = None val_loss: ValidationLoss = ValidationLoss.MSE pre_emph_weight: Optional[float] = None pre_emph_coef: Optional[float] = None + pre_emph_mrstft_weight: Optional[float] = None + pre_emph_mrstft_coef: Optional[float] = None @classmethod def parse_config(cls, config): config = super().parse_config(config) - fourier = config.get("fourier", False) - dc_weight = config.get("dc_weight", 0.0) - val_loss = ValidationLoss(config.get("val_loss", "mse")) - mask_first = config.get("mask_first", 0) - pre_emph_coef = config.get("pre_emph_coef") - pre_emph_weight = config.get("pre_emph_weight") - mrstft_weight = cls._get_mrstft_weight(config) return { - "fourier": fourier, - "mask_first": mask_first, - "dc_weight": dc_weight, - "val_loss": val_loss, - "pre_emph_coef": pre_emph_coef, - "pre_emph_weight": pre_emph_weight, - "mrstft_weight": mrstft_weight, + "fourier": config.get("fourier", False), + "mask_first": config.get("mask_first", 0), + "dc_weight": config.get("dc_weight"), + "val_loss": ValidationLoss(config.get("val_loss", "mse")), + "pre_emph_coef": config.get("pre_emph_coef"), + "pre_emph_weight": config.get("pre_emph_weight"), + "mrstft_weight": cls._get_mrstft_weight(config), + "pre_emph_mrstft_weight": config.get("pre_emph_mrstft_weight"), + "pre_emph_mrstft_coef": config.get("pre_emph_mrstft_coef"), } def apply_mask(self, *args): @@ -94,7 +94,7 @@ class LossConfig(InitializableFromConfig): return tuple(a[..., self.mask_first :] for a in args) @classmethod - def _get_mrstft_weight(cls, config) -> float: + def _get_mrstft_weight(cls, config) -> Optional[float]: key = "mrstft_weight" wrong_key = "mstft_key" # Backward compatibility if key in config: @@ -111,7 +111,12 @@ class LossConfig(InitializableFromConfig): ) return config[wrong_key] else: - return 0.0 + return None + + +class _LossItem(NamedTuple): + weight: Optional[float] + value: Optional[torch.Tensor] class Model(pl.LightningModule, InitializableFromConfig): @@ -222,7 +227,9 @@ class Model(pl.LightningModule, InitializableFromConfig): def forward(self, *args, **kwargs): return self.net(*args, **kwargs) # TODO deprecate--use self.net() instead. - def _shared_step(self, batch) -> Tuple[torch.Tensor, torch.Tensor]: + def _shared_step( + self, batch + ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, _LossItem]]: """ B: Batch size L: Sequence length @@ -232,55 +239,85 @@ class Model(pl.LightningModule, InitializableFromConfig): args, targets = batch[:-1], batch[-1] preds = self(*args, pad_start=False) - return preds, targets - - def training_step(self, batch, batch_idx): - preds, targets = self._shared_step(batch) - - loss = 0.0 + # Compute all relevant losses. + loss_dict = {} # Mind keys versus validation loss requested... # Prediction aka MSE loss if self._loss_config.fourier: - loss = loss + mse_fft(preds, targets) + loss_dict["MSE_FFT"] = _LossItem(1.0, mse_fft(preds, targets)) else: - loss = loss + self._mse_loss(preds, targets) - if self._loss_config.mrstft_weight > 0.0: - loss = loss + self._loss_config.mrstft_weight * self._mrstft_loss( - preds, targets - ) + loss_dict["MSE"] = _LossItem(1.0, self._mse_loss(preds, targets)) # Pre-emphasized MSE if self._loss_config.pre_emph_weight is not None: if (self._loss_config.pre_emph_coef is None) != ( self._loss_config.pre_emph_weight is None ): raise ValueError("Invalid pre-emph") - loss = loss + self._loss_config.pre_emph_weight * self._mse_loss( - preds, targets, pre_emph_coef=self._loss_config.pre_emph_coef + loss_dict["Pre-emphasized MSE"] = _LossItem( + self._loss_config.pre_emph_weight, + self._mse_loss( + preds, targets, pre_emph_coef=self._loss_config.pre_emph_coef + ), + ) + # Multi-resolution short-time Fourier transform loss + if self._loss_config.mrstft_weight is not None: + loss_dict["MRSTFT"] = _LossItem( + self._loss_config.mrstft_weight, self._mrstft_loss(preds, targets) + ) + # Pre-emphasized MRSTFT + if self._loss_config.pre_emph_mrstft_weight is not None: + loss_dict["Pre-emphasized MRSTFT"] = _LossItem( + self._loss_config.pre_emph_mrstft_weight, + self._mrstft_loss( + preds, targets, pre_emph_coef=self._loss_config.pre_emph_mrstft_coef + ), ) - # DC loss dc_weight = self._loss_config.dc_weight - if dc_weight > 0.0: + if dc_weight is not None and dc_weight > 0.0: # Denominator could be a bad idea. I'm going to omit it esp since I'm # using mini batches mean_dims = torch.arange(1, preds.ndim).tolist() dc_loss = nn.MSELoss()( preds.mean(dim=mean_dims), targets.mean(dim=mean_dims) ) - loss = loss + dc_weight * dc_loss + loss_dict["DC MSE"] = _LossItem(dc_weight, dc_loss) + + return preds, targets, loss_dict + + def training_step(self, batch, batch_idx): + _, _, loss_dict = self._shared_step(batch) + + loss = 0.0 + for v in loss_dict.values(): + if v.weight is not None and v.weight > 0.0: + loss = loss + v.weight * v.value return loss def validation_step(self, batch, batch_idx): - preds, targets = self._shared_step(batch) - mse_loss = self._mse_loss(preds, targets) - esr_loss = self._esr_loss(preds, targets) - val_loss = {ValidationLoss.MSE: mse_loss, ValidationLoss.ESR: esr_loss}[ - self._loss_config.val_loss - ] - dict_to_log = {"MSE": mse_loss, "ESR": esr_loss, "val_loss": val_loss} - if self._loss_config.mrstft_weight > 0.0 and self._mrstft is not None: - mrstft_loss = self._mrstft_loss(preds, targets) - dict_to_log.update({"MRSTFT": mrstft_loss}) - self.log_dict(dict_to_log) + preds, targets, loss_dict = self._shared_step(batch) + + def get_val_loss(): + # "esr" -> "ESR" + # "mse" -> "MSE" + # Others unsupported... + # TODO better mapping from Enum to dict keys + val_loss_type = self._loss_config.val_loss + val_loss_key_for_loss_dict = val_loss_type.value.upper() + if val_loss_key_for_loss_dict in loss_dict: + return loss_dict[val_loss_key_for_loss_dict].value + else: + raise RuntimeError( + f"Undefined validation loss routine for {val_loss_type}" + ) + + loss_dict["ESR"] = _LossItem(None, self._esr_loss(preds, targets)) + val_loss = get_val_loss() + self.log_dict( + { + "val_loss": val_loss, + **{key: value.value for key, value in loss_dict.items()}, + } + ) return val_loss def _esr_loss(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: @@ -302,11 +339,16 @@ class Model(pl.LightningModule, InitializableFromConfig): def _mse_loss(self, preds, targets, pre_emph_coef: Optional[float] = None): if pre_emph_coef is not None: preds, targets = [ - z[..., 1:] - pre_emph_coef * z[..., :-1] for z in (preds, targets) + apply_pre_emphasis_filter(z, pre_emph_coef) for z in (preds, targets) ] return nn.MSELoss()(preds, targets) - def _mrstft_loss(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + def _mrstft_loss( + self, + preds: torch.Tensor, + targets: torch.Tensor, + pre_emph_coef: Optional[float] = None, + ) -> torch.Tensor: """ Experimental Multi Resolution Short Time Fourier Transform Loss using auraloss implementation. B: Batch size @@ -318,9 +360,13 @@ class Model(pl.LightningModule, InitializableFromConfig): """ if self._mrstft is None: self._mrstft = auraloss.freq.MultiResolutionSTFTLoss() - backup_device = "cpu" + if pre_emph_coef is not None: + preds, targets = [ + apply_pre_emphasis_filter(z, pre_emph_coef) for z in (preds, targets) + ] + try: return multi_resolution_stft_loss( preds, targets, self._mrstft, device=self._mrstft_device diff --git a/nam/models/losses.py b/nam/models/losses.py @@ -11,7 +11,19 @@ from typing import Optional import torch from auraloss.freq import MultiResolutionSTFTLoss -___all__ = ["esr", "multi_resolution_stft_loss"] +___all__ = ["apply_pre_emphasis_filter", "esr", "multi_resolution_stft_loss"] + + +def apply_pre_emphasis_filter(x: torch.Tensor, coef: float) -> torch.Tensor: + """ + Apply first-order pre-emphsis filter + + :param x: (*, L) + :param coef: The coefficient + + :return: (*, L-1) + """ + return x[..., 1:] - coef * x[..., :-1] def esr(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: diff --git a/nam/train/colab.py b/nam/train/colab.py @@ -53,6 +53,7 @@ def _check_for_files() -> Tuple[Version, str]: raise FileNotFoundError( f"Didn't find your reamped output audio file. Please upload {_OUTPUT_BASENAME}." ) + print(f"Found {input_basename}, version {input_version}") return input_version, input_basename @@ -76,6 +77,7 @@ def run( seed: Optional[int] = 0, user_metadata: Optional[UserMetadata] = None, ignore_checks: bool = False, + fit_cab: bool = False, ): """ :param epochs: How amny epochs we'll train for. @@ -106,6 +108,7 @@ def run( seed=seed, local=False, ignore_checks=ignore_checks, + fit_cab=fit_cab, ) if model is None: diff --git a/nam/train/core.py b/nam/train/core.py @@ -555,6 +555,10 @@ def _get_wavenet_config(architecture): }[architecture] +_CAB_MRSTFT_PRE_EMPH_WEIGHT = 2.0e-4 +_CAB_MRSTFT_PRE_EMPH_COEF = 0.85 + + def _get_configs( input_version: Version, input_path: str, @@ -567,6 +571,7 @@ def _get_configs( lr: float, lr_decay: float, batch_size: int, + fit_cab: bool, ): def get_kwargs(data_info: _DataInfo): if data_info.major_version == 1: @@ -625,7 +630,9 @@ def _get_configs( "optimizer": {"lr": 0.01}, "lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 0.995}}, } - model_config["loss"]["mrstft_weight"] = 2e-4 + if fit_cab: + model_config["loss"]["pre_emph_mrstft_weight"] = _CAB_MRSTFT_PRE_EMPH_WEIGHT + model_config["loss"]["pre_emph_mrstft_coef"] = _CAB_MRSTFT_PRE_EMPH_COEF if torch.cuda.is_available(): device_config = {"accelerator": "gpu", "devices": 1} @@ -771,6 +778,7 @@ def train( modelname: str = "model", ignore_checks: bool = False, local: bool = False, + fit_cab: bool = False, ) -> Optional[Model]: if seed is not None: torch.manual_seed(seed) @@ -815,6 +823,7 @@ def train( lr, lr_decay, batch_size, + fit_cab, ) print("Starting training. It's time to kick ass and chew bubblegum!") diff --git a/nam/train/gui.py b/nam/train/gui.py @@ -32,7 +32,7 @@ from enum import Enum from functools import partial from pathlib import Path from tkinter import filedialog -from typing import Callable, Optional, Sequence +from typing import Callable, Dict, Optional, Sequence try: import torch @@ -148,6 +148,17 @@ class _PathButton(object): h() +class _CheckboxKeys(Enum): + """ + Keys for checkboxes + """ + + FIT_CAB = "fit_cab" + SILENT_TRAINING = "silent_training" + SAVE_PLOT = "save_plot" + IGNORE_DATA_CHECKS = "ignore_data_checks" + + class _GUI(object): def __init__(self): self._root = tk.Tk() @@ -257,41 +268,45 @@ class _GUI(object): def _get_additional_options_frame(self): # Checkboxes + # TODO get these definitions into __init__() self._frame_checkboxes = tk.Frame(self._root) self._frame_checkboxes.pack(side=tk.LEFT) - - # Silent run (bypass popups) row = 1 - self._silent = tk.BooleanVar() - self._checkbox_silent = tk.Checkbutton( - self._frame_checkboxes, - text="Silent run (suggested for batch training)", - variable=self._silent, - ) - self._checkbox_silent.grid(row=row, column=1, sticky="W") - row += 1 - - # Auto save the end plot - self._save_plot = tk.BooleanVar() - self._save_plot.set(True) # default this to true - self._checkbox_save_plot = tk.Checkbutton( - self._frame_checkboxes, - text="Save ESR plot automatically", - variable=self._save_plot, + + @dataclass + class Checkbox(object): + variable: tk.BooleanVar + check_button: tk.Checkbutton + + def make_checkbox( + key: _CheckboxKeys, text: str, default_value: bool + ) -> Checkbox: + variable = tk.BooleanVar() + variable.set(default_value) + check_button = tk.Checkbutton( + self._frame_checkboxes, text=text, variable=variable + ) + self._checkboxes[key] = Checkbox(variable, check_button) + + self._checkboxes: Dict[_CheckboxKeys, Checkbox] = dict() + make_checkbox(_CheckboxKeys.FIT_CAB, "Cab modeling", False) + make_checkbox( + _CheckboxKeys.SILENT_TRAINING, + "Silent run (suggested for batch training)", + False, ) - self._checkbox_save_plot.grid(row=row, column=1, sticky="W") - row += 1 - - # Skip the data quality checks! - self._ignore_checks = tk.BooleanVar() - self._ignore_checks.set(False) - self._checkbox_ignore_checks = tk.Checkbutton( - self._frame_checkboxes, - text="Ignore data quality checks (DO AT YOUR OWN RISK!)", - variable=self._ignore_checks, + make_checkbox(_CheckboxKeys.SAVE_PLOT, "Save ESR plot automatically", True) + make_checkbox( + _CheckboxKeys.IGNORE_DATA_CHECKS, + "Ignore data quality checks (DO AT YOUR OWN RISK!)", + False, ) - self._checkbox_ignore_checks.grid(row=row, column=1, sticky="W") - row += 1 + + # Grid them: + row = 1 + for v in self._checkboxes.values(): + v.check_button.grid(row=row, column=1, sticky="W") + row += 1 def mainloop(self): self._root.mainloop() @@ -344,11 +359,14 @@ class _GUI(object): lr=lr, lr_decay=lr_decay, seed=seed, - silent=self._silent.get(), - save_plot=self._save_plot.get(), + silent=self._checkboxes[_CheckboxKeys.SILENT_TRAINING].variable.get(), + save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(), modelname=basename, - ignore_checks=self._ignore_checks.get(), + ignore_checks=self._checkboxes[ + _CheckboxKeys.IGNORE_DATA_CHECKS + ].variable.get(), local=True, + fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(), ) if trained_model is None: print("Model training failed! Skip exporting...") diff --git a/tests/test_nam/test_models/test_losses.py b/tests/test_nam/test_models/test_losses.py @@ -9,6 +9,28 @@ import torch.nn as nn from nam.models import losses [email protected]( + "x,coef,y_expected", + ( + (torch.Tensor([0.0, 1.0, 2.0]), 1.0, torch.Tensor([1.0, 1.0])), + (torch.Tensor([0.0, 1.0, 2.0]), 0.5, torch.Tensor([1.0, 1.5])), + ( + torch.Tensor([[0.0, 1.0, 0.0], [1.0, 1.5, 2.0]]), + 0.5, + torch.Tensor([[1.0, -0.5], [1.0, 1.25]]), + ), + ), +) +def test_apply_pre_emphasis_filter_1d( + x: torch.Tensor, coef: float, y_expected: torch.Tensor +): + y_actual = losses.apply_pre_emphasis_filter(x, coef) + assert isinstance(y_actual, torch.Tensor) + assert y_actual.ndim == y_expected.ndim + assert y_actual.shape == y_expected.shape + assert torch.allclose(y_actual, y_expected) + + def test_esr(): """ Is the ESR calculation correct?