neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 50e060872f89a74506d462cbe7b4bbbd93dc5791
parent 55fe7374adebdf7725c6676c568b53cc2a5c1de2
Author: Steven Atkinson <[email protected]>
Date:   Sun,  7 May 2023 12:57:38 -0700

Fix v2 checks

Diffstat:
Mnam/train/core.py | 111++++++++++++++++++++++++++++++++++++++++++++++++++-----------------------------
1 file changed, 70 insertions(+), 41 deletions(-)

diff --git a/nam/train/core.py b/nam/train/core.py @@ -256,47 +256,76 @@ def _check_v1(*args, **kwargs): def _check_v2(input_path, output_path) -> bool: - print("V2 checks...") - rate = REQUIRED_RATE - y = wav_to_tensor(output_path, rate=rate) - y_val_1 = y[-19 * rate : -10 * rate] - y_val_2 = y[-10 * rate : -1 * rate] - esr_replicate = esr(y_val_1, y_val_2).item() - print(f"Replicate ESR is {esr_replicate:.8f}.") - # Do the blips line up? - # [start/end,replicate] - blips = [ - [y[: rate // 2], y[rate // 2 : rate]], - [y[-rate : -rate // 2], y[-rate // 2 :]], - ] - mse = nn.MSELoss() - mse_0 = mse(blips[0][0], blips[0][1]).item() # Within start - mse_1 = mse(blips[1][0], blips[1][1]).item() # Within end - mse_cross_0 = mse(blips[0][0], blips[1][0]).item() # 1st repeat, start vs end - mse_cross_1 = mse(blips[0][1], blips[1][1]).item() # 2nd repeat, start vs end - - mse_max = max(mse_0, mse_1) - # mse_range = mse_max - min(mse_0, mse_1) - safety_factor = 2.0 - if mse_cross_0 > safety_factor * mse_max or mse_cross_1 > safety_factor * mse_max: - plt.plot() - [ - [ - plt.plot(b, label=f"{startend}, replicate {replicate}") - for replicate, b in enumerate(bi, 1) - ] - for startend, bi in zip(("start", "end"), blips) - ] - plt.xlabel("Sample") - plt.ylabel("Output") - plt.legend() - plt.grid() - print( - "Failed blip checks. Did something change between the start and end of reamping?" - ) - plt.show() - return False - return True + with torch.no_grad(): + print("V2 checks...") + rate = REQUIRED_RATE + y = wav_to_tensor(output_path, rate=rate) + y_val_1 = y[-19 * rate : -10 * rate] + y_val_2 = y[-10 * rate : -1 * rate] + esr_replicate = esr(y_val_1, y_val_2).item() + print(f"Replicate ESR is {esr_replicate:.8f}.") + + # Do the blips line up? + # If the ESR is too bad, then flag it. + def get_blips(y): + """ + :return: [start/end,replicate] + """ + i0, i1 = rate // 4, 3 * rate // 4 + j0, j1 = -3 * rate // 4, -rate // 4 + start = -1000 + end = 4000 + blips = torch.stack( + [ + torch.stack([y[i0 + start : i0 + end], y[i1 + start : i1 + end]]), + torch.stack([y[j0 + start : j0 + end], y[j1 + start : j1 + end]]), + ] + ) + return blips + + blips = get_blips(y) + esr_0 = esr(blips[0][0], blips[0][1]).item() # Within start + esr_1 = esr(blips[1][0], blips[1][1]).item() # Within end + esr_cross_0 = esr(blips[0][0], blips[1][0]).item() # 1st repeat, start vs end + esr_cross_1 = esr(blips[0][1], blips[1][1]).item() # 2nd repeat, start vs end + + esr_threshold = 1.0e-3 + + def plot_esr_blip_error(msg, arrays, labels): + plt.figure() + [plt.plot(array, label=label) for array, label in zip(arrays, labels)] + plt.xlabel("Sample") + plt.ylabel("Output") + plt.legend() + plt.grid() + print(msg) + plt.show() + + # Check consecutive blips + for e, blip_pair, when in zip((esr_0, esr_1), blips, ("start", "end")): + if e >= esr_threshold: + plot_esr_blip_error( + f"Failed consecutive blip check at {when} of training signal. The " + "target tone doesn't seem to be replicable over short timespans. " + "Is there a noise gate or a time-based effect in the signal chain?", + blip_pair, + ("Replicate 1", "Replicate 2"), + ) + return False + # Check blips between start & end of train signal + for e, blip_pair, replicate in zip( + (esr_cross_0, esr_cross_1), blips.permute(1, 0, 2), (1, 2) + ): + if e >= esr_threshold: + plot_esr_blip_error( + f"Failed start-to-end blip check for blip replicate {replicate}. " + "The target tone doesn't seem to be same at the end of the reamp " + "as it was at the start. Did some setting change during reamping?", + blip_pair, + (f"Start, replicate {replicate}", f"End, replicate {replicate}"), + ) + return False + return True def _check(input_path: str, output_path: str, input_version: Version) -> bool: