commit a94215def62f69875d504c207fe993ea56ed8068
parent 92a241f230743536afef4da92c0090829b59dbd4
Author: vossen <44332958+vossenv@users.noreply.github.com>
Date: Thu, 23 Mar 2023 22:35:19 -0500
Batched training and silent mode (#145)
* Batched training and silent
Add GUI checkbox for these options
rev
rev
* Feedback revision
Path display, misc
Black format
Uneeded import
don't hardcode true
Diffstat:
3 files changed, 80 insertions(+), 27 deletions(-)
diff --git a/nam/models/_exportable.py b/nam/models/_exportable.py
@@ -21,7 +21,7 @@ class Exportable(abc.ABC):
Interface for my custon export format for use in the plugin.
"""
- def export(self, outdir: Path, include_snapshot: bool = False):
+ def export(self, outdir: Path, include_snapshot: bool = False, modelname: str = "model"):
"""
Interface for exporting.
You should create at least a `config.json` containing the two fields:
@@ -37,7 +37,7 @@ class Exportable(abc.ABC):
"""
training = self.training
self.eval()
- with open(Path(outdir, "model.nam"), "w") as fp:
+ with open(Path(outdir, modelname + ".nam"), "w") as fp:
json.dump(
{
"version": __version__,
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -146,6 +146,7 @@ def _calibrate_delay(
input_version: Version,
input_path: str,
output_path: str,
+ silent: bool=False
) -> int:
if input_version.major == 1:
calibrate, plot = _calibrate_delay_v1, _plot_delay_v1
@@ -158,7 +159,8 @@ def _calibrate_delay(
else:
print("Delay wasn't provided; attempting to calibrate automatically...")
delay = calibrate(input_path, output_path)
- plot(delay, input_path, output_path)
+ if not silent:
+ plot(delay, input_path, output_path)
return delay
@@ -308,7 +310,12 @@ def _esr(pred: torch.Tensor, target: torch.Tensor) -> float:
def _plot(
- model, ds, window_start: Optional[int] = None, window_end: Optional[int] = None
+ model,
+ ds,
+ window_start: Optional[int] = None,
+ window_end: Optional[int] = None,
+ filepath: Optional[str] = None,
+ silent: bool = False
):
print("Plotting a comparison of your model with the target output...")
with torch.no_grad():
@@ -339,8 +346,10 @@ def _plot(
plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target")
plt.title(f"ESR={esr:.3f}")
plt.legend()
- plt.show()
-
+ if filepath is not None:
+ plt.savefig(filepath + ".png")
+ if not silent:
+ plt.show()
def train(
input_path: str,
@@ -353,6 +362,9 @@ def train(
lr=0.004,
lr_decay=0.007,
seed: Optional[int] = 0,
+ save_plot: bool=False,
+ silent: bool=False,
+ modelname: str="model"
):
if seed is not None:
torch.manual_seed(seed)
@@ -363,7 +375,7 @@ def train(
if delay is None:
if input_version is None:
input_version = _detect_input_version(input_path)
- delay = _calibrate_delay(delay, input_version, input_path, output_path)
+ delay = _calibrate_delay(delay, input_version, input_path, output_path, silent=silent)
else:
print(f"Delay provided as {delay}; skip calibration")
@@ -416,5 +428,7 @@ def train(
dataset_validation,
window_start=100_000, # Start of the plotting window, in samples
window_end=101_000, # End of the plotting window, in samples
+ filepath=train_path +'/'+ modelname if save_plot else None,
+ silent=silent
)
return model
diff --git a/nam/train/gui.py b/nam/train/gui.py
@@ -24,6 +24,7 @@ def _ensure_graceful_shutdowns():
_ensure_graceful_shutdowns()
+import re
import tkinter as tk
from dataclasses import dataclass
from enum import Enum
@@ -38,7 +39,6 @@ try:
_install_is_valid = True
except ImportError:
_install_is_valid = False
-
_BUTTON_WIDTH = 20
_BUTTON_HEIGHT = 2
_TEXT_WIDTH = 70
@@ -57,6 +57,7 @@ class _AdvancedOptions(object):
class _PathType(Enum):
FILE = "file"
DIRECTORY = "directory"
+ MULTIFILE = "multifile"
class _PathButton(object):
@@ -105,13 +106,16 @@ class _PathButton(object):
self._label["fg"] = "red"
self._label["text"] = f"{self._info_str} is not set!"
else:
+ val = self.val
+ val = val[0] if isinstance(val, tuple) and len(val) == 1 else val
self._label["fg"] = "black"
- self._label["text"] = f"{self._info_str} set to {self.val}"
+ self._label["text"] = f"{self._info_str} set to {val}"
def _set_val(self):
res = {
_PathType.FILE: filedialog.askopenfilename,
_PathType.DIRECTORY: filedialog.askdirectory,
+ _PathType.MULTIFILE: filedialog.askopenfilenames,
}[self._path_type]()
if res != "":
self._path = res
@@ -144,7 +148,7 @@ class _GUI(object):
self._frame_output_path,
"Output Audio",
"Output audio",
- _PathType.FILE,
+ _PathType.MULTIFILE,
hooks=[self._check_button_states],
)
@@ -158,6 +162,9 @@ class _GUI(object):
hooks=[self._check_button_states],
)
+ # This should probably be to the right somewhere
+ self._get_additional_options_frame()
+
# Advanced options for training
default_architecture = core.Architecture.STANDARD
self.advanced_options = _AdvancedOptions(
@@ -191,6 +198,30 @@ class _GUI(object):
self._check_button_states()
+ def _get_additional_options_frame(self):
+ # Checkboxes
+ self._frame_silent = tk.Frame(self._root)
+ self._frame_silent.pack(side=tk.LEFT)
+
+ # Silent run (bypass popups)
+ self._silent = tk.BooleanVar()
+ self._chkbox_silent = tk.Checkbutton(
+ self._frame_silent,
+ text="Silent run",
+ variable=self._silent,
+ )
+ self._chkbox_silent.grid(row=1, column=1, sticky="W")
+
+ # Auto save the end plot
+ self._save_plot = tk.BooleanVar()
+ self._save_plot.set(True) # default this to true
+ self._chkbox_save_plot = tk.Checkbutton(
+ self._frame_silent,
+ text="Save plot automatically",
+ variable=self._save_plot,
+ )
+ self._chkbox_save_plot.grid(row=2, column=1, sticky="W")
+
def mainloop(self):
self._root.mainloop()
@@ -208,6 +239,7 @@ class _GUI(object):
num_epochs = self.advanced_options.num_epochs
architecture = self.advanced_options.architecture
delay = self.advanced_options.delay
+ file_list = self._path_button_output.val
# Advanced-er options
# If you're poking around looking for these, then maybe it's time to learn to
@@ -217,23 +249,30 @@ class _GUI(object):
seed = 0
# Run it
- trained_model = core.train(
- self._path_button_input.val,
- self._path_button_output.val,
- self._path_button_train_destination.val,
- epochs=num_epochs,
- delay=delay,
- architecture=architecture,
- lr=lr,
- lr_decay=lr_decay,
- seed=seed,
- )
- print("Model training complete!")
- print("Exporting...")
- outdir = self._path_button_train_destination.val
- print(f"Exporting trained model to {outdir}...")
- trained_model.net.export(outdir)
- print("Done!")
+ for file in file_list:
+ print("Now training {}".format(file))
+ modelname = re.sub(r"\.wav$", "", file.split("/")[-1])
+
+ trained_model = core.train(
+ self._path_button_input.val,
+ file,
+ self._path_button_train_destination.val,
+ epochs=num_epochs,
+ delay=delay,
+ architecture=architecture,
+ lr=lr,
+ lr_decay=lr_decay,
+ seed=seed,
+ silent=self._silent.get(),
+ save_plot=self._save_plot.get(),
+ modelname=modelname,
+ )
+ print("Model training complete!")
+ print("Exporting...")
+ outdir = self._path_button_train_destination.val
+ print(f"Exporting trained model to {outdir}...")
+ trained_model.net.export(outdir, modelname=modelname)
+ print("Done!")
def _check_button_states(self):
"""