commit 89496f5d87590cc2dfb7431403a6f7b4a5a6827d
parent f5b3909ad2be15eb62d44cd43e0bbc88edc1c389
Author: Steven Atkinson <[email protected]>
Date: Sat, 29 Apr 2023 13:34:56 -0700
CPU-friendly training parameters (#219)
Diffstat:
1 file changed, 16 insertions(+), 2 deletions(-)
diff --git a/nam/train/gui.py b/nam/train/gui.py
@@ -35,18 +35,30 @@ from tkinter import filedialog
from typing import Callable, Optional, Sequence
try:
+ import torch
+
from nam import __version__
from nam.train import core
from nam.models.metadata import GearType, UserMetadata, ToneType
_install_is_valid = True
+ _HAVE_ACCELERATOR = torch.cuda.is_available() or torch.backends.mps.is_available()
except ImportError:
_install_is_valid = False
+ _HAVE_ACCELERATOR = False
+
+if _HAVE_ACCELERATOR:
+ _DEFAULT_NUM_EPOCHS = 100
+ _DEFAULT_BATCH_SIZE = 16
+ _DEFAULT_LR_DECAY = 0.007
+else:
+ _DEFAULT_NUM_EPOCHS = 20
+ _DEFAULT_BATCH_SIZE = 1
+ _DEFAULT_LR_DECAY = 0.05
_BUTTON_WIDTH = 20
_BUTTON_HEIGHT = 2
_TEXT_WIDTH = 70
-_DEFAULT_NUM_EPOCHS = 100
_DEFAULT_DELAY = None
_ADVANCED_OPTIONS_LEFT_WIDTH = 12
@@ -276,7 +288,8 @@ class _GUI(object):
# If you're poking around looking for these, then maybe it's time to learn to
# use the command-line scripts ;)
lr = 0.004
- lr_decay = 0.007
+ lr_decay = _DEFAULT_LR_DECAY
+ batch_size = _DEFAULT_BATCH_SIZE
seed = 0
# Run it
@@ -291,6 +304,7 @@ class _GUI(object):
epochs=num_epochs,
delay=delay,
architecture=architecture,
+ batch_size=batch_size,
lr=lr,
lr_decay=lr_decay,
seed=seed,