commit b71db729a7fa1be4cea3bd72836c0677318954ab
parent 0ee6fd6c3a0c918035156dc9a6c54bee8d9470bb
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sun, 12 May 2024 23:03:51 -0700
[FEATURE] Checkpoints save `.nam` files in addition to `.ckpt`s (#408)
* Update core.py
Extend PyTorch Lightning ModelCheckpoint to save and remove .nam
files alongside the .ckpt files.
* Remove unneeded checkpoint code
* Get sample rate into .nam checkpoints
* Update test_exportable.py
Diffstat:
9 files changed, 237 insertions(+), 254 deletions(-)
diff --git a/nam/models/__init__.py b/nam/models/__init__.py
@@ -7,7 +7,7 @@ NAM's neural networks
"""
from . import _base # noqa F401
-from . import _exportable # noqa F401
+from . import exportable # noqa F401
from . import losses # noqa F401
from . import wavenet # noqa F401
from .base import Model # noqa F401
diff --git a/nam/models/_base.py b/nam/models/_base.py
@@ -18,7 +18,7 @@ import torch.nn as nn
from .._core import InitializableFromConfig
from ..data import wav_to_tensor
-from ._exportable import Exportable
+from .exportable import Exportable
class _Base(nn.Module, InitializableFromConfig, Exportable):
diff --git a/nam/models/_exportable.py b/nam/models/_exportable.py
@@ -1,159 +0,0 @@
-# File: _exportable.py
-# Created Date: Tuesday February 8th 2022
-# Author: Steven Atkinson (steven@atkinson.mn)
-
-import abc
-import json
-import logging
-from datetime import datetime
-from enum import Enum
-from pathlib import Path
-from typing import Any, Dict, Optional, Sequence, Tuple, Union
-
-import numpy as np
-
-from ..data import np_to_wav
-from .metadata import Date, UserMetadata
-
-logger = logging.getLogger(__name__)
-
-# Model version is independent from package version as of package version 0.5.2 so that
-# the API of the package can iterate at a different pace from that of the model files.
-_MODEL_VERSION = "0.5.2"
-
-
-def _cast_enums(d: Dict[Any, Any]) -> Dict[Any, Any]:
- """
- Casts enum-type keys to their values
- """
- out = {}
- for key, val in d.items():
- if isinstance(val, Enum):
- val = val.value
- out[key] = val
- return out
-
-
-class Exportable(abc.ABC):
- """
- Interface for my custon export format for use in the plugin.
- """
-
- def export(
- self,
- outdir: Path,
- include_snapshot: bool = False,
- basename: str = "model",
- user_metadata: Optional[UserMetadata] = None,
- ):
- """
- Interface for exporting.
- You should create at least a `config.json` containing the two fields:
- * "version" (str)
- * "architecture" (str)
- * "config": (dict w/ other necessary data like tensor shapes etc)
-
- :param outdir: Assumed to exist. Can be edited inside at will.
- :param include_snapshots: If True, outputs `input.npy` and `output.npy`
- Containing an example input/output pair that the model creates. This
- Can be used to debug e.g. the implementation of the model in the
- plugin.
- """
- model_dict = self._get_export_dict()
- model_dict["metadata"].update(
- {} if user_metadata is None else _cast_enums(user_metadata.model_dump())
- )
-
- training = self.training
- self.eval()
- with open(Path(outdir, f"{basename}.nam"), "w") as fp:
- json.dump(model_dict, fp)
- if include_snapshot:
- x, y = self._export_input_output()
- x_path = Path(outdir, "test_inputs.npy")
- y_path = Path(outdir, "test_outputs.npy")
- logger.debug(f"Saving snapshot input to {x_path}")
- np.save(x_path, x)
- logger.debug(f"Saving snapshot output to {y_path}")
- np.save(y_path, y)
-
- # And resume training state
- self.train(training)
-
- @abc.abstractmethod
- def export_cpp_header(self, filename: Path):
- """
- Export a .h file to compile into the plugin with the weights written right out
- as text
- """
- pass
-
- def export_onnx(self, filename: Path):
- """
- Export model in format for ONNX Runtime
- """
- raise NotImplementedError(
- "Exporting to ONNX is not supported for models of type "
- f"{self.__class__.__name__}"
- )
-
- def import_weights(self, weights: Sequence[float]):
- """
- Inverse of `._export_weights()
- """
- raise NotImplementedError(
- f"Importing weights for models of type {self.__class__.__name__} isn't "
- "implemented yet."
- )
-
- @abc.abstractmethod
- def _export_config(self):
- """
- Creates the JSON of the model's archtecture hyperparameters (number of layers,
- number of units, etc)
-
- :return: a JSON serializable object
- """
- pass
-
- @abc.abstractmethod
- def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]:
- """
- Create an input and corresponding output signal to verify its behavior.
-
- They should be the same length, but the start of the output might have transient
- effects. Up to you to interpret.
- """
- pass
-
- @abc.abstractmethod
- def _export_weights(self) -> np.ndarray:
- """
- Flatten the weights out to a 1D array
- """
- pass
-
- def _get_export_dict(self):
- return {
- "version": _MODEL_VERSION,
- "metadata": self._get_non_user_metadata(),
- "architecture": self.__class__.__name__,
- "config": self._export_config(),
- "weights": self._export_weights().tolist(),
- }
-
- def _get_non_user_metadata(self) -> Dict[str, Union[str, int, float]]:
- """
- Get any metadata that's non-user-provided (date, loudness, gain)
- """
- t = datetime.now()
- return {
- "date": Date(
- year=t.year,
- month=t.month,
- day=t.day,
- hour=t.hour,
- minute=t.minute,
- second=t.second,
- ).model_dump()
- }
diff --git a/nam/models/exportable.py b/nam/models/exportable.py
@@ -0,0 +1,161 @@
+# File: _exportable.py
+# Created Date: Tuesday February 8th 2022
+# Author: Steven Atkinson (steven@atkinson.mn)
+
+import abc
+import json
+import logging
+from datetime import datetime
+from enum import Enum
+from pathlib import Path
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
+
+import numpy as np
+
+from ..data import np_to_wav
+from .metadata import Date, UserMetadata
+
+logger = logging.getLogger(__name__)
+
+# Model version is independent from package version as of package version 0.5.2 so that
+# the API of the package can iterate at a different pace from that of the model files.
+_MODEL_VERSION = "0.5.2"
+
+
+def _cast_enums(d: Dict[Any, Any]) -> Dict[Any, Any]:
+ """
+ Casts enum-type keys to their values
+ """
+ out = {}
+ for key, val in d.items():
+ if isinstance(val, Enum):
+ val = val.value
+ out[key] = val
+ return out
+
+
+class Exportable(abc.ABC):
+ """
+ Interface for my custon export format for use in the plugin.
+ """
+
+ FILE_EXTENSION = ".nam"
+
+ def export(
+ self,
+ outdir: Path,
+ include_snapshot: bool = False,
+ basename: str = "model",
+ user_metadata: Optional[UserMetadata] = None,
+ ):
+ """
+ Interface for exporting.
+ You should create at least a `config.json` containing the two fields:
+ * "version" (str)
+ * "architecture" (str)
+ * "config": (dict w/ other necessary data like tensor shapes etc)
+
+ :param outdir: Assumed to exist. Can be edited inside at will.
+ :param include_snapshots: If True, outputs `input.npy` and `output.npy`
+ Containing an example input/output pair that the model creates. This
+ Can be used to debug e.g. the implementation of the model in the
+ plugin.
+ """
+ model_dict = self._get_export_dict()
+ model_dict["metadata"].update(
+ {} if user_metadata is None else _cast_enums(user_metadata.model_dump())
+ )
+
+ training = self.training
+ self.eval()
+ with open(Path(outdir, f"{basename}{self.FILE_EXTENSION}"), "w") as fp:
+ json.dump(model_dict, fp)
+ if include_snapshot:
+ x, y = self._export_input_output()
+ x_path = Path(outdir, "test_inputs.npy")
+ y_path = Path(outdir, "test_outputs.npy")
+ logger.debug(f"Saving snapshot input to {x_path}")
+ np.save(x_path, x)
+ logger.debug(f"Saving snapshot output to {y_path}")
+ np.save(y_path, y)
+
+ # And resume training state
+ self.train(training)
+
+ @abc.abstractmethod
+ def export_cpp_header(self, filename: Path):
+ """
+ Export a .h file to compile into the plugin with the weights written right out
+ as text
+ """
+ pass
+
+ def export_onnx(self, filename: Path):
+ """
+ Export model in format for ONNX Runtime
+ """
+ raise NotImplementedError(
+ "Exporting to ONNX is not supported for models of type "
+ f"{self.__class__.__name__}"
+ )
+
+ def import_weights(self, weights: Sequence[float]):
+ """
+ Inverse of `._export_weights()
+ """
+ raise NotImplementedError(
+ f"Importing weights for models of type {self.__class__.__name__} isn't "
+ "implemented yet."
+ )
+
+ @abc.abstractmethod
+ def _export_config(self):
+ """
+ Creates the JSON of the model's archtecture hyperparameters (number of layers,
+ number of units, etc)
+
+ :return: a JSON serializable object
+ """
+ pass
+
+ @abc.abstractmethod
+ def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Create an input and corresponding output signal to verify its behavior.
+
+ They should be the same length, but the start of the output might have transient
+ effects. Up to you to interpret.
+ """
+ pass
+
+ @abc.abstractmethod
+ def _export_weights(self) -> np.ndarray:
+ """
+ Flatten the weights out to a 1D array
+ """
+ pass
+
+ def _get_export_dict(self):
+ return {
+ "version": _MODEL_VERSION,
+ "metadata": self._get_non_user_metadata(),
+ "architecture": self.__class__.__name__,
+ "config": self._export_config(),
+ "weights": self._export_weights().tolist(),
+ }
+
+ def _get_non_user_metadata(self) -> Dict[str, Union[str, int, float]]:
+ """
+ Get any metadata that's non-user-provided (date, loudness, gain)
+ """
+ t = datetime.now()
+ return {
+ "date": Date(
+ year=t.year,
+ month=t.month,
+ day=t.day,
+ hour=t.hour,
+ minute=t.minute,
+ second=t.second,
+ ).model_dump()
+ }
diff --git a/nam/train/_errors.py b/nam/train/_errors.py
@@ -1,18 +0,0 @@
-# File: _errors.py
-# Created Date: Saturday April 13th 2024
-# Author: Steven Atkinson (steven@atkinson.mn)
-
-"""
-"What could go wrong?"
-"""
-
-__all__ = ["IncompatibleCheckpointError"]
-
-
-class IncompatibleCheckpointError(RuntimeError):
- """
- Raised when model loading fails because the checkpoint didn't match the model
- or its hyperparameters
- """
-
- pass
diff --git a/nam/train/colab.py b/nam/train/colab.py
@@ -81,7 +81,7 @@ def run(
fit_cab: bool = False,
):
"""
- :param epochs: How amny epochs we'll train for.
+ :param epochs: How many epochs we'll train for.
:param delay: How far the output algs the input due to round-trip latency during
reamping, in samples.
:param stage_1_channels: The number of channels in the WaveNet's first stage.
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -25,9 +25,9 @@ from torch.utils.data import DataLoader
from ..data import Split, init_dataset, wav_to_np, wav_to_tensor
from ..models import Model
+from ..models.exportable import Exportable
from ..models.losses import esr
from ..util import filter_warnings
-from ._errors import IncompatibleCheckpointError
from ._version import PROTEUS_VERSION, Version
__all__ = ["train"]
@@ -870,7 +870,6 @@ def _get_configs(
lr_decay: float,
batch_size: int,
fit_cab: bool,
- checkpoint: Optional[Path] = None,
):
def get_kwargs(data_info: _DataInfo):
if data_info.major_version == 1:
@@ -960,8 +959,6 @@ def _get_configs(
if fit_cab:
model_config["loss"]["pre_emph_mrstft_weight"] = _CAB_MRSTFT_PRE_EMPH_WEIGHT
model_config["loss"]["pre_emph_mrstft_coef"] = _CAB_MRSTFT_PRE_EMPH_COEF
- if checkpoint:
- model_config["checkpoint_path"] = checkpoint
if torch.cuda.is_available():
device_config = {"accelerator": "gpu", "devices": 1}
@@ -1095,15 +1092,56 @@ class _ValidationStopping(pl.callbacks.EarlyStopping):
self.patience = np.inf
+class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint):
+ """
+ Extension to model checkpoint to save a .nam file as well as the .ckpt file.
+ """
+
+ _NAM_FILE_EXTENSION = Exportable.FILE_EXTENSION
+
+ @classmethod
+ def _get_nam_filepath(cls, filepath: str) -> Path:
+ """
+ Given a .ckpt filepath, figure out a .nam for it.
+ """
+ if not filepath.endswith(cls.FILE_EXTENSION):
+ raise ValueError(
+ f"Checkpoint filepath {filepath} doesn't end in expected extension "
+ f"{cls.FILE_EXTENSION}"
+ )
+ return Path(filepath[: -len(cls.FILE_EXTENSION)] + cls._NAM_FILE_EXTENSION)
+
+ def _save_checkpoint(self, trainer: pl.Trainer, filepath: str):
+ # Save the .ckpt:
+ super()._save_checkpoint(trainer, filepath)
+ # Save the .nam:
+ nam_filepath = self._get_nam_filepath(filepath)
+ pl_model: Model = trainer.model
+ nam_model = pl_model.net
+ outdir = nam_filepath.parent
+ # HACK: Assume the extension
+ basename = nam_filepath.name[: -len(self._NAM_FILE_EXTENSION)]
+ nam_model.export(
+ outdir,
+ basename=basename,
+ )
+
+ def _remove_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None:
+ super()._remove_checkpoint(trainer, filepath)
+ nam_path = self._get_nam_filepath(filepath)
+ if nam_path.exists():
+ nam_path.unlink()
+
+
def _get_callbacks(threshold_esr: Optional[float]):
callbacks = [
- pl.callbacks.model_checkpoint.ModelCheckpoint(
+ _ModelCheckpoint(
filename="checkpoint_best_{epoch:04d}_{step}_{ESR:.4g}_{MSE:.3e}",
save_top_k=3,
monitor="val_loss",
every_n_epochs=1,
),
- pl.callbacks.model_checkpoint.ModelCheckpoint(
+ _ModelCheckpoint(
filename="checkpoint_last_{epoch:04d}_{step}", every_n_epochs=1
),
]
@@ -1135,7 +1173,6 @@ def train(
local: bool = False,
fit_cab: bool = False,
threshold_esr: Optional[bool] = None,
- checkpoint: Optional[Path] = None,
) -> Optional[Model]:
"""
:param threshold_esr: Stop training if ESR is better than this. Ignore if `None`.
@@ -1184,7 +1221,6 @@ def train(
lr_decay,
batch_size,
fit_cab,
- checkpoint=checkpoint,
)
print("Starting training. It's time to kick ass and chew bubblegum!")
@@ -1193,16 +1229,7 @@ def train(
# * Model is re-instantiated after training anyways.
# (Hacky) solution: set sample rate in model from dataloader after second
# instantiation from final checkpoint.
- try:
- model = Model.init_from_config(model_config)
- except RuntimeError as e:
- if "Error(s) in loading state_dict for Model:" in str(e):
- raise IncompatibleCheckpointError(
- "Model initialization failed; the checkpoint used seems to be "
- f"incompatible.\n\nOriginal error:\n\n{e}"
- )
- else:
- raise e
+ model = Model.init_from_config(model_config)
train_dataloader, val_dataloader = _get_dataloaders(
data_config, learning_config, model
)
@@ -1212,6 +1239,8 @@ def train(
f"{train_dataloader.dataset.sample_rate}, "
f"{val_dataloader.dataset.sample_rate}"
)
+ sample_rate = train_dataloader.dataset.sample_rate
+ model.net.sample_rate = sample_rate
trainer = pl.Trainer(
callbacks=_get_callbacks(threshold_esr),
@@ -1220,8 +1249,7 @@ def train(
)
# Suppress the PossibleUserWarning about num_workers (Issue 345)
with filter_warnings("ignore", category=PossibleUserWarning):
- trainer_fit_kwargs = {} if checkpoint is None else {"ckpt_path": checkpoint}
- trainer.fit(model, train_dataloader, val_dataloader, **trainer_fit_kwargs)
+ trainer.fit(model, train_dataloader, val_dataloader)
# Go to best checkpoint
best_checkpoint = trainer.checkpoint_callback.best_model_path
@@ -1232,7 +1260,8 @@ def train(
)
model.cpu()
model.eval()
- model.net.sample_rate = train_dataloader.dataset.sample_rate
+ # HACK set again
+ model.net.sample_rate = sample_rate
def window_kwargs(version: Version):
if version.major == 1:
diff --git a/nam/train/gui.py b/nam/train/gui.py
@@ -43,7 +43,6 @@ try: # 3rd-party and 1st-party imports
from nam.models.metadata import GearType, UserMetadata, ToneType
# Ok private access here--this is technically allowed access
- from nam.train._errors import IncompatibleCheckpointError
from nam.train._names import INPUT_BASENAMES, LATEST_VERSION
_install_is_valid = True
@@ -67,7 +66,6 @@ _TEXT_WIDTH = 70
_DEFAULT_DELAY = None
_DEFAULT_IGNORE_CHECKS = False
_DEFAULT_THRESHOLD_ESR = None
-_DEFAULT_CHECKPOINT = None
_ADVANCED_OPTIONS_LEFT_WIDTH = 12
_ADVANCED_OPTIONS_RIGHT_WIDTH = 12
@@ -84,7 +82,6 @@ class _AdvancedOptions(object):
:param ignore_checks: Keep going even if a check says that something is wrong.
:param threshold_esr: Stop training if the ESR gets better than this. If None, don't
stop.
- :param checkpoint: If provided, try to restart from this checkpoint.
"""
architecture: core.Architecture
@@ -92,7 +89,6 @@ class _AdvancedOptions(object):
latency: Optional[int]
ignore_checks: bool
threshold_esr: Optional[float]
- checkpoint: Optional[Path]
class _PathType(Enum):
@@ -364,7 +360,6 @@ class _GUI(object):
_DEFAULT_DELAY,
_DEFAULT_IGNORE_CHECKS,
_DEFAULT_THRESHOLD_ESR,
- _DEFAULT_CHECKPOINT,
)
# Window to edit them:
@@ -487,7 +482,6 @@ class _GUI(object):
delay = self.advanced_options.latency
file_list = self._widgets[_GUIWidgets.OUTPUT_PATH].val
threshold_esr = self.advanced_options.threshold_esr
- checkpoint = self.advanced_options.checkpoint
# Advanced-er options
# If you're poking around looking for these, then maybe it's time to learn to
@@ -502,36 +496,27 @@ class _GUI(object):
print("Now training {}".format(file))
basename = re.sub(r"\.wav$", "", file.split("/")[-1])
- try:
- trained_model = core.train(
- self._widgets[_GUIWidgets.INPUT_PATH].val,
- file,
- self._widgets[_GUIWidgets.TRAINING_DESTINATION].val,
- epochs=num_epochs,
- delay=delay,
- architecture=architecture,
- batch_size=batch_size,
- lr=lr,
- lr_decay=lr_decay,
- seed=seed,
- silent=self._checkboxes[
- _CheckboxKeys.SILENT_TRAINING
- ].variable.get(),
- save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(),
- modelname=basename,
- ignore_checks=self._checkboxes[
- _CheckboxKeys.IGNORE_DATA_CHECKS
- ].variable.get(),
- local=True,
- fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(),
- threshold_esr=threshold_esr,
- checkpoint=checkpoint,
- )
- except IncompatibleCheckpointError as e:
- trained_model = None
- self._wait_while_func(
- _BasicModal, "Training failed due to incompatible checkpoint!"
- )
+ trained_model = core.train(
+ self._widgets[_GUIWidgets.INPUT_PATH].val,
+ file,
+ self._widgets[_GUIWidgets.TRAINING_DESTINATION].val,
+ epochs=num_epochs,
+ delay=delay,
+ architecture=architecture,
+ batch_size=batch_size,
+ lr=lr,
+ lr_decay=lr_decay,
+ seed=seed,
+ silent=self._checkboxes[_CheckboxKeys.SILENT_TRAINING].variable.get(),
+ save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(),
+ modelname=basename,
+ ignore_checks=self._checkboxes[
+ _CheckboxKeys.IGNORE_DATA_CHECKS
+ ].variable.get(),
+ local=True,
+ fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(),
+ threshold_esr=threshold_esr,
+ )
if trained_model is None:
print("Model training failed! Skip exporting...")
@@ -755,17 +740,6 @@ class _AdvancedOptionsGUI(object):
type=_float_or_null,
)
- # Restart from a checkpoint
- self._frame_checkpoint = tk.Frame(self._root)
- self._frame_checkpoint.pack()
- self._path_button_checkpoint = _ClearablePathButton(
- self._frame_checkpoint,
- "Checkpoint",
- "[Optional] Select a checkpoint (.ckpt file) to restart training from",
- _PathType.FILE,
- default=self._parent.advanced_options.checkpoint,
- )
-
# "Ok": apply and destory
self._frame_ok = tk.Frame(self._root)
self._frame_ok.pack()
@@ -798,10 +772,6 @@ class _AdvancedOptionsGUI(object):
self._parent.advanced_options.threshold_esr = (
None if threshold_esr == "null" else threshold_esr
)
- checkpoint_path = self._path_button_checkpoint.val
- self._parent.advanced_options.checkpoint = (
- None if checkpoint_path is None else Path(checkpoint_path)
- )
self._root.destroy()
self._resume_main()
diff --git a/tests/test_nam/test_models/test_exportable.py b/tests/test_nam/test_models/test_exportable.py
@@ -17,7 +17,7 @@ import pytest
import torch
import torch.nn as nn
-from nam.models import _exportable
+from nam.models import exportable
from nam.models import metadata
@@ -105,7 +105,7 @@ class TestExportable(object):
@classmethod
def _get_model(cls):
- class Model(nn.Module, _exportable.Exportable):
+ class Model(nn.Module, exportable.Exportable):
def __init__(self):
super().__init__()
self._scale = nn.Parameter(torch.tensor(0.0))