test_exportable.py (8409B)
1 # File: test_exportable.py 2 # Created Date: Sunday January 29th 2023 3 # Author: Steven Atkinson ([email protected]) 4 5 """ 6 Test export behavior of models 7 """ 8 9 import json 10 from enum import Enum 11 from pathlib import Path 12 from tempfile import TemporaryDirectory 13 from typing import Optional, Tuple 14 15 import numpy as np 16 import pytest 17 import torch 18 import torch.nn as nn 19 from pydantic import BaseModel 20 21 from nam.models import exportable, metadata 22 from nam.train import metadata as train_metadata 23 24 25 class TestExportable(object): 26 def test_export(self): 27 """ 28 Does it work? 29 """ 30 31 model = self._get_model() 32 with TemporaryDirectory() as tmpdir: 33 model.export(tmpdir) 34 model_basename = "model.nam" 35 model_path = Path(tmpdir, model_basename) 36 assert model_path.exists() 37 with open(model_path, "r") as fp: 38 model_dict = json.load(fp) 39 required_keys = {"version", "architecture", "config", "weights"} 40 for key in required_keys: 41 assert key in model_dict 42 weights_list = model_dict["weights"] 43 assert isinstance(weights_list, list) 44 assert len(weights_list) == 2 45 assert all(isinstance(w, float) for w in weights_list) 46 47 @pytest.mark.parametrize( 48 "user_metadata,other_metadata", 49 ( 50 (None, None), 51 (metadata.UserMetadata(), None), 52 ( 53 metadata.UserMetadata( 54 name="My Model", 55 modeled_by="Steve", 56 gear_type=metadata.GearType.AMP, 57 gear_make="SteveCo", 58 gear_model="SteveAmp", 59 tone_type=metadata.ToneType.HI_GAIN, 60 input_level_dbu=-6.5, 61 output_level_dbu=-12.5, 62 ), 63 None, 64 ), 65 ( 66 None, 67 { 68 train_metadata.TRAINING_KEY: train_metadata.TrainingMetadata( 69 settings=train_metadata.Settings( 70 fit_cab=True, ignore_checks=False 71 ), 72 data=train_metadata.Data( 73 latency=train_metadata.Latency( 74 manual=None, 75 calibration=train_metadata.LatencyCalibration( 76 algorithm_version=1, 77 delays=[1, 3], 78 safety_factor=4, 79 recommended=-3, 80 warnings=train_metadata.LatencyCalibrationWarnings( 81 matches_lookahead=False, 82 disagreement_too_high=False, 83 ), 84 ), 85 ), 86 checks=train_metadata.DataChecks(version=4, passed=True), 87 ), 88 validation_esr=0.01, 89 ).model_dump() 90 }, 91 ), 92 ( 93 metadata.UserMetadata( 94 name="My Model", 95 modeled_by="Steve", 96 gear_type=metadata.GearType.AMP, 97 gear_make="SteveCo", 98 gear_model="SteveAmp", 99 tone_type=metadata.ToneType.HI_GAIN, 100 ), 101 { 102 train_metadata.TRAINING_KEY: train_metadata.TrainingMetadata( 103 settings=train_metadata.Settings( 104 fit_cab=True, ignore_checks=False 105 ), 106 data=train_metadata.Data( 107 latency=train_metadata.Latency( 108 manual=None, 109 calibration=train_metadata.LatencyCalibration( 110 algorithm_version=1, 111 delays=[1, 3], 112 safety_factor=4, 113 recommended=-3, 114 warnings=train_metadata.LatencyCalibrationWarnings( 115 matches_lookahead=False, 116 disagreement_too_high=False, 117 ), 118 ), 119 ), 120 checks=train_metadata.DataChecks(version=4, passed=True), 121 ), 122 validation_esr=0.01, 123 ).model_dump() 124 }, 125 ), 126 ), 127 ) 128 def test_export_metadata( 129 self, 130 user_metadata: Optional[metadata.UserMetadata], 131 other_metadata: Optional[dict], 132 ): 133 """ 134 Assert export behavior when metadata is provided 135 """ 136 137 def assert_metadata(actual: dict, expected: dict): 138 assert isinstance(actual, dict) 139 for key, expected_value in expected.items(): 140 assert key in actual 141 actual_value = actual[key] 142 if isinstance(expected_value, BaseModel): 143 assert_metadata(actual_value, expected_value) 144 else: 145 if isinstance(expected_value, Enum): 146 expected_value = expected_value.value 147 assert actual_value == expected_value 148 149 model = self._get_model() 150 with TemporaryDirectory() as tmpdir: 151 model.export( 152 tmpdir, user_metadata=user_metadata, other_metadata=other_metadata 153 ) 154 model_basename = "model.nam" 155 model_path = Path(tmpdir, model_basename) 156 assert model_path.exists() 157 with open(model_path, "r") as fp: 158 model_dict = json.load(fp) 159 metadata_key = "metadata" 160 training_key = train_metadata.TRAINING_KEY 161 assert metadata_key in model_dict 162 model_dict_metadata = model_dict[metadata_key] 163 if user_metadata is not None: 164 assert_metadata(model_dict_metadata, user_metadata.model_dump()) 165 if other_metadata is not None: 166 assert training_key in model_dict_metadata 167 assert_metadata(model_dict_metadata, other_metadata) 168 169 @pytest.mark.parametrize("include_snapshot", (True, False)) 170 def test_include_snapshot(self, include_snapshot): 171 """ 172 Does the option to include a snapshot work? 173 """ 174 model = self._get_model() 175 176 with TemporaryDirectory() as tmpdir: 177 model.export(tmpdir, include_snapshot=include_snapshot) 178 input_path = Path(tmpdir, "test_inputs.npy") 179 output_path = Path(tmpdir, "test_outputs.npy") 180 if include_snapshot: 181 assert input_path.exists() 182 assert output_path.exists() 183 # And check that the output is correct 184 x = np.load(input_path) 185 y = np.load(output_path) 186 preds = model(torch.Tensor(x)).detach().cpu().numpy() 187 assert preds == pytest.approx(y) 188 else: 189 assert not input_path.exists() 190 assert not output_path.exists() 191 192 @classmethod 193 def _get_model(cls): 194 class Model(nn.Module, exportable.Exportable): 195 def __init__(self): 196 super().__init__() 197 self._scale = nn.Parameter(torch.tensor(0.0)) 198 self._bias = nn.Parameter(torch.tensor(0.0)) 199 200 def forward(self, x: torch.Tensor): 201 return self._scale * x + self._bias 202 203 def export_cpp_header(self, filename: Path): 204 pass 205 206 def _export_config(self): 207 return {} 208 209 def _export_input_output(self) -> Tuple[np.ndarray, np.ndarray]: 210 x = 0.01 * np.random.randn( 211 3, 212 ) 213 y = self(torch.Tensor(x)).detach().cpu().numpy() 214 return x, y 215 216 def _export_weights(self) -> np.ndarray: 217 return torch.stack([self._scale, self._bias]).detach().cpu().numpy() 218 219 return Model()