neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

conv_net.py (9676B)


      1 # File: conv_net.py
      2 # Created Date: Saturday February 5th 2022
      3 # Author: Steven Atkinson (steven@atkinson.mn)
      4 
      5 import json as _json
      6 import math as _math
      7 from enum import Enum as _Enum
      8 from functools import partial as _partial
      9 from pathlib import Path as _Path
     10 from tempfile import TemporaryDirectory as _TemporaryDirectory
     11 from typing import (
     12     Optional as _Optional,
     13     Sequence as _Sequence,
     14     Tuple as _Tuple,
     15     Union as _Union,
     16 )
     17 
     18 import numpy as _np
     19 import torch as _torch
     20 import torch.nn as _nn
     21 import torch.nn.functional as _F
     22 
     23 
     24 from .. import __version__
     25 from ..data import wav_to_tensor as _wav_to_tensor
     26 from ._activations import get_activation as _get_activation
     27 from .base import BaseNet as _BaseNet
     28 from ._names import (
     29     ACTIVATION_NAME as _ACTIVATION_NAME,
     30     BATCHNORM_NAME as _BATCHNORM_NAME,
     31     CONV_NAME as _CONV_NAME,
     32 )
     33 
     34 
     35 class TrainStrategy(_Enum):
     36     STRIDE = "stride"
     37     DILATE = "dilate"
     38 
     39 
     40 default_train_strategy = TrainStrategy.DILATE
     41 
     42 
     43 class _Functional(_nn.Module):
     44     """
     45     Define a layer by a function w/ no params
     46     """
     47 
     48     def __init__(self, op):
     49         super().__init__()
     50         self._op = op
     51 
     52     def forward(self, *args, **kwargs):
     53         return self._op(*args, **kwargs)
     54 
     55 
     56 class _IR(_nn.Module):
     57     def __init__(self, filename: _Union[str, _Path]):
     58         super().__init__()
     59         self.register_buffer("_weight", reversed(_wav_to_tensor(filename))[None, None])
     60 
     61     @property
     62     def length(self) -> int:
     63         return self._weight.shape[-1]
     64 
     65     def forward(self, x: _torch.Tensor) -> _torch.Tensor:
     66         """
     67         :param x: (N,D)
     68         :return: (N,D-length+1)
     69         """
     70         return _F.conv1d(x[:, None], self._weight)[:, 0]
     71 
     72 
     73 def _conv_net(
     74     channels: int = 32,
     75     dilations: _Sequence[int] = None,
     76     batchnorm: bool = False,
     77     activation: str = "Tanh",
     78 ) -> _nn.Sequential:
     79     def block(cin, cout, dilation):
     80         net = _nn.Sequential()
     81         net.add_module(
     82             _CONV_NAME, _nn.Conv1d(cin, cout, 2, dilation=dilation, bias=not batchnorm)
     83         )
     84         if batchnorm:
     85             net.add_module(_BATCHNORM_NAME, _nn.BatchNorm1d(cout))
     86         net.add_module(_ACTIVATION_NAME, _get_activation(activation))
     87         return net
     88 
     89     def check_and_expand(n, x):
     90         if x.shape[1] < n:
     91             raise ValueError(
     92                 f"Input of length {x.shape[1]} is shorter than model receptive field ({n})"
     93             )
     94         return x[:, None, :]
     95 
     96     dilations = [1, 2, 4, 8] if dilations is None else dilations
     97     receptive_field = sum(dilations) + 1
     98     net = _nn.Sequential()
     99     net.add_module("expand", _Functional(_partial(check_and_expand, receptive_field)))
    100     cin = 1
    101     cout = channels
    102     for i, dilation in enumerate(dilations):
    103         net.add_module(f"block_{i}", block(cin, cout, dilation))
    104         cin = cout
    105     net.add_module("head", _nn.Conv1d(channels, 1, 1))
    106     net.add_module("flatten", _nn.Flatten())
    107     return net
    108 
    109 
    110 class ConvNet(_BaseNet):
    111     """
    112     A straightforward convolutional neural network.
    113 
    114     Works surprisingly well!
    115     """
    116 
    117     def __init__(
    118         self,
    119         *args,
    120         train_strategy: TrainStrategy = default_train_strategy,
    121         ir: _Optional[_IR] = None,
    122         sample_rate: _Optional[float] = None,
    123         **kwargs,
    124     ):
    125         super().__init__(sample_rate=sample_rate)
    126         self._net = _conv_net(*args, **kwargs)
    127         assert train_strategy == TrainStrategy.DILATE, "Stride no longer supported"
    128         self._train_strategy = train_strategy
    129         self._num_blocks = self._get_num_blocks(self._net)
    130         self._pad_start_default = True
    131         self._ir = ir
    132 
    133     @classmethod
    134     def parse_config(cls, config):
    135         config = super().parse_config(config)
    136         config["train_strategy"] = TrainStrategy(
    137             config.get("train_strategy", default_train_strategy.value)
    138         )
    139         config["ir"] = (
    140             None if "ir_filename" not in config else _IR(config.pop("ir_filename"))
    141         )
    142         return config
    143 
    144     @property
    145     def pad_start_default(self) -> bool:
    146         return self._pad_start_default
    147 
    148     @property
    149     def receptive_field(self) -> int:
    150         net_rf = 1 + sum(
    151             self._net._modules[f"block_{i}"]._modules["conv"].dilation[0]
    152             for i in range(self._num_blocks)
    153         )
    154         # Minus 1 because it composes w/ the net
    155         ir_rf = 0 if self._ir is None else self._ir.length - 1
    156         return net_rf + ir_rf
    157 
    158     @property
    159     def _activation(self):
    160         return (
    161             self._net._modules["block_0"]._modules[_ACTIVATION_NAME].__class__.__name__
    162         )
    163 
    164     @property
    165     def _channels(self) -> int:
    166         return self._net._modules["block_0"]._modules[_CONV_NAME].weight.shape[0]
    167 
    168     @property
    169     def _num_layers(self) -> int:
    170         return self._num_blocks
    171 
    172     @property
    173     def _batchnorm(self) -> bool:
    174         return _BATCHNORM_NAME in self._net._modules["block_0"]._modules
    175 
    176     def export_cpp_header(self, filename: _Path):
    177         with _TemporaryDirectory() as tmpdir:
    178             tmpdir = _Path(tmpdir)
    179             self.export(_Path(tmpdir))
    180             with open(_Path(tmpdir, "config.json"), "r") as fp:
    181                 _c = _json.load(fp)
    182             version = _c["version"]
    183             config = _c["config"]
    184             with open(filename, "w") as f:
    185                 f.writelines(
    186                     (
    187                         "#pragma once\n",
    188                         "// Automatically-generated model file\n",
    189                         "#include <vector>\n",
    190                         f'#define PYTHON_MODEL_VERSION "{version}"\n',
    191                         f"const int CHANNELS = {config['channels']};\n",
    192                         f"const bool BATCHNORM = {'true' if config['batchnorm'] else 'false'};\n",
    193                         "std::vector<int> DILATIONS{"
    194                         + ",".join([str(d) for d in config["dilations"]])
    195                         + "};\n",
    196                         f"const std::string ACTIVATION = \"{config['activation']}\";\n",
    197                         "std::vector<float> PARAMS{"
    198                         + ",".join(
    199                             [
    200                                 f"{w:.16f}"
    201                                 for w in _np.load(_Path(tmpdir, "weights.npy"))
    202                             ]
    203                         )
    204                         + "};\n",
    205                     )
    206                 )
    207 
    208     def _export_config(self):
    209         return {
    210             "channels": self._channels,
    211             "dilations": self._get_dilations(),
    212             "batchnorm": self._batchnorm,
    213             "activation": self._activation,
    214         }
    215 
    216     def _export_input_output(self, x=None) -> _Tuple[_np.ndarray, _np.ndarray]:
    217         """
    218         :return: (L,), (L,)
    219         """
    220         with _torch.no_grad():
    221             training = self.training
    222             self.eval()
    223             x = self._export_input_signal() if x is None else x
    224             y = self(x, pad_start=True)
    225             self.train(training)
    226             return tuple(z.detach().cpu().numpy() for z in (x, y))
    227 
    228     def _export_input_signal(self):
    229         """
    230         :return: (L,)
    231         """
    232         rate = self.sample_rate
    233         if rate is None:
    234             raise RuntimeError(
    235                 "Cannot export model's input and output without a sample rate."
    236             )
    237         return _torch.cat(
    238             [
    239                 _torch.zeros((rate,)),
    240                 0.5
    241                 * _torch.sin(
    242                     2.0 * _math.pi * 220.0 * _torch.linspace(0.0, 1.0, rate + 1)[:-1]
    243                 ),
    244                 _torch.zeros((rate,)),
    245             ]
    246         )
    247 
    248     def _export_weights(self) -> _np.ndarray:
    249         """
    250         weights are serialized to weights.npy in the following order:
    251         * (expand: no params)
    252         * loop blocks 0,...,L-1
    253             * conv:
    254                 * weight (Cout, Cin, K)
    255                 * bias (if no batchnorm) (Cout)
    256             * BN
    257                 * running mean
    258                 * running_var
    259                 * weight (Cout)
    260                 * bias (Cout)
    261                 * eps ()
    262         * head
    263             * weight (C, 1, 1)
    264             * bias (1, 1)
    265         * (flatten: no params)
    266         """
    267         params = []
    268         for i in range(self._num_layers):
    269             block_name = f"block_{i}"
    270             block = self._net._modules[block_name]
    271             conv = block._modules[_CONV_NAME]
    272             params.append(conv.weight.flatten())
    273             if conv.bias is not None:
    274                 params.append(conv.bias.flatten())
    275             if self._batchnorm:
    276                 bn = block._modules[_BATCHNORM_NAME]
    277                 params.append(bn.running_mean.flatten())
    278                 params.append(bn.running_var.flatten())
    279                 params.append(bn.weight.flatten())
    280                 params.append(bn.bias.flatten())
    281                 params.append(_torch.Tensor([bn.eps]).to(bn.weight.device))
    282         head = self._net._modules["head"]
    283         params.append(head.weight.flatten())
    284         params.append(head.bias.flatten())
    285         params = _torch.cat(params).detach().cpu().numpy()
    286         return params
    287 
    288     def _forward(self, x):
    289         y = self._net(x)
    290         if self._ir is not None:
    291             y = self._ir(y)
    292         return y
    293 
    294     def _get_dilations(self) -> _Tuple[int]:
    295         return tuple(
    296             self._net._modules[f"block_{i}"]._modules[_CONV_NAME].dilation[0]
    297             for i in range(self._num_blocks)
    298         )
    299 
    300     def _get_num_blocks(self, net: _nn.Sequential):
    301         i = 0
    302         while True:
    303             if f"block_{i}" not in net._modules:
    304                 break
    305             else:
    306                 i += 1
    307         return i