commit 1f11c8a73039ca36cb96add0412f21c89b87e7b7
parent f6f7761d197e88f491a69a3a77d89ee35f937be5
Author: Steven Atkinson <steven@atkinson.mn>
Date: Wed, 14 Dec 2022 23:59:35 -0800
Expose head_scale as a settable parameter, default to 0.02 (#61)
Diffstat:
1 file changed, 12 insertions(+), 2 deletions(-)
diff --git a/nam/train/colab.py b/nam/train/colab.py
@@ -139,6 +139,7 @@ def _get_configs(
epochs: int,
stage_1_channels: int,
stage_2_channels: int,
+ head_scale: float,
lr: float,
lr_decay: float,
):
@@ -182,7 +183,8 @@ def _get_configs(
"gated": False,
"head_bias": True,
},
- ]
+ ],
+ "head_scale": head_scale,
},
},
"loss": {"val_loss": "esr"},
@@ -260,6 +262,7 @@ def run(
delay=None,
stage_1_channels=16,
stage_2_channels=8,
+ head_scale: float = 0.02,
lr=0.004,
lr_decay=0.007,
seed=0,
@@ -278,7 +281,14 @@ def run(
input_version, input_basename = _check_for_files()
delay = _calibrate_delay(delay, input_version, input_basename)
data_config, model_config, learning_config = _get_configs(
- input_basename, delay, epochs, stage_1_channels, stage_2_channels, lr, lr_decay
+ input_basename,
+ delay,
+ epochs,
+ stage_1_channels,
+ stage_2_channels,
+ head_scale,
+ lr,
+ lr_decay,
)
print("Starting training. Let's rock!")