commit 2d06ada3e4a6ada872ff224a119113834a03f906
parent a681ebb4a2ece786e3e7c4d23a506eddca6eb568
Author: Steven Atkinson <[email protected]>
Date: Tue, 14 May 2024 00:03:21 -0700
Fix bug (#413)
Diffstat:
2 files changed, 14 insertions(+), 2 deletions(-)
diff --git a/nam/models/_base.py b/nam/models/_base.py
@@ -55,6 +55,18 @@ class _Base(nn.Module, InitializableFromConfig, Exportable):
)
@property
+ def device(self) -> Optional[torch.device]:
+ """
+ Helpful property, where the parameters of the model live.
+ """
+ # We can do this because the models are tiny and I don't expect a NAM to be on
+ # multiple devices
+ try:
+ return next(self.parameters()).device
+ except StopIteration:
+ return None
+
+ @property
def sample_rate(self) -> Optional[float]:
return self._sample_rate.item() if self._has_sample_rate else None
@@ -81,7 +93,7 @@ class _Base(nn.Module, InitializableFromConfig, Exportable):
:param gain: Multiplies input signal
"""
- x = self._metadata_loudness_x()
+ x = self._metadata_loudness_x().to(self.device)
y = self._at_nominal_settings(gain * x)
loudness = torch.sqrt(torch.mean(torch.square(y)))
if db:
diff --git a/nam/models/wavenet.py b/nam/models/wavenet.py
@@ -318,7 +318,7 @@ class _WaveNet(nn.Module):
weights = torch.cat([layer.export_weights() for layer in self._layers])
if self._head is not None:
weights = torch.cat([weights, self._head.export_weights()])
- weights = torch.cat([weights, torch.Tensor([self._head_scale])])
+ weights = torch.cat([weights.cpu(), torch.Tensor([self._head_scale])])
return weights.detach().cpu().numpy()
def import_weights(self, weights: torch.Tensor):