neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit d489ef8ad80898d6105b3e13b19195e0cc71e82f
parent 2f9da8ae28999b6c207cc5f6e248a6af638cd034
Author: Steven Atkinson <[email protected]>
Date:   Fri, 19 May 2023 09:17:52 -0700

Improve v2_0_0.wav (#253)

* Update core.py

Data info
Better detection of input version, update for new v2_0_0.wav

* Update core.py

Delay calibration tested and working

* Verify that validation sets are preceded by silence

* Tests resources cleanup

* Fix for v1.wav

* Fix tests due to non-silent pre-input

* Fix enumeration

* Fix tests due to silence req

* Option for how much silence, skip for end-to-end test for speed.

* Fix stuff
Diffstat:
Mnam/data.py | 47+++++++++++++++++++++++++++++++++++++++++++++++
Mnam/train/core.py | 239+++++++++++++++++++++++++++++++++++++++++++++++++++----------------------------
Atests/__init__.py | 0
Atests/resources/.gitignore | 2++
Atests/resources/__init__.py | 26++++++++++++++++++++++++++
Mtests/test_bin/test_train/test_main.py | 3++-
Mtests/test_nam/test_data.py | 8++++++++
Atests/test_nam/test_train/test_core.py | 155+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
8 files changed, 394 insertions(+), 86 deletions(-)

diff --git a/nam/data.py b/nam/data.py @@ -204,6 +204,10 @@ class StopError(StartStopError): pass +# In seconds. Can't be 0.5 or else v1.wav is invalid! Oops! +_DEFAULT_REQUIRE_INPUT_PRE_SILENCE = 0.4 + + class Dataset(AbstractDataset, InitializableFromConfig): """ Take a pair of matched audio files and serve input + output pairs. @@ -227,6 +231,8 @@ class Dataset(AbstractDataset, InitializableFromConfig): x_path: Optional[Union[str, Path]] = None, y_path: Optional[Union[str, Path]] = None, input_gain: float = 0.0, + rate: int = REQUIRED_RATE, + require_input_pre_silence: Optional[float] = _DEFAULT_REQUIRE_INPUT_PRE_SILENCE, ): """ :param x: The input signal. A 1D array. @@ -254,6 +260,11 @@ class Dataset(AbstractDataset, InitializableFromConfig): you are using a reamping setup, you can estimate this by reamping a completely dry signal (i.e. connecting the interface output directly back into the input with which the guitar was originally recorded.) + :param rate: Sample rate for the data + :param require_input_pre_silence: If provided, require that this much time (in + seconds) preceding the start of the data set (`start`) have a silent input. + If it's not, then raise an exception because the output due to it will leak + into the data set that we're trying to use. If `None`, don't assert. """ self._validate_x_y(x, y) self._validate_start_stop(x, y, start, stop) @@ -261,6 +272,10 @@ class Dataset(AbstractDataset, InitializableFromConfig): delay_interpolation_method = _DelayInterpolationMethod( delay_interpolation_method ) + if require_input_pre_silence is not None: + self._validate_preceding_silence( + x, start, int(require_input_pre_silence * rate) + ) x, y = [z[start:stop] for z in (x, y)] if delay is not None and delay != 0: x, y = self._apply_delay(x, y, delay, delay_interpolation_method) @@ -274,6 +289,7 @@ class Dataset(AbstractDataset, InitializableFromConfig): self._y = y self._nx = nx self._ny = ny if ny is not None else len(x) - nx + 1 + self._rate = rate def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -374,6 +390,10 @@ class Dataset(AbstractDataset, InitializableFromConfig): "y_scale": config.get("y_scale", 1.0), "x_path": config["x_path"], "y_path": config["y_path"], + "rate": config.get("rate", REQUIRED_RATE), + "require_input_pre_silence": config.get( + "require_input_pre_silence", _DEFAULT_REQUIRE_INPUT_PRE_SILENCE + ), } @classmethod @@ -505,6 +525,33 @@ class Dataset(AbstractDataset, InitializableFromConfig): msg += f"Source is {self._y_path}" raise ValueError(msg) + @classmethod + def _validate_preceding_silence( + cls, x: torch.Tensor, start: Optional[int], silent_samples: int + ): + """ + Make sure that the input is silent before the starting index. + If it's not, then the output from that non-silent input will leak into the data + set and couldn't be predicted! + + See: Issue #252 + + :param x: Input + :param start: Where the data starts + :param silent_samples: How many are expected to be silent + """ + if start is None: + return + raw_check_start = start - silent_samples + check_start = max(raw_check_start, 0) if start >= 0 else min(raw_check_start, 0) + check_end = start + if not torch.all(x[check_start:check_end] == 0.0): + raise XYError( + f"Input provided isn't silent for at least {silent_samples} samples " + "before the starting index. Responses to this non-silent input may " + "leak into the dataset!" + ) + class ParametricDataset(Dataset): """ diff --git a/nam/train/core.py b/nam/train/core.py @@ -10,6 +10,7 @@ import hashlib import tkinter as tk from copy import deepcopy from enum import Enum +from functools import partial from time import time from typing import Dict, Optional, Sequence, Tuple, Union @@ -17,6 +18,7 @@ import matplotlib.pyplot as plt import numpy as np import pytorch_lightning as pl import torch +from pydantic import BaseModel from torch.utils.data import DataLoader from ..data import REQUIRED_RATE, Split, init_dataset, wav_to_np, wav_to_tensor @@ -31,9 +33,11 @@ class Architecture(Enum): FEATHER = "feather" -def _detect_input_version(input_path) -> Version: +def _detect_input_version(input_path) -> Tuple[Version, bool]: """ Check to see if the input matches any of the known inputs + + :return: version, strong match """ def detect_strong(input_path) -> Optional[Version]: @@ -41,7 +45,7 @@ def _detect_input_version(input_path) -> Version: # Use this to create hashes for new files md5 = hashlib.md5() buffer_size = 65536 - with open(input_path, "rb") as f: + with open(path, "rb") as f: while True: data = f.read(buffer_size) if not data: @@ -56,7 +60,7 @@ def _detect_input_version(input_path) -> Version: version = { "4d54a958861bf720ec4637f43d44a7ef": Version(1, 0, 0), "7c3b6119c74465f79d96c761a0e27370": Version(1, 1, 1), - "cff9de79975f7fa2ba9060ad77cde04d": Version(2, 0, 0), + "ede3b9d82135ce10c7ace3bb27469422": Version(2, 0, 0), }.get(file_hash) if version is None: print( @@ -67,43 +71,60 @@ def _detect_input_version(input_path) -> Version: def detect_weak(input_path) -> Optional[Version]: def assign_hash(path): - # Use this to create recognized hashes for new files - x, info = wav_to_np(path, info=True) - rate = info.rate - if rate != REQUIRED_RATE: - return None - # Times of intervals, in seconds - t_blips = 1 - t_sweep = 3 - t_white = 3 - t_validation = 9 - # v1 and v2 start with 1 blips, sine sweeps, and white noise - start_hash = hashlib.md5( - x[: (t_blips + t_sweep + t_white) * rate] - ).hexdigest() - # v1 ends with validation signal - end_hash_v1 = hashlib.md5(x[-t_validation * rate :]).hexdigest() - # v2 ends with 2x validation & blips - end_hash_v2 = hashlib.md5( - x[-(2 * t_validation + t_blips) * rate :] - ).hexdigest() - return start_hash, end_hash_v1, end_hash_v2 - - start_hash, end_hash_v1, end_hash_v2 = assign_hash(input_path) + def assign_hashes_v1(path) -> Tuple[Optional[str], Optional[str]]: + # Use this to create recognized hashes for new files + x, info = wav_to_np(path, info=True) + rate = info.rate + if rate != _V1_DATA_INFO.rate: + return None, None + # Times of intervals, in seconds + t_blips = _V1_DATA_INFO.t_blips + t_sweep = 3 * rate + t_white = 3 * rate + t_validation = _V1_DATA_INFO.t_validate + # v1 and v2 start with 1 blips, sine sweeps, and white noise + start_hash = hashlib.md5(x[: t_blips + t_sweep + t_white]).hexdigest() + # v1 ends with validation signal + end_hash = hashlib.md5(x[-t_validation:]).hexdigest() + return start_hash, end_hash + + def assign_hashes_v2(path): + # Use this to create recognized hashes for new files + x, info = wav_to_np(path, info=True) + rate = info.rate + if rate != _V2_DATA_INFO.rate: + return None, None + # Times of intervals, in seconds + t_blips = _V2_DATA_INFO.t_blips + t_sweep = 3 * rate + t_white = 3 * rate + t_validation = _V1_DATA_INFO.t_validate + # v1 and v2 start with 1 blips, sine sweeps, and white noise + start_hash = hashlib.md5(x[: (t_blips + t_sweep + t_white)]).hexdigest() + # v2 ends with 2x validation & blips + end_hash = hashlib.md5(x[-(2 * t_validation + t_blips) :]).hexdigest() + return start_hash, end_hash + + start_hash_v1, end_hash_v1 = assign_hashes_v1(path) + start_hash_v2, end_hash_v2 = assign_hashes_v2(path) + return start_hash_v1, end_hash_v1, start_hash_v2, end_hash_v2 + + start_hash_v1, end_hash_v1, start_hash_v2, end_hash_v2 = assign_hash(input_path) print( "Weak hashes:\n" - f" Start: {start_hash}\n" - f" End (v1): {end_hash_v1}\n" - f" End (v2): {end_hash_v2}\n", + f" Start (v1) : {start_hash_v1}\n" + f" End (v1) : {end_hash_v1}\n" + f" Start (v2) : {start_hash_v2}\n" + f" End (v2) : {end_hash_v2}\n", ) # Check for v2 matches first version = { ( - "068a17d92274a136807523baad4913ff", - "74f924e8b245c8f7dce007765911545a", + "1c4d94fbcb47e4d820bef611c1d4ae65", + "28694e7bf9ab3f8ae6ef86e9545d4663", ): Version(2, 0, 0) - }.get((start_hash, end_hash_v2)) + }.get((start_hash_v2, end_hash_v2)) if version is not None: return version # Fallback to v1 @@ -111,50 +132,96 @@ def _detect_input_version(input_path) -> Version: ( "bb4e140c9299bae67560d280917eb52b", "9b2468fcb6e9460a399fc5f64389d353", - ): Version(1, 0, 0), + ): Version( + 1, 0, 0 + ), # FIXME! ( "9f20c6b5f7fef68dd88307625a573a14", "8458126969a3f9d8e19a53554eb1fd52", ): Version(1, 1, 1), - }.get((start_hash, end_hash_v1)) + }.get((start_hash_v1, end_hash_v1)) return version version = detect_strong(input_path) if version is not None: - return version + strong_match = True + return version, strong_match print("Falling back to weak-matching...") version = detect_weak(input_path) if version is None: raise ValueError( f"Input file at {input_path} cannot be recognized as any known version!" ) - return version + strong_match = False + return version, strong_match -_V1_BLIP_LOCATIONS = 12_000, 36_000 -_V2_START_BLIP_LOCATIONS = _V1_BLIP_LOCATIONS -_V2_END_BLIP_LOCATIONS = -36_000, -12_000 +class _DataInfo(BaseModel): + """ + :param major_version: Data major version + :param rate: Sample rate, in Hz + :param t_blips: How long the blips are, in seconds + :param t_validate: Validation signal length, in samples + :param validation_start: Where validation signal starts, in samples. Less than zero + (from the end of the array). + :param noise_interval: Inside which we quantify the noise level + :param start_blip_locations: In samples + :param end_blip_locations: In samples, negative (from end) + """ + + major_version: int + rate: Optional[int] + t_blips: int + t_validate: int + validation_start: int + noise_interval: Tuple[int, int] + start_blip_locations: Sequence[int] + end_blip_locations: Optional[Sequence[int]] + + +_V1_DATA_INFO = _DataInfo( + major_version=1, + rate=REQUIRED_RATE, + t_blips=48_000, + t_validate=432_000, + validation_start=-432_000, + noise_interval=(0, 6000), + start_blip_locations=(12_000, 36_000), + end_blip_locations=None, +) +_V2_DATA_INFO = _DataInfo( + major_version=2, + rate=REQUIRED_RATE, + t_blips=96_000, + t_validate=432_000, + validation_start=-960_000, # 96_000 + 2 * 432_000 + noise_interval=(12_000, 18_000), + start_blip_locations=(24_000, 72_000), + end_blip_locations=(-72_000, -24_000), +) + _DELAY_CALIBRATION_ABS_THRESHOLD = 0.0001 _DELAY_CALIBRATION_REL_THRESHOLD = 0.001 +_DELAY_CALIBRATION_SAFETY_FACTOR = 4 -def _calibrate_delay_v1( - input_path, output_path, locations: Sequence[int] = _V1_BLIP_LOCATIONS -) -> int: +def _calibrate_delay_v_all(data_info: _DataInfo, input_path, output_path) -> int: lookahead = 1_000 lookback = 10_000 - safety_factor = 4 + safety_factor = _DELAY_CALIBRATION_SAFETY_FACTOR # Calibrate the trigger: - y = wav_to_np(output_path)[:48_000] - background_level = np.max(np.abs(y[:6_000])) + y = wav_to_np(output_path)[: data_info.t_blips] + background_level = np.max( + np.abs(y[data_info.noise_interval[0] : data_info.noise_interval[1]]) + ) trigger_threshold = max( background_level + _DELAY_CALIBRATION_ABS_THRESHOLD, (1.0 + _DELAY_CALIBRATION_REL_THRESHOLD) * background_level, ) delays = [] - for blip_index, i in enumerate(locations, 1): + for blip_index, i in enumerate(data_info.start_blip_locations, 1): start_looking = i - lookahead stop_looking = i + lookback y_scan = y[start_looking:stop_looking] @@ -191,25 +258,26 @@ def _calibrate_delay_v1( return delay -def _calibrate_delay_v2( - input_path, output_path, locations: Sequence[int] = _V2_START_BLIP_LOCATIONS -) -> int: - return _calibrate_delay_v1(input_path, output_path, locations=locations) +_calibrate_delay_v1 = partial(_calibrate_delay_v_all, _V1_DATA_INFO) +_calibrate_delay_v2 = partial(_calibrate_delay_v_all, _V2_DATA_INFO) -def _plot_delay_v1(delay: int, input_path: str, output_path: str, _nofail=True): +def _plot_delay_v_all( + data_info: _DataInfo, delay: int, input_path: str, output_path: str, _nofail=True +): print("Plotting the delay for manual inspection...") - x = wav_to_np(input_path)[:48_000] - y = wav_to_np(output_path)[:48_000] - i = np.where(np.abs(x) > 0.5 * np.abs(x).max())[0] # In case resampled poorly + x = wav_to_np(input_path)[: data_info.t_blips] + y = wav_to_np(output_path)[: data_info.t_blips] + # Only get the blips we really want. + i = np.where(np.abs(x) > 0.5 * np.abs(x).max())[0] if len(i) == 0: print("Failed to find the spike in the input file.") print( "Plotting the input and output; there should be spikes at around the " "marked locations." ) - expected_spikes = 12_000, 36_000 # For v1 specifically - fig, axs = plt.subplots(2, 1) + expected_spikes = data_info.start_blip_locations # For v1 specifically + fig, axs = plt.subplots(len((x, y)), 1) for ax, curve in zip(axs, (x, y)): ax.plot(curve) [ax.axvline(x=es, color="C1", linestyle="--") for es in expected_spikes] @@ -217,22 +285,24 @@ def _plot_delay_v1(delay: int, input_path: str, output_path: str, _nofail=True): if _nofail: raise RuntimeError("Failed to plot delay") else: - i = i[0] - di = 20 plt.figure() - # plt.plot(x[i - di : i + di], ".-", label="Input") - plt.plot( - np.arange(-di, di), - y[i - di + delay : i + di + delay], - ".-", - label="Output", - ) - plt.axvline(x=0, linestyle="--", color="C1") + di = 20 + if data_info.major_version == 1: + i = [i[0]] + for e, ii in enumerate(i, 1): + plt.plot( + np.arange(-di, di), + y[ii - di + delay : ii + di + delay], + ".-", + label=f"Output {e}", + ) + plt.axvline(x=0, linestyle="--", color="k") plt.legend() plt.show() # This doesn't freeze the notebook -_plot_delay_v2 = _plot_delay_v1 +_plot_delay_v1 = partial(_plot_delay_v_all, _V1_DATA_INFO) +_plot_delay_v2 = partial(_plot_delay_v_all, _V2_DATA_INFO) def _calibrate_delay( @@ -290,10 +360,12 @@ def _check_v1(*args, **kwargs): def _check_v2(input_path, output_path, delay: int, silent: bool) -> bool: with torch.no_grad(): print("V2 checks...") - rate = REQUIRED_RATE + rate = _V2_DATA_INFO.rate y = wav_to_tensor(output_path, rate=rate) - y_val_1 = y[-19 * rate : -10 * rate] - y_val_2 = y[-10 * rate : -1 * rate] + t_blips = _V2_DATA_INFO.t_blips + t_validate = _V2_DATA_INFO.t_validate + y_val_1 = y[-(t_blips + 2 * t_validate) : -(t_blips + t_validate)] + y_val_2 = y[-(t_blips + t_validate) : -t_blips] esr_replicate = esr(y_val_1, y_val_2).item() print(f"Replicate ESR is {esr_replicate:.8f}.") @@ -305,8 +377,8 @@ def _check_v2(input_path, output_path, delay: int, silent: bool) -> bool: """ :return: [start/end,replicate] """ - i0, i1 = rate // 4, 3 * rate // 4 - j0, j1 = -3 * rate // 4, -rate // 4 + i0, i1 = _V2_DATA_INFO.start_blip_locations + j0, j1 = _V2_DATA_INFO.end_blip_locations i0, i1, j0, j1 = [i + delay for i in (i0, i1, j0, j1)] start = -10 @@ -496,26 +568,23 @@ def _get_configs( lr_decay: float, batch_size: int, ): - def get_kwargs(): - val_seconds = 9 - rate = REQUIRED_RATE - if input_version.major == 1: - train_val_split = -val_seconds * rate + def get_kwargs(data_info: _DataInfo): + if data_info.major_version == 1: + train_val_split = data_info.validation_start train_kwargs = {"stop": train_val_split} validation_kwargs = {"start": train_val_split} - elif input_version.major == 2: - blip_seconds = 1 - val_replicates = 2 - train_stop = -(blip_seconds + val_replicates * val_seconds) * rate - validation_start = train_stop - validation_stop = -(blip_seconds + val_seconds) * rate + elif data_info.major_version == 2: + validation_start = data_info.validation_start + train_stop = validation_start + validation_stop = validation_start + data_info.t_validate train_kwargs = {"stop": train_stop} validation_kwargs = {"start": validation_start, "stop": validation_stop} else: raise NotImplementedError(f"kwargs for input version {input_version}") return train_kwargs, validation_kwargs - train_kwargs, validation_kwargs = get_kwargs() + data_info = {1: _V1_DATA_INFO, 2: _V2_DATA_INFO}[input_version.major] + train_kwargs, validation_kwargs = get_kwargs(data_info) data_config = { "train": {"ny": ny, **train_kwargs}, "validation": {"ny": None, **validation_kwargs}, @@ -707,7 +776,7 @@ def train( torch.manual_seed(seed) if input_version is None: - input_version = _detect_input_version(input_path) + input_version, strong_match = _detect_input_version(input_path) if delay is None: delay = _calibrate_delay( diff --git a/tests/__init__.py b/tests/__init__.py diff --git a/tests/resources/.gitignore b/tests/resources/.gitignore @@ -0,0 +1 @@ +*.wav +\ No newline at end of file diff --git a/tests/resources/__init__.py b/tests/resources/__init__.py @@ -0,0 +1,26 @@ +# File: __init__.py +# Created Date: Thursday May 18th 2023 +# Author: Steven Atkinson ([email protected]) + +from pathlib import Path + +import pytest + +__all__ = ["requires_v1_0_0", "requires_v1_1_1", "requires_v2_0_0", "resource_path"] + + +def _requires_v(name: str): + path = Path(__file__).parent / Path(name) + return pytest.mark.skipif( + not path.exists(), + reason=f"Requires {name}, which hasn't been downloaded to {path}.", + ) + + +requires_v1_0_0 = _requires_v("v1.wav") +requires_v1_1_1 = _requires_v("v1_1_1.wav") +requires_v2_0_0 = _requires_v("v2_0_0.wav") + + +def resource_path(name: str) -> Path: + return Path(__file__).absolute().parent / Path(name) diff --git a/tests/test_bin/test_train/test_main.py b/tests/test_bin/test_train/test_main.py @@ -13,7 +13,7 @@ import numpy as np import pytest import torch -from nam.data import np_to_wav +from nam.data import REQUIRED_RATE, np_to_wav _BIN_TRAIN_MAIN_PY_PATH = Path(__file__).absolute().parent.parent.parent.parent / Path( "bin", "train", "main.py" @@ -67,6 +67,7 @@ class Test(object): "x_path": str(self._x_path(root_path)), "y_path": str(self._y_path(root_path)), "delay": 0, + "require_input_pre_silence": None, }, } stage_channels = (3, 2) diff --git a/tests/test_nam/test_data.py b/tests/test_nam/test_data.py @@ -119,12 +119,20 @@ class TestDataset(object): ), ) def test_validate_start(self, n: int, start: int, valid: bool): + """ + Assert that a data set can be successfully instantiated when valid args are + given, including `start`. + Assert that `StartError` is raised if invalid start is provided + """ + def init(): data.Dataset(x, y, nx, ny, start=start) nx = 1 ny = None x, y = self._create_xy(n=n) + if start is not None: + x[:start] = 0.0 # Ensure silent input before the start if valid: init() assert True # No problem! diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py @@ -0,0 +1,155 @@ +# File: test_core.py +# Created Date: Thursday May 18th 2023 +# Author: Steven Atkinson ([email protected]) + +from pathlib import Path +from tempfile import TemporaryDirectory + +import numpy as np +import pytest + +from nam.data import ( + Dataset, + np_to_wav, + wav_to_np, + wav_to_tensor, + _DEFAULT_REQUIRE_INPUT_PRE_SILENCE, +) +from nam.train import core +from nam.train._version import Version + +from ...resources import ( + requires_v1_0_0, + requires_v1_1_1, + requires_v2_0_0, + resource_path, +) + +__all__ = [] + + +def _resource_path(version: Version) -> Path: + if version == Version(1, 0, 0): + name = "v1.wav" + else: + name = f'v{str(version).replace(".", "_")}.wav' + return resource_path(name) + + +class TestDetectInputVersion(object): + @requires_v1_0_0 + def test_detect_input_version_v1_0_0_strong(self): + self._t_detect_input_version_strong(Version(1, 0, 0)) + + @requires_v1_1_1 + def test_detect_input_version_v1_1_1_strong(self): + self._t_detect_input_version_strong(Version(1, 1, 1)) + + @requires_v2_0_0 + def test_detect_input_version_v2_0_0_strong(self): + self._t_detect_input_version_strong(Version(2, 0, 0)) + + @requires_v1_0_0 + def test_detect_input_version_v1_0_0_weak(self): + self._t_detect_input_version_weak(Version(1, 0, 0)) + + @requires_v1_1_1 + def test_detect_input_version_v1_1_1_weak(self): + self._t_detect_input_version_weak(Version(1, 1, 1)) + + @requires_v2_0_0 + def test_detect_input_version_v2_0_0_weak(self): + self._t_detect_input_version_weak(Version(2, 0, 0)) + + @classmethod + def _customize_resource(cls, path_in, path_out): + x, info = wav_to_np(path_in, info=True) + # Should be safe... + i = info.rate * 60 + y = np.concatenate([x[:i], np.zeros((1,)), x[i:]]) + np_to_wav(y, path_out) + + @classmethod + def _t_detect_input_version( + cls, + path: Path, + expected_input_version: Version, + expected_strong_match: bool, + ): + input_version, strong_match = core._detect_input_version(path) + assert input_version == expected_input_version + assert strong_match == expected_strong_match + + @classmethod + def _t_detect_input_version_strong(cls, version: Version): + cls._t_detect_input_version(_resource_path(version), version, True) + + @classmethod + def _t_detect_input_version_weak(cls, version: Version): + with TemporaryDirectory() as tmpdir: + path = Path(tmpdir, "temp.wav") + cls._customize_resource(_resource_path(version), path) + cls._t_detect_input_version(path, version, False) + + +class _TCalibrateDelay(object): + _calibrate_delay = None + _data_info: core._DataInfo = None + + @pytest.mark.parametrize("expected_delay", (-10, 0, 5, 100)) + def test_calibrate_delay(self, expected_delay: int): + x = np.zeros((self._data_info.t_blips)) + for i in self._data_info.start_blip_locations: + x[i + expected_delay] = 1.0 + with TemporaryDirectory() as tmpdir: + path = Path(tmpdir, "output.wav") + np_to_wav(x, path) + delay = self._calibrate_delay(None, path) + assert delay == expected_delay - core._DELAY_CALIBRATION_SAFETY_FACTOR + + +class TestCalibrateDelayV1(_TCalibrateDelay): + _calibrate_delay = core._calibrate_delay_v1 + _data_info = core._V1_DATA_INFO + + +class TestCalibrateDelayV2(_TCalibrateDelay): + _calibrate_delay = core._calibrate_delay_v2 + _data_info = core._V2_DATA_INFO + + +def _make_t_validation_dataset_class( + version: Version, decorator, data_info: core._DataInfo +): + class C(object): + @decorator + def test_validation_preceded_by_silence(self): + """ + Validate that the datasets that we've made are valid + """ + x = wav_to_tensor(_resource_path(version)) + Dataset._validate_preceding_silence( + x, + data_info.validation_start, + int(_DEFAULT_REQUIRE_INPUT_PRE_SILENCE * data_info.rate), + ) + + return C + + +TestValidationDatasetV1_0_0 = _make_t_validation_dataset_class( + Version(1, 0, 0), requires_v1_0_0, core._V1_DATA_INFO +) + + +TestValidationDatasetV1_1_1 = _make_t_validation_dataset_class( + Version(1, 1, 1), requires_v1_1_1, core._V1_DATA_INFO +) + + +TestValidationDatasetV2_0_0 = _make_t_validation_dataset_class( + Version(2, 0, 0), requires_v2_0_0, core._V2_DATA_INFO +) + +if __name__ == "__main__": + pytest.main()