commit 29d1fb038f892a227e27b461553ea0ce350d3eeb
parent 46560898db9b4617d17d9a9586231bd1d43769eb
Author: Steven Atkinson <steven@atkinson.mn>
Date: Tue, 16 May 2023 19:04:35 -0700
Support modeling cabs (#250)
* Update data.py
* Sample rate more flexible in wav_to_x
* Better error messages in ConcatDataset
* Black
* Docstrings on Dataset
* Amp and cab training in core
* Cab modeling option in GUI trainer
* Cab modeling in Colab
* Add ir.py, update core.py
Drop in the real IR fitter
* Remove cab modeling option. Leave the setting on.
* Remove IR-fitting code
* Remove ir.py
* Revert easy_colab.ipynb
* Remove _with_ir.py
* Remove WithIR code
* Revert hyper_net.py
* Clean up WithIR code
* Clean up WithIR code
Diffstat:
8 files changed, 84 insertions(+), 33 deletions(-)
diff --git a/nam/data.py b/nam/data.py
@@ -298,11 +298,21 @@ class Dataset(AbstractDataset, InitializableFromConfig):
return self._ny
@property
- def x(self):
+ def x(self) -> torch.Tensor:
+ """
+ The input audio data
+
+ :return: (N,)
+ """
return self._x
@property
- def y(self):
+ def y(self) -> torch.Tensor:
+ """
+ The output audio data
+
+ :return: (N,)
+ """
return self._y
@property
@@ -578,6 +588,9 @@ class ConcatDataset(AbstractDataset, InitializableFromConfig):
return self.datasets[i][j]
def __len__(self) -> int:
+ """
+ How many data sets are in this data set
+ """
return sum(len(d) for d in self._datasets)
@property
@@ -622,8 +635,18 @@ class ConcatDataset(AbstractDataset, InitializableFromConfig):
j += 1
lookup[i] = (j, offset)
offset += 1
- assert j == len(self.datasets) - 1
- assert offset == len(self.datasets[-1])
+ # Assert that we got to the last data set
+ if j != len(self.datasets) - 1:
+ raise RuntimeError(
+ f"During lookup population, didn't get to the last dataset (index "
+ f"{len(self.datasets)-1}). Instead index ended at {j}."
+ )
+ if offset != len(self.datasets[-1]):
+ raise RuntimeError(
+ "During lookup population, didn't end at the index of the last datum "
+ f"in the last dataset. Expected index {len(self.datasets[-1])}, got "
+ f"{offset} instead."
+ )
return lookup
@classmethod
diff --git a/nam/models/_base.py b/nam/models/_base.py
@@ -3,7 +3,7 @@
# Author: Steven Atkinson (steven@atkinson.mn)
"""
-The foundation of the model without the PyTorch Lightning attributes (losses, training
+The foundation of the model without the PyTorch Lightning attributes (losses, training
steps)
"""
diff --git a/nam/models/base.py b/nam/models/base.py
@@ -220,7 +220,7 @@ class Model(pl.LightningModule, InitializableFromConfig):
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
def forward(self, *args, **kwargs):
- return self.net(*args, **kwargs)
+ return self.net(*args, **kwargs) # TODO deprecate--use self.net() instead.
def _shared_step(self, batch) -> Tuple[torch.Tensor, torch.Tensor]:
"""
diff --git a/nam/models/metadata.py b/nam/models/metadata.py
@@ -23,6 +23,7 @@ class GearType(Enum):
PREAMP = "preamp"
STUDIO = "studio"
+
# Note: if you change this enum, you need to update the options in easy_colab.ipynb!
class ToneType(Enum):
CLEAN = "clean"
diff --git a/nam/train/_version.py b/nam/train/_version.py
@@ -14,7 +14,11 @@ class Version:
self.patch = patch
def __eq__(self, other) -> bool:
- return self.major == other.major and self.minor == other.minor and self.patch == other.patch
+ return (
+ self.major == other.major
+ and self.minor == other.minor
+ and self.patch == other.patch
+ )
def __lt__(self, other) -> bool:
if self.major != other.major:
diff --git a/nam/train/colab.py b/nam/train/colab.py
@@ -86,7 +86,7 @@ def run(
:param lr: The initial learning rate
:param lr_decay: The amount by which the learning rate decays each epoch
:param seed: RNG seed for reproducibility.
- :param user_metadata: To include in the exported model
+ :param user_metadata: User-specified metadata to include in the .nam file.
:param ignore_checks: Ignores the data quality checks and YOLOs it
"""
diff --git a/nam/train/core.py b/nam/train/core.py
@@ -8,15 +8,15 @@ Functions used by the GUI trainer.
import hashlib
import tkinter as tk
+from copy import deepcopy
from enum import Enum
from time import time
-from typing import Optional, Sequence, Union
+from typing import Dict, Optional, Sequence, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
-import torch.nn as nn
from torch.utils.data import DataLoader
from ..data import REQUIRED_RATE, Split, init_dataset, wav_to_np, wav_to_tensor
@@ -134,6 +134,8 @@ def _detect_input_version(input_path) -> Version:
_V1_BLIP_LOCATIONS = 12_000, 36_000
_V2_START_BLIP_LOCATIONS = _V1_BLIP_LOCATIONS
_V2_END_BLIP_LOCATIONS = -36_000, -12_000
+_DELAY_CALIBRATION_ABS_THRESHOLD = 0.0001
+_DELAY_CALIBRATION_REL_THRESHOLD = 0.001
def _calibrate_delay_v1(
@@ -146,7 +148,10 @@ def _calibrate_delay_v1(
# Calibrate the trigger:
y = wav_to_np(output_path)[:48_000]
background_level = np.max(np.abs(y[:6_000]))
- trigger_threshold = max(background_level + 0.01, 1.01 * background_level)
+ trigger_threshold = max(
+ background_level + _DELAY_CALIBRATION_ABS_THRESHOLD,
+ (1.0 + _DELAY_CALIBRATION_REL_THRESHOLD) * background_level,
+ )
delays = []
for blip_index, i in enumerate(locations, 1):
@@ -186,7 +191,10 @@ def _calibrate_delay_v1(
return delay
-_calibrate_delay_v2 = _calibrate_delay_v1
+def _calibrate_delay_v2(
+ input_path, output_path, locations: Sequence[int] = _V2_START_BLIP_LOCATIONS
+) -> int:
+ return _calibrate_delay_v1(input_path, output_path, locations=locations)
def _plot_delay_v1(delay: int, input_path: str, output_path: str, _nofail=True):
@@ -477,8 +485,8 @@ def _get_wavenet_config(architecture):
def _get_configs(
input_version: Version,
- input_basename: str,
- output_basename: str,
+ input_path: str,
+ output_path: str,
delay: int,
epochs: int,
model_type: str,
@@ -512,8 +520,8 @@ def _get_configs(
"train": {"ny": ny, **train_kwargs},
"validation": {"ny": None, **validation_kwargs},
"common": {
- "x_path": input_basename,
- "y_path": output_basename,
+ "x_path": input_path,
+ "y_path": output_path,
"delay": delay,
},
}
@@ -548,6 +556,7 @@ def _get_configs(
"optimizer": {"lr": 0.01},
"lr_scheduler": {"class": "ExponentialLR", "kwargs": {"gamma": 0.995}},
}
+ model_config["loss"]["mrstft_weight"] = 2e-4
if torch.cuda.is_available():
device_config = {"accelerator": "gpu", "devices": 1}
@@ -570,6 +579,18 @@ def _get_configs(
return data_config, model_config, learning_config
+def _get_dataloaders(
+ data_config: Dict, learning_config: Dict, model: Model
+) -> Tuple[DataLoader, DataLoader]:
+ data_config, learning_config = [deepcopy(c) for c in (data_config, learning_config)]
+ data_config["common"]["nx"] = model.net.receptive_field
+ dataset_train = init_dataset(data_config, Split.TRAIN)
+ dataset_validation = init_dataset(data_config, Split.VALIDATION)
+ train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"])
+ val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"])
+ return train_dataloader, val_dataloader
+
+
def _esr(pred: torch.Tensor, target: torch.Tensor) -> float:
return (
torch.mean(torch.square(pred - target)).item()
@@ -729,11 +750,9 @@ def train(
print("Starting training. It's time to kick ass and chew bubblegum!")
model = Model.init_from_config(model_config)
- data_config["common"]["nx"] = model.net.receptive_field
- dataset_train = init_dataset(data_config, Split.TRAIN)
- dataset_validation = init_dataset(data_config, Split.VALIDATION)
- train_dataloader = DataLoader(dataset_train, **learning_config["train_dataloader"])
- val_dataloader = DataLoader(dataset_validation, **learning_config["val_dataloader"])
+ train_dataloader, val_dataloader = _get_dataloaders(
+ data_config, learning_config, model
+ )
trainer = pl.Trainer(
callbacks=[
@@ -782,7 +801,7 @@ def train(
_plot(
model,
- dataset_validation,
+ val_dataloader.dataset,
filepath=train_path + "/" + modelname if save_plot else None,
silent=silent,
**window_kwargs(input_version),
diff --git a/nam/train/gui.py b/nam/train/gui.py
@@ -257,37 +257,41 @@ class _GUI(object):
def _get_additional_options_frame(self):
# Checkboxes
- self._frame_silent = tk.Frame(self._root)
- self._frame_silent.pack(side=tk.LEFT)
+ self._frame_checkboxes = tk.Frame(self._root)
+ self._frame_checkboxes.pack(side=tk.LEFT)
# Silent run (bypass popups)
+ row = 1
self._silent = tk.BooleanVar()
- self._chkbox_silent = tk.Checkbutton(
- self._frame_silent,
+ self._checkbox_silent = tk.Checkbutton(
+ self._frame_checkboxes,
text="Silent run (suggested for batch training)",
variable=self._silent,
)
- self._chkbox_silent.grid(row=1, column=1, sticky="W")
+ self._checkbox_silent.grid(row=row, column=1, sticky="W")
+ row += 1
# Auto save the end plot
self._save_plot = tk.BooleanVar()
self._save_plot.set(True) # default this to true
- self._chkbox_save_plot = tk.Checkbutton(
- self._frame_silent,
+ self._checkbox_save_plot = tk.Checkbutton(
+ self._frame_checkboxes,
text="Save ESR plot automatically",
variable=self._save_plot,
)
- self._chkbox_save_plot.grid(row=2, column=1, sticky="W")
+ self._checkbox_save_plot.grid(row=row, column=1, sticky="W")
+ row += 1
# Skip the data quality checks!
self._ignore_checks = tk.BooleanVar()
self._ignore_checks.set(False)
- self._chkbox_ignore_checks = tk.Checkbutton(
- self._frame_silent,
+ self._checkbox_ignore_checks = tk.Checkbutton(
+ self._frame_checkboxes,
text="Ignore data quality checks (DO AT YOUR OWN RISK!)",
variable=self._ignore_checks,
)
- self._chkbox_ignore_checks.grid(row=3, column=1, sticky="W")
+ self._checkbox_ignore_checks.grid(row=row, column=1, sticky="W")
+ row += 1
def mainloop(self):
self._root.mainloop()