commit 1caec259ba4d6526cb7da5cd3b752f34d680a459
parent f8509336e28eb292557c9abb8af019cb31892543
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sat, 18 Feb 2023 13:20:52 -0800
Improve data validation (#97)
* Start validation
* Stop validation
* x and y validation
Diffstat:
M | nam/data.py | | | 101 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--- |
M | tests/test_nam/test_data.py | | | 88 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----- |
2 files changed, 181 insertions(+), 8 deletions(-)
diff --git a/nam/data.py b/nam/data.py
@@ -152,6 +152,30 @@ def _interpolate_delay(
)
+class XYError(ValueError):
+ """
+ Exceptions related to invalid x and y provided for data sets
+ """
+
+ pass
+
+
+class StartStopError(ValueError):
+ """
+ Exceptions related to invalid start and stop arguments
+ """
+
+ pass
+
+
+class StartError(StartStopError):
+ pass
+
+
+class StopError(StartStopError):
+ pass
+
+
class Dataset(AbstractDataset, InitializableFromConfig):
"""
Take a pair of matched audio files and serve input + output pairs.
@@ -203,6 +227,8 @@ class Dataset(AbstractDataset, InitializableFromConfig):
completely dry signal (i.e. connecting the interface output directly back
into the input with which the guitar was originally recorded.)
"""
+ self._validate_x_y(x, y)
+ self._validate_start_stop(x, y, start, stop)
if not isinstance(delay_interpolation_method, _DelayInterpolationMethod):
delay_interpolation_method = _DelayInterpolationMethod(
delay_interpolation_method
@@ -215,7 +241,7 @@ class Dataset(AbstractDataset, InitializableFromConfig):
y = y * y_scale
self._x_path = x_path
self._y_path = y_path
- self._validate_inputs(x, y, nx, ny)
+ self._validate_inputs_after_processing(x, y, nx, ny)
self._x = x
self._y = y
self._nx = nx
@@ -324,12 +350,81 @@ class Dataset(AbstractDataset, InitializableFromConfig):
y = _interpolate_delay(y, delay, method)
return x, y
- def _validate_inputs(self, x, y, nx, ny):
+ @classmethod
+ def _validate_start_stop(
+ self,
+ x: torch.Tensor,
+ y: torch.Tensor,
+ start: Optional[int] = None,
+ stop: Optional[int] = None,
+ ):
+ """
+ Check for potential input errors.
+
+ These may be valid indices in Python, but probably point to invalid usage, so
+ we will raise an exception if something fishy is going on (e.g. starting after
+ the end of the file, etc)
+ """
+ # We could do this whole thing with `if len(x[start: stop]==0`, but being more
+ # explicit makes the error messages better for users.
+ if start is None and stop is None:
+ return
+ if len(x) != len(y):
+ raise ValueError(
+ f"Input and output are different length. Input has {len(x)} samples, "
+ f"and output has {len(y)}"
+ )
+ n = len(x)
+ if start is not None:
+ # Start after the files' end?
+ if start >= n:
+ raise StartError(
+ f"Arrays are only {n} samples long, but start was provided as {start}, "
+ "which is beyond the end of the array!"
+ )
+ # Start before the files' beginning?
+ if start < -n:
+ raise StartError(
+ f"Arrays are only {n} samples long, but start was provided as {start}, "
+ "which is before the beginning of the array!"
+ )
+ if stop is not None:
+ # Stop after the files' end?
+ if stop > n:
+ raise StopError(
+ f"Arrays are only {n} samples long, but stop was provided as {stop}, "
+ "which is beyond the end of the array!"
+ )
+ # Start before the files' beginning?
+ if stop <= -n:
+ raise StopError(
+ f"Arrays are only {n} samples long, but stop was provided as {stop}, "
+ "which is before the beginning of the array!"
+ )
+ # Just in case...
+ if len(x[start:stop]) == 0:
+ raise StartStopError(
+ f"Array length {n} with start={start} and stop={stop} would get "
+ "rid of all of the data!"
+ )
+
+ @classmethod
+ def _validate_x_y(self, x, y):
+ if len(x) != len(y):
+ raise XYError(
+ f"Input and output aren't the same lengths! ({len(x)} vs {len(y)})"
+ )
+ # TODO channels
+ n = len(x)
+ if n == 0:
+ raise XYError("Input and output are empty!")
+
+ def _validate_inputs_after_processing(self, x, y, nx, ny):
assert x.ndim == 1
assert y.ndim == 1
assert len(x) == len(y)
if nx > len(x):
- raise RuntimeError(
+ raise RuntimeError( # TODO XYError?
f"Input of length {len(x)}, but receptive field is {nx}."
)
if ny is not None:
diff --git a/tests/test_nam/test_data.py b/tests/test_nam/test_data.py
@@ -4,7 +4,7 @@
import math
from enum import Enum
-from typing import Sequence
+from typing import Tuple
import numpy as np
import pytest
@@ -106,12 +106,86 @@ class TestDataset(object):
sample_x2 = d2[0][0]
assert torch.allclose(sample_x1 * x_scale, sample_x2)
+ @pytest.mark.parametrize(
+ "n,start,valid",
+ (
+ (13, None, True), # No start restrictions; nothing wrong
+ (13, 2, True), # Starts before the end; fine.
+ (13, 12, True), # Starts w/ one to go--ok
+ (13, 13, False), # Starts after the end
+ (13, -5, True), # Starts counting back from the end, fine
+ (13, -13, True), # Starts at the beginning of the array--ok
+ (13, -14, False), # Starts before the beginning of the array--invalid
+ ),
+ )
+ def test_validate_start(self, n: int, start: int, valid: bool):
+ def init():
+ data.Dataset(x, y, nx, ny, start=start)
+
+ nx = 1
+ ny = None
+ x, y = self._create_xy(n=n)
+ if valid:
+ init()
+ assert True # No problem!
+ else:
+ with pytest.raises(data.StartError):
+ init()
+
+ @pytest.mark.parametrize(
+ "n,stop,valid",
+ (
+ (13, None, True), # No stop restrictions; nothing wrong
+ (13, 2, True), # Stops before the end; fine.
+ (13, 13, True), # Stops at the end--ok
+ (13, 14, False), # Stops after the end--not ok
+ (13, -5, True), # Stops counting back from the end, fine
+ (13, -12, True), # Stops w/ one sample--ok
+ (13, -13, False), # Stops w/ no samples--not ok
+ ),
+ )
+ def test_validate_stop(self, n: int, stop: int, valid: bool):
+ def init():
+ data.Dataset(x, y, nx, ny, stop=stop)
+
+ nx = 1
+ ny = None
+ x, y = self._create_xy(n=n)
+ if valid:
+ init()
+ assert True # No problem!
+ else:
+ with pytest.raises(data.StopError):
+ init()
+
+ @pytest.mark.parametrize(
+ "lenx,leny,valid",
+ ((3, 3, True), (3, 4, False), (0, 0, False)), # Lenght mismatch # Empty!
+ )
+ def test_validate_x_y(self, lenx: int, leny: int, valid: bool):
+ def init():
+ data.Dataset(x, y, nx, ny)
+
+ x, y = self._create_xy()
+ assert len(x) >= lenx, "Invalid test!"
+ assert len(y) >= leny, "Invalid test!"
+ x = x[:lenx]
+ y = y[:leny]
+ nx = 1
+ ny = None
+ if valid:
+ init()
+ assert True # It worked!
+ else:
+ with pytest.raises(data.XYError):
+ init()
+
def _create_xy(
self,
n: int = 7,
method: _XYMethod = _XYMethod.RAND,
must_be_in_valid_range: bool = True,
- ) -> Sequence[torch.Tensor]:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:return: (n,), (n,)
"""
@@ -119,11 +193,15 @@ class TestDataset(object):
# note: this isn't "valid" data in the sense that it's beyond (-1, 1).
# But it is useful for the delay code.
assert not must_be_in_valid_range
- return torch.tile(torch.arange(n, dtype=torch.float)[None, :], (2, 1))
+ return tuple(
+ torch.tile(torch.arange(n, dtype=torch.float)[None, :], (2, 1))
+ )
elif method == _XYMethod.RAND:
- return 0.99 * (2.0 * torch.rand((2, n)) - 1.0) # Don't clip
+ return tuple(0.99 * (2.0 * torch.rand((2, n)) - 1.0)) # Don't clip
elif method == _XYMethod.STEP:
- return torch.tile((torch.linspace(0.0, 1.0, n) > 0.5)[None, :], (2, 1))
+ return tuple(
+ torch.tile((torch.linspace(0.0, 1.0, n) > 0.5)[None, :], (2, 1))
+ )
def _t_apply_delay_float(
self, n: int, delay: int, method: data._DelayInterpolationMethod