commit 2959e5303e60d8d2ca1924ac630cf818c3c1dd0c
parent 7808e90e2a8527adfae42bca4942c3399f2d8f46
Author: Steven Atkinson <[email protected]>
Date: Tue, 6 Feb 2024 22:40:12 -0800
[BREAKING] Get rid of `REQUIRED_RATE` (#375)
Get rid of `REQUIRED_RATE`
Diffstat:
7 files changed, 56 insertions(+), 55 deletions(-)
diff --git a/nam/data.py b/nam/data.py
@@ -22,8 +22,6 @@ from ._core import InitializableFromConfig
logger = logging.getLogger(__name__)
-REQUIRED_RATE = 48_000 # FIXME not "required" anymore!
-_DEFAULT_RATE = REQUIRED_RATE # There we go :)
_REQUIRED_CHANNELS = 1 # Mono
@@ -242,8 +240,7 @@ class Dataset(AbstractDataset, InitializableFromConfig):
x_path: Optional[Union[str, Path]] = None,
y_path: Optional[Union[str, Path]] = None,
input_gain: float = 0.0,
- sample_rate: Optional[int] = None,
- rate: Optional[int] = None,
+ sample_rate: Optional[float] = None,
require_input_pre_silence: Optional[float] = _DEFAULT_REQUIRE_INPUT_PRE_SILENCE,
):
"""
@@ -283,16 +280,13 @@ class Dataset(AbstractDataset, InitializableFromConfig):
completely dry signal (i.e. connecting the interface output directly back
into the input with which the guitar was originally recorded.)
:param sample_rate: Sample rate for the data
- :param rate: Sample rate for the data (deprecated)
:param require_input_pre_silence: If provided, require that this much time (in
seconds) preceding the start of the data set (`start`) have a silent input.
If it's not, then raise an exception because the output due to it will leak
into the data set that we're trying to use. If `None`, don't assert.
"""
self._validate_x_y(x, y)
- self._sample_rate = self._validate_sample_rate(
- sample_rate, rate, default=_DEFAULT_RATE
- )
+ self._sample_rate = sample_rate
start, stop = self._validate_start_stop(
x,
y,
@@ -302,7 +296,7 @@ class Dataset(AbstractDataset, InitializableFromConfig):
stop_samples,
start_seconds,
stop_seconds,
- self._sample_rate,
+ self.sample_rate,
)
if not isinstance(delay_interpolation_method, _DelayInterpolationMethod):
delay_interpolation_method = _DelayInterpolationMethod(
@@ -310,7 +304,7 @@ class Dataset(AbstractDataset, InitializableFromConfig):
)
if require_input_pre_silence is not None:
self._validate_preceding_silence(
- x, start, int(require_input_pre_silence * self._sample_rate)
+ x, start, require_input_pre_silence, self.sample_rate
)
x, y = [z[start:stop] for z in (x, y)]
if delay is not None and delay != 0:
@@ -377,9 +371,12 @@ class Dataset(AbstractDataset, InitializableFromConfig):
@classmethod
def parse_config(cls, config):
config = deepcopy(config)
- sample_rate = cls._validate_sample_rate(
- config.pop("sample_rate", None), config.pop("rate", None)
- )
+ if "rate" in config:
+ raise ValueError(
+ "use of `rate` was deprecated in version 0.8. Use `sample_rate` "
+ "instead."
+ )
+ sample_rate = config.pop("sample_rate", None)
x, x_wavinfo = wav_to_tensor(config.pop("x_path"), info=True, rate=sample_rate)
sample_rate = x_wavinfo.rate
try:
@@ -470,25 +467,6 @@ class Dataset(AbstractDataset, InitializableFromConfig):
return x, y
@classmethod
- def _validate_sample_rate(
- cls, sample_rate: Optional[float], rate: Optional[int], default=None
- ) -> float:
- if sample_rate is None and rate is None: # Default value
- return default
- if rate is not None:
- if sample_rate is not None:
- raise ValueError(
- "Provided both sample_rate and rate. Provide only sample_rate!"
- )
- else:
- logger.warning(
- "Use of 'rate' is deprecated and will be removed. Use sample_rate instead"
- )
- return float(rate)
- else:
- return sample_rate
-
- @classmethod
def _validate_start_stop(
cls,
x: torch.Tensor,
@@ -632,19 +610,27 @@ class Dataset(AbstractDataset, InitializableFromConfig):
@classmethod
def _validate_preceding_silence(
- cls, x: torch.Tensor, start: Optional[int], silent_samples: int
+ cls, x: torch.Tensor, start: Optional[int], silent_seconds: float, sample_rate: Optional[float]
):
"""
Make sure that the input is silent before the starting index.
If it's not, then the output from that non-silent input will leak into the data
set and couldn't be predicted!
+ This assumes that silence is indeed required. If it's not, then don't call this!
+
See: Issue #252
:param x: Input
:param start: Where the data starts
:param silent_samples: How many are expected to be silent
"""
+ if sample_rate is None:
+ raise ValueError(
+ f"Pre-silence was required for {silent_seconds} seconds, but no sample "
+ "rate was provided!"
+ )
+ silent_samples = int(silent_seconds * sample_rate)
if start is None:
return
raw_check_start = start - silent_samples
diff --git a/nam/models/_base.py b/nam/models/_base.py
@@ -17,7 +17,7 @@ import torch
import torch.nn as nn
from .._core import InitializableFromConfig
-from ..data import REQUIRED_RATE, wav_to_tensor
+from ..data import wav_to_tensor
from ._exportable import Exportable
@@ -133,7 +133,11 @@ class _Base(nn.Module, InitializableFromConfig, Exportable):
def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]:
args = self._export_input_output_args()
- rate = REQUIRED_RATE
+ rate = self.sample_rate
+ if rate is None:
+ raise RuntimeError(
+ "Cannot export model's input and output without a sample rate."
+ )
x = torch.cat(
[
torch.zeros((rate,)),
diff --git a/nam/models/conv_net.py b/nam/models/conv_net.py
@@ -17,7 +17,7 @@ import torch.nn.functional as F
from .. import __version__
-from ..data import REQUIRED_RATE, wav_to_tensor
+from ..data import wav_to_tensor
from ._activations import get_activation
from ._base import BaseNet
from ._names import ACTIVATION_NAME, BATCHNORM_NAME, CONV_NAME
@@ -217,7 +217,11 @@ class ConvNet(BaseNet):
"""
:return: (L,)
"""
- rate = REQUIRED_RATE
+ rate = self.sample_rate
+ if rate is None:
+ raise RuntimeError(
+ "Cannot export model's input and output without a sample rate."
+ )
return torch.cat(
[
torch.zeros((rate,)),
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -22,7 +22,7 @@ from pydantic import BaseModel
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
-from ..data import REQUIRED_RATE, Split, init_dataset, wav_to_np, wav_to_tensor
+from ..data import Split, init_dataset, wav_to_np, wav_to_tensor
from ..models import Model
from ..models.losses import esr
from ..util import filter_warnings
@@ -30,6 +30,9 @@ from ._version import Version
__all__ = ["train"]
+# Training using the simplified trainers in NAM is done at 48k.
+STANDARD_SAMPLE_RATE = 48_000.0
+
class Architecture(Enum):
STANDARD = "standard"
@@ -222,7 +225,7 @@ class _DataInfo(BaseModel):
"""
major_version: int
- rate: Optional[int]
+ rate: Optional[float]
t_blips: int
first_blips_start: int
t_validate: int
@@ -234,7 +237,7 @@ class _DataInfo(BaseModel):
_V1_DATA_INFO = _DataInfo(
major_version=1,
- rate=REQUIRED_RATE,
+ rate=STANDARD_SAMPLE_RATE,
t_blips=48_000,
first_blips_start=0,
t_validate=432_000,
@@ -254,7 +257,7 @@ _V1_DATA_INFO = _DataInfo(
# (3:09-3:11) Blips at 3:09.5 and 3:10.5
_V2_DATA_INFO = _DataInfo(
major_version=2,
- rate=REQUIRED_RATE,
+ rate=STANDARD_SAMPLE_RATE,
t_blips=96_000,
first_blips_start=0,
t_validate=432_000,
@@ -274,7 +277,7 @@ _V2_DATA_INFO = _DataInfo(
# (3:01-3:10) Validation 2
_V3_DATA_INFO = _DataInfo(
major_version=3,
- rate=REQUIRED_RATE,
+ rate=STANDARD_SAMPLE_RATE,
t_blips=96_000,
first_blips_start=480_000,
t_validate=432_000,
diff --git a/tests/test_bin/test_train/test_main.py b/tests/test_bin/test_train/test_main.py
@@ -13,7 +13,7 @@ import numpy as np
import pytest
import torch
-from nam.data import REQUIRED_RATE, np_to_wav
+from nam.data import np_to_wav
_BIN_TRAIN_MAIN_PY_PATH = Path(__file__).absolute().parent.parent.parent.parent / Path(
"bin", "train", "main.py"
diff --git a/tests/test_nam/test_data.py b/tests/test_nam/test_data.py
@@ -15,7 +15,8 @@ import torch
from nam import data
-_sample_rates = (44_100, 48_000, 88_200, 96_000)
+_SAMPLE_RATES = (44_100.0, 48_000.0, 88_200.0, 96_000.0)
+_DEFAULT_SAMPLE_RATE = 48_000.0
class _XYMethod(Enum):
@@ -85,11 +86,11 @@ class TestDataset(object):
def test_init(self):
x, y = self._create_xy()
- data.Dataset(x, y, 3, None)
+ data.Dataset(x, y, 3, None, sample_rate=_DEFAULT_SAMPLE_RATE)
def test_init_sample_rate(self):
x, y = self._create_xy()
- sample_rate = 48_000.0
+ sample_rate = _DEFAULT_SAMPLE_RATE
d = data.Dataset(x, y, 3, None, sample_rate=sample_rate)
assert hasattr(d, "sample_rate")
assert isinstance(d.sample_rate, float)
@@ -100,7 +101,7 @@ class TestDataset(object):
Assert https://github.com/sdatkinson/neural-amp-modeler/issues/15 fixed
"""
x, y = self._create_xy()
- data.Dataset(x, y, 3, None, delay=0)
+ data.Dataset(x, y, 3, None, delay=0, sample_rate=_DEFAULT_SAMPLE_RATE)
def test_input_gain(self):
"""
@@ -112,14 +113,16 @@ class TestDataset(object):
nx = 3
ny = None
args = (x, y, nx, ny)
- d1 = data.Dataset(*args)
- d2 = data.Dataset(*args, input_gain=input_gain)
+ d1 = data.Dataset(*args, sample_rate=_DEFAULT_SAMPLE_RATE)
+ d2 = data.Dataset(
+ *args, sample_rate=_DEFAULT_SAMPLE_RATE, input_gain=input_gain
+ )
sample_x1 = d1[0][0]
sample_x2 = d2[0][0]
assert torch.allclose(sample_x1 * x_scale, sample_x2)
- @pytest.mark.parametrize("sample_rate", _sample_rates)
+ @pytest.mark.parametrize("sample_rate", _SAMPLE_RATES)
def test_sample_rates(self, sample_rate: int):
"""
Test that datasets with various sample rates can be made
@@ -155,7 +158,7 @@ class TestDataset(object):
"""
def init():
- data.Dataset(x, y, nx, ny, start=start)
+ data.Dataset(x, y, nx, ny, start=start, sample_rate=_DEFAULT_SAMPLE_RATE)
nx = 1
ny = None
@@ -239,7 +242,7 @@ class TestDataset(object):
)
def test_validate_stop(self, n: int, stop: int, valid: bool):
def init():
- data.Dataset(x, y, nx, ny, stop=stop)
+ data.Dataset(x, y, nx, ny, stop=stop, sample_rate=_DEFAULT_SAMPLE_RATE)
nx = 1
ny = None
@@ -257,7 +260,7 @@ class TestDataset(object):
)
def test_validate_x_y(self, lenx: int, leny: int, valid: bool):
def init():
- data.Dataset(x, y, nx, ny)
+ data.Dataset(x, y, nx, ny, sample_rate=_DEFAULT_SAMPLE_RATE)
x, y = self._create_xy()
assert len(x) >= lenx, "Invalid test!"
@@ -345,7 +348,7 @@ class TestWav(object):
# Check if the two arrays are equal
assert y == pytest.approx(x, abs=self.tolerance)
- @pytest.mark.parametrize("sample_rate", _sample_rates)
+ @pytest.mark.parametrize("sample_rate", _SAMPLE_RATES)
def test_np_to_wav_to_np_sample_rates(self, sample_rate: int):
with TemporaryDirectory() as tmpdir:
# Create random numpy array
diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py
@@ -180,7 +180,8 @@ def _make_t_validation_dataset_class(
Dataset._validate_preceding_silence(
x,
data_info.validation_start,
- int(_DEFAULT_REQUIRE_INPUT_PRE_SILENCE * data_info.rate),
+ _DEFAULT_REQUIRE_INPUT_PRE_SILENCE,
+ data_info.rate,
)
return C