commit 067077812a0b23f77090bb64d75bedeb7e7f0a2b
parent 4c982e5b530cd9c849b2b55c0f0da1cf1dab8cdf
Author: woodbury shortridge <[email protected]>
Date: Fri, 8 Nov 2024 02:09:25 -0500
Make core get_callbacks public for extensions (#504)
public get_callbacks for better extensions
Diffstat:
2 files changed, 27 insertions(+), 2 deletions(-)
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -1235,7 +1235,7 @@ class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint):
nam_path.unlink()
-def _get_callbacks(
+def get_callbacks(
threshold_esr: Optional[float],
user_metadata: Optional[UserMetadata] = None,
settings_metadata: Optional[metadata.Settings] = None,
@@ -1432,7 +1432,7 @@ def train(
data_metadata = metadata.Data(latency=latency_analysis, checks=data_check_output)
trainer = pl.Trainer(
- callbacks=_get_callbacks(
+ callbacks=get_callbacks(
threshold_esr,
user_metadata=user_metadata,
settings_metadata=settings_metadata,
diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py
@@ -295,5 +295,30 @@ def test_end_to_end():
assert isinstance(train_output.model, Model)
+def test_get_callbacks():
+ """
+ Sanity check for get_callbacks with a custom extension callback and threshold_esr
+ """
+ threshold_esr = 0.01
+ callbacks = core.get_callbacks(threshold_esr=threshold_esr)
+
+ # dumb example of a user-extended custom callback
+ class CustomCallback:
+ pass
+ extended_callbacks = callbacks + [CustomCallback()]
+
+ # sanity default callbacks
+ assert any(isinstance(cb, core._ModelCheckpoint) for cb in extended_callbacks), \
+ "Expected _ModelCheckpoint to be part of the default callbacks."
+
+ # custom callback
+ assert any(isinstance(cb, CustomCallback) for cb in extended_callbacks), \
+ "Expected CustomCallback to be added to the extended callbacks."
+
+ # _ValidationStopping cb when threshold_esr is prvided
+ assert any(isinstance(cb, core._ValidationStopping) for cb in extended_callbacks), \
+ "_ValidationStopping should still be present after adding a custom callback."
+
+
if __name__ == "__main__":
pytest.main()