neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 4c71ff40429034de868a49b271a679a83f8ca350
parent 90d2413d5d17671eaf87820de9ee461649f65da7
Author: Steven Atkinson <[email protected]>
Date:   Wed,  8 Nov 2023 08:41:52 -0800

[FEATURE] Allow non-48k, non-24-bit training data (#332)

* Support non-48k sample rates

* Allow other sample widths too
Diffstat:
Mbin/train/main.py | 11++++-------
Mnam/data.py | 18++++++++++--------
Mtests/test_nam/test_data.py | 52++++++++++++++++++++++++++++++++++++++++++----------
Mtests/test_nam/test_models/test_base.py | 2++
4 files changed, 58 insertions(+), 25 deletions(-)

diff --git a/bin/train/main.py b/bin/train/main.py @@ -185,17 +185,14 @@ def main_inner( dataset_train = init_dataset(data_config, Split.TRAIN) dataset_validation = init_dataset(data_config, Split.VALIDATION) - train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"]) - val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"]) - if train_dataloader.dataset.sample_rate != val_dataloader.dataset.sample_rate: + if dataset_train.sample_rate != dataset_validation.sample_rate: raise RuntimeError( "Train and validation data loaders have different data set sample rates: " - f"{train_dataloader.dataset.sample_rate}, " - f"{val_dataloader.dataset.sample_rate}" + f"{dataset_train.sample_rate}, {dataset_validation.sample_rate}" ) + train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"]) + val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"]) - # ckpt_path = Path(outdir, "checkpoints") - # ckpt_path.mkdir() trainer = pl.Trainer( callbacks=_create_callbacks(learning_config), default_root_dir=outdir, diff --git a/nam/data.py b/nam/data.py @@ -22,8 +22,8 @@ from ._core import InitializableFromConfig logger = logging.getLogger(__name__) -_REQUIRED_SAMPWIDTH = 3 -REQUIRED_RATE = 48_000 +REQUIRED_RATE = 48_000 # FIXME not "required" anymore! +_DEFAULT_RATE = REQUIRED_RATE # There we go :) _REQUIRED_CHANNELS = 1 # Mono @@ -60,7 +60,7 @@ class AudioShapeMismatchError(ValueError): def wav_to_np( filename: Union[str, Path], - rate: Optional[int] = REQUIRED_RATE, + rate: Optional[int] = _DEFAULT_RATE, require_match: Optional[Union[str, Path]] = None, required_shape: Optional[Tuple[int]] = None, required_wavinfo: Optional[WavInfo] = None, @@ -72,7 +72,6 @@ def wav_to_np( """ x_wav = wavio.read(str(filename)) assert x_wav.data.shape[1] == _REQUIRED_CHANNELS, "Mono" - assert x_wav.sampwidth == _REQUIRED_SAMPWIDTH, "24-bit" if rate is not None and x_wav.rate != rate: raise RuntimeError( f"Explicitly expected sample rate of {rate}, but found {x_wav.rate} in " @@ -268,7 +267,8 @@ class Dataset(AbstractDataset, InitializableFromConfig): you are using a reamping setup, you can estimate this by reamping a completely dry signal (i.e. connecting the interface output directly back into the input with which the guitar was originally recorded.) - :param rate: Sample rate for the data + :param sample_rate: Sample rate for the data + :param rate: Sample rate for the data (deprecated) :param require_input_pre_silence: If provided, require that this much time (in seconds) preceding the start of the data set (`start`) have a silent input. If it's not, then raise an exception because the output due to it will leak @@ -349,7 +349,9 @@ class Dataset(AbstractDataset, InitializableFromConfig): @classmethod def parse_config(cls, config): - x, x_wavinfo = wav_to_tensor(config["x_path"], info=True) + x, x_wavinfo = wav_to_tensor( + config["x_path"], info=True, rate=config.get("rate") + ) rate = x_wavinfo.rate try: y = wav_to_tensor( @@ -402,7 +404,7 @@ class Dataset(AbstractDataset, InitializableFromConfig): "y_scale": config.get("y_scale", 1.0), "x_path": config["x_path"], "y_path": config["y_path"], - "rate": config.get("rate", REQUIRED_RATE), + "sample_rate": rate, "require_input_pre_silence": config.get( "require_input_pre_silence", _DEFAULT_REQUIRE_INPUT_PRE_SILENCE ), @@ -457,7 +459,7 @@ class Dataset(AbstractDataset, InitializableFromConfig): cls, sample_rate: Optional[float], rate: Optional[int] ) -> float: if sample_rate is None and rate is None: # Default value - return REQUIRED_RATE + return _DEFAULT_RATE if rate is not None: if sample_rate is not None: raise ValueError( diff --git a/tests/test_nam/test_data.py b/tests/test_nam/test_data.py @@ -15,6 +15,8 @@ import torch from nam import data +_sample_rates = (44_100, 48_000, 88_200, 96_000) + class _XYMethod(Enum): ARANGE = "arange" @@ -117,6 +119,22 @@ class TestDataset(object): sample_x2 = d2[0][0] assert torch.allclose(sample_x1 * x_scale, sample_x2) + @pytest.mark.parametrize("sample_rate", _sample_rates) + def test_sample_rates(self, sample_rate: int): + """ + Test that datasets with various sample rates can be made + """ + x = np.random.rand(16) - 0.5 + y = x + with TemporaryDirectory() as tmpdir: + x_path = Path(tmpdir, "input.wav") + y_path = Path(tmpdir, "output.wav") + data.np_to_wav(x, x_path, rate=sample_rate) + data.np_to_wav(y, y_path, rate=sample_rate) + config = {"x_path": str(x_path), "y_path": str(y_path), "nx": 4, "ny": 2} + parsed_config = data.Dataset.parse_config(config) + assert parsed_config["sample_rate"] == sample_rate + @pytest.mark.parametrize( "n,start,valid", ( @@ -271,16 +289,18 @@ class TestWav(object): # Check if the two arrays are equal assert y == pytest.approx(x, abs=self.tolerance) - def test_np_to_wav_to_np_44khz(self, tmpdir): - # Create random numpy array - x = np.random.rand(1000) - # Save numpy array as WAV file with sampling rate of 44 kHz - filename = os.path.join(tmpdir, "test.wav") - data.np_to_wav(x, filename, rate=44100) - # Load WAV file with sampling rate of 44 kHz - y = data.wav_to_np(filename, rate=44100) - # Check if the two arrays are equal - assert y == pytest.approx(x, abs=self.tolerance) + @pytest.mark.parametrize("sample_rate", _sample_rates) + def test_np_to_wav_to_np_sample_rates(self, sample_rate: int): + with TemporaryDirectory() as tmpdir: + # Create random numpy array + x = np.random.rand(8) + # Save numpy array as WAV file with sampling rate of 44 kHz + filename = Path(tmpdir, "x.wav") + data.np_to_wav(x, filename, rate=sample_rate) + # Load WAV file with sampling rate of 44 kHz + y = data.wav_to_np(filename, rate=sample_rate) + # Check if the two arrays are equal + assert y == pytest.approx(x, abs=self.tolerance) def test_np_to_wav_to_np_scale_arg(self, tmpdir): # Create random numpy array @@ -293,6 +313,18 @@ class TestWav(object): # Check if the two arrays are equal assert y == pytest.approx(x, abs=self.tolerance) + @pytest.mark.parametrize("sample_width", (2, 3)) + def test_sample_widths(self, sample_width: int): + """ + Test that datasets with various sample widths can be made + """ + x = np.random.rand(16) - 0.5 + with TemporaryDirectory() as tmpdir: + x_path = Path(tmpdir, "x.wav") + data.np_to_wav(x, x_path, sampwidth=sample_width) + _, info = data.wav_to_np(x_path, info=True) + assert info.sampwidth == sample_width + def test_audio_mismatch_shapes_in_order(): """ diff --git a/tests/test_nam/test_models/test_base.py b/tests/test_nam/test_models/test_base.py @@ -74,6 +74,8 @@ def test_mrstft_loss(batch_size: int, sequence_length: int): def test_mrstft_loss_cpu_fallback(mocker): """ Assert that fallback to CPU happens on failure + + :param mocker: Provided by pytest-mock """ def mocked_loss(