commit 039e858eae4e5daf129bc2c3c3c540e07d36cc01
parent 088556bbf713167135b0ab49237548a243c3a940
Author: Steven Atkinson <[email protected]>
Date: Sat, 11 Feb 2023 17:10:03 -0800
Non-integer delays (#84)
* Refactor, int delay correct
* Non-integer delays
Diffstat:
M | nam/data.py | | | 122 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------------- |
M | tests/test_nam/test_data.py | | | 120 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---- |
2 files changed, 213 insertions(+), 29 deletions(-)
diff --git a/nam/data.py b/nam/data.py
@@ -13,6 +13,7 @@ from typing import Dict, Optional, Sequence, Tuple, Union
import numpy as np
import torch
import wavio
+from scipy.interpolate import interp1d
from torch.utils.data import Dataset as _Dataset
from tqdm import tqdm
@@ -120,10 +121,39 @@ class AbstractDataset(_Dataset, abc.ABC):
pass
+class _DelayInterpolationMethod(Enum):
+ """
+ :param LINEAR: Linear interpolation
+ :param CUBIC: Cubic spline interpolation
+ """
+
+ # Note: these match scipy.interpolate.interp1d kwarg "kind"
+ LINEAR = "linear"
+ CUBIC = "cubic"
+
+
+def _interpolate_delay(
+ x: torch.Tensor, delay: float, method: _DelayInterpolationMethod
+) -> np.ndarray:
+ """
+ NOTE: This breaks the gradient tape!
+ """
+ t_in = np.arange(len(x))
+ n_out = len(x) - int(np.ceil(np.abs(delay)))
+ if delay > 0:
+ t_out = np.arange(n_out) + delay
+ elif delay < 0:
+ t_out = np.arange(len(x) - n_out, len(x)) - np.abs(delay)
+
+ return torch.Tensor(
+ interp1d(t_in, x.detach().cpu().numpy(), kind=method.value)(t_out)
+ )
+
+
class Dataset(AbstractDataset, InitializableFromConfig):
"""
Take a pair of matched audio files and serve input + output pairs.
-
+
No conditioning parameters associated w/ the data.
"""
@@ -135,11 +165,14 @@ class Dataset(AbstractDataset, InitializableFromConfig):
ny: Optional[int],
start: Optional[int] = None,
stop: Optional[int] = None,
- delay: Optional[int] = None,
+ delay: Optional[Union[int, float]] = None,
+ delay_interpolation_method: Union[
+ str, _DelayInterpolationMethod
+ ] = _DelayInterpolationMethod.CUBIC,
y_scale: float = 1.0,
x_path: Optional[Union[str, Path]] = None,
y_path: Optional[Union[str, Path]] = None,
- input_gain: float=0.0
+ input_gain: float = 0.0,
):
"""
:param x: The input signal. A 1D array.
@@ -148,31 +181,29 @@ class Dataset(AbstractDataset, InitializableFromConfig):
for a ConvNet, this would be the receptive field.
:param ny: How many samples to provide as the output array for a single "datum".
It's usually more computationally-efficient to provide a larger `ny` than 1
- so that the forward pass can process more audio all at once. However, this
- shouldn't be too large or else you won't be able to provide a large batch
- size (where each input-output pair could be something substantially
+ so that the forward pass can process more audio all at once. However, this
+ shouldn't be too large or else you won't be able to provide a large batch
+ size (where each input-output pair could be something substantially
different and improve batch diversity).
:param start: In samples; clip x and y up to this point.
:param stop: In samples; clip x and y past this point.
- :param y_scale: Multiplies the output signal by a factor (e.g. if the data are
+ :param delay: In samples. Positive means we get rid of the start of x, end of y
+ (i.e. we are correcting for an alignment error in which y is delayed behind
+ x). If a non-integer delay is provided, then y is interpolated, with
+ the extra sample removed.
+ :param y_scale: Multiplies the output signal by a factor (e.g. if the data are
too quiet).
- :param delay: In samples. Positive means we get rid of the start of x, end of y.
- :param input_gain: In dB. If the input signal wasn't fed to the amp at unity
- gain, you can indicate the gain here. The data set will multipy the raw
- audio file by the specified gain so that the true input signal amplitude
+ :param input_gain: In dB. If the input signal wasn't fed to the amp at unity
+ gain, you can indicate the gain here. The data set will multipy the raw
+ audio file by the specified gain so that the true input signal amplitude
experienced by the signal chain will be provided as input to the model. If
- you are using a reamping setup, you can estimate this by reamping a
- completely dry signal (i.e. connecting the interface output directly back
+ you are using a reamping setup, you can estimate this by reamping a
+ completely dry signal (i.e. connecting the interface output directly back
into the input with which the guitar was originally recorded.)
"""
x, y = [z[start:stop] for z in (x, y)]
- if delay is not None:
- if delay > 0:
- x = x[:-delay]
- y = y[delay:]
- elif delay < 0:
- x = x[-delay:]
- y = y[:delay]
+ if delay is not None and delay != 0:
+ x, y = self._apply_delay(x, y, delay, delay_interpolation_method)
x_scale = 10.0 ** (input_gain / 20.0)
x = x * x_scale
y = y * y_scale
@@ -186,7 +217,7 @@ class Dataset(AbstractDataset, InitializableFromConfig):
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
- :return:
+ :return:
Input (NX+NY-1,)
Output (NY,)
"""
@@ -240,12 +271,55 @@ class Dataset(AbstractDataset, InitializableFromConfig):
"y_path": config["y_path"],
}
+ @classmethod
+ def _apply_delay(
+ cls,
+ x: torch.Tensor,
+ y: torch.Tensor,
+ delay: Union[int, float],
+ method: _DelayInterpolationMethod,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if isinstance(delay, int):
+ return cls._apply_delay_int(x, y, delay)
+ elif isinstance(delay, float):
+ return cls._apply_delay_float(x, y, delay, method)
+
+ @classmethod
+ def _apply_delay_int(
+ cls, x: torch.Tensor, y: torch.Tensor, delay: int
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if delay > 0:
+ x = x[:-delay]
+ y = y[delay:]
+ elif delay < 0:
+ x = x[-delay:]
+ y = y[:delay]
+ return x, y
+
+ @classmethod
+ def _apply_delay_float(
+ cls,
+ x: torch.Tensor,
+ y: torch.Tensor,
+ delay: float,
+ method: _DelayInterpolationMethod,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ n_out = len(y) - int(np.ceil(np.abs(delay)))
+ if delay > 0:
+ x = x[:n_out]
+ elif delay < 0:
+ x = x[-n_out:]
+ y = _interpolate_delay(y, delay, method)
+ return x, y
+
def _validate_inputs(self, x, y, nx, ny):
assert x.ndim == 1
assert y.ndim == 1
assert len(x) == len(y)
if nx > len(x):
- raise RuntimeError(f"Input of length {len(x)}, but receptive field is {nx}.")
+ raise RuntimeError(
+ f"Input of length {len(x)}, but receptive field is {nx}."
+ )
if ny is not None:
assert ny <= len(y) - nx + 1
if torch.abs(y).max() >= 1.0:
@@ -304,7 +378,7 @@ class ParametricDataset(Dataset):
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
- :return:
+ :return:
Parameter values (D,)
Input (NX+NY-1,)
Output (NY,)
@@ -379,7 +453,7 @@ class ConcatDataset(AbstractDataset, InitializableFromConfig):
j += 1
lookup[i] = (j, offset)
offset += 1
- assert j == len(self.datasets)-1
+ assert j == len(self.datasets) - 1
assert offset == len(self.datasets[-1])
return lookup
diff --git a/tests/test_nam/test_data.py b/tests/test_nam/test_data.py
@@ -3,14 +3,81 @@
# Author: Steven Atkinson ([email protected])
import math
+from enum import Enum
+from typing import Sequence
+import numpy as np
import pytest
import torch
from nam import data
+class _XYMethod(Enum):
+ ARANGE = "arange"
+ RAND = "rand"
+ STEP = "step"
+
+
class TestDataset(object):
+ """
+ Assertions about nam.data.Dataset
+ """
+
+ def test_apply_delay_zero(self):
+ """
+ Assert proper function of Dataset._apply_delay() when zero delay is given, i.e.
+ no change.
+ """
+ x, y = self._create_xy()
+ x_out, y_out = data.Dataset._apply_delay(
+ x, y, 0, data._DelayInterpolationMethod.CUBIC
+ )
+ assert torch.all(x == x_out)
+ assert torch.all(y == y_out)
+
+ @pytest.mark.parametrize("method", (data._DelayInterpolationMethod))
+ def test_apply_delay_float_negative(self, method):
+ n = 7
+ delay = -2.5
+ x_out, y_out = self._t_apply_delay_float(n, delay, method)
+
+ assert torch.all(x_out == torch.Tensor([3, 4, 5, 6]))
+ assert torch.all(y_out == torch.Tensor([0.5, 1.5, 2.5, 3.5]))
+
+ @pytest.mark.parametrize("method", (data._DelayInterpolationMethod))
+ def test_apply_delay_float_positive(self, method):
+ n = 7
+ delay = 2.5
+ x_out, y_out = self._t_apply_delay_float(n, delay, method)
+
+ assert torch.all(x_out == torch.Tensor([0, 1, 2, 3]))
+ assert torch.all(y_out == torch.Tensor([2.5, 3.5, 4.5, 5.5]))
+
+ def test_apply_delay_int_negative(self):
+ """
+ Assert proper function of Dataset._apply_delay() when a positive integer delay
+ is given.
+ """
+ n = 7
+ delay = -3
+ x_out, y_out = self._t_apply_delay_int(n, delay)
+
+ assert torch.all(x_out == torch.Tensor([3, 4, 5, 6]))
+ assert torch.all(y_out == torch.Tensor([0, 1, 2, 3]))
+
+ def test_apply_delay_int_positive(self):
+ """
+ Assert proper function of Dataset._apply_delay() when a positive integer delay
+ is given.
+ """
+ n = 7
+ delay = 3
+ x_out, y_out = self._t_apply_delay_int(n, delay)
+
+ assert torch.all(x_out == torch.Tensor([0, 1, 2, 3]))
+ assert torch.all(y_out == torch.Tensor([3, 4, 5, 6]))
+
def test_init(self):
x, y = self._create_xy()
data.Dataset(x, y, 3, None)
@@ -22,7 +89,6 @@ class TestDataset(object):
x, y = self._create_xy()
data.Dataset(x, y, 3, None, delay=0)
-
def test_input_gain(self):
"""
Checks correctness of input gain parameter
@@ -40,9 +106,54 @@ class TestDataset(object):
sample_x2 = d2[0][0]
assert torch.allclose(sample_x1 * x_scale, sample_x2)
- def _create_xy(self):
- return 0.99 * (2.0 * torch.rand((2, 7)) - 1.0) # Don't clip
+ def _create_xy(
+ self,
+ n: int = 7,
+ method: _XYMethod = _XYMethod.RAND,
+ must_be_in_valid_range: bool = True,
+ ) -> Sequence[torch.Tensor]:
+ """
+ :return: (n,), (n,)
+ """
+ if method == _XYMethod.ARANGE:
+ # note: this isn't "valid" data in the sense that it's beyond (-1, 1).
+ # But it is useful for the delay code.
+ assert not must_be_in_valid_range
+ return torch.tile(torch.arange(n, dtype=torch.float)[None, :], (2, 1))
+ elif method == _XYMethod.RAND:
+ return 0.99 * (2.0 * torch.rand((2, n)) - 1.0) # Don't clip
+ elif method == _XYMethod.STEP:
+ return torch.tile((torch.linspace(0.0, 1.0, n) > 0.5)[None, :], (2, 1))
+
+ def _t_apply_delay_float(
+ self, n: int, delay: int, method: data._DelayInterpolationMethod
+ ):
+ x, y = self._create_xy(
+ n=n, method=_XYMethod.ARANGE, must_be_in_valid_range=False
+ )
+
+ x_out, y_out = data.Dataset._apply_delay(x, y, delay, method)
+ # 7, +/-2.5 -> 4
+ n_out = n - int(np.ceil(np.abs(delay)))
+ assert len(x_out) == n_out
+ assert len(y_out) == n_out
+
+ return x_out, y_out
+
+ def _t_apply_delay_int(self, n: int, delay: int):
+ x, y = self._create_xy(
+ n=n, method=_XYMethod.ARANGE, must_be_in_valid_range=False
+ )
+
+ x_out, y_out = data.Dataset._apply_delay(
+ x, y, delay, data._DelayInterpolationMethod.CUBIC
+ )
+ n_out = n - np.abs(delay)
+ assert len(x_out) == n_out
+ assert len(y_out) == n_out
+
+ return x_out, y_out
if __name__ == "__main__":
- pytest.main()
-\ No newline at end of file
+ pytest.main()