neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 4e97b46cc24417a550d92a82bce2e1d012f10ed5
parent 89496f5d87590cc2dfb7431403a6f7b4a5a6827d
Author: Steven Atkinson <steven@atkinson.mn>
Date:   Sat, 29 Apr 2023 14:50:14 -0700

Strong and weak matches for recognized training files (#220)


Diffstat:
Mnam/train/core.py | 110+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------
1 file changed, 91 insertions(+), 19 deletions(-)

diff --git a/nam/train/core.py b/nam/train/core.py @@ -34,25 +34,98 @@ def _detect_input_version(input_path) -> Version: """ Check to see if the input matches any of the known inputs """ - md5 = hashlib.md5() - buffer_size = 65536 - with open(input_path, "rb") as f: - while True: - data = f.read(buffer_size) - if not data: - break - md5.update(data) - file_hash = md5.hexdigest() - - version = { - "4d54a958861bf720ec4637f43d44a7ef": Version(1, 0, 0), - "7c3b6119c74465f79d96c761a0e27370": Version(1, 1, 1), - "cff9de79975f7fa2ba9060ad77cde04d": Version(2, 0, 0), - }.get(file_hash) + + def detect_strong(input_path) -> Optional[Version]: + def assign_hash(path): + # Use this to create hashes for new files + md5 = hashlib.md5() + buffer_size = 65536 + with open(input_path, "rb") as f: + while True: + data = f.read(buffer_size) + if not data: + break + md5.update(data) + file_hash = md5.hexdigest() + return file_hash + + file_hash = assign_hash(input_path) + print(f"Strong hash: {file_hash}") + + version = { + "4d54a958861bf720ec4637f43d44a7ef": Version(1, 0, 0), + "7c3b6119c74465f79d96c761a0e27370": Version(1, 1, 1), + "cff9de79975f7fa2ba9060ad77cde04d": Version(2, 0, 0), + }.get(file_hash) + if version is None: + print( + f"Provided input file {input_path} does not strong-match any known " + "standard input files." + ) + return version + + def detect_weak(input_path) -> Optional[Version]: + def assign_hash(path): + # Use this to create recognized hashes for new files + x, info = wav_to_np(path, info=True) + rate = info.rate + if rate != REQUIRED_RATE: + return None + # Times of intervals, in seconds + t_blips = 1 + t_sweep = 3 + t_white = 3 + t_validation = 9 + # v1 and v2 start with 1 blips, sine sweeps, and white noise + start_hash = hashlib.md5( + x[: (t_blips + t_sweep + t_white) * rate] + ).hexdigest() + # v1 ends with validation signal + end_hash_v1 = hashlib.md5(x[-t_validation * rate :]).hexdigest() + # v2 ends with 2x validation & blips + end_hash_v2 = hashlib.md5( + x[-(2 * t_validation + t_blips) * rate :] + ).hexdigest() + return start_hash, end_hash_v1, end_hash_v2 + + start_hash, end_hash_v1, end_hash_v2 = assign_hash(input_path) + print( + "Weak hashes:\n" + f" Start: {start_hash}\n" + f" End (v1): {end_hash_v1}\n" + f" End (v2): {end_hash_v2}\n", + ) + + # Check for v2 matches first + version = { + ( + "068a17d92274a136807523baad4913ff", + "74f924e8b245c8f7dce007765911545a", + ): Version(2, 0, 0) + }.get((start_hash, end_hash_v2)) + if version is not None: + return version + # Fallback to v1 + version = { + ( + "bb4e140c9299bae67560d280917eb52b", + "9b2468fcb6e9460a399fc5f64389d353", + ): Version(1, 0, 0), + ( + "9f20c6b5f7fef68dd88307625a573a14", + "8458126969a3f9d8e19a53554eb1fd52", + ): Version(1, 1, 1), + }.get((start_hash, end_hash_v1)) + return version + + version = detect_strong(input_path) + if version is not None: + return version + print("Falling back to weak-matching...") + version = detect_weak(input_path) if version is None: - raise RuntimeError( - f"Provided input file {input_path} does not match any known standard input " - "files." + raise ValueError( + f"Input file at {input_path} cannot be recognized as any known version!" ) return version @@ -76,7 +149,6 @@ def _calibrate_delay_v1( delays = [] for blip_index, i in enumerate(locations, 1): - start_looking = i - lookahead stop_looking = i + lookback y_scan = y[start_looking:stop_looking]