neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

test_recurrent.py (1555B)


      1 # File: test_recurrent.py
      2 # Created Date: Sunday July 17th 2022
      3 # Author: Steven Atkinson (steven@atkinson.mn)
      4 
      5 import pytest as _pytest
      6 import torch as _torch
      7 
      8 from nam.models import recurrent as _recurrent
      9 
     10 from .base import Base as _Base
     11 
     12 _metadata_loudness_x_mocked = 0.1 * _torch.randn((11,))  # Shorter for speed
     13 
     14 
     15 class TestLSTM(_Base):
     16     @classmethod
     17     def setup_class(cls):
     18         class LSTMWithMocks(_recurrent.LSTM):
     19             def __init__(self, *args, **kwargs):
     20                 super().__init__(*args, **kwargs)
     21                 self._get_initial_state_burn_in = 7
     22 
     23             @classmethod
     24             def _metadata_loudness_x(cls) -> _torch.Tensor:
     25                 return _metadata_loudness_x_mocked
     26 
     27         num_layers = 2
     28         hidden_size = 3
     29         super().setup_class(
     30             LSTMWithMocks,
     31             args=(hidden_size,),
     32             kwargs={"train_burn_in": 3, "train_truncate": 5, "num_layers": num_layers},
     33         )
     34         cls._num_layers = num_layers
     35         cls._hidden_size = hidden_size
     36 
     37     @_pytest.mark.parametrize(
     38         "device",
     39         (
     40             "cpu",
     41             _pytest.param(
     42                 "cuda",
     43                 marks=_pytest.mark.skipif(
     44                     not _torch.cuda.is_available(), reason="GPU test"
     45                 ),
     46             ),
     47         ),
     48     )
     49     def test_get_initial_state_on(self, device: str):
     50         model = self._construct().to(device)
     51         h, c = model._get_initial_state()
     52         assert isinstance(h, _torch.Tensor)
     53         assert isinstance(c, _torch.Tensor)