neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 8fecd53188e05df328b6762f70fe0ca03350a8c5
parent 6a610ec9cb89f47d57b2d1d30779c21a642e649b
Author: Steven Atkinson <[email protected]>
Date:   Sun, 15 Sep 2024 18:30:38 -0700

[BREAKING] Remove cab-fitting checkbox from GUI trainer (#462)

* Remove cab-fitting option from GUI trainer.

It's always on.
fit_cab to fit_mrstft.
Remove from Settings metadata since it's always on.

* colab.run: rename fit_cab kwarg to fit_mrstft, on by default
Diffstat:
Mnam/train/colab.py | 4++--
Mnam/train/core.py | 14++++++--------
Mnam/train/gui/__init__.py | 16+++++++++++++---
Mnam/train/metadata.py | 12------------
4 files changed, 21 insertions(+), 25 deletions(-)

diff --git a/nam/train/colab.py b/nam/train/colab.py @@ -84,7 +84,7 @@ def run( seed: Optional[int] = 0, user_metadata: Optional[UserMetadata] = None, ignore_checks: bool = False, - fit_cab: bool = False, + fit_mrstft: bool = True, ): """ :param epochs: How many epochs we'll train for. @@ -115,7 +115,7 @@ def run( seed=seed, local=False, ignore_checks=ignore_checks, - fit_cab=fit_cab, + fit_mrstft=fit_mrstft, ) model = train_output.model training_metadata = train_output.metadata diff --git a/nam/train/core.py b/nam/train/core.py @@ -971,7 +971,7 @@ def _get_configs( lr: float, lr_decay: float, batch_size: int, - fit_cab: bool, + fit_mrstft: bool, ): data_config = _get_data_config( @@ -1012,7 +1012,7 @@ def _get_configs( "optimizer": {"lr": 0.01}, "lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 0.995}}, } - if fit_cab: + if fit_mrstft: model_config["loss"]["pre_emph_mrstft_weight"] = _CAB_MRSTFT_PRE_EMPH_WEIGHT model_config["loss"]["pre_emph_mrstft_coef"] = _CAB_MRSTFT_PRE_EMPH_COEF @@ -1295,7 +1295,7 @@ def train( modelname: str = "model", ignore_checks: bool = False, local: bool = False, - fit_cab: bool = False, + fit_mrstft: bool = True, threshold_esr: Optional[bool] = None, user_metadata: Optional[UserMetadata] = None, fast_dev_run: Union[bool, int] = False, @@ -1351,9 +1351,7 @@ def train( return TrainOutput( model=None, metadata=metadata.TrainingMetadata( - settings=metadata.Settings( - fit_cab=fit_cab, ignore_checks=ignore_checks - ), + settings=metadata.Settings(ignore_checks=ignore_checks), data=metadata.Data( latency=latency_analysis, checks=data_check_output ), @@ -1373,7 +1371,7 @@ def train( lr, lr_decay, batch_size, - fit_cab, + fit_mrstft, ) assert ( "fast_dev_run" not in learning_config @@ -1399,7 +1397,7 @@ def train( model.net.sample_rate = sample_rate # Put together the metadata that's needed in checkpoints: - settings_metadata = metadata.Settings(fit_cab=fit_cab, ignore_checks=ignore_checks) + settings_metadata = metadata.Settings(ignore_checks=ignore_checks) data_metadata = metadata.Data(latency=latency_analysis, checks=data_check_output) trainer = pl.Trainer( diff --git a/nam/train/gui/__init__.py b/nam/train/gui/__init__.py @@ -259,7 +259,6 @@ class _CheckboxKeys(Enum): Keys for checkboxes """ - FIT_CAB = "fit_cab" SILENT_TRAINING = "silent_training" SAVE_PLOT = "save_plot" @@ -484,6 +483,18 @@ class _GUI(object): self._check_button_states() + def get_mrstft_fit(self) -> bool: + """ + Use a pre-emphasized multi-resolution shot-time Fourier transform loss during + training. + + This improves agreement in the high frequencies, usually with a minimial loss in + ESR. + """ + # Leave this as a public method to anticipate an extension to make it + # changeable. + return True + def _check_button_states(self): """ Determine if any buttons should be disabled @@ -525,7 +536,6 @@ class _GUI(object): self._widgets[key] = check_button # For tracking in set-all-widgets ops self._checkboxes: Dict[_CheckboxKeys, Checkbox] = dict() - make_checkbox(_CheckboxKeys.FIT_CAB, "Cab modeling", False) make_checkbox( _CheckboxKeys.SILENT_TRAINING, "Silent run (suggested for batch training)", @@ -616,7 +626,7 @@ class _GUI(object): modelname=basename, ignore_checks=ignore_checks, local=True, - fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(), + fit_mrstft=self.get_mrstft_fit(), threshold_esr=threshold_esr, user_metadata=user_metadata, ) diff --git a/nam/train/metadata.py b/nam/train/metadata.py @@ -13,17 +13,6 @@ from typing import List, Optional from pydantic import BaseModel -__all__ = [ - "Data", - "DataChecks", - "Latency", - "LatencyCalibration", - "LatencyCalibrationWarnings", - "Settings", - "TrainingMetadata", - "TRAINING_KEY", -] - # The key under which the metadata are saved in the .nam: TRAINING_KEY = "training" @@ -33,7 +22,6 @@ class Settings(BaseModel): User-provided settings """ - fit_cab: bool ignore_checks: bool