commit 223c69260c6eaab67e689d172a9fb6f9195099d1
parent e88a8ec94c679f1562ff911880b3d855ae712ab1
Author: mtseng15 <[email protected]>
Date: Fri, 5 Nov 2021 18:55:27 -0700
started to implement templating
Diffstat:
3 files changed, 49 insertions(+), 14 deletions(-)
diff --git a/NeuralPi.jucer b/NeuralPi.jucer
@@ -79,9 +79,9 @@
</LINUX_MAKE>
<VS2019 targetFolder="Builds/VisualStudio2019">
<CONFIGURATIONS>
- <CONFIGURATION isDebug="1" name="Debug" targetName="NeuralPi" headerPath="C:\Users\rache\Desktop\dev\json-develop\include C:\Users\rache\Desktop\dev\NeuralPi\modules\RTNeural C:\Users\rache\Desktop\dev\NeuralPi\modules\RTNeural\modules\xsimd\include"
+ <CONFIGURATION isDebug="1" name="Debug" targetName="NeuralPi" headerPath="C:\Users\tseng\Personal\mlamsk\NeuralPi\modules\json\include C:\Users\tseng\Personal\mlamsk\NeuralPi\modules\RTNeural C:\Users\tseng\Personal\mlamsk\NeuralPi\modules\RTNeural\modules\xsimd\include"
defines="USE_XSIMD=1"/>
- <CONFIGURATION isDebug="0" name="Release" targetName="NeuralPi" headerPath="C:\Users\rache\Desktop\dev\json-develop\include C:\Users\rache\Desktop\dev\NeuralPi\modules\RTNeural C:\Users\rache\Desktop\dev\NeuralPi\modules\RTNeural\modules\xsimd\include"
+ <CONFIGURATION isDebug="0" name="Release" targetName="NeuralPi" headerPath="C:\Users\tseng\Personal\mlamsk\NeuralPi\modules\json\include C:\Users\tseng\Personal\mlamsk\NeuralPi\modules\RTNeural C:\Users\tseng\Personal\mlamsk\NeuralPi\modules\RTNeural\modules\xsimd\include"
defines="USE_XSIMD=1"/>
</CONFIGURATIONS>
<MODULEPATHS>
diff --git a/Source/RTNeuralLSTM.cpp b/Source/RTNeuralLSTM.cpp
@@ -17,13 +17,10 @@ Vec2d transpose(const Vec2d& x)
return y;
}
-void RT_LSTM::load_json(const char* filename)
+template <typename T1, typename T2>
+void RT_LSTM::set_weights(T1 lstm, T2 dense, const char* filename)
{
-
- auto& lstm = model.get<0>();
- auto& dense = model.get<1>();
-
- // read a JSON file
+ // read a JSON file
std::ifstream i2(filename);
nlohmann::json weights_json;
i2 >> weights_json;
@@ -45,19 +42,21 @@ void RT_LSTM::load_json(const char* filename)
std::vector<float> dense_bias = weights_json["/state_dict/lin.bias"_json_pointer];
dense.setBias(dense_bias.data());
+
}
-
-void RT_LSTM::load_json2(const char* filename)
+void RT_LSTM::load_json(const char* filename)
{
-
- auto& lstm = model_cond1.get<0>();
- auto& dense = model_cond1.get<1>();
+ // Initialize the correct model
+ auto& lstm = model.get<0>();
+ auto& dense = model.get<1>();
+
+ // set_weights(lstm, dense, filename);
// read a JSON file
std::ifstream i2(filename);
nlohmann::json weights_json;
i2 >> weights_json;
-
+
Vec2d lstm_weights_ih = weights_json["/state_dict/rec.weight_ih_l0"_json_pointer];
lstm.setWVals(transpose(lstm_weights_ih));
@@ -77,6 +76,40 @@ void RT_LSTM::load_json2(const char* filename)
dense.setBias(dense_bias.data());
}
+void RT_LSTM::load_json2(const char* filename)
+{
+
+ auto& lstm = model_cond1.get<0>();
+ auto& dense = model_cond1.get<1>();
+
+ set_weights(lstm, dense, filename);
+
+ // // read a JSON file
+ // std::ifstream i2(filename);
+ // nlohmann::json weights_json;
+ // i2 >> weights_json;
+
+ // Vec2d lstm_weights_ih = weights_json["/state_dict/rec.weight_ih_l0"_json_pointer];
+ // lstm.setWVals(transpose(lstm_weights_ih));
+
+ // Vec2d lstm_weights_hh = weights_json["/state_dict/rec.weight_hh_l0"_json_pointer];
+ // lstm.setUVals(transpose(lstm_weights_hh));
+
+ // std::vector<float> lstm_bias_ih = weights_json["/state_dict/rec.bias_ih_l0"_json_pointer];
+ // std::vector<float> lstm_bias_hh = weights_json["/state_dict/rec.bias_hh_l0"_json_pointer];
+ // for (int i = 0; i < 80; ++i)
+ // lstm_bias_hh[i] += lstm_bias_ih[i];
+ // lstm.setBVals(lstm_bias_hh);
+
+ // Vec2d dense_weights = weights_json["/state_dict/lin.weight"_json_pointer];
+ // dense.setWeights(dense_weights);
+
+ // std::vector<float> dense_bias = weights_json["/state_dict/lin.bias"_json_pointer];
+ // dense.setBias(dense_bias.data());
+}
+
+
+
void RT_LSTM::reset()
{
if (input_size == 1) {
diff --git a/Source/RTNeuralLSTM.h b/Source/RTNeuralLSTM.h
@@ -10,6 +10,8 @@ public:
void reset();
void load_json(const char* filename);
void load_json2(const char* filename);
+ template <typename T1, typename T2>
+ void set_weights(T1 lstm, T2 dense, const char* filename);
void process(const float* inData, float* outData, int numSamples);
void process(const float* inData, float param, float* outData, int numSamples);