commit e61ebb5b285b667d40232e071f35b488e296d8aa
parent 06547c8579cbec9f0eb4f8ca732083e91b672e4e
Author: Steven Atkinson <[email protected]>
Date: Sat, 9 Sep 2023 18:32:45 -0700
Warn if calibrated delay is the maximum lookahead (#306)
* Warn if detected delay is the maximum lookahead
* Comment out snapshot export at end of training
Diffstat:
5 files changed, 102 insertions(+), 25 deletions(-)
diff --git a/bin/train/main.py b/bin/train/main.py
@@ -87,7 +87,11 @@ def plot(
args = (ds.vals, ds.x) if isinstance(ds, ParametricDataset) else (ds.x,)
output = model(*args).flatten().cpu().numpy()
t1 = time()
- print(f"Took {t1 - t0:.2f} ({tx / (t1 - t0):.2f}x)")
+ try:
+ rt = f"{tx / (t1 - t0):.2f}"
+ except ZeroDivisionError as e:
+ rt = "???"
+ print(f"Took {t1 - t0:.2f} ({rt}x)")
plt.figure(figsize=(16, 5))
# plt.plot(ds.x[window_start:window_end], label="Input")
@@ -221,6 +225,9 @@ def main_inner(
show=False,
)
plot(model, dataset_validation, show=not no_show)
+ # Would like to, but this doesn't work for all cases.
+ # If you're making snapshot models, you may find this convenient to uncomment :)
+ # model.net.export(outdir)
if __name__ == "__main__":
diff --git a/nam/models/_base.py b/nam/models/_base.py
@@ -144,7 +144,9 @@ class BaseNet(_Base):
if scalar:
x = x[None]
if pad_start:
- x = torch.cat((torch.zeros((len(x), self.receptive_field - 1)), x), dim=1)
+ x = torch.cat(
+ (torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), dim=1
+ )
y = self._forward(x, **kwargs)
if scalar:
y = y[0]
@@ -176,7 +178,11 @@ class ParametricBaseNet(_Base):
"""
def forward(
- self, params: torch.Tensor, x: torch.Tensor, pad_start: Optional[bool] = None
+ self,
+ params: torch.Tensor,
+ x: torch.Tensor,
+ pad_start: Optional[bool] = None,
+ **kwargs
):
pad_start = self.pad_start_default if pad_start is None else pad_start
scalar = x.ndim == 1
@@ -184,8 +190,10 @@ class ParametricBaseNet(_Base):
x = x[None]
params = params[None]
if pad_start:
- x = torch.cat((torch.zeros((len(x), self.receptive_field - 1)), x), dim=1)
- y = self._forward(params, x)
+ x = torch.cat(
+ (torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), dim=1
+ )
+ y = self._forward(params, x, **kwargs)
if scalar:
y = y[0]
return y
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -161,7 +161,7 @@ class _DataInfo(BaseModel):
"""
:param major_version: Data major version
:param rate: Sample rate, in Hz
- :param t_blips: How long the blips are, in seconds
+ :param t_blips: How long the blips are, in samples
:param t_validate: Validation signal length, in samples
:param validation_start: Where validation signal starts, in samples. Less than zero
(from the end of the array).
@@ -206,19 +206,41 @@ _DELAY_CALIBRATION_REL_THRESHOLD = 0.001
_DELAY_CALIBRATION_SAFETY_FACTOR = 4
+def _warn_lookaheads(indices: Sequence[int]) -> str:
+ return (
+ f"WARNING: delays from some blips ({','.join([str(i) for i in indices])}) are "
+ "at the minimum value possible. This usually means that something is "
+ "wrong with your data. Check if trianing ends with a poor result!"
+ )
+
+
def _calibrate_delay_v_all(
data_info: _DataInfo,
- input_path,
- output_path,
+ y,
abs_threshold=_DELAY_CALIBRATION_ABS_THRESHOLD,
rel_threshold=_DELAY_CALIBRATION_REL_THRESHOLD,
safety_factor=_DELAY_CALIBRATION_SAFETY_FACTOR,
) -> int:
+ def report_any_delay_warnings(delays: Sequence[int]):
+ # Warnings associated with any single delay:
+
+ lookahead_warnings = [i for i, d in enumerate(delays, 1) if d == -lookahead]
+ if len(lookahead_warnings) > 0:
+ print(_warn_lookaheads(lookahead_warnings))
+
+ # Ensemble warnings
+
+ # If they're _really_ different, then something might be wrong.
+ if np.max(delays) - np.min(delays) >= 20:
+ print(
+ "WARNING: Delays are anomalously different from each other. If this model "
+ "turns out badly, then you might need to provide the delay manually."
+ )
+
lookahead = 1_000
lookback = 10_000
-
# Calibrate the trigger:
- y = wav_to_np(output_path)[: data_info.t_blips]
+ y = y[: data_info.t_blips]
background_level = np.max(
np.abs(y[data_info.noise_interval[0] : data_info.noise_interval[1]])
)
@@ -258,16 +280,12 @@ def _calibrate_delay_v_all(
delays.append(j + start_looking - i)
print("Delays:")
- for d in delays:
- print(f" {d}")
- # If theyr'e _really_ different, then something might be wrong.
- if np.max(delays) - np.min(delays) >= 20:
- print(
- "WARNING: Delays are anomalously different from each other. If this model "
- "turns out badly, then you might need to provide the delay manually."
- )
+ for i, d in enumerate(delays, 1):
+ print(f" Blip {i:2d}: {d}")
+ report_any_delay_warnings(delays)
+
delay = int(np.min(delays)) - safety_factor
- print(f"After aplying safety factor, final delay is {delay}")
+ print(f"After aplying safety factor of {safety_factor}, the final delay is {delay}")
return delay
@@ -337,7 +355,7 @@ def _calibrate_delay(
print(f"Delay is specified as {delay}")
else:
print("Delay wasn't provided; attempting to calibrate automatically...")
- delay = calibrate(input_path, output_path)
+ delay = calibrate(wav_to_np(output_path))
if not silent:
plot(delay, input_path, output_path)
return delay
diff --git a/tests/resources/__init__.py b/tests/resources/__init__.py
@@ -2,6 +2,12 @@
# Created Date: Thursday May 18th 2023
# Author: Steven Atkinson ([email protected])
+"""
+Download the standardized reamping files to the directory containing this file.
+See:
+https://github.com/sdatkinson/neural-amp-modeler/tree/main#standardized-reamping-files
+"""
+
from pathlib import Path
import pytest
diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py
@@ -2,6 +2,8 @@
# Created Date: Thursday May 18th 2023
# Author: Steven Atkinson ([email protected])
+import sys
+from io import StringIO
from pathlib import Path
from tempfile import TemporaryDirectory
@@ -101,11 +103,46 @@ class _TCalibrateDelay(object):
x = np.zeros((self._data_info.t_blips))
for i in self._data_info.start_blip_locations:
x[i + expected_delay] = 1.0
- with TemporaryDirectory() as tmpdir:
- path = Path(tmpdir, "output.wav")
- np_to_wav(x, path)
- delay = self._calibrate_delay(None, path)
- assert delay == expected_delay - core._DELAY_CALIBRATION_SAFETY_FACTOR
+
+ delay = self._calibrate_delay(x)
+ assert delay == expected_delay - core._DELAY_CALIBRATION_SAFETY_FACTOR
+
+ def test_lookahead_warning(self):
+ """
+ If the delay is equal to the (negative) lookahead, then something is probably wrong.
+ Assert that we're warned.
+
+ See: https://github.com/sdatkinson/neural-amp-modeler/issues/304
+ """
+
+ # Make the response loud enough to trigger the threshold everywhere.
+ # Use the absolute threshold since the relative will be zero (I'll make it
+ # silent where it's calibrated.)
+ y = np.full(
+ (self._data_info.t_blips,), core._DELAY_CALIBRATION_ABS_THRESHOLD + 0.01
+ )
+ # Make the signal silent where the threshold is calibrated so the absolute
+ # threshold is used.
+ y[self._data_info.noise_interval[0] : self._data_info.noise_interval[1]] = 0.0
+
+ # Prepare to capture the output and look for a warning.
+ class Capturing(list):
+ def __enter__(self):
+ self._stdout = sys.stdout
+ sys.stdout = self._stringio = StringIO()
+ return self
+
+ def __exit__(self, *args):
+ self.extend(self._stringio.getvalue().splitlines())
+ del self._stringio
+ sys.stdout = self._stdout
+
+ with Capturing() as output:
+ self._calibrate_delay(y)
+ expected_warning = core._warn_lookaheads(
+ list(range(1, len(self._data_info.start_blip_locations) + 1))
+ )
+ assert any(o == expected_warning for o in output), output
class TestCalibrateDelayV1(_TCalibrateDelay):
@@ -151,5 +188,6 @@ TestValidationDatasetV2_0_0 = _make_t_validation_dataset_class(
Version(2, 0, 0), requires_v2_0_0, core._V2_DATA_INFO
)
+
if __name__ == "__main__":
pytest.main()