neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 00cf58292862e3028972e5a314b452ee7a106d15
parent 55fe7374adebdf7725c6676c568b53cc2a5c1de2
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Fri,  5 May 2023 16:45:50 -0700

Importing weights for WaveNet (#234)

* Importing weights for WaveNet

* Cast to torch.Tensor early
Diffstat:
Mnam/models/_exportable.py | 11++++++++++-
Mnam/models/wavenet.py | 41+++++++++++++++++++++++++++++++++++++++++
Atests/test_nam/test_models/test_wavenet.py | 31+++++++++++++++++++++++++++++++
3 files changed, 82 insertions(+), 1 deletion(-)

diff --git a/nam/models/_exportable.py b/nam/models/_exportable.py @@ -8,7 +8,7 @@ import logging from datetime import datetime from enum import Enum from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Union import numpy as np @@ -97,6 +97,15 @@ class Exportable(abc.ABC): f"{self.__class__.__name__}" ) + def import_weights(self, weights: Sequence[float]): + """ + Inverse of `._export_weights() + """ + raise NotImplementedError( + f"Importing weights for models of type {self.__class__.__name__} isn't " + "implemented yet." + ) + @abc.abstractmethod def _export_config(self): """ diff --git a/nam/models/wavenet.py b/nam/models/wavenet.py @@ -34,6 +34,21 @@ class Conv1d(nn.Conv1d): else: return torch.cat(tensors) + def import_weights(self, weights: torch.Tensor, i: int) -> int: + if self.weight is not None: + n = self.weight.numel() + self.weight.data = ( + weights[i : i + n].reshape(self.weight.shape).to(self.weight.device) + ) + i += n + if self.bias is not None: + n = self.bias.numel() + self.bias.data = ( + weights[i : i + n].reshape(self.bias.shape).to(self.bias.device) + ) + i += n + return i + class _Layer(nn.Module): def __init__( @@ -110,6 +125,11 @@ class _Layer(nn.Module): post_activation[:, :, -out_length:], ) + def import_weights(self, weights: torch.Tensor, i: int) -> int: + i = self.conv.import_weights(weights, i) + i = self._input_mixer.import_weights(weights, i) + return self._1x1.import_weights(weights, i) + @property def _channels(self) -> int: return self._1x1.in_channels @@ -176,6 +196,12 @@ class _Layers(nn.Module): + [self._head_rechannel.export_weights()] ) + def import_weights(self, weights: torch.Tensor, i: int) -> int: + i = self._rechannel.import_weights(weights, i) + for layer in self._layers: + i = layer.import_weights(weights, i) + return self._head_rechannel.import_weights(weights, i) + def forward( self, x: torch.Tensor, @@ -255,6 +281,11 @@ class _Head(nn.Module): def forward(self, *args, **kwargs): return self._layers(*args, **kwargs) + def import_weights(self, weights: torch.Tensor, i: int) -> int: + for layer in self._layers: + i = layer[1].import_weights(weights, i) + return i + class _WaveNet(nn.Module): def __init__( @@ -290,6 +321,11 @@ class _WaveNet(nn.Module): weights = torch.cat([weights, torch.Tensor([self._head_scale])]) return weights.detach().cpu().numpy() + def import_weights(self, weights: torch.Tensor): + i = 0 + for layer in self._layers: + i = layer.import_weights(weights, i) + def forward(self, x: torch.Tensor) -> torch.Tensor: """ :param x: (B,Cx,L) @@ -376,6 +412,11 @@ class WaveNet(BaseNet): ) ) + def import_weights(self, weights: Sequence[float]): + if not isinstance(weights, torch.Tensor): + weights = torch.Tensor(weights) + self._net.import_weights(weights) + def _export_config(self): return self._net.export_config() diff --git a/tests/test_nam/test_models/test_wavenet.py b/tests/test_nam/test_models/test_wavenet.py @@ -0,0 +1,31 @@ +# File: test_wavenet.py +# Created Date: Friday May 5th 2023 +# Author: Steven Atkinson (steven@atkinson.mn) + +import pytest +import torch + +from nam.models.wavenet import WaveNet +from nam.train.core import Architecture, _get_wavenet_config + + +# from .base import Base + + +class TestWaveNet(object): + def test_import_weights(self): + config = _get_wavenet_config(Architecture.FEATHER) + model_1 = WaveNet.init_from_config(config) + model_2 = WaveNet.init_from_config(config) + + batch_size = 2 + x = torch.randn(batch_size, model_1.receptive_field + 23) + + y1 = model_1(x) + y2_before = model_2(x) + + model_2.import_weights(model_1._export_weights()) + y2_after = model_2(x) + + assert not torch.allclose(y2_before, y1) + assert torch.allclose(y2_after, y1)