commit 560ef39566465adf26988de3582c96a7ddde1e2e
parent f2c3ff91c94bd06b19ed3e0052ade1e4dea02391
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sun, 24 Nov 2024 23:53:46 -0800
[BREAKING] Rename modules (#511)
* Re-organize and deprecations
* nam.models.base to nam.train.lightning_module
* Remove recurrent.LSTMCore
* Refactor lightning module tests to new file
* Black
* typo
* nam.models._base to nam.models.base
Diffstat:
17 files changed, 742 insertions(+), 858 deletions(-)
diff --git a/nam/models/__init__.py b/nam/models/__init__.py
@@ -6,10 +6,10 @@
NAM's neural networks
"""
-from . import _base # noqa F401
+from . import base # noqa F401
from . import exportable # noqa F401
from . import losses # noqa F401
-from . import wavenet # noqa F401
-from .base import Model # noqa F401
-from .linear import Linear # noqa F401
from .conv_net import ConvNet # noqa F401
+from .linear import Linear # noqa F401
+from .recurrent import LSTM # noqa F401
+from .wavenet import WaveNet # noqa F401
diff --git a/nam/models/_base.py b/nam/models/_base.py
@@ -1,256 +0,0 @@
-# File: _base.py
-# Created Date: Tuesday February 8th 2022
-# Author: Steven Atkinson (steven@atkinson.mn)
-
-"""
-The foundation of the model without the PyTorch Lightning attributes (losses, training
-steps)
-"""
-
-import abc
-import math
-import pkg_resources
-from typing import Any, Dict, Optional, Tuple, Union
-
-import numpy as np
-import torch
-import torch.nn as nn
-
-from .._core import InitializableFromConfig
-from ..data import wav_to_tensor
-from .exportable import Exportable
-
-
-class _Base(nn.Module, InitializableFromConfig, Exportable):
- def __init__(self, sample_rate: Optional[float] = None):
- super().__init__()
- self.register_buffer(
- "_has_sample_rate", torch.tensor(sample_rate is not None, dtype=torch.bool)
- )
- self.register_buffer(
- "_sample_rate", torch.tensor(0.0 if sample_rate is None else sample_rate)
- )
-
- @property
- @abc.abstractmethod
- def pad_start_default(self) -> bool:
- pass
-
- @property
- @abc.abstractmethod
- def receptive_field(self) -> int:
- """
- Receptive field of the model
- """
- pass
-
- @abc.abstractmethod
- def forward(self, *args, **kwargs) -> torch.Tensor:
- pass
-
- @classmethod
- def _metadata_loudness_x(cls) -> torch.Tensor:
- return wav_to_tensor(
- pkg_resources.resource_filename(
- "nam", "models/_resources/loudness_input.wav"
- )
- )
-
- @property
- def device(self) -> Optional[torch.device]:
- """
- Helpful property, where the parameters of the model live.
- """
- # We can do this because the models are tiny and I don't expect a NAM to be on
- # multiple devices
- try:
- return next(self.parameters()).device
- except StopIteration:
- return None
-
- @property
- def sample_rate(self) -> Optional[float]:
- return self._sample_rate.item() if self._has_sample_rate else None
-
- @sample_rate.setter
- def sample_rate(self, val: Optional[float]):
- self._has_sample_rate = torch.tensor(val is not None, dtype=torch.bool)
- self._sample_rate = torch.tensor(0.0 if val is None else val)
-
- def _get_export_dict(self):
- d = super()._get_export_dict()
- sample_rate_key = "sample_rate"
- if sample_rate_key in d:
- raise RuntimeError(
- "Model wants to put 'sample_rate' into model export dict, but the key "
- "is already taken!"
- )
- d[sample_rate_key] = self.sample_rate
- return d
-
- def _metadata_loudness(self, gain: float = 1.0, db: bool = True) -> float:
- """
- How loud is this model when given a standardized input?
- In dB
-
- :param gain: Multiplies input signal
- """
- x = self._metadata_loudness_x().to(self.device)
- y = self._at_nominal_settings(gain * x)
- loudness = torch.sqrt(torch.mean(torch.square(y)))
- if db:
- loudness = 20.0 * torch.log10(loudness)
- return loudness.item()
-
- def _metadata_gain(self) -> float:
- """
- Between 0 and 1, how much gain / compression does the model seem to have?
- """
- x = np.linspace(0.0, 1.0, 11)
- y = np.array([self._metadata_loudness(gain=gain, db=False) for gain in x])
- #
- # O ^ o o o o o o
- # u | o x +-------------------------------------+
- # t | o x | x: Minimum gain (no compression) |
- # p | o x | o: Max gain (100% compression) |
- # u | o x +-------------------------------------+
- # t | o
- # +------------->
- # Input
- #
- max_gain = y[-1] * len(x) # "Square"
- min_gain = 0.5 * max_gain # "Triangle"
- gain_range = max_gain - min_gain
- this_gain = y.sum()
- normalized_gain = (this_gain - min_gain) / gain_range
- return np.clip(normalized_gain, 0.0, 1.0)
-
- def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor:
- # parametric?...
- raise NotImplementedError()
-
- @abc.abstractmethod
- def _forward(self, *args) -> torch.Tensor:
- """
- The true forward method.
-
- :param x: (N,L1)
- :return: (N,L1-RF+1)
- """
- pass
-
- def _export_input_output_args(self) -> Tuple[Any]:
- """
- Create any other args necessesary (e.g. params to eval at)
- """
- return ()
-
- def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]:
- args = self._export_input_output_args()
- rate = self.sample_rate
- if rate is None:
- raise RuntimeError(
- "Cannot export model's input and output without a sample rate."
- )
- x = torch.cat(
- [
- torch.zeros((rate,)),
- 0.5
- * torch.sin(
- 2.0 * math.pi * 220.0 * torch.linspace(0.0, 1.0, rate + 1)[:-1]
- ),
- torch.zeros((rate,)),
- ]
- )
- # Use pad start to ensure same length as requested by ._export_input_output()
- return (
- x.detach().cpu().numpy(),
- self(*args, x, pad_start=True).detach().cpu().numpy(),
- )
-
-
-class BaseNet(_Base):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self._mps_65536_fallback = False
-
- def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None, **kwargs):
- pad_start = self.pad_start_default if pad_start is None else pad_start
- scalar = x.ndim == 1
- if scalar:
- x = x[None]
- if pad_start:
- x = torch.cat(
- (torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), dim=1
- )
- if x.shape[1] < self.receptive_field:
- raise ValueError(
- f"Input has {x.shape[1]} samples, which is too few for this model with "
- f"receptive field {self.receptive_field}!"
- )
- y = self._forward_mps_safe(x, **kwargs)
- if scalar:
- y = y[0]
- return y
-
- def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor:
- return self(x)
-
- def _forward_mps_safe(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
- """
- Wrap `._forward()` to protect against MPS-unsupported inptu lengths
- beyond 65,536 samples.
-
- Check this again when PyTorch 2.5.2 is released--hopefully it's fixed
- then.
- """
- if not self._mps_65536_fallback:
- try:
- return self._forward(x, **kwargs)
- except NotImplementedError as e:
- if "Output channels > 65536 not supported at the MPS device." in str(e):
- print(
- "===WARNING===\n"
- "NAM encountered a bug in PyTorch's MPS backend and will "
- "switch to a fallback.\n"
- f"Your version of PyTorch is {torch.__version__}.\n"
- "Please report this in an Issue at:\n"
- "https://github.com/sdatkinson/neural-amp-modeler/issues/new/choose"
- "\n"
- "so that NAM's dependencies can avoid buggy versions of "
- "PyTorch and the associated performance hit."
- )
- self._mps_65536_fallback = True
- return self._forward_mps_safe(x, **kwargs)
- else:
- raise e
- else:
- # Stitch together the output one piece at a time to avoid the MPS error
- stride = 65_536 - (self.receptive_field - 1)
- # We need to make sure that the last segment is big enough that we have the required history for the receptive field.
- out_list = []
- for i in range(0, x.shape[1], stride):
- j = min(i+65_536, x.shape[1])
- xi = x[:, i:j]
- out_list.append(self._forward(xi, **kwargs))
- # Bit hacky, but correct.
- if j == x.shape[1]:
- break
- return torch.cat(out_list, dim=1)
-
-
- @abc.abstractmethod
- def _forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
- """
- The true forward method.
-
- :param x: (N,L1)
- :return: (N,L1-RF+1)
- """
- pass
-
- def _get_non_user_metadata(self) -> Dict[str, Union[str, int, float]]:
- d = super()._get_non_user_metadata()
- d["loudness"] = self._metadata_loudness()
- d["gain"] = self._metadata_gain()
- return d
diff --git a/nam/models/base.py b/nam/models/base.py
@@ -1,395 +1,255 @@
-# File: base.py
-# Created Date: Saturday February 5th 2022
+# File: _base.py
+# Created Date: Tuesday February 8th 2022
# Author: Steven Atkinson (steven@atkinson.mn)
"""
-Implements the base PyTorch Lightning module.
-This is meant to combine an actual model (subclassed from `._base.BaseNet`)
-along with loss function boilerplate.
-
-For the base *PyTorch* model containing the actual architecture, see `._base`.
+The foundation of the model without the PyTorch Lightning attributes (losses, training
+steps)
"""
-from dataclasses import dataclass
-from enum import Enum
-from typing import Any, Dict, NamedTuple, Optional, Tuple
+import abc
+import math
+import pkg_resources
+from typing import Any, Dict, Optional, Tuple, Union
-import auraloss
-import logging
-import pytorch_lightning as pl
+import numpy as np
import torch
import torch.nn as nn
from .._core import InitializableFromConfig
-from .conv_net import ConvNet
-from .linear import Linear
-from .losses import apply_pre_emphasis_filter, esr, multi_resolution_stft_loss, mse_fft
-from .recurrent import LSTM
-from .wavenet import WaveNet
-
-logger = logging.getLogger(__name__)
-
+from ..data import wav_to_tensor
+from .exportable import Exportable
-class ValidationLoss(Enum):
- """
- mse: mean squared error
- esr: error signal ratio (Eq. (10) from
- https://www.mdpi.com/2076-3417/10/3/766/htm
- NOTE: Be careful when computing ESR on minibatches! The average ESR over
- a minibatch of data not the same as the ESR of all of the same data in
- the minibatch calculated over at once (because of the denominator).
- (Hint: think about what happens if one item in the minibatch is all
- zeroes...)
- """
- MSE = "mse"
- ESR = "esr"
+class _Base(nn.Module, InitializableFromConfig, Exportable):
+ def __init__(self, sample_rate: Optional[float] = None):
+ super().__init__()
+ self.register_buffer(
+ "_has_sample_rate", torch.tensor(sample_rate is not None, dtype=torch.bool)
+ )
+ self.register_buffer(
+ "_sample_rate", torch.tensor(0.0 if sample_rate is None else sample_rate)
+ )
+ @property
+ @abc.abstractmethod
+ def pad_start_default(self) -> bool:
+ pass
-@dataclass
-class LossConfig(InitializableFromConfig):
- """
- :param mrstft_weight: Multi-resolution short-time Fourier transform loss
- coefficient. None means to skip; 2e-4 works pretty well if one wants to use it.
- :param mask_first: How many of the first samples to ignore when comptuing the loss.
- :param dc_weight: Weight for the DC loss term. If 0, ignored.
- :params val_loss: Which loss to track for the best model checkpoint.
- :param pre_emph_coef: Coefficient of 1st-order pre-emphasis filter from
- https://www.mdpi.com/2076-3417/10/3/766. Paper value: 0.95.
- :param pre_
- """
+ @property
+ @abc.abstractmethod
+ def receptive_field(self) -> int:
+ """
+ Receptive field of the model
+ """
+ pass
- mrstft_weight: Optional[float] = None
- fourier: bool = False
- mask_first: int = 0
- dc_weight: float = None
- val_loss: ValidationLoss = ValidationLoss.MSE
- pre_emph_weight: Optional[float] = None
- pre_emph_coef: Optional[float] = None
- pre_emph_mrstft_weight: Optional[float] = None
- pre_emph_mrstft_coef: Optional[float] = None
+ @abc.abstractmethod
+ def forward(self, *args, **kwargs) -> torch.Tensor:
+ pass
@classmethod
- def parse_config(cls, config):
- config = super().parse_config(config)
- return {
- "fourier": config.get("fourier", False),
- "mask_first": config.get("mask_first", 0),
- "dc_weight": config.get("dc_weight"),
- "val_loss": ValidationLoss(config.get("val_loss", "mse")),
- "pre_emph_coef": config.get("pre_emph_coef"),
- "pre_emph_weight": config.get("pre_emph_weight"),
- "mrstft_weight": cls._get_mrstft_weight(config),
- "pre_emph_mrstft_weight": config.get("pre_emph_mrstft_weight"),
- "pre_emph_mrstft_coef": config.get("pre_emph_mrstft_coef"),
- }
+ def _metadata_loudness_x(cls) -> torch.Tensor:
+ return wav_to_tensor(
+ pkg_resources.resource_filename(
+ "nam", "models/_resources/loudness_input.wav"
+ )
+ )
- def apply_mask(self, *args):
+ @property
+ def device(self) -> Optional[torch.device]:
"""
- :param args: (L,) or (B,)
- :return: (L-M,) or (B, L-M)
+ Helpful property, where the parameters of the model live.
"""
- return tuple(a[..., self.mask_first :] for a in args)
-
- @classmethod
- def _get_mrstft_weight(cls, config) -> Optional[float]:
- key = "mrstft_weight"
- wrong_key = "mstft_key" # Backward compatibility
- if key in config:
- if "mstft_weight" in config:
- raise ValueError(
- f"Received loss configuration with both '{key}' and "
- f"'{wrong_key}'. Provide only '{key}'."
- )
- return config[key]
- elif wrong_key in config:
- logger.warning(
- f"Use of '{wrong_key}' is deprecated and will be removed in a future "
- f"version. Use '{key}' instead."
- )
- return config[wrong_key]
- else:
+ # We can do this because the models are tiny and I don't expect a NAM to be on
+ # multiple devices
+ try:
+ return next(self.parameters()).device
+ except StopIteration:
return None
+ @property
+ def sample_rate(self) -> Optional[float]:
+ return self._sample_rate.item() if self._has_sample_rate else None
+
+ @sample_rate.setter
+ def sample_rate(self, val: Optional[float]):
+ self._has_sample_rate = torch.tensor(val is not None, dtype=torch.bool)
+ self._sample_rate = torch.tensor(0.0 if val is None else val)
+
+ def _get_export_dict(self):
+ d = super()._get_export_dict()
+ sample_rate_key = "sample_rate"
+ if sample_rate_key in d:
+ raise RuntimeError(
+ "Model wants to put 'sample_rate' into model export dict, but the key "
+ "is already taken!"
+ )
+ d[sample_rate_key] = self.sample_rate
+ return d
-class _LossItem(NamedTuple):
- weight: Optional[float]
- value: Optional[torch.Tensor]
-
-
-_model_net_init_registry = {
- "ConvNet": ConvNet.init_from_config,
- "Linear": Linear.init_from_config,
- "LSTM": LSTM.init_from_config,
- "WaveNet": WaveNet.init_from_config,
-}
-
+ def _metadata_loudness(self, gain: float = 1.0, db: bool = True) -> float:
+ """
+ How loud is this model when given a standardized input?
+ In dB
-class Model(pl.LightningModule, InitializableFromConfig):
- def __init__(
- self,
- net,
- optimizer_config: Optional[dict] = None,
- scheduler_config: Optional[dict] = None,
- loss_config: Optional[LossConfig] = None,
- ):
+ :param gain: Multiplies input signal
"""
- :param scheduler_config: contains
- Required:
- * "class"
- * "kwargs"
- Optional (defaults to Lightning defaults):
- * "interval" ("epoch" of "step")
- * "frequency" (int)
- * "monitor" (str)
+ x = self._metadata_loudness_x().to(self.device)
+ y = self._at_nominal_settings(gain * x)
+ loudness = torch.sqrt(torch.mean(torch.square(y)))
+ if db:
+ loudness = 20.0 * torch.log10(loudness)
+ return loudness.item()
+
+ def _metadata_gain(self) -> float:
"""
- super().__init__()
- self._net = net
- self._optimizer_config = {} if optimizer_config is None else optimizer_config
- self._scheduler_config = scheduler_config
- self._loss_config = LossConfig() if loss_config is None else loss_config
- self._mrstft = None # Multi-resolution short-time Fourier transform loss
- # Where to compute the MRSTFT.
- # Keeping it on-device is preferable, but if that fails, then remember to drop
- # it to cpu from then on.
- self._mrstft_device: Optional[torch.device] = None
-
- @classmethod
- def init_from_config(cls, config):
- checkpoint_path = config.get("checkpoint_path")
- config = cls.parse_config(config)
- return (
- cls(**config)
- if checkpoint_path is None
- else cls.load_from_checkpoint(checkpoint_path, **config)
- )
-
- @classmethod
- def parse_config(cls, config):
+ Between 0 and 1, how much gain / compression does the model seem to have?
"""
- e.g.
-
- {
- "net": {
- "name": "ConvNet",
- "config": {...}
- },
- "loss": {
- "dc_weight": 0.1
- },
- "optimizer": {
- "lr": 0.0003
- },
- "lr_scheduler": {
- "class": "ReduceLROnPlateau",
- "kwargs": {
- "factor": 0.8,
- "patience": 10,
- "cooldown": 15,
- "min_lr": 1e-06,
- "verbose": true
- },
- "monitor": "val_loss"
- }
- }
+ x = np.linspace(0.0, 1.0, 11)
+ y = np.array([self._metadata_loudness(gain=gain, db=False) for gain in x])
+ #
+ # O ^ o o o o o o
+ # u | o x +-------------------------------------+
+ # t | o x | x: Minimum gain (no compression) |
+ # p | o x | o: Max gain (100% compression) |
+ # u | o x +-------------------------------------+
+ # t | o
+ # +------------->
+ # Input
+ #
+ max_gain = y[-1] * len(x) # "Square"
+ min_gain = 0.5 * max_gain # "Triangle"
+ gain_range = max_gain - min_gain
+ this_gain = y.sum()
+ normalized_gain = (this_gain - min_gain) / gain_range
+ return np.clip(normalized_gain, 0.0, 1.0)
+
+ def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor:
+ # parametric?...
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def _forward(self, *args) -> torch.Tensor:
"""
- config = super().parse_config(config)
- net_config = config["net"]
- net = _model_net_init_registry[net_config["name"]](net_config["config"])
- loss_config = LossConfig.init_from_config(config.get("loss", {}))
- return {
- "net": net,
- "optimizer_config": config["optimizer"],
- "scheduler_config": config["lr_scheduler"],
- "loss_config": loss_config,
- }
-
- @classmethod
- def register_net_initializer(cls, name, constructor, overwrite: bool = False):
- if name in _model_net_init_registry and not overwrite:
- raise KeyError(
- f"A constructor for net name '{name}' is already registered!"
- )
- _model_net_init_registry[name] = constructor
+ The true forward method.
- @property
- def net(self) -> nn.Module:
- return self._net
-
- def configure_optimizers(self):
- optimizer = torch.optim.Adam(self.parameters(), **self._optimizer_config)
- if self._scheduler_config is None:
- return optimizer
- else:
- lr_scheduler = getattr(
- torch.optim.lr_scheduler, self._scheduler_config["class"]
- )(optimizer, **self._scheduler_config["kwargs"])
- lr_scheduler_config = {"scheduler": lr_scheduler}
- for key in ("interval", "frequency", "monitor"):
- if key in self._scheduler_config:
- lr_scheduler_config[key] = self._scheduler_config[key]
- return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
-
- def forward(self, *args, **kwargs):
- return self.net(*args, **kwargs) # TODO deprecate--use self.net() instead.
-
- def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
- # Resolves https://github.com/sdatkinson/neural-amp-modeler/issues/351
- self.net.sample_rate = checkpoint["sample_rate"]
-
- def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
- # Resolves https://github.com/sdatkinson/neural-amp-modeler/issues/351
- checkpoint["sample_rate"] = self.net.sample_rate
-
- def _shared_step(
- self, batch
- ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, _LossItem]]:
+ :param x: (N,L1)
+ :return: (N,L1-RF+1)
"""
- B: Batch size
- L: Sequence length
+ pass
- :return: (B,L), (B,L)
+ def _export_input_output_args(self) -> Tuple[Any]:
"""
- args, targets = batch[:-1], batch[-1]
- preds = self(*args, pad_start=False)
-
- # Compute all relevant losses.
- loss_dict = {} # Mind keys versus validation loss requested...
- # Prediction aka MSE loss
- if self._loss_config.fourier:
- loss_dict["MSE_FFT"] = _LossItem(1.0, mse_fft(preds, targets))
- else:
- loss_dict["MSE"] = _LossItem(1.0, self._mse_loss(preds, targets))
- # Pre-emphasized MSE
- if self._loss_config.pre_emph_weight is not None:
- if (self._loss_config.pre_emph_coef is None) != (
- self._loss_config.pre_emph_weight is None
- ):
- raise ValueError("Invalid pre-emph")
- loss_dict["Pre-emphasized MSE"] = _LossItem(
- self._loss_config.pre_emph_weight,
- self._mse_loss(
- preds, targets, pre_emph_coef=self._loss_config.pre_emph_coef
- ),
- )
- # Multi-resolution short-time Fourier transform loss
- if self._loss_config.mrstft_weight is not None:
- loss_dict["MRSTFT"] = _LossItem(
- self._loss_config.mrstft_weight, self._mrstft_loss(preds, targets)
+ Create any other args necessesary (e.g. params to eval at)
+ """
+ return ()
+
+ def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]:
+ args = self._export_input_output_args()
+ rate = self.sample_rate
+ if rate is None:
+ raise RuntimeError(
+ "Cannot export model's input and output without a sample rate."
)
- # Pre-emphasized MRSTFT
- if self._loss_config.pre_emph_mrstft_weight is not None:
- loss_dict["Pre-emphasized MRSTFT"] = _LossItem(
- self._loss_config.pre_emph_mrstft_weight,
- self._mrstft_loss(
- preds, targets, pre_emph_coef=self._loss_config.pre_emph_mrstft_coef
+ x = torch.cat(
+ [
+ torch.zeros((rate,)),
+ 0.5
+ * torch.sin(
+ 2.0 * math.pi * 220.0 * torch.linspace(0.0, 1.0, rate + 1)[:-1]
),
- )
- # DC loss
- dc_weight = self._loss_config.dc_weight
- if dc_weight is not None and dc_weight > 0.0:
- # Denominator could be a bad idea. I'm going to omit it esp since I'm
- # using mini batches
- mean_dims = torch.arange(1, preds.ndim).tolist()
- dc_loss = nn.MSELoss()(
- preds.mean(dim=mean_dims), targets.mean(dim=mean_dims)
- )
- loss_dict["DC MSE"] = _LossItem(dc_weight, dc_loss)
-
- return preds, targets, loss_dict
-
- def training_step(self, batch, batch_idx):
- _, _, loss_dict = self._shared_step(batch)
+ torch.zeros((rate,)),
+ ]
+ )
+ # Use pad start to ensure same length as requested by ._export_input_output()
+ return (
+ x.detach().cpu().numpy(),
+ self(*args, x, pad_start=True).detach().cpu().numpy(),
+ )
- loss = 0.0
- for v in loss_dict.values():
- if v.weight is not None and v.weight > 0.0:
- loss = loss + v.weight * v.value
- return loss
- def validation_step(self, batch, batch_idx):
- preds, targets, loss_dict = self._shared_step(batch)
+class BaseNet(_Base):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._mps_65536_fallback = False
- def get_val_loss():
- # "esr" -> "ESR"
- # "mse" -> "MSE"
- # Others unsupported...
- # TODO better mapping from Enum to dict keys
- val_loss_type = self._loss_config.val_loss
- val_loss_key_for_loss_dict = val_loss_type.value.upper()
- if val_loss_key_for_loss_dict in loss_dict:
- return loss_dict[val_loss_key_for_loss_dict].value
- else:
- raise RuntimeError(
- f"Undefined validation loss routine for {val_loss_type}"
- )
+ def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None, **kwargs):
+ pad_start = self.pad_start_default if pad_start is None else pad_start
+ scalar = x.ndim == 1
+ if scalar:
+ x = x[None]
+ if pad_start:
+ x = torch.cat(
+ (torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), dim=1
+ )
+ if x.shape[1] < self.receptive_field:
+ raise ValueError(
+ f"Input has {x.shape[1]} samples, which is too few for this model with "
+ f"receptive field {self.receptive_field}!"
+ )
+ y = self._forward_mps_safe(x, **kwargs)
+ if scalar:
+ y = y[0]
+ return y
- loss_dict["ESR"] = _LossItem(None, self._esr_loss(preds, targets))
- val_loss = get_val_loss()
- self.log_dict(
- {
- "val_loss": val_loss,
- **{key: value.value for key, value in loss_dict.items()},
- }
- )
- return val_loss
+ def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor:
+ return self(x)
- def _esr_loss(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+ def _forward_mps_safe(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
"""
- Error signal ratio aka ESR loss.
-
- Eq. (10), from
- https://www.mdpi.com/2076-3417/10/3/766/htm
+ Wrap `._forward()` to protect against MPS-unsupported input lengths
+ beyond 65,536 samples.
- B: Batch size
- L: Sequence length
-
- :param preds: (B,L)
- :param targets: (B,L)
- :return: ()
+ Check this again when PyTorch 2.5.2 is released--hopefully it's fixed
+ then.
"""
- return esr(preds, targets)
-
- def _mse_loss(self, preds, targets, pre_emph_coef: Optional[float] = None):
- if pre_emph_coef is not None:
- preds, targets = [
- apply_pre_emphasis_filter(z, pre_emph_coef) for z in (preds, targets)
- ]
- return nn.MSELoss()(preds, targets)
-
- def _mrstft_loss(
- self,
- preds: torch.Tensor,
- targets: torch.Tensor,
- pre_emph_coef: Optional[float] = None,
- ) -> torch.Tensor:
+ if not self._mps_65536_fallback:
+ try:
+ return self._forward(x, **kwargs)
+ except NotImplementedError as e:
+ if "Output channels > 65536 not supported at the MPS device." in str(e):
+ print(
+ "===WARNING===\n"
+ "NAM encountered a bug in PyTorch's MPS backend and will "
+ "switch to a fallback.\n"
+ f"Your version of PyTorch is {torch.__version__}.\n"
+ "Please report this in an Issue at:\n"
+ "https://github.com/sdatkinson/neural-amp-modeler/issues/new/choose"
+ "\n"
+ "so that NAM's dependencies can avoid buggy versions of "
+ "PyTorch and the associated performance hit."
+ )
+ self._mps_65536_fallback = True
+ return self._forward_mps_safe(x, **kwargs)
+ else:
+ raise e
+ else:
+ # Stitch together the output one piece at a time to avoid the MPS error
+ stride = 65_536 - (self.receptive_field - 1)
+ # We need to make sure that the last segment is big enough that we have the required history for the receptive field.
+ out_list = []
+ for i in range(0, x.shape[1], stride):
+ j = min(i + 65_536, x.shape[1])
+ xi = x[:, i:j]
+ out_list.append(self._forward(xi, **kwargs))
+ # Bit hacky, but correct.
+ if j == x.shape[1]:
+ break
+ return torch.cat(out_list, dim=1)
+
+ @abc.abstractmethod
+ def _forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
"""
- Experimental Multi Resolution Short Time Fourier Transform Loss using auraloss implementation.
- B: Batch size
- L: Sequence length
+ The true forward method.
- :param preds: (B,L)
- :param targets: (B,L)
- :return: ()
+ :param x: (N,L1)
+ :return: (N,L1-RF+1)
"""
- if self._mrstft is None:
- self._mrstft = auraloss.freq.MultiResolutionSTFTLoss()
- backup_device = "cpu"
-
- if pre_emph_coef is not None:
- preds, targets = [
- apply_pre_emphasis_filter(z, pre_emph_coef) for z in (preds, targets)
- ]
+ pass
- try:
- return multi_resolution_stft_loss(
- preds, targets, self._mrstft, device=self._mrstft_device
- )
- except Exception as e:
- if self._mrstft_device == backup_device:
- raise e
- logger.warning("MRSTFT failed on device; falling back to CPU")
- self._mrstft_device = backup_device
- return multi_resolution_stft_loss(
- preds, targets, self._mrstft, device=self._mrstft_device
- )
+ def _get_non_user_metadata(self) -> Dict[str, Union[str, int, float]]:
+ d = super()._get_non_user_metadata()
+ d["loudness"] = self._metadata_loudness()
+ d["gain"] = self._metadata_gain()
+ return d
diff --git a/nam/models/conv_net.py b/nam/models/conv_net.py
@@ -19,7 +19,7 @@ import torch.nn.functional as F
from .. import __version__
from ..data import wav_to_tensor
from ._activations import get_activation
-from ._base import BaseNet
+from .base import BaseNet
from ._names import ACTIVATION_NAME, BATCHNORM_NAME, CONV_NAME
diff --git a/nam/models/linear.py b/nam/models/linear.py
@@ -11,7 +11,7 @@ import torch
import torch.nn as nn
from .._version import __version__
-from ._base import BaseNet
+from .base import BaseNet
class Linear(BaseNet):
diff --git a/nam/models/recurrent.py b/nam/models/recurrent.py
@@ -18,10 +18,7 @@ import numpy as np
import torch
import torch.nn as nn
-from ._base import BaseNet
-
-
-# TODO merge LSTMCore into LSTM
+from .base import BaseNet
class _L(nn.LSTM):
@@ -55,72 +52,6 @@ _LSTMCellType = torch.Tensor
_LSTMHiddenCellType = Tuple[_LSTMHiddenType, _LSTMCellType]
-class LSTMCore(_L):
- def __init__(
- self,
- *args,
- train_burn_in: Optional[int] = None,
- train_truncate: Optional[int] = None,
- **kwargs,
- ):
- super().__init__(*args, **kwargs)
- if not self.batch_first:
- raise NotImplementedError("Need batch first")
- self._train_burn_in = train_burn_in
- self._train_truncate = train_truncate
- assert len(args) < 3, "Provide as kwargs"
- self._initial_cell = nn.Parameter(
- torch.zeros((self.num_layers, self.hidden_size))
- )
- self._initial_hidden = nn.Parameter(
- torch.zeros((self.num_layers, self.hidden_size))
- )
-
- def forward(self, x, hidden_state=None):
- """
- Same as nn.LSTM.forward except:
- * Learned inital state
- * truncated BPTT when .training
- """
- if x.ndim != 3:
- raise NotImplementedError("Need (B,L,D)")
- last_hidden_state = (
- self._initial_state(None if x.ndim == 2 else len(x))
- if hidden_state is None
- else hidden_state
- )
- if not self.training or self._train_truncate is None:
- output_features = super().forward(x, last_hidden_state)[0]
- else:
- output_features_list = []
- if self._train_burn_in is not None:
- last_output_features, last_hidden_state = super().forward(
- x[:, : self._train_burn_in, :], last_hidden_state
- )
- output_features_list.append(last_output_features.detach())
- burn_in_offset = 0 if self._train_burn_in is None else self._train_burn_in
- for i in range(burn_in_offset, x.shape[1], self._train_truncate):
- if i > burn_in_offset:
- # Don't detach the burn-in state so that we can learn it.
- last_hidden_state = tuple(z.detach() for z in last_hidden_state)
- last_output_features, last_hidden_state = super().forward(
- x[:, i : i + self._train_truncate, :], last_hidden_state
- )
- output_features_list.append(last_output_features)
- output_features = torch.cat(output_features_list, dim=1)
- return output_features
-
- def _initial_state(self, n: Optional[int]) -> _LSTMHiddenCellType:
- return (
- (self._initial_hidden, self._initial_cell)
- if n is None
- else (
- torch.tile(self._initial_hidden[:, None], (1, n, 1)),
- torch.tile(self._initial_cell[:, None], (1, n, 1)),
- )
- )
-
-
# TODO get this somewhere more core-ish
class _ExportsWeights(abc.ABC):
@abc.abstractmethod
@@ -371,91 +302,3 @@ class LSTM(BaseNet):
torch.tile(self._initial_cell[:, None], (1, n, 1)),
)
)
-
-
-# TODO refactor together
-
-
-class _SkippyLSTM(nn.Module):
- def __init__(
- self, input_size, hidden_size, skip_in: bool = False, num_layers=1, **kwargs
- ):
- super().__init__()
- layers_per_lstm = 1
- self._skip_in = skip_in
- self._lstms = nn.ModuleList(
- [
- _L(
- self._layer_input_size(input_size, hidden_size, i),
- hidden_size,
- layers_per_lstm,
- batch_first=True,
- )
- for i in range(num_layers)
- ]
- )
- self._initial_hidden = nn.Parameter(
- torch.zeros((self.num_layers, layers_per_lstm, self.hidden_size))
- )
- self._initial_cell = nn.Parameter(
- torch.zeros((self.num_layers, layers_per_lstm, self.hidden_size))
- )
-
- @property
- def hidden_size(self):
- return self._lstms[0].hidden_size
-
- @property
- def input_size(self):
- return self._lstms[0].input_size
-
- @property
- def num_layers(self):
- return len(self._lstms)
-
- @property
- def output_size(self):
- return self.num_layers * self.hidden_size
-
- def forward(self, input, state=None):
- """
- :param input: (N,L,DX)
- :param state: ((L,Li,N,DH), (L,Li,N,DH))
-
- :return: (N,L,L*DH), ((L,Li,N,DH), (L,Li,N,DH))
- """
- h0, c0 = self.initial_state(input) if state is None else state
- hiddens, h_arr, c_arr, hidden = [], [], [], None
- for layer, h0i, c0i in zip(self._lstms, h0, c0):
- if self._skip_in:
- # TODO dense-block
- layer_input = (
- input if hidden is None else torch.cat([input, hidden], dim=2)
- )
- else:
- layer_input = input if hidden is None else hidden
- hidden, (hi, ci) = layer(layer_input, (h0i, c0i))
- hiddens.append(hidden)
- h_arr.append(hi)
- c_arr.append(ci)
- return (torch.cat(hiddens, dim=2), (torch.stack(h_arr), torch.stack(c_arr)))
-
- def initial_state(self, input: torch.Tensor):
- """
- Initial states for all the layers
-
- :return: (L,B,Li,DH)
- """
- assert input.ndim == 3, "Batch only for now"
- batch_size = len(input) # Assume batch_first
- return (
- torch.tile(self._initial_hidden[:, :, None], (1, 1, batch_size, 1)),
- torch.tile(self._initial_cell[:, :, None], (1, 1, batch_size, 1)),
- )
-
- def _layer_input_size(self, input_size, hidden_size, i) -> int:
- # TODO dense-block
- if self._skip_in:
- return input_size + (0 if i == 0 else hidden_size)
- else:
- return input_size if i == 0 else hidden_size
diff --git a/nam/models/wavenet.py b/nam/models/wavenet.py
@@ -18,7 +18,7 @@ import torch
import torch.nn as nn
from ._activations import get_activation
-from ._base import BaseNet
+from .base import BaseNet
from ._names import ACTIVATION_NAME, CONV_NAME
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -26,12 +26,12 @@ from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
from ..data import DataError, Split, init_dataset, wav_to_np, wav_to_tensor
-from ..models import Model
from ..models.exportable import Exportable
from ..models.losses import esr
from ..models.metadata import UserMetadata
from ..util import filter_warnings
from ._version import PROTEUS_VERSION, Version
+from .lightning_module import LightningModule
from . import metadata
# Training using the simplified trainers in NAM is done at 48k.
@@ -424,13 +424,16 @@ def _calibrate_latency_v_all(
print(msg)
print("SHARE THIS PLOT IF YOU ASK FOR HELP")
plt.figure()
- plt.plot(np.arange(-lookahead, lookback), y_scan_average, color="C0", label="Signal average")
+ plt.plot(
+ np.arange(-lookahead, lookback),
+ y_scan_average,
+ color="C0",
+ label="Signal average",
+ )
for y_scan in y_scans:
plt.plot(np.arange(-lookahead, lookback), y_scan, color="C0", alpha=0.2)
plt.axvline(x=0, color="C1", linestyle="--", label="Trigger")
- plt.axhline(
- y=-trigger_threshold, color="k", linestyle="--", label="Threshold"
- )
+ plt.axhline(y=-trigger_threshold, color="k", linestyle="--", label="Threshold")
plt.axhline(y=trigger_threshold, color="k", linestyle="--")
plt.xlim((-lookahead, lookback))
plt.xlabel("Samples")
@@ -1050,7 +1053,7 @@ def _get_configs(
def _get_dataloaders(
- data_config: Dict, learning_config: Dict, model: Model
+ data_config: Dict, learning_config: Dict, model: LightningModule
) -> Tuple[DataLoader, DataLoader]:
data_config, learning_config = [deepcopy(c) for c in (data_config, learning_config)]
data_config["common"]["nx"] = model.net.receptive_field
@@ -1205,7 +1208,7 @@ class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint):
super()._save_checkpoint(trainer, filepath)
# Save the .nam:
nam_filepath = self._get_nam_filepath(filepath)
- pl_model: Model = trainer.model
+ pl_model: LightningModule = trainer.model
nam_model = pl_model.net
outdir = nam_filepath.parent
# HACK: Assume the extension
@@ -1273,7 +1276,7 @@ class TrainOutput(NamedTuple):
simplified trainer.
"""
- model: Optional[Model]
+ model: Optional[LightningModule]
metadata: metadata.TrainingMetadata
@@ -1414,7 +1417,7 @@ def train(
# * Model is re-instantiated after training anyways.
# (Hacky) solution: set sample rate in model from dataloader after second
# instantiation from final checkpoint.
- model = Model.init_from_config(model_config)
+ model = LightningModule.init_from_config(model_config)
train_dataloader, val_dataloader = _get_dataloaders(
data_config, learning_config, model
)
@@ -1449,9 +1452,9 @@ def train(
# Go to best checkpoint
best_checkpoint = trainer.checkpoint_callback.best_model_path
if best_checkpoint != "":
- model = Model.load_from_checkpoint(
+ model = LightningModule.load_from_checkpoint(
trainer.checkpoint_callback.best_model_path,
- **Model.parse_config(model_config),
+ **LightningModule.parse_config(model_config),
)
model.cpu()
model.eval()
diff --git a/nam/train/full.py b/nam/train/full.py
@@ -16,7 +16,7 @@ import torch
from torch.utils.data import DataLoader
from nam.data import ConcatDataset, Split, init_dataset
-from nam.models import Model
+from nam.train.lightning_module import LightningModule
from nam.util import filter_warnings
torch.manual_seed(0)
@@ -143,7 +143,7 @@ def main(
with open(Path(outdir, f"config_{basename}.json"), "w") as fp:
json.dump(config, fp, indent=4)
- model = Model.init_from_config(model_config)
+ model = LightningModule.init_from_config(model_config)
# Add receptive field to data config:
data_config["common"] = data_config.get("common", {})
if "nx" in data_config["common"]:
@@ -178,9 +178,9 @@ def main(
# Go to best checkpoint
best_checkpoint = trainer.checkpoint_callback.best_model_path
if best_checkpoint != "":
- model = Model.load_from_checkpoint(
+ model = LightningModule.load_from_checkpoint(
trainer.checkpoint_callback.best_model_path,
- **Model.parse_config(model_config),
+ **LightningModule.parse_config(model_config),
)
model.cpu()
model.eval()
diff --git a/nam/train/gui/__init__.py b/nam/train/gui/__init__.py
@@ -20,6 +20,7 @@ import webbrowser
from dataclasses import dataclass
from enum import Enum
from functools import partial
+
try: # Not supported in Colab
from idlelib.tooltip import Hovertip
except ModuleNotFoundError:
@@ -28,8 +29,11 @@ except ModuleNotFoundError:
"""
Shell class
"""
+
def __init__(self, *args, **kwargs):
pass
+
+
from pathlib import Path
from tkinter import filedialog
from typing import Any, Callable, Dict, NamedTuple, Optional, Sequence
diff --git a/nam/train/lightning_module.py b/nam/train/lightning_module.py
@@ -0,0 +1,405 @@
+# File: base.py
+# Created Date: Saturday February 5th 2022
+# Author: Steven Atkinson (steven@atkinson.mn)
+
+"""
+Implements the base PyTorch Lightning module.
+This is meant to combine an actual model (subclassed from `..models.base.BaseNet`)
+along with loss function boilerplate.
+
+For the base *PyTorch* model containing the actual architecture, see `..models.base`.
+"""
+
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, Dict, NamedTuple, Optional, Tuple
+
+import auraloss
+import logging
+import pytorch_lightning as pl
+import torch
+import torch.nn as nn
+
+from .._core import InitializableFromConfig
+from ..models.conv_net import ConvNet
+from ..models.linear import Linear
+from ..models.losses import (
+ apply_pre_emphasis_filter,
+ esr,
+ multi_resolution_stft_loss,
+ mse_fft,
+)
+from ..models.recurrent import LSTM
+from ..models.wavenet import WaveNet
+
+logger = logging.getLogger(__name__)
+
+
+class ValidationLoss(Enum):
+ """
+ mse: mean squared error
+ esr: error signal ratio (Eq. (10) from
+ https://www.mdpi.com/2076-3417/10/3/766/htm
+ NOTE: Be careful when computing ESR on minibatches! The average ESR over
+ a minibatch of data not the same as the ESR of all of the same data in
+ the minibatch calculated over at once (because of the denominator).
+ (Hint: think about what happens if one item in the minibatch is all
+ zeroes...)
+ """
+
+ MSE = "mse"
+ ESR = "esr"
+
+
+@dataclass
+class LossConfig(InitializableFromConfig):
+ """
+ :param mrstft_weight: Multi-resolution short-time Fourier transform loss
+ coefficient. None means to skip; 2e-4 works pretty well if one wants to use it.
+ :param mask_first: How many of the first samples to ignore when comptuing the loss.
+ :param dc_weight: Weight for the DC loss term. If 0, ignored.
+ :params val_loss: Which loss to track for the best model checkpoint.
+ :param pre_emph_coef: Coefficient of 1st-order pre-emphasis filter from
+ https://www.mdpi.com/2076-3417/10/3/766. Paper value: 0.95.
+ :param pre_
+ """
+
+ mrstft_weight: Optional[float] = None
+ fourier: bool = False
+ mask_first: int = 0
+ dc_weight: float = None
+ val_loss: ValidationLoss = ValidationLoss.MSE
+ pre_emph_weight: Optional[float] = None
+ pre_emph_coef: Optional[float] = None
+ pre_emph_mrstft_weight: Optional[float] = None
+ pre_emph_mrstft_coef: Optional[float] = None
+
+ @classmethod
+ def parse_config(cls, config):
+ config = super().parse_config(config)
+ return {
+ "fourier": config.get("fourier", False),
+ "mask_first": config.get("mask_first", 0),
+ "dc_weight": config.get("dc_weight"),
+ "val_loss": ValidationLoss(config.get("val_loss", "mse")),
+ "pre_emph_coef": config.get("pre_emph_coef"),
+ "pre_emph_weight": config.get("pre_emph_weight"),
+ "mrstft_weight": cls._get_mrstft_weight(config),
+ "pre_emph_mrstft_weight": config.get("pre_emph_mrstft_weight"),
+ "pre_emph_mrstft_coef": config.get("pre_emph_mrstft_coef"),
+ }
+
+ def apply_mask(self, *args):
+ """
+ :param args: (L,) or (B,)
+ :return: (L-M,) or (B, L-M)
+ """
+ return tuple(a[..., self.mask_first :] for a in args)
+
+ @classmethod
+ def _get_mrstft_weight(cls, config) -> Optional[float]:
+ key = "mrstft_weight"
+ wrong_key = "mstft_key" # Backward compatibility
+ if key in config:
+ if "mstft_weight" in config:
+ raise ValueError(
+ f"Received loss configuration with both '{key}' and "
+ f"'{wrong_key}'. Provide only '{key}'."
+ )
+ return config[key]
+ elif wrong_key in config:
+ logger.warning(
+ f"Use of '{wrong_key}' is deprecated and will be removed in a future "
+ f"version. Use '{key}' instead."
+ )
+ return config[wrong_key]
+ else:
+ return None
+
+
+class _LossItem(NamedTuple):
+ weight: Optional[float]
+ value: Optional[torch.Tensor]
+
+
+_model_net_init_registry = {
+ "ConvNet": ConvNet.init_from_config,
+ "Linear": Linear.init_from_config,
+ "LSTM": LSTM.init_from_config,
+ "WaveNet": WaveNet.init_from_config,
+}
+
+
+class LightningModule(pl.LightningModule, InitializableFromConfig):
+ """
+ The PyTorch Lightning Module that unites the model with its loss and
+ optimization recipe.
+ """
+
+ def __init__(
+ self,
+ net,
+ optimizer_config: Optional[dict] = None,
+ scheduler_config: Optional[dict] = None,
+ loss_config: Optional[LossConfig] = None,
+ ):
+ """
+ :param scheduler_config: contains
+ Required:
+ * "class"
+ * "kwargs"
+ Optional (defaults to Lightning defaults):
+ * "interval" ("epoch" of "step")
+ * "frequency" (int)
+ * "monitor" (str)
+ """
+ super().__init__()
+ self._net = net
+ self._optimizer_config = {} if optimizer_config is None else optimizer_config
+ self._scheduler_config = scheduler_config
+ self._loss_config = LossConfig() if loss_config is None else loss_config
+ self._mrstft = None # Multi-resolution short-time Fourier transform loss
+ # Where to compute the MRSTFT.
+ # Keeping it on-device is preferable, but if that fails, then remember to drop
+ # it to cpu from then on.
+ self._mrstft_device: Optional[torch.device] = None
+
+ @classmethod
+ def init_from_config(cls, config):
+ checkpoint_path = config.get("checkpoint_path")
+ config = cls.parse_config(config)
+ return (
+ cls(**config)
+ if checkpoint_path is None
+ else cls.load_from_checkpoint(checkpoint_path, **config)
+ )
+
+ @classmethod
+ def parse_config(cls, config):
+ """
+ e.g.
+
+ {
+ "net": {
+ "name": "ConvNet",
+ "config": {...}
+ },
+ "loss": {
+ "dc_weight": 0.1
+ },
+ "optimizer": {
+ "lr": 0.0003
+ },
+ "lr_scheduler": {
+ "class": "ReduceLROnPlateau",
+ "kwargs": {
+ "factor": 0.8,
+ "patience": 10,
+ "cooldown": 15,
+ "min_lr": 1e-06,
+ "verbose": true
+ },
+ "monitor": "val_loss"
+ }
+ }
+ """
+ config = super().parse_config(config)
+ net_config = config["net"]
+ net = _model_net_init_registry[net_config["name"]](net_config["config"])
+ loss_config = LossConfig.init_from_config(config.get("loss", {}))
+ return {
+ "net": net,
+ "optimizer_config": config["optimizer"],
+ "scheduler_config": config["lr_scheduler"],
+ "loss_config": loss_config,
+ }
+
+ @classmethod
+ def register_net_initializer(cls, name, constructor, overwrite: bool = False):
+ if name in _model_net_init_registry and not overwrite:
+ raise KeyError(
+ f"A constructor for net name '{name}' is already registered!"
+ )
+ _model_net_init_registry[name] = constructor
+
+ @property
+ def net(self) -> nn.Module:
+ return self._net
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.Adam(self.parameters(), **self._optimizer_config)
+ if self._scheduler_config is None:
+ return optimizer
+ else:
+ lr_scheduler = getattr(
+ torch.optim.lr_scheduler, self._scheduler_config["class"]
+ )(optimizer, **self._scheduler_config["kwargs"])
+ lr_scheduler_config = {"scheduler": lr_scheduler}
+ for key in ("interval", "frequency", "monitor"):
+ if key in self._scheduler_config:
+ lr_scheduler_config[key] = self._scheduler_config[key]
+ return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
+
+ def forward(self, *args, **kwargs):
+ return self.net(*args, **kwargs) # TODO deprecate--use self.net() instead.
+
+ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
+ # Resolves https://github.com/sdatkinson/neural-amp-modeler/issues/351
+ self.net.sample_rate = checkpoint["sample_rate"]
+
+ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
+ # Resolves https://github.com/sdatkinson/neural-amp-modeler/issues/351
+ checkpoint["sample_rate"] = self.net.sample_rate
+
+ def _shared_step(
+ self, batch
+ ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, _LossItem]]:
+ """
+ B: Batch size
+ L: Sequence length
+
+ :return: (B,L), (B,L)
+ """
+ args, targets = batch[:-1], batch[-1]
+ preds = self(*args, pad_start=False)
+
+ # Compute all relevant losses.
+ loss_dict = {} # Mind keys versus validation loss requested...
+ # Prediction aka MSE loss
+ if self._loss_config.fourier:
+ loss_dict["MSE_FFT"] = _LossItem(1.0, mse_fft(preds, targets))
+ else:
+ loss_dict["MSE"] = _LossItem(1.0, self._mse_loss(preds, targets))
+ # Pre-emphasized MSE
+ if self._loss_config.pre_emph_weight is not None:
+ if (self._loss_config.pre_emph_coef is None) != (
+ self._loss_config.pre_emph_weight is None
+ ):
+ raise ValueError("Invalid pre-emph")
+ loss_dict["Pre-emphasized MSE"] = _LossItem(
+ self._loss_config.pre_emph_weight,
+ self._mse_loss(
+ preds, targets, pre_emph_coef=self._loss_config.pre_emph_coef
+ ),
+ )
+ # Multi-resolution short-time Fourier transform loss
+ if self._loss_config.mrstft_weight is not None:
+ loss_dict["MRSTFT"] = _LossItem(
+ self._loss_config.mrstft_weight, self._mrstft_loss(preds, targets)
+ )
+ # Pre-emphasized MRSTFT
+ if self._loss_config.pre_emph_mrstft_weight is not None:
+ loss_dict["Pre-emphasized MRSTFT"] = _LossItem(
+ self._loss_config.pre_emph_mrstft_weight,
+ self._mrstft_loss(
+ preds, targets, pre_emph_coef=self._loss_config.pre_emph_mrstft_coef
+ ),
+ )
+ # DC loss
+ dc_weight = self._loss_config.dc_weight
+ if dc_weight is not None and dc_weight > 0.0:
+ # Denominator could be a bad idea. I'm going to omit it esp since I'm
+ # using mini batches
+ mean_dims = torch.arange(1, preds.ndim).tolist()
+ dc_loss = nn.MSELoss()(
+ preds.mean(dim=mean_dims), targets.mean(dim=mean_dims)
+ )
+ loss_dict["DC MSE"] = _LossItem(dc_weight, dc_loss)
+
+ return preds, targets, loss_dict
+
+ def training_step(self, batch, batch_idx):
+ _, _, loss_dict = self._shared_step(batch)
+
+ loss = 0.0
+ for v in loss_dict.values():
+ if v.weight is not None and v.weight > 0.0:
+ loss = loss + v.weight * v.value
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ preds, targets, loss_dict = self._shared_step(batch)
+
+ def get_val_loss():
+ # "esr" -> "ESR"
+ # "mse" -> "MSE"
+ # Others unsupported...
+ # TODO better mapping from Enum to dict keys
+ val_loss_type = self._loss_config.val_loss
+ val_loss_key_for_loss_dict = val_loss_type.value.upper()
+ if val_loss_key_for_loss_dict in loss_dict:
+ return loss_dict[val_loss_key_for_loss_dict].value
+ else:
+ raise RuntimeError(
+ f"Undefined validation loss routine for {val_loss_type}"
+ )
+
+ loss_dict["ESR"] = _LossItem(None, self._esr_loss(preds, targets))
+ val_loss = get_val_loss()
+ self.log_dict(
+ {
+ "val_loss": val_loss,
+ **{key: value.value for key, value in loss_dict.items()},
+ }
+ )
+ return val_loss
+
+ def _esr_loss(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+ """
+ Error signal ratio aka ESR loss.
+
+ Eq. (10), from
+ https://www.mdpi.com/2076-3417/10/3/766/htm
+
+ B: Batch size
+ L: Sequence length
+
+ :param preds: (B,L)
+ :param targets: (B,L)
+ :return: ()
+ """
+ return esr(preds, targets)
+
+ def _mse_loss(self, preds, targets, pre_emph_coef: Optional[float] = None):
+ if pre_emph_coef is not None:
+ preds, targets = [
+ apply_pre_emphasis_filter(z, pre_emph_coef) for z in (preds, targets)
+ ]
+ return nn.MSELoss()(preds, targets)
+
+ def _mrstft_loss(
+ self,
+ preds: torch.Tensor,
+ targets: torch.Tensor,
+ pre_emph_coef: Optional[float] = None,
+ ) -> torch.Tensor:
+ """
+ Experimental Multi Resolution Short Time Fourier Transform Loss using auraloss implementation.
+ B: Batch size
+ L: Sequence length
+
+ :param preds: (B,L)
+ :param targets: (B,L)
+ :return: ()
+ """
+ if self._mrstft is None:
+ self._mrstft = auraloss.freq.MultiResolutionSTFTLoss()
+ backup_device = "cpu"
+
+ if pre_emph_coef is not None:
+ preds, targets = [
+ apply_pre_emphasis_filter(z, pre_emph_coef) for z in (preds, targets)
+ ]
+
+ try:
+ return multi_resolution_stft_loss(
+ preds, targets, self._mrstft, device=self._mrstft_device
+ )
+ except Exception as e:
+ if self._mrstft_device == backup_device:
+ raise e
+ logger.warning("MRSTFT failed on device; falling back to CPU")
+ self._mrstft_device = backup_device
+ return multi_resolution_stft_loss(
+ preds, targets, self._mrstft, device=self._mrstft_device
+ )
diff --git a/setup.py b/setup.py
@@ -18,7 +18,7 @@ def get_additional_requirements():
additional_requirements.append("transformers>=4")
except ModuleNotFoundError:
pass
-
+
# Issue 494
def get_numpy_requirement() -> str:
need_numpy_1 = True # Until proven otherwise
@@ -34,9 +34,10 @@ def get_additional_requirements():
except ModuleNotFoundError:
# Until I see PyTorch 2.3 come out:
pass
- return "numpy<2"if need_numpy_1 else "numpy"
+ return "numpy<2" if need_numpy_1 else "numpy"
+
additional_requirements.append(get_numpy_requirement())
-
+
return additional_requirements
diff --git a/tests/test_nam/test_models/_convolutional.py b/tests/test_nam/test_models/_convolutional.py
@@ -13,7 +13,9 @@ from .base import Base as _Base
class Convolutional(_Base):
- @_pytest.mark.skipif(not _torch.backends.mps.is_available(), reason="MPS-specific test")
+ @_pytest.mark.skipif(
+ not _torch.backends.mps.is_available(), reason="MPS-specific test"
+ )
def test_process_input_longer_than_65536(self):
"""
Processing inputs longer than 65,536 samples using the MPS backend can
diff --git a/tests/test_nam/test_models/test_base.py b/tests/test_nam/test_models/test_base.py
@@ -14,12 +14,11 @@ from typing import Optional
import numpy as np
import pytest
import torch
-from auraloss.freq import MultiResolutionSTFTLoss
-from nam.models import _base, base
+from nam.models import base
-class _MockBaseNet(_base.BaseNet):
+class MockBaseNet(base.BaseNet):
def __init__(self, gain: float, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gain = gain
@@ -46,14 +45,14 @@ class _MockBaseNet(_base.BaseNet):
def test_metadata_gain():
- obj = _MockBaseNet(1.0)
+ obj = MockBaseNet(1.0)
g = obj._metadata_gain()
# It's linear, so gain is zero.
assert g == 0.0
def test_metadata_loudness():
- obj = _MockBaseNet(1.0)
+ obj = MockBaseNet(1.0)
y = obj._metadata_loudness()
obj.gain = 2.0
y2 = obj._metadata_loudness()
@@ -62,55 +61,6 @@ def test_metadata_loudness():
assert y2 == pytest.approx(y + 20.0 * math.log10(2.0))
-@pytest.mark.parametrize(
- "batch_size,sequence_length", ((16, 8192), (3, 2048), (1, 4000))
-)
-def test_mrstft_loss(batch_size: int, sequence_length: int):
- obj = base.Model(
- _MockBaseNet(1.0), loss_config=base.LossConfig(mrstft_weight=0.0002)
- )
- preds = torch.randn((batch_size, sequence_length))
- targets = torch.randn(preds.shape)
- loss = obj._mrstft_loss(preds, targets)
- assert isinstance(loss, torch.Tensor)
- assert loss.ndim == 0
-
-
-def test_mrstft_loss_cpu_fallback(mocker):
- """
- Assert that fallback to CPU happens on failure
-
- :param mocker: Provided by pytest-mock
- """
-
- def mocked_loss(
- preds: torch.Tensor,
- targets: torch.Tensor,
- loss_func: Optional[MultiResolutionSTFTLoss] = None,
- device: Optional[torch.device] = None,
- ) -> torch.Tensor:
- """
- As if the device doesn't support it
- """
- if device != "cpu":
- raise RuntimeError("Trigger fallback")
- return torch.tensor(1.0)
-
- mocker.patch("nam.models.base.multi_resolution_stft_loss", mocked_loss)
-
- batch_size = 3
- sequence_length = 4096
- obj = base.Model(
- _MockBaseNet(1.0), loss_config=base.LossConfig(mrstft_weight=0.0002)
- )
- preds = torch.randn((batch_size, sequence_length))
- targets = torch.randn(preds.shape)
-
- assert obj._mrstft_device is None
- obj._mrstft_loss(preds, targets) # Should trigger fallback
- assert obj._mrstft_device == "cpu"
-
-
class TestSampleRate(object):
"""
Tests for sample_rate interface
@@ -118,12 +68,12 @@ class TestSampleRate(object):
@pytest.mark.parametrize("expected_sample_rate", (None, 44_100.0, 48_000.0))
def test_on_init(self, expected_sample_rate: Optional[float]):
- model = _MockBaseNet(gain=1.0, sample_rate=expected_sample_rate)
+ model = MockBaseNet(gain=1.0, sample_rate=expected_sample_rate)
self._wrap_assert(model, expected_sample_rate)
@pytest.mark.parametrize("expected_sample_rate", (None, 44_100.0, 48_000.0))
def test_setter(self, expected_sample_rate: Optional[float]):
- model = _MockBaseNet(gain=1.0)
+ model = MockBaseNet(gain=1.0)
model.sample_rate = expected_sample_rate
self._wrap_assert(model, expected_sample_rate)
@@ -134,16 +84,16 @@ class TestSampleRate(object):
https://github.com/sdatkinson/neural-amp-modeler/issues/351
"""
- model = _MockBaseNet(gain=1.0, sample_rate=expected_sample_rate)
+ model = MockBaseNet(gain=1.0, sample_rate=expected_sample_rate)
with TemporaryDirectory() as tmpdir:
model_path = Path(tmpdir, "model.pt")
torch.save(model.state_dict(), model_path)
- model2 = _MockBaseNet(gain=1.0)
+ model2 = MockBaseNet(gain=1.0)
model2.load_state_dict(torch.load(model_path))
self._wrap_assert(model2, expected_sample_rate)
@classmethod
- def _wrap_assert(cls, model: _MockBaseNet, expected: Optional[float]):
+ def _wrap_assert(cls, model: MockBaseNet, expected: Optional[float]):
actual = model.sample_rate
if expected is None:
assert actual is None
diff --git a/tests/test_nam/test_models/test_wavenet.py b/tests/test_nam/test_models/test_wavenet.py
@@ -24,7 +24,7 @@ class TestWaveNet(_Convolutional):
"head_size": 1,
"channels": 1,
"kernel_size": 1,
- "dilations": [1]
+ "dilations": [1],
}
]
}
@@ -46,7 +46,7 @@ class TestWaveNet(_Convolutional):
assert not torch.allclose(y2_before, y1)
assert torch.allclose(y2_after, y1)
-
+
if __name__ == "__main__":
pytest.main()
diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py
@@ -17,7 +17,7 @@ from nam.data import (
wav_to_tensor,
_DEFAULT_REQUIRE_INPUT_PRE_SILENCE,
)
-from nam.models import Model
+from nam.train.lightning_module import LightningModule
from nam.train import core
from nam.train._version import Version
@@ -292,7 +292,7 @@ def test_end_to_end():
fast_dev_run=True,
)
# Assertions...
- assert isinstance(train_output.model, Model)
+ assert isinstance(train_output.model, LightningModule)
def test_get_callbacks():
@@ -305,19 +305,23 @@ def test_get_callbacks():
# dumb example of a user-extended custom callback
class CustomCallback:
pass
+
extended_callbacks = callbacks + [CustomCallback()]
# sanity default callbacks
- assert any(isinstance(cb, core._ModelCheckpoint) for cb in extended_callbacks), \
- "Expected _ModelCheckpoint to be part of the default callbacks."
+ assert any(
+ isinstance(cb, core._ModelCheckpoint) for cb in extended_callbacks
+ ), "Expected _ModelCheckpoint to be part of the default callbacks."
# custom callback
- assert any(isinstance(cb, CustomCallback) for cb in extended_callbacks), \
- "Expected CustomCallback to be added to the extended callbacks."
+ assert any(
+ isinstance(cb, CustomCallback) for cb in extended_callbacks
+ ), "Expected CustomCallback to be added to the extended callbacks."
# _ValidationStopping cb when threshold_esr is prvided
- assert any(isinstance(cb, core._ValidationStopping) for cb in extended_callbacks), \
- "_ValidationStopping should still be present after adding a custom callback."
+ assert any(
+ isinstance(cb, core._ValidationStopping) for cb in extended_callbacks
+ ), "_ValidationStopping should still be present after adding a custom callback."
if __name__ == "__main__":
diff --git a/tests/test_nam/test_train/test_lightning_module.py b/tests/test_nam/test_train/test_lightning_module.py
@@ -0,0 +1,68 @@
+# File: test_lightning_module.py
+# Created Date: Sunday November 24th 2024
+# Author: Steven Atkinson (steven@atkinson.mn)
+
+from typing import Optional as _Optional
+
+import pytest as _pytest
+import torch as _torch
+from auraloss.freq import MultiResolutionSTFTLoss as _MultiResolutionSTFTLoss
+
+from nam.train import lightning_module as _lightning_module
+
+from ..test_models.test_base import MockBaseNet as _MockBaseNet
+
+
+@_pytest.mark.parametrize(
+ "batch_size,sequence_length", ((16, 8192), (3, 2048), (1, 4000))
+)
+def test_mrstft_loss(batch_size: int, sequence_length: int):
+ obj = _lightning_module.LightningModule(
+ _MockBaseNet(1.0),
+ loss_config=_lightning_module.LossConfig(mrstft_weight=0.0002),
+ )
+ preds = _torch.randn((batch_size, sequence_length))
+ targets = _torch.randn(preds.shape)
+ loss = obj._mrstft_loss(preds, targets)
+ assert isinstance(loss, _torch.Tensor)
+ assert loss.ndim == 0
+
+
+def test_mrstft_loss_cpu_fallback(mocker):
+ """
+ Assert that fallback to CPU happens on failure
+
+ :param mocker: Provided by pytest-mock
+ """
+
+ def mocked_loss(
+ preds: _torch.Tensor,
+ targets: _torch.Tensor,
+ loss_func: _Optional[_MultiResolutionSTFTLoss] = None,
+ device: _Optional[_torch.device] = None,
+ ) -> _torch.Tensor:
+ """
+ As if the device doesn't support it
+ """
+ if device != "cpu":
+ raise RuntimeError("Trigger fallback")
+ return _torch.tensor(1.0)
+
+ mocker.patch("nam.train.lightning_module.multi_resolution_stft_loss", mocked_loss)
+
+ batch_size = 3
+ sequence_length = 4096
+ obj = _lightning_module.LightningModule(
+ _MockBaseNet(1.0),
+ loss_config=_lightning_module.LossConfig(mrstft_weight=0.0002),
+ )
+ preds = _torch.randn((batch_size, sequence_length))
+ targets = _torch.randn(preds.shape)
+
+ assert obj._mrstft_device is None
+ obj._mrstft_loss(preds, targets) # Should trigger fallback
+ assert obj._mrstft_device == "cpu"
+
+
+if __name__ == "__main__":
+ _pytest.main()