neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit e61ebb5b285b667d40232e071f35b488e296d8aa
parent 06547c8579cbec9f0eb4f8ca732083e91b672e4e
Author: Steven Atkinson <[email protected]>
Date:   Sat,  9 Sep 2023 18:32:45 -0700

Warn if calibrated delay is the maximum lookahead (#306)

* Warn if detected delay is the maximum lookahead

* Comment out snapshot export at end of training
Diffstat:
Mbin/train/main.py | 9++++++++-
Mnam/models/_base.py | 16++++++++++++----
Mnam/train/core.py | 48+++++++++++++++++++++++++++++++++---------------
Mtests/resources/__init__.py | 6++++++
Mtests/test_nam/test_train/test_core.py | 48+++++++++++++++++++++++++++++++++++++++++++-----
5 files changed, 102 insertions(+), 25 deletions(-)

diff --git a/bin/train/main.py b/bin/train/main.py @@ -87,7 +87,11 @@ def plot( args = (ds.vals, ds.x) if isinstance(ds, ParametricDataset) else (ds.x,) output = model(*args).flatten().cpu().numpy() t1 = time() - print(f"Took {t1 - t0:.2f} ({tx / (t1 - t0):.2f}x)") + try: + rt = f"{tx / (t1 - t0):.2f}" + except ZeroDivisionError as e: + rt = "???" + print(f"Took {t1 - t0:.2f} ({rt}x)") plt.figure(figsize=(16, 5)) # plt.plot(ds.x[window_start:window_end], label="Input") @@ -221,6 +225,9 @@ def main_inner( show=False, ) plot(model, dataset_validation, show=not no_show) + # Would like to, but this doesn't work for all cases. + # If you're making snapshot models, you may find this convenient to uncomment :) + # model.net.export(outdir) if __name__ == "__main__": diff --git a/nam/models/_base.py b/nam/models/_base.py @@ -144,7 +144,9 @@ class BaseNet(_Base): if scalar: x = x[None] if pad_start: - x = torch.cat((torch.zeros((len(x), self.receptive_field - 1)), x), dim=1) + x = torch.cat( + (torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), dim=1 + ) y = self._forward(x, **kwargs) if scalar: y = y[0] @@ -176,7 +178,11 @@ class ParametricBaseNet(_Base): """ def forward( - self, params: torch.Tensor, x: torch.Tensor, pad_start: Optional[bool] = None + self, + params: torch.Tensor, + x: torch.Tensor, + pad_start: Optional[bool] = None, + **kwargs ): pad_start = self.pad_start_default if pad_start is None else pad_start scalar = x.ndim == 1 @@ -184,8 +190,10 @@ class ParametricBaseNet(_Base): x = x[None] params = params[None] if pad_start: - x = torch.cat((torch.zeros((len(x), self.receptive_field - 1)), x), dim=1) - y = self._forward(params, x) + x = torch.cat( + (torch.zeros((len(x), self.receptive_field - 1)).to(x.device), x), dim=1 + ) + y = self._forward(params, x, **kwargs) if scalar: y = y[0] return y diff --git a/nam/train/core.py b/nam/train/core.py @@ -161,7 +161,7 @@ class _DataInfo(BaseModel): """ :param major_version: Data major version :param rate: Sample rate, in Hz - :param t_blips: How long the blips are, in seconds + :param t_blips: How long the blips are, in samples :param t_validate: Validation signal length, in samples :param validation_start: Where validation signal starts, in samples. Less than zero (from the end of the array). @@ -206,19 +206,41 @@ _DELAY_CALIBRATION_REL_THRESHOLD = 0.001 _DELAY_CALIBRATION_SAFETY_FACTOR = 4 +def _warn_lookaheads(indices: Sequence[int]) -> str: + return ( + f"WARNING: delays from some blips ({','.join([str(i) for i in indices])}) are " + "at the minimum value possible. This usually means that something is " + "wrong with your data. Check if trianing ends with a poor result!" + ) + + def _calibrate_delay_v_all( data_info: _DataInfo, - input_path, - output_path, + y, abs_threshold=_DELAY_CALIBRATION_ABS_THRESHOLD, rel_threshold=_DELAY_CALIBRATION_REL_THRESHOLD, safety_factor=_DELAY_CALIBRATION_SAFETY_FACTOR, ) -> int: + def report_any_delay_warnings(delays: Sequence[int]): + # Warnings associated with any single delay: + + lookahead_warnings = [i for i, d in enumerate(delays, 1) if d == -lookahead] + if len(lookahead_warnings) > 0: + print(_warn_lookaheads(lookahead_warnings)) + + # Ensemble warnings + + # If they're _really_ different, then something might be wrong. + if np.max(delays) - np.min(delays) >= 20: + print( + "WARNING: Delays are anomalously different from each other. If this model " + "turns out badly, then you might need to provide the delay manually." + ) + lookahead = 1_000 lookback = 10_000 - # Calibrate the trigger: - y = wav_to_np(output_path)[: data_info.t_blips] + y = y[: data_info.t_blips] background_level = np.max( np.abs(y[data_info.noise_interval[0] : data_info.noise_interval[1]]) ) @@ -258,16 +280,12 @@ def _calibrate_delay_v_all( delays.append(j + start_looking - i) print("Delays:") - for d in delays: - print(f" {d}") - # If theyr'e _really_ different, then something might be wrong. - if np.max(delays) - np.min(delays) >= 20: - print( - "WARNING: Delays are anomalously different from each other. If this model " - "turns out badly, then you might need to provide the delay manually." - ) + for i, d in enumerate(delays, 1): + print(f" Blip {i:2d}: {d}") + report_any_delay_warnings(delays) + delay = int(np.min(delays)) - safety_factor - print(f"After aplying safety factor, final delay is {delay}") + print(f"After aplying safety factor of {safety_factor}, the final delay is {delay}") return delay @@ -337,7 +355,7 @@ def _calibrate_delay( print(f"Delay is specified as {delay}") else: print("Delay wasn't provided; attempting to calibrate automatically...") - delay = calibrate(input_path, output_path) + delay = calibrate(wav_to_np(output_path)) if not silent: plot(delay, input_path, output_path) return delay diff --git a/tests/resources/__init__.py b/tests/resources/__init__.py @@ -2,6 +2,12 @@ # Created Date: Thursday May 18th 2023 # Author: Steven Atkinson ([email protected]) +""" +Download the standardized reamping files to the directory containing this file. +See: +https://github.com/sdatkinson/neural-amp-modeler/tree/main#standardized-reamping-files +""" + from pathlib import Path import pytest diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py @@ -2,6 +2,8 @@ # Created Date: Thursday May 18th 2023 # Author: Steven Atkinson ([email protected]) +import sys +from io import StringIO from pathlib import Path from tempfile import TemporaryDirectory @@ -101,11 +103,46 @@ class _TCalibrateDelay(object): x = np.zeros((self._data_info.t_blips)) for i in self._data_info.start_blip_locations: x[i + expected_delay] = 1.0 - with TemporaryDirectory() as tmpdir: - path = Path(tmpdir, "output.wav") - np_to_wav(x, path) - delay = self._calibrate_delay(None, path) - assert delay == expected_delay - core._DELAY_CALIBRATION_SAFETY_FACTOR + + delay = self._calibrate_delay(x) + assert delay == expected_delay - core._DELAY_CALIBRATION_SAFETY_FACTOR + + def test_lookahead_warning(self): + """ + If the delay is equal to the (negative) lookahead, then something is probably wrong. + Assert that we're warned. + + See: https://github.com/sdatkinson/neural-amp-modeler/issues/304 + """ + + # Make the response loud enough to trigger the threshold everywhere. + # Use the absolute threshold since the relative will be zero (I'll make it + # silent where it's calibrated.) + y = np.full( + (self._data_info.t_blips,), core._DELAY_CALIBRATION_ABS_THRESHOLD + 0.01 + ) + # Make the signal silent where the threshold is calibrated so the absolute + # threshold is used. + y[self._data_info.noise_interval[0] : self._data_info.noise_interval[1]] = 0.0 + + # Prepare to capture the output and look for a warning. + class Capturing(list): + def __enter__(self): + self._stdout = sys.stdout + sys.stdout = self._stringio = StringIO() + return self + + def __exit__(self, *args): + self.extend(self._stringio.getvalue().splitlines()) + del self._stringio + sys.stdout = self._stdout + + with Capturing() as output: + self._calibrate_delay(y) + expected_warning = core._warn_lookaheads( + list(range(1, len(self._data_info.start_blip_locations) + 1)) + ) + assert any(o == expected_warning for o in output), output class TestCalibrateDelayV1(_TCalibrateDelay): @@ -151,5 +188,6 @@ TestValidationDatasetV2_0_0 = _make_t_validation_dataset_class( Version(2, 0, 0), requires_v2_0_0, core._V2_DATA_INFO ) + if __name__ == "__main__": pytest.main()