commit 0270e365da1d0cdf45d9d5fd43887c7be6dcc80a
parent adee42ee3f7f9a4aad345c0830998fd8782a5d25
Author: Steven Atkinson <[email protected]>
Date: Thu, 1 Jun 2023 08:46:52 -0700
Tweak cab modeling (#260)
* fit_cab option in Colab
* Refactor losses, add pre-emphasized MRSTFT
* Fit cab use pre-emphasized MRSTFT only
* Print input version found
* Fix pre-emph
* Squash bugs
* More bugs
* Define private constants for cab modeling MRSTFT
* Always report ESR on validation
* Fix ESR validation
* Fix loss calculation
* Restore MRSTFT weight
* Update gui.py
Put back "Cab modeling" checkbox.
Refactor checkboxes code
* Minor version bump
Diffstat:
7 files changed, 199 insertions(+), 89 deletions(-)
diff --git a/nam/_version.py b/nam/_version.py
@@ -1 +1 @@
-__version__ = "0.5.3"
+__version__ = "0.6.0"
diff --git a/nam/models/base.py b/nam/models/base.py
@@ -12,7 +12,7 @@ For the base *PyTorch* model containing the actual architecture, see `._base`.
from dataclasses import dataclass
from enum import Enum
-from typing import Optional, Tuple
+from typing import Dict, NamedTuple, Optional, Tuple
import auraloss
import logging
@@ -23,7 +23,7 @@ import torch.nn as nn
from .._core import InitializableFromConfig
from .conv_net import ConvNet
from .linear import Linear
-from .losses import esr, multi_resolution_stft_loss, mse_fft
+from .losses import apply_pre_emphasis_filter, esr, multi_resolution_stft_loss, mse_fft
from .parametric.catnets import CatLSTM, CatWaveNet
from .parametric.hyper_net import HyperConvNet
from .recurrent import LSTM
@@ -51,39 +51,39 @@ class ValidationLoss(Enum):
@dataclass
class LossConfig(InitializableFromConfig):
"""
+ :param mrstft_weight: Multi-resolution short-time Fourier transform loss
+ coefficient. None means to skip; 2e-4 works pretty well if one wants to use it.
:param mask_first: How many of the first samples to ignore when comptuing the loss.
:param dc_weight: Weight for the DC loss term. If 0, ignored.
:params val_loss: Which loss to track for the best model checkpoint.
:param pre_emph_coef: Coefficient of 1st-order pre-emphasis filter from
https://www.mdpi.com/2076-3417/10/3/766. Paper value: 0.95.
+ :param pre_
"""
- mrstft_weight: float = 0.0 # 0.0 means no multiresolution stft loss, 2e-4 works pretty well if one wants to use it
+ mrstft_weight: Optional[float] = None
fourier: bool = False
mask_first: int = 0
- dc_weight: float = 0.0
+ dc_weight: float = None
val_loss: ValidationLoss = ValidationLoss.MSE
pre_emph_weight: Optional[float] = None
pre_emph_coef: Optional[float] = None
+ pre_emph_mrstft_weight: Optional[float] = None
+ pre_emph_mrstft_coef: Optional[float] = None
@classmethod
def parse_config(cls, config):
config = super().parse_config(config)
- fourier = config.get("fourier", False)
- dc_weight = config.get("dc_weight", 0.0)
- val_loss = ValidationLoss(config.get("val_loss", "mse"))
- mask_first = config.get("mask_first", 0)
- pre_emph_coef = config.get("pre_emph_coef")
- pre_emph_weight = config.get("pre_emph_weight")
- mrstft_weight = cls._get_mrstft_weight(config)
return {
- "fourier": fourier,
- "mask_first": mask_first,
- "dc_weight": dc_weight,
- "val_loss": val_loss,
- "pre_emph_coef": pre_emph_coef,
- "pre_emph_weight": pre_emph_weight,
- "mrstft_weight": mrstft_weight,
+ "fourier": config.get("fourier", False),
+ "mask_first": config.get("mask_first", 0),
+ "dc_weight": config.get("dc_weight"),
+ "val_loss": ValidationLoss(config.get("val_loss", "mse")),
+ "pre_emph_coef": config.get("pre_emph_coef"),
+ "pre_emph_weight": config.get("pre_emph_weight"),
+ "mrstft_weight": cls._get_mrstft_weight(config),
+ "pre_emph_mrstft_weight": config.get("pre_emph_mrstft_weight"),
+ "pre_emph_mrstft_coef": config.get("pre_emph_mrstft_coef"),
}
def apply_mask(self, *args):
@@ -94,7 +94,7 @@ class LossConfig(InitializableFromConfig):
return tuple(a[..., self.mask_first :] for a in args)
@classmethod
- def _get_mrstft_weight(cls, config) -> float:
+ def _get_mrstft_weight(cls, config) -> Optional[float]:
key = "mrstft_weight"
wrong_key = "mstft_key" # Backward compatibility
if key in config:
@@ -111,7 +111,12 @@ class LossConfig(InitializableFromConfig):
)
return config[wrong_key]
else:
- return 0.0
+ return None
+
+
+class _LossItem(NamedTuple):
+ weight: Optional[float]
+ value: Optional[torch.Tensor]
class Model(pl.LightningModule, InitializableFromConfig):
@@ -222,7 +227,9 @@ class Model(pl.LightningModule, InitializableFromConfig):
def forward(self, *args, **kwargs):
return self.net(*args, **kwargs) # TODO deprecate--use self.net() instead.
- def _shared_step(self, batch) -> Tuple[torch.Tensor, torch.Tensor]:
+ def _shared_step(
+ self, batch
+ ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, _LossItem]]:
"""
B: Batch size
L: Sequence length
@@ -232,55 +239,85 @@ class Model(pl.LightningModule, InitializableFromConfig):
args, targets = batch[:-1], batch[-1]
preds = self(*args, pad_start=False)
- return preds, targets
-
- def training_step(self, batch, batch_idx):
- preds, targets = self._shared_step(batch)
-
- loss = 0.0
+ # Compute all relevant losses.
+ loss_dict = {} # Mind keys versus validation loss requested...
# Prediction aka MSE loss
if self._loss_config.fourier:
- loss = loss + mse_fft(preds, targets)
+ loss_dict["MSE_FFT"] = _LossItem(1.0, mse_fft(preds, targets))
else:
- loss = loss + self._mse_loss(preds, targets)
- if self._loss_config.mrstft_weight > 0.0:
- loss = loss + self._loss_config.mrstft_weight * self._mrstft_loss(
- preds, targets
- )
+ loss_dict["MSE"] = _LossItem(1.0, self._mse_loss(preds, targets))
# Pre-emphasized MSE
if self._loss_config.pre_emph_weight is not None:
if (self._loss_config.pre_emph_coef is None) != (
self._loss_config.pre_emph_weight is None
):
raise ValueError("Invalid pre-emph")
- loss = loss + self._loss_config.pre_emph_weight * self._mse_loss(
- preds, targets, pre_emph_coef=self._loss_config.pre_emph_coef
+ loss_dict["Pre-emphasized MSE"] = _LossItem(
+ self._loss_config.pre_emph_weight,
+ self._mse_loss(
+ preds, targets, pre_emph_coef=self._loss_config.pre_emph_coef
+ ),
+ )
+ # Multi-resolution short-time Fourier transform loss
+ if self._loss_config.mrstft_weight is not None:
+ loss_dict["MRSTFT"] = _LossItem(
+ self._loss_config.mrstft_weight, self._mrstft_loss(preds, targets)
+ )
+ # Pre-emphasized MRSTFT
+ if self._loss_config.pre_emph_mrstft_weight is not None:
+ loss_dict["Pre-emphasized MRSTFT"] = _LossItem(
+ self._loss_config.pre_emph_mrstft_weight,
+ self._mrstft_loss(
+ preds, targets, pre_emph_coef=self._loss_config.pre_emph_mrstft_coef
+ ),
)
-
# DC loss
dc_weight = self._loss_config.dc_weight
- if dc_weight > 0.0:
+ if dc_weight is not None and dc_weight > 0.0:
# Denominator could be a bad idea. I'm going to omit it esp since I'm
# using mini batches
mean_dims = torch.arange(1, preds.ndim).tolist()
dc_loss = nn.MSELoss()(
preds.mean(dim=mean_dims), targets.mean(dim=mean_dims)
)
- loss = loss + dc_weight * dc_loss
+ loss_dict["DC MSE"] = _LossItem(dc_weight, dc_loss)
+
+ return preds, targets, loss_dict
+
+ def training_step(self, batch, batch_idx):
+ _, _, loss_dict = self._shared_step(batch)
+
+ loss = 0.0
+ for v in loss_dict.values():
+ if v.weight is not None and v.weight > 0.0:
+ loss = loss + v.weight * v.value
return loss
def validation_step(self, batch, batch_idx):
- preds, targets = self._shared_step(batch)
- mse_loss = self._mse_loss(preds, targets)
- esr_loss = self._esr_loss(preds, targets)
- val_loss = {ValidationLoss.MSE: mse_loss, ValidationLoss.ESR: esr_loss}[
- self._loss_config.val_loss
- ]
- dict_to_log = {"MSE": mse_loss, "ESR": esr_loss, "val_loss": val_loss}
- if self._loss_config.mrstft_weight > 0.0 and self._mrstft is not None:
- mrstft_loss = self._mrstft_loss(preds, targets)
- dict_to_log.update({"MRSTFT": mrstft_loss})
- self.log_dict(dict_to_log)
+ preds, targets, loss_dict = self._shared_step(batch)
+
+ def get_val_loss():
+ # "esr" -> "ESR"
+ # "mse" -> "MSE"
+ # Others unsupported...
+ # TODO better mapping from Enum to dict keys
+ val_loss_type = self._loss_config.val_loss
+ val_loss_key_for_loss_dict = val_loss_type.value.upper()
+ if val_loss_key_for_loss_dict in loss_dict:
+ return loss_dict[val_loss_key_for_loss_dict].value
+ else:
+ raise RuntimeError(
+ f"Undefined validation loss routine for {val_loss_type}"
+ )
+
+ loss_dict["ESR"] = _LossItem(None, self._esr_loss(preds, targets))
+ val_loss = get_val_loss()
+ self.log_dict(
+ {
+ "val_loss": val_loss,
+ **{key: value.value for key, value in loss_dict.items()},
+ }
+ )
return val_loss
def _esr_loss(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
@@ -302,11 +339,16 @@ class Model(pl.LightningModule, InitializableFromConfig):
def _mse_loss(self, preds, targets, pre_emph_coef: Optional[float] = None):
if pre_emph_coef is not None:
preds, targets = [
- z[..., 1:] - pre_emph_coef * z[..., :-1] for z in (preds, targets)
+ apply_pre_emphasis_filter(z, pre_emph_coef) for z in (preds, targets)
]
return nn.MSELoss()(preds, targets)
- def _mrstft_loss(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+ def _mrstft_loss(
+ self,
+ preds: torch.Tensor,
+ targets: torch.Tensor,
+ pre_emph_coef: Optional[float] = None,
+ ) -> torch.Tensor:
"""
Experimental Multi Resolution Short Time Fourier Transform Loss using auraloss implementation.
B: Batch size
@@ -318,9 +360,13 @@ class Model(pl.LightningModule, InitializableFromConfig):
"""
if self._mrstft is None:
self._mrstft = auraloss.freq.MultiResolutionSTFTLoss()
-
backup_device = "cpu"
+ if pre_emph_coef is not None:
+ preds, targets = [
+ apply_pre_emphasis_filter(z, pre_emph_coef) for z in (preds, targets)
+ ]
+
try:
return multi_resolution_stft_loss(
preds, targets, self._mrstft, device=self._mrstft_device
diff --git a/nam/models/losses.py b/nam/models/losses.py
@@ -11,7 +11,19 @@ from typing import Optional
import torch
from auraloss.freq import MultiResolutionSTFTLoss
-___all__ = ["esr", "multi_resolution_stft_loss"]
+___all__ = ["apply_pre_emphasis_filter", "esr", "multi_resolution_stft_loss"]
+
+
+def apply_pre_emphasis_filter(x: torch.Tensor, coef: float) -> torch.Tensor:
+ """
+ Apply first-order pre-emphsis filter
+
+ :param x: (*, L)
+ :param coef: The coefficient
+
+ :return: (*, L-1)
+ """
+ return x[..., 1:] - coef * x[..., :-1]
def esr(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
diff --git a/nam/train/colab.py b/nam/train/colab.py
@@ -53,6 +53,7 @@ def _check_for_files() -> Tuple[Version, str]:
raise FileNotFoundError(
f"Didn't find your reamped output audio file. Please upload {_OUTPUT_BASENAME}."
)
+ print(f"Found {input_basename}, version {input_version}")
return input_version, input_basename
@@ -76,6 +77,7 @@ def run(
seed: Optional[int] = 0,
user_metadata: Optional[UserMetadata] = None,
ignore_checks: bool = False,
+ fit_cab: bool = False,
):
"""
:param epochs: How amny epochs we'll train for.
@@ -106,6 +108,7 @@ def run(
seed=seed,
local=False,
ignore_checks=ignore_checks,
+ fit_cab=fit_cab,
)
if model is None:
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -555,6 +555,10 @@ def _get_wavenet_config(architecture):
}[architecture]
+_CAB_MRSTFT_PRE_EMPH_WEIGHT = 2.0e-4
+_CAB_MRSTFT_PRE_EMPH_COEF = 0.85
+
+
def _get_configs(
input_version: Version,
input_path: str,
@@ -567,6 +571,7 @@ def _get_configs(
lr: float,
lr_decay: float,
batch_size: int,
+ fit_cab: bool,
):
def get_kwargs(data_info: _DataInfo):
if data_info.major_version == 1:
@@ -625,7 +630,9 @@ def _get_configs(
"optimizer": {"lr": 0.01},
"lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 0.995}},
}
- model_config["loss"]["mrstft_weight"] = 2e-4
+ if fit_cab:
+ model_config["loss"]["pre_emph_mrstft_weight"] = _CAB_MRSTFT_PRE_EMPH_WEIGHT
+ model_config["loss"]["pre_emph_mrstft_coef"] = _CAB_MRSTFT_PRE_EMPH_COEF
if torch.cuda.is_available():
device_config = {"accelerator": "gpu", "devices": 1}
@@ -771,6 +778,7 @@ def train(
modelname: str = "model",
ignore_checks: bool = False,
local: bool = False,
+ fit_cab: bool = False,
) -> Optional[Model]:
if seed is not None:
torch.manual_seed(seed)
@@ -815,6 +823,7 @@ def train(
lr,
lr_decay,
batch_size,
+ fit_cab,
)
print("Starting training. It's time to kick ass and chew bubblegum!")
diff --git a/nam/train/gui.py b/nam/train/gui.py
@@ -32,7 +32,7 @@ from enum import Enum
from functools import partial
from pathlib import Path
from tkinter import filedialog
-from typing import Callable, Optional, Sequence
+from typing import Callable, Dict, Optional, Sequence
try:
import torch
@@ -148,6 +148,17 @@ class _PathButton(object):
h()
+class _CheckboxKeys(Enum):
+ """
+ Keys for checkboxes
+ """
+
+ FIT_CAB = "fit_cab"
+ SILENT_TRAINING = "silent_training"
+ SAVE_PLOT = "save_plot"
+ IGNORE_DATA_CHECKS = "ignore_data_checks"
+
+
class _GUI(object):
def __init__(self):
self._root = tk.Tk()
@@ -257,41 +268,45 @@ class _GUI(object):
def _get_additional_options_frame(self):
# Checkboxes
+ # TODO get these definitions into __init__()
self._frame_checkboxes = tk.Frame(self._root)
self._frame_checkboxes.pack(side=tk.LEFT)
-
- # Silent run (bypass popups)
row = 1
- self._silent = tk.BooleanVar()
- self._checkbox_silent = tk.Checkbutton(
- self._frame_checkboxes,
- text="Silent run (suggested for batch training)",
- variable=self._silent,
- )
- self._checkbox_silent.grid(row=row, column=1, sticky="W")
- row += 1
-
- # Auto save the end plot
- self._save_plot = tk.BooleanVar()
- self._save_plot.set(True) # default this to true
- self._checkbox_save_plot = tk.Checkbutton(
- self._frame_checkboxes,
- text="Save ESR plot automatically",
- variable=self._save_plot,
+
+ @dataclass
+ class Checkbox(object):
+ variable: tk.BooleanVar
+ check_button: tk.Checkbutton
+
+ def make_checkbox(
+ key: _CheckboxKeys, text: str, default_value: bool
+ ) -> Checkbox:
+ variable = tk.BooleanVar()
+ variable.set(default_value)
+ check_button = tk.Checkbutton(
+ self._frame_checkboxes, text=text, variable=variable
+ )
+ self._checkboxes[key] = Checkbox(variable, check_button)
+
+ self._checkboxes: Dict[_CheckboxKeys, Checkbox] = dict()
+ make_checkbox(_CheckboxKeys.FIT_CAB, "Cab modeling", False)
+ make_checkbox(
+ _CheckboxKeys.SILENT_TRAINING,
+ "Silent run (suggested for batch training)",
+ False,
)
- self._checkbox_save_plot.grid(row=row, column=1, sticky="W")
- row += 1
-
- # Skip the data quality checks!
- self._ignore_checks = tk.BooleanVar()
- self._ignore_checks.set(False)
- self._checkbox_ignore_checks = tk.Checkbutton(
- self._frame_checkboxes,
- text="Ignore data quality checks (DO AT YOUR OWN RISK!)",
- variable=self._ignore_checks,
+ make_checkbox(_CheckboxKeys.SAVE_PLOT, "Save ESR plot automatically", True)
+ make_checkbox(
+ _CheckboxKeys.IGNORE_DATA_CHECKS,
+ "Ignore data quality checks (DO AT YOUR OWN RISK!)",
+ False,
)
- self._checkbox_ignore_checks.grid(row=row, column=1, sticky="W")
- row += 1
+
+ # Grid them:
+ row = 1
+ for v in self._checkboxes.values():
+ v.check_button.grid(row=row, column=1, sticky="W")
+ row += 1
def mainloop(self):
self._root.mainloop()
@@ -344,11 +359,14 @@ class _GUI(object):
lr=lr,
lr_decay=lr_decay,
seed=seed,
- silent=self._silent.get(),
- save_plot=self._save_plot.get(),
+ silent=self._checkboxes[_CheckboxKeys.SILENT_TRAINING].variable.get(),
+ save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(),
modelname=basename,
- ignore_checks=self._ignore_checks.get(),
+ ignore_checks=self._checkboxes[
+ _CheckboxKeys.IGNORE_DATA_CHECKS
+ ].variable.get(),
local=True,
+ fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(),
)
if trained_model is None:
print("Model training failed! Skip exporting...")
diff --git a/tests/test_nam/test_models/test_losses.py b/tests/test_nam/test_models/test_losses.py
@@ -9,6 +9,28 @@ import torch.nn as nn
from nam.models import losses
[email protected](
+ "x,coef,y_expected",
+ (
+ (torch.Tensor([0.0, 1.0, 2.0]), 1.0, torch.Tensor([1.0, 1.0])),
+ (torch.Tensor([0.0, 1.0, 2.0]), 0.5, torch.Tensor([1.0, 1.5])),
+ (
+ torch.Tensor([[0.0, 1.0, 0.0], [1.0, 1.5, 2.0]]),
+ 0.5,
+ torch.Tensor([[1.0, -0.5], [1.0, 1.25]]),
+ ),
+ ),
+)
+def test_apply_pre_emphasis_filter_1d(
+ x: torch.Tensor, coef: float, y_expected: torch.Tensor
+):
+ y_actual = losses.apply_pre_emphasis_filter(x, coef)
+ assert isinstance(y_actual, torch.Tensor)
+ assert y_actual.ndim == y_expected.ndim
+ assert y_actual.shape == y_expected.shape
+ assert torch.allclose(y_actual, y_expected)
+
+
def test_esr():
"""
Is the ESR calculation correct?