commit f743c037305f60b4e9800866ece4948e7cc508a9
parent 6c4c78dd5dae7782dc0f5b3056fdc4e4ab5abd99
Author: Steven Atkinson <[email protected]>
Date: Thu, 22 Jun 2023 00:12:25 -0700
Log sample rate in `.nam` files (#284)
* Include sample rate as an attribute of models. Define sample_rate kwarg for Dataset and sample_rate property
* Bump model file version, save sample rate to model in core trainer routine
* Add to bin/train/main.py, validate data sets match in sample rate
Diffstat:
10 files changed, 89 insertions(+), 17 deletions(-)
diff --git a/bin/train/main.py b/bin/train/main.py
@@ -181,6 +181,12 @@ def main_inner(
dataset_validation = init_dataset(data_config, Split.VALIDATION)
train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"])
val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"])
+ if train_dataloader.dataset.sample_rate != val_dataloader.dataset.sample_rate:
+ raise RuntimeError(
+ "Train and validation data loaders have different data set sample rates: "
+ f"{train_dataloader.dataset.sample_rate}, "
+ f"{val_dataloader.dataset.sample_rate}"
+ )
# ckpt_path = Path(outdir, "checkpoints")
# ckpt_path.mkdir()
@@ -204,6 +210,7 @@ def main_inner(
)
model.cpu()
model.eval()
+ model.net.sample_rate = train_dataloader.dataset.sample_rate
if make_plots:
plot(
model,
diff --git a/nam/data.py b/nam/data.py
@@ -3,6 +3,7 @@
# Author: Steven Atkinson ([email protected])
import abc
+import logging
from collections import namedtuple
from copy import deepcopy
from dataclasses import dataclass
@@ -19,6 +20,8 @@ from tqdm import tqdm
from ._core import InitializableFromConfig
+logger = logging.getLogger(__name__)
+
_REQUIRED_SAMPWIDTH = 3
REQUIRED_RATE = 48_000
_REQUIRED_CHANNELS = 1 # Mono
@@ -94,7 +97,7 @@ def wav_to_np(
required_shape, # Expected
arr_premono.shape, # Actual
f"Mismatched shapes. Expected {required_shape}, but this is "
- f"{arr_premono.shape}!"
+ f"{arr_premono.shape}!",
)
# sampwidth fine--we're just casting to 32-bit float anyways
arr = arr_premono[:, 0]
@@ -122,8 +125,8 @@ def np_to_wav(
filename: Union[str, Path],
rate: int = 48_000,
sampwidth: int = 3,
- scale = None,
- **kwargs
+ scale=None,
+ **kwargs,
):
if wavio.__version__ <= "0.0.4" and scale is None:
scale = "none"
@@ -133,7 +136,7 @@ def np_to_wav(
rate,
scale=scale,
sampwidth=sampwidth,
- **kwargs
+ **kwargs,
)
@@ -235,7 +238,8 @@ class Dataset(AbstractDataset, InitializableFromConfig):
x_path: Optional[Union[str, Path]] = None,
y_path: Optional[Union[str, Path]] = None,
input_gain: float = 0.0,
- rate: int = REQUIRED_RATE,
+ sample_rate: Optional[int] = None,
+ rate: Optional[int] = None,
require_input_pre_silence: Optional[float] = _DEFAULT_REQUIRE_INPUT_PRE_SILENCE,
):
"""
@@ -272,13 +276,14 @@ class Dataset(AbstractDataset, InitializableFromConfig):
"""
self._validate_x_y(x, y)
self._validate_start_stop(x, y, start, stop)
+ self._sample_rate = self._validate_sample_rate(sample_rate, rate)
if not isinstance(delay_interpolation_method, _DelayInterpolationMethod):
delay_interpolation_method = _DelayInterpolationMethod(
delay_interpolation_method
)
if require_input_pre_silence is not None:
self._validate_preceding_silence(
- x, start, int(require_input_pre_silence * rate)
+ x, start, int(require_input_pre_silence * self._sample_rate)
)
x, y = [z[start:stop] for z in (x, y)]
if delay is not None and delay != 0:
@@ -293,7 +298,6 @@ class Dataset(AbstractDataset, InitializableFromConfig):
self._y = y
self._nx = nx
self._ny = ny if ny is not None else len(x) - nx + 1
- self._rate = rate
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
@@ -318,6 +322,10 @@ class Dataset(AbstractDataset, InitializableFromConfig):
return self._ny
@property
+ def sample_rate(self) -> Optional[float]:
+ return self._sample_rate
+
+ @property
def x(self) -> torch.Tensor:
"""
The input audio data
@@ -445,6 +453,25 @@ class Dataset(AbstractDataset, InitializableFromConfig):
return x, y
@classmethod
+ def _validate_sample_rate(
+ cls, sample_rate: Optional[float], rate: Optional[int]
+ ) -> float:
+ if sample_rate is None and rate is None: # Default value
+ return REQUIRED_RATE
+ if rate is not None:
+ if sample_rate is not None:
+ raise ValueError(
+ "Provided both sample_rate and rate. Provide only sample_rate!"
+ )
+ else:
+ logger.warning(
+ "Use of 'rate' is deprecated and will be removed. Use sample_rate instead"
+ )
+ return float(rate)
+ else:
+ return sample_rate
+
+ @classmethod
def _validate_start_stop(
self,
x: torch.Tensor,
diff --git a/nam/models/_base.py b/nam/models/_base.py
@@ -22,6 +22,10 @@ from ._exportable import Exportable
class _Base(nn.Module, InitializableFromConfig, Exportable):
+ def __init__(self, sample_rate: Optional[float] = None):
+ super().__init__()
+ self.sample_rate = sample_rate
+
@abc.abstractproperty
def pad_start_default(self) -> bool:
pass
@@ -45,6 +49,17 @@ class _Base(nn.Module, InitializableFromConfig, Exportable):
)
)
+ def _get_export_dict(self):
+ d = super()._get_export_dict()
+ sample_rate_key = "sample_rate"
+ if sample_rate_key in d:
+ raise RuntimeError(
+ "Model wants to put 'sample_rate' into model export dict, but the key "
+ "is already taken!"
+ )
+ d[sample_rate_key] = self.sample_rate
+ return d
+
def _metadata_loudness(self, gain: float = 1.0, db: bool = True) -> float:
"""
How loud is this model when given a standardized input?
diff --git a/nam/models/_exportable.py b/nam/models/_exportable.py
@@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
# Model version is independent from package version as of package version 0.5.2 so that
# the API of the package can iterate at a different pace from that of the model files.
-_MODEL_VERSION = "0.5.1"
+_MODEL_VERSION = "0.5.2"
def _cast_enums(d: Dict[Any, Any]) -> Dict[Any, Any]:
diff --git a/nam/models/conv_net.py b/nam/models/conv_net.py
@@ -110,9 +110,10 @@ class ConvNet(BaseNet):
*args,
train_strategy: TrainStrategy = default_train_strategy,
ir: Optional[_IR] = None,
+ sample_rate: Optional[float] = None,
**kwargs,
):
- super().__init__()
+ super().__init__(sample_rate=sample_rate)
self._net = _conv_net(*args, **kwargs)
assert train_strategy == TrainStrategy.DILATE, "Stride no longer supported"
self._train_strategy = train_strategy
diff --git a/nam/models/linear.py b/nam/models/linear.py
@@ -18,8 +18,8 @@ from ._base import BaseNet
class Linear(BaseNet):
- def __init__(self, receptive_field: int, bias: bool = False):
- super().__init__()
+ def __init__(self, receptive_field: int, *args, bias: bool = False, **kwargs):
+ super().__init__(*args, **kwargs)
self._net = nn.Conv1d(1, 1, receptive_field, bias=bias)
@property
diff --git a/nam/models/recurrent.py b/nam/models/recurrent.py
@@ -131,6 +131,7 @@ class LSTM(BaseNet):
train_burn_in: Optional[int] = None,
train_truncate: Optional[int] = None,
input_size: int = 1,
+ sample_rate: Optional[float] = None,
**lstm_kwargs,
):
"""
@@ -144,7 +145,7 @@ class LSTM(BaseNet):
:param input_size: Usually 1 (mono input). A catnet extending this might change
it and provide the parametric inputs as additional input dimensions.
"""
- super().__init__()
+ super().__init__(sample_rate=sample_rate)
if "batch_first" in lstm_kwargs:
raise ValueError("batch_first cannot be set.")
self._input_size = input_size
diff --git a/nam/models/wavenet.py b/nam/models/wavenet.py
@@ -339,8 +339,8 @@ class _WaveNet(nn.Module):
class WaveNet(BaseNet):
- def __init__(self, *args, **kwargs):
- super().__init__()
+ def __init__(self, *args, sample_rate: Optional[float] = None, **kwargs):
+ super().__init__(sample_rate=sample_rate)
self._net = _WaveNet(*args, **kwargs)
@property
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -873,10 +873,21 @@ def train(
)
print("Starting training. It's time to kick ass and chew bubblegum!")
+ # Issue:
+ # * Model needs sample rate from data, but data set needs nx from model.
+ # * Model is re-instantiated after training anyways.
+ # (Hacky) solution: set sample rate in model from dataloader after second
+ # instantiation from final checkpoint.
model = Model.init_from_config(model_config)
train_dataloader, val_dataloader = _get_dataloaders(
data_config, learning_config, model
)
+ if train_dataloader.dataset.sample_rate != val_dataloader.dataset.sample_rate:
+ raise RuntimeError(
+ "Train and validation data loaders have different data set sample rates: "
+ f"{train_dataloader.dataset.sample_rate}, "
+ f"{val_dataloader.dataset.sample_rate}"
+ )
trainer = pl.Trainer(
callbacks=[
@@ -904,6 +915,7 @@ def train(
)
model.cpu()
model.eval()
+ model.net.sample_rate = train_dataloader.dataset.sample_rate
def window_kwargs(version: Version):
if version.major == 1:
diff --git a/tests/test_nam/test_data.py b/tests/test_nam/test_data.py
@@ -85,6 +85,14 @@ class TestDataset(object):
x, y = self._create_xy()
data.Dataset(x, y, 3, None)
+ def test_init_sample_rate(self):
+ x, y = self._create_xy()
+ sample_rate = 48_000.0
+ d = data.Dataset(x, y, 3, None, sample_rate=sample_rate)
+ assert hasattr(d, "sample_rate")
+ assert isinstance(d.sample_rate, float)
+ assert d.sample_rate == sample_rate
+
def test_init_zero_delay(self):
"""
Assert https://github.com/sdatkinson/neural-amp-modeler/issues/15 fixed
@@ -285,6 +293,7 @@ class TestWav(object):
# Check if the two arrays are equal
assert y == pytest.approx(x, abs=self.tolerance)
+
def test_audio_mismatch_shapes_in_order():
"""
https://github.com/sdatkinson/neural-amp-modeler/issues/257
@@ -293,12 +302,12 @@ def test_audio_mismatch_shapes_in_order():
num_channels = 1
x, y = [np.zeros((n, num_channels)) for n in (x_samples, y_samples)]
-
+
with TemporaryDirectory() as tmpdir:
y_path = Path(tmpdir, "y.wav")
data.np_to_wav(y, y_path)
f = lambda: data.wav_to_np(y_path, required_shape=x.shape)
-
+
with pytest.raises(data.AudioShapeMismatchError) as e:
f()
@@ -309,7 +318,7 @@ def test_audio_mismatch_shapes_in_order():
# x is loaded first; we expect that y matches.
assert e.shape_expected == (x_samples, num_channels)
assert e.shape_actual == (y_samples, num_channels)
-
+
if __name__ == "__main__":
pytest.main()