commit e3d1d0b7ba51d08978875641a9dfc2b7ff2c2e66
parent 62a8f7e47582cc0137831eaa494d99e1659ae0f7
Author: Steven Atkinson <[email protected]>
Date: Mon, 6 Nov 2023 17:17:32 -0800
[FEATURE] V3 reamp file (#330)
* First work on V3 reamp file
* Add v3 checks, remove v2 blip check failure
* Update Colab code
* Fix f-string
* Fix bug
* Update README.md
* Udpate README.md
* Tests
* Update README.md
* Automatic use of most recent reamp file name in GUI
* Fix bug
Diffstat:
8 files changed, 278 insertions(+), 65 deletions(-)
diff --git a/README.md b/README.md
@@ -126,7 +126,8 @@ NAM can train using any paired audio files, but the simplified trainers (Colab a
You can use any of the following files:
-* [v2_0_0.wav](https://drive.google.com/file/d/1xnyJP_IZ7NuyDSTJfn-Jmc5lw0IE7nfu/view?usp=drive_link) (preferred)
+* [v3_0_0.wav](https://drive.google.com/file/d/1Pgf8PdE0rKB1TD4TRPKbpNo1ByR3IOm9/view?usp=drive_link) (preferred)
+* [v2_0_0.wav](https://drive.google.com/file/d/1xnyJP_IZ7NuyDSTJfn-Jmc5lw0IE7nfu/view?usp=drive_link)
* [v1_1_1.wav](https://drive.google.com/file/d/1CMj2uv_x8GIs-3X1reo7squHOVfkOa6s/view?usp=drive_link)
* [v1.wav](https://drive.google.com/file/d/1jxwTHOCx3Zf03DggAsuDTcVqsgokNyhm/view?usp=drive_link)
diff --git a/bin/train/easy_colab.ipynb b/bin/train/easy_colab.ipynb
@@ -42,7 +42,7 @@
"\n",
"### Step 1.1: Download the capture signal\n",
"\"Easy mode\" uses a pre-crafted \"capture signal\".\n",
- "Download it here: [v2_0_0.wav](https://drive.google.com/file/d/1EPUJNbVXtRnCwqVQtPRUgz3jiexNmvMD/view?usp=drive_link).\n",
+ "Download it here: [v3_0_0.wav](https://drive.google.com/file/d/1Pgf8PdE0rKB1TD4TRPKbpNo1ByR3IOm9/view?usp=drive_link).\n",
"\n",
"### Step 1.2 Reamp your gear\n",
"Then reamp the gear you want to model using it. Save that reamp as \"output.wav\".\n",
@@ -114,7 +114,7 @@
"gear_model = \"GearAmp\" #@param {type:\"string\"}\n",
"#@markdown Gear type:\n",
"# This needs to be a literal. You need to change it by hand if you change the enum!\n",
- "gear_type = f\"Pick from: amp, pedal, pedal_amp, amp_cab, amp_pedal_cab, preamp, studio\" #@param {type:\"string\"}\n",
+ "gear_type = \"Pick from: amp, pedal, pedal_amp, amp_cab, amp_pedal_cab, preamp, studio\" #@param {type:\"string\"}\n",
"#@markdown Tone type:\n",
"tone_type = \"Pick from: clean, overdrive, crunch, hi_gain, fuzz\" #@param {type:\"string\"}\n",
"\n",
diff --git a/nam/train/_names.py b/nam/train/_names.py
@@ -0,0 +1,25 @@
+# File: _names.py
+# Created Date: Monday November 6th 2023
+# Author: Steven Atkinson ([email protected])
+
+from typing import NamedTuple
+
+from ._version import Version
+
+__all__ = ["INPUT_BASENAMES", "LATEST_VERSION", "VersionAndName"]
+
+
+class VersionAndName(NamedTuple):
+ version: Version
+ name: str
+
+
+# From most the least recently-released
+INPUT_BASENAMES = (
+ VersionAndName(Version(3, 0, 0), "v3_0_0.wav"),
+ VersionAndName(Version(2, 0, 0), "v2_0_0.wav"),
+ VersionAndName(Version(1, 1, 1), "v1_1_1.wav"),
+ VersionAndName(Version(1, 0, 0), "v1.wav"),
+)
+
+LATEST_VERSION = INPUT_BASENAMES[0]
diff --git a/nam/train/colab.py b/nam/train/colab.py
@@ -8,19 +8,14 @@ Hide the mess in Colab to make things look pretty for users.
from pathlib import Path
-from typing import Optional, Tuple
+from typing import NamedTuple, Optional, Tuple
from ..models.metadata import UserMetadata
+from ._names import INPUT_BASENAMES, LATEST_VERSION, Version
from ._version import Version
from .core import train
-_INPUT_BASENAMES = (
- (Version(2, 0, 0), "v2_0_0.wav"),
- (Version(1, 1, 1), "v1_1_1.wav"),
- (Version(1, 0, 0), "v1.wav"),
-)
-_LATEST_VERSION = _INPUT_BASENAMES[0]
_BUGGY_INPUT_BASENAMES = {
# 1.1.0 has the spikes at the wrong spots.
"v1_1_0.wav"
@@ -30,24 +25,25 @@ _TRAIN_PATH = "."
def _check_for_files() -> Tuple[Version, str]:
+ # TODO use hash logic as in GUI trainer!
print("Checking that we have all of the required audio files...")
for name in _BUGGY_INPUT_BASENAMES:
if Path(name).exists():
raise RuntimeError(
- f"Detected input signal {name} that has known bugs. Please download the latest input signal, {_LATEST_VERSION[1]}"
+ f"Detected input signal {name} that has known bugs. Please download the latest input signal, {LATEST_VERSION[1]}"
)
- for input_version, input_basename in _INPUT_BASENAMES:
+ for input_version, input_basename in INPUT_BASENAMES:
if Path(input_basename).exists():
- if input_version != _LATEST_VERSION[0]:
+ if input_version != LATEST_VERSION.version:
print(
f"WARNING: Using out-of-date input file {input_basename}. "
"Recommend downloading and using the latest version, "
- f"{_LATEST_VERSION[1]}."
+ f"{LATEST_VERSION.name}."
)
break
else:
raise FileNotFoundError(
- f"Didn't find NAM's input audio file. Please upload {_LATEST_VERSION[1]}"
+ f"Didn't find NAM's input audio file. Please upload {LATEST_VERSION.name}"
)
if not Path(_OUTPUT_BASENAME).exists():
raise FileNotFoundError(
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -62,6 +62,7 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
"4d54a958861bf720ec4637f43d44a7ef": Version(1, 0, 0),
"7c3b6119c74465f79d96c761a0e27370": Version(1, 1, 1),
"ede3b9d82135ce10c7ace3bb27469422": Version(2, 0, 0),
+ "36cd1af62985c2fac3e654333e36431e": Version(3, 0, 0),
}.get(file_hash)
if version is None:
print(
@@ -72,7 +73,12 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
def detect_weak(input_path) -> Optional[Version]:
def assign_hash(path):
- def assign_hashes_v1(path) -> Tuple[Optional[str], Optional[str]]:
+ Hashes = Tuple[Optional[str], Optional[str]]
+
+ def _hash(x: np.ndarray) -> str:
+ return hashlib.md5(x).hexdigest()
+
+ def assign_hashes_v1(path) -> Hashes:
# Use this to create recognized hashes for new files
x, info = wav_to_np(path, info=True)
rate = info.rate
@@ -84,12 +90,12 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
t_white = 3 * rate
t_validation = _V1_DATA_INFO.t_validate
# v1 and v2 start with 1 blips, sine sweeps, and white noise
- start_hash = hashlib.md5(x[: t_blips + t_sweep + t_white]).hexdigest()
+ start_hash = _hash(x[: t_blips + t_sweep + t_white])
# v1 ends with validation signal
- end_hash = hashlib.md5(x[-t_validation:]).hexdigest()
+ end_hash = _hash(x[-t_validation:])
return start_hash, end_hash
- def assign_hashes_v2(path):
+ def assign_hashes_v2(path) -> Hashes:
# Use this to create recognized hashes for new files
x, info = wav_to_np(path, info=True)
rate = info.rate
@@ -101,25 +107,64 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
t_white = 3 * rate
t_validation = _V1_DATA_INFO.t_validate
# v1 and v2 start with 1 blips, sine sweeps, and white noise
- start_hash = hashlib.md5(x[: (t_blips + t_sweep + t_white)]).hexdigest()
+ start_hash = _hash(x[: (t_blips + t_sweep + t_white)])
# v2 ends with 2x validation & blips
- end_hash = hashlib.md5(x[-(2 * t_validation + t_blips) :]).hexdigest()
+ end_hash = _hash(x[-(2 * t_validation + t_blips) :])
+ return start_hash, end_hash
+
+ def assign_hashes_v3(path) -> Hashes:
+ # Use this to create recognized hashes for new files
+ x, info = wav_to_np(path, info=True)
+ rate = info.rate
+ if rate != _V3_DATA_INFO.rate:
+ return None, None
+ # Times of intervals, in seconds
+ # See below.
+ end_of_start_interval = 17 * rate # Start at 0
+ start_of_end_interval = -9 * rate
+ start_hash = _hash(x[:end_of_start_interval])
+ end_hash = _hash(x[start_of_end_interval:])
return start_hash, end_hash
start_hash_v1, end_hash_v1 = assign_hashes_v1(path)
start_hash_v2, end_hash_v2 = assign_hashes_v2(path)
- return start_hash_v1, end_hash_v1, start_hash_v2, end_hash_v2
+ start_hash_v3, end_hash_v3 = assign_hashes_v3(path)
+ return (
+ start_hash_v1,
+ end_hash_v1,
+ start_hash_v2,
+ end_hash_v2,
+ start_hash_v3,
+ end_hash_v3,
+ )
- start_hash_v1, end_hash_v1, start_hash_v2, end_hash_v2 = assign_hash(input_path)
+ (
+ start_hash_v1,
+ end_hash_v1,
+ start_hash_v2,
+ end_hash_v2,
+ start_hash_v3,
+ end_hash_v3,
+ ) = assign_hash(input_path)
print(
"Weak hashes:\n"
f" Start (v1) : {start_hash_v1}\n"
f" End (v1) : {end_hash_v1}\n"
f" Start (v2) : {start_hash_v2}\n"
- f" End (v2) : {end_hash_v2}\n",
+ f" End (v2) : {end_hash_v2}\n"
+ f" Start (v3) : {start_hash_v3}\n"
+ f" End (v3) : {end_hash_v3}\n"
)
- # Check for v2 matches first
+ # Check for matches, starting with most recent
+ version = {
+ (
+ "dadb5d62f6c3973a59bf01439799809b",
+ "8458126969a3f9d8e19a53554eb1fd52",
+ ): Version(3, 0, 0)
+ }.get((start_hash_v3, end_hash_v3))
+ if version is not None:
+ return version
version = {
(
"1c4d94fbcb47e4d820bef611c1d4ae65",
@@ -128,7 +173,6 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
}.get((start_hash_v2, end_hash_v2))
if version is not None:
return version
- # Fallback to v1
version = {
(
"bb4e140c9299bae67560d280917eb52b",
@@ -162,43 +206,78 @@ class _DataInfo(BaseModel):
:param major_version: Data major version
:param rate: Sample rate, in Hz
:param t_blips: How long the blips are, in samples
+ :param first_blips_start: When the first blips section starts, in samples
:param t_validate: Validation signal length, in samples
+ :param train_start: Where training signal starts, in samples.
:param validation_start: Where validation signal starts, in samples. Less than zero
(from the end of the array).
:param noise_interval: Inside which we quantify the noise level
- :param start_blip_locations: In samples
- :param end_blip_locations: In samples, negative (from end)
+ :param blip_locations: In samples, absolute location in the file. Negative values
+ mean from the end instead of from the start (typical "Python" negastive
+ indexing).
"""
major_version: int
rate: Optional[int]
t_blips: int
+ first_blips_start: int
t_validate: int
+ train_start: int
validation_start: int
noise_interval: Tuple[int, int]
- start_blip_locations: Sequence[int]
- end_blip_locations: Optional[Sequence[int]]
+ blip_locations: Sequence[Sequence[int]]
_V1_DATA_INFO = _DataInfo(
major_version=1,
rate=REQUIRED_RATE,
t_blips=48_000,
+ first_blips_start=0,
t_validate=432_000,
+ train_start=0,
validation_start=-432_000,
noise_interval=(0, 6000),
- start_blip_locations=(12_000, 36_000),
- end_blip_locations=None,
+ blip_locations=((12_000, 36_000),),
)
+# V2:
+# (0:00-0:02) Blips at 0:00.5 and 0:01.5
+# (0:02-0:05) Chirps
+# (0:05-0:07) Noise
+# (0:07-2:50.5) General training data
+# (2:50.5-2:51) Silence
+# (2:51-3:00) Validation 1
+# (3:00-3:09) Validation 2
+# (3:09-3:11) Blips at 3:09.5 and 3:10.5
_V2_DATA_INFO = _DataInfo(
major_version=2,
rate=REQUIRED_RATE,
t_blips=96_000,
+ first_blips_start=0,
t_validate=432_000,
+ train_start=0,
validation_start=-960_000, # 96_000 + 2 * 432_000
noise_interval=(12_000, 18_000),
- start_blip_locations=(24_000, 72_000),
- end_blip_locations=(-72_000, -24_000),
+ blip_locations=((24_000, 72_000), (-72_000, -24_000)),
+)
+# V3:
+# (0:00-0:09) Validation 1
+# (0:09-0:10) Silence
+# (0:10-0:12) Blips at 0:10.5 and 0:11.5
+# (0:12-0:15) Chirps
+# (0:15-0:17) Noise
+# (0:17-3:00.5) General training data
+# (3:00.5-3:01) Silence
+# (3:01-3:10) Validation 2
+_V3_DATA_INFO = _DataInfo(
+ major_version=3,
+ rate=REQUIRED_RATE,
+ t_blips=96_000,
+ first_blips_start=480_000,
+ t_validate=432_000,
+ train_start=480_000,
+ validation_start=-432_000,
+ noise_interval=(492_000, 498_000),
+ blip_locations=((504_000, 552_000),),
)
_DELAY_CALIBRATION_ABS_THRESHOLD = 0.0003
@@ -221,6 +300,13 @@ def _calibrate_delay_v_all(
rel_threshold=_DELAY_CALIBRATION_REL_THRESHOLD,
safety_factor=_DELAY_CALIBRATION_SAFETY_FACTOR,
) -> int:
+ """
+ Calibrate the delay in teh input-output pair based on blips.
+ This only uses the blips in the first set of blip locations!
+
+ :param y: The output audio, in complete.
+ """
+
def report_any_delay_warnings(delays: Sequence[int]):
# Warnings associated with any single delay:
@@ -240,9 +326,15 @@ def _calibrate_delay_v_all(
lookahead = 1_000
lookback = 10_000
# Calibrate the trigger:
- y = y[: data_info.t_blips]
+ y = y[data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips]
background_level = np.max(
- np.abs(y[data_info.noise_interval[0] : data_info.noise_interval[1]])
+ np.abs(
+ y[
+ data_info.noise_interval[0]
+ - data_info.first_blips_start : data_info.noise_interval[1]
+ - data_info.first_blips_start
+ ]
+ )
)
trigger_threshold = max(
background_level + abs_threshold,
@@ -250,9 +342,11 @@ def _calibrate_delay_v_all(
)
delays = []
- for blip_index, i in enumerate(data_info.start_blip_locations, 1):
- start_looking = i - lookahead
- stop_looking = i + lookback
+ for blip_index, i_abs in enumerate(data_info.blip_locations[0], 1):
+ # Relative to start of the data
+ i_rel = i_abs - data_info.first_blips_start
+ start_looking = i_rel - lookahead
+ stop_looking = i_rel + lookback
y_scan = y[start_looking:stop_looking]
triggered = np.where(np.abs(y_scan) > trigger_threshold)[0]
if len(triggered) == 0:
@@ -277,11 +371,11 @@ def _calibrate_delay_v_all(
raise RuntimeError(msg)
else:
j = triggered[0]
- delays.append(j + start_looking - i)
+ delays.append(j + start_looking - i_rel)
print("Delays:")
- for i, d in enumerate(delays, 1):
- print(f" Blip {i:2d}: {d}")
+ for i_rel, d in enumerate(delays, 1):
+ print(f" Blip {i_rel:2d}: {d}")
report_any_delay_warnings(delays)
delay = int(np.min(delays)) - safety_factor
@@ -291,14 +385,19 @@ def _calibrate_delay_v_all(
_calibrate_delay_v1 = partial(_calibrate_delay_v_all, _V1_DATA_INFO)
_calibrate_delay_v2 = partial(_calibrate_delay_v_all, _V2_DATA_INFO)
+_calibrate_delay_v3 = partial(_calibrate_delay_v_all, _V3_DATA_INFO)
def _plot_delay_v_all(
data_info: _DataInfo, delay: int, input_path: str, output_path: str, _nofail=True
):
print("Plotting the delay for manual inspection...")
- x = wav_to_np(input_path)[: data_info.t_blips]
- y = wav_to_np(output_path)[: data_info.t_blips]
+ x = wav_to_np(input_path)[
+ data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips
+ ]
+ y = wav_to_np(output_path)[
+ data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips
+ ]
# Only get the blips we really want.
i = np.where(np.abs(x) > 0.5 * np.abs(x).max())[0]
if len(i) == 0:
@@ -307,10 +406,13 @@ def _plot_delay_v_all(
"Plotting the input and output; there should be spikes at around the "
"marked locations."
)
- expected_spikes = data_info.start_blip_locations # For v1 specifically
+ t = np.arange(
+ data_info.first_blips_start, data_info.first_blips_start + data_info.t_blips
+ )
+ expected_spikes = data_info.blip_locations[0] # For v1 specifically
fig, axs = plt.subplots(len((x, y)), 1)
for ax, curve in zip(axs, (x, y)):
- ax.plot(curve)
+ ax.plot(t, curve)
[ax.axvline(x=es, color="C1", linestyle="--") for es in expected_spikes]
plt.show()
if _nofail:
@@ -318,6 +420,7 @@ def _plot_delay_v_all(
else:
plt.figure()
di = 20
+ # V1's got not a spike but a longer plateau; take the front of it.
if data_info.major_version == 1:
i = [i[0]]
for e, ii in enumerate(i, 1):
@@ -334,6 +437,7 @@ def _plot_delay_v_all(
_plot_delay_v1 = partial(_plot_delay_v_all, _V1_DATA_INFO)
_plot_delay_v2 = partial(_plot_delay_v_all, _V2_DATA_INFO)
+_plot_delay_v3 = partial(_plot_delay_v_all, _V3_DATA_INFO)
def _calibrate_delay(
@@ -347,6 +451,8 @@ def _calibrate_delay(
calibrate, plot = _calibrate_delay_v1, _plot_delay_v1
elif input_version.major == 2:
calibrate, plot = _calibrate_delay_v2, _plot_delay_v2
+ elif input_version.major == 3:
+ calibrate, plot = _calibrate_delay_v3, _plot_delay_v3
else:
raise NotImplementedError(
f"Input calibration not implemented for input version {input_version}"
@@ -394,6 +500,18 @@ def _check_v1(*args, **kwargs):
return True
+def _esr_validation_replicate_msg(threshold: float) -> str:
+ return (
+ f"Validation replicates have a self-ESR of over {threshold}. "
+ "Your gear doesn't sound like itself when played twice!\n\n"
+ "Possible causes:"
+ " * Your signal chain is too noisy."
+ " * There's a time-based effect (chorus, delay, reverb) turned on."
+ " * Some knob got moved while reamping."
+ " * You started reamping before the amp had time to warm up fully."
+ )
+
+
def _check_v2(input_path, output_path, delay: int, silent: bool) -> bool:
with torch.no_grad():
print("V2 checks...")
@@ -405,6 +523,9 @@ def _check_v2(input_path, output_path, delay: int, silent: bool) -> bool:
y_val_2 = y[-(t_blips + t_validate) : -t_blips]
esr_replicate = esr(y_val_1, y_val_2).item()
print(f"Replicate ESR is {esr_replicate:.8f}.")
+ esr_replicate_threshold = 0.01
+ if esr_replicate > esr_replicate_threshold:
+ print(_esr_validation_replicate_msg(esr_replicate_threshold))
# Do the blips line up?
# If the ESR is too bad, then flag it.
@@ -414,8 +535,8 @@ def _check_v2(input_path, output_path, delay: int, silent: bool) -> bool:
"""
:return: [start/end,replicate]
"""
- i0, i1 = _V2_DATA_INFO.start_blip_locations
- j0, j1 = _V2_DATA_INFO.end_blip_locations
+ i0, i1 = _V2_DATA_INFO.blip_locations[0]
+ j0, j1 = _V2_DATA_INFO.blip_locations[1]
i0, i1, j0, j1 = [i + delay for i in (i0, i1, j0, j1)]
start = -10
@@ -442,8 +563,16 @@ def _check_v2(input_path, output_path, delay: int, silent: bool) -> bool:
esr_threshold = 1.0e-2
- def plot_esr_blip_error(silent, msg, arrays, labels):
- if not silent:
+ def plot_esr_blip_error(
+ show_plot: bool,
+ msg: str,
+ arrays: Sequence[Sequence[float]],
+ labels: Sequence[str],
+ ):
+ """
+ :param silent: Whether to make and show a plot about it
+ """
+ if show_plot:
plt.figure()
[plt.plot(array, label=label) for array, label in zip(arrays, labels)]
plt.xlabel("Sample")
@@ -451,43 +580,72 @@ def _check_v2(input_path, output_path, delay: int, silent: bool) -> bool:
plt.legend()
plt.grid()
print(msg)
- if not silent:
+ if show_plot:
plt.show()
+ print(
+ "This is known to be a very sensitive test, so training will continue. "
+ "If the model doesn't look good, then this may be why!"
+ )
# Check consecutive blips
+ show_blip_plots = False
for e, blip_pair, when in zip((esr_0, esr_1), blips, ("start", "end")):
if e >= esr_threshold:
plot_esr_blip_error(
- silent,
+ show_blip_plots,
f"Failed consecutive blip check at {when} of training signal. The "
"target tone doesn't seem to be replicable over short timespans."
"\n\n"
" Possible causes:\n\n"
" * Your recording setup is really noisy.\n"
" * There's a noise gate that's messing things up.\n"
- " * There's a time-based effect (compressor, delay, reverb) in "
+ " * There's a time-based effect (chorus, delay, reverb) in "
"the signal chain",
blip_pair,
("Replicate 1", "Replicate 2"),
)
- return False
+ # return False # Stop bothering me! :(
# Check blips between start & end of train signal
for e, blip_pair, replicate in zip(
(esr_cross_0, esr_cross_1), blips.permute(1, 0, 2), (1, 2)
):
if e >= esr_threshold:
plot_esr_blip_error(
- silent,
+ show_blip_plots,
f"Failed start-to-end blip check for blip replicate {replicate}. "
"The target tone doesn't seem to be same at the end of the reamp "
"as it was at the start. Did some setting change during reamping?",
blip_pair,
(f"Start, replicate {replicate}", f"End, replicate {replicate}"),
)
- return False
+ # return False # Stop bothering me! :(
return True
+def _check_v3(input_path, output_path, *args, **kwargs) -> bool:
+ with torch.no_grad():
+ print("V3 checks...")
+ rate = _V3_DATA_INFO.rate
+ y = wav_to_tensor(output_path, rate=rate)
+ y_val_1 = y[: _V3_DATA_INFO.t_validate]
+ y_val_2 = y[-_V3_DATA_INFO.t_validate :]
+ esr_replicate = esr(y_val_1, y_val_2).item()
+ print(f"Replicate ESR is {esr_replicate:.8f}.")
+ esr_replicate_threshold = 0.01
+ if esr_replicate > esr_replicate_threshold:
+ print(_esr_validation_replicate_msg(esr_replicate_threshold))
+ plt.figure()
+ t = np.arange(len(y_val_1)) / rate
+ plt.plot(t, y_val_1, label="Validation 1")
+ plt.plot(t, y_val_2, label="Validation 2")
+ plt.xlabel("Time (sec)")
+ plt.legend()
+ plt.title("V3 check: Validation replicate FAILURE")
+ plt.show()
+ return False
+ return True
+
+
def _check(
input_path: str, output_path: str, input_version: Version, delay: int, silent: bool
) -> bool:
@@ -500,6 +658,8 @@ def _check(
f = _check_v1
elif input_version.major == 2:
f = _check_v2
+ elif input_version.major == 3:
+ f = _check_v3
else:
print(f"Checks not implemented for input version {input_version}; skip")
return True
@@ -648,11 +808,18 @@ def _get_configs(
validation_stop = validation_start + data_info.t_validate
train_kwargs = {"stop": train_stop}
validation_kwargs = {"start": validation_start, "stop": validation_stop}
+ elif data_info.major_version == 3:
+ validation_start = data_info.validation_start
+ train_stop = validation_start
+ train_kwargs = {"start": 480_000, "stop": train_stop}
+ validation_kwargs = {"start": validation_start}
else:
raise NotImplementedError(f"kwargs for input version {input_version}")
return train_kwargs, validation_kwargs
- data_info = {1: _V1_DATA_INFO, 2: _V2_DATA_INFO}[input_version.major]
+ data_info = {1: _V1_DATA_INFO, 2: _V2_DATA_INFO, 3: _V3_DATA_INFO}[
+ input_version.major
+ ]
train_kwargs, validation_kwargs = get_kwargs(data_info)
data_config = {
"train": {"ny": ny, **train_kwargs},
diff --git a/nam/train/gui.py b/nam/train/gui.py
@@ -170,7 +170,7 @@ class _GUI(object):
self._path_button_input = _PathButton(
self._frame_input_path,
"Input Audio",
- "Select input DI file (eg: v1_1_1.wav)",
+ "Select input DI file (eg: v3_0_0.wav)",
_PathType.FILE,
hooks=[self._check_button_states],
)
diff --git a/tests/resources/__init__.py b/tests/resources/__init__.py
@@ -12,7 +12,13 @@ from pathlib import Path
import pytest
-__all__ = ["requires_v1_0_0", "requires_v1_1_1", "requires_v2_0_0", "resource_path"]
+__all__ = [
+ "requires_v1_0_0",
+ "requires_v1_1_1",
+ "requires_v2_0_0",
+ "requires_v3_0_0",
+ "resource_path",
+]
def _requires_v(name: str):
@@ -26,6 +32,7 @@ def _requires_v(name: str):
requires_v1_0_0 = _requires_v("v1.wav")
requires_v1_1_1 = _requires_v("v1_1_1.wav")
requires_v2_0_0 = _requires_v("v2_0_0.wav")
+requires_v3_0_0 = _requires_v("v3_0_0.wav")
def resource_path(name: str) -> Path:
diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py
@@ -24,6 +24,7 @@ from ...resources import (
requires_v1_0_0,
requires_v1_1_1,
requires_v2_0_0,
+ requires_v3_0_0,
resource_path,
)
@@ -100,8 +101,12 @@ class _TCalibrateDelay(object):
@pytest.mark.parametrize("expected_delay", (-10, 0, 5, 100))
def test_calibrate_delay(self, expected_delay: int):
- x = np.zeros((self._data_info.t_blips))
- for i in self._data_info.start_blip_locations:
+ x = np.zeros((self._data_info.first_blips_start + self._data_info.t_blips,))
+ # This test only works with the first set of blip locations. Any other set of
+ # blip locations is used to check the data, not to calibrate the delay.
+ for i in self._data_info.blip_locations[0]:
+ # The blip locations are absolute in the file, not relative to the start of
+ # the blip section, so `first_blips_start` isn't used.
x[i + expected_delay] = 1.0
delay = self._calibrate_delay(x)
@@ -116,10 +121,11 @@ class _TCalibrateDelay(object):
"""
# Make the response loud enough to trigger the threshold everywhere.
- # Use the absolute threshold since the relative will be zero (I'll make it
- # silent where it's calibrated.)
+ # Use the absolute threshold since the relative will be zero (The signal will be
+ # zeroed next so it's silent where the thresholds are calibrated.)
y = np.full(
- (self._data_info.t_blips,), core._DELAY_CALIBRATION_ABS_THRESHOLD + 0.01
+ (self._data_info.first_blips_start + self._data_info.t_blips,),
+ core._DELAY_CALIBRATION_ABS_THRESHOLD + 0.01,
)
# Make the signal silent where the threshold is calibrated so the absolute
# threshold is used.
@@ -139,8 +145,9 @@ class _TCalibrateDelay(object):
with Capturing() as output:
self._calibrate_delay(y)
+ # `[0]` -- Only look in the first set of blip locations
expected_warning = core._warn_lookaheads(
- list(range(1, len(self._data_info.start_blip_locations) + 1))
+ list(range(1, len(self._data_info.blip_locations[0]) + 1))
)
assert any(o == expected_warning for o in output), output
@@ -155,6 +162,11 @@ class TestCalibrateDelayV2(_TCalibrateDelay):
_data_info = core._V2_DATA_INFO
+class TestCalibrateDelayV3(_TCalibrateDelay):
+ _calibrate_delay = core._calibrate_delay_v3
+ _data_info = core._V3_DATA_INFO
+
+
def _make_t_validation_dataset_class(
version: Version, decorator, data_info: core._DataInfo
):
@@ -189,5 +201,10 @@ TestValidationDatasetV2_0_0 = _make_t_validation_dataset_class(
)
+TestValidationDatasetV3_0_0 = _make_t_validation_dataset_class(
+ Version(3, 0, 0), requires_v3_0_0, core._V3_DATA_INFO
+)
+
+
if __name__ == "__main__":
pytest.main()