neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 413d031b92e011ec0b3e6ab3b865b8632725a219
parent df664e0fa9d360f20051638f74fbaceb0814b144
Author: Steven Atkinson <[email protected]>
Date:   Mon, 13 Mar 2023 00:38:43 -0500

Better error messages (#125)

* More helpful error message when the input/output files aren't the same length.

* deepcopy btw

* Plot diagnostics when can't plot delay

* Better error reporting in delay calibration trigger

* More plotting tweaks
Diffstat:
Mnam/_core.py | 4+++-
Mnam/data.py | 71+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------
Mnam/train/core.py | 111++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----------------------
3 files changed, 145 insertions(+), 41 deletions(-)

diff --git a/nam/_core.py b/nam/_core.py @@ -2,6 +2,8 @@ # Created Date: Saturday February 5th 2022 # Author: Steven Atkinson ([email protected]) +from copy import deepcopy + class InitializableFromConfig(object): @classmethod @@ -10,4 +12,4 @@ class InitializableFromConfig(object): @classmethod def parse_config(cls, config): - return config + return deepcopy(config) diff --git a/nam/data.py b/nam/data.py @@ -35,6 +35,26 @@ class WavInfo: rate: int +class AudioShapeMismatchError(ValueError): + """ + Exception where the shape (number of samples, number of channels) of two audio files + don't match but were supposed to. + """ + + def __init__(self, shape_expected, shape_actual, *args, **kwargs): + super().__init__(*args, **kwargs) + self._shape_expected = shape_expected + self._shape_actual = shape_actual + + @property + def shape_expected(self): + return self._shape_expected + + @property + def shape_actual(self): + return self._shape_actual + + def wav_to_np( filename: Union[str, Path], rate: Optional[int] = REQUIRED_RATE, @@ -70,8 +90,11 @@ def wav_to_np( arr_premono = x_wav.data[preroll:] / (2.0 ** (8 * x_wav.sampwidth - 1)) if required_shape is not None: if arr_premono.shape != required_shape: - raise ValueError( - f"Mismatched shapes {arr_premono.shape} versus {required_shape}" + raise AudioShapeMismatchError( + arr_premono.shape, + required_shape, + f"Mismatched shapes. Expected {required_shape}, but this is " + f"{arr_premono.shape}!", ) # sampwidth fine--we're just casting to 32-bit float anyways arr = arr_premono[:, 0] @@ -289,12 +312,44 @@ class Dataset(AbstractDataset, InitializableFromConfig): @classmethod def parse_config(cls, config): x, x_wavinfo = wav_to_tensor(config["x_path"], info=True) - y = wav_to_tensor( - config["y_path"], - preroll=config.get("y_preroll"), - required_shape=(len(x), 1), - required_wavinfo=x_wavinfo, - ) + rate = x_wavinfo.rate + try: + y = wav_to_tensor( + config["y_path"], + rate=rate, + preroll=config.get("y_preroll"), + required_shape=(len(x), 1), + required_wavinfo=x_wavinfo, + ) + except AudioShapeMismatchError as e: + # Really verbose message since users see this. + x_samples, x_channels = e.shape_expected + y_samples, y_channels = e.shape_actual + msg = "Your audio files aren't the same shape as each other!" + if x_channels != y_channels: + ctosm = {1: "mono", 2: "stereo"} + msg += f"\n * The input is {ctosm[x_channels]}, but the output is {ctosm[y_channels]}!" + if x_samples != y_samples: + + def sample_to_time(s, rate): + seconds = s // rate + remainder = s % rate + hours, minutes = 0, 0 + seconds_per_hour = 3600 + while seconds >= seconds_per_hour: + hours += 1 + seconds -= seconds_per_hour + seconds_per_minute = 60 + while seconds >= seconds_per_minute: + minutes += 1 + seconds -= seconds_per_minute + return ( + f"{hours}:{minutes:02d}:{seconds:02d} and {remainder} samples" + ) + + msg += f"\n * The input is {sample_to_time(x_samples, rate)} long" + msg += f"\n * The output is {sample_to_time(y_samples, rate)} long" + raise ValueError(msg) return { "x": x, "y": y, diff --git a/nam/train/core.py b/nam/train/core.py @@ -54,51 +54,98 @@ def _detect_input_version(input_path) -> Version: return version +_V1_BLIP_LOCATIONS = 12_000, 36_000 + + def _calibrate_delay_v1(input_path, output_path) -> int: + lookahead = 1_000 + lookback = 10_000 safety_factor = 4 - # Locations of blips in v1 signal file: - i1, i2 = 12_000, 36_000 - j1_start_looking = i1 - 1_000 - j2_start_looking = i2 - 1_000 + # Calibrate the trigger: y = wav_to_np(output_path)[:48_000] - background_level = np.max(np.abs(y[:6_000])) trigger_threshold = max(background_level + 0.01, 1.01 * background_level) - j1 = np.where(np.abs(y[j1_start_looking:j2_start_looking]) > trigger_threshold)[0][ - 0 - ] - j2 = np.where(np.abs(y[j2_start_looking:]) > trigger_threshold)[0][0] - - delay_1 = (j1 + j1_start_looking) - i1 - delay_2 = (j2 + j2_start_looking) - i2 - print(f"Delays: {delay_1}, {delay_2}") - delay = int(np.min([delay_1, delay_2])) - safety_factor - print(f"Final delay is {delay}") + + delays = [] + for blip_index, i in enumerate(_V1_BLIP_LOCATIONS, 1): + + start_looking = i - lookahead + stop_looking = i + lookback + y_scan = y[start_looking:stop_looking] + triggered = np.where(np.abs(y_scan) > trigger_threshold)[0] + if len(triggered) == 0: + msg = ( + f"No response activated the trigger in response to blip " + f"{blip_index}. Is something wrong with the reamp?" + ) + print(msg) + print("SHARE THIS PLOT IF YOU ASK FOR HELP") + plt.figure() + plt.plot(np.arange(-lookahead, lookback), y_scan, label="Signal") + plt.axvline(x=0, color="C1", linestyle="--", label="Trigger") + plt.axhline( + y=-trigger_threshold, color="k", linestyle="--", label="Threshold" + ) + plt.axhline(y=trigger_threshold, color="k", linestyle="--") + plt.xlim((-lookahead, lookback)) + plt.xlabel("Samples") + plt.ylabel("Response") + plt.legend() + plt.show() + raise RuntimeError(msg) + else: + j = triggered[0] + delays.append(j + start_looking - i) + + print("Delays:") + for d in delays: + print(" {d}") + delay = int(np.min(delays)) - safety_factor + print(f"After aplying safety factor, final delay is {delay}") return delay -def _plot_delay_v1(delay: int, input_path: str, output_path: str): +def _plot_delay_v1(delay: int, input_path: str, output_path: str, _nofail=True): print("Plotting the delay for manual inspection...") x = wav_to_np(input_path)[:48_000] y = wav_to_np(output_path)[:48_000] - i = np.where(np.abs(x) > 0.1)[0][0] # In case resampled poorly - di = 20 - plt.figure() - # plt.plot(x[i - di : i + di], ".-", label="Input") - plt.plot( - np.arange(-di, di), - y[i - di + delay : i + di + delay], - ".-", - label="Output", - ) - plt.axvline(x=0, linestyle="--", color="C1") - plt.legend() - plt.show() # This doesn't freeze the notebook + i = np.where(np.abs(x) > 0.5 * np.abs(x).max())[0] # In case resampled poorly + if len(i) == 0: + print("Failed to find the spike in the input file.") + print( + "Plotting the input and output; there should be spikes at around the " + "marked locations." + ) + expected_spikes = 12_000, 36_000 # For v1 specifically + fig, axs = plt.subplots(2, 1) + for ax, curve in zip(axs, (x, y)): + ax.plot(curve) + [ax.axvline(x=es, color="C1", linestyle="--") for es in expected_spikes] + plt.show() + if _nofail: + raise RuntimeError("Failed to plot delay") + else: + i = i[0] + di = 20 + plt.figure() + # plt.plot(x[i - di : i + di], ".-", label="Input") + plt.plot( + np.arange(-di, di), + y[i - di + delay : i + di + delay], + ".-", + label="Output", + ) + plt.axvline(x=0, linestyle="--", color="C1") + plt.legend() + plt.show() # This doesn't freeze the notebook def _calibrate_delay( - delay: Optional[int], input_version: Version, input_path: str, output_path: str, + delay: Optional[int], + input_version: Version, + input_path: str, + output_path: str, ) -> int: if input_version.major == 1: calibrate, plot = _calibrate_delay_v1, _plot_delay_v1 @@ -226,7 +273,7 @@ def _get_configs( "name": "WaveNet", # This should do decently. If you really want a nice model, try turning up # "channels" in the first block and "input_size" in the second from 12 to 16. - "config": _get_wavenet_config(architecture) + "config": _get_wavenet_config(architecture), }, "loss": {"val_loss": "esr"}, "optimizer": {"lr": lr}, @@ -302,7 +349,7 @@ def train( input_version: Optional[Version] = None, epochs=100, delay=None, - architecture: Union[Architecture, str]=Architecture.STANDARD, + architecture: Union[Architecture, str] = Architecture.STANDARD, lr=0.004, lr_decay=0.007, seed: Optional[int] = 0,