neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

data.py (28905B)


      1 # File: data.py
      2 # Created Date: Saturday February 5th 2022
      3 # Author: Steven Atkinson ([email protected])
      4 
      5 """
      6 Functions and classes for working with audio data with NAM
      7 """
      8 
      9 import abc as _abc
     10 import logging as _logging
     11 from collections import namedtuple as _namedtuple
     12 from copy import deepcopy as _deepcopy
     13 from dataclasses import dataclass as _dataclass
     14 from enum import Enum as _Enum
     15 from pathlib import Path as _Path
     16 from typing import (
     17     Any as _Any,
     18     Callable as _Callable,
     19     Dict as _Dict,
     20     Optional as _Optional,
     21     Sequence as _Sequence,
     22     Tuple as _Tuple,
     23     Union as _Union,
     24 )
     25 
     26 import numpy as _np
     27 import torch as _torch
     28 import wavio as _wavio
     29 from scipy.interpolate import interp1d as _interp1d
     30 from torch.utils.data import Dataset as _Dataset
     31 from tqdm import tqdm as _tqdm
     32 
     33 from ._core import InitializableFromConfig as _InitializableFromConfig
     34 
     35 logger = _logging.getLogger(__name__)
     36 
     37 _REQUIRED_CHANNELS = 1  # Mono
     38 
     39 
     40 class Split(_Enum):
     41     TRAIN = "train"
     42     VALIDATION = "validation"
     43 
     44 
     45 @_dataclass
     46 class WavInfo:
     47     sampwidth: int
     48     rate: int
     49 
     50 
     51 class DataError(Exception):
     52     """
     53     Parent class for all special exceptions raised by NAM data sets
     54     """
     55 
     56     pass
     57 
     58 
     59 class AudioShapeMismatchError(ValueError, DataError):
     60     """
     61     Exception where the shape (number of samples, number of channels) of two audio files
     62     don't match but were supposed to.
     63     """
     64 
     65     def __init__(self, shape_expected, shape_actual, *args, **kwargs):
     66         super().__init__(*args, **kwargs)
     67         self._shape_expected = shape_expected
     68         self._shape_actual = shape_actual
     69 
     70     @property
     71     def shape_expected(self):
     72         return self._shape_expected
     73 
     74     @property
     75     def shape_actual(self):
     76         return self._shape_actual
     77 
     78 
     79 def wav_to_np(
     80     filename: _Union[str, _Path],
     81     rate: _Optional[int] = None,
     82     require_match: _Optional[_Union[str, _Path]] = None,
     83     required_shape: _Optional[_Tuple[int, ...]] = None,
     84     required_wavinfo: _Optional[WavInfo] = None,
     85     preroll: _Optional[int] = None,
     86     info: bool = False,
     87 ) -> _Union[_np.ndarray, _Tuple[_np.ndarray, WavInfo]]:
     88     """
     89     :param filename: Where to load from
     90     :param rate: Expected sample rate. `None` allows for anything.
     91     :param require_match: If not `None`, assert that the data you get matches the shape
     92         and other characteristics of another audio file at the provided location
     93     :param required_shape: If not `None`, assert that the audio loaded is of shape
     94         `(num_samples, num_channels)`.
     95     :param required_wavinfo: If not `None`, assert that the WAV info of the laoded audio
     96         matches that provided.
     97     :param preroll: Drop this many samples off the front
     98     :param info: If `True`, also return the WAV info of this file.
     99     """
    100     x_wav = _wavio.read(str(filename))
    101     assert x_wav.data.shape[1] == _REQUIRED_CHANNELS, "Mono"
    102     if rate is not None and x_wav.rate != rate:
    103         raise RuntimeError(
    104             f"Explicitly expected sample rate of {rate}, but found {x_wav.rate} in "
    105             f"file {filename}!"
    106         )
    107 
    108     if require_match is not None:
    109         assert required_shape is None
    110         assert required_wavinfo is None
    111         y_wav = _wavio.read(str(require_match))
    112         required_shape = y_wav.data.shape
    113         required_wavinfo = WavInfo(y_wav.sampwidth, y_wav.rate)
    114     if required_wavinfo is not None:
    115         if x_wav.rate != required_wavinfo.rate:
    116             raise ValueError(
    117                 f"Mismatched rates {x_wav.rate} versus {required_wavinfo.rate}"
    118             )
    119     arr_premono = x_wav.data[preroll:] / (2.0 ** (8 * x_wav.sampwidth - 1))
    120     if required_shape is not None:
    121         if arr_premono.shape != required_shape:
    122             raise AudioShapeMismatchError(
    123                 required_shape,  # Expected
    124                 arr_premono.shape,  # Actual
    125                 f"Mismatched shapes. Expected {required_shape}, but this is "
    126                 f"{arr_premono.shape}!",
    127             )
    128         # sampwidth fine--we're just casting to 32-bit float anyways
    129     arr = arr_premono[:, 0]
    130     return arr if not info else (arr, WavInfo(x_wav.sampwidth, x_wav.rate))
    131 
    132 
    133 def wav_to_tensor(
    134     *args, info: bool = False, **kwargs
    135 ) -> _Union[_torch.Tensor, _Tuple[_torch.Tensor, WavInfo]]:
    136     out = wav_to_np(*args, info=info, **kwargs)
    137     if info:
    138         arr, info = out
    139         return _torch.Tensor(arr), info
    140     else:
    141         arr = out
    142         return _torch.Tensor(arr)
    143 
    144 
    145 def tensor_to_wav(x: _torch.Tensor, *args, **kwargs):
    146     np_to_wav(x.detach().cpu().numpy(), *args, **kwargs)
    147 
    148 
    149 def np_to_wav(
    150     x: _np.ndarray,
    151     filename: _Union[str, _Path],
    152     rate: int = 48_000,
    153     sampwidth: int = 3,
    154     scale=None,
    155     **kwargs,
    156 ):
    157     if _wavio.__version__ <= "0.0.4" and scale is None:
    158         scale = "none"
    159     _wavio.write(
    160         str(filename),
    161         (_np.clip(x, -1.0, 1.0) * (2 ** (8 * sampwidth - 1))).astype(_np.int32),
    162         rate,
    163         scale=scale,
    164         sampwidth=sampwidth,
    165         **kwargs,
    166     )
    167 
    168 
    169 class AbstractDataset(_Dataset, _abc.ABC):
    170     @_abc.abstractmethod
    171     def __getitem__(self, idx: int):
    172         """
    173         Get input and output audio segment for training / evaluation.
    174         :return:
    175         """
    176         pass
    177 
    178 
    179 class _DelayInterpolationMethod(_Enum):
    180     """
    181     :param LINEAR: Linear interpolation
    182     :param CUBIC: Cubic spline interpolation
    183     """
    184 
    185     # Note: these match scipy.interpolate.interp1d kwarg "kind"
    186     LINEAR = "linear"
    187     CUBIC = "cubic"
    188 
    189 
    190 def _interpolate_delay(
    191     x: _torch.Tensor, delay: float, method: _DelayInterpolationMethod
    192 ) -> _np.ndarray:
    193     """
    194     NOTE: This breaks the gradient tape!
    195     """
    196     if delay == 0.0:
    197         return x
    198     t_in = _np.arange(len(x))
    199     n_out = len(x) - int(_np.ceil(_np.abs(delay)))
    200     if delay > 0:
    201         t_out = _np.arange(n_out) + delay
    202     elif delay < 0:
    203         t_out = _np.arange(len(x) - n_out, len(x)) - _np.abs(delay)
    204 
    205     return _torch.Tensor(
    206         _interp1d(t_in, x.detach().cpu().numpy(), kind=method.value)(t_out)
    207     )
    208 
    209 
    210 class XYError(ValueError, DataError):
    211     """
    212     Exceptions related to invalid x and y provided for data sets
    213     """
    214 
    215     pass
    216 
    217 
    218 class StartStopError(ValueError, DataError):
    219     """
    220     Exceptions related to invalid start and stop arguments
    221     """
    222 
    223     pass
    224 
    225 
    226 class StartError(StartStopError):
    227     pass
    228 
    229 
    230 class StopError(StartStopError):
    231     pass
    232 
    233 
    234 # In seconds. Can't be 0.5 or else v1.wav is invalid! Oops!
    235 _DEFAULT_REQUIRE_INPUT_PRE_SILENCE = 0.4
    236 
    237 
    238 def _sample_to_time(s, rate):
    239     seconds = s // rate
    240     remainder = s % rate
    241     hours, minutes = 0, 0
    242     seconds_per_hour = 3600
    243     while seconds >= seconds_per_hour:
    244         hours += 1
    245         seconds -= seconds_per_hour
    246     seconds_per_minute = 60
    247     while seconds >= seconds_per_minute:
    248         minutes += 1
    249         seconds -= seconds_per_minute
    250     return f"{hours}:{minutes:02d}:{seconds:02d} and {remainder} samples"
    251 
    252 
    253 class Dataset(AbstractDataset, _InitializableFromConfig):
    254     """
    255     Take a pair of matched audio files and serve input + output pairs.
    256     """
    257 
    258     def __init__(
    259         self,
    260         x: _torch.Tensor,
    261         y: _torch.Tensor,
    262         nx: int,
    263         ny: _Optional[int],
    264         start: _Optional[int] = None,
    265         stop: _Optional[int] = None,
    266         start_samples: _Optional[int] = None,
    267         stop_samples: _Optional[int] = None,
    268         start_seconds: _Optional[_Union[int, float]] = None,
    269         stop_seconds: _Optional[_Union[int, float]] = None,
    270         delay: _Optional[_Union[int, float]] = None,
    271         delay_interpolation_method: _Union[
    272             str, _DelayInterpolationMethod
    273         ] = _DelayInterpolationMethod.CUBIC,
    274         y_scale: float = 1.0,
    275         x_path: _Optional[_Union[str, _Path]] = None,
    276         y_path: _Optional[_Union[str, _Path]] = None,
    277         input_gain: float = 0.0,
    278         sample_rate: _Optional[float] = None,
    279         require_input_pre_silence: _Optional[
    280             float
    281         ] = _DEFAULT_REQUIRE_INPUT_PRE_SILENCE,
    282     ):
    283         """
    284         :param x: The input signal. A 1D array.
    285         :param y: The associated output from the model. A 1D array.
    286         :param nx: The number of samples required as input for the model. For example,
    287             for a ConvNet, this would be the receptive field.
    288         :param ny: How many samples to provide as the output array for a single "datum".
    289             It's usually more computationally-efficient to provide a larger `ny` than 1
    290             so that the forward pass can process more audio all at once. However, this
    291             shouldn't be too large or else you won't be able to provide a large batch
    292             size (where each input-output pair could be something substantially
    293             different and improve batch diversity).
    294         :param start: [DEPRECATED; use start_samples instead.] In samples; clip x and y
    295             at this point. Negative values are taken from the end of the audio.
    296         :param stop: [DEPRECATED; use stop_samples instead.] In samples; clip x and y at
    297             this point. Negative values are taken from the end of the audio.
    298         :param start_samples: Clip x and y at this point. Negative values are taken from
    299             the end of the audio.
    300         :param stop: Clip x and y at this point. Negative values are taken from the end
    301             of the audio.
    302         :param start_seconds: Clip x and y at this point. Negative values are taken from
    303             the end of the audio. Requires providing `sample_rate`.
    304         :param stop_seconds: Clip x and y at this point. Negative values are taken from
    305             the end of the audio. Requires providing `sample_rate`.
    306         :param delay: In samples. Positive means we get rid of the start of x, end of y
    307             (i.e. we are correcting for an alignment error in which y is delayed behind
    308             x). If a non-integer delay is provided, then y is interpolated, with
    309             the extra sample removed.
    310         :param y_scale: Multiplies the output signal by a factor (e.g. if the data are
    311             too quiet).
    312         :param input_gain: In dB. If the input signal wasn't fed to the amp at unity
    313             gain, you can indicate the gain here. The data set will multipy the raw
    314             audio file by the specified gain so that the true input signal amplitude
    315             experienced by the signal chain will be provided as input to the model. If
    316             you are using a reamping setup, you can estimate this by reamping a
    317             completely dry signal (i.e. connecting the interface output directly back
    318             into the input with which the guitar was originally recorded.)
    319         :param sample_rate: Sample rate for the data
    320         :param require_input_pre_silence: If provided, require that this much time (in
    321             seconds) preceding the start of the data set (`start`) have a silent input.
    322             If it's not, then raise an exception because the output due to it will leak
    323             into the data set that we're trying to use. If `None`, don't assert.
    324         """
    325         self._validate_x_y(x, y)
    326         self._sample_rate = sample_rate
    327         start, stop = self._validate_start_stop(
    328             x,
    329             y,
    330             start,
    331             stop,
    332             start_samples,
    333             stop_samples,
    334             start_seconds,
    335             stop_seconds,
    336             self.sample_rate,
    337         )
    338         if not isinstance(delay_interpolation_method, _DelayInterpolationMethod):
    339             delay_interpolation_method = _DelayInterpolationMethod(
    340                 delay_interpolation_method
    341             )
    342         if require_input_pre_silence is not None:
    343             self._validate_preceding_silence(
    344                 x, start, require_input_pre_silence, self.sample_rate
    345             )
    346         x, y = [z[start:stop] for z in (x, y)]
    347         if delay is not None and delay != 0:
    348             x, y = self._apply_delay(x, y, delay, delay_interpolation_method)
    349         x_scale = 10.0 ** (input_gain / 20.0)
    350         x = x * x_scale
    351         y = y * y_scale
    352         self._x_path = x_path
    353         self._y_path = y_path
    354         self._validate_inputs_after_processing(x, y, nx, ny)
    355         self._x = x
    356         self._y = y
    357         self._nx = nx
    358         self._ny = ny if ny is not None else len(x) - nx + 1
    359 
    360     def __getitem__(self, idx: int) -> _Tuple[_torch.Tensor, _torch.Tensor]:
    361         """
    362         :return:
    363             Input (NX+NY-1,)
    364             Output (NY,)
    365         """
    366         if idx >= len(self):
    367             raise IndexError(f"Attempted to access datum {idx}, but len is {len(self)}")
    368         i = idx * self._ny
    369         j = i + self.y_offset
    370         return self.x[i : i + self._nx + self._ny - 1], self.y[j : j + self._ny]
    371 
    372     def __len__(self) -> int:
    373         n = len(self.x)
    374         # If ny were 1
    375         single_pairs = n - self._nx + 1
    376         return single_pairs // self._ny
    377 
    378     @property
    379     def ny(self) -> int:
    380         return self._ny
    381 
    382     @property
    383     def sample_rate(self) -> _Optional[float]:
    384         return self._sample_rate
    385 
    386     @property
    387     def x(self) -> _torch.Tensor:
    388         """
    389         The input audio data
    390 
    391         :return: (N,)
    392         """
    393         return self._x
    394 
    395     @property
    396     def y(self) -> _torch.Tensor:
    397         """
    398         The output audio data
    399 
    400         :return: (N,)
    401         """
    402         return self._y
    403 
    404     @property
    405     def y_offset(self) -> int:
    406         return self._nx - 1
    407 
    408     @classmethod
    409     def parse_config(cls, config):
    410         """
    411         :param config:
    412             Must contain:
    413                 x_path (path-like)
    414                 y_path (path-like)
    415             May contain:
    416                 sample_rate (int)
    417                 y_preroll (int)
    418                 allow_unequal_lengths (bool)
    419             Must NOT contain:
    420                 x (torch.Tensor) - loaded from x_path
    421                 y (torch.Tensor) - loaded from y_path
    422             Everything else is passed on to __init__
    423         """
    424         config = _deepcopy(config)
    425         sample_rate = config.pop("sample_rate", None)
    426         x, x_wavinfo = wav_to_tensor(config.pop("x_path"), info=True, rate=sample_rate)
    427         sample_rate = x_wavinfo.rate
    428         if config.pop("allow_unequal_lengths", False):
    429             y = wav_to_tensor(
    430                 config.pop("y_path"),
    431                 rate=sample_rate,
    432                 preroll=config.pop("y_preroll", None),
    433                 required_wavinfo=x_wavinfo,
    434             )
    435             # Truncate to the shorter of the two
    436             if len(x) == 0:
    437                 raise DataError("Input is zero-length!")
    438             if len(y) == 0:
    439                 raise DataError("Output is zero-length!")
    440             n = min(len(x), len(y))
    441             if n < len(x):
    442                 print(f"Truncating input to {_sample_to_time(n, sample_rate)}")
    443             if n < len(y):
    444                 print(f"Truncating output to {_sample_to_time(n, sample_rate)}")
    445             x, y = [z[:n] for z in (x, y)]
    446         else:
    447             try:
    448                 y = wav_to_tensor(
    449                     config.pop("y_path"),
    450                     rate=sample_rate,
    451                     preroll=config.pop("y_preroll", None),
    452                     required_shape=(len(x), 1),
    453                     required_wavinfo=x_wavinfo,
    454                 )
    455             except AudioShapeMismatchError as e:
    456                 # Really verbose message since users see this.
    457                 x_samples, x_channels = e.shape_expected
    458                 y_samples, y_channels = e.shape_actual
    459                 msg = "Your audio files aren't the same shape as each other!"
    460                 if x_channels != y_channels:
    461                     channels_to_stereo_mono = {1: "mono", 2: "stereo"}
    462                     msg += f"\n * The input is {channels_to_stereo_mono[x_channels]}, but the output is {channels_to_stereo_mono[y_channels]}!"
    463                 if x_samples != y_samples:
    464                     msg += f"\n * The input is {_sample_to_time(x_samples, sample_rate)} long"
    465                     msg += f"\n * The output is {_sample_to_time(y_samples, sample_rate)} long"
    466                     msg += f"\n\nOriginal exception:\n{e}"
    467                 raise DataError(msg)
    468         return {"x": x, "y": y, "sample_rate": sample_rate, **config}
    469 
    470     @classmethod
    471     def _apply_delay(
    472         cls,
    473         x: _torch.Tensor,
    474         y: _torch.Tensor,
    475         delay: _Union[int, float],
    476         method: _DelayInterpolationMethod,
    477     ) -> _Tuple[_torch.Tensor, _torch.Tensor]:
    478         # Check for floats that could be treated like ints (simpler algorithm)
    479         if isinstance(delay, float) and int(delay) == delay:
    480             delay = int(delay)
    481         if isinstance(delay, int):
    482             return cls._apply_delay_int(x, y, delay)
    483         elif isinstance(delay, float):
    484             return cls._apply_delay_float(x, y, delay, method)
    485         else:
    486             raise TypeError(type(delay))
    487 
    488     @classmethod
    489     def _apply_delay_int(
    490         cls, x: _torch.Tensor, y: _torch.Tensor, delay: int
    491     ) -> _Tuple[_torch.Tensor, _torch.Tensor]:
    492         if delay > 0:
    493             x = x[:-delay]
    494             y = y[delay:]
    495         elif delay < 0:
    496             x = x[-delay:]
    497             y = y[:delay]
    498         return x, y
    499 
    500     @classmethod
    501     def _apply_delay_float(
    502         cls,
    503         x: _torch.Tensor,
    504         y: _torch.Tensor,
    505         delay: float,
    506         method: _DelayInterpolationMethod,
    507     ) -> _Tuple[_torch.Tensor, _torch.Tensor]:
    508         n_out = len(y) - int(_np.ceil(_np.abs(delay)))
    509         if delay > 0:
    510             x = x[:n_out]
    511         elif delay < 0:
    512             x = x[-n_out:]
    513         y = _interpolate_delay(y, delay, method)
    514         return x, y
    515 
    516     @classmethod
    517     def _validate_start_stop(
    518         cls,
    519         x: _torch.Tensor,
    520         y: _torch.Tensor,
    521         start: _Optional[int] = None,
    522         stop: _Optional[int] = None,
    523         start_samples: _Optional[int] = None,
    524         stop_samples: _Optional[int] = None,
    525         start_seconds: _Optional[_Union[int, float]] = None,
    526         stop_seconds: _Optional[_Union[int, float]] = None,
    527         sample_rate: _Optional[int] = None,
    528     ) -> _Tuple[_Optional[int], _Optional[int]]:
    529         """
    530         Parse the requested start and stop trim points.
    531 
    532         These may be valid indices in Python, but probably point to invalid usage, so
    533         we will raise an exception if something fishy is going on (e.g. starting after
    534         the end of the file, etc)
    535 
    536         :return: parsed start/stop (if valid).
    537         """
    538 
    539         def parse_start_stop(s, samples, seconds, rate):
    540             # Assumes validated inputs
    541             if s is not None:
    542                 return s
    543             if samples is not None:
    544                 return samples
    545             if seconds is not None:
    546                 return int(seconds * rate)
    547             # else
    548             return None
    549 
    550         # Resolve different ways of asking for start/stop...
    551         if start is not None:
    552             logger.warning("Using `start` is deprecated; use `start_samples` instead.")
    553         if start is not None:
    554             logger.warning("Using `stop` is deprecated; use `start_samples` instead.")
    555         if (
    556             int(start is not None)
    557             + int(start_samples is not None)
    558             + int(start_seconds is not None)
    559             >= 2
    560         ):
    561             raise ValueError(
    562                 "More than one start provided. Use only one of `start`, `start_samples`, or `start_seconds`!"
    563             )
    564         if (
    565             int(stop is not None)
    566             + int(stop_samples is not None)
    567             + int(stop_seconds is not None)
    568             >= 2
    569         ):
    570             raise ValueError(
    571                 "More than one stop provided. Use only one of `stop`, `stop_samples`, or `stop_seconds`!"
    572             )
    573         if start_seconds is not None and sample_rate is None:
    574             raise ValueError(
    575                 "Provided `start_seconds` without sample rate; cannot resolve into samples!"
    576             )
    577         if stop_seconds is not None and sample_rate is None:
    578             raise ValueError(
    579                 "Provided `stop_seconds` without sample rate; cannot resolve into samples!"
    580             )
    581 
    582         # By this point, we should have a valid, unambiguous way of asking.
    583         start = parse_start_stop(start, start_samples, start_seconds, sample_rate)
    584         stop = parse_start_stop(stop, stop_samples, stop_seconds, sample_rate)
    585         # And only use start/stop from this point.
    586 
    587         # We could do this whole thing with `if len(x[start: stop]==0`, but being more
    588         # explicit makes the error messages better for users.
    589         if start is None and stop is None:
    590             return start, stop
    591         if len(x) != len(y):
    592             raise ValueError(
    593                 f"Input and output are different length. Input has {len(x)} samples, "
    594                 f"and output has {len(y)}"
    595             )
    596         n = len(x)
    597         if start is not None:
    598             # Start after the files' end?
    599             if start >= n:
    600                 raise StartError(
    601                     f"Arrays are only {n} samples long, but start was provided as {start}, "
    602                     "which is beyond the end of the array!"
    603                 )
    604             # Start before the files' beginning?
    605             if start < -n:
    606                 raise StartError(
    607                     f"Arrays are only {n} samples long, but start was provided as {start}, "
    608                     "which is before the beginning of the array!"
    609                 )
    610         if stop is not None:
    611             # Stop after the files' end?
    612             if stop > n:
    613                 raise StopError(
    614                     f"Arrays are only {n} samples long, but stop was provided as {stop}, "
    615                     "which is beyond the end of the array!"
    616                 )
    617             # Start before the files' beginning?
    618             if stop <= -n:
    619                 raise StopError(
    620                     f"Arrays are only {n} samples long, but stop was provided as {stop}, "
    621                     "which is before the beginning of the array!"
    622                 )
    623         # Just in case...
    624         if len(x[start:stop]) == 0:
    625             raise StartStopError(
    626                 f"Array length {n} with start={start} and stop={stop} would get "
    627                 "rid of all of the data!"
    628             )
    629         return start, stop
    630 
    631     @classmethod
    632     def _validate_x_y(self, x, y):
    633         if len(x) != len(y):
    634             raise XYError(
    635                 f"Input and output aren't the same lengths! ({len(x)} vs {len(y)})"
    636             )
    637         # TODO channels
    638         n = len(x)
    639         if n == 0:
    640             raise XYError("Input and output are empty!")
    641 
    642     def _validate_inputs_after_processing(self, x, y, nx, ny):
    643         assert x.ndim == 1
    644         assert y.ndim == 1
    645         assert len(x) == len(y)
    646         if nx > len(x):
    647             raise RuntimeError(  # TODO XYError?
    648                 f"Input of length {len(x)}, but receptive field is {nx}."
    649             )
    650         if ny is not None:
    651             assert ny <= len(y) - nx + 1
    652         if _torch.abs(y).max() >= 1.0:
    653             msg = "Output clipped."
    654             if self._y_path is not None:
    655                 msg += f"Source is {self._y_path}"
    656             raise ValueError(msg)
    657 
    658     @classmethod
    659     def _validate_preceding_silence(
    660         cls,
    661         x: _torch.Tensor,
    662         start: _Optional[int],
    663         silent_seconds: float,
    664         sample_rate: _Optional[float],
    665     ):
    666         """
    667         Make sure that the input is silent before the starting index.
    668         If it's not, then the output from that non-silent input will leak into the data
    669         set and couldn't be predicted!
    670 
    671         This assumes that silence is indeed required. If it's not, then don't call this!
    672 
    673         See: Issue #252
    674 
    675         :param x: Input
    676         :param start: Where the data starts
    677         :param silent_samples: How many are expected to be silent
    678         """
    679         if sample_rate is None:
    680             raise ValueError(
    681                 f"Pre-silence was required for {silent_seconds} seconds, but no sample "
    682                 "rate was provided!"
    683             )
    684         silent_samples = int(silent_seconds * sample_rate)
    685         if start is None:
    686             return
    687         raw_check_start = start - silent_samples
    688         check_start = max(raw_check_start, 0) if start >= 0 else min(raw_check_start, 0)
    689         check_end = start
    690         if not _torch.all(x[check_start:check_end] == 0.0):
    691             raise XYError(
    692                 f"Input provided isn't silent for at least {silent_samples} samples "
    693                 "before the starting index. Responses to this non-silent input may "
    694                 "leak into the dataset!"
    695             )
    696 
    697 
    698 class ConcatDataset(AbstractDataset, _InitializableFromConfig):
    699     def __init__(self, datasets: _Sequence[Dataset], flatten=True):
    700         if flatten:
    701             datasets = self._flatten_datasets(datasets)
    702         self._validate_datasets(datasets)
    703         self._datasets = datasets
    704         self._lookup = self._make_lookup()
    705 
    706     def __getitem__(self, idx: int) -> _Tuple[_torch.Tensor, _torch.Tensor]:
    707         i, j = self._lookup[idx]
    708         return self.datasets[i][j]
    709 
    710     def __len__(self) -> int:
    711         """
    712         How many data sets are in this data set
    713         """
    714         return sum(len(d) for d in self._datasets)
    715 
    716     @property
    717     def datasets(self):
    718         return self._datasets
    719 
    720     @classmethod
    721     def parse_config(cls, config):
    722         init = _dataset_init_registry[config.get("type", "dataset")]
    723         return {
    724             "datasets": tuple(
    725                 init(c) for c in _tqdm(config["dataset_configs"], desc="Loading data")
    726             )
    727         }
    728 
    729     def _flatten_datasets(self, datasets):
    730         """
    731         If any dataset is a ConcatDataset, pull it out
    732         """
    733         flattened = []
    734         for d in datasets:
    735             if isinstance(d, ConcatDataset):
    736                 flattened.extend(d.datasets)
    737             else:
    738                 flattened.append(d)
    739         return flattened
    740 
    741     def _make_lookup(self):
    742         """
    743         For faster __getitem__
    744         """
    745         lookup = {}
    746         offset = 0
    747         j = 0  # Dataset index
    748         for i in range(len(self)):
    749             if offset == len(self.datasets[j]):
    750                 offset -= len(self.datasets[j])
    751                 j += 1
    752             lookup[i] = (j, offset)
    753             offset += 1
    754         # Assert that we got to the last data set
    755         if j != len(self.datasets) - 1:
    756             raise RuntimeError(
    757                 f"During lookup population, didn't get to the last dataset (index "
    758                 f"{len(self.datasets)-1}). Instead index ended at {j}."
    759             )
    760         if offset != len(self.datasets[-1]):
    761             raise RuntimeError(
    762                 "During lookup population, didn't end at the index of the last datum "
    763                 f"in the last dataset. Expected index {len(self.datasets[-1])}, got "
    764                 f"{offset} instead."
    765             )
    766         return lookup
    767 
    768     @classmethod
    769     def _validate_datasets(cls, datasets: _Sequence[Dataset]):
    770         Reference = _namedtuple("Reference", ("index", "val"))
    771         ref_keys, ref_ny = None, None
    772         for i, d in enumerate(datasets):
    773             ref_ny = Reference(i, d.ny) if ref_ny is None else ref_ny
    774             if d.ny != ref_ny.val:
    775                 raise ValueError(
    776                     f"Mismatch between ny of datasets {ref_ny.index} ({ref_ny.val}) and {i} ({d.ny})"
    777                 )
    778 
    779 
    780 _dataset_init_registry = {"dataset": Dataset.init_from_config}
    781 
    782 
    783 def register_dataset_initializer(
    784     name: str, constructor: _Callable[[_Any], AbstractDataset], overwrite=False
    785 ):
    786     """
    787     If you have other data set types, you can register their initializer by name using
    788     this.
    789 
    790     For example, the basic NAM is registered by default under the name "default", but if
    791     it weren't, you could register it like this:
    792 
    793     >>> from nam import data
    794     >>> data.register_dataset_initializer("parametric", MyParametricDataset.init_from_config)
    795 
    796     :param name: The name that'll be used in the config to ask for the data set type
    797     :param constructor: The constructor that'll be fed the config.
    798     """
    799     if name in _dataset_init_registry and not overwrite:
    800         raise KeyError(
    801             f"A constructor for dataset name '{name}' is already registered!"
    802         )
    803     _dataset_init_registry[name] = constructor
    804 
    805 
    806 def init_dataset(config, split: Split) -> AbstractDataset:
    807     name = config.get("type", "dataset")
    808     base_config = config[split.value]
    809     common = config.get("common", {})
    810     if isinstance(base_config, dict):
    811         init = _dataset_init_registry[name]
    812         return init({**common, **base_config})
    813     elif isinstance(base_config, list):
    814         return ConcatDataset.init_from_config(
    815             {
    816                 "type": name,
    817                 "dataset_configs": [{**common, **c} for c in base_config],
    818             }
    819         )