neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit a94215def62f69875d504c207fe993ea56ed8068
parent 92a241f230743536afef4da92c0090829b59dbd4
Author: vossen <44332958+vossenv@users.noreply.github.com>
Date:   Thu, 23 Mar 2023 22:35:19 -0500

Batched training and silent mode (#145)

* Batched training and silent


Add GUI checkbox for these options


rev


rev

* Feedback revision 

Path display, misc
Black format


Uneeded import


don't hardcode true
Diffstat:
Mnam/models/_exportable.py | 4++--
Mnam/train/core.py | 24+++++++++++++++++++-----
Mnam/train/gui.py | 79+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------
3 files changed, 80 insertions(+), 27 deletions(-)

diff --git a/nam/models/_exportable.py b/nam/models/_exportable.py @@ -21,7 +21,7 @@ class Exportable(abc.ABC): Interface for my custon export format for use in the plugin. """ - def export(self, outdir: Path, include_snapshot: bool = False): + def export(self, outdir: Path, include_snapshot: bool = False, modelname: str = "model"): """ Interface for exporting. You should create at least a `config.json` containing the two fields: @@ -37,7 +37,7 @@ class Exportable(abc.ABC): """ training = self.training self.eval() - with open(Path(outdir, "model.nam"), "w") as fp: + with open(Path(outdir, modelname + ".nam"), "w") as fp: json.dump( { "version": __version__, diff --git a/nam/train/core.py b/nam/train/core.py @@ -146,6 +146,7 @@ def _calibrate_delay( input_version: Version, input_path: str, output_path: str, + silent: bool=False ) -> int: if input_version.major == 1: calibrate, plot = _calibrate_delay_v1, _plot_delay_v1 @@ -158,7 +159,8 @@ def _calibrate_delay( else: print("Delay wasn't provided; attempting to calibrate automatically...") delay = calibrate(input_path, output_path) - plot(delay, input_path, output_path) + if not silent: + plot(delay, input_path, output_path) return delay @@ -308,7 +310,12 @@ def _esr(pred: torch.Tensor, target: torch.Tensor) -> float: def _plot( - model, ds, window_start: Optional[int] = None, window_end: Optional[int] = None + model, + ds, + window_start: Optional[int] = None, + window_end: Optional[int] = None, + filepath: Optional[str] = None, + silent: bool = False ): print("Plotting a comparison of your model with the target output...") with torch.no_grad(): @@ -339,8 +346,10 @@ def _plot( plt.plot(ds.y[window_start:window_end], linestyle="--", label="Target") plt.title(f"ESR={esr:.3f}") plt.legend() - plt.show() - + if filepath is not None: + plt.savefig(filepath + ".png") + if not silent: + plt.show() def train( input_path: str, @@ -353,6 +362,9 @@ def train( lr=0.004, lr_decay=0.007, seed: Optional[int] = 0, + save_plot: bool=False, + silent: bool=False, + modelname: str="model" ): if seed is not None: torch.manual_seed(seed) @@ -363,7 +375,7 @@ def train( if delay is None: if input_version is None: input_version = _detect_input_version(input_path) - delay = _calibrate_delay(delay, input_version, input_path, output_path) + delay = _calibrate_delay(delay, input_version, input_path, output_path, silent=silent) else: print(f"Delay provided as {delay}; skip calibration") @@ -416,5 +428,7 @@ def train( dataset_validation, window_start=100_000, # Start of the plotting window, in samples window_end=101_000, # End of the plotting window, in samples + filepath=train_path +'/'+ modelname if save_plot else None, + silent=silent ) return model diff --git a/nam/train/gui.py b/nam/train/gui.py @@ -24,6 +24,7 @@ def _ensure_graceful_shutdowns(): _ensure_graceful_shutdowns() +import re import tkinter as tk from dataclasses import dataclass from enum import Enum @@ -38,7 +39,6 @@ try: _install_is_valid = True except ImportError: _install_is_valid = False - _BUTTON_WIDTH = 20 _BUTTON_HEIGHT = 2 _TEXT_WIDTH = 70 @@ -57,6 +57,7 @@ class _AdvancedOptions(object): class _PathType(Enum): FILE = "file" DIRECTORY = "directory" + MULTIFILE = "multifile" class _PathButton(object): @@ -105,13 +106,16 @@ class _PathButton(object): self._label["fg"] = "red" self._label["text"] = f"{self._info_str} is not set!" else: + val = self.val + val = val[0] if isinstance(val, tuple) and len(val) == 1 else val self._label["fg"] = "black" - self._label["text"] = f"{self._info_str} set to {self.val}" + self._label["text"] = f"{self._info_str} set to {val}" def _set_val(self): res = { _PathType.FILE: filedialog.askopenfilename, _PathType.DIRECTORY: filedialog.askdirectory, + _PathType.MULTIFILE: filedialog.askopenfilenames, }[self._path_type]() if res != "": self._path = res @@ -144,7 +148,7 @@ class _GUI(object): self._frame_output_path, "Output Audio", "Output audio", - _PathType.FILE, + _PathType.MULTIFILE, hooks=[self._check_button_states], ) @@ -158,6 +162,9 @@ class _GUI(object): hooks=[self._check_button_states], ) + # This should probably be to the right somewhere + self._get_additional_options_frame() + # Advanced options for training default_architecture = core.Architecture.STANDARD self.advanced_options = _AdvancedOptions( @@ -191,6 +198,30 @@ class _GUI(object): self._check_button_states() + def _get_additional_options_frame(self): + # Checkboxes + self._frame_silent = tk.Frame(self._root) + self._frame_silent.pack(side=tk.LEFT) + + # Silent run (bypass popups) + self._silent = tk.BooleanVar() + self._chkbox_silent = tk.Checkbutton( + self._frame_silent, + text="Silent run", + variable=self._silent, + ) + self._chkbox_silent.grid(row=1, column=1, sticky="W") + + # Auto save the end plot + self._save_plot = tk.BooleanVar() + self._save_plot.set(True) # default this to true + self._chkbox_save_plot = tk.Checkbutton( + self._frame_silent, + text="Save plot automatically", + variable=self._save_plot, + ) + self._chkbox_save_plot.grid(row=2, column=1, sticky="W") + def mainloop(self): self._root.mainloop() @@ -208,6 +239,7 @@ class _GUI(object): num_epochs = self.advanced_options.num_epochs architecture = self.advanced_options.architecture delay = self.advanced_options.delay + file_list = self._path_button_output.val # Advanced-er options # If you're poking around looking for these, then maybe it's time to learn to @@ -217,23 +249,30 @@ class _GUI(object): seed = 0 # Run it - trained_model = core.train( - self._path_button_input.val, - self._path_button_output.val, - self._path_button_train_destination.val, - epochs=num_epochs, - delay=delay, - architecture=architecture, - lr=lr, - lr_decay=lr_decay, - seed=seed, - ) - print("Model training complete!") - print("Exporting...") - outdir = self._path_button_train_destination.val - print(f"Exporting trained model to {outdir}...") - trained_model.net.export(outdir) - print("Done!") + for file in file_list: + print("Now training {}".format(file)) + modelname = re.sub(r"\.wav$", "", file.split("/")[-1]) + + trained_model = core.train( + self._path_button_input.val, + file, + self._path_button_train_destination.val, + epochs=num_epochs, + delay=delay, + architecture=architecture, + lr=lr, + lr_decay=lr_decay, + seed=seed, + silent=self._silent.get(), + save_plot=self._save_plot.get(), + modelname=modelname, + ) + print("Model training complete!") + print("Exporting...") + outdir = self._path_button_train_destination.val + print(f"Exporting trained model to {outdir}...") + trained_model.net.export(outdir, modelname=modelname) + print("Done!") def _check_button_states(self): """