commit 17273de0bd13eb691dbf7416c3140b43e11535ac
parent a24caf30e69bf7f2b02e89e428abe12eeba5adc1
Author: Steven Atkinson <[email protected]>
Date: Tue, 17 Sep 2024 08:56:04 -0700
Deprecate non-v3.0.0 input signals (#464)
Diffstat:
3 files changed, 22 insertions(+), 1 deletion(-)
diff --git a/nam/train/_names.py b/nam/train/_names.py
@@ -16,10 +16,14 @@ class VersionAndName(NamedTuple):
# From most- to the least-recently-released:
INPUT_BASENAMES = (
VersionAndName(Version(3, 0, 0), "input.wav", {"v3_0_0.wav"}),
+ # ==================================================================================
+ # These are deprecated and will be removed in v0.11. If you still want them, you'll
+ # need to write an extension.
VersionAndName(Version(2, 0, 0), "v2_0_0.wav", None),
VersionAndName(Version(1, 1, 1), "v1_1_1.wav", None),
VersionAndName(Version(1, 0, 0), "v1.wav", None),
VersionAndName(PROTEUS_VERSION, "Proteus_Capture.wav", None),
+ # ==================================================================================
)
LATEST_VERSION = INPUT_BASENAMES[0]
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -235,6 +235,7 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
f"Input file at {input_path} cannot be recognized as any known version!"
)
strong_match = False
+
return version, strong_match
@@ -776,7 +777,15 @@ def _check_data(
else:
print(f"Checks not implemented for input version {input_version}; skip")
return None
- return f(input_path, output_path, delay, silent)
+ out = f(input_path, output_path, delay, silent)
+ # Issue 442: Deprecate inputs
+ if input_version.major != 3:
+ print(
+ f"Input version {input_version} is deprecated and will be removed in "
+ "version 0.11 of the trainer. To continue using it, you must ignore checks."
+ )
+ out.passed = False
+ return out
def _get_wavenet_config(architecture):
diff --git a/tests/test_nam/test_train/test_core.py b/tests/test_nam/test_train/test_core.py
@@ -56,6 +56,10 @@ class TestDetectInputVersion(object):
def test_detect_input_version_v2_0_0_strong(self):
self._t_detect_input_version_strong(Version(2, 0, 0))
+ @requires_v3_0_0
+ def test_detect_input_version_v3_0_0_strong(self):
+ self._t_detect_input_version_strong(Version(3, 0, 0))
+
@requires_v1_0_0
def test_detect_input_version_v1_0_0_weak(self):
self._t_detect_input_version_weak(Version(1, 0, 0))
@@ -68,6 +72,10 @@ class TestDetectInputVersion(object):
def test_detect_input_version_v2_0_0_weak(self):
self._t_detect_input_version_weak(Version(2, 0, 0))
+ @requires_v3_0_0
+ def test_detect_input_version_v3_0_0_weak(self):
+ self._t_detect_input_version_weak(Version(3, 0, 0))
+
@classmethod
def _customize_resource(cls, path_in, path_out):
x, info = wav_to_np(path_in, info=True)