commit 8d85f3b4e6a8ae109c3462ff78fccf783afcc434
parent 0270e365da1d0cdf45d9d5fd43887c7be6dcc80a
Author: Steven Atkinson <steven@atkinson.mn>
Date: Fri, 2 Jun 2023 21:36:51 -0700
Tune delay calculation threshold (#262)
* fit_cab option in Colab
* Refactor losses, add pre-emphasized MRSTFT
* Fit cab use pre-emphasized MRSTFT only
* Print input version found
* Fix pre-emph
* Squash bugs
* More bugs
* Define private constants for cab modeling MRSTFT
* Always report ESR on validation
* Fix ESR validation
* Fix loss calculation
* Restore MRSTFT weight
* Update gui.py
Put back "Cab modeling" checkbox.
Refactor checkboxes code
* Minor version bump
* Adjust threshold, warn if delays are too different
Diffstat:
1 file changed, 17 insertions(+), 5 deletions(-)
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -200,15 +200,21 @@ _V2_DATA_INFO = _DataInfo(
end_blip_locations=(-72_000, -24_000),
)
-_DELAY_CALIBRATION_ABS_THRESHOLD = 0.0001
+_DELAY_CALIBRATION_ABS_THRESHOLD = 0.0003
_DELAY_CALIBRATION_REL_THRESHOLD = 0.001
_DELAY_CALIBRATION_SAFETY_FACTOR = 4
-def _calibrate_delay_v_all(data_info: _DataInfo, input_path, output_path) -> int:
+def _calibrate_delay_v_all(
+ data_info: _DataInfo,
+ input_path,
+ output_path,
+ abs_threshold=_DELAY_CALIBRATION_ABS_THRESHOLD,
+ rel_threshold=_DELAY_CALIBRATION_REL_THRESHOLD,
+ safety_factor=_DELAY_CALIBRATION_SAFETY_FACTOR,
+) -> int:
lookahead = 1_000
lookback = 10_000
- safety_factor = _DELAY_CALIBRATION_SAFETY_FACTOR
# Calibrate the trigger:
y = wav_to_np(output_path)[: data_info.t_blips]
@@ -216,8 +222,8 @@ def _calibrate_delay_v_all(data_info: _DataInfo, input_path, output_path) -> int
np.abs(y[data_info.noise_interval[0] : data_info.noise_interval[1]])
)
trigger_threshold = max(
- background_level + _DELAY_CALIBRATION_ABS_THRESHOLD,
- (1.0 + _DELAY_CALIBRATION_REL_THRESHOLD) * background_level,
+ background_level + abs_threshold,
+ (1.0 + rel_threshold) * background_level,
)
delays = []
@@ -253,6 +259,12 @@ def _calibrate_delay_v_all(data_info: _DataInfo, input_path, output_path) -> int
print("Delays:")
for d in delays:
print(f" {d}")
+ # If theyr'e _really_ different, then something might be wrong.
+ if np.max(delays) - np.min(delays) >= 20:
+ print(
+ "WARNING: Delays are anomalously different from each other. If this model "
+ "turns out badly, then you might need to provide the delay manually."
+ )
delay = int(np.min(delays)) - safety_factor
print(f"After aplying safety factor, final delay is {delay}")
return delay