commit 413d031b92e011ec0b3e6ab3b865b8632725a219
parent df664e0fa9d360f20051638f74fbaceb0814b144
Author: Steven Atkinson <[email protected]>
Date: Mon, 13 Mar 2023 00:38:43 -0500
Better error messages (#125)
* More helpful error message when the input/output files aren't the same length.
* deepcopy btw
* Plot diagnostics when can't plot delay
* Better error reporting in delay calibration trigger
* More plotting tweaks
Diffstat:
M | nam/_core.py | | | 4 | +++- |
M | nam/data.py | | | 71 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------- |
M | nam/train/core.py | | | 111 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------- |
3 files changed, 145 insertions(+), 41 deletions(-)
diff --git a/nam/_core.py b/nam/_core.py
@@ -2,6 +2,8 @@
# Created Date: Saturday February 5th 2022
# Author: Steven Atkinson ([email protected])
+from copy import deepcopy
+
class InitializableFromConfig(object):
@classmethod
@@ -10,4 +12,4 @@ class InitializableFromConfig(object):
@classmethod
def parse_config(cls, config):
- return config
+ return deepcopy(config)
diff --git a/nam/data.py b/nam/data.py
@@ -35,6 +35,26 @@ class WavInfo:
rate: int
+class AudioShapeMismatchError(ValueError):
+ """
+ Exception where the shape (number of samples, number of channels) of two audio files
+ don't match but were supposed to.
+ """
+
+ def __init__(self, shape_expected, shape_actual, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._shape_expected = shape_expected
+ self._shape_actual = shape_actual
+
+ @property
+ def shape_expected(self):
+ return self._shape_expected
+
+ @property
+ def shape_actual(self):
+ return self._shape_actual
+
+
def wav_to_np(
filename: Union[str, Path],
rate: Optional[int] = REQUIRED_RATE,
@@ -70,8 +90,11 @@ def wav_to_np(
arr_premono = x_wav.data[preroll:] / (2.0 ** (8 * x_wav.sampwidth - 1))
if required_shape is not None:
if arr_premono.shape != required_shape:
- raise ValueError(
- f"Mismatched shapes {arr_premono.shape} versus {required_shape}"
+ raise AudioShapeMismatchError(
+ arr_premono.shape,
+ required_shape,
+ f"Mismatched shapes. Expected {required_shape}, but this is "
+ f"{arr_premono.shape}!",
)
# sampwidth fine--we're just casting to 32-bit float anyways
arr = arr_premono[:, 0]
@@ -289,12 +312,44 @@ class Dataset(AbstractDataset, InitializableFromConfig):
@classmethod
def parse_config(cls, config):
x, x_wavinfo = wav_to_tensor(config["x_path"], info=True)
- y = wav_to_tensor(
- config["y_path"],
- preroll=config.get("y_preroll"),
- required_shape=(len(x), 1),
- required_wavinfo=x_wavinfo,
- )
+ rate = x_wavinfo.rate
+ try:
+ y = wav_to_tensor(
+ config["y_path"],
+ rate=rate,
+ preroll=config.get("y_preroll"),
+ required_shape=(len(x), 1),
+ required_wavinfo=x_wavinfo,
+ )
+ except AudioShapeMismatchError as e:
+ # Really verbose message since users see this.
+ x_samples, x_channels = e.shape_expected
+ y_samples, y_channels = e.shape_actual
+ msg = "Your audio files aren't the same shape as each other!"
+ if x_channels != y_channels:
+ ctosm = {1: "mono", 2: "stereo"}
+ msg += f"\n * The input is {ctosm[x_channels]}, but the output is {ctosm[y_channels]}!"
+ if x_samples != y_samples:
+
+ def sample_to_time(s, rate):
+ seconds = s // rate
+ remainder = s % rate
+ hours, minutes = 0, 0
+ seconds_per_hour = 3600
+ while seconds >= seconds_per_hour:
+ hours += 1
+ seconds -= seconds_per_hour
+ seconds_per_minute = 60
+ while seconds >= seconds_per_minute:
+ minutes += 1
+ seconds -= seconds_per_minute
+ return (
+ f"{hours}:{minutes:02d}:{seconds:02d} and {remainder} samples"
+ )
+
+ msg += f"\n * The input is {sample_to_time(x_samples, rate)} long"
+ msg += f"\n * The output is {sample_to_time(y_samples, rate)} long"
+ raise ValueError(msg)
return {
"x": x,
"y": y,
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -54,51 +54,98 @@ def _detect_input_version(input_path) -> Version:
return version
+_V1_BLIP_LOCATIONS = 12_000, 36_000
+
+
def _calibrate_delay_v1(input_path, output_path) -> int:
+ lookahead = 1_000
+ lookback = 10_000
safety_factor = 4
- # Locations of blips in v1 signal file:
- i1, i2 = 12_000, 36_000
- j1_start_looking = i1 - 1_000
- j2_start_looking = i2 - 1_000
+ # Calibrate the trigger:
y = wav_to_np(output_path)[:48_000]
-
background_level = np.max(np.abs(y[:6_000]))
trigger_threshold = max(background_level + 0.01, 1.01 * background_level)
- j1 = np.where(np.abs(y[j1_start_looking:j2_start_looking]) > trigger_threshold)[0][
- 0
- ]
- j2 = np.where(np.abs(y[j2_start_looking:]) > trigger_threshold)[0][0]
-
- delay_1 = (j1 + j1_start_looking) - i1
- delay_2 = (j2 + j2_start_looking) - i2
- print(f"Delays: {delay_1}, {delay_2}")
- delay = int(np.min([delay_1, delay_2])) - safety_factor
- print(f"Final delay is {delay}")
+
+ delays = []
+ for blip_index, i in enumerate(_V1_BLIP_LOCATIONS, 1):
+
+ start_looking = i - lookahead
+ stop_looking = i + lookback
+ y_scan = y[start_looking:stop_looking]
+ triggered = np.where(np.abs(y_scan) > trigger_threshold)[0]
+ if len(triggered) == 0:
+ msg = (
+ f"No response activated the trigger in response to blip "
+ f"{blip_index}. Is something wrong with the reamp?"
+ )
+ print(msg)
+ print("SHARE THIS PLOT IF YOU ASK FOR HELP")
+ plt.figure()
+ plt.plot(np.arange(-lookahead, lookback), y_scan, label="Signal")
+ plt.axvline(x=0, color="C1", linestyle="--", label="Trigger")
+ plt.axhline(
+ y=-trigger_threshold, color="k", linestyle="--", label="Threshold"
+ )
+ plt.axhline(y=trigger_threshold, color="k", linestyle="--")
+ plt.xlim((-lookahead, lookback))
+ plt.xlabel("Samples")
+ plt.ylabel("Response")
+ plt.legend()
+ plt.show()
+ raise RuntimeError(msg)
+ else:
+ j = triggered[0]
+ delays.append(j + start_looking - i)
+
+ print("Delays:")
+ for d in delays:
+ print(" {d}")
+ delay = int(np.min(delays)) - safety_factor
+ print(f"After aplying safety factor, final delay is {delay}")
return delay
-def _plot_delay_v1(delay: int, input_path: str, output_path: str):
+def _plot_delay_v1(delay: int, input_path: str, output_path: str, _nofail=True):
print("Plotting the delay for manual inspection...")
x = wav_to_np(input_path)[:48_000]
y = wav_to_np(output_path)[:48_000]
- i = np.where(np.abs(x) > 0.1)[0][0] # In case resampled poorly
- di = 20
- plt.figure()
- # plt.plot(x[i - di : i + di], ".-", label="Input")
- plt.plot(
- np.arange(-di, di),
- y[i - di + delay : i + di + delay],
- ".-",
- label="Output",
- )
- plt.axvline(x=0, linestyle="--", color="C1")
- plt.legend()
- plt.show() # This doesn't freeze the notebook
+ i = np.where(np.abs(x) > 0.5 * np.abs(x).max())[0] # In case resampled poorly
+ if len(i) == 0:
+ print("Failed to find the spike in the input file.")
+ print(
+ "Plotting the input and output; there should be spikes at around the "
+ "marked locations."
+ )
+ expected_spikes = 12_000, 36_000 # For v1 specifically
+ fig, axs = plt.subplots(2, 1)
+ for ax, curve in zip(axs, (x, y)):
+ ax.plot(curve)
+ [ax.axvline(x=es, color="C1", linestyle="--") for es in expected_spikes]
+ plt.show()
+ if _nofail:
+ raise RuntimeError("Failed to plot delay")
+ else:
+ i = i[0]
+ di = 20
+ plt.figure()
+ # plt.plot(x[i - di : i + di], ".-", label="Input")
+ plt.plot(
+ np.arange(-di, di),
+ y[i - di + delay : i + di + delay],
+ ".-",
+ label="Output",
+ )
+ plt.axvline(x=0, linestyle="--", color="C1")
+ plt.legend()
+ plt.show() # This doesn't freeze the notebook
def _calibrate_delay(
- delay: Optional[int], input_version: Version, input_path: str, output_path: str,
+ delay: Optional[int],
+ input_version: Version,
+ input_path: str,
+ output_path: str,
) -> int:
if input_version.major == 1:
calibrate, plot = _calibrate_delay_v1, _plot_delay_v1
@@ -226,7 +273,7 @@ def _get_configs(
"name": "WaveNet",
# This should do decently. If you really want a nice model, try turning up
# "channels" in the first block and "input_size" in the second from 12 to 16.
- "config": _get_wavenet_config(architecture)
+ "config": _get_wavenet_config(architecture),
},
"loss": {"val_loss": "esr"},
"optimizer": {"lr": lr},
@@ -302,7 +349,7 @@ def train(
input_version: Optional[Version] = None,
epochs=100,
delay=None,
- architecture: Union[Architecture, str]=Architecture.STANDARD,
+ architecture: Union[Architecture, str] = Architecture.STANDARD,
lr=0.004,
lr_decay=0.007,
seed: Optional[int] = 0,