commit 28f9e2e94c2b50101c2a2a24cc0cf8a7693f2267
parent f0e27f454a4615d1184710f518d2ad55ff13dd65
Author: Steven Atkinson <[email protected]>
Date: Sun, 6 Oct 2024 21:10:52 -0700
[ENHANCEMENT] Better latency calculation via averaging (#485)
* Take impulse response replicates' mean to denoise for latency calculation
* Reduce safety factor to 1
* Fix test
Diffstat:
2 files changed, 39 insertions(+), 37 deletions(-)
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -333,7 +333,7 @@ _V4_DATA_INFO = _DataInfo(
_DELAY_CALIBRATION_ABS_THRESHOLD = 0.0003
_DELAY_CALIBRATION_REL_THRESHOLD = 0.001
-_DELAY_CALIBRATION_SAFETY_FACTOR = 4
+_DELAY_CALIBRATION_SAFETY_FACTOR = 1 # Might be able to make this zero...
def _warn_lookaheads(indices: Sequence[int]) -> str:
@@ -391,7 +391,7 @@ def _calibrate_latency_v_all(
lookahead = 1_000
lookback = 10_000
- # Calibrate the trigger:
+ # Calibrate the level for the trigger:
y = y[data_info.first_blips_start : data_info.first_blips_start + data_info.t_blips]
background_level = np.max(
np.abs(
@@ -407,51 +407,53 @@ def _calibrate_latency_v_all(
(1.0 + rel_threshold) * background_level,
)
- delays = []
+ y_scans = []
for blip_index, i_abs in enumerate(data_info.blip_locations[0], 1):
# Relative to start of the data
i_rel = i_abs - data_info.first_blips_start
start_looking = i_rel - lookahead
stop_looking = i_rel + lookback
- y_scan = y[start_looking:stop_looking]
- triggered = np.where(np.abs(y_scan) > trigger_threshold)[0]
- if len(triggered) == 0:
- msg = (
- f"No response activated the trigger in response to blip "
- f"{blip_index}. Is something wrong with the reamp?"
- )
- print(msg)
- print("SHARE THIS PLOT IF YOU ASK FOR HELP")
- plt.figure()
- plt.plot(np.arange(-lookahead, lookback), y_scan, label="Signal")
- plt.axvline(x=0, color="C1", linestyle="--", label="Trigger")
- plt.axhline(
- y=-trigger_threshold, color="k", linestyle="--", label="Threshold"
- )
- plt.axhline(y=trigger_threshold, color="k", linestyle="--")
- plt.xlim((-lookahead, lookback))
- plt.xlabel("Samples")
- plt.ylabel("Response")
- plt.legend()
- plt.show()
- raise RuntimeError(msg)
- else:
- j = triggered[0]
- delays.append(j + start_looking - i_rel)
+ y_scans.append(y[start_looking:stop_looking])
+ y_scan_average = np.mean(np.stack(y_scans), axis=0)
+ triggered = np.where(np.abs(y_scan_average) > trigger_threshold)[0]
+ if len(triggered) == 0:
+ msg = (
+ "No response activated the trigger in response to input spikes. "
+ "Is something wrong with the reamp?"
+ )
+ print(msg)
+ print("SHARE THIS PLOT IF YOU ASK FOR HELP")
+ plt.figure()
+ plt.plot(np.arange(-lookahead, lookback), y_scan_average, color="C0", label="Signal average")
+ for y_scan in y_scans:
+ plt.plot(np.arange(-lookahead, lookback), y_scan, color="C0", alpha=0.2)
+ plt.axvline(x=0, color="C1", linestyle="--", label="Trigger")
+ plt.axhline(
+ y=-trigger_threshold, color="k", linestyle="--", label="Threshold"
+ )
+ plt.axhline(y=trigger_threshold, color="k", linestyle="--")
+ plt.xlim((-lookahead, lookback))
+ plt.xlabel("Samples")
+ plt.ylabel("Response")
+ plt.legend()
+ plt.title("SHARE THIS PLOT IF YOU ASK FOR HELP")
+ plt.show()
+ raise RuntimeError(msg)
+ else:
+ j = triggered[0]
+ delay = j + start_looking - i_rel
- print("Delays:")
- for i_rel, d in enumerate(delays, 1):
- print(f" Blip {i_rel:2d}: {d}")
- warnings = report_any_latency_warnings(delays)
+ print(f"Delay based on average is {delay}")
+ warnings = report_any_latency_warnings([delay])
- delay_post_safety_factor = int(np.min(delays)) - safety_factor
+ delay_post_safety_factor = delay - safety_factor
print(
f"After aplying safety factor of {safety_factor}, the final delay is "
f"{delay_post_safety_factor}"
)
return metadata.LatencyCalibration(
algorithm_version=1,
- delays=delays,
+ delays=[delay],
safety_factor=safety_factor,
recommended=delay_post_safety_factor,
warnings=warnings,
diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py
@@ -161,9 +161,9 @@ class _TCalibrateDelay(object):
with Capturing() as output:
self._calibrate_delay(y)
# `[0]` -- Only look in the first set of blip locations
- expected_warning = core._warn_lookaheads(
- list(range(1, len(self._data_info.blip_locations[0]) + 1))
- )
+ # With #485, we average them all together so there's only one index.
+ # TODO clean this up.
+ expected_warning = core._warn_lookaheads([1]) # "Blip 1"
assert any(o == expected_warning for o in output), output