neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit b71db729a7fa1be4cea3bd72836c0677318954ab
parent 0ee6fd6c3a0c918035156dc9a6c54bee8d9470bb
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Sun, 12 May 2024 23:03:51 -0700

[FEATURE] Checkpoints save `.nam` files in addition to `.ckpt`s (#408)

* Update core.py

Extend PyTorch Lightning ModelCheckpoint to save and remove .nam
files alongside the .ckpt files.

* Remove unneeded checkpoint code

* Get sample rate into .nam checkpoints

* Update test_exportable.py
Diffstat:
Mnam/models/__init__.py | 2+-
Mnam/models/_base.py | 2+-
Dnam/models/_exportable.py | 159-------------------------------------------------------------------------------
Anam/models/exportable.py | 161+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Dnam/train/_errors.py | 18------------------
Mnam/train/colab.py | 2+-
Mnam/train/core.py | 71++++++++++++++++++++++++++++++++++++++++++++++++++---------------------
Mnam/train/gui.py | 72+++++++++++++++++++++---------------------------------------------------
Mtests/test_nam/test_models/test_exportable.py | 4++--
9 files changed, 237 insertions(+), 254 deletions(-)

diff --git a/nam/models/__init__.py b/nam/models/__init__.py @@ -7,7 +7,7 @@ NAM's neural networks """ from . import _base # noqa F401 -from . import _exportable # noqa F401 +from . import exportable # noqa F401 from . import losses # noqa F401 from . import wavenet # noqa F401 from .base import Model # noqa F401 diff --git a/nam/models/_base.py b/nam/models/_base.py @@ -18,7 +18,7 @@ import torch.nn as nn from .._core import InitializableFromConfig from ..data import wav_to_tensor -from ._exportable import Exportable +from .exportable import Exportable class _Base(nn.Module, InitializableFromConfig, Exportable): diff --git a/nam/models/_exportable.py b/nam/models/_exportable.py @@ -1,159 +0,0 @@ -# File: _exportable.py -# Created Date: Tuesday February 8th 2022 -# Author: Steven Atkinson (steven@atkinson.mn) - -import abc -import json -import logging -from datetime import datetime -from enum import Enum -from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Tuple, Union - -import numpy as np - -from ..data import np_to_wav -from .metadata import Date, UserMetadata - -logger = logging.getLogger(__name__) - -# Model version is independent from package version as of package version 0.5.2 so that -# the API of the package can iterate at a different pace from that of the model files. -_MODEL_VERSION = "0.5.2" - - -def _cast_enums(d: Dict[Any, Any]) -> Dict[Any, Any]: - """ - Casts enum-type keys to their values - """ - out = {} - for key, val in d.items(): - if isinstance(val, Enum): - val = val.value - out[key] = val - return out - - -class Exportable(abc.ABC): - """ - Interface for my custon export format for use in the plugin. - """ - - def export( - self, - outdir: Path, - include_snapshot: bool = False, - basename: str = "model", - user_metadata: Optional[UserMetadata] = None, - ): - """ - Interface for exporting. - You should create at least a `config.json` containing the two fields: - * "version" (str) - * "architecture" (str) - * "config": (dict w/ other necessary data like tensor shapes etc) - - :param outdir: Assumed to exist. Can be edited inside at will. - :param include_snapshots: If True, outputs `input.npy` and `output.npy` - Containing an example input/output pair that the model creates. This - Can be used to debug e.g. the implementation of the model in the - plugin. - """ - model_dict = self._get_export_dict() - model_dict["metadata"].update( - {} if user_metadata is None else _cast_enums(user_metadata.model_dump()) - ) - - training = self.training - self.eval() - with open(Path(outdir, f"{basename}.nam"), "w") as fp: - json.dump(model_dict, fp) - if include_snapshot: - x, y = self._export_input_output() - x_path = Path(outdir, "test_inputs.npy") - y_path = Path(outdir, "test_outputs.npy") - logger.debug(f"Saving snapshot input to {x_path}") - np.save(x_path, x) - logger.debug(f"Saving snapshot output to {y_path}") - np.save(y_path, y) - - # And resume training state - self.train(training) - - @abc.abstractmethod - def export_cpp_header(self, filename: Path): - """ - Export a .h file to compile into the plugin with the weights written right out - as text - """ - pass - - def export_onnx(self, filename: Path): - """ - Export model in format for ONNX Runtime - """ - raise NotImplementedError( - "Exporting to ONNX is not supported for models of type " - f"{self.__class__.__name__}" - ) - - def import_weights(self, weights: Sequence[float]): - """ - Inverse of `._export_weights() - """ - raise NotImplementedError( - f"Importing weights for models of type {self.__class__.__name__} isn't " - "implemented yet." - ) - - @abc.abstractmethod - def _export_config(self): - """ - Creates the JSON of the model's archtecture hyperparameters (number of layers, - number of units, etc) - - :return: a JSON serializable object - """ - pass - - @abc.abstractmethod - def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]: - """ - Create an input and corresponding output signal to verify its behavior. - - They should be the same length, but the start of the output might have transient - effects. Up to you to interpret. - """ - pass - - @abc.abstractmethod - def _export_weights(self) -> np.ndarray: - """ - Flatten the weights out to a 1D array - """ - pass - - def _get_export_dict(self): - return { - "version": _MODEL_VERSION, - "metadata": self._get_non_user_metadata(), - "architecture": self.__class__.__name__, - "config": self._export_config(), - "weights": self._export_weights().tolist(), - } - - def _get_non_user_metadata(self) -> Dict[str, Union[str, int, float]]: - """ - Get any metadata that's non-user-provided (date, loudness, gain) - """ - t = datetime.now() - return { - "date": Date( - year=t.year, - month=t.month, - day=t.day, - hour=t.hour, - minute=t.minute, - second=t.second, - ).model_dump() - } diff --git a/nam/models/exportable.py b/nam/models/exportable.py @@ -0,0 +1,161 @@ +# File: _exportable.py +# Created Date: Tuesday February 8th 2022 +# Author: Steven Atkinson (steven@atkinson.mn) + +import abc +import json +import logging +from datetime import datetime +from enum import Enum +from pathlib import Path +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import numpy as np + +from ..data import np_to_wav +from .metadata import Date, UserMetadata + +logger = logging.getLogger(__name__) + +# Model version is independent from package version as of package version 0.5.2 so that +# the API of the package can iterate at a different pace from that of the model files. +_MODEL_VERSION = "0.5.2" + + +def _cast_enums(d: Dict[Any, Any]) -> Dict[Any, Any]: + """ + Casts enum-type keys to their values + """ + out = {} + for key, val in d.items(): + if isinstance(val, Enum): + val = val.value + out[key] = val + return out + + +class Exportable(abc.ABC): + """ + Interface for my custon export format for use in the plugin. + """ + + FILE_EXTENSION = ".nam" + + def export( + self, + outdir: Path, + include_snapshot: bool = False, + basename: str = "model", + user_metadata: Optional[UserMetadata] = None, + ): + """ + Interface for exporting. + You should create at least a `config.json` containing the two fields: + * "version" (str) + * "architecture" (str) + * "config": (dict w/ other necessary data like tensor shapes etc) + + :param outdir: Assumed to exist. Can be edited inside at will. + :param include_snapshots: If True, outputs `input.npy` and `output.npy` + Containing an example input/output pair that the model creates. This + Can be used to debug e.g. the implementation of the model in the + plugin. + """ + model_dict = self._get_export_dict() + model_dict["metadata"].update( + {} if user_metadata is None else _cast_enums(user_metadata.model_dump()) + ) + + training = self.training + self.eval() + with open(Path(outdir, f"{basename}{self.FILE_EXTENSION}"), "w") as fp: + json.dump(model_dict, fp) + if include_snapshot: + x, y = self._export_input_output() + x_path = Path(outdir, "test_inputs.npy") + y_path = Path(outdir, "test_outputs.npy") + logger.debug(f"Saving snapshot input to {x_path}") + np.save(x_path, x) + logger.debug(f"Saving snapshot output to {y_path}") + np.save(y_path, y) + + # And resume training state + self.train(training) + + @abc.abstractmethod + def export_cpp_header(self, filename: Path): + """ + Export a .h file to compile into the plugin with the weights written right out + as text + """ + pass + + def export_onnx(self, filename: Path): + """ + Export model in format for ONNX Runtime + """ + raise NotImplementedError( + "Exporting to ONNX is not supported for models of type " + f"{self.__class__.__name__}" + ) + + def import_weights(self, weights: Sequence[float]): + """ + Inverse of `._export_weights() + """ + raise NotImplementedError( + f"Importing weights for models of type {self.__class__.__name__} isn't " + "implemented yet." + ) + + @abc.abstractmethod + def _export_config(self): + """ + Creates the JSON of the model's archtecture hyperparameters (number of layers, + number of units, etc) + + :return: a JSON serializable object + """ + pass + + @abc.abstractmethod + def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]: + """ + Create an input and corresponding output signal to verify its behavior. + + They should be the same length, but the start of the output might have transient + effects. Up to you to interpret. + """ + pass + + @abc.abstractmethod + def _export_weights(self) -> np.ndarray: + """ + Flatten the weights out to a 1D array + """ + pass + + def _get_export_dict(self): + return { + "version": _MODEL_VERSION, + "metadata": self._get_non_user_metadata(), + "architecture": self.__class__.__name__, + "config": self._export_config(), + "weights": self._export_weights().tolist(), + } + + def _get_non_user_metadata(self) -> Dict[str, Union[str, int, float]]: + """ + Get any metadata that's non-user-provided (date, loudness, gain) + """ + t = datetime.now() + return { + "date": Date( + year=t.year, + month=t.month, + day=t.day, + hour=t.hour, + minute=t.minute, + second=t.second, + ).model_dump() + } diff --git a/nam/train/_errors.py b/nam/train/_errors.py @@ -1,18 +0,0 @@ -# File: _errors.py -# Created Date: Saturday April 13th 2024 -# Author: Steven Atkinson (steven@atkinson.mn) - -""" -"What could go wrong?" -""" - -__all__ = ["IncompatibleCheckpointError"] - - -class IncompatibleCheckpointError(RuntimeError): - """ - Raised when model loading fails because the checkpoint didn't match the model - or its hyperparameters - """ - - pass diff --git a/nam/train/colab.py b/nam/train/colab.py @@ -81,7 +81,7 @@ def run( fit_cab: bool = False, ): """ - :param epochs: How amny epochs we'll train for. + :param epochs: How many epochs we'll train for. :param delay: How far the output algs the input due to round-trip latency during reamping, in samples. :param stage_1_channels: The number of channels in the WaveNet's first stage. diff --git a/nam/train/core.py b/nam/train/core.py @@ -25,9 +25,9 @@ from torch.utils.data import DataLoader from ..data import Split, init_dataset, wav_to_np, wav_to_tensor from ..models import Model +from ..models.exportable import Exportable from ..models.losses import esr from ..util import filter_warnings -from ._errors import IncompatibleCheckpointError from ._version import PROTEUS_VERSION, Version __all__ = ["train"] @@ -870,7 +870,6 @@ def _get_configs( lr_decay: float, batch_size: int, fit_cab: bool, - checkpoint: Optional[Path] = None, ): def get_kwargs(data_info: _DataInfo): if data_info.major_version == 1: @@ -960,8 +959,6 @@ def _get_configs( if fit_cab: model_config["loss"]["pre_emph_mrstft_weight"] = _CAB_MRSTFT_PRE_EMPH_WEIGHT model_config["loss"]["pre_emph_mrstft_coef"] = _CAB_MRSTFT_PRE_EMPH_COEF - if checkpoint: - model_config["checkpoint_path"] = checkpoint if torch.cuda.is_available(): device_config = {"accelerator": "gpu", "devices": 1} @@ -1095,15 +1092,56 @@ class _ValidationStopping(pl.callbacks.EarlyStopping): self.patience = np.inf +class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint): + """ + Extension to model checkpoint to save a .nam file as well as the .ckpt file. + """ + + _NAM_FILE_EXTENSION = Exportable.FILE_EXTENSION + + @classmethod + def _get_nam_filepath(cls, filepath: str) -> Path: + """ + Given a .ckpt filepath, figure out a .nam for it. + """ + if not filepath.endswith(cls.FILE_EXTENSION): + raise ValueError( + f"Checkpoint filepath {filepath} doesn't end in expected extension " + f"{cls.FILE_EXTENSION}" + ) + return Path(filepath[: -len(cls.FILE_EXTENSION)] + cls._NAM_FILE_EXTENSION) + + def _save_checkpoint(self, trainer: pl.Trainer, filepath: str): + # Save the .ckpt: + super()._save_checkpoint(trainer, filepath) + # Save the .nam: + nam_filepath = self._get_nam_filepath(filepath) + pl_model: Model = trainer.model + nam_model = pl_model.net + outdir = nam_filepath.parent + # HACK: Assume the extension + basename = nam_filepath.name[: -len(self._NAM_FILE_EXTENSION)] + nam_model.export( + outdir, + basename=basename, + ) + + def _remove_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None: + super()._remove_checkpoint(trainer, filepath) + nam_path = self._get_nam_filepath(filepath) + if nam_path.exists(): + nam_path.unlink() + + def _get_callbacks(threshold_esr: Optional[float]): callbacks = [ - pl.callbacks.model_checkpoint.ModelCheckpoint( + _ModelCheckpoint( filename="checkpoint_best_{epoch:04d}_{step}_{ESR:.4g}_{MSE:.3e}", save_top_k=3, monitor="val_loss", every_n_epochs=1, ), - pl.callbacks.model_checkpoint.ModelCheckpoint( + _ModelCheckpoint( filename="checkpoint_last_{epoch:04d}_{step}", every_n_epochs=1 ), ] @@ -1135,7 +1173,6 @@ def train( local: bool = False, fit_cab: bool = False, threshold_esr: Optional[bool] = None, - checkpoint: Optional[Path] = None, ) -> Optional[Model]: """ :param threshold_esr: Stop training if ESR is better than this. Ignore if `None`. @@ -1184,7 +1221,6 @@ def train( lr_decay, batch_size, fit_cab, - checkpoint=checkpoint, ) print("Starting training. It's time to kick ass and chew bubblegum!") @@ -1193,16 +1229,7 @@ def train( # * Model is re-instantiated after training anyways. # (Hacky) solution: set sample rate in model from dataloader after second # instantiation from final checkpoint. - try: - model = Model.init_from_config(model_config) - except RuntimeError as e: - if "Error(s) in loading state_dict for Model:" in str(e): - raise IncompatibleCheckpointError( - "Model initialization failed; the checkpoint used seems to be " - f"incompatible.\n\nOriginal error:\n\n{e}" - ) - else: - raise e + model = Model.init_from_config(model_config) train_dataloader, val_dataloader = _get_dataloaders( data_config, learning_config, model ) @@ -1212,6 +1239,8 @@ def train( f"{train_dataloader.dataset.sample_rate}, " f"{val_dataloader.dataset.sample_rate}" ) + sample_rate = train_dataloader.dataset.sample_rate + model.net.sample_rate = sample_rate trainer = pl.Trainer( callbacks=_get_callbacks(threshold_esr), @@ -1220,8 +1249,7 @@ def train( ) # Suppress the PossibleUserWarning about num_workers (Issue 345) with filter_warnings("ignore", category=PossibleUserWarning): - trainer_fit_kwargs = {} if checkpoint is None else {"ckpt_path": checkpoint} - trainer.fit(model, train_dataloader, val_dataloader, **trainer_fit_kwargs) + trainer.fit(model, train_dataloader, val_dataloader) # Go to best checkpoint best_checkpoint = trainer.checkpoint_callback.best_model_path @@ -1232,7 +1260,8 @@ def train( ) model.cpu() model.eval() - model.net.sample_rate = train_dataloader.dataset.sample_rate + # HACK set again + model.net.sample_rate = sample_rate def window_kwargs(version: Version): if version.major == 1: diff --git a/nam/train/gui.py b/nam/train/gui.py @@ -43,7 +43,6 @@ try: # 3rd-party and 1st-party imports from nam.models.metadata import GearType, UserMetadata, ToneType # Ok private access here--this is technically allowed access - from nam.train._errors import IncompatibleCheckpointError from nam.train._names import INPUT_BASENAMES, LATEST_VERSION _install_is_valid = True @@ -67,7 +66,6 @@ _TEXT_WIDTH = 70 _DEFAULT_DELAY = None _DEFAULT_IGNORE_CHECKS = False _DEFAULT_THRESHOLD_ESR = None -_DEFAULT_CHECKPOINT = None _ADVANCED_OPTIONS_LEFT_WIDTH = 12 _ADVANCED_OPTIONS_RIGHT_WIDTH = 12 @@ -84,7 +82,6 @@ class _AdvancedOptions(object): :param ignore_checks: Keep going even if a check says that something is wrong. :param threshold_esr: Stop training if the ESR gets better than this. If None, don't stop. - :param checkpoint: If provided, try to restart from this checkpoint. """ architecture: core.Architecture @@ -92,7 +89,6 @@ class _AdvancedOptions(object): latency: Optional[int] ignore_checks: bool threshold_esr: Optional[float] - checkpoint: Optional[Path] class _PathType(Enum): @@ -364,7 +360,6 @@ class _GUI(object): _DEFAULT_DELAY, _DEFAULT_IGNORE_CHECKS, _DEFAULT_THRESHOLD_ESR, - _DEFAULT_CHECKPOINT, ) # Window to edit them: @@ -487,7 +482,6 @@ class _GUI(object): delay = self.advanced_options.latency file_list = self._widgets[_GUIWidgets.OUTPUT_PATH].val threshold_esr = self.advanced_options.threshold_esr - checkpoint = self.advanced_options.checkpoint # Advanced-er options # If you're poking around looking for these, then maybe it's time to learn to @@ -502,36 +496,27 @@ class _GUI(object): print("Now training {}".format(file)) basename = re.sub(r"\.wav$", "", file.split("/")[-1]) - try: - trained_model = core.train( - self._widgets[_GUIWidgets.INPUT_PATH].val, - file, - self._widgets[_GUIWidgets.TRAINING_DESTINATION].val, - epochs=num_epochs, - delay=delay, - architecture=architecture, - batch_size=batch_size, - lr=lr, - lr_decay=lr_decay, - seed=seed, - silent=self._checkboxes[ - _CheckboxKeys.SILENT_TRAINING - ].variable.get(), - save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(), - modelname=basename, - ignore_checks=self._checkboxes[ - _CheckboxKeys.IGNORE_DATA_CHECKS - ].variable.get(), - local=True, - fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(), - threshold_esr=threshold_esr, - checkpoint=checkpoint, - ) - except IncompatibleCheckpointError as e: - trained_model = None - self._wait_while_func( - _BasicModal, "Training failed due to incompatible checkpoint!" - ) + trained_model = core.train( + self._widgets[_GUIWidgets.INPUT_PATH].val, + file, + self._widgets[_GUIWidgets.TRAINING_DESTINATION].val, + epochs=num_epochs, + delay=delay, + architecture=architecture, + batch_size=batch_size, + lr=lr, + lr_decay=lr_decay, + seed=seed, + silent=self._checkboxes[_CheckboxKeys.SILENT_TRAINING].variable.get(), + save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(), + modelname=basename, + ignore_checks=self._checkboxes[ + _CheckboxKeys.IGNORE_DATA_CHECKS + ].variable.get(), + local=True, + fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(), + threshold_esr=threshold_esr, + ) if trained_model is None: print("Model training failed! Skip exporting...") @@ -755,17 +740,6 @@ class _AdvancedOptionsGUI(object): type=_float_or_null, ) - # Restart from a checkpoint - self._frame_checkpoint = tk.Frame(self._root) - self._frame_checkpoint.pack() - self._path_button_checkpoint = _ClearablePathButton( - self._frame_checkpoint, - "Checkpoint", - "[Optional] Select a checkpoint (.ckpt file) to restart training from", - _PathType.FILE, - default=self._parent.advanced_options.checkpoint, - ) - # "Ok": apply and destory self._frame_ok = tk.Frame(self._root) self._frame_ok.pack() @@ -798,10 +772,6 @@ class _AdvancedOptionsGUI(object): self._parent.advanced_options.threshold_esr = ( None if threshold_esr == "null" else threshold_esr ) - checkpoint_path = self._path_button_checkpoint.val - self._parent.advanced_options.checkpoint = ( - None if checkpoint_path is None else Path(checkpoint_path) - ) self._root.destroy() self._resume_main() diff --git a/tests/test_nam/test_models/test_exportable.py b/tests/test_nam/test_models/test_exportable.py @@ -17,7 +17,7 @@ import pytest import torch import torch.nn as nn -from nam.models import _exportable +from nam.models import exportable from nam.models import metadata @@ -105,7 +105,7 @@ class TestExportable(object): @classmethod def _get_model(cls): - class Model(nn.Module, _exportable.Exportable): + class Model(nn.Module, exportable.Exportable): def __init__(self): super().__init__() self._scale = nn.Parameter(torch.tensor(0.0))