commit faa1ead1d80bd664cc42305ac2b8bcf975d8e685
parent 840f10e4d2fdeae5ce51fb8c9586935a1c495e59
Author: Steven Atkinson <[email protected]>
Date: Mon, 12 Jun 2023 19:34:07 -0700
Faster recurrent testing with shorter inputs (#275)
Diffstat:
3 files changed, 27 insertions(+), 7 deletions(-)
diff --git a/nam/models/_base.py b/nam/models/_base.py
@@ -37,6 +37,14 @@ class _Base(nn.Module, InitializableFromConfig, Exportable):
def forward(self, *args, **kwargs) -> torch.Tensor:
pass
+ @classmethod
+ def _metadata_loudness_x(cls) -> torch.Tensor:
+ return wav_to_tensor(
+ pkg_resources.resource_filename(
+ "nam", "models/_resources/loudness_input.wav"
+ )
+ )
+
def _metadata_loudness(self, gain: float = 1.0, db: bool = True) -> float:
"""
How loud is this model when given a standardized input?
@@ -44,11 +52,7 @@ class _Base(nn.Module, InitializableFromConfig, Exportable):
:param gain: Multiplies input signal
"""
- x = wav_to_tensor(
- pkg_resources.resource_filename(
- "nam", "models/_resources/loudness_input.wav"
- )
- )
+ x = self._metadata_loudness_x()
y = self._at_nominal_settings(gain * x)
loudness = torch.sqrt(torch.mean(torch.square(y)))
if db:
diff --git a/nam/models/recurrent.py b/nam/models/recurrent.py
@@ -158,6 +158,7 @@ class LSTM(BaseNet):
self._initial_hidden = nn.Parameter(
torch.zeros((lstm_kwargs.get("num_layers", 1), hidden_size))
)
+ self._get_initial_state_burn_in = 48_000
@property
def receptive_field(self) -> int:
@@ -340,7 +341,11 @@ class LSTM(BaseNet):
:return: (L,DH), (L,DH)
"""
- inputs = torch.zeros((1, 48_000, 1)) if inputs is None else inputs
+ inputs = (
+ torch.zeros((1, self._get_initial_state_burn_in, 1))
+ if inputs is None
+ else inputs
+ )
_, (h, c) = self._core(inputs)
return h, c
diff --git a/tests/test_nam/test_models/test_recurrent.py b/tests/test_nam/test_models/test_recurrent.py
@@ -14,14 +14,25 @@ from nam.models import recurrent
from .base import Base
+_metadata_loudness_x_mocked = 0.1 * torch.randn((11,)) # Shorter for speed
+
class TestLSTM(Base):
@classmethod
def setup_class(cls):
+ class LSTMWithMocks(recurrent.LSTM):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._get_initial_state_burn_in = 7
+
+ @classmethod
+ def _metadata_loudness_x(cls) -> torch.Tensor:
+ return _metadata_loudness_x_mocked
+
num_layers = 2
hidden_size = 3
super().setup_class(
- recurrent.LSTM,
+ LSTMWithMocks,
args=(hidden_size,),
kwargs={"train_burn_in": 3, "train_truncate": 5, "num_layers": num_layers},
)