commit ced3ec4355568faa1b89ff722afdf2987f219be2
parent 1c53cba5ddd5a2d61a74ba63ff51478299c568c2
Author: Steven Atkinson <[email protected]>
Date: Wed, 20 Jul 2022 23:01:13 -0400
Version 0.3.0 (#39)
PR: https://github.com/sdatkinson/neural-amp-modeler/pull/39
# New features
* Hypernet parametric model (#23) (not supported by plugin yet.)
* DC loss (#31)
* ESR loss (#33)
* LSTM model (#35)
* Conditional LSTM model (#36)
Diffstat:
29 files changed, 1846 insertions(+), 160 deletions(-)
diff --git a/bin/train/inputs/config_learning.json b/bin/train/inputs/config_learning.json
@@ -1,17 +0,0 @@
-{
- "train_dataloader": {
- "batch_size": 16,
- "shuffle": true,
- "pin_memory": true,
- "drop_last": true,
- "num_workers": 4
- },
- "val_dataloader": {
- },
- "trainer": {
- "gpus": 1,
- "max_epochs": 1000
- },
- "trainer_fit_kwargs": {
- }
-}
-\ No newline at end of file
diff --git a/bin/train/inputs/config_data_single_pair.json b/bin/train/inputs/data/single_pair.json
diff --git a/bin/train/inputs/config_data_two_pairs.json b/bin/train/inputs/data/two_pairs.json
diff --git a/bin/train/inputs/learning/default.json b/bin/train/inputs/learning/default.json
@@ -0,0 +1,17 @@
+{
+ "train_dataloader": {
+ "batch_size": 32,
+ "shuffle": true,
+ "pin_memory": true,
+ "drop_last": true,
+ "num_workers": 4
+ },
+ "val_dataloader": {
+ },
+ "trainer": {
+ "gpus": 1,
+ "max_epochs": 1000
+ },
+ "trainer_fit_kwargs": {
+ }
+}
+\ No newline at end of file
diff --git a/bin/train/inputs/models/catlstm.json b/bin/train/inputs/models/catlstm.json
@@ -0,0 +1,41 @@
+{
+ "_comments": [
+ "Parametric extension of the LSTM model. All LSTM tips apply plus:",
+ " * Make sure that `input_size` is the number of knobs plus one. I've set it",
+ " up like we're modeling a tube screamer (drive/tone/level), so 1+3=4.",
+ " * I've messed around with weight decay, but I don't think it's actually",
+ " helpful. Still, it's in there so you can see how to use it if you're",
+ " curious.",
+ " * Doesn't seem like the model needs to be all that bigger than the",
+ " non-parametric version, even if you're modeling a fair number of knobs.",
+ " * You'll probably have a much larger dataset, so validating every so often ",
+ " in steps instead of epochs helps. Make sure to also set val_check_interval",
+ " under the trainer dict in your learning.json."
+ ],
+ "net": {
+ "name": "CatLSTM",
+ "config": {
+ "num_layers": 3,
+ "hidden_size": 24,
+ "train_truncate": 1024,
+ "train_burn_in": 4096,
+ "input_size": 4
+ }
+ },
+ "loss": {
+ "val_loss": "mse",
+ "mask_first": 4096
+ },
+ "optimizer": {
+ "lr": 0.01,
+ "weight_decay": 1e-09
+ },
+ "lr_scheduler": {
+ "class": "ExponentialLR",
+ "kwargs": {
+ "gamma": 0.995
+ },
+ "interval": "step",
+ "frequency": 200
+ }
+}
+\ No newline at end of file
diff --git a/bin/train/inputs/config_model.json b/bin/train/inputs/models/convnet.json
diff --git a/bin/train/inputs/models/lstm.json b/bin/train/inputs/models/lstm.json
@@ -0,0 +1,32 @@
+{
+ "_comments": [
+ "Reminders and tips:",
+ " * For your data, use nx=1, and use a long ny like 32768.",
+ " * For this model, it really helps if you have the delay in your data set",
+ " correctly. I've seen improvements fixing a delay that was off by 10",
+ " samples.",
+ " * gamma below is picked so that we end up with a learning rate of about",
+ " 1e-4 after 1000 epochs. I've found LSTMs to work with a pretty aggressive",
+ " learning rate that would be out of the question for other architectures.",
+ " * Number of units between 8 and 96, layers from 1 to 5 all seem to be ok",
+ " depending on the dataset, though bigger models might not make real-time."
+ ],
+ "net": {
+ "name": "LSTM",
+ "config": {
+ "hidden_size": 24,
+ "train_burn_in": 4096,
+ "train_truncate": 1024,
+ "num_layers": 3
+ }
+ },
+ "optimizer": {
+ "lr": 0.01
+ },
+ "lr_scheduler": {
+ "class": "ExponentialLR",
+ "kwargs": {
+ "gamma": 0.995
+ }
+ }
+}
+\ No newline at end of file
diff --git a/bin/train/main.py b/bin/train/main.py
@@ -15,7 +15,7 @@ import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
-from nam.data import Split, init_dataset
+from nam.data import ConcatDataset, ParametricDataset, Split, init_dataset
from nam.models import Model
torch.manual_seed(0)
@@ -49,11 +49,32 @@ def plot(
window_start: Optional[int] = None,
window_end: Optional[int] = None,
):
+ if isinstance(ds, ConcatDataset):
+
+ def extend_savefig(i, savefig):
+ if savefig is None:
+ return None
+ savefig = Path(savefig)
+ extension = savefig.name.split(".")[-1]
+ stem = savefig.name[: -len(extension) - 1]
+ return Path(savefig.parent, f"{stem}_{i}.{extension}")
+
+ for i, ds_i in enumerate(ds.datasets):
+ plot(
+ model,
+ ds_i,
+ savefig=extend_savefig(i, savefig),
+ show=show and i == len(ds.datasets) - 1,
+ window_start=window_start,
+ window_end=window_end,
+ )
+ return
with torch.no_grad():
tx = len(ds.x) / 48_000
print(f"Run (t={tx})")
t0 = time()
- output = model(ds.x).flatten().cpu().numpy()
+ args = (ds.vals, ds.x) if isinstance(ds, ParametricDataset) else (ds.x,)
+ output = model(*args).flatten().cpu().numpy()
t1 = time()
print(f"Took {t1 - t0} ({tx / (t1 - t0):.2f}x)")
@@ -72,6 +93,31 @@ def plot(
plt.show()
+def _create_callbacks(learning_config):
+ """
+ Checkpointing, essentially
+ """
+ # Checkpoints should be run every time the validation check is run.
+ # So base it off of learning_config["trainer"]["val_check_interval"] if it's there.
+ if "val_check_interval" in learning_config["trainer"]:
+ kwargs = {
+ "every_n_train_steps": learning_config["trainer"]["val_check_interval"]
+ }
+ else:
+ kwargs = {"every_n_epochs": 1}
+
+ checkpoint_best = pl.callbacks.model_checkpoint.ModelCheckpoint(
+ filename="{epoch:04d}_{step}_{ESR:.3e}_{MSE:.3e}",
+ save_top_k=3,
+ monitor="val_loss",
+ **kwargs,
+ )
+ checkpoint_last = pl.callbacks.model_checkpoint.ModelCheckpoint(
+ filename="checkpoint_last_{epoch:04d}_{step}", **kwargs
+ )
+ return [checkpoint_best, checkpoint_last]
+
+
def main(args):
outdir = ensure_outdir(args.outdir)
# Read
@@ -100,17 +146,7 @@ def main(args):
# ckpt_path = Path(outdir, "checkpoints")
# ckpt_path.mkdir()
trainer = pl.Trainer(
- callbacks=[
- pl.callbacks.model_checkpoint.ModelCheckpoint(
- filename="{epoch}_{val_loss:.6f}",
- save_top_k=3,
- monitor="val_loss",
- every_n_epochs=1,
- ),
- pl.callbacks.model_checkpoint.ModelCheckpoint(
- filename="checkpoint_last_{epoch:04d}", every_n_epochs=1
- ),
- ],
+ callbacks=_create_callbacks(learning_config),
default_root_dir=outdir,
**learning_config["trainer"],
)
diff --git a/nam/_version.py b/nam/_version.py
@@ -1 +1 @@
-__version__ = "0.2.1"
+__version__ = "0.3.0"
diff --git a/nam/data.py b/nam/data.py
@@ -4,10 +4,11 @@
import abc
from collections import namedtuple
+from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
-from typing import Optional, Sequence, Tuple, Union
+from typing import Dict, Optional, Sequence, Tuple, Union
import numpy as np
import torch
@@ -18,7 +19,7 @@ from tqdm import tqdm
from ._core import InitializableFromConfig
_REQUIRED_SAMPWIDTH = 3
-_REQUIRED_RATE = 48_000
+REQUIRED_RATE = 48_000
_REQUIRED_CHANNELS = 1 # Mono
@@ -47,7 +48,7 @@ def wav_to_np(
x_wav = wavio.read(str(filename))
assert x_wav.data.shape[1] == _REQUIRED_CHANNELS, "Mono"
assert x_wav.sampwidth == _REQUIRED_SAMPWIDTH, "24-bit"
- assert x_wav.rate == _REQUIRED_RATE, "48 kHz"
+ assert x_wav.rate == REQUIRED_RATE, "48 kHz"
if require_match is not None:
assert required_shape is None
@@ -83,20 +84,20 @@ def wav_to_tensor(
return torch.Tensor(arr)
-def tensor_to_wav(
- x: torch.Tensor,
+def tensor_to_wav(x: torch.Tensor, *args, **kwargs):
+ np_to_wav(x.detach().cpu().numpy(), *args, **kwargs)
+
+
+def np_to_wav(
+ x: np.ndarray,
filename: Union[str, Path],
rate: int = 48_000,
sampwidth: int = 3,
scale="none",
):
wavio.write(
- filename,
- (torch.clamp(x, -1.0, 1.0) * (2 ** (8 * sampwidth - 1)))
- .detach()
- .cpu()
- .numpy()
- .astype(np.int32),
+ str(filename),
+ (np.clip(x, -1.0, 1.0) * (2 ** (8 * sampwidth - 1))).astype(np.int32),
rate,
scale=scale,
sampwidth=sampwidth,
@@ -105,13 +106,20 @@ def tensor_to_wav(
class AbstractDataset(_Dataset, abc.ABC):
@abc.abstractmethod
- def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:
+ def __getitem__(
+ self, idx
+ ) -> Union[
+ Tuple[torch.Tensor, torch.Tensor],
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
+ ]:
pass
class Dataset(AbstractDataset, InitializableFromConfig):
"""
- Take a pair of matched audio files and serve input + output pairs
+ Take a pair of matched audio files and serve input + output pairs.
+
+ No conditioning parameters associated w/ the data.
"""
def __init__(
@@ -204,7 +212,8 @@ class Dataset(AbstractDataset, InitializableFromConfig):
assert x.ndim == 1
assert y.ndim == 1
assert len(x) == len(y)
- assert nx <= len(x)
+ if nx > len(x):
+ raise RuntimeError(f"Input of length {len(x)}, but receptive field is {nx}.")
if ny is not None:
assert ny <= len(y) - nx + 1
if torch.abs(y).max() >= 1.0:
@@ -214,8 +223,71 @@ class Dataset(AbstractDataset, InitializableFromConfig):
raise ValueError(msg)
+class ParametricDataset(Dataset):
+ """
+ Additionally tracks some conditioning parameters
+ """
+
+ def __init__(self, params: Dict[str, Union[bool, float, int]], *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._keys = sorted(tuple(k for k in params.keys()))
+ self._vals = torch.Tensor([float(params[k]) for k in self._keys])
+
+ @classmethod
+ def init_from_config(cls, config):
+ if "slices" not in config:
+ return super().init_from_config(config)
+ else:
+ return cls.init_from_config_with_slices(config)
+
+ @classmethod
+ def init_from_config_with_slices(cls, config):
+ config, x, y, slices = cls.parse_config_with_slices(config)
+ datasets = []
+ for s in tqdm(slices, desc="Slices..."):
+ c = deepcopy(config)
+ start, stop, params = [s[k] for k in ("start", "stop", "params")]
+ c.update(x=x[start:stop], y=y[start:stop], params=params)
+ datasets.append(ParametricDataset(**c))
+ return ConcatDataset(datasets)
+
+ @classmethod
+ def parse_config(cls, config):
+ assert "slices" not in config
+ params = config["params"]
+ return {
+ "params": params,
+ "id": config.get("id"),
+ "common_params": config.get("common_params"),
+ "param_map": config.get("param_map"),
+ **super().parse_config(config),
+ }
+
+ @classmethod
+ def parse_config_with_slices(cls, config):
+ slices = config["slices"]
+ config = super().parse_config(config)
+ x, y = [config.pop(k) for k in "xy"]
+ return config, x, y, slices
+
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ # FIXME don't override signature
+ x, y = super().__getitem__(idx)
+ return self.vals, x, y
+
+ @property
+ def keys(self) -> Tuple[str]:
+ return self._keys
+
+ @property
+ def vals(self):
+ return self._vals
+
+
class ConcatDataset(AbstractDataset, InitializableFromConfig):
- def __init__(self, datasets: Sequence[Dataset]):
+ def __init__(self, datasets: Sequence[Dataset], flatten=True):
+ if flatten:
+ datasets = self._flatten_datasets(datasets)
self._validate_datasets(datasets)
self._datasets = datasets
@@ -229,36 +301,72 @@ class ConcatDataset(AbstractDataset, InitializableFromConfig):
def __len__(self) -> int:
return sum(len(d) for d in self._datasets)
+ @property
+ def datasets(self):
+ return self._datasets
+
@classmethod
def parse_config(cls, config):
+ init = (
+ ParametricDataset.init_from_config
+ if config["parametric"]
+ else Dataset.init_from_config
+ )
return {
"datasets": tuple(
- Dataset.init_from_config(c)
- for c in tqdm(config["dataset_configs"], desc="Loading data")
+ init(c) for c in tqdm(config["dataset_configs"], desc="Loading data")
)
}
+ def _flatten_datasets(self, datasets):
+ """
+ If any dataset is a ConcatDataset, pull it out
+ """
+ flattened = []
+ for d in datasets:
+ if isinstance(d, ConcatDataset):
+ flattened.extend(d.datasets)
+ else:
+ flattened.append(d)
+ return flattened
+
@classmethod
def _validate_datasets(cls, datasets: Sequence[Dataset]):
Reference = namedtuple("Reference", ("index", "val"))
- ref_ny = None
+ ref_keys, ref_ny = None, None
for i, d in enumerate(datasets):
ref_ny = Reference(i, d.ny) if ref_ny is None else ref_ny
if d.ny != ref_ny.val:
raise ValueError(
- f"Mismatch between ny of datasets {ref_ny.index} ({ref_ny.val}) and"
- f" {i} ({d.ny})"
+ f"Mismatch between ny of datasets {ref_ny.index} ({ref_ny.val}) and {i} ({d.ny})"
)
+ if isinstance(d, ParametricDataset):
+ val = d.keys
+ if ref_keys is None:
+ ref_keys = Reference(i, val)
+ if val != ref_keys.val:
+ raise ValueError(
+ f"Mismatch between keys of datasets {ref_keys.index} "
+ f"({ref_keys.val}) and {i} ({val})"
+ )
def init_dataset(config, split: Split) -> AbstractDataset:
+ parametric = config.get("parametric", False)
base_config = config[split.value]
common = config.get("common", {})
if isinstance(base_config, dict):
- return Dataset.init_from_config({**common, **base_config})
+ init = (
+ ParametricDataset.init_from_config
+ if parametric
+ else Dataset.init_from_config
+ )
+ return init({**common, **base_config})
elif isinstance(base_config, list):
return ConcatDataset.init_from_config(
- {"dataset_configs": [{**common, **c} for c in base_config]}
+ {
+ "parametric": parametric,
+ "dataset_configs": [{**common, **c} for c in base_config],
+ }
)
- else:
- raise TypeError(f"Unrecognized config type {type(base_config)}")
+
diff --git a/nam/models/_base.py b/nam/models/_base.py
@@ -3,7 +3,7 @@
# Author: Steven Atkinson ([email protected])
import abc
-from typing import Tuple
+from typing import Optional, Tuple
import numpy as np
import torch
@@ -13,7 +13,7 @@ from .._core import InitializableFromConfig
from ._exportable import Exportable
-class BaseNet(nn.Module, InitializableFromConfig, Exportable):
+class _Base(nn.Module, InitializableFromConfig, Exportable):
@abc.abstractproperty
def pad_start_default(self) -> bool:
pass
@@ -25,7 +25,33 @@ class BaseNet(nn.Module, InitializableFromConfig, Exportable):
"""
pass
- def forward(self, x: torch.Tensor, pad_start: bool = None):
+ @abc.abstractmethod
+ def forward(self, *args, **kwargs) -> torch.Tensor:
+ pass
+
+ @abc.abstractmethod
+ def _forward(self, *args) -> torch.Tensor:
+ """
+ The true forward method.
+
+ :param x: (N,L1)
+ :return: (N,L1-RF+1)
+ """
+ pass
+
+ def _export_input_output(
+ self, seed=0, extra_length=13
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ x = torch.Tensor(
+ np.random.default_rng(seed).normal(
+ size=(self.receptive_field + extra_length,)
+ )
+ )
+ return x, self(x, pad_start=False)
+
+
+class BaseNet(_Base):
+ def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None):
pad_start = self.pad_start_default if pad_start is None else pad_start
scalar = x.ndim == 1
if scalar:
@@ -47,12 +73,34 @@ class BaseNet(nn.Module, InitializableFromConfig, Exportable):
"""
pass
- def _test_signal(
- self, seed=0, extra_length=13
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- x = torch.Tensor(
- np.random.default_rng(seed).normal(
- size=(self.receptive_field + extra_length,)
- )
- )
- return x, self(x, pad_start=False)
+
+class ParametricBaseNet(_Base):
+ """
+ Parametric inputs
+ """
+
+ def forward(
+ self, params: torch.Tensor, x: torch.Tensor, pad_start: Optional[bool] = None
+ ):
+ pad_start = self.pad_start_default if pad_start is None else pad_start
+ scalar = x.ndim == 1
+ if scalar:
+ x = x[None]
+ params = params[None]
+ if pad_start:
+ x = torch.cat((torch.zeros((len(x), self.receptive_field - 1)), x), dim=1)
+ y = self._forward(params, x)
+ if scalar:
+ y = y[0]
+ return y
+
+ @abc.abstractmethod
+ def _forward(self, params: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
+ """
+ The true forward method.
+
+ :param params: (N,D)
+ :param x: (N,L1)
+ :return: (N,L1-RF+1)
+ """
+ pass
diff --git a/nam/models/_exportable.py b/nam/models/_exportable.py
@@ -3,7 +3,14 @@
# Author: Steven Atkinson ([email protected])
import abc
+import json
from pathlib import Path
+from typing import Tuple
+
+import numpy as np
+
+from .._version import __version__
+from ..data import np_to_wav
class Exportable(abc.ABC):
@@ -11,7 +18,6 @@ class Exportable(abc.ABC):
Interface for my custon export format for use in the plugin.
"""
- @abc.abstractmethod
def export(self, outdir: Path):
"""
Interface for exporting.
@@ -22,7 +28,27 @@ class Exportable(abc.ABC):
:param outdir: Assumed to exist. Can be edited inside at will.
"""
- pass
+ training = self.training
+ self.eval()
+ with open(Path(outdir, "config.json"), "w") as fp:
+ json.dump(
+ {
+ "version": __version__,
+ "architecture": self.__class__.__name__,
+ "config": self._export_config(),
+ },
+ fp,
+ indent=4,
+ )
+ np.save(Path(outdir, "weights.npy"), self._export_weights())
+ x, y = self._export_input_output()
+ np.save(Path(outdir, "inputs.npy"), x)
+ np.save(Path(outdir, "outputs.npy"), y)
+ np_to_wav(x, Path(outdir, "input.wav"))
+ np_to_wav(y, Path(outdir, "output.wav"))
+
+ # And resume training state
+ self.train(training)
@abc.abstractmethod
def export_cpp_header(self, filename: Path):
@@ -31,3 +57,27 @@ class Exportable(abc.ABC):
as text
"""
pass
+
+ @abc.abstractmethod
+ def _export_config(self):
+ """
+ Creates the JSON of the model's archtecture hyperparameters (number of layers,
+ number of units, etc)
+
+ :return: a JSON serializable object
+ """
+ pass
+
+ @abc.abstractmethod
+ def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Create an input and corresponding output signal to verify its behavior.
+ """
+ pass
+
+ @abc.abstractmethod
+ def _export_weights(self) -> np.ndarray:
+ """
+ Flatten the weights out to a 1D array
+ """
+ pass
diff --git a/nam/models/_names.py b/nam/models/_names.py
@@ -0,0 +1,13 @@
+# File: _names.py
+# Created Date: Wednesday June 22nd 2022
+# Author: Steven Atkinson ([email protected])
+
+
+"""
+Layer names
+"""
+
+ACTIVATION_NAME = "activation"
+BATCHNORM_NAME = "batchnorm"
+CONV_NAME = "conv"
+LAYERNORM_NAME = "layernorm"
diff --git a/nam/models/base.py b/nam/models/base.py
@@ -6,6 +6,8 @@
Lightning stuff
"""
+from dataclasses import dataclass
+from enum import Enum
from typing import Optional
import pytorch_lightning as pl
@@ -13,8 +15,55 @@ import torch
import torch.nn as nn
from .._core import InitializableFromConfig
-from .linear import Linear
from .conv_net import ConvNet
+from .linear import Linear
+from .parametric.catnets import CatLSTM
+from .parametric.hyper_net import HyperConvNet
+from .recurrent import LSTM
+
+
+class ValidationLoss(Enum):
+ """
+ mse: mean squared error
+ esr: error signal ratio (Eq. (10) from
+ https://www.mdpi.com/2076-3417/10/3/766/htm
+ NOTE: Be careful when computing ESR on minibatches! The average ESR over
+ a minibatch of data not the same as the ESR of all of the same data in
+ the minibatch calculated over at once (because of the denominator).
+ (Hint: think about what happens if one item in the minibatch is all
+ zeroes...)
+ """
+
+ MSE = "mse"
+ ESR = "esr"
+
+
+@dataclass
+class LossConfig(InitializableFromConfig):
+ """
+ :param mask_first: How many of the first samples to ignore when comptuing the loss.
+ :param dc_weight: Weight for the DC loss term. If 0, ignored.
+ :params val_loss: Which loss to track for the best model checkpoint.
+ """
+
+ mask_first: int = 0
+ dc_weight: float = 0.0
+ val_loss: ValidationLoss = ValidationLoss.MSE
+
+ @classmethod
+ def parse_config(cls, config):
+ config = super().parse_config(config)
+ dc_weight = config.get("dc_weight", 0.0)
+ val_loss = ValidationLoss(config.get("val_loss", "mse"))
+ mask_first = config.get("mask_first", 0)
+ return {"mask_first": mask_first, "dc_weight": dc_weight, "val_loss": val_loss}
+
+ def apply_mask(self, *args):
+ """
+ :param args: (L,) or (B,)
+ :return: (L-M,) or (B, L-M)
+ """
+ return tuple(a[..., self.mask_first :] for a in args)
class Model(pl.LightningModule, InitializableFromConfig):
@@ -23,6 +72,7 @@ class Model(pl.LightningModule, InitializableFromConfig):
net,
optimizer_config: Optional[dict] = None,
scheduler_config: Optional[dict] = None,
+ loss_config: Optional[LossConfig] = None,
):
"""
:param scheduler_config: contains
@@ -38,6 +88,7 @@ class Model(pl.LightningModule, InitializableFromConfig):
self._net = net
self._optimizer_config = {} if optimizer_config is None else optimizer_config
self._scheduler_config = scheduler_config
+ self._loss_config = LossConfig() if loss_config is None else loss_config
@classmethod
def init_from_config(cls, config):
@@ -51,15 +102,48 @@ class Model(pl.LightningModule, InitializableFromConfig):
@classmethod
def parse_config(cls, config):
+ """
+ e.g.
+
+ {
+ "net": {
+ "name": "ConvNet",
+ "config": {...}
+ },
+ "loss": {
+ "dc_weight": 0.1
+ },
+ "optimizer": {
+ "lr": 0.0003
+ },
+ "lr_scheduler": {
+ "class": "ReduceLROnPlateau",
+ "kwargs": {
+ "factor": 0.8,
+ "patience": 10,
+ "cooldown": 15,
+ "min_lr": 1e-06,
+ "verbose": true
+ },
+ "monitor": "val_loss"
+ }
+ }
+ """
config = super().parse_config(config)
net_config = config["net"]
- net = {"Linear": Linear.init_from_config, "ConvNet": ConvNet.init_from_config}[
- net_config["name"]
- ](net_config["config"])
+ net = {
+ "CatLSTM": CatLSTM.init_from_config,
+ "ConvNet": ConvNet.init_from_config,
+ "HyperConvNet": HyperConvNet.init_from_config,
+ "Linear": Linear.init_from_config,
+ "LSTM": LSTM.init_from_config,
+ }[net_config["name"]](net_config["config"])
+ loss_config = LossConfig.init_from_config(config.get("loss", {}))
return {
"net": net,
"optimizer_config": config["optimizer"],
"scheduler_config": config["lr_scheduler"],
+ "loss_config": loss_config,
}
@property
@@ -84,14 +168,48 @@ class Model(pl.LightningModule, InitializableFromConfig):
return self.net(*args, **kwargs)
def _shared_step(self, batch):
- sources, targets = batch
- preds = self(sources, pad_start=False)
- return nn.MSELoss()(preds, targets)
+ args, targets = batch[:-1], batch[-1]
+ preds = self(*args, pad_start=False)
+
+ return preds, targets
def training_step(self, batch, batch_idx):
- return self._shared_step(batch)
+ preds, targets = self._shared_step(batch)
+
+ loss = 0.0
+ # Prediction aka MSE aka "ESR" loss
+ loss = loss + self._mse_loss(preds, targets)
+
+ # DC loss
+ dc_weight = self._loss_config.dc_weight
+ if dc_weight > 0.0:
+ # Denominator could be a bad idea. I'm going to omit it esp since I'm
+ # using mini batches
+ mean_dims = torch.arange(1, preds.ndim).tolist()
+ dc_loss = nn.MSELoss()(
+ preds.mean(dim=mean_dims), targets.mean(dim=mean_dims)
+ )
+ loss = loss + dc_weight * dc_loss
+ return loss
def validation_step(self, batch, batch_idx):
- val_loss = self._shared_step(batch)
- self.log_dict({"val_loss": val_loss})
+ preds, targets = self._shared_step(batch)
+ mse_loss = self._mse_loss(preds, targets)
+ esr_loss = self._esr_loss(preds, targets)
+ val_loss = {ValidationLoss.MSE: mse_loss, ValidationLoss.ESR: esr_loss}[
+ self._loss_config.val_loss
+ ]
+ self.log_dict({"MSE": mse_loss, "ESR": esr_loss, "val_loss": val_loss})
return val_loss
+
+ def _esr_loss(self, preds, targets):
+ """
+ Error signal ratio aka ESR loss.
+
+ Eq. (10), from
+ https://www.mdpi.com/2076-3417/10/3/766/htm
+ """
+ return nn.MSELoss()(preds, targets) / nn.MSELoss()(targets, 0.0 * targets)
+
+ def _mse_loss(self, preds, targets):
+ return nn.MSELoss()(preds, targets)
diff --git a/nam/models/conv_net.py b/nam/models/conv_net.py
@@ -3,6 +3,7 @@
# Author: Steven Atkinson ([email protected])
import json
+import math
from enum import Enum
from functools import partial
from pathlib import Path
@@ -16,12 +17,9 @@ import torch.nn.functional as F
from .. import __version__
-from ..data import wav_to_tensor
+from ..data import REQUIRED_RATE, wav_to_tensor
from ._base import BaseNet
-
-_CONV_NAME = "conv"
-_BATCHNORM_NAME = "batchnorm"
-_ACTIVATION_NAME = "activation"
+from ._names import ACTIVATION_NAME, BATCHNORM_NAME, CONV_NAME
class TrainStrategy(Enum):
@@ -71,11 +69,11 @@ def _conv_net(
def block(cin, cout, dilation):
net = nn.Sequential()
net.add_module(
- _CONV_NAME, nn.Conv1d(cin, cout, 2, dilation=dilation, bias=not batchnorm)
+ CONV_NAME, nn.Conv1d(cin, cout, 2, dilation=dilation, bias=not batchnorm)
)
if batchnorm:
- net.add_module(_BATCHNORM_NAME, nn.BatchNorm1d(cout))
- net.add_module(_ACTIVATION_NAME, getattr(nn, activation)())
+ net.add_module(BATCHNORM_NAME, nn.BatchNorm1d(cout))
+ net.add_module(ACTIVATION_NAME, getattr(nn, activation)())
return net
def check_and_expand(n, x):
@@ -149,12 +147,12 @@ class ConvNet(BaseNet):
@property
def _activation(self):
return (
- self._net._modules["block_0"]._modules[_ACTIVATION_NAME].__class__.__name__
+ self._net._modules["block_0"]._modules[ACTIVATION_NAME].__class__.__name__
)
@property
def _channels(self) -> int:
- return self._net._modules["block_0"]._modules[_CONV_NAME].weight.shape[0]
+ return self._net._modules["block_0"]._modules[CONV_NAME].weight.shape[0]
@property
def _num_layers(self) -> int:
@@ -162,16 +160,75 @@ class ConvNet(BaseNet):
@property
def _batchnorm(self) -> bool:
- return _BATCHNORM_NAME in self._net._modules["block_0"]._modules
+ return BATCHNORM_NAME in self._net._modules["block_0"]._modules
- def export(self, outdir: Path):
+ def export_cpp_header(self, filename: Path):
+ with TemporaryDirectory() as tmpdir:
+ tmpdir = Path(tmpdir)
+ self.export(Path(tmpdir))
+ with open(Path(tmpdir, "config.json"), "r") as fp:
+ _c = json.load(fp)
+ version = _c["version"]
+ config = _c["config"]
+ with open(filename, "w") as f:
+ f.writelines(
+ (
+ "#pragma once\n",
+ "// Automatically-generated model file\n",
+ "#include <vector>\n",
+ f'#define PYTHON_MODEL_VERSION "{version}"\n',
+ f"const int CHANNELS = {config['channels']};\n",
+ f"const bool BATCHNORM = {'true' if config['batchnorm'] else 'false'};\n",
+ "std::vector<int> DILATIONS{"
+ + ",".join([str(d) for d in config["dilations"]])
+ + "};\n",
+ f"const std::string ACTIVATION = \"{config['activation']}\";\n",
+ "std::vector<float> PARAMS{"
+ + ",".join(
+ [f"{w:.16f}" for w in np.load(Path(tmpdir, "weights.npy"))]
+ )
+ + "};\n",
+ )
+ )
+
+ def _export_config(self):
+ return {
+ "channels": self._channels,
+ "dilations": self._get_dilations(),
+ "batchnorm": self._batchnorm,
+ "activation": self._activation,
+ }
+
+ def _export_input_output(self, x=None) -> Tuple[np.ndarray, np.ndarray]:
"""
- Files created:
- * config.json
- * weights.npy
- * input.npy
- * output.npy
+ :return: (L,), (L,)
+ """
+ with torch.no_grad():
+ training = self.training
+ self.eval()
+ x = self._export_input_signal() if x is None else x
+ y = self(x, pad_start=True)
+ self.train(training)
+ return tuple(z.detach().cpu().numpy() for z in (x, y))
+
+ def _export_input_signal(self):
+ """
+ :return: (L,)
+ """
+ rate = REQUIRED_RATE
+ return torch.cat(
+ [
+ torch.zeros((rate,)),
+ 0.5
+ * torch.sin(
+ 2.0 * math.pi * 220.0 * torch.linspace(0.0, 1.0, rate + 1)[:-1]
+ ),
+ torch.zeros((rate,)),
+ ]
+ )
+ def _export_weights(self) -> np.ndarray:
+ """
weights are serialized to weights.npy in the following order:
* (expand: no params)
* loop blocks 0,...,L-1
@@ -188,37 +245,17 @@ class ConvNet(BaseNet):
* weight (C, 1, 1)
* bias (1, 1)
* (flatten: no params)
-
- A test input & output are also provided, input.npy and output.npy
"""
- training = self.training
- self.eval()
- with open(Path(outdir, "config.json"), "w") as fp:
- json.dump(
- {
- "version": __version__,
- "architecture": "ConvNet",
- "config": {
- "channels": self._channels,
- "dilations": self._get_dilations(),
- "batchnorm": self._batchnorm,
- "activation": self._activation,
- },
- },
- fp,
- indent=4,
- )
-
params = []
for i in range(self._num_layers):
block_name = f"block_{i}"
block = self._net._modules[block_name]
- conv = block._modules[_CONV_NAME]
+ conv = block._modules[CONV_NAME]
params.append(conv.weight.flatten())
if conv.bias is not None:
params.append(conv.bias.flatten())
if self._batchnorm:
- bn = block._modules[_BATCHNORM_NAME]
+ bn = block._modules[BATCHNORM_NAME]
params.append(bn.running_mean.flatten())
params.append(bn.running_var.flatten())
params.append(bn.weight.flatten())
@@ -228,45 +265,7 @@ class ConvNet(BaseNet):
params.append(head.weight.flatten())
params.append(head.bias.flatten())
params = torch.cat(params).detach().cpu().numpy()
- # Hope I don't regret using np.save...
- np.save(Path(outdir, "weights.npy"), params)
-
- # And an input/output to verify correct computation:
- x, y = self._test_signal()
- np.save(Path(outdir, "input.npy"), x.detach().cpu().numpy())
- np.save(Path(outdir, "output.npy"), y.detach().cpu().numpy())
-
- # And resume training state
- self.train(training)
-
- def export_cpp_header(self, filename: Path):
- with TemporaryDirectory() as tmpdir:
- tmpdir = Path(tmpdir)
- self.export(Path(tmpdir))
- with open(Path(tmpdir, "config.json"), "r") as fp:
- _c = json.load(fp)
- version = _c["version"]
- config = _c["config"]
- with open(filename, "w") as f:
- f.writelines(
- (
- "#pragma once\n",
- "// Automatically-generated model file\n",
- "#include <vector>\n",
- f'#define PYTHON_MODEL_VERSION "{version}"\n',
- f"const int CHANNELS = {config['channels']};\n",
- f"const bool BATCHNORM = {'true' if config['batchnorm'] else 'false'};\n",
- "std::vector<int> DILATIONS{"
- + ",".join([str(d) for d in config["dilations"]])
- + "};\n",
- f"const std::string ACTIVATION = \"{config['activation']}\";\n",
- "std::vector<float> PARAMS{"
- + ",".join(
- [f"{w:.16f}" for w in np.load(Path(tmpdir, "weights.npy"))]
- )
- + "};\n",
- )
- )
+ return params
def _forward(self, x):
y = self._net(x)
@@ -276,7 +275,7 @@ class ConvNet(BaseNet):
def _get_dilations(self) -> Tuple[int]:
return tuple(
- self._net._modules[f"block_{i}"]._modules[_CONV_NAME].dilation[0]
+ self._net._modules[f"block_{i}"]._modules[CONV_NAME].dilation[0]
for i in range(self._num_blocks)
)
diff --git a/nam/models/linear.py b/nam/models/linear.py
@@ -55,7 +55,7 @@ class Linear(BaseNet):
np.save(Path(outdir, "weights.npy"), params)
# And an input/output to verify correct computation:
- x, y = self._test_signal()
+ x, y = self._export_input_output()
np.save(Path(outdir, "input.npy"), x.detach().cpu().numpy())
np.save(Path(outdir, "output.npy"), y.detach().cpu().numpy())
diff --git a/nam/models/parametric/__init__.py b/nam/models/parametric/__init__.py
@@ -0,0 +1,4 @@
+# File: __init__.py
+# Created Date: Sunday July 17th 2022
+# Author: Steven Atkinson ([email protected])
+
diff --git a/nam/models/parametric/catnets.py b/nam/models/parametric/catnets.py
@@ -0,0 +1,169 @@
+# File: catnets.py
+# Created Date: Wednesday June 22nd 2022
+# Author: Steven Atkinson ([email protected])
+
+"""
+"Cat nets" -- parametric models where the parametric input is concatenated to the
+input samples
+"""
+
+import abc
+from enum import Enum
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Dict, Tuple
+
+import torch
+
+from .._base import ParametricBaseNet
+from ..recurrent import LSTM
+from .params import Param
+
+
+class _ShapeType(Enum):
+ CONV = "conv" # (B,C,L)
+ RNN = "rnn" # (B,L,D)
+
+
+class _CatMixin(ParametricBaseNet):
+ """
+ Parameteric nets that concatenate the params with the input at each time point
+ Mix in with a non-parametric class like
+
+ ```
+ class CatLSTM(LSTM, _CatMixin):
+ pass
+ ```
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # Hacky, see .export()
+ self._sidedoor_parametric_config = None
+
+ @abc.abstractproperty
+ def _shape_type(self) -> _ShapeType:
+ pass
+
+ @abc.abstractproperty
+ def _single_class(self):
+ """"
+ The class for the non-parametric model that this is extending
+ """
+ # TODO verify that single class satisfies requirements
+ # ._export_weights()
+ # ._export_input_output()
+ pass # HACK
+
+ def export(self, outdir: Path, parametric_config: Dict[str, Param]):
+ """
+ Interface for exporting.
+ You should create at least a `config.json` containing the two fields:
+ * "version" (str)
+ * "architecture" (str)
+ * "config": (dict w/ other necessary data like tensor shapes etc)
+
+ :param outdir: Assumed to exist. Can be edited inside at will.
+ """
+ with self._use_parametric_config(parametric_config):
+ return super().export(outdir)
+
+ def export_cpp_header(self, filename: Path, parametric_config: Dict[str, Param]):
+ with self._use_parametric_config(parametric_config):
+ return super().export_cpp_header(filename)
+
+ def _export_config(self):
+ """
+ Adds in the sidedoored parametric pieces
+
+ :paramtric_config: the dict of parameter info (name, type, etc)
+ """
+ config = super()._export_config()
+ if not isinstance(config, dict):
+ raise TypeError(
+ f"Parameteric models' base configs must be a dict; got {type(config)}"
+ )
+ parametric_key = "parametric"
+ if parametric_key in config:
+ raise ValueError(
+ f'Already found parametric key "{parametric_key}" in base config dict.'
+ )
+ # Yucky sidedoor
+ config[parametric_key] = {
+ k: v.to_json() for k, v in self._sidedoor_parametric_config.items()
+ }
+ return config
+
+ def _forward(self, params, x):
+ """
+ :param params: (N,D)
+ :param x: (N,L1)
+
+ :return: (N,L2)
+ """
+ sequence_length = x.shape[1]
+ x_augmented = (
+ torch.cat(
+ [
+ x[..., None],
+ torch.tile(params[:, None, :], (1, sequence_length, 1)),
+ ],
+ dim=2,
+ )
+ if self._shape_type == _ShapeType.RNN
+ else torch.cat(
+ [x[:, None, :], torch.tile(params[..., None], (1, 1, sequence_length))],
+ dim=1,
+ )
+ )
+ return self._single_class._forward(self, x_augmented)
+
+ @contextmanager
+ def _use_parametric_config(self, c):
+ """
+ Sneaks in the parametric config while exporting
+ """
+ try:
+ self._sidedoor_parametric_config = c
+ yield None
+ finally:
+ self._sidedoor_parametric_config = None
+
+
+class CatLSTM(_CatMixin, LSTM):
+ @property
+ def _shape_type(self) -> _ShapeType:
+ return _ShapeType.RNN
+
+ @property
+ def _single_class(self):
+ return LSTM
+
+ def _append_default_params(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Requires sidedoor'd params
+
+ :param x: (B,L)
+ :return: (B,L,1+D)
+ """
+ assert x.ndim == 2
+ param_names = sorted([k for k in self._sidedoor_parametric_config.keys()])
+ params = torch.Tensor(
+ [self._sidedoor_parametric_config[k].default_value for k in param_names]
+ )
+ sequence_length = x.shape[1]
+ return torch.cat(
+ [
+ x[:, :, None],
+ torch.tile(params[None, None, :], (1, sequence_length, 1)),
+ ],
+ dim=2,
+ )
+
+ def _get_initial_state(self) -> Tuple[torch.Tensor, torch.Tensor]:
+ inputs = self._append_default_params(torch.zeros((1, 48_000)))
+ return super()._get_initial_state(inputs=inputs)
+
+ def _export_input_output(self):
+ x = self._append_default_params(self._export_input_signal()[None])
+ return super()._export_input_output(x=x)
diff --git a/nam/models/parametric/hyper_net.py b/nam/models/parametric/hyper_net.py
@@ -0,0 +1,558 @@
+# File: hyper_net.py
+# Created Date: Sunday May 29th 2022
+# Author: Steven Atkinson ([email protected])
+
+import abc
+import json
+import math
+from dataclasses import dataclass
+from enum import Enum
+from pathlib import Path
+from tempfile import TemporaryDirectory
+from typing import Any, Callable, List, Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.init import (
+ calculate_gain,
+ _calculate_correct_fan,
+ _calculate_fan_in_and_fan_out,
+)
+
+from ..._version import __version__
+from .._base import ParametricBaseNet
+
+
+class SpecialLayers(Enum):
+ CONV = "conv"
+ BATCHNORM = "batchnorm"
+
+
+@dataclass
+class LayerSpec:
+ """
+ Helps the hypernet
+ """
+
+ special_type: Optional[str]
+ shapes: Tuple[Tuple[int]]
+ norms: Tuple[float]
+ biases: Tuple[float]
+
+
+class _NetLayer(nn.Module, abc.ABC):
+ @abc.abstractproperty
+ def num_tensors(self) -> int:
+ pass
+
+ @abc.abstractmethod
+ def get_spec(self) -> LayerSpec:
+ pass
+
+
+class _Conv(nn.Conv1d, _NetLayer):
+ @property
+ def num_tensors(self):
+ return 2 if self.bias is not None else 1
+
+ def forward(self, params, inputs):
+ # Use depthwise convolution trick to process the convolutions together
+ cout, cin, kernel_size = self.weight.shape
+ n = len(params[0])
+ weight = params[0].reshape((n * cout, cin, kernel_size)) # (N, CinCout)
+ bias = params[1].flatten() if self.bias is not None else None
+ groups = n
+ return F.conv1d(
+ inputs.reshape((1, n * cin, -1)),
+ weight,
+ bias=bias,
+ stride=self.stride,
+ padding=self.padding,
+ dilation=self.dilation,
+ groups=groups,
+ ).reshape((n, cout, -1))
+
+ def get_spec(self):
+ shapes = (
+ (self.weight.shape,)
+ if self.bias is None
+ else (self.weight.shape, self.bias.shape)
+ )
+ norms = (
+ (self._weight_norm(),)
+ if self.bias is None
+ else (self._weight_norm(), self._bias_norm())
+ )
+ biases = (0,) if self.bias is None else (0, 0)
+ return LayerSpec(SpecialLayers("conv"), shapes, norms, biases)
+
+ def _bias_norm(self):
+ # https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#Conv1d
+ fan = _calculate_fan_in_and_fan_out(self.weight.data)[0]
+ bound = 1.0 / math.sqrt(fan)
+ std = math.sqrt(1.0 / 12.0) * (2 * bound)
+ # LayerNorm handles division by number of dimensions...
+ return std
+
+ def _weight_norm(self):
+ """
+ Std of the unfiorm distribution used in initialization
+ """
+ # https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#Conv1d
+ fan = _calculate_correct_fan(self.weight.data, "fan_in")
+ # Kaiming uniform w/ a=sqrt(5)
+ gain = calculate_gain("leaky_relu", 5.0)
+ std = gain / math.sqrt(fan)
+ # LayerNorm handles division by number of dimensions...
+ return std
+
+
+class _BatchNorm(nn.BatchNorm1d, _NetLayer):
+ def __init__(self, num_features, *args, affine=True, **kwargs):
+ # Handle affine params outside of parent class
+ super().__init__(num_features, *args, affine=False, **kwargs)
+ self._num_features = num_features
+ assert affine
+ self._affine = affine
+
+ @property
+ def num_tensors(self) -> int:
+ return 2
+
+ def get_spec(self) -> LayerSpec:
+ return LayerSpec(
+ SpecialLayers.BATCHNORM,
+ ((self._num_features,), (self._num_features,)),
+ (1.0e-5, 1.0e-5),
+ (1.0, 0.0),
+ )
+
+ def forward(self, params, inputs):
+ """
+ Only change is we need to provide *params into F.batch_norm instead of
+ self.weight, self.bias
+ """
+ # Also use "inputs" instead of "input" to not collide w/ builtin (ew)
+ weight, bias = [z[:, :, None] for z in params]
+ pre_affine = super().forward(inputs)
+ return weight * pre_affine + bias
+
+
+class _Affine(nn.Module):
+ def __init__(self, scale: torch.Tensor, bias: torch.Tensor):
+ super().__init__()
+ self._weight = nn.Parameter(scale)
+ self._bias = nn.Parameter(bias)
+
+ @property
+ def bias(self) -> nn.Parameter:
+ return self._bias
+
+ @property
+ def weight(self) -> nn.Parameter:
+ return self._weight
+
+ def forward(self, inputs):
+ return self._weight * inputs + self._bias
+
+
+class HyperNet(nn.Module):
+ """
+ MLP followed by layer norms on split-up dims
+ """
+
+ def __init__(self, d_in, net, numels, norms, biases):
+ super().__init__()
+ self._net = net
+ # Figure out the scale factor empirically
+ norm0 = net(torch.randn((10_000, d_in))).std(dim=0).mean().item()
+ self._cum_numel = torch.cat(
+ [torch.LongTensor([0]), torch.cumsum(torch.LongTensor(numels), dim=0)]
+ )
+ affine_scale = torch.cat(
+ [torch.full((numel,), norm / norm0) for numel, norm in zip(numels, norms)]
+ )
+ affine_bias = torch.cat(
+ [
+ torch.full((numel,), bias, dtype=torch.float)
+ for numel, bias in zip(numels, biases)
+ ]
+ )
+ self._affine = _Affine(affine_scale, affine_bias)
+
+ @property
+ def batchnorm(self) -> bool:
+ """
+ Does the hypernet use batchnorm layers
+ """
+ return any(isinstance(m, nn.BatchNorm1d) for m in self.modules())
+
+ @property
+ def input_dim(self) -> int:
+ return self._net[0][0].weight.shape[1]
+
+ @property
+ def num_layers(self) -> int:
+ return len([layer for layer in self._net if isinstance(layer, _HyperNetBlock)])
+
+ @property
+ def num_units(self) -> int:
+ return self._net[0][0].weight.shape[0]
+
+ def forward(self, x) -> Tuple[torch.Tensor]:
+ """
+ Just return a flat array of param tensors for now
+ """
+ y = self._affine(self._net(x))
+ return tuple(
+ y[:, i:j] for i, j in zip(self._cum_numel[:-1], self._cum_numel[1:])
+ )
+
+ def get_export_params(self) -> np.ndarray:
+ params = []
+ for block in self._net[:-1]:
+ linear = block[0]
+ params.append(linear.weight.flatten())
+ params.append(linear.bias.flatten())
+ if self.batchnorm:
+ bn = block[1]
+ params.append(bn.running_mean.flatten())
+ params.append(bn.running_var.flatten())
+ params.append(bn.weight.flatten())
+ params.append(bn.bias.flatten())
+ params.append(torch.Tensor([bn.eps]).to(bn.weight.device))
+ assert len(block) <= 3, "Linear-(BN)-activation"
+ assert (
+ len([p for p in block[-1].parameters()]) == 0
+ ), "No params in activation"
+ head = self._net[-1]
+ params.append(head.weight.flatten())
+ params.append(head.bias.flatten())
+ affine = self._affine
+ params.append(affine.weight.flatten())
+ params.append(affine.bias.flatten())
+ return torch.cat(params).detach().cpu().numpy()
+
+
+class _Activation(abc.ABC):
+ """
+ Indicate that a module is an activation within the main net
+ """
+
+ @abc.abstractproperty
+ def name(self) -> str:
+ """
+ What to call the layer by when making a config w/ it
+ """
+ pass
+
+
+def _extend_activation(C, name: str) -> _Activation:
+ class Activation(C, _NetLayer, _Activation):
+ @property
+ def name(self):
+ return name
+
+ @property
+ def num_tensors(self) -> int:
+ return 0
+
+ def get_spec(self) -> LayerSpec:
+ return LayerSpec(None, (), (), ())
+
+ def forward(self, params, inputs):
+ return super().forward(inputs)
+
+ return Activation
+
+
+_Tanh = _extend_activation(nn.Tanh, "Tanh")
+_ReLU = _extend_activation(nn.ReLU, "ReLU")
+_Flatten = _extend_activation(nn.Flatten, "Flatten") # Hah, works
+
+
+def _get_activation(name):
+ return {"Tanh": _Tanh, "ReLU": _ReLU}[name]()
+
+
+class _HyperNetBlock(nn.Sequential):
+ """
+ For IDing blocks of the hypernet
+ """
+
+ pass
+
+
+class HyperConvNet(ParametricBaseNet):
+ """
+ For parameteric data
+
+ Conditioning is input to a hypernetwork that outputs the parameters of the conv net.
+ """
+
+ def __init__(
+ self, hyper_net: HyperNet, net: Callable[[Any, torch.Tensor], torch.Tensor]
+ ):
+ super().__init__()
+ self._hyper_net = hyper_net
+ self._net = net
+
+ @classmethod
+ def parse_config(cls, config):
+ config = super().parse_config(config)
+ net, specs = cls._get_net(config["net"])
+ hyper_net = cls._get_hyper_net(config["hyper_net"], specs)
+ return {"hyper_net": hyper_net, "net": net}
+
+ @property
+ def pad_start_default(self) -> bool:
+ return True
+
+ @property
+ def receptive_field(self) -> int:
+ # Last conv is the collapser--compensate w/ a minus 1
+ return sum([m.dilation[0] for m in self._net if isinstance(m, _Conv)]) + 1 - 1
+
+ def export(self, outdir: Path):
+ """
+ Files created:
+ * config.json
+ * weights.npy
+ * test_signal_params.npy
+ * test_signal_input.npy
+ * test_signal_output.npy
+
+ weights are serialized to weights.npy in the following order:
+ * Hypernet
+ * Loop layers:
+ * Linear
+ * weights (din*dout)
+ * biases (dout)
+ * BN
+ * running_mean (d)
+ * running_var (d)
+ * weight (d)
+ * bias (d)
+ * eps ()
+ * activation
+ * (Assume no params bc Tanh)
+ * Linear out
+ * weights (units*dy)
+ * bias (dy)
+ * affine
+ * weights (dy)
+ * bias (dy)
+ * Main net
+ (Layers are conv-BN-act)
+ (All params are outputted by the hypernet except the BatchNorm buffers!)
+ * Loop layers
+ * BN
+ * running_mean
+ * running_var
+ * eps ()
+ * (flatten: no params)
+
+ A test input & output are also provided, input.npy and output.npy
+ """
+ training = self.training
+ self.eval()
+ with open(Path(outdir, "config.json"), "w") as fp:
+ json.dump(self._export_config(), fp, indent=4)
+
+ # Hope I don't regret using np.save...
+ np.save(Path(outdir, "weights.npy"), self._export_weights())
+
+ # And an input/output to verify correct computation:
+ params, x, y = self._export_input_output()
+ np.save(Path(outdir, "test_signal_params.npy"), params.detach().cpu().numpy())
+ np.save(Path(outdir, "test_signal_input.npy"), x.detach().cpu().numpy())
+ np.save(Path(outdir, "test_signal_output.npy"), y.detach().cpu().numpy())
+
+ # And resume training state
+ self.train(training)
+
+ def export_cpp_header(self, filename: Path):
+ with TemporaryDirectory() as tmpdir:
+ tmpdir = Path(tmpdir)
+ self.export(Path(tmpdir))
+ with open(Path(tmpdir, "config.json"), "r") as fp:
+ _c = json.load(fp)
+ version = _c["version"]
+ config = _c["config"]
+ params = np.load(Path(tmpdir, "weights.npy"))
+ with open(filename, "w") as f:
+ f.writelines(
+ (
+ "#pragma once\n",
+ "// Automatically-generated model file\n",
+ "// HyperConvNet model\n" "#include <vector>\n",
+ f'#define PYTHON_MODEL_VERSION "{version}"\n',
+ f'const int HYPER_NET_INPUT_DIM = {config["hyper_net"]["input_dim"]};\n',
+ f'const int HYPER_NET_NUM_LAYERS = {config["hyper_net"]["num_layers"]};\n',
+ f'const int HYPER_NET_NUM_UNITS = {config["hyper_net"]["num_units"]};\n',
+ f'const int HYPER_NET_BATCHNORM = {"true" if config["hyper_net"]["batchnorm"] else "false"};\n',
+ f"const int CHANNELS = {config['net']['channels']};\n",
+ f"const bool BATCHNORM = {'true' if config['net']['batchnorm'] else 'false'};\n",
+ "std::vector<int> DILATIONS{"
+ + ",".join([str(d) for d in config["net"]["dilations"]])
+ + "};\n",
+ f"const std::string ACTIVATION = \"{config['net']['activation']}\";\n",
+ "std::vector<float> PARAMS{"
+ + ",".join([f"{w:.16f}" for w in params])
+ + "};\n",
+ )
+ )
+
+ @classmethod
+ def _get_net(cls, config):
+ channels = config["channels"]
+ dilations = config["dilations"]
+ batchnorm = config["batchnorm"]
+ activation = config["activation"]
+
+ layers = []
+ layer_specs = []
+ cin = 1
+ for dilation in dilations:
+ layer = _Conv(cin, channels, 2, dilation=dilation, bias=not batchnorm)
+ layers.append(layer)
+ layer_specs.append(layer.get_spec())
+ if batchnorm:
+ # Slow momentum on main net bc it's wild
+ layer = _BatchNorm(channels, momentum=0.01)
+ layers.append(layer)
+ layer_specs.append(layer.get_spec())
+ layer = _get_activation(activation)
+ layers.append(layer)
+ layer_specs.append(layer.get_spec())
+ cin = channels
+ layer = _Conv(cin, 1, 1)
+ layers.append(layer)
+ layer_specs.append(layer.get_spec())
+ layer = _Flatten()
+ layers.append(layer)
+ layer_specs.append(layer.get_spec())
+
+ return nn.ModuleList(layers), layer_specs
+
+ @classmethod
+ def _get_hyper_net(cls, config, specs) -> HyperNet:
+ def block(dx, dy, batchnorm) -> _HyperNetBlock:
+ layer_list = [nn.Linear(dx, dy)]
+ if batchnorm:
+ layer_list.append(nn.BatchNorm1d(dy))
+ layer_list.append(nn.ReLU())
+ return _HyperNetBlock(*layer_list)
+
+ num_inputs = config["num_inputs"]
+ num_layers = config["num_layers"]
+ num_units = config["num_units"]
+ batchnorm = config["batchnorm"]
+ # Flatten specs
+ numels = [np.prod(np.array(shape)) for spec in specs for shape in spec.shapes]
+ norms = [norm for spec in specs for norm in spec.norms]
+ biases = [bias for spec in specs for bias in spec.biases]
+ num_outputs = sum(numels)
+
+ din, layer_list = num_inputs, []
+ for _ in range(num_layers):
+ layer_list.append(block(din, num_units, batchnorm))
+ din = num_units
+ layer_list.append(nn.Linear(din, num_outputs))
+ net = nn.Sequential(*layer_list)
+
+ return HyperNet(num_inputs, net, numels, norms, biases)
+
+ @property
+ def _activation(self) -> str:
+ """
+ What activation does the main net use
+ """
+ for m in self._net.modules():
+ if isinstance(m, _Activation):
+ return m.name
+
+ @property
+ def _batchnorm(self) -> bool:
+ return any(isinstance(x, _BatchNorm) for x in self._net)
+
+ @property
+ def _channels(self) -> int:
+ return self._net[0].weight.shape[0]
+
+ @property
+ def _net_no_head(self):
+ return self._net[:-2]
+
+ def _forward(self, params: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
+ net_params = self._hyper_net(params)
+ i = 0
+ for m in self._net:
+ j = i + m.num_tensors
+ x = m(net_params[i:j], x)
+ i = j
+ assert j == len(net_params)
+ return x
+
+ def _get_dilations(self) -> List[int]:
+ dilations = []
+ for (
+ layer
+ ) in self._net_no_head: # Last two layers are a 1D conv head and flatten
+ if isinstance(layer, _Conv):
+ dilations.append(layer.dilation[0])
+ return dilations
+
+ def _export_config(self):
+ return {
+ "version": __version__,
+ "architecture": "HyperConvNet",
+ "config": {
+ "hyper_net": {
+ "input_dim": self._hyper_net.input_dim,
+ "num_layers": self._hyper_net.num_layers,
+ "num_units": self._hyper_net.num_units,
+ "batchnorm": self._hyper_net.batchnorm,
+ },
+ "net": {
+ "channels": self._channels,
+ "dilations": self._get_dilations(),
+ "batchnorm": self._batchnorm,
+ "activation": self._activation,
+ },
+ },
+ }
+
+ def _export_weights(self) -> np.ndarray:
+ """
+ Flatten the parameters of the network to be exported.
+ See doctsring for .export() for ensured layout.
+ """
+ return np.concatenate(
+ [self._hyper_net.get_export_params(), self._export_net_weights()]
+ )
+
+ def _export_net_weights(self) -> np.ndarray:
+ """
+ Only the buffers--parameters are outputted by the hypernet!
+ """
+ params = []
+ for bn in self._net_no_head:
+ if isinstance(bn, _BatchNorm):
+ params.append(bn.running_mean.flatten())
+ params.append(bn.running_var.flatten())
+ params.append(torch.Tensor([bn.eps]).to(bn.running_mean.device))
+ return (
+ np.array([])
+ if len(params) == 0
+ else torch.cat(params).detach().cpu().numpy()
+ )
+
+ def _export_input_output(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ params = torch.randn((self._hyper_net.input_dim,))
+ x = torch.randn((2 * self.receptive_field,))
+ x = 0.5 * x / x.abs().max()
+ y = self(params, x)
+ return params, x, y
diff --git a/nam/models/parametric/params.py b/nam/models/parametric/params.py
@@ -0,0 +1,71 @@
+# File: params.py
+# Created Date: Sunday July 17th 2022
+# Author: Steven Atkinson ([email protected])
+
+"""
+Handling parametric inputs
+"""
+
+import abc
+import inspect
+from dataclasses import dataclass, fields
+from enum import Enum
+from typing import Any
+
+from ..._core import InitializableFromConfig
+
+
+# class ParamType(Enum):
+# CONTINUOUS = "continuous"
+# BOOLEAN = "boolean"
+
+
+@dataclass
+class Param(InitializableFromConfig):
+ default_value: Any
+
+ @classmethod
+ def init_from_config(cls, config):
+ C, kwargs = cls.parse_config(config)
+ return C(**kwargs)
+
+ @classmethod
+ def parse_config(cls, config):
+ for C in [
+ _C
+ for _C in globals().values()
+ if inspect.isclass(_C) and _C is not Param and issubclass(_C, Param)
+ ]:
+ if C.typestr() == config["type"]:
+ config.pop("type")
+ break
+ else:
+ raise ValueError(f"Unrecognized aprameter type {config['type']}")
+ return C, config
+
+ @abc.abstractclassmethod
+ def typestr(cls) -> str:
+ pass
+
+ def to_json(self):
+ return {
+ "type": self.typestr(),
+ **{f.name: getattr(self, f.name) for f in fields(self)},
+ }
+
+
+@dataclass
+class BooleanParam(Param):
+ @classmethod
+ def typestr(cls) -> str:
+ return "boolean"
+
+
+@dataclass
+class ContinuousParam(Param):
+ minval: float
+ maxval: float
+
+ @classmethod
+ def typestr(self) -> str:
+ return "continuous"
diff --git a/nam/models/recurrent.py b/nam/models/recurrent.py
@@ -0,0 +1,183 @@
+# File: recurrent.py
+# Created Date: Saturday July 2nd 2022
+# Author: Steven Atkinson ([email protected])
+
+"""
+Recurrent models (LSTM)
+
+TODO batch_first=False (I get it...)
+"""
+
+import math
+from pathlib import Path
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ..data import REQUIRED_RATE
+from ._base import BaseNet
+
+
+class LSTM(BaseNet):
+ """
+ ABC for recurrent architectures
+ """
+
+ def __init__(
+ self,
+ hidden_size,
+ train_burn_in: Optional[int] = None,
+ train_truncate: Optional[int] = None,
+ input_size: int = 1,
+ **lstm_kwargs,
+ ):
+ """
+ :param hidden_size: for LSTM
+ :param train_burn_in: Detach calculations from first (this many) samples when
+ training to burn in the hidden state.
+ :param train_truncate: detach the hidden & cell states every this many steps
+ during training so that backpropagation through time is faster + to simulate
+ better starting states for h(t0)&c(t0) (instead of zeros)
+ TODO recognition head to start the hidden state in a good place?
+ :param input_size: Usually 1 (mono input). A catnet extending this might change
+ it and provide the parametric inputs as additional input dimensions.
+ """
+ super().__init__()
+ if "batch_first" in lstm_kwargs:
+ raise ValueError("batch_first cannot be set.")
+ self._input_size = input_size
+ self._core = nn.LSTM(
+ self._input_size, hidden_size, batch_first=True, **lstm_kwargs
+ )
+ self._head = nn.Linear(hidden_size, 1)
+ self._train_burn_in = train_burn_in
+ self._train_truncate = train_truncate
+
+ @property
+ def receptive_field(self) -> int:
+ return 1
+
+ @property
+ def pad_start_default(self) -> bool:
+ # I should simplify this...
+ return True
+
+ def export_cpp_header(self, filename: Path):
+ raise NotImplementedError()
+
+ def _forward(self, x):
+ """
+ :param x: (B,L) or (B,L,D)
+ :return: (B,L)
+ """
+ if x.ndim==2:
+ x = x[:, :, None]
+ if not self.training or self._train_truncate is None:
+ output_features = self._core(x)[0]
+ else:
+ last_hidden_state = None
+ output_features_list = []
+ if self._train_burn_in is not None:
+ last_output_features, last_hidden_state = self._core(
+ x[:, : self._train_burn_in, :]
+ )
+ output_features_list.append(last_output_features.detach())
+ burn_in_offset = 0 if self._train_burn_in is None else self._train_burn_in
+ for i in range(burn_in_offset, x.shape[1], self._train_truncate):
+ last_output_features, last_hidden_state = self._core(
+ x[:, i : i + self._train_truncate, :,],
+ None
+ if last_hidden_state is None
+ else tuple(z.detach() for z in last_hidden_state),
+ )
+ output_features_list.append(last_output_features)
+ output_features = torch.cat(output_features_list, dim=1)
+ return self._head(output_features)[:, :, 0]
+
+ def _export_cell_weights(
+ self, i: int, hidden_state: torch.Tensor, cell_state: torch.Tensor
+ ) -> np.ndarray:
+ """
+ * weight matrix (xh -> ifco)
+ * bias vector
+ * Initial hidden state
+ * Initial cell state
+ """
+
+ tensors = [
+ torch.cat(
+ [
+ getattr(self._core, f"weight_ih_l{i}").data,
+ getattr(self._core, f"weight_hh_l{i}").data,
+ ],
+ dim=1,
+ ),
+ getattr(self._core, f"bias_ih_l{i}").data
+ + getattr(self._core, f"bias_hh_l{i}").data,
+ hidden_state,
+ cell_state,
+ ]
+ return np.concatenate([z.detach().cpu().numpy().flatten() for z in tensors])
+
+ def _export_config(self):
+ return {
+ "input_size": self._core.input_size,
+ "hidden_size": self._core.hidden_size,
+ "num_layers": self._core.num_layers,
+ }
+
+ def _export_input_output(self, x=None):
+ x = self._export_input_signal()[None, :, None] if x is None else x
+ y = self._head(self._core(x, self._get_initial_state())[0][0])[:, 0]
+ return x[0, :, 0].detach().cpu().numpy(), y.detach().cpu().numpy().flatten()
+
+ def _export_input_signal(self):
+ rate = REQUIRED_RATE
+ return torch.cat(
+ [
+ torch.zeros((rate,)),
+ 0.5
+ * torch.sin(
+ 2.0 * math.pi * 220.0 * torch.linspace(0.0, 1.0, rate + 1)[:-1]
+ ),
+ torch.zeros((rate,)),
+ ]
+ )
+
+ def _export_weights(self):
+ """
+ * Loop over cells:
+ * weight matrix (xh -> ifco)
+ * bias vector
+ * Initial hidden state
+ * Initial cell state
+ * Head weights
+ * Head bias
+ """
+ return np.concatenate(
+ [
+ self._export_cell_weights(i, h, c)
+ for i, (h, c) in enumerate(zip(*self._get_initial_state()))
+ ]
+ + [
+ self._head.weight.data.detach().cpu().numpy().flatten(),
+ self._head.bias.data.detach().cpu().numpy().flatten(),
+ ]
+ )
+
+ def _get_initial_state(self, inputs=None) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Convenience function to find a good hidden state to start the plugin at
+
+ DX=input size
+ L=num layers
+ S=sequence length
+ :param inputs: (1,S,DX)
+
+ :return: (L,DH), (L,DH)
+ """
+ inputs = torch.zeros((1, 48_000, 1)) if inputs is None else inputs
+ _, (h, c) = self._core(inputs)
+ return h, c
diff --git a/tests/test_nam/test_data.py b/tests/test_nam/test_data.py
@@ -0,0 +1,24 @@
+# File: test_data.py
+# Created Date: Friday May 6th 2022
+# Author: Steven Atkinson ([email protected])
+
+import pytest
+import torch
+
+from nam import data
+
+
+class TestDataset(object):
+ def test_init(self):
+ x, y = self._create_xy()
+ data.Dataset(x, y, 3, None)
+
+ def test_init_zero_delay(self):
+ """
+ Assert https://github.com/sdatkinson/neural-amp-modeler/issues/15 fixed
+ """
+ x, y = self._create_xy()
+ data.Dataset(x, y, 3, None, delay=0)
+
+ def _create_xy(self):
+ return 0.99 * (2.0 * torch.rand((2, 7)) - 1.0) # Don't clip
diff --git a/tests/test_nam/test_models/__init__.py b/tests/test_nam/test_models/__init__.py
diff --git a/tests/test_nam/test_models/base.py b/tests/test_nam/test_models/base.py
@@ -0,0 +1,31 @@
+# File: base.py
+# Created Date: Saturday June 4th 2022
+# Author: Steven Atkinson ([email protected])
+
+import abc
+from pathlib import Path
+from tempfile import TemporaryDirectory
+
+
+class Base(abc.ABC):
+ @classmethod
+ def setup_class(cls, C, args=None, kwargs=None):
+ cls._C = C
+ cls._args = () if args is None else args
+ cls._kwargs = {} if kwargs is None else kwargs
+
+ def test_init(self, args=None, kwargs=None):
+ obj = self._construct(args=args, kwargs=kwargs)
+ assert isinstance(obj, self._C)
+
+ def test_export(self, args=None, kwargs=None):
+ model = self._construct(args=args, kwargs=kwargs)
+ with TemporaryDirectory() as tmpdir:
+ model.export(Path(tmpdir))
+
+ def _construct(self, C=None, args=None, kwargs=None):
+ C = self._C if C is None else C
+ args = args if args is not None else self._args
+ kwargs = kwargs if kwargs is not None else self._kwargs
+ return C(*args, **kwargs)
+
diff --git a/tests/test_nam/test_models/test_conv_net.py b/tests/test_nam/test_models/test_conv_net.py
@@ -0,0 +1,40 @@
+# File: test_conv_net.py
+# Created Date: Friday May 6th 2022
+# Author: Steven Atkinson ([email protected])
+
+from pathlib import Path
+from tempfile import TemporaryDirectory
+
+import pytest
+
+from nam.models import conv_net
+
+from .base import Base
+
+
+class TestConvNet(Base):
+ @classmethod
+ def setup_class(cls):
+ channels = 3
+ dilations = [1, 2, 4]
+ return super().setup_class(
+ conv_net.ConvNet,
+ (channels, dilations),
+ {"batchnorm": False, "activation": "Tanh"},
+ )
+
+ @pytest.mark.parametrize(
+ ("batchnorm,activation"), ((False, "ReLU"), (True, "Tanh"))
+ )
+ def test_init(self, batchnorm, activation):
+ super().test_init(kwargs={"batchnorm": batchnorm, "activation": activation})
+
+ @pytest.mark.parametrize(
+ ("batchnorm,activation"), ((False, "ReLU"), (True, "Tanh"))
+ )
+ def test_export(self, batchnorm, activation):
+ super().test_export(kwargs={"batchnorm": batchnorm, "activation": activation})
+
+
+if __name__ == "__main__":
+ pytest.main()
diff --git a/tests/test_nam/test_models/test_parametric/__init__.py b/tests/test_nam/test_models/test_parametric/__init__.py
@@ -0,0 +1,4 @@
+# File: __init__.py
+# Created Date: Sunday July 17th 2022
+# Author: Steven Atkinson ([email protected])
+
diff --git a/tests/test_nam/test_models/test_parametric/test_catnets.py b/tests/test_nam/test_models/test_parametric/test_catnets.py
@@ -0,0 +1,50 @@
+# File: test_catnets.py
+# Created Date: Sunday July 17th 2022
+# Author: Steven Atkinson ([email protected])
+
+from pathlib import Path
+from tempfile import TemporaryDirectory
+
+import pytest
+
+from ..base import Base
+
+from nam.models.parametric import catnets, params
+
+
+_mock_params = {
+ "gain": params.ContinuousParam(0.5, 0.0, 1.0),
+ "tone": params.ContinuousParam(0.5, 0.0, 1.0),
+ "level": params.ContinuousParam(0.5, 0.0, 1.0),
+}
+
+
+class _ParametricBase(Base):
+ pass
+
+
+class TestCatLSTM(_ParametricBase):
+ @classmethod
+ def setup_class(cls):
+ # Using init_from_config
+ return super().setup_class(
+ catnets.CatLSTM,
+ args=(),
+ kwargs={
+ "num_layers": 1,
+ "hidden_size": 2,
+ "train_truncate": 11,
+ "train_burn_in": 7,
+ "input_size": 1 + len(_mock_params),
+ },
+ )
+
+ def test_export(self, args=None, kwargs=None):
+ # Override to provide params info
+ model = self._construct(args=args, kwargs=kwargs)
+ with TemporaryDirectory() as tmpdir:
+ model.export(Path(tmpdir), _mock_params)
+
+
+if __name__ == "__main__":
+ pytest.main()
diff --git a/tests/test_nam/test_models/test_parametric/test_hyper_net.py b/tests/test_nam/test_models/test_parametric/test_hyper_net.py
@@ -0,0 +1,86 @@
+# File: test_hyper_net.py
+# Created Date: Saturday June 4th 2022
+# Author: Steven Atkinson ([email protected])
+
+from pathlib import Path
+from tempfile import TemporaryDirectory
+
+import pytest
+
+from nam.models.parametric import hyper_net
+
+from ..base import Base
+
+
+class TestHyperConvNet(Base):
+ @classmethod
+ def setup_class(cls):
+ return super().setup_class(hyper_net.HyperConvNet, (), {})
+
+ @pytest.mark.parametrize(
+ ("batchnorm,activation"), ((False, "ReLU"), (True, "Tanh"))
+ )
+ def test_init(self, batchnorm, activation):
+ # TODO refactor
+ channels = 3
+ dilations = [1, 2, 4]
+ assert isinstance(
+ self._construct(
+ self._config(
+ batchnorm=batchnorm,
+ activation=activation,
+ dilations=dilations,
+ channels=channels,
+ )
+ ),
+ hyper_net.HyperConvNet,
+ )
+
+ @pytest.mark.parametrize(
+ ("batchnorm,activation"), ((False, "ReLU"), (True, "Tanh"))
+ )
+ def test_export(self, batchnorm, activation):
+ # TODO refactor
+ channels = 3
+ dilations = [1, 2, 4]
+ model = self._construct(
+ self._config(
+ batchnorm=batchnorm,
+ activation=activation,
+ dilations=dilations,
+ channels=channels,
+ )
+ )
+ with TemporaryDirectory() as tmpdir:
+ model.export(Path(tmpdir))
+
+ def test_export_cpp_header(self):
+ # TODO refactor
+ with TemporaryDirectory() as tmpdir:
+ self._construct().export_cpp_header(Path(tmpdir, "model.h"))
+
+ def _config(self, batchnorm=True, activation="Tanh", dilations=None, channels=7):
+ dilations = [1, 2, 4] if dilations is None else dilations
+ return {
+ "net": {
+ "channels": channels,
+ "dilations": dilations,
+ "batchnorm": batchnorm,
+ "activation": activation,
+ },
+ "hyper_net": {
+ "num_inputs": 3,
+ "num_layers": 2,
+ "num_units": 11,
+ "batchnorm": True,
+ },
+ }
+
+ def _construct(self, config=None):
+ # Override for simplicity...
+ config = self._config() if config is None else config
+ return self._C.init_from_config(config)
+
+
+if __name__ == "__main__":
+ pytest.main()
diff --git a/tests/test_nam/test_models/test_recurrent.py b/tests/test_nam/test_models/test_recurrent.py
@@ -0,0 +1,19 @@
+# File: test_recurrent.py
+# Created Date: Sunday July 17th 2022
+# Author: Steven Atkinson ([email protected])
+
+from .base import Base
+
+from nam.models import recurrent
+
+
+class TestLSTM(Base):
+ @classmethod
+ def setup_class(cls):
+ hidden_size = 3
+ return super().setup_class(
+ recurrent.LSTM,
+ args=(hidden_size,),
+ kwargs={"train_burn_in": 3, "train_truncate": 5, "num_layers": 2},
+ )
+