commit adee42ee3f7f9a4aad345c0830998fd8782a5d25
parent 213b72b70917e8acb655bed899b34d509a6448c0
Author: Steven Atkinson <[email protected]>
Date: Mon, 29 May 2023 14:35:22 -0700
Fix AudioShapeMismatchError mixed up arguments (#258)
Diffstat:
2 files changed, 30 insertions(+), 3 deletions(-)
diff --git a/nam/data.py b/nam/data.py
@@ -91,10 +91,10 @@ def wav_to_np(
if required_shape is not None:
if arr_premono.shape != required_shape:
raise AudioShapeMismatchError(
- arr_premono.shape,
- required_shape,
+ required_shape, # Expected
+ arr_premono.shape, # Actual
f"Mismatched shapes. Expected {required_shape}, but this is "
- f"{arr_premono.shape}!",
+ f"{arr_premono.shape}!"
)
# sampwidth fine--we're just casting to 32-bit float anyways
arr = arr_premono[:, 0]
diff --git a/tests/test_nam/test_data.py b/tests/test_nam/test_data.py
@@ -4,6 +4,8 @@
import math
from enum import Enum
+from pathlib import Path
+from tempfile import TemporaryDirectory
from typing import Tuple
import numpy as np
@@ -241,5 +243,30 @@ class TestDataset(object):
return x_out, y_out
+def test_audio_mismatch_shapes_in_order():
+ """
+ https://github.com/sdatkinson/neural-amp-modeler/issues/257
+ """
+ x_samples, y_samples = 5, 7
+ num_channels = 1
+
+ x, y = [np.zeros((n, num_channels)) for n in (x_samples, y_samples)]
+
+ with TemporaryDirectory() as tmpdir:
+ y_path = Path(tmpdir, "y.wav")
+ data.np_to_wav(y, y_path)
+ f = lambda: data.wav_to_np(y_path, required_shape=x.shape)
+
+ with pytest.raises(data.AudioShapeMismatchError) as e:
+ f()
+
+ try:
+ f()
+ assert False, "Shouldn't have succeeded!"
+ except data.AudioShapeMismatchError as e:
+ # x is loaded first; we expect that y matches.
+ assert e.shape_expected == (x_samples, num_channels)
+ assert e.shape_actual == (y_samples, num_channels)
+
if __name__ == "__main__":
pytest.main()