neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit bb348f79db93db4e108794da245512e8481c70e6
parent 2228effbf05273f3e0616fa60a1a7179c27490e8
Author: Steven Atkinson <[email protected]>
Date:   Mon, 25 Sep 2023 18:30:43 -0700

[Enhancement] Extensible registry of `BaseNet` subclass constructors (#310)

Update nam/models/base.py

Add extensible registry for BaseNet constructors
Diffstat:
Mnam/models/base.py | 29++++++++++++++++++++---------
1 file changed, 20 insertions(+), 9 deletions(-)

diff --git a/nam/models/base.py b/nam/models/base.py @@ -119,6 +119,17 @@ class _LossItem(NamedTuple): value: Optional[torch.Tensor] +_model_net_init_registry = { + "CatLSTM": CatLSTM.init_from_config, + "CatWaveNet": CatWaveNet.init_from_config, + "ConvNet": ConvNet.init_from_config, + "HyperConvNet": HyperConvNet.init_from_config, + "Linear": Linear.init_from_config, + "LSTM": LSTM.init_from_config, + "WaveNet": WaveNet.init_from_config, +} + + class Model(pl.LightningModule, InitializableFromConfig): def __init__( self, @@ -189,15 +200,7 @@ class Model(pl.LightningModule, InitializableFromConfig): """ config = super().parse_config(config) net_config = config["net"] - net = { - "CatLSTM": CatLSTM.init_from_config, - "CatWaveNet": CatWaveNet.init_from_config, - "ConvNet": ConvNet.init_from_config, - "HyperConvNet": HyperConvNet.init_from_config, - "Linear": Linear.init_from_config, - "LSTM": LSTM.init_from_config, - "WaveNet": WaveNet.init_from_config, - }[net_config["name"]](net_config["config"]) + net = _model_net_init_registry[net_config["name"]](net_config["config"]) loss_config = LossConfig.init_from_config(config.get("loss", {})) return { "net": net, @@ -206,6 +209,14 @@ class Model(pl.LightningModule, InitializableFromConfig): "loss_config": loss_config, } + @classmethod + def register_net_initializer(cls, name, constructor): + if name in _model_net_init_registry: + raise KeyError( + f"A constructor for net name '{name}' is already registered!" + ) + _model_net_init_registry[name] = constructor + @property def net(self) -> nn.Module: return self._net