test_conv_net.py (1019B)
1 # File: test_conv_net.py 2 # Created Date: Friday May 6th 2022 3 # Author: Steven Atkinson ([email protected]) 4 5 import pytest as _pytest 6 7 from nam.models import conv_net 8 9 from .base import Base as _Base 10 11 12 class TestConvNet(_Base): 13 @classmethod 14 def setup_class(cls): 15 channels = 3 16 dilations = [1, 2, 4] 17 return super().setup_class( 18 conv_net.ConvNet, 19 (channels, dilations), 20 {"batchnorm": False, "activation": "Tanh"}, 21 ) 22 23 @_pytest.mark.parametrize( 24 ("batchnorm,activation"), ((False, "ReLU"), (True, "Tanh")) 25 ) 26 def test_init(self, batchnorm, activation): 27 super().test_init(kwargs={"batchnorm": batchnorm, "activation": activation}) 28 29 @_pytest.mark.parametrize( 30 ("batchnorm,activation"), ((False, "ReLU"), (True, "Tanh")) 31 ) 32 def test_export(self, batchnorm, activation): 33 super().test_export(kwargs={"batchnorm": batchnorm, "activation": activation}) 34 35 36 if __name__ == "__main__": 37 _pytest.main()