neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 825b810545516ce4556ba4e6cbde9f86641cb37b
parent df619915ed021dd9c5ab233785493eda3960f219
Author: Steven Atkinson <[email protected]>
Date:   Mon, 20 May 2024 23:16:49 -0500

[FEATURE,GUI,BREAKING] Validate data before training (#425)

* Training metadata

* Optional kwarg training_metadata on export

* Save training metadata in .nam file, tests

* Refactor delay calibration to return more info in preparation for including in metadata

* Got standardizedd training metadata all sorted out

* Fix bugs, add end-to-end test of core train

* Some cleanup

* Flake

* Validate data before starting training.
Diffstat:
Mnam/data.py | 14+++++++++++---
Mnam/train/core.py | 234+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------
Mnam/train/gui/__init__.py | 230+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------
Mnam/train/metadata.py | 18++++++++++++++++++
Mtests/test_nam/test_models/test_exportable.py | 8++++++++
5 files changed, 434 insertions(+), 70 deletions(-)

diff --git a/nam/data.py b/nam/data.py @@ -40,7 +40,15 @@ class WavInfo: rate: int -class AudioShapeMismatchError(ValueError): +class DataError(Exception): + """ + Parent class for all special exceptions raised by NAM data sets + """ + + pass + + +class AudioShapeMismatchError(ValueError, DataError): """ Exception where the shape (number of samples, number of channels) of two audio files don't match but were supposed to. @@ -191,7 +199,7 @@ def _interpolate_delay( ) -class XYError(ValueError): +class XYError(ValueError, DataError): """ Exceptions related to invalid x and y provided for data sets """ @@ -199,7 +207,7 @@ class XYError(ValueError): pass -class StartStopError(ValueError): +class StartStopError(ValueError, DataError): """ Exceptions related to invalid start and stop arguments """ diff --git a/nam/train/core.py b/nam/train/core.py @@ -25,7 +25,7 @@ from pydantic import BaseModel from pytorch_lightning.utilities.warnings import PossibleUserWarning from torch.utils.data import DataLoader -from ..data import Split, init_dataset, wav_to_np, wav_to_tensor +from ..data import DataError, Split, init_dataset, wav_to_np, wav_to_tensor from ..models import Model from ..models.exportable import Exportable from ..models.losses import esr @@ -34,10 +34,20 @@ from ..util import filter_warnings from ._version import PROTEUS_VERSION, Version from . import metadata -__all__ = ["train"] +__all__ = [ + "Architecture", + "DataValidationOutput", + "STANDARD_SAMPLE_RATE", + "TrainOutput", + "train", + "validate_data", + "validate_input", +] # Training using the simplified trainers in NAM is done at 48k. STANDARD_SAMPLE_RATE = 48_000.0 +# Default number of output samples per datum. +_NY_DEFAULT = 8192 class Architecture(Enum): @@ -47,6 +57,10 @@ class Architecture(Enum): NANO = "nano" +class _InputValidationError(ValueError): + pass + + def _detect_input_version(input_path) -> Tuple[Version, bool]: """ Check to see if the input matches any of the known inputs @@ -227,7 +241,7 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]: print("Falling back to weak-matching...") version = detect_weak(input_path) if version is None: - raise ValueError( + raise _InputValidationError( f"Input file at {input_path} cannot be recognized as any known version!" ) strong_match = False @@ -353,22 +367,37 @@ def _calibrate_latency_v_all( :param y: The output audio, in complete. """ - def report_any_delay_warnings(delays: Sequence[int]): + def report_any_latency_warnings( + delays: Sequence[int], + ) -> metadata.LatencyCalibrationWarnings: # Warnings associated with any single delay: + # "Lookahead warning": if the delay is equal to the lookahead, then it's + # probably an error. lookahead_warnings = [i for i, d in enumerate(delays, 1) if d == -lookahead] - if len(lookahead_warnings) > 0: + matches_lookahead = len(lookahead_warnings) > 0 + if matches_lookahead: print(_warn_lookaheads(lookahead_warnings)) # Ensemble warnings # If they're _really_ different, then something might be wrong. - if np.max(delays) - np.min(delays) >= 20: + max_disagreement_threshold = 20 + max_disagreement_too_high = ( + np.max(delays) - np.min(delays) >= max_disagreement_threshold + ) + if max_disagreement_too_high: print( - "WARNING: Delays are anomalously different from each other. If this model " - "turns out badly, then you might need to provide the delay manually." + "WARNING: Latencies are anomalously different from each other (more " + f"than {max_disagreement_threshold} samples). If this model turns out " + "badly, then you might need to provide the latency manually." ) + return metadata.LatencyCalibrationWarnings( + matches_lookahead=matches_lookahead, + disagreement_too_high=max_disagreement_too_high, + ) + lookahead = 1_000 lookback = 10_000 # Calibrate the trigger: @@ -422,7 +451,7 @@ def _calibrate_latency_v_all( print("Delays:") for i_rel, d in enumerate(delays, 1): print(f" Blip {i_rel:2d}: {d}") - report_any_delay_warnings(delays) + warnings = report_any_latency_warnings(delays) delay_post_safety_factor = int(np.min(delays)) - safety_factor print( @@ -434,6 +463,7 @@ def _calibrate_latency_v_all( delays=delays, safety_factor=safety_factor, recommended=delay_post_safety_factor, + warnings=warnings, ) @@ -876,20 +906,9 @@ _CAB_MRSTFT_PRE_EMPH_WEIGHT = 2.0e-4 _CAB_MRSTFT_PRE_EMPH_COEF = 0.85 -def _get_configs( - input_version: Version, - input_path: str, - output_path: str, - delay: int, - epochs: int, - model_type: str, - architecture: Architecture, - ny: int, - lr: float, - lr_decay: float, - batch_size: int, - fit_cab: bool, -): +def _get_data_config( + input_version: Version, input_path: Path, output_path: Path, ny: int, latency: int +) -> dict: def get_kwargs(data_info: _DataInfo): if data_info.major_version == 1: train_val_split = data_info.validation_start @@ -941,9 +960,34 @@ def _get_configs( "common": { "x_path": input_path, "y_path": output_path, - "delay": delay, + "delay": latency, }, } + return data_config + + +def _get_configs( + input_version: Version, + input_path: str, + output_path: str, + latency: int, + epochs: int, + model_type: str, + architecture: Architecture, + ny: int, + lr: float, + lr_decay: float, + batch_size: int, + fit_cab: bool, +): + + data_config = _get_data_config( + input_version=input_version, + input_path=input_path, + output_path=output_path, + ny=ny, + latency=latency, + ) if model_type == "WaveNet": model_config = { @@ -1228,6 +1272,16 @@ class TrainOutput(NamedTuple): metadata: metadata.TrainingMetadata +def _get_final_latency(latency_analysis: metadata.Latency) -> int: + if latency_analysis.manual is not None: + latency = latency_analysis.manual + print(f"Latency provided as {latency_analysis.manual}; override calibration") + else: + latency = latency_analysis.calibration.recommended + print(f"Set latency to recommended {latency_analysis.calibration.recommended}") + return latency + + def train( input_path: str, output_path: str, @@ -1239,7 +1293,7 @@ def train( model_type: str = "WaveNet", architecture: Union[Architecture, str] = Architecture.STANDARD, batch_size: int = 16, - ny: int = 8192, + ny: int = _NY_DEFAULT, lr=0.004, lr_decay=0.007, seed: Optional[int] = 0, @@ -1254,6 +1308,7 @@ def train( fast_dev_run: Union[bool, int] = False, ) -> Optional[TrainOutput]: """ + :param lr_decay: =1-gamma for Exponential learning rate decay. :param threshold_esr: Stop training if ESR is better than this. Ignore if `None`. :param fast_dev_run: One-step training, used for tests. """ @@ -1276,17 +1331,12 @@ def train( user_latency = parse_user_latency(delay, latency) latency_analysis = _analyze_latency( - latency, input_version, input_path, output_path, silent=silent + user_latency, input_version, input_path, output_path, silent=silent ) - if latency_analysis.manual is not None: - latency = latency_analysis.manual - print(f"Latency provided as {user_latency}; override calibration") - else: - latency = latency_analysis.calibration.recommended - print(f"Set latency to recommended {latency_analysis.calibration.recommended}") + final_latency = _get_final_latency(latency_analysis) data_check_output = _check_data( - input_path, output_path, input_version, latency, silent + input_path, output_path, input_version, final_latency, silent ) if data_check_output is not None: if data_check_output.passed: @@ -1322,7 +1372,7 @@ def train( input_version, input_path, output_path, - latency, + final_latency, epochs, model_type, Architecture(architecture), @@ -1418,3 +1468,119 @@ def train( validation_esr=validation_esr, ), ) + + +class DataInputValidation(BaseModel): + passed: bool + + +def validate_input(input_path) -> DataInputValidation: + """ + :return: Could it be validated? + """ + try: + _detect_input_version(input_path) + # succeeded... + return DataInputValidation(passed=True) + except _InputValidationError as e: + print(f"Input validation failed!\n\n{e}") + return DataInputValidation(passed=False) + + +class _PyTorchDataSplitValidation(BaseModel): + """ + :param msg: On exception, catch and assign. Otherwise None + """ + + passed: bool + msg: Optional[str] + + +class _PyTorchDataValidation(BaseModel): + passed: bool + train: _PyTorchDataSplitValidation # cf Split.TRAIN + validation: _PyTorchDataSplitValidation # Split.VALIDATION + + +class DataValidationOutput(BaseModel): + passed: bool + input_version: str + latency: metadata.Latency + checks: metadata.DataChecks + pytorch: _PyTorchDataValidation + + +def validate_data( + input_path: Path, + output_path: Path, + user_latency: Optional[int], + num_output_samples_per_datum: int = _NY_DEFAULT, +): + """ + Just do the checks to make sure that the data are ok. + + * Version identification + * Latency calibration + * Other checks + """ + passed = True # Until proven otherwise + + # Data version ID + input_version, strong_match = _detect_input_version(input_path) + + # Latency analysis + latency_analysis = _analyze_latency( + user_latency, input_version, input_path, output_path, silent=True + ) + if latency_analysis.manual is None and any( + val for val in latency_analysis.calibration.warnings.model_dump().values() + ): + passed = False + final_latency = _get_final_latency(latency_analysis) + + # Other data checks based on input file version + data_checks = _check_data( + input_path, + output_path, + input_version, + latency_analysis.calibration.recommended, + silent=True, + ) + passed = passed and data_checks.passed + + # Finally, try to make the PyTorch Dataset objects and note any failures: + data_config = _get_data_config( + input_version=input_version, + input_path=input_path, + output_path=output_path, + ny=num_output_samples_per_datum, + latency=final_latency, + ) + # HACK this should depend on the model that's going to be used, but I think it will + # be unlikely to make a difference. Still, would be nice to fix. + data_config["common"]["nx"] = 4096 + + pytorch_data_split_validation_dict: Dict[str, _PyTorchDataSplitValidation] = {} + for split in Split: + try: + init_dataset(data_config, split) + pytorch_data_split_validation_dict[split.value] = ( + _PyTorchDataSplitValidation(passed=True, msg=None) + ) + except DataError as e: + pytorch_data_split_validation_dict[split.value] = ( + _PyTorchDataSplitValidation(passed=False, msg=str(e)) + ) + pytorch_data_validation = _PyTorchDataValidation( + passed=all(v.passed for v in pytorch_data_split_validation_dict.values()), + **pytorch_data_split_validation_dict, + ) + passed = passed and pytorch_data_validation.passed + + return DataValidationOutput( + passed=passed, + input_version=str(input_version), + latency=latency_analysis, + checks=data_checks, + pytorch=pytorch_data_validation, + ) diff --git a/nam/train/gui/__init__.py b/nam/train/gui/__init__.py @@ -39,11 +39,13 @@ try: # 3rd-party and 1st-party imports import torch from nam import __version__ + from nam.data import Split from nam.train import core from nam.train.gui._resources import settings from nam.models.metadata import GearType, UserMetadata, ToneType # Ok private access here--this is technically allowed access + from nam.train import metadata from nam.train._names import INPUT_BASENAMES, LATEST_VERSION from nam.train.metadata import TRAINING_KEY @@ -115,6 +117,7 @@ class _PathButton(object): path_key: settings.PathKey, hooks: Optional[Sequence[Callable[[], None]]] = None, color_when_not_set: str = "#EF0000", # Darker red + color_when_set: str = "systemTextColor", default: Optional[Path] = None, ): """ @@ -132,7 +135,6 @@ class _PathButton(object): text=button_text, width=_BUTTON_WIDTH, height=_BUTTON_HEIGHT, - fg="black", command=self._set_val, ) self._widgets["button"].pack(side=tk.LEFT) @@ -140,13 +142,13 @@ class _PathButton(object): self._frame, width=_TEXT_WIDTH, height=_BUTTON_HEIGHT, - fg="black", bg=None, anchor="w", ) self._widgets["label"].pack(side=tk.LEFT) self._hooks = hooks self._color_when_not_set = color_when_not_set + self._color_when_set = color_when_set self._set_text() def __setitem__(self, key, val): @@ -172,7 +174,7 @@ class _PathButton(object): else: val = self.val val = val[0] if isinstance(val, tuple) and len(val) == 1 else val - self._widgets["label"]["fg"] = "black" + self._widgets["label"]["fg"] = self._color_when_set self._widgets["label"][ "text" ] = f"{self._button_text.capitalize()} set to {val}" @@ -212,7 +214,6 @@ class _InputPathButton(_PathButton): text="Download input file", width=_BUTTON_WIDTH, height=_BUTTON_HEIGHT, - fg="black", command=self._download_input_file, ) self._widgets["button_download_input"].pack(side=tk.RIGHT) @@ -252,7 +253,6 @@ class _ClearablePathButton(_PathButton): text="Clear", width=_BUTTON_WIDTH, height=_BUTTON_HEIGHT, - fg="black", command=self._clear_path, ) self._widgets["button_clear"].pack(side=tk.RIGHT) @@ -270,7 +270,6 @@ class _CheckboxKeys(Enum): FIT_CAB = "fit_cab" SILENT_TRAINING = "silent_training" SAVE_PLOT = "save_plot" - IGNORE_DATA_CHECKS = "ignore_data_checks" class _TopLevelWithOk(tk.Toplevel): @@ -295,7 +294,42 @@ class _TopLevelWithOk(tk.Toplevel): super().destroy() -class _BasicModal(object): +class _TopLevelWithYesNo(tk.Toplevel): + """ + Toplevel holding functions for yes/no buttons to close + """ + + def __init__( + self, + on_yes: Callable[[None], None], + on_no: Callable[[None], None], + on_close: Optional[Callable[[None], None]], + resume_main: Callable[[None], None], + ): + """ + :param on_yes: What to do when "Yes" button is pressed. + :param on_no: What to do when "No" button is pressed. + :param on_close: Do this regardless when closing (via yes/no/x) before + resuming. + """ + super().__init__() + self._on_yes = on_yes + self._on_no = on_no + self._on_close = on_close + self._resume_main = resume_main + + def destroy(self, pressed_yes: bool = False, pressed_no: bool = False): + if pressed_yes: + self._on_yes() + if pressed_no: + self._on_no() + if self._on_close is not None: + self._on_close() + self._resume_main() + super().destroy() + + +class _OkModal(object): """ Message and OK button """ @@ -309,12 +343,49 @@ class _BasicModal(object): text="Ok", width=_BUTTON_WIDTH, height=_BUTTON_HEIGHT, - fg="black", command=lambda: self._root.destroy(pressed_ok=True), ) self._ok.pack() +class _YesNoModal(object): + """ + Modal w/ yes/no buttons + """ + + def __init__( + self, + on_yes: Callable[[None], None], + on_no: Callable[[None], None], + resume_main, + msg: str, + on_close: Optional[Callable[[None], None]] = None, + label_kwargs: Optional[dict] = None, + ): + label_kwargs = {} if label_kwargs is None else label_kwargs + self._root = _TopLevelWithYesNo(on_yes, on_no, on_close, resume_main) + self._text = tk.Label(self._root, text=msg, **label_kwargs) + self._text.pack() + self._buttons_frame = tk.Frame(self._root) + self._buttons_frame.pack() + self._yes = tk.Button( + self._buttons_frame, + text="Yes", + width=_BUTTON_WIDTH, + height=_BUTTON_HEIGHT, + command=lambda: self._root.destroy(pressed_yes=True), + ) + self._yes.pack(side=tk.LEFT) + self._no = tk.Button( + self._buttons_frame, + text="No", + width=_BUTTON_WIDTH, + height=_BUTTON_HEIGHT, + command=lambda: self._root.destroy(pressed_no=True), + ) + self._no.pack(side=tk.RIGHT) + + class _GUIWidgets(Enum): INPUT_PATH = "input_path" OUTPUT_PATH = "output_path" @@ -373,7 +444,6 @@ class _GUI(object): text="Metadata...", width=_BUTTON_WIDTH, height=_BUTTON_HEIGHT, - fg="black", command=self._open_metadata, ) self._widgets["metadata"].pack() @@ -405,7 +475,6 @@ class _GUI(object): text="Advanced options...", width=_BUTTON_WIDTH, height=_BUTTON_HEIGHT, - fg="black", command=self._open_advanced_options, ) self._widgets[_GUIWidgets.ADVANCED_OPTIONS].pack() @@ -417,7 +486,6 @@ class _GUI(object): text="Train", width=_BUTTON_WIDTH, height=_BUTTON_HEIGHT, - fg="black", command=self._train, ) self._widgets[_GUIWidgets.TRAIN].pack() @@ -472,11 +540,6 @@ class _GUI(object): False, ) make_checkbox(_CheckboxKeys.SAVE_PLOT, "Save ESR plot automatically", True) - make_checkbox( - _CheckboxKeys.IGNORE_DATA_CHECKS, - "Ignore data quality checks (DO AT YOUR OWN RISK!)", - False, - ) # Grid them: row = 1 @@ -513,10 +576,20 @@ class _GUI(object): widget["state"] = state def _train(self): + input_path = self._widgets[_GUIWidgets.INPUT_PATH].val + output_paths = self._widgets[_GUIWidgets.OUTPUT_PATH].val + # Validate all files before running: + success = self._validate_all_data(input_path, output_paths) + if success: + self._train2() + + def _train2(self, ignore_checks=False): + input_path = self._widgets[_GUIWidgets.INPUT_PATH].val + # Advanced options: num_epochs = self.advanced_options.num_epochs architecture = self.advanced_options.architecture - delay = self.advanced_options.latency + user_latency = self.advanced_options.latency file_list = self._widgets[_GUIWidgets.OUTPUT_PATH].val threshold_esr = self.advanced_options.threshold_esr @@ -527,21 +600,20 @@ class _GUI(object): lr_decay = _DEFAULT_LR_DECAY batch_size = _DEFAULT_BATCH_SIZE seed = 0 - # Run it for file in file_list: - print("Now training {}".format(file)) + print(f"Now training {file}") basename = re.sub(r"\.wav$", "", file.split("/")[-1]) user_metadata = ( self.user_metadata if self.user_metadata_flag else UserMetadata() ) train_output = core.train( - self._widgets[_GUIWidgets.INPUT_PATH].val, + input_path, file, self._widgets[_GUIWidgets.TRAINING_DESTINATION].val, epochs=num_epochs, - latency=delay, + latency=user_latency, architecture=architecture, batch_size=batch_size, lr=lr, @@ -550,9 +622,7 @@ class _GUI(object): silent=self._checkboxes[_CheckboxKeys.SILENT_TRAINING].variable.get(), save_plot=self._checkboxes[_CheckboxKeys.SAVE_PLOT].variable.get(), modelname=basename, - ignore_checks=self._checkboxes[ - _CheckboxKeys.IGNORE_DATA_CHECKS - ].variable.get(), + ignore_checks=ignore_checks, local=True, fit_cab=self._checkboxes[_CheckboxKeys.FIT_CAB].variable.get(), threshold_esr=threshold_esr, @@ -570,7 +640,9 @@ class _GUI(object): outdir, basename=basename, user_metadata=user_metadata, - other_metadata={TRAINING_KEY: train_output.metadata.model_dump()}, + other_metadata={ + metadata.TRAINING_KEY: train_output.metadata.model_dump() + }, ) print("Done!") @@ -578,6 +650,104 @@ class _GUI(object): # the user re-visits the window and clicks "ok" self.user_metadata_flag = False + def _validate_all_data( + self, input_path: Path, output_paths: Sequence[Path] + ) -> bool: + """ + Validate all the data. + If something doesn't pass, then alert the user and ask them whether they + want to continue. + + :return: whether we passed (NOTE: Training in spite of failure is + triggered by a modal that is produced on failure.) + """ + + def make_message_for_file( + output_path: str, validation_output: core.DataValidationOutput + ) -> str: + """ + File and explain what's wrong with it. + """ + # TODO put this closer to what it looks at, i.e. core.DataValidationOutput + msg = f" {Path(output_path).name}:\n" # They all have the same directory so + if validation_output.latency.manual is None: + if validation_output.latency.calibration.warnings.matches_lookahead: + msg += ( + " * The calibrated latency is the maximum allowed. This is " + "probably because the latency calibration was triggered by noise.\n" + ) + if validation_output.latency.calibration.warnings.disagreement_too_high: + msg += " * The calculated latencies are too different from each other.\n" + if not validation_output.checks.passed: + msg += " * A data check failed (TODO in more detail).\n" + if not validation_output.pytorch.passed: + msg += " * PyTorch data set errors:\n" + for split in Split: + split_validation = getattr(validation_output.pytorch, split.value) + if not split_validation.passed: + msg += f" * {split.value:10s}: {split_validation.msg}\n" + return msg + + # Validate input + input_validation = core.validate_input(input_path) + if not input_validation.passed: + self._wait_while_func( + (lambda resume, *args, **kwargs: _OkModal(resume, *args, **kwargs)), + f"Input file {input_path} is not recognized as a standardized input " + "file.\nTraining cannot proceed.", + ) + return False + + user_latency = self.advanced_options.latency + file_validation_outputs = { + output_path: core.validate_data( + input_path, + output_path, + user_latency, + ) + for output_path in output_paths + } + if any(not fv.passed for fv in file_validation_outputs.values()): + msg = ( + "The following output files failed checks:\n" + + "".join( + [ + make_message_for_file(output_path, fv) + for output_path, fv in file_validation_outputs.items() + if not fv.passed + ] + ) + + "\nIgnore and proceed?" + ) + + # Hacky to listen to the modal: + modal_listener = {"proceed": False, "still_open": True} + + def on_yes(): + modal_listener["proceed"] = True + + def on_no(): + modal_listener["proceed"] = False + + def on_close(): + if modal_listener["proceed"]: + self._train2(ignore_checks=True) + + self._wait_while_func( + ( + lambda resume, on_yes, on_no, *args, **kwargs: _YesNoModal( + on_yes, on_no, resume, *args, **kwargs + ) + ), + on_yes=on_yes, + on_no=on_no, + msg=msg, + on_close=on_close, + label_kwargs={"anchor": "w"}, + ) + return False + return True + def _wait_while_func(self, func, *args, **kwargs): """ Disable this GUI while something happens. @@ -631,12 +801,10 @@ class _LabeledOptionMenu(object): self._choices = choices height = _BUTTON_HEIGHT bg = None - fg = "black" self._label = tk.Label( frame, width=_ADVANCED_OPTIONS_LEFT_WIDTH, height=height, - fg=fg, bg=bg, anchor="w", text=label, @@ -696,7 +864,6 @@ class _LabeledText(object): frame, width=left_width, height=label_height, - fg="black", bg=None, anchor="w", text=label, @@ -707,7 +874,6 @@ class _LabeledText(object): frame, width=right_width, height=text_height, - fg="black", bg=None, ) self._text.pack(side=tk.RIGHT) @@ -779,7 +945,7 @@ class _AdvancedOptionsGUI(object): type=_float_or_null, ) - # "Ok": apply and destory + # "Ok": apply and destroy self._frame_ok = tk.Frame(self._root) self._frame_ok.pack() self._button_ok = tk.Button( @@ -787,7 +953,6 @@ class _AdvancedOptionsGUI(object): text="Ok", width=_BUTTON_WIDTH, height=_BUTTON_HEIGHT, - fg="black", command=lambda: self._root.destroy(pressed_ok=True), ) self._button_ok.pack() @@ -879,7 +1044,7 @@ class _UserMetadataGUI(object): default=parent.user_metadata.tone_type, ) - # "Ok": apply and destory + # "Ok": apply and destroy self._frame_ok = tk.Frame(self._root) self._frame_ok.pack() self._button_ok = tk.Button( @@ -887,7 +1052,6 @@ class _UserMetadataGUI(object): text="Ok", width=_BUTTON_WIDTH, height=_BUTTON_HEIGHT, - fg="black", command=lambda: self._root.destroy(pressed_ok=True), ) self._button_ok.pack() diff --git a/nam/train/metadata.py b/nam/train/metadata.py @@ -18,6 +18,7 @@ __all__ = [ "DataChecks", "Latency", "LatencyCalibration", + "LatencyCalibrationWarnings", "Settings", "TrainingMetadata", "TRAINING_KEY", @@ -36,11 +37,28 @@ class Settings(BaseModel): ignore_checks: bool +class LatencyCalibrationWarnings(BaseModel): + """ + Things that aren't necessarily wrong with the latency calibration but are + worth looking into. + + :param matches_lookahead: The calibrated latency is as far forard as + possible, i.e. the very first sample we looked at tripped the trigger. + That's probably not a coincidence but the trigger is too sensitive. + :param max_disagreement: The max disagreement between latency estimates. If + it's too large, then there's a risk that something was warong. + """ + + matches_lookahead: bool + disagreement_too_high: int + + class LatencyCalibration(BaseModel): algorithm_version: int delays: List[int] safety_factor: int recommended: int + warnings: LatencyCalibrationWarnings class Latency(BaseModel): diff --git a/tests/test_nam/test_models/test_exportable.py b/tests/test_nam/test_models/test_exportable.py @@ -75,6 +75,10 @@ class TestExportable(object): delays=[1, 3], safety_factor=4, recommended=-3, + warnings=train_metadata.LatencyCalibrationWarnings( + matches_lookahead=False, + disagreement_too_high=False, + ), ), ), checks=train_metadata.DataChecks(version=4, passed=True), @@ -105,6 +109,10 @@ class TestExportable(object): delays=[1, 3], safety_factor=4, recommended=-3, + warnings=train_metadata.LatencyCalibrationWarnings( + matches_lookahead=False, + disagreement_too_high=False, + ), ), ), checks=train_metadata.DataChecks(version=4, passed=True),