neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 1caec259ba4d6526cb7da5cd3b752f34d680a459
parent f8509336e28eb292557c9abb8af019cb31892543
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Sat, 18 Feb 2023 13:20:52 -0800

Improve data validation (#97)

* Start validation

* Stop validation

* x and y validation
Diffstat:
Mnam/data.py | 101++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---
Mtests/test_nam/test_data.py | 88++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----
2 files changed, 181 insertions(+), 8 deletions(-)

diff --git a/nam/data.py b/nam/data.py @@ -152,6 +152,30 @@ def _interpolate_delay( ) +class XYError(ValueError): + """ + Exceptions related to invalid x and y provided for data sets + """ + + pass + + +class StartStopError(ValueError): + """ + Exceptions related to invalid start and stop arguments + """ + + pass + + +class StartError(StartStopError): + pass + + +class StopError(StartStopError): + pass + + class Dataset(AbstractDataset, InitializableFromConfig): """ Take a pair of matched audio files and serve input + output pairs. @@ -203,6 +227,8 @@ class Dataset(AbstractDataset, InitializableFromConfig): completely dry signal (i.e. connecting the interface output directly back into the input with which the guitar was originally recorded.) """ + self._validate_x_y(x, y) + self._validate_start_stop(x, y, start, stop) if not isinstance(delay_interpolation_method, _DelayInterpolationMethod): delay_interpolation_method = _DelayInterpolationMethod( delay_interpolation_method @@ -215,7 +241,7 @@ class Dataset(AbstractDataset, InitializableFromConfig): y = y * y_scale self._x_path = x_path self._y_path = y_path - self._validate_inputs(x, y, nx, ny) + self._validate_inputs_after_processing(x, y, nx, ny) self._x = x self._y = y self._nx = nx @@ -324,12 +350,81 @@ class Dataset(AbstractDataset, InitializableFromConfig): y = _interpolate_delay(y, delay, method) return x, y - def _validate_inputs(self, x, y, nx, ny): + @classmethod + def _validate_start_stop( + self, + x: torch.Tensor, + y: torch.Tensor, + start: Optional[int] = None, + stop: Optional[int] = None, + ): + """ + Check for potential input errors. + + These may be valid indices in Python, but probably point to invalid usage, so + we will raise an exception if something fishy is going on (e.g. starting after + the end of the file, etc) + """ + # We could do this whole thing with `if len(x[start: stop]==0`, but being more + # explicit makes the error messages better for users. + if start is None and stop is None: + return + if len(x) != len(y): + raise ValueError( + f"Input and output are different length. Input has {len(x)} samples, " + f"and output has {len(y)}" + ) + n = len(x) + if start is not None: + # Start after the files' end? + if start >= n: + raise StartError( + f"Arrays are only {n} samples long, but start was provided as {start}, " + "which is beyond the end of the array!" + ) + # Start before the files' beginning? + if start < -n: + raise StartError( + f"Arrays are only {n} samples long, but start was provided as {start}, " + "which is before the beginning of the array!" + ) + if stop is not None: + # Stop after the files' end? + if stop > n: + raise StopError( + f"Arrays are only {n} samples long, but stop was provided as {stop}, " + "which is beyond the end of the array!" + ) + # Start before the files' beginning? + if stop <= -n: + raise StopError( + f"Arrays are only {n} samples long, but stop was provided as {stop}, " + "which is before the beginning of the array!" + ) + # Just in case... + if len(x[start:stop]) == 0: + raise StartStopError( + f"Array length {n} with start={start} and stop={stop} would get " + "rid of all of the data!" + ) + + @classmethod + def _validate_x_y(self, x, y): + if len(x) != len(y): + raise XYError( + f"Input and output aren't the same lengths! ({len(x)} vs {len(y)})" + ) + # TODO channels + n = len(x) + if n == 0: + raise XYError("Input and output are empty!") + + def _validate_inputs_after_processing(self, x, y, nx, ny): assert x.ndim == 1 assert y.ndim == 1 assert len(x) == len(y) if nx > len(x): - raise RuntimeError( + raise RuntimeError( # TODO XYError? f"Input of length {len(x)}, but receptive field is {nx}." ) if ny is not None: diff --git a/tests/test_nam/test_data.py b/tests/test_nam/test_data.py @@ -4,7 +4,7 @@ import math from enum import Enum -from typing import Sequence +from typing import Tuple import numpy as np import pytest @@ -106,12 +106,86 @@ class TestDataset(object): sample_x2 = d2[0][0] assert torch.allclose(sample_x1 * x_scale, sample_x2) + @pytest.mark.parametrize( + "n,start,valid", + ( + (13, None, True), # No start restrictions; nothing wrong + (13, 2, True), # Starts before the end; fine. + (13, 12, True), # Starts w/ one to go--ok + (13, 13, False), # Starts after the end + (13, -5, True), # Starts counting back from the end, fine + (13, -13, True), # Starts at the beginning of the array--ok + (13, -14, False), # Starts before the beginning of the array--invalid + ), + ) + def test_validate_start(self, n: int, start: int, valid: bool): + def init(): + data.Dataset(x, y, nx, ny, start=start) + + nx = 1 + ny = None + x, y = self._create_xy(n=n) + if valid: + init() + assert True # No problem! + else: + with pytest.raises(data.StartError): + init() + + @pytest.mark.parametrize( + "n,stop,valid", + ( + (13, None, True), # No stop restrictions; nothing wrong + (13, 2, True), # Stops before the end; fine. + (13, 13, True), # Stops at the end--ok + (13, 14, False), # Stops after the end--not ok + (13, -5, True), # Stops counting back from the end, fine + (13, -12, True), # Stops w/ one sample--ok + (13, -13, False), # Stops w/ no samples--not ok + ), + ) + def test_validate_stop(self, n: int, stop: int, valid: bool): + def init(): + data.Dataset(x, y, nx, ny, stop=stop) + + nx = 1 + ny = None + x, y = self._create_xy(n=n) + if valid: + init() + assert True # No problem! + else: + with pytest.raises(data.StopError): + init() + + @pytest.mark.parametrize( + "lenx,leny,valid", + ((3, 3, True), (3, 4, False), (0, 0, False)), # Lenght mismatch # Empty! + ) + def test_validate_x_y(self, lenx: int, leny: int, valid: bool): + def init(): + data.Dataset(x, y, nx, ny) + + x, y = self._create_xy() + assert len(x) >= lenx, "Invalid test!" + assert len(y) >= leny, "Invalid test!" + x = x[:lenx] + y = y[:leny] + nx = 1 + ny = None + if valid: + init() + assert True # It worked! + else: + with pytest.raises(data.XYError): + init() + def _create_xy( self, n: int = 7, method: _XYMethod = _XYMethod.RAND, must_be_in_valid_range: bool = True, - ) -> Sequence[torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: """ :return: (n,), (n,) """ @@ -119,11 +193,15 @@ class TestDataset(object): # note: this isn't "valid" data in the sense that it's beyond (-1, 1). # But it is useful for the delay code. assert not must_be_in_valid_range - return torch.tile(torch.arange(n, dtype=torch.float)[None, :], (2, 1)) + return tuple( + torch.tile(torch.arange(n, dtype=torch.float)[None, :], (2, 1)) + ) elif method == _XYMethod.RAND: - return 0.99 * (2.0 * torch.rand((2, n)) - 1.0) # Don't clip + return tuple(0.99 * (2.0 * torch.rand((2, n)) - 1.0)) # Don't clip elif method == _XYMethod.STEP: - return torch.tile((torch.linspace(0.0, 1.0, n) > 0.5)[None, :], (2, 1)) + return tuple( + torch.tile((torch.linspace(0.0, 1.0, n) > 0.5)[None, :], (2, 1)) + ) def _t_apply_delay_float( self, n: int, delay: int, method: data._DelayInterpolationMethod