commit 4c789d9e136fe426cd6842b933922bf287a5cd53
parent 223c69260c6eaab67e689d172a9fb6f9195099d1
Author: mtseng15 <tseng.micah@gmail.com>
Date: Sat, 6 Nov 2021 20:59:58 -0700
Implemented a Templated system for loading models
Diffstat:
2 files changed, 23 insertions(+), 63 deletions(-)
diff --git a/Source/RTNeuralLSTM.cpp b/Source/RTNeuralLSTM.cpp
@@ -17,10 +17,14 @@ Vec2d transpose(const Vec2d& x)
return y;
}
-template <typename T1, typename T2>
-void RT_LSTM::set_weights(T1 lstm, T2 dense, const char* filename)
+template <typename T1>
+void RT_LSTM::set_weights(T1 model, const char* filename)
{
- // read a JSON file
+ // Initialize the correct model
+ auto& lstm = (*model).get<0>();
+ auto& dense = (*model).get<1>();
+
+ // read a JSON file
std::ifstream i2(filename);
nlohmann::json weights_json;
i2 >> weights_json;
@@ -46,70 +50,26 @@ void RT_LSTM::set_weights(T1 lstm, T2 dense, const char* filename)
}
void RT_LSTM::load_json(const char* filename)
{
- // Initialize the correct model
- auto& lstm = model.get<0>();
- auto& dense = model.get<1>();
-
- // set_weights(lstm, dense, filename);
-
- // read a JSON file
+ // Read in the JSON file
std::ifstream i2(filename);
- nlohmann::json weights_json;
- i2 >> weights_json;
-
- Vec2d lstm_weights_ih = weights_json["/state_dict/rec.weight_ih_l0"_json_pointer];
- lstm.setWVals(transpose(lstm_weights_ih));
-
- Vec2d lstm_weights_hh = weights_json["/state_dict/rec.weight_hh_l0"_json_pointer];
- lstm.setUVals(transpose(lstm_weights_hh));
-
- std::vector<float> lstm_bias_ih = weights_json["/state_dict/rec.bias_ih_l0"_json_pointer];
- std::vector<float> lstm_bias_hh = weights_json["/state_dict/rec.bias_hh_l0"_json_pointer];
- for (int i = 0; i < 80; ++i)
- lstm_bias_hh[i] += lstm_bias_ih[i];
- lstm.setBVals(lstm_bias_hh);
-
- Vec2d dense_weights = weights_json["/state_dict/lin.weight"_json_pointer];
- dense.setWeights(dense_weights);
-
- std::vector<float> dense_bias = weights_json["/state_dict/lin.bias"_json_pointer];
- dense.setBias(dense_bias.data());
-}
-
-void RT_LSTM::load_json2(const char* filename)
-{
+ nlohmann::json weights_json;
+ i2 >> weights_json;
- auto& lstm = model_cond1.get<0>();
- auto& dense = model_cond1.get<1>();
+ // Get the input size of the JSON file
+ int input_size_json = weights_json["/model_data/input_size"_json_pointer];
+ input_size = input_size_json;
- set_weights(lstm, dense, filename);
-
- // // read a JSON file
- // std::ifstream i2(filename);
- // nlohmann::json weights_json;
- // i2 >> weights_json;
-
- // Vec2d lstm_weights_ih = weights_json["/state_dict/rec.weight_ih_l0"_json_pointer];
- // lstm.setWVals(transpose(lstm_weights_ih));
-
- // Vec2d lstm_weights_hh = weights_json["/state_dict/rec.weight_hh_l0"_json_pointer];
- // lstm.setUVals(transpose(lstm_weights_hh));
-
- // std::vector<float> lstm_bias_ih = weights_json["/state_dict/rec.bias_ih_l0"_json_pointer];
- // std::vector<float> lstm_bias_hh = weights_json["/state_dict/rec.bias_hh_l0"_json_pointer];
- // for (int i = 0; i < 80; ++i)
- // lstm_bias_hh[i] += lstm_bias_ih[i];
- // lstm.setBVals(lstm_bias_hh);
-
- // Vec2d dense_weights = weights_json["/state_dict/lin.weight"_json_pointer];
- // dense.setWeights(dense_weights);
+ // Load the appropriate model
+ if (input_size == 1) {
+ set_weights(&model, filename);
+ }
+ else {
+ set_weights(&model_cond1, filename);
+ }
- // std::vector<float> dense_bias = weights_json["/state_dict/lin.bias"_json_pointer];
- // dense.setBias(dense_bias.data());
}
-
void RT_LSTM::reset()
{
if (input_size == 1) {
diff --git a/Source/RTNeuralLSTM.h b/Source/RTNeuralLSTM.h
@@ -9,9 +9,9 @@ public:
void reset();
void load_json(const char* filename);
- void load_json2(const char* filename);
- template <typename T1, typename T2>
- void set_weights(T1 lstm, T2 dense, const char* filename);
+ template <typename T1>
+
+ void set_weights(T1 model, const char* filename);
void process(const float* inData, float* outData, int numSamples);
void process(const float* inData, float param, float* outData, int numSamples);