neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 055c1bb53e10e6038812d67f181bda10ad18b3dd
parent bc6f76d80f3c06212b7e3661f0997b3b4f4b9c59
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Mon,  8 May 2023 22:16:00 -0700

Tweak v2_0_0.wav checks (#239)

* Better v2 checks, add in kwarg to ignore in core.py

* GUI trainer with ignore checks working

* local and ignore_checks

* Colab tweaks
Diffstat:
Mnam/train/colab.py | 18+++++++++++++-----
Mnam/train/core.py | 153++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------------
Mnam/train/gui.py | 56++++++++++++++++++++++++++++++++++++++------------------
3 files changed, 167 insertions(+), 60 deletions(-)

diff --git a/nam/train/colab.py b/nam/train/colab.py @@ -75,6 +75,7 @@ def run( lr_decay: float = 0.007, seed: Optional[int] = 0, user_metadata: Optional[UserMetadata] = None, + ignore_checks: bool = False, ): """ :param epochs: How amny epochs we'll train for. @@ -85,6 +86,8 @@ def run( :param lr: The initial learning rate :param lr_decay: The amount by which the learning rate decays each epoch :param seed: RNG seed for reproducibility. + :param user_metadata: To include in the exported model + :param ignore_checks: Ignores the data quality checks and YOLOs it """ input_version, input_basename = _check_for_files() @@ -101,10 +104,15 @@ def run( lr=lr, lr_decay=lr_decay, seed=seed, + local=False, + ignore_checks=ignore_checks, ) - print("Exporting your model...") - model_export_outdir = _get_valid_export_directory() - model_export_outdir.mkdir(parents=True, exist_ok=False) - model.net.export(model_export_outdir, user_metadata=user_metadata) - print(f"Model exported to {model_export_outdir}. Enjoy!") + if model is None: + print("No model returned; skip exporting!") + else: + print("Exporting your model...") + model_export_outdir = _get_valid_export_directory() + model_export_outdir.mkdir(parents=True, exist_ok=False) + model.net.export(model_export_outdir, user_metadata=user_metadata) + print(f"Model exported to {model_export_outdir}. Enjoy!") diff --git a/nam/train/core.py b/nam/train/core.py @@ -7,6 +7,7 @@ Functions used by the GUI trainer. """ import hashlib +import tkinter as tk from enum import Enum from time import time from typing import Optional, Sequence, Union @@ -250,6 +251,7 @@ def _calibrate_delay( plot(delay, input_path, output_path) return delay + def _get_lstm_config(architecture): return { Architecture.STANDARD: { @@ -272,11 +274,12 @@ def _get_lstm_config(architecture): }, }[architecture] + def _check_v1(*args, **kwargs): return True -def _check_v2(input_path, output_path) -> bool: +def _check_v2(input_path, output_path, delay: int, silent: bool) -> bool: with torch.no_grad(): print("V2 checks...") rate = REQUIRED_RATE @@ -288,14 +291,18 @@ def _check_v2(input_path, output_path) -> bool: # Do the blips line up? # If the ESR is too bad, then flag it. + print("Checking blips...") + def get_blips(y): """ :return: [start/end,replicate] """ i0, i1 = rate // 4, 3 * rate // 4 j0, j1 = -3 * rate // 4, -rate // 4 - start = -1000 - end = 4000 + + i0, i1, j0, j1 = [i + delay for i in (i0, i1, j0, j1)] + start = -10 + end = 1000 blips = torch.stack( [ torch.stack([y[i0 + start : i0 + end], y[i1 + start : i1 + end]]), @@ -310,25 +317,39 @@ def _check_v2(input_path, output_path) -> bool: esr_cross_0 = esr(blips[0][0], blips[1][0]).item() # 1st repeat, start vs end esr_cross_1 = esr(blips[0][1], blips[1][1]).item() # 2nd repeat, start vs end - esr_threshold = 1.0e-3 - - def plot_esr_blip_error(msg, arrays, labels): - plt.figure() - [plt.plot(array, label=label) for array, label in zip(arrays, labels)] - plt.xlabel("Sample") - plt.ylabel("Output") - plt.legend() - plt.grid() + print(" ESRs:") + print(f" Start : {esr_0}") + print(f" End : {esr_1}") + print(f" Cross (1) : {esr_cross_0}") + print(f" Cross (2) : {esr_cross_1}") + + esr_threshold = 1.0e-2 + + def plot_esr_blip_error(silent, msg, arrays, labels): + if not silent: + plt.figure() + [plt.plot(array, label=label) for array, label in zip(arrays, labels)] + plt.xlabel("Sample") + plt.ylabel("Output") + plt.legend() + plt.grid() print(msg) - plt.show() + if not silent: + plt.show() # Check consecutive blips for e, blip_pair, when in zip((esr_0, esr_1), blips, ("start", "end")): if e >= esr_threshold: plot_esr_blip_error( + silent, f"Failed consecutive blip check at {when} of training signal. The " - "target tone doesn't seem to be replicable over short timespans. " - "Is there a noise gate or a time-based effect in the signal chain?", + "target tone doesn't seem to be replicable over short timespans." + "\n\n" + " Possible causes:\n\n" + " * Your recording setup is really noisy.\n" + " * There's a noise gate that's messing things up.\n" + " * There's a time-based effect (compressor, delay, reverb) in " + "the signal chain", blip_pair, ("Replicate 1", "Replicate 2"), ) @@ -339,6 +360,7 @@ def _check_v2(input_path, output_path) -> bool: ): if e >= esr_threshold: plot_esr_blip_error( + silent, f"Failed start-to-end blip check for blip replicate {replicate}. " "The target tone doesn't seem to be same at the end of the reamp " "as it was at the start. Did some setting change during reamping?", @@ -349,7 +371,9 @@ def _check_v2(input_path, output_path) -> bool: return True -def _check(input_path: str, output_path: str, input_version: Version) -> bool: +def _check( + input_path: str, output_path: str, input_version: Version, delay: int, silent: bool +) -> bool: """ Ensure that everything should go smoothly @@ -362,7 +386,7 @@ def _check(input_path: str, output_path: str, input_version: Version) -> bool: else: print(f"Checks not implemented for input version {input_version}; skip") return True - return f(input_path, output_path) + return f(input_path, output_path, delay, silent) def _get_wavenet_config(architecture): @@ -493,7 +517,7 @@ def _get_configs( "delay": delay, }, } - + if model_type == "WaveNet": model_config = { "net": { @@ -504,7 +528,10 @@ def _get_configs( }, "loss": {"val_loss": "esr"}, "optimizer": {"lr": lr}, - "lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 1.0 - lr_decay}}, + "lr_scheduler": { + "class": "ExponentialLR", + "kwargs": {"gamma": 1.0 - lr_decay}, + }, } else: model_config = { @@ -514,20 +541,13 @@ def _get_configs( }, "loss": { "val_loss": "mse", - "mask_first": 4096, + "mask_first": 4096, "pre_emph_weight": 1.0, - "pre_emph_coef": 0.85 + "pre_emph_coef": 0.85, }, - "optimizer": { - "lr": 0.01 - }, - "lr_scheduler": { - "class": "ExponentialLR", - "kwargs": { - "gamma": 0.995 - } - } - } + "optimizer": {"lr": 0.01}, + "lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 0.995}}, + } if torch.cuda.is_available(): device_config = {"accelerator": "gpu", "devices": 1} @@ -600,6 +620,48 @@ def _plot( plt.show() +def _print_nasty_checks_warning(): + """ + "ffs" -Dom + """ + print( + "\n" + "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX\n" + "X X\n" + "X WARNING: X\n" + "X X\n" + "X You are ignoring the checks! Your model might turn out bad! X\n" + "X X\n" + "X I warned you! X\n" + "X X\n" + "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX\n" + ) + + +def _nasty_checks_modal(): + msg = "You are ignoring the checks!\nYour model might turn out bad!" + + root = tk.Tk() + root.withdraw() # hide the root window + modal = tk.Toplevel(root) + modal.geometry("300x100") + modal.title("Warning!") + label = tk.Label(modal, text=msg) + label.pack(pady=10) + ok_button = tk.Button( + modal, + text="I can only blame myself!", + command=lambda: [modal.destroy(), root.quit()], + ) + ok_button.pack() + modal.grab_set() # disable interaction with root window while modal is open + modal.mainloop() + + +# Example usage: +# show_modal("Hello, World!") + + def train( input_path: str, output_path: str, @@ -617,22 +679,39 @@ def train( save_plot: bool = False, silent: bool = False, modelname: str = "model", -): + ignore_checks: bool = False, + local: bool = False, +) -> Optional[Model]: if seed is not None: torch.manual_seed(seed) + if input_version is None: + input_version = _detect_input_version(input_path) + if delay is None: - if input_version is None: - input_version = _detect_input_version(input_path) delay = _calibrate_delay( delay, input_version, input_path, output_path, silent=silent ) else: print(f"Delay provided as {delay}; skip calibration") - if not _check(input_path, output_path, input_version): - print("Failed checks; exit training") - return + if _check(input_path, output_path, input_version, delay, silent): + print("-Checks passed") + else: + print("Failed checks!") + if ignore_checks: + if local and not silent: + _nasty_checks_modal() + else: + _print_nasty_checks_warning() + elif not local: # And not ignore_checks + print( + "(To disable this check, run AT YOUR OWN RISK with " + "`ignore_checks=True`.)" + ) + if not ignore_checks: + print("Exiting core training...") + return data_config, model_config, learning_config = _get_configs( input_version, diff --git a/nam/train/gui.py b/nam/train/gui.py @@ -60,6 +60,7 @@ _BUTTON_HEIGHT = 2 _TEXT_WIDTH = 70 _DEFAULT_DELAY = None +_DEFAULT_IGNORE_CHECKS = False _ADVANCED_OPTIONS_LEFT_WIDTH = 12 _ADVANCED_OPTIONS_RIGHT_WIDTH = 12 @@ -71,6 +72,7 @@ class _AdvancedOptions(object): architecture: core.Architecture num_epochs: int delay: Optional[int] + ignore_checks: bool class _PathType(Enum): @@ -203,7 +205,10 @@ class _GUI(object): # Advanced options for training default_architecture = core.Architecture.STANDARD self.advanced_options = _AdvancedOptions( - default_architecture, _DEFAULT_NUM_EPOCHS, _DEFAULT_DELAY + default_architecture, + _DEFAULT_NUM_EPOCHS, + _DEFAULT_DELAY, + _DEFAULT_IGNORE_CHECKS, ) # Window to edit them: self._frame_advanced_options = tk.Frame(self._root) @@ -233,6 +238,23 @@ class _GUI(object): self._check_button_states() + def _check_button_states(self): + """ + Determine if any buttons should be disabled + """ + # Train button is disabled unless all paths are set + if any( + pb.val is None + for pb in ( + self._path_button_input, + self._path_button_output, + self._path_button_train_destination, + ) + ): + self._button_train["state"] = tk.DISABLED + return + self._button_train["state"] = tk.NORMAL + def _get_additional_options_frame(self): # Checkboxes self._frame_silent = tk.Frame(self._root) @@ -257,6 +279,16 @@ class _GUI(object): ) self._chkbox_save_plot.grid(row=2, column=1, sticky="W") + # Skip the data quality checks! + self._ignore_checks = tk.BooleanVar() + self._ignore_checks.set(False) + self._chkbox_ignore_checks = tk.Checkbutton( + self._frame_silent, + text="Ignore data quality checks (DO AT YOUR OWN RISK!)", + variable=self._ignore_checks, + ) + self._chkbox_ignore_checks.grid(row=3, column=1, sticky="W") + def mainloop(self): self._root.mainloop() @@ -311,7 +343,12 @@ class _GUI(object): silent=self._silent.get(), save_plot=self._save_plot.get(), modelname=basename, + ignore_checks=self._ignore_checks.get(), + local=True, ) + if trained_model is None: + print("Model training failed! Skip exporting...") + continue print("Model training complete!") print("Exporting...") outdir = self._path_button_train_destination.val @@ -329,23 +366,6 @@ class _GUI(object): # the user re-visits the window and clicks "ok" self.user_metadata_flag = False - def _check_button_states(self): - """ - Determine if any buttons should be disabled - """ - # Train button is diabled unless all paths are set - if any( - pb.val is None - for pb in ( - self._path_button_input, - self._path_button_output, - self._path_button_train_destination, - ) - ): - self._button_train["state"] = tk.DISABLED - return - self._button_train["state"] = tk.NORMAL - # some typing functions def _non_negative_int(val):