commit bc6f76d80f3c06212b7e3661f0997b3b4f4b9c59
parent b5568563c55d3e736c7e2a353214e32304068307
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sun, 7 May 2023 23:35:48 -0700
Fix misspelled key for mrstft_weight (#237)
* Fix misspelled key for mrstft_weight
* Backwards-compatible
* Deprecation notice
Diffstat:
1 file changed, 21 insertions(+), 1 deletion(-)
diff --git a/nam/models/base.py b/nam/models/base.py
@@ -75,7 +75,7 @@ class LossConfig(InitializableFromConfig):
mask_first = config.get("mask_first", 0)
pre_emph_coef = config.get("pre_emph_coef")
pre_emph_weight = config.get("pre_emph_weight")
- mrstft_weight = config.get("mstft_weight", 0.0)
+ mrstft_weight = cls._get_mrstft_weight(config)
return {
"fourier": fourier,
"mask_first": mask_first,
@@ -93,6 +93,26 @@ class LossConfig(InitializableFromConfig):
"""
return tuple(a[..., self.mask_first :] for a in args)
+ @classmethod
+ def _get_mrstft_weight(cls, config) -> float:
+ key = "mrstft_weight"
+ wrong_key = "mstft_key" # Backward compatibility
+ if key in config:
+ if "mstft_weight" in config:
+ raise ValueError(
+ f"Received loss configuration with both '{key}' and "
+ f"'{wrong_key}'. Provide only '{key}'."
+ )
+ return config[key]
+ elif wrong_key in config:
+ logger.warning(
+ f"Use of '{wrong_key}' is deprecated and will be removed in a future "
+ f"version. Use '{key}' instead."
+ )
+ return config[wrong_key]
+ else:
+ return 0.0
+
class Model(pl.LightningModule, InitializableFromConfig):
def __init__(