commit d489ef8ad80898d6105b3e13b19195e0cc71e82f
parent 2f9da8ae28999b6c207cc5f6e248a6af638cd034
Author: Steven Atkinson <[email protected]>
Date: Fri, 19 May 2023 09:17:52 -0700
Improve v2_0_0.wav (#253)
* Update core.py
Data info
Better detection of input version, update for new v2_0_0.wav
* Update core.py
Delay calibration tested and working
* Verify that validation sets are preceded by silence
* Tests resources cleanup
* Fix for v1.wav
* Fix tests due to non-silent pre-input
* Fix enumeration
* Fix tests due to silence req
* Option for how much silence, skip for end-to-end test for speed.
* Fix stuff
Diffstat:
8 files changed, 394 insertions(+), 86 deletions(-)
diff --git a/nam/data.py b/nam/data.py
@@ -204,6 +204,10 @@ class StopError(StartStopError):
pass
+# In seconds. Can't be 0.5 or else v1.wav is invalid! Oops!
+_DEFAULT_REQUIRE_INPUT_PRE_SILENCE = 0.4
+
+
class Dataset(AbstractDataset, InitializableFromConfig):
"""
Take a pair of matched audio files and serve input + output pairs.
@@ -227,6 +231,8 @@ class Dataset(AbstractDataset, InitializableFromConfig):
x_path: Optional[Union[str, Path]] = None,
y_path: Optional[Union[str, Path]] = None,
input_gain: float = 0.0,
+ rate: int = REQUIRED_RATE,
+ require_input_pre_silence: Optional[float] = _DEFAULT_REQUIRE_INPUT_PRE_SILENCE,
):
"""
:param x: The input signal. A 1D array.
@@ -254,6 +260,11 @@ class Dataset(AbstractDataset, InitializableFromConfig):
you are using a reamping setup, you can estimate this by reamping a
completely dry signal (i.e. connecting the interface output directly back
into the input with which the guitar was originally recorded.)
+ :param rate: Sample rate for the data
+ :param require_input_pre_silence: If provided, require that this much time (in
+ seconds) preceding the start of the data set (`start`) have a silent input.
+ If it's not, then raise an exception because the output due to it will leak
+ into the data set that we're trying to use. If `None`, don't assert.
"""
self._validate_x_y(x, y)
self._validate_start_stop(x, y, start, stop)
@@ -261,6 +272,10 @@ class Dataset(AbstractDataset, InitializableFromConfig):
delay_interpolation_method = _DelayInterpolationMethod(
delay_interpolation_method
)
+ if require_input_pre_silence is not None:
+ self._validate_preceding_silence(
+ x, start, int(require_input_pre_silence * rate)
+ )
x, y = [z[start:stop] for z in (x, y)]
if delay is not None and delay != 0:
x, y = self._apply_delay(x, y, delay, delay_interpolation_method)
@@ -274,6 +289,7 @@ class Dataset(AbstractDataset, InitializableFromConfig):
self._y = y
self._nx = nx
self._ny = ny if ny is not None else len(x) - nx + 1
+ self._rate = rate
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
@@ -374,6 +390,10 @@ class Dataset(AbstractDataset, InitializableFromConfig):
"y_scale": config.get("y_scale", 1.0),
"x_path": config["x_path"],
"y_path": config["y_path"],
+ "rate": config.get("rate", REQUIRED_RATE),
+ "require_input_pre_silence": config.get(
+ "require_input_pre_silence", _DEFAULT_REQUIRE_INPUT_PRE_SILENCE
+ ),
}
@classmethod
@@ -505,6 +525,33 @@ class Dataset(AbstractDataset, InitializableFromConfig):
msg += f"Source is {self._y_path}"
raise ValueError(msg)
+ @classmethod
+ def _validate_preceding_silence(
+ cls, x: torch.Tensor, start: Optional[int], silent_samples: int
+ ):
+ """
+ Make sure that the input is silent before the starting index.
+ If it's not, then the output from that non-silent input will leak into the data
+ set and couldn't be predicted!
+
+ See: Issue #252
+
+ :param x: Input
+ :param start: Where the data starts
+ :param silent_samples: How many are expected to be silent
+ """
+ if start is None:
+ return
+ raw_check_start = start - silent_samples
+ check_start = max(raw_check_start, 0) if start >= 0 else min(raw_check_start, 0)
+ check_end = start
+ if not torch.all(x[check_start:check_end] == 0.0):
+ raise XYError(
+ f"Input provided isn't silent for at least {silent_samples} samples "
+ "before the starting index. Responses to this non-silent input may "
+ "leak into the dataset!"
+ )
+
class ParametricDataset(Dataset):
"""
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -10,6 +10,7 @@ import hashlib
import tkinter as tk
from copy import deepcopy
from enum import Enum
+from functools import partial
from time import time
from typing import Dict, Optional, Sequence, Tuple, Union
@@ -17,6 +18,7 @@ import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
+from pydantic import BaseModel
from torch.utils.data import DataLoader
from ..data import REQUIRED_RATE, Split, init_dataset, wav_to_np, wav_to_tensor
@@ -31,9 +33,11 @@ class Architecture(Enum):
FEATHER = "feather"
-def _detect_input_version(input_path) -> Version:
+def _detect_input_version(input_path) -> Tuple[Version, bool]:
"""
Check to see if the input matches any of the known inputs
+
+ :return: version, strong match
"""
def detect_strong(input_path) -> Optional[Version]:
@@ -41,7 +45,7 @@ def _detect_input_version(input_path) -> Version:
# Use this to create hashes for new files
md5 = hashlib.md5()
buffer_size = 65536
- with open(input_path, "rb") as f:
+ with open(path, "rb") as f:
while True:
data = f.read(buffer_size)
if not data:
@@ -56,7 +60,7 @@ def _detect_input_version(input_path) -> Version:
version = {
"4d54a958861bf720ec4637f43d44a7ef": Version(1, 0, 0),
"7c3b6119c74465f79d96c761a0e27370": Version(1, 1, 1),
- "cff9de79975f7fa2ba9060ad77cde04d": Version(2, 0, 0),
+ "ede3b9d82135ce10c7ace3bb27469422": Version(2, 0, 0),
}.get(file_hash)
if version is None:
print(
@@ -67,43 +71,60 @@ def _detect_input_version(input_path) -> Version:
def detect_weak(input_path) -> Optional[Version]:
def assign_hash(path):
- # Use this to create recognized hashes for new files
- x, info = wav_to_np(path, info=True)
- rate = info.rate
- if rate != REQUIRED_RATE:
- return None
- # Times of intervals, in seconds
- t_blips = 1
- t_sweep = 3
- t_white = 3
- t_validation = 9
- # v1 and v2 start with 1 blips, sine sweeps, and white noise
- start_hash = hashlib.md5(
- x[: (t_blips + t_sweep + t_white) * rate]
- ).hexdigest()
- # v1 ends with validation signal
- end_hash_v1 = hashlib.md5(x[-t_validation * rate :]).hexdigest()
- # v2 ends with 2x validation & blips
- end_hash_v2 = hashlib.md5(
- x[-(2 * t_validation + t_blips) * rate :]
- ).hexdigest()
- return start_hash, end_hash_v1, end_hash_v2
-
- start_hash, end_hash_v1, end_hash_v2 = assign_hash(input_path)
+ def assign_hashes_v1(path) -> Tuple[Optional[str], Optional[str]]:
+ # Use this to create recognized hashes for new files
+ x, info = wav_to_np(path, info=True)
+ rate = info.rate
+ if rate != _V1_DATA_INFO.rate:
+ return None, None
+ # Times of intervals, in seconds
+ t_blips = _V1_DATA_INFO.t_blips
+ t_sweep = 3 * rate
+ t_white = 3 * rate
+ t_validation = _V1_DATA_INFO.t_validate
+ # v1 and v2 start with 1 blips, sine sweeps, and white noise
+ start_hash = hashlib.md5(x[: t_blips + t_sweep + t_white]).hexdigest()
+ # v1 ends with validation signal
+ end_hash = hashlib.md5(x[-t_validation:]).hexdigest()
+ return start_hash, end_hash
+
+ def assign_hashes_v2(path):
+ # Use this to create recognized hashes for new files
+ x, info = wav_to_np(path, info=True)
+ rate = info.rate
+ if rate != _V2_DATA_INFO.rate:
+ return None, None
+ # Times of intervals, in seconds
+ t_blips = _V2_DATA_INFO.t_blips
+ t_sweep = 3 * rate
+ t_white = 3 * rate
+ t_validation = _V1_DATA_INFO.t_validate
+ # v1 and v2 start with 1 blips, sine sweeps, and white noise
+ start_hash = hashlib.md5(x[: (t_blips + t_sweep + t_white)]).hexdigest()
+ # v2 ends with 2x validation & blips
+ end_hash = hashlib.md5(x[-(2 * t_validation + t_blips) :]).hexdigest()
+ return start_hash, end_hash
+
+ start_hash_v1, end_hash_v1 = assign_hashes_v1(path)
+ start_hash_v2, end_hash_v2 = assign_hashes_v2(path)
+ return start_hash_v1, end_hash_v1, start_hash_v2, end_hash_v2
+
+ start_hash_v1, end_hash_v1, start_hash_v2, end_hash_v2 = assign_hash(input_path)
print(
"Weak hashes:\n"
- f" Start: {start_hash}\n"
- f" End (v1): {end_hash_v1}\n"
- f" End (v2): {end_hash_v2}\n",
+ f" Start (v1) : {start_hash_v1}\n"
+ f" End (v1) : {end_hash_v1}\n"
+ f" Start (v2) : {start_hash_v2}\n"
+ f" End (v2) : {end_hash_v2}\n",
)
# Check for v2 matches first
version = {
(
- "068a17d92274a136807523baad4913ff",
- "74f924e8b245c8f7dce007765911545a",
+ "1c4d94fbcb47e4d820bef611c1d4ae65",
+ "28694e7bf9ab3f8ae6ef86e9545d4663",
): Version(2, 0, 0)
- }.get((start_hash, end_hash_v2))
+ }.get((start_hash_v2, end_hash_v2))
if version is not None:
return version
# Fallback to v1
@@ -111,50 +132,96 @@ def _detect_input_version(input_path) -> Version:
(
"bb4e140c9299bae67560d280917eb52b",
"9b2468fcb6e9460a399fc5f64389d353",
- ): Version(1, 0, 0),
+ ): Version(
+ 1, 0, 0
+ ), # FIXME!
(
"9f20c6b5f7fef68dd88307625a573a14",
"8458126969a3f9d8e19a53554eb1fd52",
): Version(1, 1, 1),
- }.get((start_hash, end_hash_v1))
+ }.get((start_hash_v1, end_hash_v1))
return version
version = detect_strong(input_path)
if version is not None:
- return version
+ strong_match = True
+ return version, strong_match
print("Falling back to weak-matching...")
version = detect_weak(input_path)
if version is None:
raise ValueError(
f"Input file at {input_path} cannot be recognized as any known version!"
)
- return version
+ strong_match = False
+ return version, strong_match
-_V1_BLIP_LOCATIONS = 12_000, 36_000
-_V2_START_BLIP_LOCATIONS = _V1_BLIP_LOCATIONS
-_V2_END_BLIP_LOCATIONS = -36_000, -12_000
+class _DataInfo(BaseModel):
+ """
+ :param major_version: Data major version
+ :param rate: Sample rate, in Hz
+ :param t_blips: How long the blips are, in seconds
+ :param t_validate: Validation signal length, in samples
+ :param validation_start: Where validation signal starts, in samples. Less than zero
+ (from the end of the array).
+ :param noise_interval: Inside which we quantify the noise level
+ :param start_blip_locations: In samples
+ :param end_blip_locations: In samples, negative (from end)
+ """
+
+ major_version: int
+ rate: Optional[int]
+ t_blips: int
+ t_validate: int
+ validation_start: int
+ noise_interval: Tuple[int, int]
+ start_blip_locations: Sequence[int]
+ end_blip_locations: Optional[Sequence[int]]
+
+
+_V1_DATA_INFO = _DataInfo(
+ major_version=1,
+ rate=REQUIRED_RATE,
+ t_blips=48_000,
+ t_validate=432_000,
+ validation_start=-432_000,
+ noise_interval=(0, 6000),
+ start_blip_locations=(12_000, 36_000),
+ end_blip_locations=None,
+)
+_V2_DATA_INFO = _DataInfo(
+ major_version=2,
+ rate=REQUIRED_RATE,
+ t_blips=96_000,
+ t_validate=432_000,
+ validation_start=-960_000, # 96_000 + 2 * 432_000
+ noise_interval=(12_000, 18_000),
+ start_blip_locations=(24_000, 72_000),
+ end_blip_locations=(-72_000, -24_000),
+)
+
_DELAY_CALIBRATION_ABS_THRESHOLD = 0.0001
_DELAY_CALIBRATION_REL_THRESHOLD = 0.001
+_DELAY_CALIBRATION_SAFETY_FACTOR = 4
-def _calibrate_delay_v1(
- input_path, output_path, locations: Sequence[int] = _V1_BLIP_LOCATIONS
-) -> int:
+def _calibrate_delay_v_all(data_info: _DataInfo, input_path, output_path) -> int:
lookahead = 1_000
lookback = 10_000
- safety_factor = 4
+ safety_factor = _DELAY_CALIBRATION_SAFETY_FACTOR
# Calibrate the trigger:
- y = wav_to_np(output_path)[:48_000]
- background_level = np.max(np.abs(y[:6_000]))
+ y = wav_to_np(output_path)[: data_info.t_blips]
+ background_level = np.max(
+ np.abs(y[data_info.noise_interval[0] : data_info.noise_interval[1]])
+ )
trigger_threshold = max(
background_level + _DELAY_CALIBRATION_ABS_THRESHOLD,
(1.0 + _DELAY_CALIBRATION_REL_THRESHOLD) * background_level,
)
delays = []
- for blip_index, i in enumerate(locations, 1):
+ for blip_index, i in enumerate(data_info.start_blip_locations, 1):
start_looking = i - lookahead
stop_looking = i + lookback
y_scan = y[start_looking:stop_looking]
@@ -191,25 +258,26 @@ def _calibrate_delay_v1(
return delay
-def _calibrate_delay_v2(
- input_path, output_path, locations: Sequence[int] = _V2_START_BLIP_LOCATIONS
-) -> int:
- return _calibrate_delay_v1(input_path, output_path, locations=locations)
+_calibrate_delay_v1 = partial(_calibrate_delay_v_all, _V1_DATA_INFO)
+_calibrate_delay_v2 = partial(_calibrate_delay_v_all, _V2_DATA_INFO)
-def _plot_delay_v1(delay: int, input_path: str, output_path: str, _nofail=True):
+def _plot_delay_v_all(
+ data_info: _DataInfo, delay: int, input_path: str, output_path: str, _nofail=True
+):
print("Plotting the delay for manual inspection...")
- x = wav_to_np(input_path)[:48_000]
- y = wav_to_np(output_path)[:48_000]
- i = np.where(np.abs(x) > 0.5 * np.abs(x).max())[0] # In case resampled poorly
+ x = wav_to_np(input_path)[: data_info.t_blips]
+ y = wav_to_np(output_path)[: data_info.t_blips]
+ # Only get the blips we really want.
+ i = np.where(np.abs(x) > 0.5 * np.abs(x).max())[0]
if len(i) == 0:
print("Failed to find the spike in the input file.")
print(
"Plotting the input and output; there should be spikes at around the "
"marked locations."
)
- expected_spikes = 12_000, 36_000 # For v1 specifically
- fig, axs = plt.subplots(2, 1)
+ expected_spikes = data_info.start_blip_locations # For v1 specifically
+ fig, axs = plt.subplots(len((x, y)), 1)
for ax, curve in zip(axs, (x, y)):
ax.plot(curve)
[ax.axvline(x=es, color="C1", linestyle="--") for es in expected_spikes]
@@ -217,22 +285,24 @@ def _plot_delay_v1(delay: int, input_path: str, output_path: str, _nofail=True):
if _nofail:
raise RuntimeError("Failed to plot delay")
else:
- i = i[0]
- di = 20
plt.figure()
- # plt.plot(x[i - di : i + di], ".-", label="Input")
- plt.plot(
- np.arange(-di, di),
- y[i - di + delay : i + di + delay],
- ".-",
- label="Output",
- )
- plt.axvline(x=0, linestyle="--", color="C1")
+ di = 20
+ if data_info.major_version == 1:
+ i = [i[0]]
+ for e, ii in enumerate(i, 1):
+ plt.plot(
+ np.arange(-di, di),
+ y[ii - di + delay : ii + di + delay],
+ ".-",
+ label=f"Output {e}",
+ )
+ plt.axvline(x=0, linestyle="--", color="k")
plt.legend()
plt.show() # This doesn't freeze the notebook
-_plot_delay_v2 = _plot_delay_v1
+_plot_delay_v1 = partial(_plot_delay_v_all, _V1_DATA_INFO)
+_plot_delay_v2 = partial(_plot_delay_v_all, _V2_DATA_INFO)
def _calibrate_delay(
@@ -290,10 +360,12 @@ def _check_v1(*args, **kwargs):
def _check_v2(input_path, output_path, delay: int, silent: bool) -> bool:
with torch.no_grad():
print("V2 checks...")
- rate = REQUIRED_RATE
+ rate = _V2_DATA_INFO.rate
y = wav_to_tensor(output_path, rate=rate)
- y_val_1 = y[-19 * rate : -10 * rate]
- y_val_2 = y[-10 * rate : -1 * rate]
+ t_blips = _V2_DATA_INFO.t_blips
+ t_validate = _V2_DATA_INFO.t_validate
+ y_val_1 = y[-(t_blips + 2 * t_validate) : -(t_blips + t_validate)]
+ y_val_2 = y[-(t_blips + t_validate) : -t_blips]
esr_replicate = esr(y_val_1, y_val_2).item()
print(f"Replicate ESR is {esr_replicate:.8f}.")
@@ -305,8 +377,8 @@ def _check_v2(input_path, output_path, delay: int, silent: bool) -> bool:
"""
:return: [start/end,replicate]
"""
- i0, i1 = rate // 4, 3 * rate // 4
- j0, j1 = -3 * rate // 4, -rate // 4
+ i0, i1 = _V2_DATA_INFO.start_blip_locations
+ j0, j1 = _V2_DATA_INFO.end_blip_locations
i0, i1, j0, j1 = [i + delay for i in (i0, i1, j0, j1)]
start = -10
@@ -496,26 +568,23 @@ def _get_configs(
lr_decay: float,
batch_size: int,
):
- def get_kwargs():
- val_seconds = 9
- rate = REQUIRED_RATE
- if input_version.major == 1:
- train_val_split = -val_seconds * rate
+ def get_kwargs(data_info: _DataInfo):
+ if data_info.major_version == 1:
+ train_val_split = data_info.validation_start
train_kwargs = {"stop": train_val_split}
validation_kwargs = {"start": train_val_split}
- elif input_version.major == 2:
- blip_seconds = 1
- val_replicates = 2
- train_stop = -(blip_seconds + val_replicates * val_seconds) * rate
- validation_start = train_stop
- validation_stop = -(blip_seconds + val_seconds) * rate
+ elif data_info.major_version == 2:
+ validation_start = data_info.validation_start
+ train_stop = validation_start
+ validation_stop = validation_start + data_info.t_validate
train_kwargs = {"stop": train_stop}
validation_kwargs = {"start": validation_start, "stop": validation_stop}
else:
raise NotImplementedError(f"kwargs for input version {input_version}")
return train_kwargs, validation_kwargs
- train_kwargs, validation_kwargs = get_kwargs()
+ data_info = {1: _V1_DATA_INFO, 2: _V2_DATA_INFO}[input_version.major]
+ train_kwargs, validation_kwargs = get_kwargs(data_info)
data_config = {
"train": {"ny": ny, **train_kwargs},
"validation": {"ny": None, **validation_kwargs},
@@ -707,7 +776,7 @@ def train(
torch.manual_seed(seed)
if input_version is None:
- input_version = _detect_input_version(input_path)
+ input_version, strong_match = _detect_input_version(input_path)
if delay is None:
delay = _calibrate_delay(
diff --git a/tests/__init__.py b/tests/__init__.py
diff --git a/tests/resources/.gitignore b/tests/resources/.gitignore
@@ -0,0 +1 @@
+*.wav
+\ No newline at end of file
diff --git a/tests/resources/__init__.py b/tests/resources/__init__.py
@@ -0,0 +1,26 @@
+# File: __init__.py
+# Created Date: Thursday May 18th 2023
+# Author: Steven Atkinson ([email protected])
+
+from pathlib import Path
+
+import pytest
+
+__all__ = ["requires_v1_0_0", "requires_v1_1_1", "requires_v2_0_0", "resource_path"]
+
+
+def _requires_v(name: str):
+ path = Path(__file__).parent / Path(name)
+ return pytest.mark.skipif(
+ not path.exists(),
+ reason=f"Requires {name}, which hasn't been downloaded to {path}.",
+ )
+
+
+requires_v1_0_0 = _requires_v("v1.wav")
+requires_v1_1_1 = _requires_v("v1_1_1.wav")
+requires_v2_0_0 = _requires_v("v2_0_0.wav")
+
+
+def resource_path(name: str) -> Path:
+ return Path(__file__).absolute().parent / Path(name)
diff --git a/tests/test_bin/test_train/test_main.py b/tests/test_bin/test_train/test_main.py
@@ -13,7 +13,7 @@ import numpy as np
import pytest
import torch
-from nam.data import np_to_wav
+from nam.data import REQUIRED_RATE, np_to_wav
_BIN_TRAIN_MAIN_PY_PATH = Path(__file__).absolute().parent.parent.parent.parent / Path(
"bin", "train", "main.py"
@@ -67,6 +67,7 @@ class Test(object):
"x_path": str(self._x_path(root_path)),
"y_path": str(self._y_path(root_path)),
"delay": 0,
+ "require_input_pre_silence": None,
},
}
stage_channels = (3, 2)
diff --git a/tests/test_nam/test_data.py b/tests/test_nam/test_data.py
@@ -119,12 +119,20 @@ class TestDataset(object):
),
)
def test_validate_start(self, n: int, start: int, valid: bool):
+ """
+ Assert that a data set can be successfully instantiated when valid args are
+ given, including `start`.
+ Assert that `StartError` is raised if invalid start is provided
+ """
+
def init():
data.Dataset(x, y, nx, ny, start=start)
nx = 1
ny = None
x, y = self._create_xy(n=n)
+ if start is not None:
+ x[:start] = 0.0 # Ensure silent input before the start
if valid:
init()
assert True # No problem!
diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py
@@ -0,0 +1,155 @@
+# File: test_core.py
+# Created Date: Thursday May 18th 2023
+# Author: Steven Atkinson ([email protected])
+
+from pathlib import Path
+from tempfile import TemporaryDirectory
+
+import numpy as np
+import pytest
+
+from nam.data import (
+ Dataset,
+ np_to_wav,
+ wav_to_np,
+ wav_to_tensor,
+ _DEFAULT_REQUIRE_INPUT_PRE_SILENCE,
+)
+from nam.train import core
+from nam.train._version import Version
+
+from ...resources import (
+ requires_v1_0_0,
+ requires_v1_1_1,
+ requires_v2_0_0,
+ resource_path,
+)
+
+__all__ = []
+
+
+def _resource_path(version: Version) -> Path:
+ if version == Version(1, 0, 0):
+ name = "v1.wav"
+ else:
+ name = f'v{str(version).replace(".", "_")}.wav'
+ return resource_path(name)
+
+
+class TestDetectInputVersion(object):
+ @requires_v1_0_0
+ def test_detect_input_version_v1_0_0_strong(self):
+ self._t_detect_input_version_strong(Version(1, 0, 0))
+
+ @requires_v1_1_1
+ def test_detect_input_version_v1_1_1_strong(self):
+ self._t_detect_input_version_strong(Version(1, 1, 1))
+
+ @requires_v2_0_0
+ def test_detect_input_version_v2_0_0_strong(self):
+ self._t_detect_input_version_strong(Version(2, 0, 0))
+
+ @requires_v1_0_0
+ def test_detect_input_version_v1_0_0_weak(self):
+ self._t_detect_input_version_weak(Version(1, 0, 0))
+
+ @requires_v1_1_1
+ def test_detect_input_version_v1_1_1_weak(self):
+ self._t_detect_input_version_weak(Version(1, 1, 1))
+
+ @requires_v2_0_0
+ def test_detect_input_version_v2_0_0_weak(self):
+ self._t_detect_input_version_weak(Version(2, 0, 0))
+
+ @classmethod
+ def _customize_resource(cls, path_in, path_out):
+ x, info = wav_to_np(path_in, info=True)
+ # Should be safe...
+ i = info.rate * 60
+ y = np.concatenate([x[:i], np.zeros((1,)), x[i:]])
+ np_to_wav(y, path_out)
+
+ @classmethod
+ def _t_detect_input_version(
+ cls,
+ path: Path,
+ expected_input_version: Version,
+ expected_strong_match: bool,
+ ):
+ input_version, strong_match = core._detect_input_version(path)
+ assert input_version == expected_input_version
+ assert strong_match == expected_strong_match
+
+ @classmethod
+ def _t_detect_input_version_strong(cls, version: Version):
+ cls._t_detect_input_version(_resource_path(version), version, True)
+
+ @classmethod
+ def _t_detect_input_version_weak(cls, version: Version):
+ with TemporaryDirectory() as tmpdir:
+ path = Path(tmpdir, "temp.wav")
+ cls._customize_resource(_resource_path(version), path)
+ cls._t_detect_input_version(path, version, False)
+
+
+class _TCalibrateDelay(object):
+ _calibrate_delay = None
+ _data_info: core._DataInfo = None
+
+ @pytest.mark.parametrize("expected_delay", (-10, 0, 5, 100))
+ def test_calibrate_delay(self, expected_delay: int):
+ x = np.zeros((self._data_info.t_blips))
+ for i in self._data_info.start_blip_locations:
+ x[i + expected_delay] = 1.0
+ with TemporaryDirectory() as tmpdir:
+ path = Path(tmpdir, "output.wav")
+ np_to_wav(x, path)
+ delay = self._calibrate_delay(None, path)
+ assert delay == expected_delay - core._DELAY_CALIBRATION_SAFETY_FACTOR
+
+
+class TestCalibrateDelayV1(_TCalibrateDelay):
+ _calibrate_delay = core._calibrate_delay_v1
+ _data_info = core._V1_DATA_INFO
+
+
+class TestCalibrateDelayV2(_TCalibrateDelay):
+ _calibrate_delay = core._calibrate_delay_v2
+ _data_info = core._V2_DATA_INFO
+
+
+def _make_t_validation_dataset_class(
+ version: Version, decorator, data_info: core._DataInfo
+):
+ class C(object):
+ @decorator
+ def test_validation_preceded_by_silence(self):
+ """
+ Validate that the datasets that we've made are valid
+ """
+ x = wav_to_tensor(_resource_path(version))
+ Dataset._validate_preceding_silence(
+ x,
+ data_info.validation_start,
+ int(_DEFAULT_REQUIRE_INPUT_PRE_SILENCE * data_info.rate),
+ )
+
+ return C
+
+
+TestValidationDatasetV1_0_0 = _make_t_validation_dataset_class(
+ Version(1, 0, 0), requires_v1_0_0, core._V1_DATA_INFO
+)
+
+
+TestValidationDatasetV1_1_1 = _make_t_validation_dataset_class(
+ Version(1, 1, 1), requires_v1_1_1, core._V1_DATA_INFO
+)
+
+
+TestValidationDatasetV2_0_0 = _make_t_validation_dataset_class(
+ Version(2, 0, 0), requires_v2_0_0, core._V2_DATA_INFO
+)
+
+if __name__ == "__main__":
+ pytest.main()