commit bb348f79db93db4e108794da245512e8481c70e6
parent 2228effbf05273f3e0616fa60a1a7179c27490e8
Author: Steven Atkinson <[email protected]>
Date: Mon, 25 Sep 2023 18:30:43 -0700
[Enhancement] Extensible registry of `BaseNet` subclass constructors (#310)
Update nam/models/base.py
Add extensible registry for BaseNet constructors
Diffstat:
1 file changed, 20 insertions(+), 9 deletions(-)
diff --git a/nam/models/base.py b/nam/models/base.py
@@ -119,6 +119,17 @@ class _LossItem(NamedTuple):
value: Optional[torch.Tensor]
+_model_net_init_registry = {
+ "CatLSTM": CatLSTM.init_from_config,
+ "CatWaveNet": CatWaveNet.init_from_config,
+ "ConvNet": ConvNet.init_from_config,
+ "HyperConvNet": HyperConvNet.init_from_config,
+ "Linear": Linear.init_from_config,
+ "LSTM": LSTM.init_from_config,
+ "WaveNet": WaveNet.init_from_config,
+}
+
+
class Model(pl.LightningModule, InitializableFromConfig):
def __init__(
self,
@@ -189,15 +200,7 @@ class Model(pl.LightningModule, InitializableFromConfig):
"""
config = super().parse_config(config)
net_config = config["net"]
- net = {
- "CatLSTM": CatLSTM.init_from_config,
- "CatWaveNet": CatWaveNet.init_from_config,
- "ConvNet": ConvNet.init_from_config,
- "HyperConvNet": HyperConvNet.init_from_config,
- "Linear": Linear.init_from_config,
- "LSTM": LSTM.init_from_config,
- "WaveNet": WaveNet.init_from_config,
- }[net_config["name"]](net_config["config"])
+ net = _model_net_init_registry[net_config["name"]](net_config["config"])
loss_config = LossConfig.init_from_config(config.get("loss", {}))
return {
"net": net,
@@ -206,6 +209,14 @@ class Model(pl.LightningModule, InitializableFromConfig):
"loss_config": loss_config,
}
+ @classmethod
+ def register_net_initializer(cls, name, constructor):
+ if name in _model_net_init_registry:
+ raise KeyError(
+ f"A constructor for net name '{name}' is already registered!"
+ )
+ _model_net_init_registry[name] = constructor
+
@property
def net(self) -> nn.Module:
return self._net