neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit faa1ead1d80bd664cc42305ac2b8bcf975d8e685
parent 840f10e4d2fdeae5ce51fb8c9586935a1c495e59
Author: Steven Atkinson <[email protected]>
Date:   Mon, 12 Jun 2023 19:34:07 -0700

Faster recurrent testing with shorter inputs (#275)


Diffstat:
Mnam/models/_base.py | 14+++++++++-----
Mnam/models/recurrent.py | 7++++++-
Mtests/test_nam/test_models/test_recurrent.py | 13++++++++++++-
3 files changed, 27 insertions(+), 7 deletions(-)

diff --git a/nam/models/_base.py b/nam/models/_base.py @@ -37,6 +37,14 @@ class _Base(nn.Module, InitializableFromConfig, Exportable): def forward(self, *args, **kwargs) -> torch.Tensor: pass + @classmethod + def _metadata_loudness_x(cls) -> torch.Tensor: + return wav_to_tensor( + pkg_resources.resource_filename( + "nam", "models/_resources/loudness_input.wav" + ) + ) + def _metadata_loudness(self, gain: float = 1.0, db: bool = True) -> float: """ How loud is this model when given a standardized input? @@ -44,11 +52,7 @@ class _Base(nn.Module, InitializableFromConfig, Exportable): :param gain: Multiplies input signal """ - x = wav_to_tensor( - pkg_resources.resource_filename( - "nam", "models/_resources/loudness_input.wav" - ) - ) + x = self._metadata_loudness_x() y = self._at_nominal_settings(gain * x) loudness = torch.sqrt(torch.mean(torch.square(y))) if db: diff --git a/nam/models/recurrent.py b/nam/models/recurrent.py @@ -158,6 +158,7 @@ class LSTM(BaseNet): self._initial_hidden = nn.Parameter( torch.zeros((lstm_kwargs.get("num_layers", 1), hidden_size)) ) + self._get_initial_state_burn_in = 48_000 @property def receptive_field(self) -> int: @@ -340,7 +341,11 @@ class LSTM(BaseNet): :return: (L,DH), (L,DH) """ - inputs = torch.zeros((1, 48_000, 1)) if inputs is None else inputs + inputs = ( + torch.zeros((1, self._get_initial_state_burn_in, 1)) + if inputs is None + else inputs + ) _, (h, c) = self._core(inputs) return h, c diff --git a/tests/test_nam/test_models/test_recurrent.py b/tests/test_nam/test_models/test_recurrent.py @@ -14,14 +14,25 @@ from nam.models import recurrent from .base import Base +_metadata_loudness_x_mocked = 0.1 * torch.randn((11,)) # Shorter for speed + class TestLSTM(Base): @classmethod def setup_class(cls): + class LSTMWithMocks(recurrent.LSTM): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._get_initial_state_burn_in = 7 + + @classmethod + def _metadata_loudness_x(cls) -> torch.Tensor: + return _metadata_loudness_x_mocked + num_layers = 2 hidden_size = 3 super().setup_class( - recurrent.LSTM, + LSTMWithMocks, args=(hidden_size,), kwargs={"train_burn_in": 3, "train_truncate": 5, "num_layers": num_layers}, )