commit 00cf58292862e3028972e5a314b452ee7a106d15
parent 55fe7374adebdf7725c6676c568b53cc2a5c1de2
Author: Steven Atkinson <steven@atkinson.mn>
Date: Fri, 5 May 2023 16:45:50 -0700
Importing weights for WaveNet (#234)
* Importing weights for WaveNet
* Cast to torch.Tensor early
Diffstat:
3 files changed, 82 insertions(+), 1 deletion(-)
diff --git a/nam/models/_exportable.py b/nam/models/_exportable.py
@@ -8,7 +8,7 @@ import logging
from datetime import datetime
from enum import Enum
from pathlib import Path
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
import numpy as np
@@ -97,6 +97,15 @@ class Exportable(abc.ABC):
f"{self.__class__.__name__}"
)
+ def import_weights(self, weights: Sequence[float]):
+ """
+ Inverse of `._export_weights()
+ """
+ raise NotImplementedError(
+ f"Importing weights for models of type {self.__class__.__name__} isn't "
+ "implemented yet."
+ )
+
@abc.abstractmethod
def _export_config(self):
"""
diff --git a/nam/models/wavenet.py b/nam/models/wavenet.py
@@ -34,6 +34,21 @@ class Conv1d(nn.Conv1d):
else:
return torch.cat(tensors)
+ def import_weights(self, weights: torch.Tensor, i: int) -> int:
+ if self.weight is not None:
+ n = self.weight.numel()
+ self.weight.data = (
+ weights[i : i + n].reshape(self.weight.shape).to(self.weight.device)
+ )
+ i += n
+ if self.bias is not None:
+ n = self.bias.numel()
+ self.bias.data = (
+ weights[i : i + n].reshape(self.bias.shape).to(self.bias.device)
+ )
+ i += n
+ return i
+
class _Layer(nn.Module):
def __init__(
@@ -110,6 +125,11 @@ class _Layer(nn.Module):
post_activation[:, :, -out_length:],
)
+ def import_weights(self, weights: torch.Tensor, i: int) -> int:
+ i = self.conv.import_weights(weights, i)
+ i = self._input_mixer.import_weights(weights, i)
+ return self._1x1.import_weights(weights, i)
+
@property
def _channels(self) -> int:
return self._1x1.in_channels
@@ -176,6 +196,12 @@ class _Layers(nn.Module):
+ [self._head_rechannel.export_weights()]
)
+ def import_weights(self, weights: torch.Tensor, i: int) -> int:
+ i = self._rechannel.import_weights(weights, i)
+ for layer in self._layers:
+ i = layer.import_weights(weights, i)
+ return self._head_rechannel.import_weights(weights, i)
+
def forward(
self,
x: torch.Tensor,
@@ -255,6 +281,11 @@ class _Head(nn.Module):
def forward(self, *args, **kwargs):
return self._layers(*args, **kwargs)
+ def import_weights(self, weights: torch.Tensor, i: int) -> int:
+ for layer in self._layers:
+ i = layer[1].import_weights(weights, i)
+ return i
+
class _WaveNet(nn.Module):
def __init__(
@@ -290,6 +321,11 @@ class _WaveNet(nn.Module):
weights = torch.cat([weights, torch.Tensor([self._head_scale])])
return weights.detach().cpu().numpy()
+ def import_weights(self, weights: torch.Tensor):
+ i = 0
+ for layer in self._layers:
+ i = layer.import_weights(weights, i)
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
:param x: (B,Cx,L)
@@ -376,6 +412,11 @@ class WaveNet(BaseNet):
)
)
+ def import_weights(self, weights: Sequence[float]):
+ if not isinstance(weights, torch.Tensor):
+ weights = torch.Tensor(weights)
+ self._net.import_weights(weights)
+
def _export_config(self):
return self._net.export_config()
diff --git a/tests/test_nam/test_models/test_wavenet.py b/tests/test_nam/test_models/test_wavenet.py
@@ -0,0 +1,31 @@
+# File: test_wavenet.py
+# Created Date: Friday May 5th 2023
+# Author: Steven Atkinson (steven@atkinson.mn)
+
+import pytest
+import torch
+
+from nam.models.wavenet import WaveNet
+from nam.train.core import Architecture, _get_wavenet_config
+
+
+# from .base import Base
+
+
+class TestWaveNet(object):
+ def test_import_weights(self):
+ config = _get_wavenet_config(Architecture.FEATHER)
+ model_1 = WaveNet.init_from_config(config)
+ model_2 = WaveNet.init_from_config(config)
+
+ batch_size = 2
+ x = torch.randn(batch_size, model_1.receptive_field + 23)
+
+ y1 = model_1(x)
+ y2_before = model_2(x)
+
+ model_2.import_weights(model_1._export_weights())
+ y2_after = model_2(x)
+
+ assert not torch.allclose(y2_before, y1)
+ assert torch.allclose(y2_after, y1)