commit 50e060872f89a74506d462cbe7b4bbbd93dc5791
parent 55fe7374adebdf7725c6676c568b53cc2a5c1de2
Author: Steven Atkinson <[email protected]>
Date: Sun, 7 May 2023 12:57:38 -0700
Fix v2 checks
Diffstat:
M | nam/train/core.py | | | 111 | ++++++++++++++++++++++++++++++++++++++++++++++++++----------------------------- |
1 file changed, 70 insertions(+), 41 deletions(-)
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -256,47 +256,76 @@ def _check_v1(*args, **kwargs):
def _check_v2(input_path, output_path) -> bool:
- print("V2 checks...")
- rate = REQUIRED_RATE
- y = wav_to_tensor(output_path, rate=rate)
- y_val_1 = y[-19 * rate : -10 * rate]
- y_val_2 = y[-10 * rate : -1 * rate]
- esr_replicate = esr(y_val_1, y_val_2).item()
- print(f"Replicate ESR is {esr_replicate:.8f}.")
- # Do the blips line up?
- # [start/end,replicate]
- blips = [
- [y[: rate // 2], y[rate // 2 : rate]],
- [y[-rate : -rate // 2], y[-rate // 2 :]],
- ]
- mse = nn.MSELoss()
- mse_0 = mse(blips[0][0], blips[0][1]).item() # Within start
- mse_1 = mse(blips[1][0], blips[1][1]).item() # Within end
- mse_cross_0 = mse(blips[0][0], blips[1][0]).item() # 1st repeat, start vs end
- mse_cross_1 = mse(blips[0][1], blips[1][1]).item() # 2nd repeat, start vs end
-
- mse_max = max(mse_0, mse_1)
- # mse_range = mse_max - min(mse_0, mse_1)
- safety_factor = 2.0
- if mse_cross_0 > safety_factor * mse_max or mse_cross_1 > safety_factor * mse_max:
- plt.plot()
- [
- [
- plt.plot(b, label=f"{startend}, replicate {replicate}")
- for replicate, b in enumerate(bi, 1)
- ]
- for startend, bi in zip(("start", "end"), blips)
- ]
- plt.xlabel("Sample")
- plt.ylabel("Output")
- plt.legend()
- plt.grid()
- print(
- "Failed blip checks. Did something change between the start and end of reamping?"
- )
- plt.show()
- return False
- return True
+ with torch.no_grad():
+ print("V2 checks...")
+ rate = REQUIRED_RATE
+ y = wav_to_tensor(output_path, rate=rate)
+ y_val_1 = y[-19 * rate : -10 * rate]
+ y_val_2 = y[-10 * rate : -1 * rate]
+ esr_replicate = esr(y_val_1, y_val_2).item()
+ print(f"Replicate ESR is {esr_replicate:.8f}.")
+
+ # Do the blips line up?
+ # If the ESR is too bad, then flag it.
+ def get_blips(y):
+ """
+ :return: [start/end,replicate]
+ """
+ i0, i1 = rate // 4, 3 * rate // 4
+ j0, j1 = -3 * rate // 4, -rate // 4
+ start = -1000
+ end = 4000
+ blips = torch.stack(
+ [
+ torch.stack([y[i0 + start : i0 + end], y[i1 + start : i1 + end]]),
+ torch.stack([y[j0 + start : j0 + end], y[j1 + start : j1 + end]]),
+ ]
+ )
+ return blips
+
+ blips = get_blips(y)
+ esr_0 = esr(blips[0][0], blips[0][1]).item() # Within start
+ esr_1 = esr(blips[1][0], blips[1][1]).item() # Within end
+ esr_cross_0 = esr(blips[0][0], blips[1][0]).item() # 1st repeat, start vs end
+ esr_cross_1 = esr(blips[0][1], blips[1][1]).item() # 2nd repeat, start vs end
+
+ esr_threshold = 1.0e-3
+
+ def plot_esr_blip_error(msg, arrays, labels):
+ plt.figure()
+ [plt.plot(array, label=label) for array, label in zip(arrays, labels)]
+ plt.xlabel("Sample")
+ plt.ylabel("Output")
+ plt.legend()
+ plt.grid()
+ print(msg)
+ plt.show()
+
+ # Check consecutive blips
+ for e, blip_pair, when in zip((esr_0, esr_1), blips, ("start", "end")):
+ if e >= esr_threshold:
+ plot_esr_blip_error(
+ f"Failed consecutive blip check at {when} of training signal. The "
+ "target tone doesn't seem to be replicable over short timespans. "
+ "Is there a noise gate or a time-based effect in the signal chain?",
+ blip_pair,
+ ("Replicate 1", "Replicate 2"),
+ )
+ return False
+ # Check blips between start & end of train signal
+ for e, blip_pair, replicate in zip(
+ (esr_cross_0, esr_cross_1), blips.permute(1, 0, 2), (1, 2)
+ ):
+ if e >= esr_threshold:
+ plot_esr_blip_error(
+ f"Failed start-to-end blip check for blip replicate {replicate}. "
+ "The target tone doesn't seem to be same at the end of the reamp "
+ "as it was at the start. Did some setting change during reamping?",
+ blip_pair,
+ (f"Start, replicate {replicate}", f"End, replicate {replicate}"),
+ )
+ return False
+ return True
def _check(input_path: str, output_path: str, input_version: Version) -> bool: