neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

test_exportable.py (8409B)


      1 # File: test_exportable.py
      2 # Created Date: Sunday January 29th 2023
      3 # Author: Steven Atkinson ([email protected])
      4 
      5 """
      6 Test export behavior of models
      7 """
      8 
      9 import json
     10 from enum import Enum
     11 from pathlib import Path
     12 from tempfile import TemporaryDirectory
     13 from typing import Optional, Tuple
     14 
     15 import numpy as np
     16 import pytest
     17 import torch
     18 import torch.nn as nn
     19 from pydantic import BaseModel
     20 
     21 from nam.models import exportable, metadata
     22 from nam.train import metadata as train_metadata
     23 
     24 
     25 class TestExportable(object):
     26     def test_export(self):
     27         """
     28         Does it work?
     29         """
     30 
     31         model = self._get_model()
     32         with TemporaryDirectory() as tmpdir:
     33             model.export(tmpdir)
     34             model_basename = "model.nam"
     35             model_path = Path(tmpdir, model_basename)
     36             assert model_path.exists()
     37             with open(model_path, "r") as fp:
     38                 model_dict = json.load(fp)
     39             required_keys = {"version", "architecture", "config", "weights"}
     40             for key in required_keys:
     41                 assert key in model_dict
     42             weights_list = model_dict["weights"]
     43             assert isinstance(weights_list, list)
     44             assert len(weights_list) == 2
     45             assert all(isinstance(w, float) for w in weights_list)
     46 
     47     @pytest.mark.parametrize(
     48         "user_metadata,other_metadata",
     49         (
     50             (None, None),
     51             (metadata.UserMetadata(), None),
     52             (
     53                 metadata.UserMetadata(
     54                     name="My Model",
     55                     modeled_by="Steve",
     56                     gear_type=metadata.GearType.AMP,
     57                     gear_make="SteveCo",
     58                     gear_model="SteveAmp",
     59                     tone_type=metadata.ToneType.HI_GAIN,
     60                     input_level_dbu=-6.5,
     61                     output_level_dbu=-12.5,
     62                 ),
     63                 None,
     64             ),
     65             (
     66                 None,
     67                 {
     68                     train_metadata.TRAINING_KEY: train_metadata.TrainingMetadata(
     69                         settings=train_metadata.Settings(
     70                             fit_cab=True, ignore_checks=False
     71                         ),
     72                         data=train_metadata.Data(
     73                             latency=train_metadata.Latency(
     74                                 manual=None,
     75                                 calibration=train_metadata.LatencyCalibration(
     76                                     algorithm_version=1,
     77                                     delays=[1, 3],
     78                                     safety_factor=4,
     79                                     recommended=-3,
     80                                     warnings=train_metadata.LatencyCalibrationWarnings(
     81                                         matches_lookahead=False,
     82                                         disagreement_too_high=False,
     83                                     ),
     84                                 ),
     85                             ),
     86                             checks=train_metadata.DataChecks(version=4, passed=True),
     87                         ),
     88                         validation_esr=0.01,
     89                     ).model_dump()
     90                 },
     91             ),
     92             (
     93                 metadata.UserMetadata(
     94                     name="My Model",
     95                     modeled_by="Steve",
     96                     gear_type=metadata.GearType.AMP,
     97                     gear_make="SteveCo",
     98                     gear_model="SteveAmp",
     99                     tone_type=metadata.ToneType.HI_GAIN,
    100                 ),
    101                 {
    102                     train_metadata.TRAINING_KEY: train_metadata.TrainingMetadata(
    103                         settings=train_metadata.Settings(
    104                             fit_cab=True, ignore_checks=False
    105                         ),
    106                         data=train_metadata.Data(
    107                             latency=train_metadata.Latency(
    108                                 manual=None,
    109                                 calibration=train_metadata.LatencyCalibration(
    110                                     algorithm_version=1,
    111                                     delays=[1, 3],
    112                                     safety_factor=4,
    113                                     recommended=-3,
    114                                     warnings=train_metadata.LatencyCalibrationWarnings(
    115                                         matches_lookahead=False,
    116                                         disagreement_too_high=False,
    117                                     ),
    118                                 ),
    119                             ),
    120                             checks=train_metadata.DataChecks(version=4, passed=True),
    121                         ),
    122                         validation_esr=0.01,
    123                     ).model_dump()
    124                 },
    125             ),
    126         ),
    127     )
    128     def test_export_metadata(
    129         self,
    130         user_metadata: Optional[metadata.UserMetadata],
    131         other_metadata: Optional[dict],
    132     ):
    133         """
    134         Assert export behavior when metadata is provided
    135         """
    136 
    137         def assert_metadata(actual: dict, expected: dict):
    138             assert isinstance(actual, dict)
    139             for key, expected_value in expected.items():
    140                 assert key in actual
    141                 actual_value = actual[key]
    142                 if isinstance(expected_value, BaseModel):
    143                     assert_metadata(actual_value, expected_value)
    144                 else:
    145                     if isinstance(expected_value, Enum):
    146                         expected_value = expected_value.value
    147                     assert actual_value == expected_value
    148 
    149         model = self._get_model()
    150         with TemporaryDirectory() as tmpdir:
    151             model.export(
    152                 tmpdir, user_metadata=user_metadata, other_metadata=other_metadata
    153             )
    154             model_basename = "model.nam"
    155             model_path = Path(tmpdir, model_basename)
    156             assert model_path.exists()
    157             with open(model_path, "r") as fp:
    158                 model_dict = json.load(fp)
    159             metadata_key = "metadata"
    160             training_key = train_metadata.TRAINING_KEY
    161             assert metadata_key in model_dict
    162             model_dict_metadata = model_dict[metadata_key]
    163             if user_metadata is not None:
    164                 assert_metadata(model_dict_metadata, user_metadata.model_dump())
    165             if other_metadata is not None:
    166                 assert training_key in model_dict_metadata
    167                 assert_metadata(model_dict_metadata, other_metadata)
    168 
    169     @pytest.mark.parametrize("include_snapshot", (True, False))
    170     def test_include_snapshot(self, include_snapshot):
    171         """
    172         Does the option to include a snapshot work?
    173         """
    174         model = self._get_model()
    175 
    176         with TemporaryDirectory() as tmpdir:
    177             model.export(tmpdir, include_snapshot=include_snapshot)
    178             input_path = Path(tmpdir, "test_inputs.npy")
    179             output_path = Path(tmpdir, "test_outputs.npy")
    180             if include_snapshot:
    181                 assert input_path.exists()
    182                 assert output_path.exists()
    183                 # And check that the output is correct
    184                 x = np.load(input_path)
    185                 y = np.load(output_path)
    186                 preds = model(torch.Tensor(x)).detach().cpu().numpy()
    187                 assert preds == pytest.approx(y)
    188             else:
    189                 assert not input_path.exists()
    190                 assert not output_path.exists()
    191 
    192     @classmethod
    193     def _get_model(cls):
    194         class Model(nn.Module, exportable.Exportable):
    195             def __init__(self):
    196                 super().__init__()
    197                 self._scale = nn.Parameter(torch.tensor(0.0))
    198                 self._bias = nn.Parameter(torch.tensor(0.0))
    199 
    200             def forward(self, x: torch.Tensor):
    201                 return self._scale * x + self._bias
    202 
    203             def export_cpp_header(self, filename: Path):
    204                 pass
    205 
    206             def _export_config(self):
    207                 return {}
    208 
    209             def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]:
    210                 x = 0.01 * np.random.randn(
    211                     3,
    212                 )
    213                 y = self(torch.Tensor(x)).detach().cpu().numpy()
    214                 return x, y
    215 
    216             def _export_weights(self) -> np.ndarray:
    217                 return torch.stack([self._scale, self._bias]).detach().cpu().numpy()
    218 
    219         return Model()