test_recurrent.py (1555B)
1 # File: test_recurrent.py 2 # Created Date: Sunday July 17th 2022 3 # Author: Steven Atkinson (steven@atkinson.mn) 4 5 import pytest as _pytest 6 import torch as _torch 7 8 from nam.models import recurrent as _recurrent 9 10 from .base import Base as _Base 11 12 _metadata_loudness_x_mocked = 0.1 * _torch.randn((11,)) # Shorter for speed 13 14 15 class TestLSTM(_Base): 16 @classmethod 17 def setup_class(cls): 18 class LSTMWithMocks(_recurrent.LSTM): 19 def __init__(self, *args, **kwargs): 20 super().__init__(*args, **kwargs) 21 self._get_initial_state_burn_in = 7 22 23 @classmethod 24 def _metadata_loudness_x(cls) -> _torch.Tensor: 25 return _metadata_loudness_x_mocked 26 27 num_layers = 2 28 hidden_size = 3 29 super().setup_class( 30 LSTMWithMocks, 31 args=(hidden_size,), 32 kwargs={"train_burn_in": 3, "train_truncate": 5, "num_layers": num_layers}, 33 ) 34 cls._num_layers = num_layers 35 cls._hidden_size = hidden_size 36 37 @_pytest.mark.parametrize( 38 "device", 39 ( 40 "cpu", 41 _pytest.param( 42 "cuda", 43 marks=_pytest.mark.skipif( 44 not _torch.cuda.is_available(), reason="GPU test" 45 ), 46 ), 47 ), 48 ) 49 def test_get_initial_state_on(self, device: str): 50 model = self._construct().to(device) 51 h, c = model._get_initial_state() 52 assert isinstance(h, _torch.Tensor) 53 assert isinstance(c, _torch.Tensor)