conv_net.py (9676B)
1 # File: conv_net.py 2 # Created Date: Saturday February 5th 2022 3 # Author: Steven Atkinson (steven@atkinson.mn) 4 5 import json as _json 6 import math as _math 7 from enum import Enum as _Enum 8 from functools import partial as _partial 9 from pathlib import Path as _Path 10 from tempfile import TemporaryDirectory as _TemporaryDirectory 11 from typing import ( 12 Optional as _Optional, 13 Sequence as _Sequence, 14 Tuple as _Tuple, 15 Union as _Union, 16 ) 17 18 import numpy as _np 19 import torch as _torch 20 import torch.nn as _nn 21 import torch.nn.functional as _F 22 23 24 from .. import __version__ 25 from ..data import wav_to_tensor as _wav_to_tensor 26 from ._activations import get_activation as _get_activation 27 from .base import BaseNet as _BaseNet 28 from ._names import ( 29 ACTIVATION_NAME as _ACTIVATION_NAME, 30 BATCHNORM_NAME as _BATCHNORM_NAME, 31 CONV_NAME as _CONV_NAME, 32 ) 33 34 35 class TrainStrategy(_Enum): 36 STRIDE = "stride" 37 DILATE = "dilate" 38 39 40 default_train_strategy = TrainStrategy.DILATE 41 42 43 class _Functional(_nn.Module): 44 """ 45 Define a layer by a function w/ no params 46 """ 47 48 def __init__(self, op): 49 super().__init__() 50 self._op = op 51 52 def forward(self, *args, **kwargs): 53 return self._op(*args, **kwargs) 54 55 56 class _IR(_nn.Module): 57 def __init__(self, filename: _Union[str, _Path]): 58 super().__init__() 59 self.register_buffer("_weight", reversed(_wav_to_tensor(filename))[None, None]) 60 61 @property 62 def length(self) -> int: 63 return self._weight.shape[-1] 64 65 def forward(self, x: _torch.Tensor) -> _torch.Tensor: 66 """ 67 :param x: (N,D) 68 :return: (N,D-length+1) 69 """ 70 return _F.conv1d(x[:, None], self._weight)[:, 0] 71 72 73 def _conv_net( 74 channels: int = 32, 75 dilations: _Sequence[int] = None, 76 batchnorm: bool = False, 77 activation: str = "Tanh", 78 ) -> _nn.Sequential: 79 def block(cin, cout, dilation): 80 net = _nn.Sequential() 81 net.add_module( 82 _CONV_NAME, _nn.Conv1d(cin, cout, 2, dilation=dilation, bias=not batchnorm) 83 ) 84 if batchnorm: 85 net.add_module(_BATCHNORM_NAME, _nn.BatchNorm1d(cout)) 86 net.add_module(_ACTIVATION_NAME, _get_activation(activation)) 87 return net 88 89 def check_and_expand(n, x): 90 if x.shape[1] < n: 91 raise ValueError( 92 f"Input of length {x.shape[1]} is shorter than model receptive field ({n})" 93 ) 94 return x[:, None, :] 95 96 dilations = [1, 2, 4, 8] if dilations is None else dilations 97 receptive_field = sum(dilations) + 1 98 net = _nn.Sequential() 99 net.add_module("expand", _Functional(_partial(check_and_expand, receptive_field))) 100 cin = 1 101 cout = channels 102 for i, dilation in enumerate(dilations): 103 net.add_module(f"block_{i}", block(cin, cout, dilation)) 104 cin = cout 105 net.add_module("head", _nn.Conv1d(channels, 1, 1)) 106 net.add_module("flatten", _nn.Flatten()) 107 return net 108 109 110 class ConvNet(_BaseNet): 111 """ 112 A straightforward convolutional neural network. 113 114 Works surprisingly well! 115 """ 116 117 def __init__( 118 self, 119 *args, 120 train_strategy: TrainStrategy = default_train_strategy, 121 ir: _Optional[_IR] = None, 122 sample_rate: _Optional[float] = None, 123 **kwargs, 124 ): 125 super().__init__(sample_rate=sample_rate) 126 self._net = _conv_net(*args, **kwargs) 127 assert train_strategy == TrainStrategy.DILATE, "Stride no longer supported" 128 self._train_strategy = train_strategy 129 self._num_blocks = self._get_num_blocks(self._net) 130 self._pad_start_default = True 131 self._ir = ir 132 133 @classmethod 134 def parse_config(cls, config): 135 config = super().parse_config(config) 136 config["train_strategy"] = TrainStrategy( 137 config.get("train_strategy", default_train_strategy.value) 138 ) 139 config["ir"] = ( 140 None if "ir_filename" not in config else _IR(config.pop("ir_filename")) 141 ) 142 return config 143 144 @property 145 def pad_start_default(self) -> bool: 146 return self._pad_start_default 147 148 @property 149 def receptive_field(self) -> int: 150 net_rf = 1 + sum( 151 self._net._modules[f"block_{i}"]._modules["conv"].dilation[0] 152 for i in range(self._num_blocks) 153 ) 154 # Minus 1 because it composes w/ the net 155 ir_rf = 0 if self._ir is None else self._ir.length - 1 156 return net_rf + ir_rf 157 158 @property 159 def _activation(self): 160 return ( 161 self._net._modules["block_0"]._modules[_ACTIVATION_NAME].__class__.__name__ 162 ) 163 164 @property 165 def _channels(self) -> int: 166 return self._net._modules["block_0"]._modules[_CONV_NAME].weight.shape[0] 167 168 @property 169 def _num_layers(self) -> int: 170 return self._num_blocks 171 172 @property 173 def _batchnorm(self) -> bool: 174 return _BATCHNORM_NAME in self._net._modules["block_0"]._modules 175 176 def export_cpp_header(self, filename: _Path): 177 with _TemporaryDirectory() as tmpdir: 178 tmpdir = _Path(tmpdir) 179 self.export(_Path(tmpdir)) 180 with open(_Path(tmpdir, "config.json"), "r") as fp: 181 _c = _json.load(fp) 182 version = _c["version"] 183 config = _c["config"] 184 with open(filename, "w") as f: 185 f.writelines( 186 ( 187 "#pragma once\n", 188 "// Automatically-generated model file\n", 189 "#include <vector>\n", 190 f'#define PYTHON_MODEL_VERSION "{version}"\n', 191 f"const int CHANNELS = {config['channels']};\n", 192 f"const bool BATCHNORM = {'true' if config['batchnorm'] else 'false'};\n", 193 "std::vector<int> DILATIONS{" 194 + ",".join([str(d) for d in config["dilations"]]) 195 + "};\n", 196 f"const std::string ACTIVATION = \"{config['activation']}\";\n", 197 "std::vector<float> PARAMS{" 198 + ",".join( 199 [ 200 f"{w:.16f}" 201 for w in _np.load(_Path(tmpdir, "weights.npy")) 202 ] 203 ) 204 + "};\n", 205 ) 206 ) 207 208 def _export_config(self): 209 return { 210 "channels": self._channels, 211 "dilations": self._get_dilations(), 212 "batchnorm": self._batchnorm, 213 "activation": self._activation, 214 } 215 216 def _export_input_output(self, x=None) -> _Tuple[_np.ndarray, _np.ndarray]: 217 """ 218 :return: (L,), (L,) 219 """ 220 with _torch.no_grad(): 221 training = self.training 222 self.eval() 223 x = self._export_input_signal() if x is None else x 224 y = self(x, pad_start=True) 225 self.train(training) 226 return tuple(z.detach().cpu().numpy() for z in (x, y)) 227 228 def _export_input_signal(self): 229 """ 230 :return: (L,) 231 """ 232 rate = self.sample_rate 233 if rate is None: 234 raise RuntimeError( 235 "Cannot export model's input and output without a sample rate." 236 ) 237 return _torch.cat( 238 [ 239 _torch.zeros((rate,)), 240 0.5 241 * _torch.sin( 242 2.0 * _math.pi * 220.0 * _torch.linspace(0.0, 1.0, rate + 1)[:-1] 243 ), 244 _torch.zeros((rate,)), 245 ] 246 ) 247 248 def _export_weights(self) -> _np.ndarray: 249 """ 250 weights are serialized to weights.npy in the following order: 251 * (expand: no params) 252 * loop blocks 0,...,L-1 253 * conv: 254 * weight (Cout, Cin, K) 255 * bias (if no batchnorm) (Cout) 256 * BN 257 * running mean 258 * running_var 259 * weight (Cout) 260 * bias (Cout) 261 * eps () 262 * head 263 * weight (C, 1, 1) 264 * bias (1, 1) 265 * (flatten: no params) 266 """ 267 params = [] 268 for i in range(self._num_layers): 269 block_name = f"block_{i}" 270 block = self._net._modules[block_name] 271 conv = block._modules[_CONV_NAME] 272 params.append(conv.weight.flatten()) 273 if conv.bias is not None: 274 params.append(conv.bias.flatten()) 275 if self._batchnorm: 276 bn = block._modules[_BATCHNORM_NAME] 277 params.append(bn.running_mean.flatten()) 278 params.append(bn.running_var.flatten()) 279 params.append(bn.weight.flatten()) 280 params.append(bn.bias.flatten()) 281 params.append(_torch.Tensor([bn.eps]).to(bn.weight.device)) 282 head = self._net._modules["head"] 283 params.append(head.weight.flatten()) 284 params.append(head.bias.flatten()) 285 params = _torch.cat(params).detach().cpu().numpy() 286 return params 287 288 def _forward(self, x): 289 y = self._net(x) 290 if self._ir is not None: 291 y = self._ir(y) 292 return y 293 294 def _get_dilations(self) -> _Tuple[int]: 295 return tuple( 296 self._net._modules[f"block_{i}"]._modules[_CONV_NAME].dilation[0] 297 for i in range(self._num_blocks) 298 ) 299 300 def _get_num_blocks(self, net: _nn.Sequential): 301 i = 0 302 while True: 303 if f"block_{i}" not in net._modules: 304 break 305 else: 306 i += 1 307 return i