neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 2d06ada3e4a6ada872ff224a119113834a03f906
parent a681ebb4a2ece786e3e7c4d23a506eddca6eb568
Author: Steven Atkinson <[email protected]>
Date:   Tue, 14 May 2024 00:03:21 -0700

Fix bug (#413)


Diffstat:
Mnam/models/_base.py | 14+++++++++++++-
Mnam/models/wavenet.py | 2+-
2 files changed, 14 insertions(+), 2 deletions(-)

diff --git a/nam/models/_base.py b/nam/models/_base.py @@ -55,6 +55,18 @@ class _Base(nn.Module, InitializableFromConfig, Exportable): ) @property + def device(self) -> Optional[torch.device]: + """ + Helpful property, where the parameters of the model live. + """ + # We can do this because the models are tiny and I don't expect a NAM to be on + # multiple devices + try: + return next(self.parameters()).device + except StopIteration: + return None + + @property def sample_rate(self) -> Optional[float]: return self._sample_rate.item() if self._has_sample_rate else None @@ -81,7 +93,7 @@ class _Base(nn.Module, InitializableFromConfig, Exportable): :param gain: Multiplies input signal """ - x = self._metadata_loudness_x() + x = self._metadata_loudness_x().to(self.device) y = self._at_nominal_settings(gain * x) loudness = torch.sqrt(torch.mean(torch.square(y))) if db: diff --git a/nam/models/wavenet.py b/nam/models/wavenet.py @@ -318,7 +318,7 @@ class _WaveNet(nn.Module): weights = torch.cat([layer.export_weights() for layer in self._layers]) if self._head is not None: weights = torch.cat([weights, self._head.export_weights()]) - weights = torch.cat([weights, torch.Tensor([self._head_scale])]) + weights = torch.cat([weights.cpu(), torch.Tensor([self._head_scale])]) return weights.detach().cpu().numpy() def import_weights(self, weights: torch.Tensor):