test_base.py (2963B)
1 # File: test_base.py 2 # Created Date: Thursday March 16th 2023 3 # Author: Steven Atkinson ([email protected]) 4 5 """ 6 Tests for the base network and Lightning module 7 """ 8 9 import math 10 from pathlib import Path 11 from tempfile import TemporaryDirectory 12 from typing import Optional 13 14 import numpy as np 15 import pytest 16 import torch 17 18 from nam.models import base 19 20 21 class MockBaseNet(base.BaseNet): 22 def __init__(self, gain: float, *args, **kwargs): 23 super().__init__(*args, **kwargs) 24 self.gain = gain 25 26 @property 27 def pad_start_default(self) -> bool: 28 return True 29 30 @property 31 def receptive_field(self) -> int: 32 return 1 33 34 def export_cpp_header(self, filename: Path): 35 pass 36 37 def _export_config(self): 38 pass 39 40 def _export_weights(self) -> np.ndarray: 41 pass 42 43 def _forward(self, x: torch.Tensor) -> torch.Tensor: 44 return self.gain * x 45 46 47 def test_metadata_gain(): 48 obj = MockBaseNet(1.0) 49 g = obj._metadata_gain() 50 # It's linear, so gain is zero. 51 assert g == 0.0 52 53 54 def test_metadata_loudness(): 55 obj = MockBaseNet(1.0) 56 y = obj._metadata_loudness() 57 obj.gain = 2.0 58 y2 = obj._metadata_loudness() 59 assert isinstance(y, float) 60 # 2x louder = +6dB 61 assert y2 == pytest.approx(y + 20.0 * math.log10(2.0)) 62 63 64 class TestSampleRate(object): 65 """ 66 Tests for sample_rate interface 67 """ 68 69 @pytest.mark.parametrize("expected_sample_rate", (None, 44_100.0, 48_000.0)) 70 def test_on_init(self, expected_sample_rate: Optional[float]): 71 model = MockBaseNet(gain=1.0, sample_rate=expected_sample_rate) 72 self._wrap_assert(model, expected_sample_rate) 73 74 @pytest.mark.parametrize("expected_sample_rate", (None, 44_100.0, 48_000.0)) 75 def test_setter(self, expected_sample_rate: Optional[float]): 76 model = MockBaseNet(gain=1.0) 77 model.sample_rate = expected_sample_rate 78 self._wrap_assert(model, expected_sample_rate) 79 80 @pytest.mark.parametrize("expected_sample_rate", (None, 44_100.0, 48_000.0)) 81 def test_state_dict(self, expected_sample_rate: Optional[float]): 82 """ 83 Assert that it makes it into the state dict 84 85 https://github.com/sdatkinson/neural-amp-modeler/issues/351 86 """ 87 model = MockBaseNet(gain=1.0, sample_rate=expected_sample_rate) 88 with TemporaryDirectory() as tmpdir: 89 model_path = Path(tmpdir, "model.pt") 90 torch.save(model.state_dict(), model_path) 91 model2 = MockBaseNet(gain=1.0) 92 model2.load_state_dict(torch.load(model_path)) 93 self._wrap_assert(model2, expected_sample_rate) 94 95 @classmethod 96 def _wrap_assert(cls, model: MockBaseNet, expected: Optional[float]): 97 actual = model.sample_rate 98 if expected is None: 99 assert actual is None 100 else: 101 assert isinstance(actual, float) 102 assert actual == expected 103 104 105 if __name__ == "__main__": 106 pytest.main()