neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 1d354a00dc75719e49624cb5af296ee66d5e5155
parent 75a1b3442d1b0bab2e74f88ee8fe6451866a93c7
Author: Steven Atkinson <[email protected]>
Date:   Fri, 31 Mar 2023 00:28:57 -0500

Run multi-resolution STFT loss on-device unless it fails (#165)

* MRSTFT on device

* MRSTFT loss device fallback
Diffstat:
Menvironment_cpu.yml | 2++
Menvironment_gpu.yml | 2++
Mnam/models/base.py | 37+++++++++++++++++++++++++++----------
Mnam/models/losses.py | 27+++++++++++++++++++++++++++
Mrequirements.txt | 1+
Mtests/test_nam/test_models/test_base.py | 37++++++++++++++++++++++++++++++++++++-
6 files changed, 95 insertions(+), 11 deletions(-)

diff --git a/environment_cpu.yml b/environment_cpu.yml @@ -4,6 +4,7 @@ name: nam channels: + - conda-forge # pytest-mock - pytorch dependencies: - black @@ -14,6 +15,7 @@ dependencies: - numpy - pip - pytest + - pytest-mock - pytorch - scipy - semver diff --git a/environment_gpu.yml b/environment_gpu.yml @@ -4,6 +4,7 @@ name: nam channels: + - conda-forge # pytest-mock - pytorch - nvidia dependencies: @@ -15,6 +16,7 @@ dependencies: - numpy - pip - pytest + - pytest-mock - pytorch - pytorch-cuda=11.7 - scipy diff --git a/nam/models/base.py b/nam/models/base.py @@ -15,6 +15,7 @@ from enum import Enum from typing import Optional, Tuple import auraloss +import logging import pytorch_lightning as pl import torch import torch.nn as nn @@ -22,12 +23,14 @@ import torch.nn as nn from .._core import InitializableFromConfig from .conv_net import ConvNet from .linear import Linear -from .losses import esr, mse_fft +from .losses import esr, multi_resolution_stft_loss, mse_fft from .parametric.catnets import CatLSTM, CatWaveNet from .parametric.hyper_net import HyperConvNet from .recurrent import LSTM from .wavenet import WaveNet +logger = logging.getLogger(__name__) + class ValidationLoss(Enum): """ @@ -55,7 +58,7 @@ class LossConfig(InitializableFromConfig): https://www.mdpi.com/2076-3417/10/3/766. Paper value: 0.95. """ - mstft_weight: float = 0.0 # 0.0 means no multiresolution stft loss, 2e-4 works pretty well if one wants to use it + mrstft_weight: float = 0.0 # 0.0 means no multiresolution stft loss, 2e-4 works pretty well if one wants to use it fourier: bool = False mask_first: int = 0 dc_weight: float = 0.0 @@ -72,6 +75,7 @@ class LossConfig(InitializableFromConfig): mask_first = config.get("mask_first", 0) pre_emph_coef = config.get("pre_emph_coef") pre_emph_weight = config.get("pre_emph_weight") + mrstft_weight = config.get("mstft_weight", 0.0) return { "fourier": fourier, "mask_first": mask_first, @@ -79,6 +83,7 @@ class LossConfig(InitializableFromConfig): "val_loss": val_loss, "pre_emph_coef": pre_emph_coef, "pre_emph_weight": pre_emph_weight, + "mrstft_weight": mrstft_weight, } def apply_mask(self, *args): @@ -113,6 +118,10 @@ class Model(pl.LightningModule, InitializableFromConfig): self._scheduler_config = scheduler_config self._loss_config = LossConfig() if loss_config is None else loss_config self._mrstft = None # Multi-resolution short-time Fourier transform loss + # Where to compute the MRSTFT. + # Keeping it on-device is preferable, but if that fails, then remember to drop + # it to cpu from then on. + self._mrstft_device: Optional[torch.device] = None @classmethod def init_from_config(cls, config): @@ -214,8 +223,8 @@ class Model(pl.LightningModule, InitializableFromConfig): loss = loss + mse_fft(preds, targets) else: loss = loss + self._mse_loss(preds, targets) - if self._loss_config.mstft_weight > 0.0: - loss = loss + self._loss_config.mstft_weight * self._mrstft_loss( + if self._loss_config.mrstft_weight > 0.0: + loss = loss + self._loss_config.mrstft_weight * self._mrstft_loss( preds, targets ) # Pre-emphasized MSE @@ -248,7 +257,7 @@ class Model(pl.LightningModule, InitializableFromConfig): self._loss_config.val_loss ] dict_to_log = {"MSE": mse_loss, "ESR": esr_loss, "val_loss": val_loss} - if self._loss_config.mstft_weight > 0.0 and self._mrstft is not None: + if self._loss_config.mrstft_weight > 0.0 and self._mrstft is not None: mrstft_loss = self._mrstft_loss(preds, targets) dict_to_log.update({"MRSTFT": mrstft_loss}) self.log_dict(dict_to_log) @@ -290,9 +299,17 @@ class Model(pl.LightningModule, InitializableFromConfig): if self._mrstft is None: self._mrstft = auraloss.freq.MultiResolutionSTFTLoss() - device = "cpu" # not all platforms support this on gpu yet - preds_cpu = preds.to(device) - targets_cpu = targets.to(device) + backup_device = "cpu" - loss = self._mrstft(preds_cpu, targets_cpu) - return loss + try: + return multi_resolution_stft_loss( + preds, targets, self._mrstft, device=self._mrstft_device + ) + except Exception as e: + if self._mrstft_device == backup_device: + raise e + logger.warning("MRSTFT failed on device; falling back to CPU") + self._mrstft_device = backup_device + return multi_resolution_stft_loss( + preds, targets, self._mrstft, device=self._mrstft_device + ) diff --git a/nam/models/losses.py b/nam/models/losses.py @@ -6,7 +6,10 @@ Loss functions """ +from typing import Optional + import torch +from auraloss.freq import MultiResolutionSTFTLoss def esr(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: @@ -33,6 +36,30 @@ def esr(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: ) +def multi_resolution_stft_loss( + preds: torch.Tensor, + targets: torch.Tensor, + loss_func: Optional[MultiResolutionSTFTLoss] = None, + device: Optional[torch.device] = None, +) -> torch.Tensor: + """ + Experimental Multi Resolution Short Time Fourier Transform Loss using auraloss implementation. + B: Batch size + L: Sequence length + + :param preds: (B,L) + :param targets: (B,L) + :param loss_func: A pre-initialized instance of the loss function module. Providing + this saves time. + :param device: If provided, send the preds and targets to the provided device. + :return: () + """ + loss_func = MultiResolutionSTFTLoss() if loss_func is None else loss_func + if device is not None: + preds, targets = [z.to(device) for z in (preds, targets)] + return loss_func(preds, targets) + + def mse_fft(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ Fourier loss diff --git a/requirements.txt b/requirements.txt @@ -12,6 +12,7 @@ onnxruntime pip pre-commit pytest +pytest-mock pytorch_lightning scipy sounddevice diff --git a/tests/test_nam/test_models/test_base.py b/tests/test_nam/test_models/test_base.py @@ -4,10 +4,12 @@ import math from pathlib import Path +from typing import Optional import numpy as np import pytest import torch +from auraloss.freq import MultiResolutionSTFTLoss from nam.models import _base, base @@ -53,7 +55,7 @@ def test_loudness(): ) def test_mrstft_loss(batch_size: int, sequence_length: int): obj = base.Model( - _MockBaseNet(1.0), loss_config=base.LossConfig(mstft_weight=0.0002) + _MockBaseNet(1.0), loss_config=base.LossConfig(mrstft_weight=0.0002) ) preds = torch.randn((batch_size, sequence_length)) targets = torch.randn(preds.shape) @@ -62,5 +64,38 @@ def test_mrstft_loss(batch_size: int, sequence_length: int): assert loss.ndim == 0 +def test_mrstft_loss_cpu_fallback(mocker): + """ + Assert that fallback to CPU happens on failure + """ + + def mocked_loss( + preds: torch.Tensor, + targets: torch.Tensor, + loss_func: Optional[MultiResolutionSTFTLoss] = None, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """ + As if the device doesn't support it + """ + if device != "cpu": + raise RuntimeError("Trigger fallback") + return torch.tensor(1.0) + + mocker.patch("nam.models.base.multi_resolution_stft_loss", mocked_loss) + + batch_size = 3 + sequence_length = 4096 + obj = base.Model( + _MockBaseNet(1.0), loss_config=base.LossConfig(mrstft_weight=0.0002) + ) + preds = torch.randn((batch_size, sequence_length)) + targets = torch.randn(preds.shape) + + assert obj._mrstft_device is None + obj._mrstft_loss(preds, targets) # Should trigger fallback + assert obj._mrstft_device == "cpu" + + if __name__ == "__main__": pytest.main()