neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 039e858eae4e5daf129bc2c3c3c540e07d36cc01
parent 088556bbf713167135b0ab49237548a243c3a940
Author: Steven Atkinson <[email protected]>
Date:   Sat, 11 Feb 2023 17:10:03 -0800

Non-integer delays (#84)

* Refactor, int delay correct

* Non-integer delays
Diffstat:
Mnam/data.py | 122+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------
Mtests/test_nam/test_data.py | 120+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----
2 files changed, 213 insertions(+), 29 deletions(-)

diff --git a/nam/data.py b/nam/data.py @@ -13,6 +13,7 @@ from typing import Dict, Optional, Sequence, Tuple, Union import numpy as np import torch import wavio +from scipy.interpolate import interp1d from torch.utils.data import Dataset as _Dataset from tqdm import tqdm @@ -120,10 +121,39 @@ class AbstractDataset(_Dataset, abc.ABC): pass +class _DelayInterpolationMethod(Enum): + """ + :param LINEAR: Linear interpolation + :param CUBIC: Cubic spline interpolation + """ + + # Note: these match scipy.interpolate.interp1d kwarg "kind" + LINEAR = "linear" + CUBIC = "cubic" + + +def _interpolate_delay( + x: torch.Tensor, delay: float, method: _DelayInterpolationMethod +) -> np.ndarray: + """ + NOTE: This breaks the gradient tape! + """ + t_in = np.arange(len(x)) + n_out = len(x) - int(np.ceil(np.abs(delay))) + if delay > 0: + t_out = np.arange(n_out) + delay + elif delay < 0: + t_out = np.arange(len(x) - n_out, len(x)) - np.abs(delay) + + return torch.Tensor( + interp1d(t_in, x.detach().cpu().numpy(), kind=method.value)(t_out) + ) + + class Dataset(AbstractDataset, InitializableFromConfig): """ Take a pair of matched audio files and serve input + output pairs. - + No conditioning parameters associated w/ the data. """ @@ -135,11 +165,14 @@ class Dataset(AbstractDataset, InitializableFromConfig): ny: Optional[int], start: Optional[int] = None, stop: Optional[int] = None, - delay: Optional[int] = None, + delay: Optional[Union[int, float]] = None, + delay_interpolation_method: Union[ + str, _DelayInterpolationMethod + ] = _DelayInterpolationMethod.CUBIC, y_scale: float = 1.0, x_path: Optional[Union[str, Path]] = None, y_path: Optional[Union[str, Path]] = None, - input_gain: float=0.0 + input_gain: float = 0.0, ): """ :param x: The input signal. A 1D array. @@ -148,31 +181,29 @@ class Dataset(AbstractDataset, InitializableFromConfig): for a ConvNet, this would be the receptive field. :param ny: How many samples to provide as the output array for a single "datum". It's usually more computationally-efficient to provide a larger `ny` than 1 - so that the forward pass can process more audio all at once. However, this - shouldn't be too large or else you won't be able to provide a large batch - size (where each input-output pair could be something substantially + so that the forward pass can process more audio all at once. However, this + shouldn't be too large or else you won't be able to provide a large batch + size (where each input-output pair could be something substantially different and improve batch diversity). :param start: In samples; clip x and y up to this point. :param stop: In samples; clip x and y past this point. - :param y_scale: Multiplies the output signal by a factor (e.g. if the data are + :param delay: In samples. Positive means we get rid of the start of x, end of y + (i.e. we are correcting for an alignment error in which y is delayed behind + x). If a non-integer delay is provided, then y is interpolated, with + the extra sample removed. + :param y_scale: Multiplies the output signal by a factor (e.g. if the data are too quiet). - :param delay: In samples. Positive means we get rid of the start of x, end of y. - :param input_gain: In dB. If the input signal wasn't fed to the amp at unity - gain, you can indicate the gain here. The data set will multipy the raw - audio file by the specified gain so that the true input signal amplitude + :param input_gain: In dB. If the input signal wasn't fed to the amp at unity + gain, you can indicate the gain here. The data set will multipy the raw + audio file by the specified gain so that the true input signal amplitude experienced by the signal chain will be provided as input to the model. If - you are using a reamping setup, you can estimate this by reamping a - completely dry signal (i.e. connecting the interface output directly back + you are using a reamping setup, you can estimate this by reamping a + completely dry signal (i.e. connecting the interface output directly back into the input with which the guitar was originally recorded.) """ x, y = [z[start:stop] for z in (x, y)] - if delay is not None: - if delay > 0: - x = x[:-delay] - y = y[delay:] - elif delay < 0: - x = x[-delay:] - y = y[:delay] + if delay is not None and delay != 0: + x, y = self._apply_delay(x, y, delay, delay_interpolation_method) x_scale = 10.0 ** (input_gain / 20.0) x = x * x_scale y = y * y_scale @@ -186,7 +217,7 @@ class Dataset(AbstractDataset, InitializableFromConfig): def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: """ - :return: + :return: Input (NX+NY-1,) Output (NY,) """ @@ -240,12 +271,55 @@ class Dataset(AbstractDataset, InitializableFromConfig): "y_path": config["y_path"], } + @classmethod + def _apply_delay( + cls, + x: torch.Tensor, + y: torch.Tensor, + delay: Union[int, float], + method: _DelayInterpolationMethod, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if isinstance(delay, int): + return cls._apply_delay_int(x, y, delay) + elif isinstance(delay, float): + return cls._apply_delay_float(x, y, delay, method) + + @classmethod + def _apply_delay_int( + cls, x: torch.Tensor, y: torch.Tensor, delay: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + if delay > 0: + x = x[:-delay] + y = y[delay:] + elif delay < 0: + x = x[-delay:] + y = y[:delay] + return x, y + + @classmethod + def _apply_delay_float( + cls, + x: torch.Tensor, + y: torch.Tensor, + delay: float, + method: _DelayInterpolationMethod, + ) -> Tuple[torch.Tensor, torch.Tensor]: + n_out = len(y) - int(np.ceil(np.abs(delay))) + if delay > 0: + x = x[:n_out] + elif delay < 0: + x = x[-n_out:] + y = _interpolate_delay(y, delay, method) + return x, y + def _validate_inputs(self, x, y, nx, ny): assert x.ndim == 1 assert y.ndim == 1 assert len(x) == len(y) if nx > len(x): - raise RuntimeError(f"Input of length {len(x)}, but receptive field is {nx}.") + raise RuntimeError( + f"Input of length {len(x)}, but receptive field is {nx}." + ) if ny is not None: assert ny <= len(y) - nx + 1 if torch.abs(y).max() >= 1.0: @@ -304,7 +378,7 @@ class ParametricDataset(Dataset): def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - :return: + :return: Parameter values (D,) Input (NX+NY-1,) Output (NY,) @@ -379,7 +453,7 @@ class ConcatDataset(AbstractDataset, InitializableFromConfig): j += 1 lookup[i] = (j, offset) offset += 1 - assert j == len(self.datasets)-1 + assert j == len(self.datasets) - 1 assert offset == len(self.datasets[-1]) return lookup diff --git a/tests/test_nam/test_data.py b/tests/test_nam/test_data.py @@ -3,14 +3,81 @@ # Author: Steven Atkinson ([email protected]) import math +from enum import Enum +from typing import Sequence +import numpy as np import pytest import torch from nam import data +class _XYMethod(Enum): + ARANGE = "arange" + RAND = "rand" + STEP = "step" + + class TestDataset(object): + """ + Assertions about nam.data.Dataset + """ + + def test_apply_delay_zero(self): + """ + Assert proper function of Dataset._apply_delay() when zero delay is given, i.e. + no change. + """ + x, y = self._create_xy() + x_out, y_out = data.Dataset._apply_delay( + x, y, 0, data._DelayInterpolationMethod.CUBIC + ) + assert torch.all(x == x_out) + assert torch.all(y == y_out) + + @pytest.mark.parametrize("method", (data._DelayInterpolationMethod)) + def test_apply_delay_float_negative(self, method): + n = 7 + delay = -2.5 + x_out, y_out = self._t_apply_delay_float(n, delay, method) + + assert torch.all(x_out == torch.Tensor([3, 4, 5, 6])) + assert torch.all(y_out == torch.Tensor([0.5, 1.5, 2.5, 3.5])) + + @pytest.mark.parametrize("method", (data._DelayInterpolationMethod)) + def test_apply_delay_float_positive(self, method): + n = 7 + delay = 2.5 + x_out, y_out = self._t_apply_delay_float(n, delay, method) + + assert torch.all(x_out == torch.Tensor([0, 1, 2, 3])) + assert torch.all(y_out == torch.Tensor([2.5, 3.5, 4.5, 5.5])) + + def test_apply_delay_int_negative(self): + """ + Assert proper function of Dataset._apply_delay() when a positive integer delay + is given. + """ + n = 7 + delay = -3 + x_out, y_out = self._t_apply_delay_int(n, delay) + + assert torch.all(x_out == torch.Tensor([3, 4, 5, 6])) + assert torch.all(y_out == torch.Tensor([0, 1, 2, 3])) + + def test_apply_delay_int_positive(self): + """ + Assert proper function of Dataset._apply_delay() when a positive integer delay + is given. + """ + n = 7 + delay = 3 + x_out, y_out = self._t_apply_delay_int(n, delay) + + assert torch.all(x_out == torch.Tensor([0, 1, 2, 3])) + assert torch.all(y_out == torch.Tensor([3, 4, 5, 6])) + def test_init(self): x, y = self._create_xy() data.Dataset(x, y, 3, None) @@ -22,7 +89,6 @@ class TestDataset(object): x, y = self._create_xy() data.Dataset(x, y, 3, None, delay=0) - def test_input_gain(self): """ Checks correctness of input gain parameter @@ -40,9 +106,54 @@ class TestDataset(object): sample_x2 = d2[0][0] assert torch.allclose(sample_x1 * x_scale, sample_x2) - def _create_xy(self): - return 0.99 * (2.0 * torch.rand((2, 7)) - 1.0) # Don't clip + def _create_xy( + self, + n: int = 7, + method: _XYMethod = _XYMethod.RAND, + must_be_in_valid_range: bool = True, + ) -> Sequence[torch.Tensor]: + """ + :return: (n,), (n,) + """ + if method == _XYMethod.ARANGE: + # note: this isn't "valid" data in the sense that it's beyond (-1, 1). + # But it is useful for the delay code. + assert not must_be_in_valid_range + return torch.tile(torch.arange(n, dtype=torch.float)[None, :], (2, 1)) + elif method == _XYMethod.RAND: + return 0.99 * (2.0 * torch.rand((2, n)) - 1.0) # Don't clip + elif method == _XYMethod.STEP: + return torch.tile((torch.linspace(0.0, 1.0, n) > 0.5)[None, :], (2, 1)) + + def _t_apply_delay_float( + self, n: int, delay: int, method: data._DelayInterpolationMethod + ): + x, y = self._create_xy( + n=n, method=_XYMethod.ARANGE, must_be_in_valid_range=False + ) + + x_out, y_out = data.Dataset._apply_delay(x, y, delay, method) + # 7, +/-2.5 -> 4 + n_out = n - int(np.ceil(np.abs(delay))) + assert len(x_out) == n_out + assert len(y_out) == n_out + + return x_out, y_out + + def _t_apply_delay_int(self, n: int, delay: int): + x, y = self._create_xy( + n=n, method=_XYMethod.ARANGE, must_be_in_valid_range=False + ) + + x_out, y_out = data.Dataset._apply_delay( + x, y, delay, data._DelayInterpolationMethod.CUBIC + ) + n_out = n - np.abs(delay) + assert len(x_out) == n_out + assert len(y_out) == n_out + + return x_out, y_out if __name__ == "__main__": - pytest.main() -\ No newline at end of file + pytest.main()