neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit 64ebab66e1539d66896bd484c4ba2d4e7f560316
parent 0d6e110f85905a1849dbc07b1412943f6dfa1984
Author: Steven Atkinson <[email protected]>
Date:   Sat, 10 Dec 2022 18:56:55 -0500

Improvements to Easy Colab (#55)

* Start working on new capture signal version

* Fix some text

* Fix indent

* Fix delay calibration

* Fix delay plot

* Improve delay trigger

* Revert "Improve delay trigger"

This reverts commit a93a02cfda8259d702c48ee074456db3d25d1bcf.

* Point at easy colab at main again
Diffstat:
Mbin/train/easy_colab.ipynb | 14++++++--------
Mnam/train/colab.py | 64++++++++++++++++++++++++++++++++++++++++++++++++++--------------
2 files changed, 56 insertions(+), 22 deletions(-)

diff --git a/bin/train/easy_colab.ipynb b/bin/train/easy_colab.ipynb @@ -6,15 +6,15 @@ "id": "TC3XkMetGWtK" }, "source": [ - "# Neural Amp Modeler (\"Easy mode\" Trainer)\n", + "# Neural Amp Modeler (\"Easy Mode\" Trainer)\n", "This notebook allows you to train a neural amp model based on a pair of input/output WAV files that you have of the amp you want to model.\n", "\n", - "**To use this notebook**:\n", - "Go to [colab.research.google.com](https://colab.research.google.com/), select the \"GitHub\" tab, and select this notebook. Or, if you've cloned the repo, you can upload it from your computer.\n", + "**Note**:\n", + "This notebook is meant to be used on [Google Colab](https://colab.research.google.com/github/sdatkinson/neural-amp-modeler/blob/main/bin/train/easy_colab.ipynb).\n", "\n", "🔶**Before you run**🔶\n", "\n", - "Make sure to get a GPU! (Runtime->Change runtime type->Select \"GPU\" from the \"Hardware accelerator dropdown menu)\n", + "Make sure to get a GPU! (From the upper-left menu, click Runtime->Change runtime type->Select \"GPU\" from the \"Hardware accelerator dropdown menu)\n", "\n", "⚠**Warning**⚠\n", "\n", @@ -39,16 +39,14 @@ "\n", "### Step 1.1: Download the capture signal\n", "\"Easy mode\" uses a pre-crafted \"capture signal\".\n", - "Download it [here](https://drive.google.com/file/d/1sVHtubguuXmDHRM8w1TmDaJEfB7-H9Vz/view?usp=share_link).\n", + "Download it [here](https://drive.google.com/file/d/12GVN9FXtzcAZnmflhKysEcfvq37h2wOk/view?usp=share_link).\n", "\n", "### Step 1.2 Reamp your gear\n", "Then reamp the gear you want to model using it. Save that reamp as \"output.wav\".\n", "**Please use 48kHz, 24-bit, mono.** We'll support other sample rates etc in the future; sit tight!\n", "\n", "### Step 1.3: upload!\n", - "Upload the input (DI) and output (amped) files you want to use by clicking the Folder icon on the left ⬅ and then clicking the upload icon.\n", - "\n", - "Once you're done, run the next cell and I'll check that everything looks good." + "Upload the input (DI) and output (amped) files you want to use by clicking the Folder icon on the left ⬅ and then clicking the upload icon." ] }, { diff --git a/nam/train/colab.py b/nam/train/colab.py @@ -8,7 +8,7 @@ Hide the mess in Colab to make things look pretty for users. from pathlib import Path from time import time -from typing import Optional +from typing import Optional, Tuple import matplotlib.pyplot as plt import numpy as np @@ -19,20 +19,48 @@ from torch.utils.data import DataLoader from nam.data import REQUIRED_RATE, Split, init_dataset, wav_to_np from nam.models import Model -_INPUT_BASENAME = "v1.wav" + +class _Version: + def __init__(self, major: int, minor: int, patch: int): + self.major = major + self.minor = minor + self.patch = patch + + def __lt__(self, other) -> bool: + if self.major != other.major: + return self.major < other.major + if self.minor != other.minor: + return self.minor < other.minor + if self.patch != other.patch: + return self.patch < other.patch + + def __str__(self) -> str: + return f"{self.major}.{self.minor}.{self.patch}" + + +_INPUT_BASENAMES = ((_Version(1, 1, 0), "v1_1_0.wav"), (_Version(1, 0, 0), "v1.wav")) _OUTPUT_BASENAME = "output.wav" -def _check_for_files(): +def _check_for_files() -> Tuple[_Version, str]: print("Checking that we have all of the required audio files...") - if not Path(_INPUT_BASENAME).exists(): + for i, (input_version, input_basename) in enumerate(_INPUT_BASENAMES): + if Path(input_basename).exists(): + if i > 0: + print( + f"WARNING: Using out-of-date input file {input_basename}. " + "Recommend downloading and using the latest version." + ) + break + else: raise FileNotFoundError( - f"Didn't find NAM's input audio file. Please upload {_INPUT_BASENAME}" + f"Didn't find NAM's input audio file. Please upload {_INPUT_BASENAMES[0][1]}" ) if not Path(_OUTPUT_BASENAME).exists(): raise FileNotFoundError( f"Didn't find your reamped output audio file. Please upload {_OUTPUT_BASENAME}." ) + return input_version, input_basename def _calibrate_delay_v1() -> int: @@ -59,9 +87,9 @@ def _calibrate_delay_v1() -> int: return delay -def _plot_delay_v1(delay: int): +def _plot_delay_v1(delay: int, input_basename: str): print("Plotting the delay for manual inspection...") - x = wav_to_np(_INPUT_BASENAME)[:48_000] + x = wav_to_np(input_basename)[:48_000] y = wav_to_np(_OUTPUT_BASENAME)[:48_000] i = np.where(np.abs(x) > 0.1)[0][0] # In case resampled poorly di = 20 @@ -78,18 +106,26 @@ def _plot_delay_v1(delay: int): plt.show() # This doesn't freeze the notebook -def _calibrate_delay(delay: Optional[int]) -> int: - calibrate, plot = _calibrate_delay_v1, _plot_delay_v1 +def _calibrate_delay( + delay: Optional[int], input_version: _Version, input_basename: str +) -> int: + if input_version.major == 1: + calibrate, plot = _calibrate_delay_v1, _plot_delay_v1 + else: + raise NotImplementedError( + f"Input calibration not implemented for input version {input_version}" + ) if delay is not None: print(f"Delay is specified as {delay}") else: print("Delay wasn't provided; attempting to calibrate automatically...") delay = calibrate() - plot(delay) + plot(delay, input_basename) return delay def _get_configs( + input_basename: str, delay: int, epochs: int, stage_1_channels: int, @@ -103,7 +139,7 @@ def _get_configs( "train": {"ny": 8192, "stop": train_val_split}, "validation": {"ny": None, "start": train_val_split}, "common": { - "x_path": _INPUT_BASENAME, + "x_path": input_basename, "y_path": _OUTPUT_BASENAME, "delay": delay, }, @@ -230,10 +266,10 @@ def run( :param seed: RNG seed for reproducibility. """ torch.manual_seed(seed) - _check_for_files() - delay = _calibrate_delay(delay) + input_version, input_basename = _check_for_files() + delay = _calibrate_delay(delay, input_version, input_basename) data_config, model_config, learning_config = _get_configs( - delay, epochs, stage_1_channels, stage_2_channels, lr, lr_decay + input_basename, delay, epochs, stage_1_channels, stage_2_channels, lr, lr_decay ) print("Starting training. Let's rock!")