neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit e3d1d0b7ba51d08978875641a9dfc2b7ff2c2e66
parent 62a8f7e47582cc0137831eaa494d99e1659ae0f7
Author: Steven Atkinson <[email protected]>
Date:   Mon,  6 Nov 2023 17:17:32 -0800

[FEATURE] V3 reamp file (#330)

* First work on V3 reamp file

* Add v3 checks, remove v2 blip check failure

* Update Colab code

* Fix f-string

* Fix bug

* Update README.md

* Udpate README.md

* Tests

* Update README.md

* Automatic use of most recent reamp file name in GUI

* Fix bug
Diffstat:
MREADME.md | 3++-
Mbin/train/easy_colab.ipynb | 4++--
Anam/train/_names.py | 25+++++++++++++++++++++++++
Mnam/train/colab.py | 20++++++++------------
Mnam/train/core.py | 251+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------
Mnam/train/gui.py | 2+-
Mtests/resources/__init__.py | 9++++++++-
Mtests/test_nam/test_train/test_core.py | 29+++++++++++++++++++++++------
8 files changed, 278 insertions(+), 65 deletions(-)

diff --git a/README.md b/README.md @@ -126,7 +126,8 @@ NAM can train using any paired audio files, but the simplified trainers (Colab a You can use any of the following files: -* [v2_0_0.wav](https://drive.google.com/file/d/1xnyJP_IZ7NuyDSTJfn-Jmc5lw0IE7nfu/view?usp=drive_link) (preferred) +* [v3_0_0.wav](https://drive.google.com/file/d/1Pgf8PdE0rKB1TD4TRPKbpNo1ByR3IOm9/view?usp=drive_link) (preferred) +* [v2_0_0.wav](https://drive.google.com/file/d/1xnyJP_IZ7NuyDSTJfn-Jmc5lw0IE7nfu/view?usp=drive_link) * [v1_1_1.wav](https://drive.google.com/file/d/1CMj2uv_x8GIs-3X1reo7squHOVfkOa6s/view?usp=drive_link) * [v1.wav](https://drive.google.com/file/d/1jxwTHOCx3Zf03DggAsuDTcVqsgokNyhm/view?usp=drive_link) diff --git a/bin/train/easy_colab.ipynb b/bin/train/easy_colab.ipynb @@ -42,7 +42,7 @@ "\n", "### Step 1.1: Download the capture signal\n", "\"Easy mode\" uses a pre-crafted \"capture signal\".\n", - "Download it here: [v2_0_0.wav](https://drive.google.com/file/d/1EPUJNbVXtRnCwqVQtPRUgz3jiexNmvMD/view?usp=drive_link).\n", + "Download it here: [v3_0_0.wav](https://drive.google.com/file/d/1Pgf8PdE0rKB1TD4TRPKbpNo1ByR3IOm9/view?usp=drive_link).\n", "\n", "### Step 1.2 Reamp your gear\n", "Then reamp the gear you want to model using it. Save that reamp as \"output.wav\".\n", @@ -114,7 +114,7 @@ "gear_model = \"GearAmp\" #@param {type:\"string\"}\n", "#@markdown Gear type:\n", "# This needs to be a literal. You need to change it by hand if you change the enum!\n", - "gear_type = f\"Pick from: amp, pedal, pedal_amp, amp_cab, amp_pedal_cab, preamp, studio\" #@param {type:\"string\"}\n", + "gear_type = \"Pick from: amp, pedal, pedal_amp, amp_cab, amp_pedal_cab, preamp, studio\" #@param {type:\"string\"}\n", "#@markdown Tone type:\n", "tone_type = \"Pick from: clean, overdrive, crunch, hi_gain, fuzz\" #@param {type:\"string\"}\n", "\n", diff --git a/nam/train/_names.py b/nam/train/_names.py @@ -0,0 +1,25 @@ +# File: _names.py +# Created Date: Monday November 6th 2023 +# Author: Steven Atkinson ([email protected]) + +from typing import NamedTuple + +from ._version import Version + +__all__ = ["INPUT_BASENAMES", "LATEST_VERSION", "VersionAndName"] + + +class VersionAndName(NamedTuple): + version: Version + name: str + + +# From most the least recently-released +INPUT_BASENAMES = ( + VersionAndName(Version(3, 0, 0), "v3_0_0.wav"), + VersionAndName(Version(2, 0, 0), "v2_0_0.wav"), + VersionAndName(Version(1, 1, 1), "v1_1_1.wav"), + VersionAndName(Version(1, 0, 0), "v1.wav"), +) + +LATEST_VERSION = INPUT_BASENAMES[0] diff --git a/nam/train/colab.py b/nam/train/colab.py @@ -8,19 +8,14 @@ Hide the mess in Colab to make things look pretty for users. from pathlib import Path -from typing import Optional, Tuple +from typing import NamedTuple, Optional, Tuple from ..models.metadata import UserMetadata +from ._names import INPUT_BASENAMES, LATEST_VERSION, Version from ._version import Version from .core import train -_INPUT_BASENAMES = ( - (Version(2, 0, 0), "v2_0_0.wav"), - (Version(1, 1, 1), "v1_1_1.wav"), - (Version(1, 0, 0), "v1.wav"), -) -_LATEST_VERSION = _INPUT_BASENAMES[0] _BUGGY_INPUT_BASENAMES = { # 1.1.0 has the spikes at the wrong spots. "v1_1_0.wav" @@ -30,24 +25,25 @@ _TRAIN_PATH = "." def _check_for_files() -> Tuple[Version, str]: + # TODO use hash logic as in GUI trainer! print("Checking that we have all of the required audio files...") for name in _BUGGY_INPUT_BASENAMES: if Path(name).exists(): raise RuntimeError( - f"Detected input signal {name} that has known bugs. Please download the latest input signal, {_LATEST_VERSION[1]}" + f"Detected input signal {name} that has known bugs. Please download the latest input signal, {LATEST_VERSION[1]}" ) - for input_version, input_basename in _INPUT_BASENAMES: + for input_version, input_basename in INPUT_BASENAMES: if Path(input_basename).exists(): - if input_version != _LATEST_VERSION[0]: + if input_version != LATEST_VERSION.version: print( f"WARNING: Using out-of-date input file {input_basename}. " "Recommend downloading and using the latest version, " - f"{_LATEST_VERSION[1]}." + f"{LATEST_VERSION.name}." ) break else: raise FileNotFoundError( - f"Didn't find NAM's input audio file. Please upload {_LATEST_VERSION[1]}" + f"Didn't find NAM's input audio file. Please upload {LATEST_VERSION.name}" ) if not Path(_OUTPUT_BASENAME).exists(): raise FileNotFoundError( diff --git a/nam/train/core.py b/nam/train/core.py @@ -62,6 +62,7 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]: "4d54a958861bf720ec4637f43d44a7ef": Version(1, 0, 0), "7c3b6119c74465f79d96c761a0e27370": Version(1, 1, 1), "ede3b9d82135ce10c7ace3bb27469422": Version(2, 0, 0), + "36cd1af62985c2fac3e654333e36431e": Version(3, 0, 0), }.get(file_hash) if version is None: print( @@ -72,7 +73,12 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]: def detect_weak(input_path) -> Optional[Version]: def assign_hash(path): - def assign_hashes_v1(path) -> Tuple[Optional[str], Optional[str]]: + Hashes = Tuple[Optional[str], Optional[str]] + + def _hash(x: np.ndarray) -> str: + return hashlib.md5(x).hexdigest() + + def assign_hashes_v1(path) -> Hashes: # Use this to create recognized hashes for new files x, info = wav_to_np(path, info=True) rate = info.rate @@ -84,12 +90,12 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]: t_white = 3 * rate t_validation = _V1_DATA_INFO.t_validate # v1 and v2 start with 1 blips, sine sweeps, and white noise - start_hash = hashlib.md5(x[: t_blips + t_sweep + t_white]).hexdigest() + start_hash = _hash(x[: t_blips + t_sweep + t_white]) # v1 ends with validation signal - end_hash = hashlib.md5(x[-t_validation:]).hexdigest() + end_hash = _hash(x[-t_validation:]) return start_hash, end_hash - def assign_hashes_v2(path): + def assign_hashes_v2(path) -> Hashes: # Use this to create recognized hashes for new files x, info = wav_to_np(path, info=True) rate = info.rate @@ -101,25 +107,64 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]: t_white = 3 * rate t_validation = _V1_DATA_INFO.t_validate # v1 and v2 start with 1 blips, sine sweeps, and white noise - start_hash = hashlib.md5(x[: (t_blips + t_sweep + t_white)]).hexdigest() + start_hash = _hash(x[: (t_blips + t_sweep + t_white)]) # v2 ends with 2x validation & blips - end_hash = hashlib.md5(x[-(2 * t_validation + t_blips) :]).hexdigest() + end_hash = _hash(x[-(2 * t_validation + t_blips) :]) + return start_hash, end_hash + + def assign_hashes_v3(path) -> Hashes: + # Use this to create recognized hashes for new files + x, info = wav_to_np(path, info=True) + rate = info.rate + if rate != _V3_DATA_INFO.rate: + return None, None + # Times of intervals, in seconds + # See below. + end_of_start_interval = 17 * rate # Start at 0 + start_of_end_interval = -9 * rate + start_hash = _hash(x[:end_of_start_interval]) + end_hash = _hash(x[start_of_end_interval:]) return start_hash, end_hash start_hash_v1, end_hash_v1 = assign_hashes_v1(path) start_hash_v2, end_hash_v2 = assign_hashes_v2(path) - return start_hash_v1, end_hash_v1, start_hash_v2, end_hash_v2 + start_hash_v3, end_hash_v3 = assign_hashes_v3(path) + return ( + start_hash_v1, + end_hash_v1, + start_hash_v2, + end_hash_v2, + start_hash_v3, + end_hash_v3, + ) - start_hash_v1, end_hash_v1, start_hash_v2, end_hash_v2 = assign_hash(input_path) + ( + start_hash_v1, + end_hash_v1, + start_hash_v2, + end_hash_v2, + start_hash_v3, + end_hash_v3, + ) = assign_hash(input_path) print( "Weak hashes:\n" f" Start (v1) : {start_hash_v1}\n" f" End (v1) : {end_hash_v1}\n" f" Start (v2) : {start_hash_v2}\n" - f" End (v2) : {end_hash_v2}\n", + f" End (v2) : {end_hash_v2}\n" + f" Start (v3) : {start_hash_v3}\n" + f" End (v3) : {end_hash_v3}\n" ) - # Check for v2 matches first + # Check for matches, starting with most recent + version = { + ( + "dadb5d62f6c3973a59bf01439799809b", + "8458126969a3f9d8e19a53554eb1fd52", + ): Version(3, 0, 0) + }.get((start_hash_v3, end_hash_v3)) + if version is not None: + return version version = { ( "1c4d94fbcb47e4d820bef611c1d4ae65", @@ -128,7 +173,6 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]: }.get((start_hash_v2, end_hash_v2)) if version is not None: return version - # Fallback to v1 version = { ( "bb4e140c9299bae67560d280917eb52b", @@ -162,43 +206,78 @@ class _DataInfo(BaseModel): :param major_version: Data major version :param rate: Sample rate, in Hz :param t_blips: How long the blips are, in samples + :param first_blips_start: When the first blips section starts, in samples :param t_validate: Validation signal length, in samples + :param train_start: Where training signal starts, in samples. :param validation_start: Where validation signal starts, in samples. Less than zero (from the end of the array). :param noise_interval: Inside which we quantify the noise level - :param start_blip_locations: In samples - :param end_blip_locations: In samples, negative (from end) + :param blip_locations: In samples, absolute location in the file. Negative values + mean from the end instead of from the start (typical "Python" negastive + indexing). """ major_version: int rate: Optional[int] t_blips: int + first_blips_start: int t_validate: int + train_start: int validation_start: int noise_interval: Tuple[int, int] - start_blip_locations: Sequence[int] - end_blip_locations: Optional[Sequence[int]] + blip_locations: Sequence[Sequence[int]] _V1_DATA_INFO = _DataInfo( major_version=1, rate=REQUIRED_RATE, t_blips=48_000, + first_blips_start=0, t_validate=432_000, + train_start=0, validation_start=-432_000, noise_interval=(0, 6000), - start_blip_locations=(12_000, 36_000), - end_blip_locations=None, + blip_locations=((12_000, 36_000),), ) +# V2: +# (0:00-0:02) Blips at 0:00.5 and 0:01.5 +# (0:02-0:05) Chirps +# (0:05-0:07) Noise +# (0:07-2:50.5) General training data +# (2:50.5-2:51) Silence +# (2:51-3:00) Validation 1 +# (3:00-3:09) Validation 2 +# (3:09-3:11) Blips at 3:09.5 and 3:10.5 _V2_DATA_INFO = _DataInfo( major_version=2, rate=REQUIRED_RATE, t_blips=96_000, + first_blips_start=0, t_validate=432_000, + train_start=0, validation_start=-960_000, # 96_000 + 2 * 432_000 noise_interval=(12_000, 18_000), - start_blip_locations=(24_000, 72_000), - end_blip_locations=(-72_000, -24_000), + blip_locations=((24_000, 72_000), (-72_000, -24_000)), +) +# V3: +# (0:00-0:09) Validation 1 +# (0:09-0:10) Silence +# (0:10-0:12) Blips at 0:10.5 and 0:11.5 +# (0:12-0:15) Chirps +# (0:15-0:17) Noise +# (0:17-3:00.5) General training data +# (3:00.5-3:01) Silence +# (3:01-3:10) Validation 2 +_V3_DATA_INFO = _DataInfo( + major_version=3, + rate=REQUIRED_RATE, + t_blips=96_000, + first_blips_start=480_000, + t_validate=432_000, + train_start=480_000, + validation_start=-432_000, + noise_interval=(492_000, 498_000), + blip_locations=((504_000, 552_000),), ) _DELAY_CALIBRATION_ABS_THRESHOLD = 0.0003 @@ -221,6 +300,13 @@ def _calibrate_delay_v_all( rel_threshold=_DELAY_CALIBRATION_REL_THRESHOLD, safety_factor=_DELAY_CALIBRATION_SAFETY_FACTOR, ) -> int: + """ + Calibrate the delay in teh input-output pair based on blips. + This only uses the blips in the first set of blip locations! + + :param y: The output audio, in complete. + """ + def report_any_delay_warnings(delays: Sequence[int]): # Warnings associated with any single delay: @@ -240,9 +326,15 @@ def _calibrate_delay_v_all( lookahead = 1_000 lookback = 10_000 # Calibrate the trigger: - y = y[: data_info.t_blips] + y = y[data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips] background_level = np.max( - np.abs(y[data_info.noise_interval[0] : data_info.noise_interval[1]]) + np.abs( + y[ + data_info.noise_interval[0] + - data_info.first_blips_start : data_info.noise_interval[1] + - data_info.first_blips_start + ] + ) ) trigger_threshold = max( background_level + abs_threshold, @@ -250,9 +342,11 @@ def _calibrate_delay_v_all( ) delays = [] - for blip_index, i in enumerate(data_info.start_blip_locations, 1): - start_looking = i - lookahead - stop_looking = i + lookback + for blip_index, i_abs in enumerate(data_info.blip_locations[0], 1): + # Relative to start of the data + i_rel = i_abs - data_info.first_blips_start + start_looking = i_rel - lookahead + stop_looking = i_rel + lookback y_scan = y[start_looking:stop_looking] triggered = np.where(np.abs(y_scan) > trigger_threshold)[0] if len(triggered) == 0: @@ -277,11 +371,11 @@ def _calibrate_delay_v_all( raise RuntimeError(msg) else: j = triggered[0] - delays.append(j + start_looking - i) + delays.append(j + start_looking - i_rel) print("Delays:") - for i, d in enumerate(delays, 1): - print(f" Blip {i:2d}: {d}") + for i_rel, d in enumerate(delays, 1): + print(f" Blip {i_rel:2d}: {d}") report_any_delay_warnings(delays) delay = int(np.min(delays)) - safety_factor @@ -291,14 +385,19 @@ def _calibrate_delay_v_all( _calibrate_delay_v1 = partial(_calibrate_delay_v_all, _V1_DATA_INFO) _calibrate_delay_v2 = partial(_calibrate_delay_v_all, _V2_DATA_INFO) +_calibrate_delay_v3 = partial(_calibrate_delay_v_all, _V3_DATA_INFO) def _plot_delay_v_all( data_info: _DataInfo, delay: int, input_path: str, output_path: str, _nofail=True ): print("Plotting the delay for manual inspection...") - x = wav_to_np(input_path)[: data_info.t_blips] - y = wav_to_np(output_path)[: data_info.t_blips] + x = wav_to_np(input_path)[ + data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips + ] + y = wav_to_np(output_path)[ + data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips + ] # Only get the blips we really want. i = np.where(np.abs(x) > 0.5 * np.abs(x).max())[0] if len(i) == 0: @@ -307,10 +406,13 @@ def _plot_delay_v_all( "Plotting the input and output; there should be spikes at around the " "marked locations." ) - expected_spikes = data_info.start_blip_locations # For v1 specifically + t = np.arange( + data_info.first_blips_start, data_info.first_blips_start + data_info.t_blips + ) + expected_spikes = data_info.blip_locations[0] # For v1 specifically fig, axs = plt.subplots(len((x, y)), 1) for ax, curve in zip(axs, (x, y)): - ax.plot(curve) + ax.plot(t, curve) [ax.axvline(x=es, color="C1", linestyle="--") for es in expected_spikes] plt.show() if _nofail: @@ -318,6 +420,7 @@ def _plot_delay_v_all( else: plt.figure() di = 20 + # V1's got not a spike but a longer plateau; take the front of it. if data_info.major_version == 1: i = [i[0]] for e, ii in enumerate(i, 1): @@ -334,6 +437,7 @@ def _plot_delay_v_all( _plot_delay_v1 = partial(_plot_delay_v_all, _V1_DATA_INFO) _plot_delay_v2 = partial(_plot_delay_v_all, _V2_DATA_INFO) +_plot_delay_v3 = partial(_plot_delay_v_all, _V3_DATA_INFO) def _calibrate_delay( @@ -347,6 +451,8 @@ def _calibrate_delay( calibrate, plot = _calibrate_delay_v1, _plot_delay_v1 elif input_version.major == 2: calibrate, plot = _calibrate_delay_v2, _plot_delay_v2 + elif input_version.major == 3: + calibrate, plot = _calibrate_delay_v3, _plot_delay_v3 else: raise NotImplementedError( f"Input calibration not implemented for input version {input_version}" @@ -394,6 +500,18 @@ def _check_v1(*args, **kwargs): return True +def _esr_validation_replicate_msg(threshold: float) -> str: + return ( + f"Validation replicates have a self-ESR of over {threshold}. " + "Your gear doesn't sound like itself when played twice!\n\n" + "Possible causes:" + " * Your signal chain is too noisy." + " * There's a time-based effect (chorus, delay, reverb) turned on." + " * Some knob got moved while reamping." + " * You started reamping before the amp had time to warm up fully." + ) + + def _check_v2(input_path, output_path, delay: int, silent: bool) -> bool: with torch.no_grad(): print("V2 checks...") @@ -405,6 +523,9 @@ def _check_v2(input_path, output_path, delay: int, silent: bool) -> bool: y_val_2 = y[-(t_blips + t_validate) : -t_blips] esr_replicate = esr(y_val_1, y_val_2).item() print(f"Replicate ESR is {esr_replicate:.8f}.") + esr_replicate_threshold = 0.01 + if esr_replicate > esr_replicate_threshold: + print(_esr_validation_replicate_msg(esr_replicate_threshold)) # Do the blips line up? # If the ESR is too bad, then flag it. @@ -414,8 +535,8 @@ def _check_v2(input_path, output_path, delay: int, silent: bool) -> bool: """ :return: [start/end,replicate] """ - i0, i1 = _V2_DATA_INFO.start_blip_locations - j0, j1 = _V2_DATA_INFO.end_blip_locations + i0, i1 = _V2_DATA_INFO.blip_locations[0] + j0, j1 = _V2_DATA_INFO.blip_locations[1] i0, i1, j0, j1 = [i + delay for i in (i0, i1, j0, j1)] start = -10 @@ -442,8 +563,16 @@ def _check_v2(input_path, output_path, delay: int, silent: bool) -> bool: esr_threshold = 1.0e-2 - def plot_esr_blip_error(silent, msg, arrays, labels): - if not silent: + def plot_esr_blip_error( + show_plot: bool, + msg: str, + arrays: Sequence[Sequence[float]], + labels: Sequence[str], + ): + """ + :param silent: Whether to make and show a plot about it + """ + if show_plot: plt.figure() [plt.plot(array, label=label) for array, label in zip(arrays, labels)] plt.xlabel("Sample") @@ -451,43 +580,72 @@ def _check_v2(input_path, output_path, delay: int, silent: bool) -> bool: plt.legend() plt.grid() print(msg) - if not silent: + if show_plot: plt.show() + print( + "This is known to be a very sensitive test, so training will continue. " + "If the model doesn't look good, then this may be why!" + ) # Check consecutive blips + show_blip_plots = False for e, blip_pair, when in zip((esr_0, esr_1), blips, ("start", "end")): if e >= esr_threshold: plot_esr_blip_error( - silent, + show_blip_plots, f"Failed consecutive blip check at {when} of training signal. The " "target tone doesn't seem to be replicable over short timespans." "\n\n" " Possible causes:\n\n" " * Your recording setup is really noisy.\n" " * There's a noise gate that's messing things up.\n" - " * There's a time-based effect (compressor, delay, reverb) in " + " * There's a time-based effect (chorus, delay, reverb) in " "the signal chain", blip_pair, ("Replicate 1", "Replicate 2"), ) - return False + # return False # Stop bothering me! :( # Check blips between start & end of train signal for e, blip_pair, replicate in zip( (esr_cross_0, esr_cross_1), blips.permute(1, 0, 2), (1, 2) ): if e >= esr_threshold: plot_esr_blip_error( - silent, + show_blip_plots, f"Failed start-to-end blip check for blip replicate {replicate}. " "The target tone doesn't seem to be same at the end of the reamp " "as it was at the start. Did some setting change during reamping?", blip_pair, (f"Start, replicate {replicate}", f"End, replicate {replicate}"), ) - return False + # return False # Stop bothering me! :( return True +def _check_v3(input_path, output_path, *args, **kwargs) -> bool: + with torch.no_grad(): + print("V3 checks...") + rate = _V3_DATA_INFO.rate + y = wav_to_tensor(output_path, rate=rate) + y_val_1 = y[: _V3_DATA_INFO.t_validate] + y_val_2 = y[-_V3_DATA_INFO.t_validate :] + esr_replicate = esr(y_val_1, y_val_2).item() + print(f"Replicate ESR is {esr_replicate:.8f}.") + esr_replicate_threshold = 0.01 + if esr_replicate > esr_replicate_threshold: + print(_esr_validation_replicate_msg(esr_replicate_threshold)) + plt.figure() + t = np.arange(len(y_val_1)) / rate + plt.plot(t, y_val_1, label="Validation 1") + plt.plot(t, y_val_2, label="Validation 2") + plt.xlabel("Time (sec)") + plt.legend() + plt.title("V3 check: Validation replicate FAILURE") + plt.show() + return False + return True + + def _check( input_path: str, output_path: str, input_version: Version, delay: int, silent: bool ) -> bool: @@ -500,6 +658,8 @@ def _check( f = _check_v1 elif input_version.major == 2: f = _check_v2 + elif input_version.major == 3: + f = _check_v3 else: print(f"Checks not implemented for input version {input_version}; skip") return True @@ -648,11 +808,18 @@ def _get_configs( validation_stop = validation_start + data_info.t_validate train_kwargs = {"stop": train_stop} validation_kwargs = {"start": validation_start, "stop": validation_stop} + elif data_info.major_version == 3: + validation_start = data_info.validation_start + train_stop = validation_start + train_kwargs = {"start": 480_000, "stop": train_stop} + validation_kwargs = {"start": validation_start} else: raise NotImplementedError(f"kwargs for input version {input_version}") return train_kwargs, validation_kwargs - data_info = {1: _V1_DATA_INFO, 2: _V2_DATA_INFO}[input_version.major] + data_info = {1: _V1_DATA_INFO, 2: _V2_DATA_INFO, 3: _V3_DATA_INFO}[ + input_version.major + ] train_kwargs, validation_kwargs = get_kwargs(data_info) data_config = { "train": {"ny": ny, **train_kwargs}, diff --git a/nam/train/gui.py b/nam/train/gui.py @@ -170,7 +170,7 @@ class _GUI(object): self._path_button_input = _PathButton( self._frame_input_path, "Input Audio", - "Select input DI file (eg: v1_1_1.wav)", + "Select input DI file (eg: v3_0_0.wav)", _PathType.FILE, hooks=[self._check_button_states], ) diff --git a/tests/resources/__init__.py b/tests/resources/__init__.py @@ -12,7 +12,13 @@ from pathlib import Path import pytest -__all__ = ["requires_v1_0_0", "requires_v1_1_1", "requires_v2_0_0", "resource_path"] +__all__ = [ + "requires_v1_0_0", + "requires_v1_1_1", + "requires_v2_0_0", + "requires_v3_0_0", + "resource_path", +] def _requires_v(name: str): @@ -26,6 +32,7 @@ def _requires_v(name: str): requires_v1_0_0 = _requires_v("v1.wav") requires_v1_1_1 = _requires_v("v1_1_1.wav") requires_v2_0_0 = _requires_v("v2_0_0.wav") +requires_v3_0_0 = _requires_v("v3_0_0.wav") def resource_path(name: str) -> Path: diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py @@ -24,6 +24,7 @@ from ...resources import ( requires_v1_0_0, requires_v1_1_1, requires_v2_0_0, + requires_v3_0_0, resource_path, ) @@ -100,8 +101,12 @@ class _TCalibrateDelay(object): @pytest.mark.parametrize("expected_delay", (-10, 0, 5, 100)) def test_calibrate_delay(self, expected_delay: int): - x = np.zeros((self._data_info.t_blips)) - for i in self._data_info.start_blip_locations: + x = np.zeros((self._data_info.first_blips_start + self._data_info.t_blips,)) + # This test only works with the first set of blip locations. Any other set of + # blip locations is used to check the data, not to calibrate the delay. + for i in self._data_info.blip_locations[0]: + # The blip locations are absolute in the file, not relative to the start of + # the blip section, so `first_blips_start` isn't used. x[i + expected_delay] = 1.0 delay = self._calibrate_delay(x) @@ -116,10 +121,11 @@ class _TCalibrateDelay(object): """ # Make the response loud enough to trigger the threshold everywhere. - # Use the absolute threshold since the relative will be zero (I'll make it - # silent where it's calibrated.) + # Use the absolute threshold since the relative will be zero (The signal will be + # zeroed next so it's silent where the thresholds are calibrated.) y = np.full( - (self._data_info.t_blips,), core._DELAY_CALIBRATION_ABS_THRESHOLD + 0.01 + (self._data_info.first_blips_start + self._data_info.t_blips,), + core._DELAY_CALIBRATION_ABS_THRESHOLD + 0.01, ) # Make the signal silent where the threshold is calibrated so the absolute # threshold is used. @@ -139,8 +145,9 @@ class _TCalibrateDelay(object): with Capturing() as output: self._calibrate_delay(y) + # `[0]` -- Only look in the first set of blip locations expected_warning = core._warn_lookaheads( - list(range(1, len(self._data_info.start_blip_locations) + 1)) + list(range(1, len(self._data_info.blip_locations[0]) + 1)) ) assert any(o == expected_warning for o in output), output @@ -155,6 +162,11 @@ class TestCalibrateDelayV2(_TCalibrateDelay): _data_info = core._V2_DATA_INFO +class TestCalibrateDelayV3(_TCalibrateDelay): + _calibrate_delay = core._calibrate_delay_v3 + _data_info = core._V3_DATA_INFO + + def _make_t_validation_dataset_class( version: Version, decorator, data_info: core._DataInfo ): @@ -189,5 +201,10 @@ TestValidationDatasetV2_0_0 = _make_t_validation_dataset_class( ) +TestValidationDatasetV3_0_0 = _make_t_validation_dataset_class( + Version(3, 0, 0), requires_v3_0_0, core._V3_DATA_INFO +) + + if __name__ == "__main__": pytest.main()