commit 27c6a048025e7894e0d89579cfda6c59d93e0f20
parent 16a2108fe64602661daa55303e6a06bf348b59d9
Author: Steven Atkinson <[email protected]>
Date: Sat, 10 Feb 2024 10:08:16 -0800
Support Proteus in Colab (#378)
* Support Proteus in Colab
* Minor version bump to 0.8.0
Diffstat:
5 files changed, 19 insertions(+), 8 deletions(-)
diff --git a/nam/_version.py b/nam/_version.py
@@ -1 +1 @@
-__version__ = "0.7.4"
+__version__ = "0.8.0"
diff --git a/nam/train/_names.py b/nam/train/_names.py
@@ -4,7 +4,7 @@
from typing import NamedTuple
-from ._version import Version
+from ._version import PROTEUS_VERSION, Version
__all__ = ["INPUT_BASENAMES", "LATEST_VERSION", "VersionAndName"]
@@ -20,6 +20,7 @@ INPUT_BASENAMES = (
VersionAndName(Version(2, 0, 0), "v2_0_0.wav"),
VersionAndName(Version(1, 1, 1), "v1_1_1.wav"),
VersionAndName(Version(1, 0, 0), "v1.wav"),
+ VersionAndName(PROTEUS_VERSION, "Proteus_Capture.wav"),
)
LATEST_VERSION = INPUT_BASENAMES[0]
diff --git a/nam/train/_version.py b/nam/train/_version.py
@@ -6,6 +6,8 @@
Version utility
"""
+__all__ = ["PROTEUS_VERSION", "Version"]
+
class Version:
def __init__(self, major: int, minor: int, patch: int):
@@ -30,3 +32,6 @@ class Version:
def __str__(self) -> str:
return f"{self.major}.{self.minor}.{self.patch}"
+
+
+PROTEUS_VERSION = Version(4, 0, 0)
diff --git a/nam/train/colab.py b/nam/train/colab.py
@@ -12,7 +12,7 @@ from typing import NamedTuple, Optional, Tuple
from ..models.metadata import UserMetadata
from ._names import INPUT_BASENAMES, LATEST_VERSION, Version
-from ._version import Version
+from ._version import PROTEUS_VERSION, Version
from .core import train
@@ -34,7 +34,9 @@ def _check_for_files() -> Tuple[Version, str]:
)
for input_version, input_basename in INPUT_BASENAMES:
if Path(input_basename).exists():
- if input_version != LATEST_VERSION.version:
+ if input_version == PROTEUS_VERSION:
+ print(f"Using Proteus input file...")
+ elif input_version != LATEST_VERSION.version:
print(
f"WARNING: Using out-of-date input file {input_basename}. "
"Recommend downloading and using the latest version, "
@@ -49,7 +51,10 @@ def _check_for_files() -> Tuple[Version, str]:
raise FileNotFoundError(
f"Didn't find your reamped output audio file. Please upload {_OUTPUT_BASENAME}."
)
- print(f"Found {input_basename}, version {input_version}")
+ if input_version != PROTEUS_VERSION:
+ print(f"Found {input_basename}, version {input_version}")
+ else:
+ print(f"Found Proteus input {input_basename}.")
return input_version, input_basename
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -26,7 +26,7 @@ from ..data import Split, init_dataset, wav_to_np, wav_to_tensor
from ..models import Model
from ..models.losses import esr
from ..util import filter_warnings
-from ._version import Version
+from ._version import PROTEUS_VERSION, Version
__all__ = ["train"]
@@ -70,7 +70,7 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
"7c3b6119c74465f79d96c761a0e27370": Version(1, 1, 1),
"ede3b9d82135ce10c7ace3bb27469422": Version(2, 0, 0),
"36cd1af62985c2fac3e654333e36431e": Version(3, 0, 0),
- "80e224bd5622fd6153ff1fd9f34cb3bd": Version(4, 0, 0),
+ "80e224bd5622fd6153ff1fd9f34cb3bd": PROTEUS_VERSION,
}.get(file_hash)
if version is None:
print(
@@ -211,7 +211,7 @@ def _detect_input_version(input_path) -> Tuple[Version, bool]:
}.get((start_hash_v1, end_hash_v1))
if version is not None:
return version
- version = {"46151c8030798081acc00a725325a07d": Version(4, 0, 0)}.get(hash_v4)
+ version = {"46151c8030798081acc00a725325a07d": PROTEUS_VERSION}.get(hash_v4)
return version
version = detect_strong(input_path)