neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

test_base.py (2963B)


      1 # File: test_base.py
      2 # Created Date: Thursday March 16th 2023
      3 # Author: Steven Atkinson ([email protected])
      4 
      5 """
      6 Tests for the base network and Lightning module
      7 """
      8 
      9 import math
     10 from pathlib import Path
     11 from tempfile import TemporaryDirectory
     12 from typing import Optional
     13 
     14 import numpy as np
     15 import pytest
     16 import torch
     17 
     18 from nam.models import base
     19 
     20 
     21 class MockBaseNet(base.BaseNet):
     22     def __init__(self, gain: float, *args, **kwargs):
     23         super().__init__(*args, **kwargs)
     24         self.gain = gain
     25 
     26     @property
     27     def pad_start_default(self) -> bool:
     28         return True
     29 
     30     @property
     31     def receptive_field(self) -> int:
     32         return 1
     33 
     34     def export_cpp_header(self, filename: Path):
     35         pass
     36 
     37     def _export_config(self):
     38         pass
     39 
     40     def _export_weights(self) -> np.ndarray:
     41         pass
     42 
     43     def _forward(self, x: torch.Tensor) -> torch.Tensor:
     44         return self.gain * x
     45 
     46 
     47 def test_metadata_gain():
     48     obj = MockBaseNet(1.0)
     49     g = obj._metadata_gain()
     50     # It's linear, so gain is zero.
     51     assert g == 0.0
     52 
     53 
     54 def test_metadata_loudness():
     55     obj = MockBaseNet(1.0)
     56     y = obj._metadata_loudness()
     57     obj.gain = 2.0
     58     y2 = obj._metadata_loudness()
     59     assert isinstance(y, float)
     60     # 2x louder = +6dB
     61     assert y2 == pytest.approx(y + 20.0 * math.log10(2.0))
     62 
     63 
     64 class TestSampleRate(object):
     65     """
     66     Tests for sample_rate interface
     67     """
     68 
     69     @pytest.mark.parametrize("expected_sample_rate", (None, 44_100.0, 48_000.0))
     70     def test_on_init(self, expected_sample_rate: Optional[float]):
     71         model = MockBaseNet(gain=1.0, sample_rate=expected_sample_rate)
     72         self._wrap_assert(model, expected_sample_rate)
     73 
     74     @pytest.mark.parametrize("expected_sample_rate", (None, 44_100.0, 48_000.0))
     75     def test_setter(self, expected_sample_rate: Optional[float]):
     76         model = MockBaseNet(gain=1.0)
     77         model.sample_rate = expected_sample_rate
     78         self._wrap_assert(model, expected_sample_rate)
     79 
     80     @pytest.mark.parametrize("expected_sample_rate", (None, 44_100.0, 48_000.0))
     81     def test_state_dict(self, expected_sample_rate: Optional[float]):
     82         """
     83         Assert that it makes it into the state dict
     84 
     85         https://github.com/sdatkinson/neural-amp-modeler/issues/351
     86         """
     87         model = MockBaseNet(gain=1.0, sample_rate=expected_sample_rate)
     88         with TemporaryDirectory() as tmpdir:
     89             model_path = Path(tmpdir, "model.pt")
     90             torch.save(model.state_dict(), model_path)
     91             model2 = MockBaseNet(gain=1.0)
     92             model2.load_state_dict(torch.load(model_path))
     93         self._wrap_assert(model2, expected_sample_rate)
     94 
     95     @classmethod
     96     def _wrap_assert(cls, model: MockBaseNet, expected: Optional[float]):
     97         actual = model.sample_rate
     98         if expected is None:
     99             assert actual is None
    100         else:
    101             assert isinstance(actual, float)
    102             assert actual == expected
    103 
    104 
    105 if __name__ == "__main__":
    106     pytest.main()