commit 1d354a00dc75719e49624cb5af296ee66d5e5155
parent 75a1b3442d1b0bab2e74f88ee8fe6451866a93c7
Author: Steven Atkinson <[email protected]>
Date: Fri, 31 Mar 2023 00:28:57 -0500
Run multi-resolution STFT loss on-device unless it fails (#165)
* MRSTFT on device
* MRSTFT loss device fallback
Diffstat:
6 files changed, 95 insertions(+), 11 deletions(-)
diff --git a/environment_cpu.yml b/environment_cpu.yml
@@ -4,6 +4,7 @@
name: nam
channels:
+ - conda-forge # pytest-mock
- pytorch
dependencies:
- black
@@ -14,6 +15,7 @@ dependencies:
- numpy
- pip
- pytest
+ - pytest-mock
- pytorch
- scipy
- semver
diff --git a/environment_gpu.yml b/environment_gpu.yml
@@ -4,6 +4,7 @@
name: nam
channels:
+ - conda-forge # pytest-mock
- pytorch
- nvidia
dependencies:
@@ -15,6 +16,7 @@ dependencies:
- numpy
- pip
- pytest
+ - pytest-mock
- pytorch
- pytorch-cuda=11.7
- scipy
diff --git a/nam/models/base.py b/nam/models/base.py
@@ -15,6 +15,7 @@ from enum import Enum
from typing import Optional, Tuple
import auraloss
+import logging
import pytorch_lightning as pl
import torch
import torch.nn as nn
@@ -22,12 +23,14 @@ import torch.nn as nn
from .._core import InitializableFromConfig
from .conv_net import ConvNet
from .linear import Linear
-from .losses import esr, mse_fft
+from .losses import esr, multi_resolution_stft_loss, mse_fft
from .parametric.catnets import CatLSTM, CatWaveNet
from .parametric.hyper_net import HyperConvNet
from .recurrent import LSTM
from .wavenet import WaveNet
+logger = logging.getLogger(__name__)
+
class ValidationLoss(Enum):
"""
@@ -55,7 +58,7 @@ class LossConfig(InitializableFromConfig):
https://www.mdpi.com/2076-3417/10/3/766. Paper value: 0.95.
"""
- mstft_weight: float = 0.0 # 0.0 means no multiresolution stft loss, 2e-4 works pretty well if one wants to use it
+ mrstft_weight: float = 0.0 # 0.0 means no multiresolution stft loss, 2e-4 works pretty well if one wants to use it
fourier: bool = False
mask_first: int = 0
dc_weight: float = 0.0
@@ -72,6 +75,7 @@ class LossConfig(InitializableFromConfig):
mask_first = config.get("mask_first", 0)
pre_emph_coef = config.get("pre_emph_coef")
pre_emph_weight = config.get("pre_emph_weight")
+ mrstft_weight = config.get("mstft_weight", 0.0)
return {
"fourier": fourier,
"mask_first": mask_first,
@@ -79,6 +83,7 @@ class LossConfig(InitializableFromConfig):
"val_loss": val_loss,
"pre_emph_coef": pre_emph_coef,
"pre_emph_weight": pre_emph_weight,
+ "mrstft_weight": mrstft_weight,
}
def apply_mask(self, *args):
@@ -113,6 +118,10 @@ class Model(pl.LightningModule, InitializableFromConfig):
self._scheduler_config = scheduler_config
self._loss_config = LossConfig() if loss_config is None else loss_config
self._mrstft = None # Multi-resolution short-time Fourier transform loss
+ # Where to compute the MRSTFT.
+ # Keeping it on-device is preferable, but if that fails, then remember to drop
+ # it to cpu from then on.
+ self._mrstft_device: Optional[torch.device] = None
@classmethod
def init_from_config(cls, config):
@@ -214,8 +223,8 @@ class Model(pl.LightningModule, InitializableFromConfig):
loss = loss + mse_fft(preds, targets)
else:
loss = loss + self._mse_loss(preds, targets)
- if self._loss_config.mstft_weight > 0.0:
- loss = loss + self._loss_config.mstft_weight * self._mrstft_loss(
+ if self._loss_config.mrstft_weight > 0.0:
+ loss = loss + self._loss_config.mrstft_weight * self._mrstft_loss(
preds, targets
)
# Pre-emphasized MSE
@@ -248,7 +257,7 @@ class Model(pl.LightningModule, InitializableFromConfig):
self._loss_config.val_loss
]
dict_to_log = {"MSE": mse_loss, "ESR": esr_loss, "val_loss": val_loss}
- if self._loss_config.mstft_weight > 0.0 and self._mrstft is not None:
+ if self._loss_config.mrstft_weight > 0.0 and self._mrstft is not None:
mrstft_loss = self._mrstft_loss(preds, targets)
dict_to_log.update({"MRSTFT": mrstft_loss})
self.log_dict(dict_to_log)
@@ -290,9 +299,17 @@ class Model(pl.LightningModule, InitializableFromConfig):
if self._mrstft is None:
self._mrstft = auraloss.freq.MultiResolutionSTFTLoss()
- device = "cpu" # not all platforms support this on gpu yet
- preds_cpu = preds.to(device)
- targets_cpu = targets.to(device)
+ backup_device = "cpu"
- loss = self._mrstft(preds_cpu, targets_cpu)
- return loss
+ try:
+ return multi_resolution_stft_loss(
+ preds, targets, self._mrstft, device=self._mrstft_device
+ )
+ except Exception as e:
+ if self._mrstft_device == backup_device:
+ raise e
+ logger.warning("MRSTFT failed on device; falling back to CPU")
+ self._mrstft_device = backup_device
+ return multi_resolution_stft_loss(
+ preds, targets, self._mrstft, device=self._mrstft_device
+ )
diff --git a/nam/models/losses.py b/nam/models/losses.py
@@ -6,7 +6,10 @@
Loss functions
"""
+from typing import Optional
+
import torch
+from auraloss.freq import MultiResolutionSTFTLoss
def esr(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
@@ -33,6 +36,30 @@ def esr(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
)
+def multi_resolution_stft_loss(
+ preds: torch.Tensor,
+ targets: torch.Tensor,
+ loss_func: Optional[MultiResolutionSTFTLoss] = None,
+ device: Optional[torch.device] = None,
+) -> torch.Tensor:
+ """
+ Experimental Multi Resolution Short Time Fourier Transform Loss using auraloss implementation.
+ B: Batch size
+ L: Sequence length
+
+ :param preds: (B,L)
+ :param targets: (B,L)
+ :param loss_func: A pre-initialized instance of the loss function module. Providing
+ this saves time.
+ :param device: If provided, send the preds and targets to the provided device.
+ :return: ()
+ """
+ loss_func = MultiResolutionSTFTLoss() if loss_func is None else loss_func
+ if device is not None:
+ preds, targets = [z.to(device) for z in (preds, targets)]
+ return loss_func(preds, targets)
+
+
def mse_fft(preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
Fourier loss
diff --git a/requirements.txt b/requirements.txt
@@ -12,6 +12,7 @@ onnxruntime
pip
pre-commit
pytest
+pytest-mock
pytorch_lightning
scipy
sounddevice
diff --git a/tests/test_nam/test_models/test_base.py b/tests/test_nam/test_models/test_base.py
@@ -4,10 +4,12 @@
import math
from pathlib import Path
+from typing import Optional
import numpy as np
import pytest
import torch
+from auraloss.freq import MultiResolutionSTFTLoss
from nam.models import _base, base
@@ -53,7 +55,7 @@ def test_loudness():
)
def test_mrstft_loss(batch_size: int, sequence_length: int):
obj = base.Model(
- _MockBaseNet(1.0), loss_config=base.LossConfig(mstft_weight=0.0002)
+ _MockBaseNet(1.0), loss_config=base.LossConfig(mrstft_weight=0.0002)
)
preds = torch.randn((batch_size, sequence_length))
targets = torch.randn(preds.shape)
@@ -62,5 +64,38 @@ def test_mrstft_loss(batch_size: int, sequence_length: int):
assert loss.ndim == 0
+def test_mrstft_loss_cpu_fallback(mocker):
+ """
+ Assert that fallback to CPU happens on failure
+ """
+
+ def mocked_loss(
+ preds: torch.Tensor,
+ targets: torch.Tensor,
+ loss_func: Optional[MultiResolutionSTFTLoss] = None,
+ device: Optional[torch.device] = None,
+ ) -> torch.Tensor:
+ """
+ As if the device doesn't support it
+ """
+ if device != "cpu":
+ raise RuntimeError("Trigger fallback")
+ return torch.tensor(1.0)
+
+ mocker.patch("nam.models.base.multi_resolution_stft_loss", mocked_loss)
+
+ batch_size = 3
+ sequence_length = 4096
+ obj = base.Model(
+ _MockBaseNet(1.0), loss_config=base.LossConfig(mrstft_weight=0.0002)
+ )
+ preds = torch.randn((batch_size, sequence_length))
+ targets = torch.randn(preds.shape)
+
+ assert obj._mrstft_device is None
+ obj._mrstft_loss(preds, targets) # Should trigger fallback
+ assert obj._mrstft_device == "cpu"
+
+
if __name__ == "__main__":
pytest.main()