commit 4c71ff40429034de868a49b271a679a83f8ca350
parent 90d2413d5d17671eaf87820de9ee461649f65da7
Author: Steven Atkinson <[email protected]>
Date: Wed, 8 Nov 2023 08:41:52 -0800
[FEATURE] Allow non-48k, non-24-bit training data (#332)
* Support non-48k sample rates
* Allow other sample widths too
Diffstat:
4 files changed, 58 insertions(+), 25 deletions(-)
diff --git a/bin/train/main.py b/bin/train/main.py
@@ -185,17 +185,14 @@ def main_inner(
dataset_train = init_dataset(data_config, Split.TRAIN)
dataset_validation = init_dataset(data_config, Split.VALIDATION)
- train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"])
- val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"])
- if train_dataloader.dataset.sample_rate != val_dataloader.dataset.sample_rate:
+ if dataset_train.sample_rate != dataset_validation.sample_rate:
raise RuntimeError(
"Train and validation data loaders have different data set sample rates: "
- f"{train_dataloader.dataset.sample_rate}, "
- f"{val_dataloader.dataset.sample_rate}"
+ f"{dataset_train.sample_rate}, {dataset_validation.sample_rate}"
)
+ train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"])
+ val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"])
- # ckpt_path = Path(outdir, "checkpoints")
- # ckpt_path.mkdir()
trainer = pl.Trainer(
callbacks=_create_callbacks(learning_config),
default_root_dir=outdir,
diff --git a/nam/data.py b/nam/data.py
@@ -22,8 +22,8 @@ from ._core import InitializableFromConfig
logger = logging.getLogger(__name__)
-_REQUIRED_SAMPWIDTH = 3
-REQUIRED_RATE = 48_000
+REQUIRED_RATE = 48_000 # FIXME not "required" anymore!
+_DEFAULT_RATE = REQUIRED_RATE # There we go :)
_REQUIRED_CHANNELS = 1 # Mono
@@ -60,7 +60,7 @@ class AudioShapeMismatchError(ValueError):
def wav_to_np(
filename: Union[str, Path],
- rate: Optional[int] = REQUIRED_RATE,
+ rate: Optional[int] = _DEFAULT_RATE,
require_match: Optional[Union[str, Path]] = None,
required_shape: Optional[Tuple[int]] = None,
required_wavinfo: Optional[WavInfo] = None,
@@ -72,7 +72,6 @@ def wav_to_np(
"""
x_wav = wavio.read(str(filename))
assert x_wav.data.shape[1] == _REQUIRED_CHANNELS, "Mono"
- assert x_wav.sampwidth == _REQUIRED_SAMPWIDTH, "24-bit"
if rate is not None and x_wav.rate != rate:
raise RuntimeError(
f"Explicitly expected sample rate of {rate}, but found {x_wav.rate} in "
@@ -268,7 +267,8 @@ class Dataset(AbstractDataset, InitializableFromConfig):
you are using a reamping setup, you can estimate this by reamping a
completely dry signal (i.e. connecting the interface output directly back
into the input with which the guitar was originally recorded.)
- :param rate: Sample rate for the data
+ :param sample_rate: Sample rate for the data
+ :param rate: Sample rate for the data (deprecated)
:param require_input_pre_silence: If provided, require that this much time (in
seconds) preceding the start of the data set (`start`) have a silent input.
If it's not, then raise an exception because the output due to it will leak
@@ -349,7 +349,9 @@ class Dataset(AbstractDataset, InitializableFromConfig):
@classmethod
def parse_config(cls, config):
- x, x_wavinfo = wav_to_tensor(config["x_path"], info=True)
+ x, x_wavinfo = wav_to_tensor(
+ config["x_path"], info=True, rate=config.get("rate")
+ )
rate = x_wavinfo.rate
try:
y = wav_to_tensor(
@@ -402,7 +404,7 @@ class Dataset(AbstractDataset, InitializableFromConfig):
"y_scale": config.get("y_scale", 1.0),
"x_path": config["x_path"],
"y_path": config["y_path"],
- "rate": config.get("rate", REQUIRED_RATE),
+ "sample_rate": rate,
"require_input_pre_silence": config.get(
"require_input_pre_silence", _DEFAULT_REQUIRE_INPUT_PRE_SILENCE
),
@@ -457,7 +459,7 @@ class Dataset(AbstractDataset, InitializableFromConfig):
cls, sample_rate: Optional[float], rate: Optional[int]
) -> float:
if sample_rate is None and rate is None: # Default value
- return REQUIRED_RATE
+ return _DEFAULT_RATE
if rate is not None:
if sample_rate is not None:
raise ValueError(
diff --git a/tests/test_nam/test_data.py b/tests/test_nam/test_data.py
@@ -15,6 +15,8 @@ import torch
from nam import data
+_sample_rates = (44_100, 48_000, 88_200, 96_000)
+
class _XYMethod(Enum):
ARANGE = "arange"
@@ -117,6 +119,22 @@ class TestDataset(object):
sample_x2 = d2[0][0]
assert torch.allclose(sample_x1 * x_scale, sample_x2)
+ @pytest.mark.parametrize("sample_rate", _sample_rates)
+ def test_sample_rates(self, sample_rate: int):
+ """
+ Test that datasets with various sample rates can be made
+ """
+ x = np.random.rand(16) - 0.5
+ y = x
+ with TemporaryDirectory() as tmpdir:
+ x_path = Path(tmpdir, "input.wav")
+ y_path = Path(tmpdir, "output.wav")
+ data.np_to_wav(x, x_path, rate=sample_rate)
+ data.np_to_wav(y, y_path, rate=sample_rate)
+ config = {"x_path": str(x_path), "y_path": str(y_path), "nx": 4, "ny": 2}
+ parsed_config = data.Dataset.parse_config(config)
+ assert parsed_config["sample_rate"] == sample_rate
+
@pytest.mark.parametrize(
"n,start,valid",
(
@@ -271,16 +289,18 @@ class TestWav(object):
# Check if the two arrays are equal
assert y == pytest.approx(x, abs=self.tolerance)
- def test_np_to_wav_to_np_44khz(self, tmpdir):
- # Create random numpy array
- x = np.random.rand(1000)
- # Save numpy array as WAV file with sampling rate of 44 kHz
- filename = os.path.join(tmpdir, "test.wav")
- data.np_to_wav(x, filename, rate=44100)
- # Load WAV file with sampling rate of 44 kHz
- y = data.wav_to_np(filename, rate=44100)
- # Check if the two arrays are equal
- assert y == pytest.approx(x, abs=self.tolerance)
+ @pytest.mark.parametrize("sample_rate", _sample_rates)
+ def test_np_to_wav_to_np_sample_rates(self, sample_rate: int):
+ with TemporaryDirectory() as tmpdir:
+ # Create random numpy array
+ x = np.random.rand(8)
+ # Save numpy array as WAV file with sampling rate of 44 kHz
+ filename = Path(tmpdir, "x.wav")
+ data.np_to_wav(x, filename, rate=sample_rate)
+ # Load WAV file with sampling rate of 44 kHz
+ y = data.wav_to_np(filename, rate=sample_rate)
+ # Check if the two arrays are equal
+ assert y == pytest.approx(x, abs=self.tolerance)
def test_np_to_wav_to_np_scale_arg(self, tmpdir):
# Create random numpy array
@@ -293,6 +313,18 @@ class TestWav(object):
# Check if the two arrays are equal
assert y == pytest.approx(x, abs=self.tolerance)
+ @pytest.mark.parametrize("sample_width", (2, 3))
+ def test_sample_widths(self, sample_width: int):
+ """
+ Test that datasets with various sample widths can be made
+ """
+ x = np.random.rand(16) - 0.5
+ with TemporaryDirectory() as tmpdir:
+ x_path = Path(tmpdir, "x.wav")
+ data.np_to_wav(x, x_path, sampwidth=sample_width)
+ _, info = data.wav_to_np(x_path, info=True)
+ assert info.sampwidth == sample_width
+
def test_audio_mismatch_shapes_in_order():
"""
diff --git a/tests/test_nam/test_models/test_base.py b/tests/test_nam/test_models/test_base.py
@@ -74,6 +74,8 @@ def test_mrstft_loss(batch_size: int, sequence_length: int):
def test_mrstft_loss_cpu_fallback(mocker):
"""
Assert that fallback to CPU happens on failure
+
+ :param mocker: Provided by pytest-mock
"""
def mocked_loss(