commit 64ebab66e1539d66896bd484c4ba2d4e7f560316
parent 0d6e110f85905a1849dbc07b1412943f6dfa1984
Author: Steven Atkinson <[email protected]>
Date: Sat, 10 Dec 2022 18:56:55 -0500
Improvements to Easy Colab (#55)
* Start working on new capture signal version
* Fix some text
* Fix indent
* Fix delay calibration
* Fix delay plot
* Improve delay trigger
* Revert "Improve delay trigger"
This reverts commit a93a02cfda8259d702c48ee074456db3d25d1bcf.
* Point at easy colab at main again
Diffstat:
2 files changed, 56 insertions(+), 22 deletions(-)
diff --git a/bin/train/easy_colab.ipynb b/bin/train/easy_colab.ipynb
@@ -6,15 +6,15 @@
"id": "TC3XkMetGWtK"
},
"source": [
- "# Neural Amp Modeler (\"Easy mode\" Trainer)\n",
+ "# Neural Amp Modeler (\"Easy Mode\" Trainer)\n",
"This notebook allows you to train a neural amp model based on a pair of input/output WAV files that you have of the amp you want to model.\n",
"\n",
- "**To use this notebook**:\n",
- "Go to [colab.research.google.com](https://colab.research.google.com/), select the \"GitHub\" tab, and select this notebook. Or, if you've cloned the repo, you can upload it from your computer.\n",
+ "**Note**:\n",
+ "This notebook is meant to be used on [Google Colab](https://colab.research.google.com/github/sdatkinson/neural-amp-modeler/blob/main/bin/train/easy_colab.ipynb).\n",
"\n",
"🔶**Before you run**🔶\n",
"\n",
- "Make sure to get a GPU! (Runtime->Change runtime type->Select \"GPU\" from the \"Hardware accelerator dropdown menu)\n",
+ "Make sure to get a GPU! (From the upper-left menu, click Runtime->Change runtime type->Select \"GPU\" from the \"Hardware accelerator dropdown menu)\n",
"\n",
"âš **Warning**âš \n",
"\n",
@@ -39,16 +39,14 @@
"\n",
"### Step 1.1: Download the capture signal\n",
"\"Easy mode\" uses a pre-crafted \"capture signal\".\n",
- "Download it [here](https://drive.google.com/file/d/1sVHtubguuXmDHRM8w1TmDaJEfB7-H9Vz/view?usp=share_link).\n",
+ "Download it [here](https://drive.google.com/file/d/12GVN9FXtzcAZnmflhKysEcfvq37h2wOk/view?usp=share_link).\n",
"\n",
"### Step 1.2 Reamp your gear\n",
"Then reamp the gear you want to model using it. Save that reamp as \"output.wav\".\n",
"**Please use 48kHz, 24-bit, mono.** We'll support other sample rates etc in the future; sit tight!\n",
"\n",
"### Step 1.3: upload!\n",
- "Upload the input (DI) and output (amped) files you want to use by clicking the Folder icon on the left ⬅ and then clicking the upload icon.\n",
- "\n",
- "Once you're done, run the next cell and I'll check that everything looks good."
+ "Upload the input (DI) and output (amped) files you want to use by clicking the Folder icon on the left ⬅ and then clicking the upload icon."
]
},
{
diff --git a/nam/train/colab.py b/nam/train/colab.py
@@ -8,7 +8,7 @@ Hide the mess in Colab to make things look pretty for users.
from pathlib import Path
from time import time
-from typing import Optional
+from typing import Optional, Tuple
import matplotlib.pyplot as plt
import numpy as np
@@ -19,20 +19,48 @@ from torch.utils.data import DataLoader
from nam.data import REQUIRED_RATE, Split, init_dataset, wav_to_np
from nam.models import Model
-_INPUT_BASENAME = "v1.wav"
+
+class _Version:
+ def __init__(self, major: int, minor: int, patch: int):
+ self.major = major
+ self.minor = minor
+ self.patch = patch
+
+ def __lt__(self, other) -> bool:
+ if self.major != other.major:
+ return self.major < other.major
+ if self.minor != other.minor:
+ return self.minor < other.minor
+ if self.patch != other.patch:
+ return self.patch < other.patch
+
+ def __str__(self) -> str:
+ return f"{self.major}.{self.minor}.{self.patch}"
+
+
+_INPUT_BASENAMES = ((_Version(1, 1, 0), "v1_1_0.wav"), (_Version(1, 0, 0), "v1.wav"))
_OUTPUT_BASENAME = "output.wav"
-def _check_for_files():
+def _check_for_files() -> Tuple[_Version, str]:
print("Checking that we have all of the required audio files...")
- if not Path(_INPUT_BASENAME).exists():
+ for i, (input_version, input_basename) in enumerate(_INPUT_BASENAMES):
+ if Path(input_basename).exists():
+ if i > 0:
+ print(
+ f"WARNING: Using out-of-date input file {input_basename}. "
+ "Recommend downloading and using the latest version."
+ )
+ break
+ else:
raise FileNotFoundError(
- f"Didn't find NAM's input audio file. Please upload {_INPUT_BASENAME}"
+ f"Didn't find NAM's input audio file. Please upload {_INPUT_BASENAMES[0][1]}"
)
if not Path(_OUTPUT_BASENAME).exists():
raise FileNotFoundError(
f"Didn't find your reamped output audio file. Please upload {_OUTPUT_BASENAME}."
)
+ return input_version, input_basename
def _calibrate_delay_v1() -> int:
@@ -59,9 +87,9 @@ def _calibrate_delay_v1() -> int:
return delay
-def _plot_delay_v1(delay: int):
+def _plot_delay_v1(delay: int, input_basename: str):
print("Plotting the delay for manual inspection...")
- x = wav_to_np(_INPUT_BASENAME)[:48_000]
+ x = wav_to_np(input_basename)[:48_000]
y = wav_to_np(_OUTPUT_BASENAME)[:48_000]
i = np.where(np.abs(x) > 0.1)[0][0] # In case resampled poorly
di = 20
@@ -78,18 +106,26 @@ def _plot_delay_v1(delay: int):
plt.show() # This doesn't freeze the notebook
-def _calibrate_delay(delay: Optional[int]) -> int:
- calibrate, plot = _calibrate_delay_v1, _plot_delay_v1
+def _calibrate_delay(
+ delay: Optional[int], input_version: _Version, input_basename: str
+) -> int:
+ if input_version.major == 1:
+ calibrate, plot = _calibrate_delay_v1, _plot_delay_v1
+ else:
+ raise NotImplementedError(
+ f"Input calibration not implemented for input version {input_version}"
+ )
if delay is not None:
print(f"Delay is specified as {delay}")
else:
print("Delay wasn't provided; attempting to calibrate automatically...")
delay = calibrate()
- plot(delay)
+ plot(delay, input_basename)
return delay
def _get_configs(
+ input_basename: str,
delay: int,
epochs: int,
stage_1_channels: int,
@@ -103,7 +139,7 @@ def _get_configs(
"train": {"ny": 8192, "stop": train_val_split},
"validation": {"ny": None, "start": train_val_split},
"common": {
- "x_path": _INPUT_BASENAME,
+ "x_path": input_basename,
"y_path": _OUTPUT_BASENAME,
"delay": delay,
},
@@ -230,10 +266,10 @@ def run(
:param seed: RNG seed for reproducibility.
"""
torch.manual_seed(seed)
- _check_for_files()
- delay = _calibrate_delay(delay)
+ input_version, input_basename = _check_for_files()
+ delay = _calibrate_delay(delay, input_version, input_basename)
data_config, model_config, learning_config = _get_configs(
- delay, epochs, stage_1_channels, stage_2_channels, lr, lr_decay
+ input_basename, delay, epochs, stage_1_channels, stage_2_channels, lr, lr_decay
)
print("Starting training. Let's rock!")