commit 8fecd53188e05df328b6762f70fe0ca03350a8c5
parent 6a610ec9cb89f47d57b2d1d30779c21a642e649b
Author: Steven Atkinson <[email protected]>
Date: Sun, 15 Sep 2024 18:30:38 -0700
[BREAKING] Remove cab-fitting checkbox from GUI trainer (#462)
* Remove cab-fitting option from GUI trainer.
It's always on.
fit_cab to fit_mrstft.
Remove from Settings metadata since it's always on.
* colab.run: rename fit_cab kwarg to fit_mrstft, on by default
Diffstat:
4 files changed, 21 insertions(+), 25 deletions(-)
diff --git a/nam/train/colab.py b/nam/train/colab.py
@@ -84,7 +84,7 @@ def run(
seed: Optional[int] = 0,
user_metadata: Optional[UserMetadata] = None,
ignore_checks: bool = False,
- fit_cab: bool = False,
+ fit_mrstft: bool = True,
):
"""
:param epochs: How many epochs we'll train for.
@@ -115,7 +115,7 @@ def run(
seed=seed,
local=False,
ignore_checks=ignore_checks,
- fit_cab=fit_cab,
+ fit_mrstft=fit_mrstft,
)
model = train_output.model
training_metadata = train_output.metadata
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -971,7 +971,7 @@ def _get_configs(
lr: float,
lr_decay: float,
batch_size: int,
- fit_cab: bool,
+ fit_mrstft: bool,
):
data_config = _get_data_config(
@@ -1012,7 +1012,7 @@ def _get_configs(
"optimizer": {"lr": 0.01},
"lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 0.995}},
}
- if fit_cab:
+ if fit_mrstft:
model_config["loss"]["pre_emph_mrstft_weight"] = _CAB_MRSTFT_PRE_EMPH_WEIGHT
model_config["loss"]["pre_emph_mrstft_coef"] = _CAB_MRSTFT_PRE_EMPH_COEF
@@ -1295,7 +1295,7 @@ def train(
modelname: str = "model",
ignore_checks: bool = False,
local: bool = False,
- fit_cab: bool = False,
+ fit_mrstft: bool = True,
threshold_esr: Optional[bool] = None,
user_metadata: Optional[UserMetadata] = None,
fast_dev_run: Union[bool, int] = False,
@@ -1351,9 +1351,7 @@ def train(
return TrainOutput(
model=None,
metadata=metadata.TrainingMetadata(
- settings=metadata.Settings(
- fit_cab=fit_cab, ignore_checks=ignore_checks
- ),
+ settings=metadata.Settings(ignore_checks=ignore_checks),
data=metadata.Data(
latency=latency_analysis, checks=data_check_output
),
@@ -1373,7 +1371,7 @@ def train(
lr,
lr_decay,
batch_size,
- fit_cab,
+ fit_mrstft,
)
assert (
"fast_dev_run" not in learning_config
@@ -1399,7 +1397,7 @@ def train(
model.net.sample_rate = sample_rate
# Put together the metadata that's needed in checkpoints:
- settings_metadata = metadata.Settings(fit_cab=fit_cab, ignore_checks=ignore_checks)
+ settings_metadata = metadata.Settings(ignore_checks=ignore_checks)
data_metadata = metadata.Data(latency=latency_analysis, checks=data_check_output)
trainer = pl.Trainer(
diff --git a/nam/train/gui/__init__.py b/nam/train/gui/__init__.py
@@ -259,7 +259,6 @@ class _CheckboxKeys(Enum):
Keys for checkboxes
"""
- FIT_CAB = "fit_cab"
SILENT_TRAINING = "silent_training"
SAVE_PLOT = "save_plot"
@@ -484,6 +483,18 @@ class _GUI(object):
self._check_button_states()
+ def get_mrstft_fit(self) -> bool:
+ """
+ Use a pre-emphasized multi-resolution shot-time Fourier transform loss during
+ training.
+
+ This improves agreement in the high frequencies, usually with a minimial loss in
+ ESR.
+ """
+ # Leave this as a public method to anticipate an extension to make it
+ # changeable.
+ return True
+
def _check_button_states(self):
"""
Determine if any buttons should be disabled
@@ -525,7 +536,6 @@ class _GUI(object):
self._widgets[key] = check_button # For tracking in set-all-widgets ops
self._checkboxes: Dict[_CheckboxKeys, Checkbox] = dict()
- make_checkbox(_CheckboxKeys.FIT_CAB, "Cab modeling", False)
make_checkbox(
_CheckboxKeys.SILENT_TRAINING,
"Silent run (suggested for batch training)",
@@ -616,7 +626,7 @@ class _GUI(object):
modelname=basename,
ignore_checks=ignore_checks,
local=True,
- fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(),
+ fit_mrstft=self.get_mrstft_fit(),
threshold_esr=threshold_esr,
user_metadata=user_metadata,
)
diff --git a/nam/train/metadata.py b/nam/train/metadata.py
@@ -13,17 +13,6 @@ from typing import List, Optional
from pydantic import BaseModel
-__all__ = [
- "Data",
- "DataChecks",
- "Latency",
- "LatencyCalibration",
- "LatencyCalibrationWarnings",
- "Settings",
- "TrainingMetadata",
- "TRAINING_KEY",
-]
-
# The key under which the metadata are saved in the .nam:
TRAINING_KEY = "training"
@@ -33,7 +22,6 @@ class Settings(BaseModel):
User-provided settings
"""
- fit_cab: bool
ignore_checks: bool