commit aa489d720868354495f84e1f9f035e3869dd4dea
parent c80d4d44af2c80af985f5978f733334c9bbfd60b
Author: Steven Atkinson <[email protected]>
Date: Sat, 5 Jan 2019 13:20:54 -0500
Batched predicts to address memory issues, default n_train
Diffstat:
1 file changed, 17 insertions(+), 7 deletions(-)
diff --git a/models.py b/models.py
@@ -10,7 +10,7 @@ import os
import json
-def from_json(f, n_train, checkpoint_path=None):
+def from_json(f, n_train=1, checkpoint_path=None):
if isinstance(f, str):
with open(f, "r") as json_file:
f = json.load(json_file)
@@ -119,18 +119,28 @@ class Autoregressive(Model):
def input_length(self):
return self._input_length
- def predict(self, x):
+ def predict(self, x, batch_size=None, verbose=False):
"""
Return 1D array of predictions same length as x
"""
n = x.size
+ batch_size = batch_size or n
# Pad x with leading zeros:
x = np.concatenate((np.zeros(self.input_length - 1), x))
- # Reshape into a batch:
- x_mtx = np.stack([x[i: i + self.input_length] for i in range(n)])
- # Predict and flatten.
- return self.sess.run(self.prediction, feed_dict={self.x: x_mtx}) \
- .flatten()
+ i = 0
+ y = []
+ while i < n:
+ if verbose:
+ print("model.predict {}/{}".format(i, n))
+ this_batch_size = np.min([batch_size, n - i])
+ # Reshape into a batch:
+ x_mtx = np.stack([x[j: j + self.input_length]
+ for j in range(i, i + this_batch_size)])
+ # Predict and flatten.
+ y.append(self.sess.run(self.prediction, feed_dict={self.x: x_mtx}) \
+ .flatten())
+ i += this_batch_size
+ return np.concatenate(y)
class FullyConnected(Autoregressive):