neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 5304a28f26daf5e5d0c387fbfc9037e8c9092f75
parent 1d354a00dc75719e49624cb5af296ee66d5e5155
Author: Steven Atkinson <[email protected]>
Date:   Sat,  1 Apr 2023 14:22:43 -0700

Metadata in model files, loudness as first metadatum (#168)

Define metadata, store loudness (#137)
Diffstat:
Mnam/_version.py | 2+-
Mnam/models/_activations.py | 3+--
Mnam/models/_base.py | 35++++++++++++++++++++++++-----------
Mnam/models/_exportable.py | 15+++++++++------
Mnam/models/parametric/__init__.py | 1-
Mnam/models/parametric/catnets.py | 24++++++++++++++++++------
Mnam/train/_version.py | 1+
Mtests/test_nam/test_models/base.py | 1-
Mtests/test_nam/test_models/test_base.py | 1+
Mtests/test_nam/test_models/test_parametric/__init__.py | 1-
Mtests/test_nam/test_models/test_recurrent.py | 6++++--
11 files changed, 59 insertions(+), 31 deletions(-)

diff --git a/nam/_version.py b/nam/_version.py @@ -1 +1 @@ -__version__ = "0.5.0" +__version__ = "0.5.1" diff --git a/nam/models/_activations.py b/nam/models/_activations.py @@ -6,4 +6,4 @@ import torch.nn as nn def get_activation(name: str) -> nn.Module: - return getattr(nn, name)() -\ No newline at end of file + return getattr(nn, name)() diff --git a/nam/models/_base.py b/nam/models/_base.py @@ -10,7 +10,6 @@ steps) import abc import math import pkg_resources -from pathlib import Path from typing import Any, Optional, Tuple import numpy as np @@ -38,17 +37,21 @@ class _Base(nn.Module, InitializableFromConfig, Exportable): def forward(self, *args, **kwargs) -> torch.Tensor: pass - def _loudness(self, gain: float=1.0) -> float: + def _loudness(self, gain: float = 1.0) -> float: """ How loud is this model when given a standardized input? In dB :param gain: Multiplies input signal """ - x = wav_to_tensor(pkg_resources.resource_filename("nam", "models/_resources/loudness_input.wav")) + x = wav_to_tensor( + pkg_resources.resource_filename( + "nam", "models/_resources/loudness_input.wav" + ) + ) y = self._at_nominal_settings(gain * x) return 10.0 * torch.log10(torch.mean(torch.square(y))).item() - + def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor: # parametric?... raise NotImplementedError() @@ -59,7 +62,7 @@ class _Base(nn.Module, InitializableFromConfig, Exportable): The true forward method. :param x: (N,L1) - :return: (N,L1-RF+1) + :return: (N,L1-RF+1) """ pass @@ -84,10 +87,15 @@ class _Base(nn.Module, InitializableFromConfig, Exportable): ) # Use pad start to ensure same length as requested by ._export_input_output() return ( - x.detach().cpu().numpy(), - self(*args, x, pad_start=True).detach().cpu().numpy() + x.detach().cpu().numpy(), + self(*args, x, pad_start=True).detach().cpu().numpy(), ) - + + def _get_export_dict(self): + d = super()._get_export_dict() + d["metadata"] = {} + return d + class BaseNet(_Base): def forward(self, x: torch.Tensor, pad_start: Optional[bool] = None): @@ -101,7 +109,7 @@ class BaseNet(_Base): if scalar: y = y[0] return y - + def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor: return self(x) @@ -111,10 +119,15 @@ class BaseNet(_Base): The true forward method. :param x: (N,L1) - :return: (N,L1-RF+1) + :return: (N,L1-RF+1) """ pass + def _get_export_dict(self): + d = super()._get_export_dict() + d["metadata"]["loudness"] = self._loudness() + return d + class ParametricBaseNet(_Base): """ @@ -143,6 +156,6 @@ class ParametricBaseNet(_Base): :param params: (N,D) :param x: (N,L1) - :return: (N,L1-RF+1) + :return: (N,L1-RF+1) """ pass diff --git a/nam/models/_exportable.py b/nam/models/_exportable.py @@ -39,12 +39,7 @@ class Exportable(abc.ABC): self.eval() with open(Path(outdir, modelname + ".nam"), "w") as fp: json.dump( - { - "version": __version__, - "architecture": self.__class__.__name__, - "config": self._export_config(), - "weights": self._export_weights().tolist(), - }, + self._get_export_dict(), fp, indent=4, ) @@ -103,3 +98,11 @@ class Exportable(abc.ABC): Flatten the weights out to a 1D array """ pass + + def _get_export_dict(self): + return { + "version": __version__, + "architecture": self.__class__.__name__, + "config": self._export_config(), + "weights": self._export_weights().tolist(), + } diff --git a/nam/models/parametric/__init__.py b/nam/models/parametric/__init__.py @@ -1,4 +1,3 @@ # File: __init__.py # Created Date: Sunday July 17th 2022 # Author: Steven Atkinson ([email protected]) - diff --git a/nam/models/parametric/catnets.py b/nam/models/parametric/catnets.py @@ -8,6 +8,7 @@ input samples """ import abc +import logging from enum import Enum from contextlib import contextmanager from pathlib import Path @@ -21,6 +22,8 @@ from ..recurrent import LSTM from ..wavenet import WaveNet from .params import Param +logger = logging.getLogger(__name__) + class _ShapeType(Enum): CONV = "conv" # (B,C,L) @@ -49,7 +52,7 @@ class _CatMixin(ParametricBaseNet): @abc.abstractproperty def _single_class(self): - """" + """ " The class for the non-parametric model that this is extending """ # TODO verify that single class satisfies requirements @@ -99,8 +102,11 @@ class _CatMixin(ParametricBaseNet): def _export_cpp_header_parametric(self, config): if config is None: return self._single_class._export_cpp_head_parametric(self, config) - s_parametric = ['nlohmann::json PARAMETRIC = nlohmann::json::parse(R"(\n', " {\n"] - for i, (key, val) in enumerate(config.items(), 1): + s_parametric = [ + 'nlohmann::json PARAMETRIC = nlohmann::json::parse(R"(\n', + " {\n", + ] + for i, (key, val) in enumerate(config.items(), 1): s_parametric.append(f' "{key}": ' "{\n") for j, (k2, v2) in enumerate(val.items(), 1): v_str = f'"{v2}"' if isinstance(v2, str) else str(v2) @@ -112,7 +118,6 @@ class _CatMixin(ParametricBaseNet): s_parametric.append(')");\n') return tuple(s_parametric) - def _export_input_output_args(self) -> Tuple[torch.Tensor]: return (self._sidedoor_params_to_tensor(),) @@ -186,6 +191,15 @@ class CatLSTM(_CatMixin, LSTM): dim=2, ) + def _at_nominal_settings(self, x: torch.Tensor) -> torch.Tensor: + if self._input_size != 1: + logger.warning( + "Nominal settings aren't defined for parametric models; outputting unity" + ) + return x + params = torch.zeros(()).to(x.device) + return self(params, x) + def _get_initial_state(self) -> Tuple[torch.Tensor, torch.Tensor]: inputs = self._append_default_params(torch.zeros((1, 48_000))) return super()._get_initial_state(inputs=inputs) @@ -199,4 +213,3 @@ class CatWaveNet(_CatMixin, WaveNet): @property def _single_class(self): return WaveNet - -\ No newline at end of file diff --git a/nam/train/_version.py b/nam/train/_version.py @@ -6,6 +6,7 @@ Version utility """ + class Version: def __init__(self, major: int, minor: int, patch: int): self.major = major diff --git a/tests/test_nam/test_models/base.py b/tests/test_nam/test_models/base.py @@ -28,4 +28,3 @@ class Base(abc.ABC): args = args if args is not None else self._args kwargs = kwargs if kwargs is not None else self._kwargs return C(*args, **kwargs) - diff --git a/tests/test_nam/test_models/test_base.py b/tests/test_nam/test_models/test_base.py @@ -40,6 +40,7 @@ class _MockBaseNet(_base.BaseNet): return self.gain * x + def test_loudness(): obj = _MockBaseNet(1.0) y = obj._loudness() diff --git a/tests/test_nam/test_models/test_parametric/__init__.py b/tests/test_nam/test_models/test_parametric/__init__.py @@ -1,4 +1,3 @@ # File: __init__.py # Created Date: Sunday July 17th 2022 # Author: Steven Atkinson ([email protected]) - diff --git a/tests/test_nam/test_models/test_recurrent.py b/tests/test_nam/test_models/test_recurrent.py @@ -47,12 +47,14 @@ class TestLSTM(Base): ] input_names = [z.name for z in session.get_inputs()] - onnx_inputs = {i: z.detach().cpu().numpy() for i, z in zip(input_names, (x, hin, cin))} + onnx_inputs = { + i: z.detach().cpu().numpy() for i, z in zip(input_names, (x, hin, cin)) + } y_actual, hout_actual, cout_actual = session.run([], onnx_inputs) def approx(val): return pytest.approx(val, rel=1.0e-6, abs=1.0e-6) - + assert y_expected == approx(y_actual) assert hout_expected == approx(hout_actual) assert cout_expected == approx(cout_actual)