commit e3f2402d1432065b35b8395e49a2a9d72761e167
parent 8ed7f988273d61001202094f9ac8335c6eea878d
Author: sdatkinson <steven@atkinson.mn>
Date: Sun, 21 Feb 2021 15:09:46 -0500
GPU
Diffstat:
2 files changed, 13 insertions(+), 9 deletions(-)
diff --git a/models.py b/models.py
@@ -72,7 +72,7 @@ class Autoregressive(Model):
n = x.numel()
batch_size = batch_size or n
# Pad x with leading zeros:
- x = torch.cat((torch.zeros(self.input_length - 1), x))
+ x = torch.cat((torch.zeros(self.input_length - 1).to(x.device), x))
i = 0
y = []
while i < n:
diff --git a/train.py b/train.py
@@ -30,9 +30,10 @@ check_training_at = np.concatenate(
)
plot_kwargs = {"window": (30000, 40000)}
wav_kwargs = {"window": (0, 5 * 44100)}
-save_dir = "output"
+save_dir = "output2"
torch.manual_seed(0)
+_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def ensure_save_dir():
@@ -67,7 +68,7 @@ class Data(object):
ilist = torch.randint(0, self.n, size=(n,))
x = torch.stack([self.x[i : i + self.input_length] for i in ilist])
y = torch.stack([self.y[i + self.input_length - 1] for i in ilist])[:, None]
- return x, y
+ return x.to(_device), y.to(_device)
def sequence(self, start, end):
end += self.input_length - 1
@@ -155,9 +156,9 @@ def plot_predictions(model, data, title=None, fname=None, window=None):
with torch.no_grad():
x, y, t = predict(model, data, window=window)
plt.figure(figsize=(12, 4))
- plt.plot(x)
- plt.plot(t)
- plt.plot(y)
+ plt.plot(x.cpu())
+ plt.plot(t.cpu())
+ plt.plot(y.cpu())
plt.legend(("Input", "Target", "Prediction"))
if title is not None:
plt.title(title)
@@ -182,7 +183,7 @@ def plot_loss(t_loss, v_loss, fname):
def predict(model, data, window=None, save_wav_file=None):
with torch.no_grad():
- x, t = data.x, data.y
+ x, t = data.x.to(_device), data.y.to(_device)
if window is not None:
x, t = x[window[0] : window[1]], t[window[0] : window[1]]
y = model.predict_sequence(x).squeeze()
@@ -192,7 +193,7 @@ def predict(model, data, window=None, save_wav_file=None):
sampwidth = 3 # 24-bit
wavio.write(
save_wav_file,
- y.numpy() * 2 ** 23,
+ y.cpu().numpy() * 2 ** 23,
rate,
scale="none",
sampwidth=sampwidth,
@@ -235,7 +236,7 @@ if __name__ == "__main__":
validate_data = Data(args.validation_data, input_length)
# Training
- model = models.from_json(args.model_arch)
+ model = models.from_json(args.model_arch).to(_device)
t_loss_list, v_loss_list = train(
model,
train_data,
@@ -247,7 +248,10 @@ if __name__ == "__main__":
wav_kwargs=wav_kwargs,
)
plot_predictions(model, validate_data, window=(0, 44100))
+
print("Predict the full output")
+ _device = "cpu"
+ model.to(_device)
predict(
model,
validate_data,