commit 825b810545516ce4556ba4e6cbde9f86641cb37b
parent df619915ed021dd9c5ab233785493eda3960f219
Author: Steven Atkinson <[email protected]>
Date: Mon, 20 May 2024 23:16:49 -0500
[FEATURE,GUI,BREAKING] Validate data before training (#425)
* Training metadata
* Optional kwarg training_metadata on export
* Save training metadata in .nam file, tests
* Refactor delay calibration to return more info in preparation for including in metadata
* Got standardizedd training metadata all sorted out
* Fix bugs, add end-to-end test of core train
* Some cleanup
* Flake
* Validate data before starting training.
Diffstat:
5 files changed, 434 insertions(+), 70 deletions(-)
diff --git a/nam/data.py b/nam/data.py
@@ -40,7 +40,15 @@ class WavInfo:
rate: int
-class AudioShapeMismatchError(ValueError):
+class DataError(Exception):
+ """
+ Parent class for all special exceptions raised by NAM data sets
+ """
+
+ pass
+
+
+class AudioShapeMismatchError(ValueError, DataError):
"""
Exception where the shape (number of samples, number of channels) of two audio files
don't match but were supposed to.
@@ -191,7 +199,7 @@ def _interpolate_delay(
)
-class XYError(ValueError):
+class XYError(ValueError, DataError):
"""
Exceptions related to invalid x and y provided for data sets
"""
@@ -199,7 +207,7 @@ class XYError(ValueError):
pass
-class StartStopError(ValueError):
+class StartStopError(ValueError, DataError):
"""
Exceptions related to invalid start and stop arguments
"""
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -25,7 +25,7 @@ from pydantic import BaseModel
from pytorch_lightning.utilities.warnings import PossibleUserWarning
from torch.utils.data import DataLoader
-from ..data import Split, init_dataset, wav_to_np, wav_to_tensor
+from ..data import DataError, Split, init_dataset, wav_to_np, wav_to_tensor
from ..models import Model
from ..models.exportable import Exportable
from ..models.losses import esr
@@ -34,10 +34,20 @@ from ..util import filter_warnings
from ._version import PROTEUS_VERSION, Version
from . import metadata
-__all__ = ["train"]
+__all__ = [
+ "Architecture",
+ "DataValidationOutput",
+ "STANDARD_SAMPLE_RATE",
+ "TrainOutput",
+ "train",
+ "validate_data",
+ "validate_input",
+]
# Training using the simplified trainers in NAM is done at 48k.
STANDARD_SAMPLE_RATE = 48_000.0
+# Default number of output samples per datum.
+_NY_DEFAULT = 8192
class Architecture(Enum):
@@ -47,6 +57,10 @@ class Architecture(Enum):
NANO = "nano"
+class _InputValidationError(ValueError):
+ pass
+
+
def _detect_input_version(input_path) -> Tuple[Version, bool]:
"""
Check to see if the input matches any of the known inputs
@@ -227,7 +241,7 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
print("Falling back to weak-matching...")
version = detect_weak(input_path)
if version is None:
- raise ValueError(
+ raise _InputValidationError(
f"Input file at {input_path} cannot be recognized as any known version!"
)
strong_match = False
@@ -353,22 +367,37 @@ def _calibrate_latency_v_all(
:param y: The output audio, in complete.
"""
- def report_any_delay_warnings(delays: Sequence[int]):
+ def report_any_latency_warnings(
+ delays: Sequence[int],
+ ) -> metadata.LatencyCalibrationWarnings:
# Warnings associated with any single delay:
+ # "Lookahead warning": if the delay is equal to the lookahead, then it's
+ # probably an error.
lookahead_warnings = [i for i, d in enumerate(delays, 1) if d == -lookahead]
- if len(lookahead_warnings) > 0:
+ matches_lookahead = len(lookahead_warnings) > 0
+ if matches_lookahead:
print(_warn_lookaheads(lookahead_warnings))
# Ensemble warnings
# If they're _really_ different, then something might be wrong.
- if np.max(delays) - np.min(delays) >= 20:
+ max_disagreement_threshold = 20
+ max_disagreement_too_high = (
+ np.max(delays) - np.min(delays) >= max_disagreement_threshold
+ )
+ if max_disagreement_too_high:
print(
- "WARNING: Delays are anomalously different from each other. If this model "
- "turns out badly, then you might need to provide the delay manually."
+ "WARNING: Latencies are anomalously different from each other (more "
+ f"than {max_disagreement_threshold} samples). If this model turns out "
+ "badly, then you might need to provide the latency manually."
)
+ return metadata.LatencyCalibrationWarnings(
+ matches_lookahead=matches_lookahead,
+ disagreement_too_high=max_disagreement_too_high,
+ )
+
lookahead = 1_000
lookback = 10_000
# Calibrate the trigger:
@@ -422,7 +451,7 @@ def _calibrate_latency_v_all(
print("Delays:")
for i_rel, d in enumerate(delays, 1):
print(f" Blip {i_rel:2d}: {d}")
- report_any_delay_warnings(delays)
+ warnings = report_any_latency_warnings(delays)
delay_post_safety_factor = int(np.min(delays)) - safety_factor
print(
@@ -434,6 +463,7 @@ def _calibrate_latency_v_all(
delays=delays,
safety_factor=safety_factor,
recommended=delay_post_safety_factor,
+ warnings=warnings,
)
@@ -876,20 +906,9 @@ _CAB_MRSTFT_PRE_EMPH_WEIGHT = 2.0e-4
_CAB_MRSTFT_PRE_EMPH_COEF = 0.85
-def _get_configs(
- input_version: Version,
- input_path: str,
- output_path: str,
- delay: int,
- epochs: int,
- model_type: str,
- architecture: Architecture,
- ny: int,
- lr: float,
- lr_decay: float,
- batch_size: int,
- fit_cab: bool,
-):
+def _get_data_config(
+ input_version: Version, input_path: Path, output_path: Path, ny: int, latency: int
+) -> dict:
def get_kwargs(data_info: _DataInfo):
if data_info.major_version == 1:
train_val_split = data_info.validation_start
@@ -941,9 +960,34 @@ def _get_configs(
"common": {
"x_path": input_path,
"y_path": output_path,
- "delay": delay,
+ "delay": latency,
},
}
+ return data_config
+
+
+def _get_configs(
+ input_version: Version,
+ input_path: str,
+ output_path: str,
+ latency: int,
+ epochs: int,
+ model_type: str,
+ architecture: Architecture,
+ ny: int,
+ lr: float,
+ lr_decay: float,
+ batch_size: int,
+ fit_cab: bool,
+):
+
+ data_config = _get_data_config(
+ input_version=input_version,
+ input_path=input_path,
+ output_path=output_path,
+ ny=ny,
+ latency=latency,
+ )
if model_type == "WaveNet":
model_config = {
@@ -1228,6 +1272,16 @@ class TrainOutput(NamedTuple):
metadata: metadata.TrainingMetadata
+def _get_final_latency(latency_analysis: metadata.Latency) -> int:
+ if latency_analysis.manual is not None:
+ latency = latency_analysis.manual
+ print(f"Latency provided as {latency_analysis.manual}; override calibration")
+ else:
+ latency = latency_analysis.calibration.recommended
+ print(f"Set latency to recommended {latency_analysis.calibration.recommended}")
+ return latency
+
+
def train(
input_path: str,
output_path: str,
@@ -1239,7 +1293,7 @@ def train(
model_type: str = "WaveNet",
architecture: Union[Architecture, str] = Architecture.STANDARD,
batch_size: int = 16,
- ny: int = 8192,
+ ny: int = _NY_DEFAULT,
lr=0.004,
lr_decay=0.007,
seed: Optional[int] = 0,
@@ -1254,6 +1308,7 @@ def train(
fast_dev_run: Union[bool, int] = False,
) -> Optional[TrainOutput]:
"""
+ :param lr_decay: =1-gamma for Exponential learning rate decay.
:param threshold_esr: Stop training if ESR is better than this. Ignore if `None`.
:param fast_dev_run: One-step training, used for tests.
"""
@@ -1276,17 +1331,12 @@ def train(
user_latency = parse_user_latency(delay, latency)
latency_analysis = _analyze_latency(
- latency, input_version, input_path, output_path, silent=silent
+ user_latency, input_version, input_path, output_path, silent=silent
)
- if latency_analysis.manual is not None:
- latency = latency_analysis.manual
- print(f"Latency provided as {user_latency}; override calibration")
- else:
- latency = latency_analysis.calibration.recommended
- print(f"Set latency to recommended {latency_analysis.calibration.recommended}")
+ final_latency = _get_final_latency(latency_analysis)
data_check_output = _check_data(
- input_path, output_path, input_version, latency, silent
+ input_path, output_path, input_version, final_latency, silent
)
if data_check_output is not None:
if data_check_output.passed:
@@ -1322,7 +1372,7 @@ def train(
input_version,
input_path,
output_path,
- latency,
+ final_latency,
epochs,
model_type,
Architecture(architecture),
@@ -1418,3 +1468,119 @@ def train(
validation_esr=validation_esr,
),
)
+
+
+class DataInputValidation(BaseModel):
+ passed: bool
+
+
+def validate_input(input_path) -> DataInputValidation:
+ """
+ :return: Could it be validated?
+ """
+ try:
+ _detect_input_version(input_path)
+ # succeeded...
+ return DataInputValidation(passed=True)
+ except _InputValidationError as e:
+ print(f"Input validation failed!\n\n{e}")
+ return DataInputValidation(passed=False)
+
+
+class _PyTorchDataSplitValidation(BaseModel):
+ """
+ :param msg: On exception, catch and assign. Otherwise None
+ """
+
+ passed: bool
+ msg: Optional[str]
+
+
+class _PyTorchDataValidation(BaseModel):
+ passed: bool
+ train: _PyTorchDataSplitValidation # cf Split.TRAIN
+ validation: _PyTorchDataSplitValidation # Split.VALIDATION
+
+
+class DataValidationOutput(BaseModel):
+ passed: bool
+ input_version: str
+ latency: metadata.Latency
+ checks: metadata.DataChecks
+ pytorch: _PyTorchDataValidation
+
+
+def validate_data(
+ input_path: Path,
+ output_path: Path,
+ user_latency: Optional[int],
+ num_output_samples_per_datum: int = _NY_DEFAULT,
+):
+ """
+ Just do the checks to make sure that the data are ok.
+
+ * Version identification
+ * Latency calibration
+ * Other checks
+ """
+ passed = True # Until proven otherwise
+
+ # Data version ID
+ input_version, strong_match = _detect_input_version(input_path)
+
+ # Latency analysis
+ latency_analysis = _analyze_latency(
+ user_latency, input_version, input_path, output_path, silent=True
+ )
+ if latency_analysis.manual is None and any(
+ val for val in latency_analysis.calibration.warnings.model_dump().values()
+ ):
+ passed = False
+ final_latency = _get_final_latency(latency_analysis)
+
+ # Other data checks based on input file version
+ data_checks = _check_data(
+ input_path,
+ output_path,
+ input_version,
+ latency_analysis.calibration.recommended,
+ silent=True,
+ )
+ passed = passed and data_checks.passed
+
+ # Finally, try to make the PyTorch Dataset objects and note any failures:
+ data_config = _get_data_config(
+ input_version=input_version,
+ input_path=input_path,
+ output_path=output_path,
+ ny=num_output_samples_per_datum,
+ latency=final_latency,
+ )
+ # HACK this should depend on the model that's going to be used, but I think it will
+ # be unlikely to make a difference. Still, would be nice to fix.
+ data_config["common"]["nx"] = 4096
+
+ pytorch_data_split_validation_dict: Dict[str, _PyTorchDataSplitValidation] = {}
+ for split in Split:
+ try:
+ init_dataset(data_config, split)
+ pytorch_data_split_validation_dict[split.value] = (
+ _PyTorchDataSplitValidation(passed=True, msg=None)
+ )
+ except DataError as e:
+ pytorch_data_split_validation_dict[split.value] = (
+ _PyTorchDataSplitValidation(passed=False, msg=str(e))
+ )
+ pytorch_data_validation = _PyTorchDataValidation(
+ passed=all(v.passed for v in pytorch_data_split_validation_dict.values()),
+ **pytorch_data_split_validation_dict,
+ )
+ passed = passed and pytorch_data_validation.passed
+
+ return DataValidationOutput(
+ passed=passed,
+ input_version=str(input_version),
+ latency=latency_analysis,
+ checks=data_checks,
+ pytorch=pytorch_data_validation,
+ )
diff --git a/nam/train/gui/__init__.py b/nam/train/gui/__init__.py
@@ -39,11 +39,13 @@ try: # 3rd-party and 1st-party imports
import torch
from nam import __version__
+ from nam.data import Split
from nam.train import core
from nam.train.gui._resources import settings
from nam.models.metadata import GearType, UserMetadata, ToneType
# Ok private access here--this is technically allowed access
+ from nam.train import metadata
from nam.train._names import INPUT_BASENAMES, LATEST_VERSION
from nam.train.metadata import TRAINING_KEY
@@ -115,6 +117,7 @@ class _PathButton(object):
path_key: settings.PathKey,
hooks: Optional[Sequence[Callable[[], None]]] = None,
color_when_not_set: str = "#EF0000", # Darker red
+ color_when_set: str = "systemTextColor",
default: Optional[Path] = None,
):
"""
@@ -132,7 +135,6 @@ class _PathButton(object):
text=button_text,
width=_BUTTON_WIDTH,
height=_BUTTON_HEIGHT,
- fg="black",
command=self._set_val,
)
self._widgets["button"].pack(side=tk.LEFT)
@@ -140,13 +142,13 @@ class _PathButton(object):
self._frame,
width=_TEXT_WIDTH,
height=_BUTTON_HEIGHT,
- fg="black",
bg=None,
anchor="w",
)
self._widgets["label"].pack(side=tk.LEFT)
self._hooks = hooks
self._color_when_not_set = color_when_not_set
+ self._color_when_set = color_when_set
self._set_text()
def __setitem__(self, key, val):
@@ -172,7 +174,7 @@ class _PathButton(object):
else:
val = self.val
val = val[0] if isinstance(val, tuple) and len(val) == 1 else val
- self._widgets["label"]["fg"] = "black"
+ self._widgets["label"]["fg"] = self._color_when_set
self._widgets["label"][
"text"
] = f"{self._button_text.capitalize()} set to {val}"
@@ -212,7 +214,6 @@ class _InputPathButton(_PathButton):
text="Download input file",
width=_BUTTON_WIDTH,
height=_BUTTON_HEIGHT,
- fg="black",
command=self._download_input_file,
)
self._widgets["button_download_input"].pack(side=tk.RIGHT)
@@ -252,7 +253,6 @@ class _ClearablePathButton(_PathButton):
text="Clear",
width=_BUTTON_WIDTH,
height=_BUTTON_HEIGHT,
- fg="black",
command=self._clear_path,
)
self._widgets["button_clear"].pack(side=tk.RIGHT)
@@ -270,7 +270,6 @@ class _CheckboxKeys(Enum):
FIT_CAB = "fit_cab"
SILENT_TRAINING = "silent_training"
SAVE_PLOT = "save_plot"
- IGNORE_DATA_CHECKS = "ignore_data_checks"
class _TopLevelWithOk(tk.Toplevel):
@@ -295,7 +294,42 @@ class _TopLevelWithOk(tk.Toplevel):
super().destroy()
-class _BasicModal(object):
+class _TopLevelWithYesNo(tk.Toplevel):
+ """
+ Toplevel holding functions for yes/no buttons to close
+ """
+
+ def __init__(
+ self,
+ on_yes: Callable[[None], None],
+ on_no: Callable[[None], None],
+ on_close: Optional[Callable[[None], None]],
+ resume_main: Callable[[None], None],
+ ):
+ """
+ :param on_yes: What to do when "Yes" button is pressed.
+ :param on_no: What to do when "No" button is pressed.
+ :param on_close: Do this regardless when closing (via yes/no/x) before
+ resuming.
+ """
+ super().__init__()
+ self._on_yes = on_yes
+ self._on_no = on_no
+ self._on_close = on_close
+ self._resume_main = resume_main
+
+ def destroy(self, pressed_yes: bool = False, pressed_no: bool = False):
+ if pressed_yes:
+ self._on_yes()
+ if pressed_no:
+ self._on_no()
+ if self._on_close is not None:
+ self._on_close()
+ self._resume_main()
+ super().destroy()
+
+
+class _OkModal(object):
"""
Message and OK button
"""
@@ -309,12 +343,49 @@ class _BasicModal(object):
text="Ok",
width=_BUTTON_WIDTH,
height=_BUTTON_HEIGHT,
- fg="black",
command=lambda: self._root.destroy(pressed_ok=True),
)
self._ok.pack()
+class _YesNoModal(object):
+ """
+ Modal w/ yes/no buttons
+ """
+
+ def __init__(
+ self,
+ on_yes: Callable[[None], None],
+ on_no: Callable[[None], None],
+ resume_main,
+ msg: str,
+ on_close: Optional[Callable[[None], None]] = None,
+ label_kwargs: Optional[dict] = None,
+ ):
+ label_kwargs = {} if label_kwargs is None else label_kwargs
+ self._root = _TopLevelWithYesNo(on_yes, on_no, on_close, resume_main)
+ self._text = tk.Label(self._root, text=msg, **label_kwargs)
+ self._text.pack()
+ self._buttons_frame = tk.Frame(self._root)
+ self._buttons_frame.pack()
+ self._yes = tk.Button(
+ self._buttons_frame,
+ text="Yes",
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ command=lambda: self._root.destroy(pressed_yes=True),
+ )
+ self._yes.pack(side=tk.LEFT)
+ self._no = tk.Button(
+ self._buttons_frame,
+ text="No",
+ width=_BUTTON_WIDTH,
+ height=_BUTTON_HEIGHT,
+ command=lambda: self._root.destroy(pressed_no=True),
+ )
+ self._no.pack(side=tk.RIGHT)
+
+
class _GUIWidgets(Enum):
INPUT_PATH = "input_path"
OUTPUT_PATH = "output_path"
@@ -373,7 +444,6 @@ class _GUI(object):
text="Metadata...",
width=_BUTTON_WIDTH,
height=_BUTTON_HEIGHT,
- fg="black",
command=self._open_metadata,
)
self._widgets["metadata"].pack()
@@ -405,7 +475,6 @@ class _GUI(object):
text="Advanced options...",
width=_BUTTON_WIDTH,
height=_BUTTON_HEIGHT,
- fg="black",
command=self._open_advanced_options,
)
self._widgets[_GUIWidgets.ADVANCED_OPTIONS].pack()
@@ -417,7 +486,6 @@ class _GUI(object):
text="Train",
width=_BUTTON_WIDTH,
height=_BUTTON_HEIGHT,
- fg="black",
command=self._train,
)
self._widgets[_GUIWidgets.TRAIN].pack()
@@ -472,11 +540,6 @@ class _GUI(object):
False,
)
make_checkbox(_CheckboxKeys.SAVE_PLOT, "Save ESR plot automatically", True)
- make_checkbox(
- _CheckboxKeys.IGNORE_DATA_CHECKS,
- "Ignore data quality checks (DO AT YOUR OWN RISK!)",
- False,
- )
# Grid them:
row = 1
@@ -513,10 +576,20 @@ class _GUI(object):
widget["state"] = state
def _train(self):
+ input_path = self._widgets[_GUIWidgets.INPUT_PATH].val
+ output_paths = self._widgets[_GUIWidgets.OUTPUT_PATH].val
+ # Validate all files before running:
+ success = self._validate_all_data(input_path, output_paths)
+ if success:
+ self._train2()
+
+ def _train2(self, ignore_checks=False):
+ input_path = self._widgets[_GUIWidgets.INPUT_PATH].val
+
# Advanced options:
num_epochs = self.advanced_options.num_epochs
architecture = self.advanced_options.architecture
- delay = self.advanced_options.latency
+ user_latency = self.advanced_options.latency
file_list = self._widgets[_GUIWidgets.OUTPUT_PATH].val
threshold_esr = self.advanced_options.threshold_esr
@@ -527,21 +600,20 @@ class _GUI(object):
lr_decay = _DEFAULT_LR_DECAY
batch_size = _DEFAULT_BATCH_SIZE
seed = 0
-
# Run it
for file in file_list:
- print("Now training {}".format(file))
+ print(f"Now training {file}")
basename = re.sub(r"\.wav$", "", file.split("/")[-1])
user_metadata = (
self.user_metadata if self.user_metadata_flag else UserMetadata()
)
train_output = core.train(
- self._widgets[_GUIWidgets.INPUT_PATH].val,
+ input_path,
file,
self._widgets[_GUIWidgets.TRAINING_DESTINATION].val,
epochs=num_epochs,
- latency=delay,
+ latency=user_latency,
architecture=architecture,
batch_size=batch_size,
lr=lr,
@@ -550,9 +622,7 @@ class _GUI(object):
silent=self._checkboxes[_CheckboxKeys.SILENT_TRAINING].variable.get(),
save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(),
modelname=basename,
- ignore_checks=self._checkboxes[
- _CheckboxKeys.IGNORE_DATA_CHECKS
- ].variable.get(),
+ ignore_checks=ignore_checks,
local=True,
fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(),
threshold_esr=threshold_esr,
@@ -570,7 +640,9 @@ class _GUI(object):
outdir,
basename=basename,
user_metadata=user_metadata,
- other_metadata={TRAINING_KEY: train_output.metadata.model_dump()},
+ other_metadata={
+ metadata.TRAINING_KEY: train_output.metadata.model_dump()
+ },
)
print("Done!")
@@ -578,6 +650,104 @@ class _GUI(object):
# the user re-visits the window and clicks "ok"
self.user_metadata_flag = False
+ def _validate_all_data(
+ self, input_path: Path, output_paths: Sequence[Path]
+ ) -> bool:
+ """
+ Validate all the data.
+ If something doesn't pass, then alert the user and ask them whether they
+ want to continue.
+
+ :return: whether we passed (NOTE: Training in spite of failure is
+ triggered by a modal that is produced on failure.)
+ """
+
+ def make_message_for_file(
+ output_path: str, validation_output: core.DataValidationOutput
+ ) -> str:
+ """
+ File and explain what's wrong with it.
+ """
+ # TODO put this closer to what it looks at, i.e. core.DataValidationOutput
+ msg = f" {Path(output_path).name}:\n" # They all have the same directory so
+ if validation_output.latency.manual is None:
+ if validation_output.latency.calibration.warnings.matches_lookahead:
+ msg += (
+ " * The calibrated latency is the maximum allowed. This is "
+ "probably because the latency calibration was triggered by noise.\n"
+ )
+ if validation_output.latency.calibration.warnings.disagreement_too_high:
+ msg += " * The calculated latencies are too different from each other.\n"
+ if not validation_output.checks.passed:
+ msg += " * A data check failed (TODO in more detail).\n"
+ if not validation_output.pytorch.passed:
+ msg += " * PyTorch data set errors:\n"
+ for split in Split:
+ split_validation = getattr(validation_output.pytorch, split.value)
+ if not split_validation.passed:
+ msg += f" * {split.value:10s}: {split_validation.msg}\n"
+ return msg
+
+ # Validate input
+ input_validation = core.validate_input(input_path)
+ if not input_validation.passed:
+ self._wait_while_func(
+ (lambda resume, *args, **kwargs: _OkModal(resume, *args, **kwargs)),
+ f"Input file {input_path} is not recognized as a standardized input "
+ "file.\nTraining cannot proceed.",
+ )
+ return False
+
+ user_latency = self.advanced_options.latency
+ file_validation_outputs = {
+ output_path: core.validate_data(
+ input_path,
+ output_path,
+ user_latency,
+ )
+ for output_path in output_paths
+ }
+ if any(not fv.passed for fv in file_validation_outputs.values()):
+ msg = (
+ "The following output files failed checks:\n"
+ + "".join(
+ [
+ make_message_for_file(output_path, fv)
+ for output_path, fv in file_validation_outputs.items()
+ if not fv.passed
+ ]
+ )
+ + "\nIgnore and proceed?"
+ )
+
+ # Hacky to listen to the modal:
+ modal_listener = {"proceed": False, "still_open": True}
+
+ def on_yes():
+ modal_listener["proceed"] = True
+
+ def on_no():
+ modal_listener["proceed"] = False
+
+ def on_close():
+ if modal_listener["proceed"]:
+ self._train2(ignore_checks=True)
+
+ self._wait_while_func(
+ (
+ lambda resume, on_yes, on_no, *args, **kwargs: _YesNoModal(
+ on_yes, on_no, resume, *args, **kwargs
+ )
+ ),
+ on_yes=on_yes,
+ on_no=on_no,
+ msg=msg,
+ on_close=on_close,
+ label_kwargs={"anchor": "w"},
+ )
+ return False
+ return True
+
def _wait_while_func(self, func, *args, **kwargs):
"""
Disable this GUI while something happens.
@@ -631,12 +801,10 @@ class _LabeledOptionMenu(object):
self._choices = choices
height = _BUTTON_HEIGHT
bg = None
- fg = "black"
self._label = tk.Label(
frame,
width=_ADVANCED_OPTIONS_LEFT_WIDTH,
height=height,
- fg=fg,
bg=bg,
anchor="w",
text=label,
@@ -696,7 +864,6 @@ class _LabeledText(object):
frame,
width=left_width,
height=label_height,
- fg="black",
bg=None,
anchor="w",
text=label,
@@ -707,7 +874,6 @@ class _LabeledText(object):
frame,
width=right_width,
height=text_height,
- fg="black",
bg=None,
)
self._text.pack(side=tk.RIGHT)
@@ -779,7 +945,7 @@ class _AdvancedOptionsGUI(object):
type=_float_or_null,
)
- # "Ok": apply and destory
+ # "Ok": apply and destroy
self._frame_ok = tk.Frame(self._root)
self._frame_ok.pack()
self._button_ok = tk.Button(
@@ -787,7 +953,6 @@ class _AdvancedOptionsGUI(object):
text="Ok",
width=_BUTTON_WIDTH,
height=_BUTTON_HEIGHT,
- fg="black",
command=lambda: self._root.destroy(pressed_ok=True),
)
self._button_ok.pack()
@@ -879,7 +1044,7 @@ class _UserMetadataGUI(object):
default=parent.user_metadata.tone_type,
)
- # "Ok": apply and destory
+ # "Ok": apply and destroy
self._frame_ok = tk.Frame(self._root)
self._frame_ok.pack()
self._button_ok = tk.Button(
@@ -887,7 +1052,6 @@ class _UserMetadataGUI(object):
text="Ok",
width=_BUTTON_WIDTH,
height=_BUTTON_HEIGHT,
- fg="black",
command=lambda: self._root.destroy(pressed_ok=True),
)
self._button_ok.pack()
diff --git a/nam/train/metadata.py b/nam/train/metadata.py
@@ -18,6 +18,7 @@ __all__ = [
"DataChecks",
"Latency",
"LatencyCalibration",
+ "LatencyCalibrationWarnings",
"Settings",
"TrainingMetadata",
"TRAINING_KEY",
@@ -36,11 +37,28 @@ class Settings(BaseModel):
ignore_checks: bool
+class LatencyCalibrationWarnings(BaseModel):
+ """
+ Things that aren't necessarily wrong with the latency calibration but are
+ worth looking into.
+
+ :param matches_lookahead: The calibrated latency is as far forard as
+ possible, i.e. the very first sample we looked at tripped the trigger.
+ That's probably not a coincidence but the trigger is too sensitive.
+ :param max_disagreement: The max disagreement between latency estimates. If
+ it's too large, then there's a risk that something was warong.
+ """
+
+ matches_lookahead: bool
+ disagreement_too_high: int
+
+
class LatencyCalibration(BaseModel):
algorithm_version: int
delays: List[int]
safety_factor: int
recommended: int
+ warnings: LatencyCalibrationWarnings
class Latency(BaseModel):
diff --git a/tests/test_nam/test_models/test_exportable.py b/tests/test_nam/test_models/test_exportable.py
@@ -75,6 +75,10 @@ class TestExportable(object):
delays=[1, 3],
safety_factor=4,
recommended=-3,
+ warnings=train_metadata.LatencyCalibrationWarnings(
+ matches_lookahead=False,
+ disagreement_too_high=False,
+ ),
),
),
checks=train_metadata.DataChecks(version=4, passed=True),
@@ -105,6 +109,10 @@ class TestExportable(object):
delays=[1, 3],
safety_factor=4,
recommended=-3,
+ warnings=train_metadata.LatencyCalibrationWarnings(
+ matches_lookahead=False,
+ disagreement_too_high=False,
+ ),
),
),
checks=train_metadata.DataChecks(version=4, passed=True),