commit 4e97b46cc24417a550d92a82bce2e1d012f10ed5
parent 89496f5d87590cc2dfb7431403a6f7b4a5a6827d
Author: Steven Atkinson <steven@atkinson.mn>
Date: Sat, 29 Apr 2023 14:50:14 -0700
Strong and weak matches for recognized training files (#220)
Diffstat:
M | nam/train/core.py | | | 110 | +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------- |
1 file changed, 91 insertions(+), 19 deletions(-)
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -34,25 +34,98 @@ def _detect_input_version(input_path) -> Version:
"""
Check to see if the input matches any of the known inputs
"""
- md5 = hashlib.md5()
- buffer_size = 65536
- with open(input_path, "rb") as f:
- while True:
- data = f.read(buffer_size)
- if not data:
- break
- md5.update(data)
- file_hash = md5.hexdigest()
-
- version = {
- "4d54a958861bf720ec4637f43d44a7ef": Version(1, 0, 0),
- "7c3b6119c74465f79d96c761a0e27370": Version(1, 1, 1),
- "cff9de79975f7fa2ba9060ad77cde04d": Version(2, 0, 0),
- }.get(file_hash)
+
+ def detect_strong(input_path) -> Optional[Version]:
+ def assign_hash(path):
+ # Use this to create hashes for new files
+ md5 = hashlib.md5()
+ buffer_size = 65536
+ with open(input_path, "rb") as f:
+ while True:
+ data = f.read(buffer_size)
+ if not data:
+ break
+ md5.update(data)
+ file_hash = md5.hexdigest()
+ return file_hash
+
+ file_hash = assign_hash(input_path)
+ print(f"Strong hash: {file_hash}")
+
+ version = {
+ "4d54a958861bf720ec4637f43d44a7ef": Version(1, 0, 0),
+ "7c3b6119c74465f79d96c761a0e27370": Version(1, 1, 1),
+ "cff9de79975f7fa2ba9060ad77cde04d": Version(2, 0, 0),
+ }.get(file_hash)
+ if version is None:
+ print(
+ f"Provided input file {input_path} does not strong-match any known "
+ "standard input files."
+ )
+ return version
+
+ def detect_weak(input_path) -> Optional[Version]:
+ def assign_hash(path):
+ # Use this to create recognized hashes for new files
+ x, info = wav_to_np(path, info=True)
+ rate = info.rate
+ if rate != REQUIRED_RATE:
+ return None
+ # Times of intervals, in seconds
+ t_blips = 1
+ t_sweep = 3
+ t_white = 3
+ t_validation = 9
+ # v1 and v2 start with 1 blips, sine sweeps, and white noise
+ start_hash = hashlib.md5(
+ x[: (t_blips + t_sweep + t_white) * rate]
+ ).hexdigest()
+ # v1 ends with validation signal
+ end_hash_v1 = hashlib.md5(x[-t_validation * rate :]).hexdigest()
+ # v2 ends with 2x validation & blips
+ end_hash_v2 = hashlib.md5(
+ x[-(2 * t_validation + t_blips) * rate :]
+ ).hexdigest()
+ return start_hash, end_hash_v1, end_hash_v2
+
+ start_hash, end_hash_v1, end_hash_v2 = assign_hash(input_path)
+ print(
+ "Weak hashes:\n"
+ f" Start: {start_hash}\n"
+ f" End (v1): {end_hash_v1}\n"
+ f" End (v2): {end_hash_v2}\n",
+ )
+
+ # Check for v2 matches first
+ version = {
+ (
+ "068a17d92274a136807523baad4913ff",
+ "74f924e8b245c8f7dce007765911545a",
+ ): Version(2, 0, 0)
+ }.get((start_hash, end_hash_v2))
+ if version is not None:
+ return version
+ # Fallback to v1
+ version = {
+ (
+ "bb4e140c9299bae67560d280917eb52b",
+ "9b2468fcb6e9460a399fc5f64389d353",
+ ): Version(1, 0, 0),
+ (
+ "9f20c6b5f7fef68dd88307625a573a14",
+ "8458126969a3f9d8e19a53554eb1fd52",
+ ): Version(1, 1, 1),
+ }.get((start_hash, end_hash_v1))
+ return version
+
+ version = detect_strong(input_path)
+ if version is not None:
+ return version
+ print("Falling back to weak-matching...")
+ version = detect_weak(input_path)
if version is None:
- raise RuntimeError(
- f"Provided input file {input_path} does not match any known standard input "
- "files."
+ raise ValueError(
+ f"Input file at {input_path} cannot be recognized as any known version!"
)
return version
@@ -76,7 +149,6 @@ def _calibrate_delay_v1(
delays = []
for blip_index, i in enumerate(locations, 1):
-
start_looking = i - lookahead
stop_looking = i + lookback
y_scan = y[start_looking:stop_looking]