neural-amp-modeler

Neural network emulator for guitar amplifiers
Log | Files | Refs | README | LICENSE

commit e3f2402d1432065b35b8395e49a2a9d72761e167
parent 8ed7f988273d61001202094f9ac8335c6eea878d
Author: sdatkinson <steven@atkinson.mn>
Date:   Sun, 21 Feb 2021 15:09:46 -0500

GPU

Diffstat:
Mmodels.py | 2+-
Mtrain.py | 20++++++++++++--------
2 files changed, 13 insertions(+), 9 deletions(-)

diff --git a/models.py b/models.py @@ -72,7 +72,7 @@ class Autoregressive(Model): n = x.numel() batch_size = batch_size or n # Pad x with leading zeros: - x = torch.cat((torch.zeros(self.input_length - 1), x)) + x = torch.cat((torch.zeros(self.input_length - 1).to(x.device), x)) i = 0 y = [] while i < n: diff --git a/train.py b/train.py @@ -30,9 +30,10 @@ check_training_at = np.concatenate( ) plot_kwargs = {"window": (30000, 40000)} wav_kwargs = {"window": (0, 5 * 44100)} -save_dir = "output" +save_dir = "output2" torch.manual_seed(0) +_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def ensure_save_dir(): @@ -67,7 +68,7 @@ class Data(object): ilist = torch.randint(0, self.n, size=(n,)) x = torch.stack([self.x[i : i + self.input_length] for i in ilist]) y = torch.stack([self.y[i + self.input_length - 1] for i in ilist])[:, None] - return x, y + return x.to(_device), y.to(_device) def sequence(self, start, end): end += self.input_length - 1 @@ -155,9 +156,9 @@ def plot_predictions(model, data, title=None, fname=None, window=None): with torch.no_grad(): x, y, t = predict(model, data, window=window) plt.figure(figsize=(12, 4)) - plt.plot(x) - plt.plot(t) - plt.plot(y) + plt.plot(x.cpu()) + plt.plot(t.cpu()) + plt.plot(y.cpu()) plt.legend(("Input", "Target", "Prediction")) if title is not None: plt.title(title) @@ -182,7 +183,7 @@ def plot_loss(t_loss, v_loss, fname): def predict(model, data, window=None, save_wav_file=None): with torch.no_grad(): - x, t = data.x, data.y + x, t = data.x.to(_device), data.y.to(_device) if window is not None: x, t = x[window[0] : window[1]], t[window[0] : window[1]] y = model.predict_sequence(x).squeeze() @@ -192,7 +193,7 @@ def predict(model, data, window=None, save_wav_file=None): sampwidth = 3 # 24-bit wavio.write( save_wav_file, - y.numpy() * 2 ** 23, + y.cpu().numpy() * 2 ** 23, rate, scale="none", sampwidth=sampwidth, @@ -235,7 +236,7 @@ if __name__ == "__main__": validate_data = Data(args.validation_data, input_length) # Training - model = models.from_json(args.model_arch) + model = models.from_json(args.model_arch).to(_device) t_loss_list, v_loss_list = train( model, train_data, @@ -247,7 +248,10 @@ if __name__ == "__main__": wav_kwargs=wav_kwargs, ) plot_predictions(model, validate_data, window=(0, 44100)) + print("Predict the full output") + _device = "cpu" + model.to(_device) predict( model, validate_data,