neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit f743c037305f60b4e9800866ece4948e7cc508a9
parent 6c4c78dd5dae7782dc0f5b3056fdc4e4ab5abd99
Author: Steven Atkinson <[email protected]>
Date:   Thu, 22 Jun 2023 00:12:25 -0700

Log sample rate in `.nam` files (#284)

* Include sample rate as an attribute of models. Define sample_rate kwarg for Dataset and sample_rate property

* Bump model file version, save sample rate to model in core trainer routine

* Add to bin/train/main.py, validate data sets match in sample rate
Diffstat:
Mbin/train/main.py | 7+++++++
Mnam/data.py | 41++++++++++++++++++++++++++++++++++-------
Mnam/models/_base.py | 15+++++++++++++++
Mnam/models/_exportable.py | 2+-
Mnam/models/conv_net.py | 3++-
Mnam/models/linear.py | 4++--
Mnam/models/recurrent.py | 3++-
Mnam/models/wavenet.py | 4++--
Mnam/train/core.py | 12++++++++++++
Mtests/test_nam/test_data.py | 15++++++++++++---
10 files changed, 89 insertions(+), 17 deletions(-)

diff --git a/bin/train/main.py b/bin/train/main.py @@ -181,6 +181,12 @@ def main_inner( dataset_validation = init_dataset(data_config, Split.VALIDATION) train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"]) val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"]) + if train_dataloader.dataset.sample_rate != val_dataloader.dataset.sample_rate: + raise RuntimeError( + "Train and validation data loaders have different data set sample rates: " + f"{train_dataloader.dataset.sample_rate}, " + f"{val_dataloader.dataset.sample_rate}" + ) # ckpt_path = Path(outdir, "checkpoints") # ckpt_path.mkdir() @@ -204,6 +210,7 @@ def main_inner( ) model.cpu() model.eval() + model.net.sample_rate = train_dataloader.dataset.sample_rate if make_plots: plot( model, diff --git a/nam/data.py b/nam/data.py @@ -3,6 +3,7 @@ # Author: Steven Atkinson ([email protected]) import abc +import logging from collections import namedtuple from copy import deepcopy from dataclasses import dataclass @@ -19,6 +20,8 @@ from tqdm import tqdm from ._core import InitializableFromConfig +logger = logging.getLogger(__name__) + _REQUIRED_SAMPWIDTH = 3 REQUIRED_RATE = 48_000 _REQUIRED_CHANNELS = 1 # Mono @@ -94,7 +97,7 @@ def wav_to_np( required_shape, # Expected arr_premono.shape, # Actual f"Mismatched shapes. Expected {required_shape}, but this is " - f"{arr_premono.shape}!" + f"{arr_premono.shape}!", ) # sampwidth fine--we're just casting to 32-bit float anyways arr = arr_premono[:, 0] @@ -122,8 +125,8 @@ def np_to_wav( filename: Union[str, Path], rate: int = 48_000, sampwidth: int = 3, - scale = None, - **kwargs + scale=None, + **kwargs, ): if wavio.__version__ <= "0.0.4" and scale is None: scale = "none" @@ -133,7 +136,7 @@ def np_to_wav( rate, scale=scale, sampwidth=sampwidth, - **kwargs + **kwargs, ) @@ -235,7 +238,8 @@ class Dataset(AbstractDataset, InitializableFromConfig): x_path: Optional[Union[str, Path]] = None, y_path: Optional[Union[str, Path]] = None, input_gain: float = 0.0, - rate: int = REQUIRED_RATE, + sample_rate: Optional[int] = None, + rate: Optional[int] = None, require_input_pre_silence: Optional[float] = _DEFAULT_REQUIRE_INPUT_PRE_SILENCE, ): """ @@ -272,13 +276,14 @@ class Dataset(AbstractDataset, InitializableFromConfig): """ self._validate_x_y(x, y) self._validate_start_stop(x, y, start, stop) + self._sample_rate = self._validate_sample_rate(sample_rate, rate) if not isinstance(delay_interpolation_method, _DelayInterpolationMethod): delay_interpolation_method = _DelayInterpolationMethod( delay_interpolation_method ) if require_input_pre_silence is not None: self._validate_preceding_silence( - x, start, int(require_input_pre_silence * rate) + x, start, int(require_input_pre_silence * self._sample_rate) ) x, y = [z[start:stop] for z in (x, y)] if delay is not None and delay != 0: @@ -293,7 +298,6 @@ class Dataset(AbstractDataset, InitializableFromConfig): self._y = y self._nx = nx self._ny = ny if ny is not None else len(x) - nx + 1 - self._rate = rate def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -318,6 +322,10 @@ class Dataset(AbstractDataset, InitializableFromConfig): return self._ny @property + def sample_rate(self) -> Optional[float]: + return self._sample_rate + + @property def x(self) -> torch.Tensor: """ The input audio data @@ -445,6 +453,25 @@ class Dataset(AbstractDataset, InitializableFromConfig): return x, y @classmethod + def _validate_sample_rate( + cls, sample_rate: Optional[float], rate: Optional[int] + ) -> float: + if sample_rate is None and rate is None: # Default value + return REQUIRED_RATE + if rate is not None: + if sample_rate is not None: + raise ValueError( + "Provided both sample_rate and rate. Provide only sample_rate!" + ) + else: + logger.warning( + "Use of 'rate' is deprecated and will be removed. Use sample_rate instead" + ) + return float(rate) + else: + return sample_rate + + @classmethod def _validate_start_stop( self, x: torch.Tensor, diff --git a/nam/models/_base.py b/nam/models/_base.py @@ -22,6 +22,10 @@ from ._exportable import Exportable class _Base(nn.Module, InitializableFromConfig, Exportable): + def __init__(self, sample_rate: Optional[float] = None): + super().__init__() + self.sample_rate = sample_rate + @abc.abstractproperty def pad_start_default(self) -> bool: pass @@ -45,6 +49,17 @@ class _Base(nn.Module, InitializableFromConfig, Exportable): ) ) + def _get_export_dict(self): + d = super()._get_export_dict() + sample_rate_key = "sample_rate" + if sample_rate_key in d: + raise RuntimeError( + "Model wants to put 'sample_rate' into model export dict, but the key " + "is already taken!" + ) + d[sample_rate_key] = self.sample_rate + return d + def _metadata_loudness(self, gain: float = 1.0, db: bool = True) -> float: """ How loud is this model when given a standardized input? diff --git a/nam/models/_exportable.py b/nam/models/_exportable.py @@ -19,7 +19,7 @@ logger = logging.getLogger(__name__) # Model version is independent from package version as of package version 0.5.2 so that # the API of the package can iterate at a different pace from that of the model files. -_MODEL_VERSION = "0.5.1" +_MODEL_VERSION = "0.5.2" def _cast_enums(d: Dict[Any, Any]) -> Dict[Any, Any]: diff --git a/nam/models/conv_net.py b/nam/models/conv_net.py @@ -110,9 +110,10 @@ class ConvNet(BaseNet): *args, train_strategy: TrainStrategy = default_train_strategy, ir: Optional[_IR] = None, + sample_rate: Optional[float] = None, **kwargs, ): - super().__init__() + super().__init__(sample_rate=sample_rate) self._net = _conv_net(*args, **kwargs) assert train_strategy == TrainStrategy.DILATE, "Stride no longer supported" self._train_strategy = train_strategy diff --git a/nam/models/linear.py b/nam/models/linear.py @@ -18,8 +18,8 @@ from ._base import BaseNet class Linear(BaseNet): - def __init__(self, receptive_field: int, bias: bool = False): - super().__init__() + def __init__(self, receptive_field: int, *args, bias: bool = False, **kwargs): + super().__init__(*args, **kwargs) self._net = nn.Conv1d(1, 1, receptive_field, bias=bias) @property diff --git a/nam/models/recurrent.py b/nam/models/recurrent.py @@ -131,6 +131,7 @@ class LSTM(BaseNet): train_burn_in: Optional[int] = None, train_truncate: Optional[int] = None, input_size: int = 1, + sample_rate: Optional[float] = None, **lstm_kwargs, ): """ @@ -144,7 +145,7 @@ class LSTM(BaseNet): :param input_size: Usually 1 (mono input). A catnet extending this might change it and provide the parametric inputs as additional input dimensions. """ - super().__init__() + super().__init__(sample_rate=sample_rate) if "batch_first" in lstm_kwargs: raise ValueError("batch_first cannot be set.") self._input_size = input_size diff --git a/nam/models/wavenet.py b/nam/models/wavenet.py @@ -339,8 +339,8 @@ class _WaveNet(nn.Module): class WaveNet(BaseNet): - def __init__(self, *args, **kwargs): - super().__init__() + def __init__(self, *args, sample_rate: Optional[float] = None, **kwargs): + super().__init__(sample_rate=sample_rate) self._net = _WaveNet(*args, **kwargs) @property diff --git a/nam/train/core.py b/nam/train/core.py @@ -873,10 +873,21 @@ def train( ) print("Starting training. It's time to kick ass and chew bubblegum!") + # Issue: + # * Model needs sample rate from data, but data set needs nx from model. + # * Model is re-instantiated after training anyways. + # (Hacky) solution: set sample rate in model from dataloader after second + # instantiation from final checkpoint. model = Model.init_from_config(model_config) train_dataloader, val_dataloader = _get_dataloaders( data_config, learning_config, model ) + if train_dataloader.dataset.sample_rate != val_dataloader.dataset.sample_rate: + raise RuntimeError( + "Train and validation data loaders have different data set sample rates: " + f"{train_dataloader.dataset.sample_rate}, " + f"{val_dataloader.dataset.sample_rate}" + ) trainer = pl.Trainer( callbacks=[ @@ -904,6 +915,7 @@ def train( ) model.cpu() model.eval() + model.net.sample_rate = train_dataloader.dataset.sample_rate def window_kwargs(version: Version): if version.major == 1: diff --git a/tests/test_nam/test_data.py b/tests/test_nam/test_data.py @@ -85,6 +85,14 @@ class TestDataset(object): x, y = self._create_xy() data.Dataset(x, y, 3, None) + def test_init_sample_rate(self): + x, y = self._create_xy() + sample_rate = 48_000.0 + d = data.Dataset(x, y, 3, None, sample_rate=sample_rate) + assert hasattr(d, "sample_rate") + assert isinstance(d.sample_rate, float) + assert d.sample_rate == sample_rate + def test_init_zero_delay(self): """ Assert https://github.com/sdatkinson/neural-amp-modeler/issues/15 fixed @@ -285,6 +293,7 @@ class TestWav(object): # Check if the two arrays are equal assert y == pytest.approx(x, abs=self.tolerance) + def test_audio_mismatch_shapes_in_order(): """ https://github.com/sdatkinson/neural-amp-modeler/issues/257 @@ -293,12 +302,12 @@ def test_audio_mismatch_shapes_in_order(): num_channels = 1 x, y = [np.zeros((n, num_channels)) for n in (x_samples, y_samples)] - + with TemporaryDirectory() as tmpdir: y_path = Path(tmpdir, "y.wav") data.np_to_wav(y, y_path) f = lambda: data.wav_to_np(y_path, required_shape=x.shape) - + with pytest.raises(data.AudioShapeMismatchError) as e: f() @@ -309,7 +318,7 @@ def test_audio_mismatch_shapes_in_order(): # x is loaded first; we expect that y matches. assert e.shape_expected == (x_samples, num_channels) assert e.shape_actual == (y_samples, num_channels) - + if __name__ == "__main__": pytest.main()