neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 560ef39566465adf26988de3582c96a7ddde1e2e
parent f2c3ff91c94bd06b19ed3e0052ade1e4dea02391
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Sun, 24 Nov 2024 23:53:46 -0800

[BREAKING] Rename modules (#511)

* Re-organize and deprecations

* nam.models.base to nam.train.lightning_module
* Remove recurrent.LSTMCore

* Refactor lightning module tests to new file

* Black

* typo

* nam.models._base to nam.models.base
Diffstat:
Mnam/models/__init__.py | 8++++----
Dnam/models/_base.py | 256-------------------------------------------------------------------------------
Mnam/models/base.py | 558++++++++++++++++++++++++++++++-------------------------------------------------
Mnam/models/conv_net.py | 2+-
Mnam/models/linear.py | 2+-
Mnam/models/recurrent.py | 159+------------------------------------------------------------------------------
Mnam/models/wavenet.py | 2+-
Mnam/train/core.py | 25++++++++++++++-----------
Mnam/train/full.py | 8++++----
Mnam/train/gui/__init__.py | 4++++
Anam/train/lightning_module.py | 405+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Msetup.py | 7++++---
Mtests/test_nam/test_models/_convolutional.py | 4+++-
Mtests/test_nam/test_models/test_base.py | 68+++++++++-----------------------------------------------------------
Mtests/test_nam/test_models/test_wavenet.py | 4++--
Mtests/test_nam/test_train/test_core.py | 20++++++++++++--------
Atests/test_nam/test_train/test_lightning_module.py | 68++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
17 files changed, 742 insertions(+), 858 deletions(-)

diff --git a/nam/models/__init__.py b/nam/models/__init__.py @@ -6,10 +6,10 @@ NAM's neural networks """ -from . import _base # noqa F401 +from . import base # noqa F401 from . import exportable # noqa F401 from . import losses # noqa F401 -from . import wavenet # noqa F401 -from .base import Model # noqa F401 -from .linear import Linear # noqa F401 from .conv_net import ConvNet # noqa F401 +from .linear import Linear # noqa F401 +from .recurrent import LSTM # noqa F401 +from .wavenet import WaveNet # noqa F401 diff --git a/nam/models/_base.py b/nam/models/_base.py @@ -1,256 +0,0 @@ -# File: _base.py -# Created Date: Tuesday February 8th 2022 -# Author: Steven Atkinson (steven@atkinson.mn) - -""" -The foundation of the model without the PyTorch Lightning attributes (losses, training -steps) -""" - -import abc -import math -import pkg_resources -from typing import Any, Dict, Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn as nn - -from .._core import InitializableFromConfig -from ..data import wav_to_tensor -from .exportable import Exportable - - -class _Base(nn.Module, InitializableFromConfig, Exportable): - def __init__(self, sample_rate: Optional[float] = None): - super().__init__() - self.register_buffer( - "_has_sample_rate", torch.tensor(sample_rate is not None, dtype=torch.bool) - ) - self.register_buffer( - "_sample_rate", torch.tensor(0.0 if sample_rate is None else sample_rate) - ) - - @property - @abc.abstractmethod - def pad_start_default(self) -> bool: - pass - - @property - @abc.abstractmethod - def receptive_field(self) -> int: - """ - Receptive field of the model - """ - pass - - @abc.abstractmethod - def forward(self, *args, **kwargs) -> torch.Tensor: - pass - - @classmethod - def _metadata_loudness_x(cls) -> torch.Tensor: - return wav_to_tensor( - pkg_resources.resource_filename( - "nam", "models/_resources/loudness_input.wav" - ) - ) - - @property - def device(self) -> Optional[torch.device]: - """ - Helpful property, where the parameters of the model live. - """ - # We can do this because the models are tiny and I don't expect a NAM to be on - # multiple devices - try: - return next(self.parameters()).device - except StopIteration: - return None - - @property - def sample_rate(self) -> Optional[float]: - return self._sample_rate.item() if self._has_sample_rate else None - - @sample_rate.setter - def sample_rate(self, val: Optional[float]): - self._has_sample_rate = torch.tensor(val is not None, dtype=torch.bool) - self._sample_rate = torch.tensor(0.0 if val is None else val) - - def _get_export_dict(self): - d = super()._get_export_dict() - sample_rate_key = "sample_rate" - if sample_rate_key in d: - raise RuntimeError( - "Model wants to put 'sample_rate' into model export dict, but the key " - "is already taken!" - ) - d[sample_rate_key] = self.sample_rate - return d - - def _metadata_loudness(self, gain: float = 1.0, db: bool = True) -> float: - """ - How loud is this model when given a standardized input? - In dB - - :param gain: Multiplies input signal - """ - x = self._metadata_loudness_x().to(self.device) - y = self._at_nominal_settings(gain * x) - loudness = torch.sqrt(torch.mean(torch.square(y))) - if db: - loudness = 20.0 * torch.log10(loudness) - return loudness.item() - - def _metadata_gain(self) -> float: - """ - Between 0 and 1, how much gain / compression does the model seem to have? - """ - x = np.linspace(0.0, 1.0, 11) - y = np.array([self._metadata_loudness(gain=gain, db=False) for gain in x]) - # - # O ^ o o o o o o - # u | o x +-------------------------------------+ - # t | o x | x: Minimum gain (no compression) | - # p | o x | o: Max gain (100% compression) | - # u | o x +-------------------------------------+ - # t | o - # +-------------> - # Input - # - max_gain = y[-1] * len(x) # "Square" - min_gain = 0.5 * max_gain # "Triangle" - gain_range = max_gain - min_gain - this_gain = y.sum() - normalized_gain = (this_gain - min_gain) / gain_range - return np.clip(normalized_gain, 0.0, 1.0) - - def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor: - # parametric?... - raise NotImplementedError() - - @abc.abstractmethod - def _forward(self, *args) -> torch.Tensor: - """ - The true forward method. - - :param x: (N,L1) - :return: (N,L1-RF+1) - """ - pass - - def _export_input_output_args(self) -> Tuple[Any]: - """ - Create any other args necessesary (e.g. params to eval at) - """ - return () - - def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]: - args = self._export_input_output_args() - rate = self.sample_rate - if rate is None: - raise RuntimeError( - "Cannot export model's input and output without a sample rate." - ) - x = torch.cat( - [ - torch.zeros((rate,)), - 0.5 - * torch.sin( - 2.0 * math.pi * 220.0 * torch.linspace(0.0, 1.0, rate + 1)[:-1] - ), - torch.zeros((rate,)), - ] - ) - # Use pad start to ensure same length as requested by ._export_input_output() - return ( - x.detach().cpu().numpy(), - self(*args, x, pad_start=True).detach().cpu().numpy(), - ) - - -class BaseNet(_Base): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._mps_65536_fallback = False - - def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None, **kwargs): - pad_start = self.pad_start_default if pad_start is None else pad_start - scalar = x.ndim == 1 - if scalar: - x = x[None] - if pad_start: - x = torch.cat( - (torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), dim=1 - ) - if x.shape[1] < self.receptive_field: - raise ValueError( - f"Input has {x.shape[1]} samples, which is too few for this model with " - f"receptive field {self.receptive_field}!" - ) - y = self._forward_mps_safe(x, **kwargs) - if scalar: - y = y[0] - return y - - def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor: - return self(x) - - def _forward_mps_safe(self, x: torch.Tensor, **kwargs) -> torch.Tensor: - """ - Wrap `._forward()` to protect against MPS-unsupported inptu lengths - beyond 65,536 samples. - - Check this again when PyTorch 2.5.2 is released--hopefully it's fixed - then. - """ - if not self._mps_65536_fallback: - try: - return self._forward(x, **kwargs) - except NotImplementedError as e: - if "Output channels > 65536 not supported at the MPS device." in str(e): - print( - "===WARNING===\n" - "NAM encountered a bug in PyTorch's MPS backend and will " - "switch to a fallback.\n" - f"Your version of PyTorch is {torch.__version__}.\n" - "Please report this in an Issue at:\n" - "https://github.com/sdatkinson/neural-amp-modeler/issues/new/choose" - "\n" - "so that NAM's dependencies can avoid buggy versions of " - "PyTorch and the associated performance hit." - ) - self._mps_65536_fallback = True - return self._forward_mps_safe(x, **kwargs) - else: - raise e - else: - # Stitch together the output one piece at a time to avoid the MPS error - stride = 65_536 - (self.receptive_field - 1) - # We need to make sure that the last segment is big enough that we have the required history for the receptive field. - out_list = [] - for i in range(0, x.shape[1], stride): - j = min(i+65_536, x.shape[1]) - xi = x[:, i:j] - out_list.append(self._forward(xi, **kwargs)) - # Bit hacky, but correct. - if j == x.shape[1]: - break - return torch.cat(out_list, dim=1) - - - @abc.abstractmethod - def _forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: - """ - The true forward method. - - :param x: (N,L1) - :return: (N,L1-RF+1) - """ - pass - - def _get_non_user_metadata(self) -> Dict[str, Union[str, int, float]]: - d = super()._get_non_user_metadata() - d["loudness"] = self._metadata_loudness() - d["gain"] = self._metadata_gain() - return d diff --git a/nam/models/base.py b/nam/models/base.py @@ -1,395 +1,255 @@ -# File: base.py -# Created Date: Saturday February 5th 2022 +# File: _base.py +# Created Date: Tuesday February 8th 2022 # Author: Steven Atkinson (steven@atkinson.mn) """ -Implements the base PyTorch Lightning module. -This is meant to combine an actual model (subclassed from `._base.BaseNet`) -along with loss function boilerplate. - -For the base *PyTorch* model containing the actual architecture, see `._base`. +The foundation of the model without the PyTorch Lightning attributes (losses, training +steps) """ -from dataclasses import dataclass -from enum import Enum -from typing import Any, Dict, NamedTuple, Optional, Tuple +import abc +import math +import pkg_resources +from typing import Any, Dict, Optional, Tuple, Union -import auraloss -import logging -import pytorch_lightning as pl +import numpy as np import torch import torch.nn as nn from .._core import InitializableFromConfig -from .conv_net import ConvNet -from .linear import Linear -from .losses import apply_pre_emphasis_filter, esr, multi_resolution_stft_loss, mse_fft -from .recurrent import LSTM -from .wavenet import WaveNet - -logger = logging.getLogger(__name__) - +from ..data import wav_to_tensor +from .exportable import Exportable -class ValidationLoss(Enum): - """ - mse: mean squared error - esr: error signal ratio (Eq. (10) from - https://www.mdpi.com/2076-3417/10/3/766/htm - NOTE: Be careful when computing ESR on minibatches! The average ESR over - a minibatch of data not the same as the ESR of all of the same data in - the minibatch calculated over at once (because of the denominator). - (Hint: think about what happens if one item in the minibatch is all - zeroes...) - """ - MSE = "mse" - ESR = "esr" +class _Base(nn.Module, InitializableFromConfig, Exportable): + def __init__(self, sample_rate: Optional[float] = None): + super().__init__() + self.register_buffer( + "_has_sample_rate", torch.tensor(sample_rate is not None, dtype=torch.bool) + ) + self.register_buffer( + "_sample_rate", torch.tensor(0.0 if sample_rate is None else sample_rate) + ) + @property + @abc.abstractmethod + def pad_start_default(self) -> bool: + pass -@dataclass -class LossConfig(InitializableFromConfig): - """ - :param mrstft_weight: Multi-resolution short-time Fourier transform loss - coefficient. None means to skip; 2e-4 works pretty well if one wants to use it. - :param mask_first: How many of the first samples to ignore when comptuing the loss. - :param dc_weight: Weight for the DC loss term. If 0, ignored. - :params val_loss: Which loss to track for the best model checkpoint. - :param pre_emph_coef: Coefficient of 1st-order pre-emphasis filter from - https://www.mdpi.com/2076-3417/10/3/766. Paper value: 0.95. - :param pre_ - """ + @property + @abc.abstractmethod + def receptive_field(self) -> int: + """ + Receptive field of the model + """ + pass - mrstft_weight: Optional[float] = None - fourier: bool = False - mask_first: int = 0 - dc_weight: float = None - val_loss: ValidationLoss = ValidationLoss.MSE - pre_emph_weight: Optional[float] = None - pre_emph_coef: Optional[float] = None - pre_emph_mrstft_weight: Optional[float] = None - pre_emph_mrstft_coef: Optional[float] = None + @abc.abstractmethod + def forward(self, *args, **kwargs) -> torch.Tensor: + pass @classmethod - def parse_config(cls, config): - config = super().parse_config(config) - return { - "fourier": config.get("fourier", False), - "mask_first": config.get("mask_first", 0), - "dc_weight": config.get("dc_weight"), - "val_loss": ValidationLoss(config.get("val_loss", "mse")), - "pre_emph_coef": config.get("pre_emph_coef"), - "pre_emph_weight": config.get("pre_emph_weight"), - "mrstft_weight": cls._get_mrstft_weight(config), - "pre_emph_mrstft_weight": config.get("pre_emph_mrstft_weight"), - "pre_emph_mrstft_coef": config.get("pre_emph_mrstft_coef"), - } + def _metadata_loudness_x(cls) -> torch.Tensor: + return wav_to_tensor( + pkg_resources.resource_filename( + "nam", "models/_resources/loudness_input.wav" + ) + ) - def apply_mask(self, *args): + @property + def device(self) -> Optional[torch.device]: """ - :param args: (L,) or (B,) - :return: (L-M,) or (B, L-M) + Helpful property, where the parameters of the model live. """ - return tuple(a[..., self.mask_first :] for a in args) - - @classmethod - def _get_mrstft_weight(cls, config) -> Optional[float]: - key = "mrstft_weight" - wrong_key = "mstft_key" # Backward compatibility - if key in config: - if "mstft_weight" in config: - raise ValueError( - f"Received loss configuration with both '{key}' and " - f"'{wrong_key}'. Provide only '{key}'." - ) - return config[key] - elif wrong_key in config: - logger.warning( - f"Use of '{wrong_key}' is deprecated and will be removed in a future " - f"version. Use '{key}' instead." - ) - return config[wrong_key] - else: + # We can do this because the models are tiny and I don't expect a NAM to be on + # multiple devices + try: + return next(self.parameters()).device + except StopIteration: return None + @property + def sample_rate(self) -> Optional[float]: + return self._sample_rate.item() if self._has_sample_rate else None + + @sample_rate.setter + def sample_rate(self, val: Optional[float]): + self._has_sample_rate = torch.tensor(val is not None, dtype=torch.bool) + self._sample_rate = torch.tensor(0.0 if val is None else val) + + def _get_export_dict(self): + d = super()._get_export_dict() + sample_rate_key = "sample_rate" + if sample_rate_key in d: + raise RuntimeError( + "Model wants to put 'sample_rate' into model export dict, but the key " + "is already taken!" + ) + d[sample_rate_key] = self.sample_rate + return d -class _LossItem(NamedTuple): - weight: Optional[float] - value: Optional[torch.Tensor] - - -_model_net_init_registry = { - "ConvNet": ConvNet.init_from_config, - "Linear": Linear.init_from_config, - "LSTM": LSTM.init_from_config, - "WaveNet": WaveNet.init_from_config, -} - + def _metadata_loudness(self, gain: float = 1.0, db: bool = True) -> float: + """ + How loud is this model when given a standardized input? + In dB -class Model(pl.LightningModule, InitializableFromConfig): - def __init__( - self, - net, - optimizer_config: Optional[dict] = None, - scheduler_config: Optional[dict] = None, - loss_config: Optional[LossConfig] = None, - ): + :param gain: Multiplies input signal """ - :param scheduler_config: contains - Required: - * "class" - * "kwargs" - Optional (defaults to Lightning defaults): - * "interval" ("epoch" of "step") - * "frequency" (int) - * "monitor" (str) + x = self._metadata_loudness_x().to(self.device) + y = self._at_nominal_settings(gain * x) + loudness = torch.sqrt(torch.mean(torch.square(y))) + if db: + loudness = 20.0 * torch.log10(loudness) + return loudness.item() + + def _metadata_gain(self) -> float: """ - super().__init__() - self._net = net - self._optimizer_config = {} if optimizer_config is None else optimizer_config - self._scheduler_config = scheduler_config - self._loss_config = LossConfig() if loss_config is None else loss_config - self._mrstft = None # Multi-resolution short-time Fourier transform loss - # Where to compute the MRSTFT. - # Keeping it on-device is preferable, but if that fails, then remember to drop - # it to cpu from then on. - self._mrstft_device: Optional[torch.device] = None - - @classmethod - def init_from_config(cls, config): - checkpoint_path = config.get("checkpoint_path") - config = cls.parse_config(config) - return ( - cls(**config) - if checkpoint_path is None - else cls.load_from_checkpoint(checkpoint_path, **config) - ) - - @classmethod - def parse_config(cls, config): + Between 0 and 1, how much gain / compression does the model seem to have? """ - e.g. - - { - "net": { - "name": "ConvNet", - "config": {...} - }, - "loss": { - "dc_weight": 0.1 - }, - "optimizer": { - "lr": 0.0003 - }, - "lr_scheduler": { - "class": "ReduceLROnPlateau", - "kwargs": { - "factor": 0.8, - "patience": 10, - "cooldown": 15, - "min_lr": 1e-06, - "verbose": true - }, - "monitor": "val_loss" - } - } + x = np.linspace(0.0, 1.0, 11) + y = np.array([self._metadata_loudness(gain=gain, db=False) for gain in x]) + # + # O ^ o o o o o o + # u | o x +-------------------------------------+ + # t | o x | x: Minimum gain (no compression) | + # p | o x | o: Max gain (100% compression) | + # u | o x +-------------------------------------+ + # t | o + # +-------------> + # Input + # + max_gain = y[-1] * len(x) # "Square" + min_gain = 0.5 * max_gain # "Triangle" + gain_range = max_gain - min_gain + this_gain = y.sum() + normalized_gain = (this_gain - min_gain) / gain_range + return np.clip(normalized_gain, 0.0, 1.0) + + def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor: + # parametric?... + raise NotImplementedError() + + @abc.abstractmethod + def _forward(self, *args) -> torch.Tensor: """ - config = super().parse_config(config) - net_config = config["net"] - net = _model_net_init_registry[net_config["name"]](net_config["config"]) - loss_config = LossConfig.init_from_config(config.get("loss", {})) - return { - "net": net, - "optimizer_config": config["optimizer"], - "scheduler_config": config["lr_scheduler"], - "loss_config": loss_config, - } - - @classmethod - def register_net_initializer(cls, name, constructor, overwrite: bool = False): - if name in _model_net_init_registry and not overwrite: - raise KeyError( - f"A constructor for net name '{name}' is already registered!" - ) - _model_net_init_registry[name] = constructor + The true forward method. - @property - def net(self) -> nn.Module: - return self._net - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), **self._optimizer_config) - if self._scheduler_config is None: - return optimizer - else: - lr_scheduler = getattr( - torch.optim.lr_scheduler, self._scheduler_config["class"] - )(optimizer, **self._scheduler_config["kwargs"]) - lr_scheduler_config = {"scheduler": lr_scheduler} - for key in ("interval", "frequency", "monitor"): - if key in self._scheduler_config: - lr_scheduler_config[key] = self._scheduler_config[key] - return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} - - def forward(self, *args, **kwargs): - return self.net(*args, **kwargs) # TODO deprecate--use self.net() instead. - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - # Resolves https://github.com/sdatkinson/neural-amp-modeler/issues/351 - self.net.sample_rate = checkpoint["sample_rate"] - - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - # Resolves https://github.com/sdatkinson/neural-amp-modeler/issues/351 - checkpoint["sample_rate"] = self.net.sample_rate - - def _shared_step( - self, batch - ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, _LossItem]]: + :param x: (N,L1) + :return: (N,L1-RF+1) """ - B: Batch size - L: Sequence length + pass - :return: (B,L), (B,L) + def _export_input_output_args(self) -> Tuple[Any]: """ - args, targets = batch[:-1], batch[-1] - preds = self(*args, pad_start=False) - - # Compute all relevant losses. - loss_dict = {} # Mind keys versus validation loss requested... - # Prediction aka MSE loss - if self._loss_config.fourier: - loss_dict["MSE_FFT"] = _LossItem(1.0, mse_fft(preds, targets)) - else: - loss_dict["MSE"] = _LossItem(1.0, self._mse_loss(preds, targets)) - # Pre-emphasized MSE - if self._loss_config.pre_emph_weight is not None: - if (self._loss_config.pre_emph_coef is None) != ( - self._loss_config.pre_emph_weight is None - ): - raise ValueError("Invalid pre-emph") - loss_dict["Pre-emphasized MSE"] = _LossItem( - self._loss_config.pre_emph_weight, - self._mse_loss( - preds, targets, pre_emph_coef=self._loss_config.pre_emph_coef - ), - ) - # Multi-resolution short-time Fourier transform loss - if self._loss_config.mrstft_weight is not None: - loss_dict["MRSTFT"] = _LossItem( - self._loss_config.mrstft_weight, self._mrstft_loss(preds, targets) + Create any other args necessesary (e.g. params to eval at) + """ + return () + + def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]: + args = self._export_input_output_args() + rate = self.sample_rate + if rate is None: + raise RuntimeError( + "Cannot export model's input and output without a sample rate." ) - # Pre-emphasized MRSTFT - if self._loss_config.pre_emph_mrstft_weight is not None: - loss_dict["Pre-emphasized MRSTFT"] = _LossItem( - self._loss_config.pre_emph_mrstft_weight, - self._mrstft_loss( - preds, targets, pre_emph_coef=self._loss_config.pre_emph_mrstft_coef + x = torch.cat( + [ + torch.zeros((rate,)), + 0.5 + * torch.sin( + 2.0 * math.pi * 220.0 * torch.linspace(0.0, 1.0, rate + 1)[:-1] ), - ) - # DC loss - dc_weight = self._loss_config.dc_weight - if dc_weight is not None and dc_weight > 0.0: - # Denominator could be a bad idea. I'm going to omit it esp since I'm - # using mini batches - mean_dims = torch.arange(1, preds.ndim).tolist() - dc_loss = nn.MSELoss()( - preds.mean(dim=mean_dims), targets.mean(dim=mean_dims) - ) - loss_dict["DC MSE"] = _LossItem(dc_weight, dc_loss) - - return preds, targets, loss_dict - - def training_step(self, batch, batch_idx): - _, _, loss_dict = self._shared_step(batch) + torch.zeros((rate,)), + ] + ) + # Use pad start to ensure same length as requested by ._export_input_output() + return ( + x.detach().cpu().numpy(), + self(*args, x, pad_start=True).detach().cpu().numpy(), + ) - loss = 0.0 - for v in loss_dict.values(): - if v.weight is not None and v.weight > 0.0: - loss = loss + v.weight * v.value - return loss - def validation_step(self, batch, batch_idx): - preds, targets, loss_dict = self._shared_step(batch) +class BaseNet(_Base): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._mps_65536_fallback = False - def get_val_loss(): - # "esr" -> "ESR" - # "mse" -> "MSE" - # Others unsupported... - # TODO better mapping from Enum to dict keys - val_loss_type = self._loss_config.val_loss - val_loss_key_for_loss_dict = val_loss_type.value.upper() - if val_loss_key_for_loss_dict in loss_dict: - return loss_dict[val_loss_key_for_loss_dict].value - else: - raise RuntimeError( - f"Undefined validation loss routine for {val_loss_type}" - ) + def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None, **kwargs): + pad_start = self.pad_start_default if pad_start is None else pad_start + scalar = x.ndim == 1 + if scalar: + x = x[None] + if pad_start: + x = torch.cat( + (torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), dim=1 + ) + if x.shape[1] < self.receptive_field: + raise ValueError( + f"Input has {x.shape[1]} samples, which is too few for this model with " + f"receptive field {self.receptive_field}!" + ) + y = self._forward_mps_safe(x, **kwargs) + if scalar: + y = y[0] + return y - loss_dict["ESR"] = _LossItem(None, self._esr_loss(preds, targets)) - val_loss = get_val_loss() - self.log_dict( - { - "val_loss": val_loss, - **{key: value.value for key, value in loss_dict.items()}, - } - ) - return val_loss + def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor: + return self(x) - def _esr_loss(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + def _forward_mps_safe(self, x: torch.Tensor, **kwargs) -> torch.Tensor: """ - Error signal ratio aka ESR loss. - - Eq. (10), from - https://www.mdpi.com/2076-3417/10/3/766/htm + Wrap `._forward()` to protect against MPS-unsupported input lengths + beyond 65,536 samples. - B: Batch size - L: Sequence length - - :param preds: (B,L) - :param targets: (B,L) - :return: () + Check this again when PyTorch 2.5.2 is released--hopefully it's fixed + then. """ - return esr(preds, targets) - - def _mse_loss(self, preds, targets, pre_emph_coef: Optional[float] = None): - if pre_emph_coef is not None: - preds, targets = [ - apply_pre_emphasis_filter(z, pre_emph_coef) for z in (preds, targets) - ] - return nn.MSELoss()(preds, targets) - - def _mrstft_loss( - self, - preds: torch.Tensor, - targets: torch.Tensor, - pre_emph_coef: Optional[float] = None, - ) -> torch.Tensor: + if not self._mps_65536_fallback: + try: + return self._forward(x, **kwargs) + except NotImplementedError as e: + if "Output channels > 65536 not supported at the MPS device." in str(e): + print( + "===WARNING===\n" + "NAM encountered a bug in PyTorch's MPS backend and will " + "switch to a fallback.\n" + f"Your version of PyTorch is {torch.__version__}.\n" + "Please report this in an Issue at:\n" + "https://github.com/sdatkinson/neural-amp-modeler/issues/new/choose" + "\n" + "so that NAM's dependencies can avoid buggy versions of " + "PyTorch and the associated performance hit." + ) + self._mps_65536_fallback = True + return self._forward_mps_safe(x, **kwargs) + else: + raise e + else: + # Stitch together the output one piece at a time to avoid the MPS error + stride = 65_536 - (self.receptive_field - 1) + # We need to make sure that the last segment is big enough that we have the required history for the receptive field. + out_list = [] + for i in range(0, x.shape[1], stride): + j = min(i + 65_536, x.shape[1]) + xi = x[:, i:j] + out_list.append(self._forward(xi, **kwargs)) + # Bit hacky, but correct. + if j == x.shape[1]: + break + return torch.cat(out_list, dim=1) + + @abc.abstractmethod + def _forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: """ - Experimental Multi Resolution Short Time Fourier Transform Loss using auraloss implementation. - B: Batch size - L: Sequence length + The true forward method. - :param preds: (B,L) - :param targets: (B,L) - :return: () + :param x: (N,L1) + :return: (N,L1-RF+1) """ - if self._mrstft is None: - self._mrstft = auraloss.freq.MultiResolutionSTFTLoss() - backup_device = "cpu" - - if pre_emph_coef is not None: - preds, targets = [ - apply_pre_emphasis_filter(z, pre_emph_coef) for z in (preds, targets) - ] + pass - try: - return multi_resolution_stft_loss( - preds, targets, self._mrstft, device=self._mrstft_device - ) - except Exception as e: - if self._mrstft_device == backup_device: - raise e - logger.warning("MRSTFT failed on device; falling back to CPU") - self._mrstft_device = backup_device - return multi_resolution_stft_loss( - preds, targets, self._mrstft, device=self._mrstft_device - ) + def _get_non_user_metadata(self) -> Dict[str, Union[str, int, float]]: + d = super()._get_non_user_metadata() + d["loudness"] = self._metadata_loudness() + d["gain"] = self._metadata_gain() + return d diff --git a/nam/models/conv_net.py b/nam/models/conv_net.py @@ -19,7 +19,7 @@ import torch.nn.functional as F from .. import __version__ from ..data import wav_to_tensor from ._activations import get_activation -from ._base import BaseNet +from .base import BaseNet from ._names import ACTIVATION_NAME, BATCHNORM_NAME, CONV_NAME diff --git a/nam/models/linear.py b/nam/models/linear.py @@ -11,7 +11,7 @@ import torch import torch.nn as nn from .._version import __version__ -from ._base import BaseNet +from .base import BaseNet class Linear(BaseNet): diff --git a/nam/models/recurrent.py b/nam/models/recurrent.py @@ -18,10 +18,7 @@ import numpy as np import torch import torch.nn as nn -from ._base import BaseNet - - -# TODO merge LSTMCore into LSTM +from .base import BaseNet class _L(nn.LSTM): @@ -55,72 +52,6 @@ _LSTMCellType = torch.Tensor _LSTMHiddenCellType = Tuple[_LSTMHiddenType, _LSTMCellType] -class LSTMCore(_L): - def __init__( - self, - *args, - train_burn_in: Optional[int] = None, - train_truncate: Optional[int] = None, - **kwargs, - ): - super().__init__(*args, **kwargs) - if not self.batch_first: - raise NotImplementedError("Need batch first") - self._train_burn_in = train_burn_in - self._train_truncate = train_truncate - assert len(args) < 3, "Provide as kwargs" - self._initial_cell = nn.Parameter( - torch.zeros((self.num_layers, self.hidden_size)) - ) - self._initial_hidden = nn.Parameter( - torch.zeros((self.num_layers, self.hidden_size)) - ) - - def forward(self, x, hidden_state=None): - """ - Same as nn.LSTM.forward except: - * Learned inital state - * truncated BPTT when .training - """ - if x.ndim != 3: - raise NotImplementedError("Need (B,L,D)") - last_hidden_state = ( - self._initial_state(None if x.ndim == 2 else len(x)) - if hidden_state is None - else hidden_state - ) - if not self.training or self._train_truncate is None: - output_features = super().forward(x, last_hidden_state)[0] - else: - output_features_list = [] - if self._train_burn_in is not None: - last_output_features, last_hidden_state = super().forward( - x[:, : self._train_burn_in, :], last_hidden_state - ) - output_features_list.append(last_output_features.detach()) - burn_in_offset = 0 if self._train_burn_in is None else self._train_burn_in - for i in range(burn_in_offset, x.shape[1], self._train_truncate): - if i > burn_in_offset: - # Don't detach the burn-in state so that we can learn it. - last_hidden_state = tuple(z.detach() for z in last_hidden_state) - last_output_features, last_hidden_state = super().forward( - x[:, i : i + self._train_truncate, :], last_hidden_state - ) - output_features_list.append(last_output_features) - output_features = torch.cat(output_features_list, dim=1) - return output_features - - def _initial_state(self, n: Optional[int]) -> _LSTMHiddenCellType: - return ( - (self._initial_hidden, self._initial_cell) - if n is None - else ( - torch.tile(self._initial_hidden[:, None], (1, n, 1)), - torch.tile(self._initial_cell[:, None], (1, n, 1)), - ) - ) - - # TODO get this somewhere more core-ish class _ExportsWeights(abc.ABC): @abc.abstractmethod @@ -371,91 +302,3 @@ class LSTM(BaseNet): torch.tile(self._initial_cell[:, None], (1, n, 1)), ) ) - - -# TODO refactor together - - -class _SkippyLSTM(nn.Module): - def __init__( - self, input_size, hidden_size, skip_in: bool = False, num_layers=1, **kwargs - ): - super().__init__() - layers_per_lstm = 1 - self._skip_in = skip_in - self._lstms = nn.ModuleList( - [ - _L( - self._layer_input_size(input_size, hidden_size, i), - hidden_size, - layers_per_lstm, - batch_first=True, - ) - for i in range(num_layers) - ] - ) - self._initial_hidden = nn.Parameter( - torch.zeros((self.num_layers, layers_per_lstm, self.hidden_size)) - ) - self._initial_cell = nn.Parameter( - torch.zeros((self.num_layers, layers_per_lstm, self.hidden_size)) - ) - - @property - def hidden_size(self): - return self._lstms[0].hidden_size - - @property - def input_size(self): - return self._lstms[0].input_size - - @property - def num_layers(self): - return len(self._lstms) - - @property - def output_size(self): - return self.num_layers * self.hidden_size - - def forward(self, input, state=None): - """ - :param input: (N,L,DX) - :param state: ((L,Li,N,DH), (L,Li,N,DH)) - - :return: (N,L,L*DH), ((L,Li,N,DH), (L,Li,N,DH)) - """ - h0, c0 = self.initial_state(input) if state is None else state - hiddens, h_arr, c_arr, hidden = [], [], [], None - for layer, h0i, c0i in zip(self._lstms, h0, c0): - if self._skip_in: - # TODO dense-block - layer_input = ( - input if hidden is None else torch.cat([input, hidden], dim=2) - ) - else: - layer_input = input if hidden is None else hidden - hidden, (hi, ci) = layer(layer_input, (h0i, c0i)) - hiddens.append(hidden) - h_arr.append(hi) - c_arr.append(ci) - return (torch.cat(hiddens, dim=2), (torch.stack(h_arr), torch.stack(c_arr))) - - def initial_state(self, input: torch.Tensor): - """ - Initial states for all the layers - - :return: (L,B,Li,DH) - """ - assert input.ndim == 3, "Batch only for now" - batch_size = len(input) # Assume batch_first - return ( - torch.tile(self._initial_hidden[:, :, None], (1, 1, batch_size, 1)), - torch.tile(self._initial_cell[:, :, None], (1, 1, batch_size, 1)), - ) - - def _layer_input_size(self, input_size, hidden_size, i) -> int: - # TODO dense-block - if self._skip_in: - return input_size + (0 if i == 0 else hidden_size) - else: - return input_size if i == 0 else hidden_size diff --git a/nam/models/wavenet.py b/nam/models/wavenet.py @@ -18,7 +18,7 @@ import torch import torch.nn as nn from ._activations import get_activation -from ._base import BaseNet +from .base import BaseNet from ._names import ACTIVATION_NAME, CONV_NAME diff --git a/nam/train/core.py b/nam/train/core.py @@ -26,12 +26,12 @@ from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader from ..data import DataError, Split, init_dataset, wav_to_np, wav_to_tensor -from ..models import Model from ..models.exportable import Exportable from ..models.losses import esr from ..models.metadata import UserMetadata from ..util import filter_warnings from ._version import PROTEUS_VERSION, Version +from .lightning_module import LightningModule from . import metadata # Training using the simplified trainers in NAM is done at 48k. @@ -424,13 +424,16 @@ def _calibrate_latency_v_all( print(msg) print("SHARE THIS PLOT IF YOU ASK FOR HELP") plt.figure() - plt.plot(np.arange(-lookahead, lookback), y_scan_average, color="C0", label="Signal average") + plt.plot( + np.arange(-lookahead, lookback), + y_scan_average, + color="C0", + label="Signal average", + ) for y_scan in y_scans: plt.plot(np.arange(-lookahead, lookback), y_scan, color="C0", alpha=0.2) plt.axvline(x=0, color="C1", linestyle="--", label="Trigger") - plt.axhline( - y=-trigger_threshold, color="k", linestyle="--", label="Threshold" - ) + plt.axhline(y=-trigger_threshold, color="k", linestyle="--", label="Threshold") plt.axhline(y=trigger_threshold, color="k", linestyle="--") plt.xlim((-lookahead, lookback)) plt.xlabel("Samples") @@ -1050,7 +1053,7 @@ def _get_configs( def _get_dataloaders( - data_config: Dict, learning_config: Dict, model: Model + data_config: Dict, learning_config: Dict, model: LightningModule ) -> Tuple[DataLoader, DataLoader]: data_config, learning_config = [deepcopy(c) for c in (data_config, learning_config)] data_config["common"]["nx"] = model.net.receptive_field @@ -1205,7 +1208,7 @@ class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint): super()._save_checkpoint(trainer, filepath) # Save the .nam: nam_filepath = self._get_nam_filepath(filepath) - pl_model: Model = trainer.model + pl_model: LightningModule = trainer.model nam_model = pl_model.net outdir = nam_filepath.parent # HACK: Assume the extension @@ -1273,7 +1276,7 @@ class TrainOutput(NamedTuple): simplified trainer. """ - model: Optional[Model] + model: Optional[LightningModule] metadata: metadata.TrainingMetadata @@ -1414,7 +1417,7 @@ def train( # * Model is re-instantiated after training anyways. # (Hacky) solution: set sample rate in model from dataloader after second # instantiation from final checkpoint. - model = Model.init_from_config(model_config) + model = LightningModule.init_from_config(model_config) train_dataloader, val_dataloader = _get_dataloaders( data_config, learning_config, model ) @@ -1449,9 +1452,9 @@ def train( # Go to best checkpoint best_checkpoint = trainer.checkpoint_callback.best_model_path if best_checkpoint != "": - model = Model.load_from_checkpoint( + model = LightningModule.load_from_checkpoint( trainer.checkpoint_callback.best_model_path, - **Model.parse_config(model_config), + **LightningModule.parse_config(model_config), ) model.cpu() model.eval() diff --git a/nam/train/full.py b/nam/train/full.py @@ -16,7 +16,7 @@ import torch from torch.utils.data import DataLoader from nam.data import ConcatDataset, Split, init_dataset -from nam.models import Model +from nam.train.lightning_module import LightningModule from nam.util import filter_warnings torch.manual_seed(0) @@ -143,7 +143,7 @@ def main( with open(Path(outdir, f"config_{basename}.json"), "w") as fp: json.dump(config, fp, indent=4) - model = Model.init_from_config(model_config) + model = LightningModule.init_from_config(model_config) # Add receptive field to data config: data_config["common"] = data_config.get("common", {}) if "nx" in data_config["common"]: @@ -178,9 +178,9 @@ def main( # Go to best checkpoint best_checkpoint = trainer.checkpoint_callback.best_model_path if best_checkpoint != "": - model = Model.load_from_checkpoint( + model = LightningModule.load_from_checkpoint( trainer.checkpoint_callback.best_model_path, - **Model.parse_config(model_config), + **LightningModule.parse_config(model_config), ) model.cpu() model.eval() diff --git a/nam/train/gui/__init__.py b/nam/train/gui/__init__.py @@ -20,6 +20,7 @@ import webbrowser from dataclasses import dataclass from enum import Enum from functools import partial + try: # Not supported in Colab from idlelib.tooltip import Hovertip except ModuleNotFoundError: @@ -28,8 +29,11 @@ except ModuleNotFoundError: """ Shell class """ + def __init__(self, *args, **kwargs): pass + + from pathlib import Path from tkinter import filedialog from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence diff --git a/nam/train/lightning_module.py b/nam/train/lightning_module.py @@ -0,0 +1,405 @@ +# File: base.py +# Created Date: Saturday February 5th 2022 +# Author: Steven Atkinson (steven@atkinson.mn) + +""" +Implements the base PyTorch Lightning module. +This is meant to combine an actual model (subclassed from `..models.base.BaseNet`) +along with loss function boilerplate. + +For the base *PyTorch* model containing the actual architecture, see `..models.base`. +""" + +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, NamedTuple, Optional, Tuple + +import auraloss +import logging +import pytorch_lightning as pl +import torch +import torch.nn as nn + +from .._core import InitializableFromConfig +from ..models.conv_net import ConvNet +from ..models.linear import Linear +from ..models.losses import ( + apply_pre_emphasis_filter, + esr, + multi_resolution_stft_loss, + mse_fft, +) +from ..models.recurrent import LSTM +from ..models.wavenet import WaveNet + +logger = logging.getLogger(__name__) + + +class ValidationLoss(Enum): + """ + mse: mean squared error + esr: error signal ratio (Eq. (10) from + https://www.mdpi.com/2076-3417/10/3/766/htm + NOTE: Be careful when computing ESR on minibatches! The average ESR over + a minibatch of data not the same as the ESR of all of the same data in + the minibatch calculated over at once (because of the denominator). + (Hint: think about what happens if one item in the minibatch is all + zeroes...) + """ + + MSE = "mse" + ESR = "esr" + + +@dataclass +class LossConfig(InitializableFromConfig): + """ + :param mrstft_weight: Multi-resolution short-time Fourier transform loss + coefficient. None means to skip; 2e-4 works pretty well if one wants to use it. + :param mask_first: How many of the first samples to ignore when comptuing the loss. + :param dc_weight: Weight for the DC loss term. If 0, ignored. + :params val_loss: Which loss to track for the best model checkpoint. + :param pre_emph_coef: Coefficient of 1st-order pre-emphasis filter from + https://www.mdpi.com/2076-3417/10/3/766. Paper value: 0.95. + :param pre_ + """ + + mrstft_weight: Optional[float] = None + fourier: bool = False + mask_first: int = 0 + dc_weight: float = None + val_loss: ValidationLoss = ValidationLoss.MSE + pre_emph_weight: Optional[float] = None + pre_emph_coef: Optional[float] = None + pre_emph_mrstft_weight: Optional[float] = None + pre_emph_mrstft_coef: Optional[float] = None + + @classmethod + def parse_config(cls, config): + config = super().parse_config(config) + return { + "fourier": config.get("fourier", False), + "mask_first": config.get("mask_first", 0), + "dc_weight": config.get("dc_weight"), + "val_loss": ValidationLoss(config.get("val_loss", "mse")), + "pre_emph_coef": config.get("pre_emph_coef"), + "pre_emph_weight": config.get("pre_emph_weight"), + "mrstft_weight": cls._get_mrstft_weight(config), + "pre_emph_mrstft_weight": config.get("pre_emph_mrstft_weight"), + "pre_emph_mrstft_coef": config.get("pre_emph_mrstft_coef"), + } + + def apply_mask(self, *args): + """ + :param args: (L,) or (B,) + :return: (L-M,) or (B, L-M) + """ + return tuple(a[..., self.mask_first :] for a in args) + + @classmethod + def _get_mrstft_weight(cls, config) -> Optional[float]: + key = "mrstft_weight" + wrong_key = "mstft_key" # Backward compatibility + if key in config: + if "mstft_weight" in config: + raise ValueError( + f"Received loss configuration with both '{key}' and " + f"'{wrong_key}'. Provide only '{key}'." + ) + return config[key] + elif wrong_key in config: + logger.warning( + f"Use of '{wrong_key}' is deprecated and will be removed in a future " + f"version. Use '{key}' instead." + ) + return config[wrong_key] + else: + return None + + +class _LossItem(NamedTuple): + weight: Optional[float] + value: Optional[torch.Tensor] + + +_model_net_init_registry = { + "ConvNet": ConvNet.init_from_config, + "Linear": Linear.init_from_config, + "LSTM": LSTM.init_from_config, + "WaveNet": WaveNet.init_from_config, +} + + +class LightningModule(pl.LightningModule, InitializableFromConfig): + """ + The PyTorch Lightning Module that unites the model with its loss and + optimization recipe. + """ + + def __init__( + self, + net, + optimizer_config: Optional[dict] = None, + scheduler_config: Optional[dict] = None, + loss_config: Optional[LossConfig] = None, + ): + """ + :param scheduler_config: contains + Required: + * "class" + * "kwargs" + Optional (defaults to Lightning defaults): + * "interval" ("epoch" of "step") + * "frequency" (int) + * "monitor" (str) + """ + super().__init__() + self._net = net + self._optimizer_config = {} if optimizer_config is None else optimizer_config + self._scheduler_config = scheduler_config + self._loss_config = LossConfig() if loss_config is None else loss_config + self._mrstft = None # Multi-resolution short-time Fourier transform loss + # Where to compute the MRSTFT. + # Keeping it on-device is preferable, but if that fails, then remember to drop + # it to cpu from then on. + self._mrstft_device: Optional[torch.device] = None + + @classmethod + def init_from_config(cls, config): + checkpoint_path = config.get("checkpoint_path") + config = cls.parse_config(config) + return ( + cls(**config) + if checkpoint_path is None + else cls.load_from_checkpoint(checkpoint_path, **config) + ) + + @classmethod + def parse_config(cls, config): + """ + e.g. + + { + "net": { + "name": "ConvNet", + "config": {...} + }, + "loss": { + "dc_weight": 0.1 + }, + "optimizer": { + "lr": 0.0003 + }, + "lr_scheduler": { + "class": "ReduceLROnPlateau", + "kwargs": { + "factor": 0.8, + "patience": 10, + "cooldown": 15, + "min_lr": 1e-06, + "verbose": true + }, + "monitor": "val_loss" + } + } + """ + config = super().parse_config(config) + net_config = config["net"] + net = _model_net_init_registry[net_config["name"]](net_config["config"]) + loss_config = LossConfig.init_from_config(config.get("loss", {})) + return { + "net": net, + "optimizer_config": config["optimizer"], + "scheduler_config": config["lr_scheduler"], + "loss_config": loss_config, + } + + @classmethod + def register_net_initializer(cls, name, constructor, overwrite: bool = False): + if name in _model_net_init_registry and not overwrite: + raise KeyError( + f"A constructor for net name '{name}' is already registered!" + ) + _model_net_init_registry[name] = constructor + + @property + def net(self) -> nn.Module: + return self._net + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), **self._optimizer_config) + if self._scheduler_config is None: + return optimizer + else: + lr_scheduler = getattr( + torch.optim.lr_scheduler, self._scheduler_config["class"] + )(optimizer, **self._scheduler_config["kwargs"]) + lr_scheduler_config = {"scheduler": lr_scheduler} + for key in ("interval", "frequency", "monitor"): + if key in self._scheduler_config: + lr_scheduler_config[key] = self._scheduler_config[key] + return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} + + def forward(self, *args, **kwargs): + return self.net(*args, **kwargs) # TODO deprecate--use self.net() instead. + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + # Resolves https://github.com/sdatkinson/neural-amp-modeler/issues/351 + self.net.sample_rate = checkpoint["sample_rate"] + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + # Resolves https://github.com/sdatkinson/neural-amp-modeler/issues/351 + checkpoint["sample_rate"] = self.net.sample_rate + + def _shared_step( + self, batch + ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, _LossItem]]: + """ + B: Batch size + L: Sequence length + + :return: (B,L), (B,L) + """ + args, targets = batch[:-1], batch[-1] + preds = self(*args, pad_start=False) + + # Compute all relevant losses. + loss_dict = {} # Mind keys versus validation loss requested... + # Prediction aka MSE loss + if self._loss_config.fourier: + loss_dict["MSE_FFT"] = _LossItem(1.0, mse_fft(preds, targets)) + else: + loss_dict["MSE"] = _LossItem(1.0, self._mse_loss(preds, targets)) + # Pre-emphasized MSE + if self._loss_config.pre_emph_weight is not None: + if (self._loss_config.pre_emph_coef is None) != ( + self._loss_config.pre_emph_weight is None + ): + raise ValueError("Invalid pre-emph") + loss_dict["Pre-emphasized MSE"] = _LossItem( + self._loss_config.pre_emph_weight, + self._mse_loss( + preds, targets, pre_emph_coef=self._loss_config.pre_emph_coef + ), + ) + # Multi-resolution short-time Fourier transform loss + if self._loss_config.mrstft_weight is not None: + loss_dict["MRSTFT"] = _LossItem( + self._loss_config.mrstft_weight, self._mrstft_loss(preds, targets) + ) + # Pre-emphasized MRSTFT + if self._loss_config.pre_emph_mrstft_weight is not None: + loss_dict["Pre-emphasized MRSTFT"] = _LossItem( + self._loss_config.pre_emph_mrstft_weight, + self._mrstft_loss( + preds, targets, pre_emph_coef=self._loss_config.pre_emph_mrstft_coef + ), + ) + # DC loss + dc_weight = self._loss_config.dc_weight + if dc_weight is not None and dc_weight > 0.0: + # Denominator could be a bad idea. I'm going to omit it esp since I'm + # using mini batches + mean_dims = torch.arange(1, preds.ndim).tolist() + dc_loss = nn.MSELoss()( + preds.mean(dim=mean_dims), targets.mean(dim=mean_dims) + ) + loss_dict["DC MSE"] = _LossItem(dc_weight, dc_loss) + + return preds, targets, loss_dict + + def training_step(self, batch, batch_idx): + _, _, loss_dict = self._shared_step(batch) + + loss = 0.0 + for v in loss_dict.values(): + if v.weight is not None and v.weight > 0.0: + loss = loss + v.weight * v.value + return loss + + def validation_step(self, batch, batch_idx): + preds, targets, loss_dict = self._shared_step(batch) + + def get_val_loss(): + # "esr" -> "ESR" + # "mse" -> "MSE" + # Others unsupported... + # TODO better mapping from Enum to dict keys + val_loss_type = self._loss_config.val_loss + val_loss_key_for_loss_dict = val_loss_type.value.upper() + if val_loss_key_for_loss_dict in loss_dict: + return loss_dict[val_loss_key_for_loss_dict].value + else: + raise RuntimeError( + f"Undefined validation loss routine for {val_loss_type}" + ) + + loss_dict["ESR"] = _LossItem(None, self._esr_loss(preds, targets)) + val_loss = get_val_loss() + self.log_dict( + { + "val_loss": val_loss, + **{key: value.value for key, value in loss_dict.items()}, + } + ) + return val_loss + + def _esr_loss(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """ + Error signal ratio aka ESR loss. + + Eq. (10), from + https://www.mdpi.com/2076-3417/10/3/766/htm + + B: Batch size + L: Sequence length + + :param preds: (B,L) + :param targets: (B,L) + :return: () + """ + return esr(preds, targets) + + def _mse_loss(self, preds, targets, pre_emph_coef: Optional[float] = None): + if pre_emph_coef is not None: + preds, targets = [ + apply_pre_emphasis_filter(z, pre_emph_coef) for z in (preds, targets) + ] + return nn.MSELoss()(preds, targets) + + def _mrstft_loss( + self, + preds: torch.Tensor, + targets: torch.Tensor, + pre_emph_coef: Optional[float] = None, + ) -> torch.Tensor: + """ + Experimental Multi Resolution Short Time Fourier Transform Loss using auraloss implementation. + B: Batch size + L: Sequence length + + :param preds: (B,L) + :param targets: (B,L) + :return: () + """ + if self._mrstft is None: + self._mrstft = auraloss.freq.MultiResolutionSTFTLoss() + backup_device = "cpu" + + if pre_emph_coef is not None: + preds, targets = [ + apply_pre_emphasis_filter(z, pre_emph_coef) for z in (preds, targets) + ] + + try: + return multi_resolution_stft_loss( + preds, targets, self._mrstft, device=self._mrstft_device + ) + except Exception as e: + if self._mrstft_device == backup_device: + raise e + logger.warning("MRSTFT failed on device; falling back to CPU") + self._mrstft_device = backup_device + return multi_resolution_stft_loss( + preds, targets, self._mrstft, device=self._mrstft_device + ) diff --git a/setup.py b/setup.py @@ -18,7 +18,7 @@ def get_additional_requirements(): additional_requirements.append("transformers>=4") except ModuleNotFoundError: pass - + # Issue 494 def get_numpy_requirement() -> str: need_numpy_1 = True # Until proven otherwise @@ -34,9 +34,10 @@ def get_additional_requirements(): except ModuleNotFoundError: # Until I see PyTorch 2.3 come out: pass - return "numpy<2"if need_numpy_1 else "numpy" + return "numpy<2" if need_numpy_1 else "numpy" + additional_requirements.append(get_numpy_requirement()) - + return additional_requirements diff --git a/tests/test_nam/test_models/_convolutional.py b/tests/test_nam/test_models/_convolutional.py @@ -13,7 +13,9 @@ from .base import Base as _Base class Convolutional(_Base): - @_pytest.mark.skipif(not _torch.backends.mps.is_available(), reason="MPS-specific test") + @_pytest.mark.skipif( + not _torch.backends.mps.is_available(), reason="MPS-specific test" + ) def test_process_input_longer_than_65536(self): """ Processing inputs longer than 65,536 samples using the MPS backend can diff --git a/tests/test_nam/test_models/test_base.py b/tests/test_nam/test_models/test_base.py @@ -14,12 +14,11 @@ from typing import Optional import numpy as np import pytest import torch -from auraloss.freq import MultiResolutionSTFTLoss -from nam.models import _base, base +from nam.models import base -class _MockBaseNet(_base.BaseNet): +class MockBaseNet(base.BaseNet): def __init__(self, gain: float, *args, **kwargs): super().__init__(*args, **kwargs) self.gain = gain @@ -46,14 +45,14 @@ class _MockBaseNet(_base.BaseNet): def test_metadata_gain(): - obj = _MockBaseNet(1.0) + obj = MockBaseNet(1.0) g = obj._metadata_gain() # It's linear, so gain is zero. assert g == 0.0 def test_metadata_loudness(): - obj = _MockBaseNet(1.0) + obj = MockBaseNet(1.0) y = obj._metadata_loudness() obj.gain = 2.0 y2 = obj._metadata_loudness() @@ -62,55 +61,6 @@ def test_metadata_loudness(): assert y2 == pytest.approx(y + 20.0 * math.log10(2.0)) -@pytest.mark.parametrize( - "batch_size,sequence_length", ((16, 8192), (3, 2048), (1, 4000)) -) -def test_mrstft_loss(batch_size: int, sequence_length: int): - obj = base.Model( - _MockBaseNet(1.0), loss_config=base.LossConfig(mrstft_weight=0.0002) - ) - preds = torch.randn((batch_size, sequence_length)) - targets = torch.randn(preds.shape) - loss = obj._mrstft_loss(preds, targets) - assert isinstance(loss, torch.Tensor) - assert loss.ndim == 0 - - -def test_mrstft_loss_cpu_fallback(mocker): - """ - Assert that fallback to CPU happens on failure - - :param mocker: Provided by pytest-mock - """ - - def mocked_loss( - preds: torch.Tensor, - targets: torch.Tensor, - loss_func: Optional[MultiResolutionSTFTLoss] = None, - device: Optional[torch.device] = None, - ) -> torch.Tensor: - """ - As if the device doesn't support it - """ - if device != "cpu": - raise RuntimeError("Trigger fallback") - return torch.tensor(1.0) - - mocker.patch("nam.models.base.multi_resolution_stft_loss", mocked_loss) - - batch_size = 3 - sequence_length = 4096 - obj = base.Model( - _MockBaseNet(1.0), loss_config=base.LossConfig(mrstft_weight=0.0002) - ) - preds = torch.randn((batch_size, sequence_length)) - targets = torch.randn(preds.shape) - - assert obj._mrstft_device is None - obj._mrstft_loss(preds, targets) # Should trigger fallback - assert obj._mrstft_device == "cpu" - - class TestSampleRate(object): """ Tests for sample_rate interface @@ -118,12 +68,12 @@ class TestSampleRate(object): @pytest.mark.parametrize("expected_sample_rate", (None, 44_100.0, 48_000.0)) def test_on_init(self, expected_sample_rate: Optional[float]): - model = _MockBaseNet(gain=1.0, sample_rate=expected_sample_rate) + model = MockBaseNet(gain=1.0, sample_rate=expected_sample_rate) self._wrap_assert(model, expected_sample_rate) @pytest.mark.parametrize("expected_sample_rate", (None, 44_100.0, 48_000.0)) def test_setter(self, expected_sample_rate: Optional[float]): - model = _MockBaseNet(gain=1.0) + model = MockBaseNet(gain=1.0) model.sample_rate = expected_sample_rate self._wrap_assert(model, expected_sample_rate) @@ -134,16 +84,16 @@ class TestSampleRate(object): https://github.com/sdatkinson/neural-amp-modeler/issues/351 """ - model = _MockBaseNet(gain=1.0, sample_rate=expected_sample_rate) + model = MockBaseNet(gain=1.0, sample_rate=expected_sample_rate) with TemporaryDirectory() as tmpdir: model_path = Path(tmpdir, "model.pt") torch.save(model.state_dict(), model_path) - model2 = _MockBaseNet(gain=1.0) + model2 = MockBaseNet(gain=1.0) model2.load_state_dict(torch.load(model_path)) self._wrap_assert(model2, expected_sample_rate) @classmethod - def _wrap_assert(cls, model: _MockBaseNet, expected: Optional[float]): + def _wrap_assert(cls, model: MockBaseNet, expected: Optional[float]): actual = model.sample_rate if expected is None: assert actual is None diff --git a/tests/test_nam/test_models/test_wavenet.py b/tests/test_nam/test_models/test_wavenet.py @@ -24,7 +24,7 @@ class TestWaveNet(_Convolutional): "head_size": 1, "channels": 1, "kernel_size": 1, - "dilations": [1] + "dilations": [1], } ] } @@ -46,7 +46,7 @@ class TestWaveNet(_Convolutional): assert not torch.allclose(y2_before, y1) assert torch.allclose(y2_after, y1) - + if __name__ == "__main__": pytest.main() diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py @@ -17,7 +17,7 @@ from nam.data import ( wav_to_tensor, _DEFAULT_REQUIRE_INPUT_PRE_SILENCE, ) -from nam.models import Model +from nam.train.lightning_module import LightningModule from nam.train import core from nam.train._version import Version @@ -292,7 +292,7 @@ def test_end_to_end(): fast_dev_run=True, ) # Assertions... - assert isinstance(train_output.model, Model) + assert isinstance(train_output.model, LightningModule) def test_get_callbacks(): @@ -305,19 +305,23 @@ def test_get_callbacks(): # dumb example of a user-extended custom callback class CustomCallback: pass + extended_callbacks = callbacks + [CustomCallback()] # sanity default callbacks - assert any(isinstance(cb, core._ModelCheckpoint) for cb in extended_callbacks), \ - "Expected _ModelCheckpoint to be part of the default callbacks." + assert any( + isinstance(cb, core._ModelCheckpoint) for cb in extended_callbacks + ), "Expected _ModelCheckpoint to be part of the default callbacks." # custom callback - assert any(isinstance(cb, CustomCallback) for cb in extended_callbacks), \ - "Expected CustomCallback to be added to the extended callbacks." + assert any( + isinstance(cb, CustomCallback) for cb in extended_callbacks + ), "Expected CustomCallback to be added to the extended callbacks." # _ValidationStopping cb when threshold_esr is prvided - assert any(isinstance(cb, core._ValidationStopping) for cb in extended_callbacks), \ - "_ValidationStopping should still be present after adding a custom callback." + assert any( + isinstance(cb, core._ValidationStopping) for cb in extended_callbacks + ), "_ValidationStopping should still be present after adding a custom callback." if __name__ == "__main__": diff --git a/tests/test_nam/test_train/test_lightning_module.py b/tests/test_nam/test_train/test_lightning_module.py @@ -0,0 +1,68 @@ +# File: test_lightning_module.py +# Created Date: Sunday November 24th 2024 +# Author: Steven Atkinson (steven@atkinson.mn) + +from typing import Optional as _Optional + +import pytest as _pytest +import torch as _torch +from auraloss.freq import MultiResolutionSTFTLoss as _MultiResolutionSTFTLoss + +from nam.train import lightning_module as _lightning_module + +from ..test_models.test_base import MockBaseNet as _MockBaseNet + + +@_pytest.mark.parametrize( + "batch_size,sequence_length", ((16, 8192), (3, 2048), (1, 4000)) +) +def test_mrstft_loss(batch_size: int, sequence_length: int): + obj = _lightning_module.LightningModule( + _MockBaseNet(1.0), + loss_config=_lightning_module.LossConfig(mrstft_weight=0.0002), + ) + preds = _torch.randn((batch_size, sequence_length)) + targets = _torch.randn(preds.shape) + loss = obj._mrstft_loss(preds, targets) + assert isinstance(loss, _torch.Tensor) + assert loss.ndim == 0 + + +def test_mrstft_loss_cpu_fallback(mocker): + """ + Assert that fallback to CPU happens on failure + + :param mocker: Provided by pytest-mock + """ + + def mocked_loss( + preds: _torch.Tensor, + targets: _torch.Tensor, + loss_func: _Optional[_MultiResolutionSTFTLoss] = None, + device: _Optional[_torch.device] = None, + ) -> _torch.Tensor: + """ + As if the device doesn't support it + """ + if device != "cpu": + raise RuntimeError("Trigger fallback") + return _torch.tensor(1.0) + + mocker.patch("nam.train.lightning_module.multi_resolution_stft_loss", mocked_loss) + + batch_size = 3 + sequence_length = 4096 + obj = _lightning_module.LightningModule( + _MockBaseNet(1.0), + loss_config=_lightning_module.LossConfig(mrstft_weight=0.0002), + ) + preds = _torch.randn((batch_size, sequence_length)) + targets = _torch.randn(preds.shape) + + assert obj._mrstft_device is None + obj._mrstft_loss(preds, targets) # Should trigger fallback + assert obj._mrstft_device == "cpu" + + +if __name__ == "__main__": + _pytest.main()