neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 64f79fdc4c68b8735b879f495d4e5099cd39aa6d
parent 0c5112907742f1e65a4fd6ff22656a38253505d0
Author: Steven Atkinson <[email protected]>
Date:   Sat, 28 Sep 2024 10:08:49 -0700

Data validation: Check audio length and sample rates for matches (#474)

* Raise DataError instead of ValueError

* Add sample rate and length validation for data (require exact match for now)

* Fix docstring

* Define critical checks

* Don't allow ignoring critical checks
Diffstat:
Mnam/data.py | 3++-
Mnam/train/core.py | 98+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Mnam/train/gui/__init__.py | 99+++++++++++++++++++++++++++++++++++++++++++++++++++----------------------------
3 files changed, 164 insertions(+), 36 deletions(-)

diff --git a/nam/data.py b/nam/data.py @@ -431,7 +431,8 @@ class Dataset(AbstractDataset, InitializableFromConfig): msg += ( f"\n * The output is {sample_to_time(y_samples, sample_rate)} long" ) - raise ValueError(msg) + msg += f"\n\nOriginal exception:\n{e}" + raise DataError(msg) return {"x": x, "y": y, "sample_rate": sample_rate, **config} @classmethod diff --git a/nam/train/core.py b/nam/train/core.py @@ -1327,6 +1327,25 @@ def train( if seed is not None: torch.manual_seed(seed) + # HACK: We need to check the sample rates and lengths of the audio here or else + # It will look like a bad self-ESR (Issue 473) + # Can move this into the "v3 checks" once the others are deprecated. + # And honestly remake this whole thing as a data processing pipeline. + sample_rate_validation = _check_audio_sample_rates(input_path, output_path) + if not sample_rate_validation.passed: + raise ValueError( + "Different sample rates detected for input " + f"({sample_rate_validation.input}) and output " + f"({sample_rate_validation.output}) audio!" + ) + length_validation = _check_audio_lengths(input_path, output_path) + if not length_validation.passed: + raise ValueError( + "Your recording differs in length from the input file by " + f"{length_validation.delta_seconds:.2f} seconds. Check your reamp " + "in your DAW and ensure that they are the same length." + ) + if input_version is None: input_version, strong_match = _detect_input_version(input_path) @@ -1501,14 +1520,79 @@ class _PyTorchDataValidation(BaseModel): validation: _PyTorchDataSplitValidation # Split.VALIDATION +class _SampleRateValidation(BaseModel): + passed: bool + input: int + output: int + + +class _LengthValidation(BaseModel): + passed: bool + delta_seconds: float + + class DataValidationOutput(BaseModel): passed: bool + passed_critical: bool + sample_rate: _SampleRateValidation + length: _LengthValidation input_version: str latency: metadata.Latency checks: metadata.DataChecks pytorch: _PyTorchDataValidation +def _check_audio_sample_rates( + input_path: Path, + output_path: Path, +) -> _SampleRateValidation: + _, x_info = wav_to_np(input_path, info=True) + _, y_info = wav_to_np(output_path, info=True) + + return _SampleRateValidation( + passed=x_info.rate == y_info.rate, + input=x_info.rate, + output=y_info.rate, + ) + + +def _check_audio_lengths( + input_path: Path, + output_path: Path, + max_under_seconds: Optional[float] = 0.0, + max_over_seconds: Optional[float] = 0.0, +) -> _LengthValidation: + """ + Check that the input and output have the right lengths compared to each + other. + + :param input_path: Path to input audio + :param output_path: Path to output audio + :param max_under_seconds: If not None, the maximum amount by which the + output can be shorter than the input. Should be non-negative i.e. a + value of 1.0 means that the output can't be more than a second shorter + than the input. + :param max_over_seconds: If not None, the maximum amount by which the + output can be longer than the input. Should be non-negative i.e. a + value of 1.0 means that the output can't be more than a second longer + than the input. + """ + x, x_info = wav_to_np(input_path, info=True) + y, y_info = wav_to_np(output_path, info=True) + + length_input = len(x) / x_info.rate + length_output = len(y) / y_info.rate + delta_seconds = length_output - length_input + + passed = True + if max_under_seconds is not None and delta_seconds < -max_under_seconds: + passed = False + if max_over_seconds is not None and delta_seconds > max_under_seconds: + passed = False + + return _LengthValidation(passed=passed, delta_seconds=delta_seconds) + + def validate_data( input_path: Path, output_path: Path, @@ -1522,7 +1606,17 @@ def validate_data( * Latency calibration * Other checks """ + print("Validating data...") passed = True # Until proven otherwise + passed_critical = True # These can't be ignored + + sample_rate_validation = _check_audio_sample_rates(input_path, output_path) + passed = passed and sample_rate_validation.passed + passed_critical = passed_critical and sample_rate_validation.passed + + length_validation = _check_audio_lengths(input_path, output_path) + passed = passed and length_validation.passed + passed_critical = passed_critical and length_validation.passed # Data version ID input_version, strong_match = _detect_input_version(input_path) @@ -1575,9 +1669,13 @@ def validate_data( **pytorch_data_split_validation_dict, ) passed = passed and pytorch_data_validation.passed + passed_critical = passed_critical and pytorch_data_validation.passed return DataValidationOutput( passed=passed, + passed_critical=passed_critical, + sample_rate=sample_rate_validation, + length=length_validation, input_version=str(input_version), latency=latency_analysis, checks=data_checks, diff --git a/nam/train/gui/__init__.py b/nam/train/gui/__init__.py @@ -327,9 +327,11 @@ class _OkModal(object): Message and OK button """ - def __init__(self, resume_main, msg: str): + def __init__(self, resume_main, msg: str, label_kwargs: Optional[dict] = None): + label_kwargs = {} if label_kwargs is None else label_kwargs + self._root = _TopLevelWithOk((lambda: None), resume_main) - self._text = tk.Label(self._root, text=msg) + self._text = tk.Label(self._root, text=msg, **label_kwargs) self._text.pack() self._ok = tk.Button( self._root, @@ -764,12 +766,32 @@ class GUI(object): output_path: str, validation_output: core.DataValidationOutput ) -> str: """ - File and explain what's wrong with it. + State the file and explain what's wrong with it. """ # TODO put this closer to what it looks at, i.e. core.DataValidationOutput msg = ( f"\t{Path(output_path).name}:\n" # They all have the same directory so ) + if not validation_output.sample_rate.passed: + msg += ( + "\t\t There are different sample rates for the input (" + f"{validation_output.sample_rate.input}) and output (" + f"{validation_output.sample_rate.output}).\n" + ) + if not validation_output.length.passed: + msg += ( + "\t\t* The input and output audio files are too different in length" + ) + if validation_output.length.delta_seconds > 0: + msg += ( + f" (the output is {validation_output.length.delta_seconds:.2f} " + "seconds longer than the input)\n" + ) + else: + msg += ( + f" (the output is {-validation_output.length.delta_seconds:.2f}" + " seconds shorter than the input)\n" + ) if validation_output.latency.manual is None: if validation_output.latency.calibration.warnings.matches_lookahead: msg += ( @@ -808,44 +830,51 @@ class GUI(object): for output_path in output_paths } if any(not fv.passed for fv in file_validation_outputs.values()): - msg = ( - "The following output files failed checks:\n" - + "".join( - [ - make_message_for_file(output_path, fv) - for output_path, fv in file_validation_outputs.items() - if not fv.passed - ] - ) - + "\nIgnore and proceed?" + msg = "The following output files failed checks:\n" + "".join( + [ + make_message_for_file(output_path, fv) + for output_path, fv in file_validation_outputs.items() + if not fv.passed + ] ) + if all(fv.passed_critical for fv in file_validation_outputs.values()): + msg += "\nIgnore and proceed?" - # Hacky to listen to the modal: - modal_listener = {"proceed": False, "still_open": True} + # Hacky to listen to the modal: + modal_listener = {"proceed": False, "still_open": True} - def on_yes(): - modal_listener["proceed"] = True + def on_yes(): + modal_listener["proceed"] = True - def on_no(): - modal_listener["proceed"] = False + def on_no(): + modal_listener["proceed"] = False - def on_close(): - if modal_listener["proceed"]: - self._train2(ignore_checks=True) + def on_close(): + if modal_listener["proceed"]: + self._train2(ignore_checks=True) + + self._wait_while_func( + ( + lambda resume, on_yes, on_no, *args, **kwargs: _YesNoModal( + on_yes, on_no, resume, *args, **kwargs + ) + ), + on_yes=on_yes, + on_no=on_no, + msg=msg, + on_close=on_close, + label_kwargs={"justify": "left"}, + ) + return False # we still failed checks so say so. + else: + msg += "\nCritical errors found, cannot ignore." + self._wait_while_func( + lambda resume, msg, **kwargs: _OkModal(resume, msg, **kwargs), + msg=msg, + label_kwargs={"justify": "left"}, + ) + return False - self._wait_while_func( - ( - lambda resume, on_yes, on_no, *args, **kwargs: _YesNoModal( - on_yes, on_no, resume, *args, **kwargs - ) - ), - on_yes=on_yes, - on_no=on_no, - msg=msg, - on_close=on_close, - label_kwargs={"justify": "left"}, - ) - return False return True def _wait_while_func(self, func, *args, **kwargs):