commit a284650353cefa986f3ba6a8876a7375a0205a7c
parent a2f54f02050c6a87d313880b7f805c757f10e4f9
Author: Steven Atkinson <[email protected]>
Date: Tue, 9 Jan 2024 00:12:46 -0800
[BUGFIX] Store model's expected sample rate in Lightning checkpoints (#357)
* Rework sample_rate implementation to be storable in PyTorch artifacts
* Store sample rate in checkpoint
Diffstat:
5 files changed, 79 insertions(+), 10 deletions(-)
diff --git a/bin/train/main.py b/bin/train/main.py
@@ -33,6 +33,7 @@ from torch.utils.data import DataLoader
from nam.data import ConcatDataset, ParametricDataset, Split, init_dataset
from nam.models import Model
+from nam.models._base import BaseNet # HACK access
from nam.util import filter_warnings, timestamp
torch.manual_seed(0)
@@ -191,6 +192,7 @@ def main_inner(
"Train and validation data loaders have different data set sample rates: "
f"{dataset_train.sample_rate}, {dataset_validation.sample_rate}"
)
+ model.net.sample_rate = dataset_train.sample_rate
train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"])
val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"])
@@ -215,7 +217,6 @@ def main_inner(
)
model.cpu()
model.eval()
- model.net.sample_rate = train_dataloader.dataset.sample_rate
if make_plots:
plot(
model,
@@ -226,9 +227,9 @@ def main_inner(
show=False,
)
plot(model, dataset_validation, show=not no_show)
- # Would like to, but this doesn't work for all cases.
- # If you're making snapshot models, you may find this convenient to uncomment :)
- # model.net.export(outdir)
+ # Convenient export for snapshot models:
+ if isinstance(model.net, BaseNet):
+ model.net.export(outdir)
if __name__ == "__main__":
diff --git a/nam/data.py b/nam/data.py
@@ -753,7 +753,7 @@ _dataset_init_registry = {
def register_dataset_initializer(
- name: str, constructor: Callable[[Any], AbstractDataset]
+ name: str, constructor: Callable[[Any], AbstractDataset], overwrite=False
):
"""
If you have otehr data set types, you can register their initializer by name using
@@ -768,7 +768,7 @@ def register_dataset_initializer(
:param name: The name that'll be used in the config to ask for the data set type
:param constructor: The constructor that'll be fed the config.
"""
- if name in _dataset_init_registry:
+ if name in _dataset_init_registry and not overwrite:
raise KeyError(
f"A constructor for dataset name '{name}' is already registered!"
)
diff --git a/nam/models/_base.py b/nam/models/_base.py
@@ -24,7 +24,12 @@ from ._exportable import Exportable
class _Base(nn.Module, InitializableFromConfig, Exportable):
def __init__(self, sample_rate: Optional[float] = None):
super().__init__()
- self.sample_rate = sample_rate
+ self.register_buffer(
+ "_has_sample_rate", torch.tensor(sample_rate is not None, dtype=torch.bool)
+ )
+ self.register_buffer(
+ "_sample_rate", torch.tensor(0.0 if sample_rate is None else sample_rate)
+ )
@abc.abstractproperty
def pad_start_default(self) -> bool:
@@ -49,6 +54,15 @@ class _Base(nn.Module, InitializableFromConfig, Exportable):
)
)
+ @property
+ def sample_rate(self) -> Optional[float]:
+ return self._sample_rate.item() if self._has_sample_rate else None
+
+ @sample_rate.setter
+ def sample_rate(self, val: Optional[float]):
+ self._has_sample_rate = torch.tensor(val is not None, dtype=torch.bool)
+ self._sample_rate = torch.tensor(0.0 if val is None else val)
+
def _get_export_dict(self):
d = super()._get_export_dict()
sample_rate_key = "sample_rate"
diff --git a/nam/models/base.py b/nam/models/base.py
@@ -12,7 +12,7 @@ For the base *PyTorch* model containing the actual architecture, see `._base`.
from dataclasses import dataclass
from enum import Enum
-from typing import Dict, NamedTuple, Optional, Tuple
+from typing import Any, Dict, NamedTuple, Optional, Tuple
import auraloss
import logging
@@ -210,8 +210,8 @@ class Model(pl.LightningModule, InitializableFromConfig):
}
@classmethod
- def register_net_initializer(cls, name, constructor):
- if name in _model_net_init_registry:
+ def register_net_initializer(cls, name, constructor, overwrite: bool = False):
+ if name in _model_net_init_registry and not overwrite:
raise KeyError(
f"A constructor for net name '{name}' is already registered!"
)
@@ -238,6 +238,14 @@ class Model(pl.LightningModule, InitializableFromConfig):
def forward(self, *args, **kwargs):
return self.net(*args, **kwargs) # TODO deprecate--use self.net() instead.
+ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
+ # Resolves https://github.com/sdatkinson/neural-amp-modeler/issues/351
+ self.net.sample_rate = checkpoint["sample_rate"]
+
+ def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
+ # Resolves https://github.com/sdatkinson/neural-amp-modeler/issues/351
+ checkpoint["sample_rate"] = self.net.sample_rate
+
def _shared_step(
self, batch
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, _LossItem]]:
diff --git a/tests/test_nam/test_models/test_base.py b/tests/test_nam/test_models/test_base.py
@@ -2,8 +2,13 @@
# Created Date: Thursday March 16th 2023
# Author: Steven Atkinson ([email protected])
+"""
+Tests for the base network and Lightning module
+"""
+
import math
from pathlib import Path
+from tempfile import TemporaryDirectory
from typing import Optional
import numpy as np
@@ -106,5 +111,46 @@ def test_mrstft_loss_cpu_fallback(mocker):
assert obj._mrstft_device == "cpu"
+class TestSampleRate(object):
+ """
+ Tests for sample_rate interface
+ """
+
+ @pytest.mark.parametrize("expected_sample_rate", (None, 44_100.0, 48_000.0))
+ def test_on_init(self, expected_sample_rate: Optional[float]):
+ model = _MockBaseNet(gain=1.0, sample_rate=expected_sample_rate)
+ self._wrap_assert(model, expected_sample_rate)
+
+ @pytest.mark.parametrize("expected_sample_rate", (None, 44_100.0, 48_000.0))
+ def test_setter(self, expected_sample_rate: Optional[float]):
+ model = _MockBaseNet(gain=1.0)
+ model.sample_rate = expected_sample_rate
+ self._wrap_assert(model, expected_sample_rate)
+
+ @pytest.mark.parametrize("expected_sample_rate", (None, 44_100.0, 48_000.0))
+ def test_state_dict(self, expected_sample_rate: Optional[float]):
+ """
+ Assert that it makes it into the state dict
+
+ https://github.com/sdatkinson/neural-amp-modeler/issues/351
+ """
+ model = _MockBaseNet(gain=1.0, sample_rate=expected_sample_rate)
+ with TemporaryDirectory() as tmpdir:
+ model_path = Path(tmpdir, "model.pt")
+ torch.save(model.state_dict(), model_path)
+ model2 = _MockBaseNet(gain=1.0)
+ model2.load_state_dict(torch.load(model_path))
+ self._wrap_assert(model2, expected_sample_rate)
+
+ @classmethod
+ def _wrap_assert(cls, model: _MockBaseNet, expected: Optional[float]):
+ actual = model.sample_rate
+ if expected is None:
+ assert actual is None
+ else:
+ assert isinstance(actual, float)
+ assert actual == expected
+
+
if __name__ == "__main__":
pytest.main()