commit 64f79fdc4c68b8735b879f495d4e5099cd39aa6d
parent 0c5112907742f1e65a4fd6ff22656a38253505d0
Author: Steven Atkinson <[email protected]>
Date: Sat, 28 Sep 2024 10:08:49 -0700
Data validation: Check audio length and sample rates for matches (#474)
* Raise DataError instead of ValueError
* Add sample rate and length validation for data (require exact match for now)
* Fix docstring
* Define critical checks
* Don't allow ignoring critical checks
Diffstat:
3 files changed, 164 insertions(+), 36 deletions(-)
diff --git a/nam/data.py b/nam/data.py
@@ -431,7 +431,8 @@ class Dataset(AbstractDataset, InitializableFromConfig):
msg += (
f"\n * The output is {sample_to_time(y_samples, sample_rate)} long"
)
- raise ValueError(msg)
+ msg += f"\n\nOriginal exception:\n{e}"
+ raise DataError(msg)
return {"x": x, "y": y, "sample_rate": sample_rate, **config}
@classmethod
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -1327,6 +1327,25 @@ def train(
if seed is not None:
torch.manual_seed(seed)
+ # HACK: We need to check the sample rates and lengths of the audio here or else
+ # It will look like a bad self-ESR (Issue 473)
+ # Can move this into the "v3 checks" once the others are deprecated.
+ # And honestly remake this whole thing as a data processing pipeline.
+ sample_rate_validation = _check_audio_sample_rates(input_path, output_path)
+ if not sample_rate_validation.passed:
+ raise ValueError(
+ "Different sample rates detected for input "
+ f"({sample_rate_validation.input}) and output "
+ f"({sample_rate_validation.output}) audio!"
+ )
+ length_validation = _check_audio_lengths(input_path, output_path)
+ if not length_validation.passed:
+ raise ValueError(
+ "Your recording differs in length from the input file by "
+ f"{length_validation.delta_seconds:.2f} seconds. Check your reamp "
+ "in your DAW and ensure that they are the same length."
+ )
+
if input_version is None:
input_version, strong_match = _detect_input_version(input_path)
@@ -1501,14 +1520,79 @@ class _PyTorchDataValidation(BaseModel):
validation: _PyTorchDataSplitValidation # Split.VALIDATION
+class _SampleRateValidation(BaseModel):
+ passed: bool
+ input: int
+ output: int
+
+
+class _LengthValidation(BaseModel):
+ passed: bool
+ delta_seconds: float
+
+
class DataValidationOutput(BaseModel):
passed: bool
+ passed_critical: bool
+ sample_rate: _SampleRateValidation
+ length: _LengthValidation
input_version: str
latency: metadata.Latency
checks: metadata.DataChecks
pytorch: _PyTorchDataValidation
+def _check_audio_sample_rates(
+ input_path: Path,
+ output_path: Path,
+) -> _SampleRateValidation:
+ _, x_info = wav_to_np(input_path, info=True)
+ _, y_info = wav_to_np(output_path, info=True)
+
+ return _SampleRateValidation(
+ passed=x_info.rate == y_info.rate,
+ input=x_info.rate,
+ output=y_info.rate,
+ )
+
+
+def _check_audio_lengths(
+ input_path: Path,
+ output_path: Path,
+ max_under_seconds: Optional[float] = 0.0,
+ max_over_seconds: Optional[float] = 0.0,
+) -> _LengthValidation:
+ """
+ Check that the input and output have the right lengths compared to each
+ other.
+
+ :param input_path: Path to input audio
+ :param output_path: Path to output audio
+ :param max_under_seconds: If not None, the maximum amount by which the
+ output can be shorter than the input. Should be non-negative i.e. a
+ value of 1.0 means that the output can't be more than a second shorter
+ than the input.
+ :param max_over_seconds: If not None, the maximum amount by which the
+ output can be longer than the input. Should be non-negative i.e. a
+ value of 1.0 means that the output can't be more than a second longer
+ than the input.
+ """
+ x, x_info = wav_to_np(input_path, info=True)
+ y, y_info = wav_to_np(output_path, info=True)
+
+ length_input = len(x) / x_info.rate
+ length_output = len(y) / y_info.rate
+ delta_seconds = length_output - length_input
+
+ passed = True
+ if max_under_seconds is not None and delta_seconds < -max_under_seconds:
+ passed = False
+ if max_over_seconds is not None and delta_seconds > max_under_seconds:
+ passed = False
+
+ return _LengthValidation(passed=passed, delta_seconds=delta_seconds)
+
+
def validate_data(
input_path: Path,
output_path: Path,
@@ -1522,7 +1606,17 @@ def validate_data(
* Latency calibration
* Other checks
"""
+ print("Validating data...")
passed = True # Until proven otherwise
+ passed_critical = True # These can't be ignored
+
+ sample_rate_validation = _check_audio_sample_rates(input_path, output_path)
+ passed = passed and sample_rate_validation.passed
+ passed_critical = passed_critical and sample_rate_validation.passed
+
+ length_validation = _check_audio_lengths(input_path, output_path)
+ passed = passed and length_validation.passed
+ passed_critical = passed_critical and length_validation.passed
# Data version ID
input_version, strong_match = _detect_input_version(input_path)
@@ -1575,9 +1669,13 @@ def validate_data(
**pytorch_data_split_validation_dict,
)
passed = passed and pytorch_data_validation.passed
+ passed_critical = passed_critical and pytorch_data_validation.passed
return DataValidationOutput(
passed=passed,
+ passed_critical=passed_critical,
+ sample_rate=sample_rate_validation,
+ length=length_validation,
input_version=str(input_version),
latency=latency_analysis,
checks=data_checks,
diff --git a/nam/train/gui/__init__.py b/nam/train/gui/__init__.py
@@ -327,9 +327,11 @@ class _OkModal(object):
Message and OK button
"""
- def __init__(self, resume_main, msg: str):
+ def __init__(self, resume_main, msg: str, label_kwargs: Optional[dict] = None):
+ label_kwargs = {} if label_kwargs is None else label_kwargs
+
self._root = _TopLevelWithOk((lambda: None), resume_main)
- self._text = tk.Label(self._root, text=msg)
+ self._text = tk.Label(self._root, text=msg, **label_kwargs)
self._text.pack()
self._ok = tk.Button(
self._root,
@@ -764,12 +766,32 @@ class GUI(object):
output_path: str, validation_output: core.DataValidationOutput
) -> str:
"""
- File and explain what's wrong with it.
+ State the file and explain what's wrong with it.
"""
# TODO put this closer to what it looks at, i.e. core.DataValidationOutput
msg = (
f"\t{Path(output_path).name}:\n" # They all have the same directory so
)
+ if not validation_output.sample_rate.passed:
+ msg += (
+ "\t\t There are different sample rates for the input ("
+ f"{validation_output.sample_rate.input}) and output ("
+ f"{validation_output.sample_rate.output}).\n"
+ )
+ if not validation_output.length.passed:
+ msg += (
+ "\t\t* The input and output audio files are too different in length"
+ )
+ if validation_output.length.delta_seconds > 0:
+ msg += (
+ f" (the output is {validation_output.length.delta_seconds:.2f} "
+ "seconds longer than the input)\n"
+ )
+ else:
+ msg += (
+ f" (the output is {-validation_output.length.delta_seconds:.2f}"
+ " seconds shorter than the input)\n"
+ )
if validation_output.latency.manual is None:
if validation_output.latency.calibration.warnings.matches_lookahead:
msg += (
@@ -808,44 +830,51 @@ class GUI(object):
for output_path in output_paths
}
if any(not fv.passed for fv in file_validation_outputs.values()):
- msg = (
- "The following output files failed checks:\n"
- + "".join(
- [
- make_message_for_file(output_path, fv)
- for output_path, fv in file_validation_outputs.items()
- if not fv.passed
- ]
- )
- + "\nIgnore and proceed?"
+ msg = "The following output files failed checks:\n" + "".join(
+ [
+ make_message_for_file(output_path, fv)
+ for output_path, fv in file_validation_outputs.items()
+ if not fv.passed
+ ]
)
+ if all(fv.passed_critical for fv in file_validation_outputs.values()):
+ msg += "\nIgnore and proceed?"
- # Hacky to listen to the modal:
- modal_listener = {"proceed": False, "still_open": True}
+ # Hacky to listen to the modal:
+ modal_listener = {"proceed": False, "still_open": True}
- def on_yes():
- modal_listener["proceed"] = True
+ def on_yes():
+ modal_listener["proceed"] = True
- def on_no():
- modal_listener["proceed"] = False
+ def on_no():
+ modal_listener["proceed"] = False
- def on_close():
- if modal_listener["proceed"]:
- self._train2(ignore_checks=True)
+ def on_close():
+ if modal_listener["proceed"]:
+ self._train2(ignore_checks=True)
+
+ self._wait_while_func(
+ (
+ lambda resume, on_yes, on_no, *args, **kwargs: _YesNoModal(
+ on_yes, on_no, resume, *args, **kwargs
+ )
+ ),
+ on_yes=on_yes,
+ on_no=on_no,
+ msg=msg,
+ on_close=on_close,
+ label_kwargs={"justify": "left"},
+ )
+ return False # we still failed checks so say so.
+ else:
+ msg += "\nCritical errors found, cannot ignore."
+ self._wait_while_func(
+ lambda resume, msg, **kwargs: _OkModal(resume, msg, **kwargs),
+ msg=msg,
+ label_kwargs={"justify": "left"},
+ )
+ return False
- self._wait_while_func(
- (
- lambda resume, on_yes, on_no, *args, **kwargs: _YesNoModal(
- on_yes, on_no, resume, *args, **kwargs
- )
- ),
- on_yes=on_yes,
- on_no=on_no,
- msg=msg,
- on_close=on_close,
- label_kwargs={"justify": "left"},
- )
- return False
return True
def _wait_while_func(self, func, *args, **kwargs):