neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 067077812a0b23f77090bb64d75bedeb7e7f0a2b
parent 4c982e5b530cd9c849b2b55c0f0da1cf1dab8cdf
Author: woodbury shortridge <[email protected]>
Date:   Fri,  8 Nov 2024 02:09:25 -0500

Make core get_callbacks public for extensions (#504)

public get_callbacks for better extensions
Diffstat:
Mnam/train/core.py | 4++--
Mtests/test_nam/test_train/test_core.py | 25+++++++++++++++++++++++++
2 files changed, 27 insertions(+), 2 deletions(-)

diff --git a/nam/train/core.py b/nam/train/core.py @@ -1235,7 +1235,7 @@ class _ModelCheckpoint(pl.callbacks.model_checkpoint.ModelCheckpoint): nam_path.unlink() -def _get_callbacks( +def get_callbacks( threshold_esr: Optional[float], user_metadata: Optional[UserMetadata] = None, settings_metadata: Optional[metadata.Settings] = None, @@ -1432,7 +1432,7 @@ def train( data_metadata = metadata.Data(latency=latency_analysis, checks=data_check_output) trainer = pl.Trainer( - callbacks=_get_callbacks( + callbacks=get_callbacks( threshold_esr, user_metadata=user_metadata, settings_metadata=settings_metadata, diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py @@ -295,5 +295,30 @@ def test_end_to_end(): assert isinstance(train_output.model, Model) +def test_get_callbacks(): + """ + Sanity check for get_callbacks with a custom extension callback and threshold_esr + """ + threshold_esr = 0.01 + callbacks = core.get_callbacks(threshold_esr=threshold_esr) + + # dumb example of a user-extended custom callback + class CustomCallback: + pass + extended_callbacks = callbacks + [CustomCallback()] + + # sanity default callbacks + assert any(isinstance(cb, core._ModelCheckpoint) for cb in extended_callbacks), \ + "Expected _ModelCheckpoint to be part of the default callbacks." + + # custom callback + assert any(isinstance(cb, CustomCallback) for cb in extended_callbacks), \ + "Expected CustomCallback to be added to the extended callbacks." + + # _ValidationStopping cb when threshold_esr is prvided + assert any(isinstance(cb, core._ValidationStopping) for cb in extended_callbacks), \ + "_ValidationStopping should still be present after adding a custom callback." + + if __name__ == "__main__": pytest.main()