commit 055c1bb53e10e6038812d67f181bda10ad18b3dd
parent bc6f76d80f3c06212b7e3661f0997b3b4f4b9c59
Author: Steven Atkinson <steven@atkinson.mn>
Date: Mon, 8 May 2023 22:16:00 -0700
Tweak v2_0_0.wav checks (#239)
* Better v2 checks, add in kwarg to ignore in core.py
* GUI trainer with ignore checks working
* local and ignore_checks
* Colab tweaks
Diffstat:
3 files changed, 167 insertions(+), 60 deletions(-)
diff --git a/nam/train/colab.py b/nam/train/colab.py
@@ -75,6 +75,7 @@ def run(
lr_decay: float = 0.007,
seed: Optional[int] = 0,
user_metadata: Optional[UserMetadata] = None,
+ ignore_checks: bool = False,
):
"""
:param epochs: How amny epochs we'll train for.
@@ -85,6 +86,8 @@ def run(
:param lr: The initial learning rate
:param lr_decay: The amount by which the learning rate decays each epoch
:param seed: RNG seed for reproducibility.
+ :param user_metadata: To include in the exported model
+ :param ignore_checks: Ignores the data quality checks and YOLOs it
"""
input_version, input_basename = _check_for_files()
@@ -101,10 +104,15 @@ def run(
lr=lr,
lr_decay=lr_decay,
seed=seed,
+ local=False,
+ ignore_checks=ignore_checks,
)
- print("Exporting your model...")
- model_export_outdir = _get_valid_export_directory()
- model_export_outdir.mkdir(parents=True, exist_ok=False)
- model.net.export(model_export_outdir, user_metadata=user_metadata)
- print(f"Model exported to {model_export_outdir}. Enjoy!")
+ if model is None:
+ print("No model returned; skip exporting!")
+ else:
+ print("Exporting your model...")
+ model_export_outdir = _get_valid_export_directory()
+ model_export_outdir.mkdir(parents=True, exist_ok=False)
+ model.net.export(model_export_outdir, user_metadata=user_metadata)
+ print(f"Model exported to {model_export_outdir}. Enjoy!")
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -7,6 +7,7 @@ Functions used by the GUI trainer.
"""
import hashlib
+import tkinter as tk
from enum import Enum
from time import time
from typing import Optional, Sequence, Union
@@ -250,6 +251,7 @@ def _calibrate_delay(
plot(delay, input_path, output_path)
return delay
+
def _get_lstm_config(architecture):
return {
Architecture.STANDARD: {
@@ -272,11 +274,12 @@ def _get_lstm_config(architecture):
},
}[architecture]
+
def _check_v1(*args, **kwargs):
return True
-def _check_v2(input_path, output_path) -> bool:
+def _check_v2(input_path, output_path, delay: int, silent: bool) -> bool:
with torch.no_grad():
print("V2 checks...")
rate = REQUIRED_RATE
@@ -288,14 +291,18 @@ def _check_v2(input_path, output_path) -> bool:
# Do the blips line up?
# If the ESR is too bad, then flag it.
+ print("Checking blips...")
+
def get_blips(y):
"""
:return: [start/end,replicate]
"""
i0, i1 = rate // 4, 3 * rate // 4
j0, j1 = -3 * rate // 4, -rate // 4
- start = -1000
- end = 4000
+
+ i0, i1, j0, j1 = [i + delay for i in (i0, i1, j0, j1)]
+ start = -10
+ end = 1000
blips = torch.stack(
[
torch.stack([y[i0 + start : i0 + end], y[i1 + start : i1 + end]]),
@@ -310,25 +317,39 @@ def _check_v2(input_path, output_path) -> bool:
esr_cross_0 = esr(blips[0][0], blips[1][0]).item() # 1st repeat, start vs end
esr_cross_1 = esr(blips[0][1], blips[1][1]).item() # 2nd repeat, start vs end
- esr_threshold = 1.0e-3
-
- def plot_esr_blip_error(msg, arrays, labels):
- plt.figure()
- [plt.plot(array, label=label) for array, label in zip(arrays, labels)]
- plt.xlabel("Sample")
- plt.ylabel("Output")
- plt.legend()
- plt.grid()
+ print(" ESRs:")
+ print(f" Start : {esr_0}")
+ print(f" End : {esr_1}")
+ print(f" Cross (1) : {esr_cross_0}")
+ print(f" Cross (2) : {esr_cross_1}")
+
+ esr_threshold = 1.0e-2
+
+ def plot_esr_blip_error(silent, msg, arrays, labels):
+ if not silent:
+ plt.figure()
+ [plt.plot(array, label=label) for array, label in zip(arrays, labels)]
+ plt.xlabel("Sample")
+ plt.ylabel("Output")
+ plt.legend()
+ plt.grid()
print(msg)
- plt.show()
+ if not silent:
+ plt.show()
# Check consecutive blips
for e, blip_pair, when in zip((esr_0, esr_1), blips, ("start", "end")):
if e >= esr_threshold:
plot_esr_blip_error(
+ silent,
f"Failed consecutive blip check at {when} of training signal. The "
- "target tone doesn't seem to be replicable over short timespans. "
- "Is there a noise gate or a time-based effect in the signal chain?",
+ "target tone doesn't seem to be replicable over short timespans."
+ "\n\n"
+ " Possible causes:\n\n"
+ " * Your recording setup is really noisy.\n"
+ " * There's a noise gate that's messing things up.\n"
+ " * There's a time-based effect (compressor, delay, reverb) in "
+ "the signal chain",
blip_pair,
("Replicate 1", "Replicate 2"),
)
@@ -339,6 +360,7 @@ def _check_v2(input_path, output_path) -> bool:
):
if e >= esr_threshold:
plot_esr_blip_error(
+ silent,
f"Failed start-to-end blip check for blip replicate {replicate}. "
"The target tone doesn't seem to be same at the end of the reamp "
"as it was at the start. Did some setting change during reamping?",
@@ -349,7 +371,9 @@ def _check_v2(input_path, output_path) -> bool:
return True
-def _check(input_path: str, output_path: str, input_version: Version) -> bool:
+def _check(
+ input_path: str, output_path: str, input_version: Version, delay: int, silent: bool
+) -> bool:
"""
Ensure that everything should go smoothly
@@ -362,7 +386,7 @@ def _check(input_path: str, output_path: str, input_version: Version) -> bool:
else:
print(f"Checks not implemented for input version {input_version}; skip")
return True
- return f(input_path, output_path)
+ return f(input_path, output_path, delay, silent)
def _get_wavenet_config(architecture):
@@ -493,7 +517,7 @@ def _get_configs(
"delay": delay,
},
}
-
+
if model_type == "WaveNet":
model_config = {
"net": {
@@ -504,7 +528,10 @@ def _get_configs(
},
"loss": {"val_loss": "esr"},
"optimizer": {"lr": lr},
- "lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 1.0 - lr_decay}},
+ "lr_scheduler": {
+ "class": "ExponentialLR",
+ "kwargs": {"gamma": 1.0 - lr_decay},
+ },
}
else:
model_config = {
@@ -514,20 +541,13 @@ def _get_configs(
},
"loss": {
"val_loss": "mse",
- "mask_first": 4096,
+ "mask_first": 4096,
"pre_emph_weight": 1.0,
- "pre_emph_coef": 0.85
+ "pre_emph_coef": 0.85,
},
- "optimizer": {
- "lr": 0.01
- },
- "lr_scheduler": {
- "class": "ExponentialLR",
- "kwargs": {
- "gamma": 0.995
- }
- }
- }
+ "optimizer": {"lr": 0.01},
+ "lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 0.995}},
+ }
if torch.cuda.is_available():
device_config = {"accelerator": "gpu", "devices": 1}
@@ -600,6 +620,48 @@ def _plot(
plt.show()
+def _print_nasty_checks_warning():
+ """
+ "ffs" -Dom
+ """
+ print(
+ "\n"
+ "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX\n"
+ "X X\n"
+ "X WARNING: X\n"
+ "X X\n"
+ "X You are ignoring the checks! Your model might turn out bad! X\n"
+ "X X\n"
+ "X I warned you! X\n"
+ "X X\n"
+ "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX\n"
+ )
+
+
+def _nasty_checks_modal():
+ msg = "You are ignoring the checks!\nYour model might turn out bad!"
+
+ root = tk.Tk()
+ root.withdraw() # hide the root window
+ modal = tk.Toplevel(root)
+ modal.geometry("300x100")
+ modal.title("Warning!")
+ label = tk.Label(modal, text=msg)
+ label.pack(pady=10)
+ ok_button = tk.Button(
+ modal,
+ text="I can only blame myself!",
+ command=lambda: [modal.destroy(), root.quit()],
+ )
+ ok_button.pack()
+ modal.grab_set() # disable interaction with root window while modal is open
+ modal.mainloop()
+
+
+# Example usage:
+# show_modal("Hello, World!")
+
+
def train(
input_path: str,
output_path: str,
@@ -617,22 +679,39 @@ def train(
save_plot: bool = False,
silent: bool = False,
modelname: str = "model",
-):
+ ignore_checks: bool = False,
+ local: bool = False,
+) -> Optional[Model]:
if seed is not None:
torch.manual_seed(seed)
+ if input_version is None:
+ input_version = _detect_input_version(input_path)
+
if delay is None:
- if input_version is None:
- input_version = _detect_input_version(input_path)
delay = _calibrate_delay(
delay, input_version, input_path, output_path, silent=silent
)
else:
print(f"Delay provided as {delay}; skip calibration")
- if not _check(input_path, output_path, input_version):
- print("Failed checks; exit training")
- return
+ if _check(input_path, output_path, input_version, delay, silent):
+ print("-Checks passed")
+ else:
+ print("Failed checks!")
+ if ignore_checks:
+ if local and not silent:
+ _nasty_checks_modal()
+ else:
+ _print_nasty_checks_warning()
+ elif not local: # And not ignore_checks
+ print(
+ "(To disable this check, run AT YOUR OWN RISK with "
+ "`ignore_checks=True`.)"
+ )
+ if not ignore_checks:
+ print("Exiting core training...")
+ return
data_config, model_config, learning_config = _get_configs(
input_version,
diff --git a/nam/train/gui.py b/nam/train/gui.py
@@ -60,6 +60,7 @@ _BUTTON_HEIGHT = 2
_TEXT_WIDTH = 70
_DEFAULT_DELAY = None
+_DEFAULT_IGNORE_CHECKS = False
_ADVANCED_OPTIONS_LEFT_WIDTH = 12
_ADVANCED_OPTIONS_RIGHT_WIDTH = 12
@@ -71,6 +72,7 @@ class _AdvancedOptions(object):
architecture: core.Architecture
num_epochs: int
delay: Optional[int]
+ ignore_checks: bool
class _PathType(Enum):
@@ -203,7 +205,10 @@ class _GUI(object):
# Advanced options for training
default_architecture = core.Architecture.STANDARD
self.advanced_options = _AdvancedOptions(
- default_architecture, _DEFAULT_NUM_EPOCHS, _DEFAULT_DELAY
+ default_architecture,
+ _DEFAULT_NUM_EPOCHS,
+ _DEFAULT_DELAY,
+ _DEFAULT_IGNORE_CHECKS,
)
# Window to edit them:
self._frame_advanced_options = tk.Frame(self._root)
@@ -233,6 +238,23 @@ class _GUI(object):
self._check_button_states()
+ def _check_button_states(self):
+ """
+ Determine if any buttons should be disabled
+ """
+ # Train button is disabled unless all paths are set
+ if any(
+ pb.val is None
+ for pb in (
+ self._path_button_input,
+ self._path_button_output,
+ self._path_button_train_destination,
+ )
+ ):
+ self._button_train["state"] = tk.DISABLED
+ return
+ self._button_train["state"] = tk.NORMAL
+
def _get_additional_options_frame(self):
# Checkboxes
self._frame_silent = tk.Frame(self._root)
@@ -257,6 +279,16 @@ class _GUI(object):
)
self._chkbox_save_plot.grid(row=2, column=1, sticky="W")
+ # Skip the data quality checks!
+ self._ignore_checks = tk.BooleanVar()
+ self._ignore_checks.set(False)
+ self._chkbox_ignore_checks = tk.Checkbutton(
+ self._frame_silent,
+ text="Ignore data quality checks (DO AT YOUR OWN RISK!)",
+ variable=self._ignore_checks,
+ )
+ self._chkbox_ignore_checks.grid(row=3, column=1, sticky="W")
+
def mainloop(self):
self._root.mainloop()
@@ -311,7 +343,12 @@ class _GUI(object):
silent=self._silent.get(),
save_plot=self._save_plot.get(),
modelname=basename,
+ ignore_checks=self._ignore_checks.get(),
+ local=True,
)
+ if trained_model is None:
+ print("Model training failed! Skip exporting...")
+ continue
print("Model training complete!")
print("Exporting...")
outdir = self._path_button_train_destination.val
@@ -329,23 +366,6 @@ class _GUI(object):
# the user re-visits the window and clicks "ok"
self.user_metadata_flag = False
- def _check_button_states(self):
- """
- Determine if any buttons should be disabled
- """
- # Train button is diabled unless all paths are set
- if any(
- pb.val is None
- for pb in (
- self._path_button_input,
- self._path_button_output,
- self._path_button_train_destination,
- )
- ):
- self._button_train["state"] = tk.DISABLED
- return
- self._button_train["state"] = tk.NORMAL
-
# some typing functions
def _non_negative_int(val):