RTNeuralLSTM.cpp (2997B)
1 #include "RTNeuralLSTM.h" 2 3 using Vec2d = std::vector<std::vector<float>>; 4 5 Vec2d transpose(const Vec2d& x) 6 { 7 auto outer_size = x.size(); 8 auto inner_size = x[0].size(); 9 Vec2d y(inner_size, std::vector<float>(outer_size, 0.0f)); 10 11 for (size_t i = 0; i < outer_size; ++i) 12 { 13 for (size_t j = 0; j < inner_size; ++j) 14 y[j][i] = x[i][j]; 15 } 16 17 return y; 18 } 19 20 template <typename T1> 21 void RT_LSTM::set_weights(T1 model, const char* filename) 22 { 23 // Initialize the correct model 24 auto& lstm = (*model).template get<0>(); 25 auto& dense = (*model).template get<1>(); 26 27 // read a JSON file 28 std::ifstream i2(filename); 29 nlohmann::json weights_json; 30 i2 >> weights_json; 31 32 Vec2d lstm_weights_ih = weights_json["/state_dict/rec.weight_ih_l0"_json_pointer]; 33 lstm.setWVals(transpose(lstm_weights_ih)); 34 35 Vec2d lstm_weights_hh = weights_json["/state_dict/rec.weight_hh_l0"_json_pointer]; 36 lstm.setUVals(transpose(lstm_weights_hh)); 37 38 std::vector<float> lstm_bias_ih = weights_json["/state_dict/rec.bias_ih_l0"_json_pointer]; 39 std::vector<float> lstm_bias_hh = weights_json["/state_dict/rec.bias_hh_l0"_json_pointer]; 40 for (int i = 0; i < 80; ++i) 41 lstm_bias_hh[i] += lstm_bias_ih[i]; 42 lstm.setBVals(lstm_bias_hh); 43 44 Vec2d dense_weights = weights_json["/state_dict/lin.weight"_json_pointer]; 45 dense.setWeights(dense_weights); 46 47 std::vector<float> dense_bias = weights_json["/state_dict/lin.bias"_json_pointer]; 48 dense.setBias(dense_bias.data()); 49 50 } 51 void RT_LSTM::load_json(const char* filename) 52 { 53 // Read in the JSON file 54 std::ifstream i2(filename); 55 nlohmann::json weights_json; 56 i2 >> weights_json; 57 58 // Get the input size of the JSON file 59 int input_size_json = weights_json["/model_data/input_size"_json_pointer]; 60 input_size = input_size_json; 61 62 // Load the appropriate model 63 if (input_size == 1) { 64 set_weights(&model, filename); 65 } 66 else if (input_size == 2) { 67 set_weights(&model_cond1, filename); 68 } 69 else if (input_size == 3) { 70 set_weights(&model_cond2, filename); 71 } 72 } 73 74 75 void RT_LSTM::reset() 76 { 77 if (input_size == 1) { 78 model.reset(); 79 } else { 80 model_cond1.reset(); 81 } 82 } 83 84 void RT_LSTM::process(const float* inData, float* outData, int numSamples) 85 { 86 for (int i = 0; i < numSamples; ++i) 87 outData[i] = model.forward(inData + i) + inData[i]; 88 } 89 90 void RT_LSTM::process(const float* inData, float param, float* outData, int numSamples) 91 { 92 for (int i = 0; i < numSamples; ++i) { 93 inArray1[0] = inData[i]; 94 inArray1[1] = param; 95 outData[i] = model_cond1.forward(inArray1) + inData[i]; 96 } 97 } 98 99 void RT_LSTM::process(const float* inData, float param1, float param2, float* outData, int numSamples) 100 { 101 for (int i = 0; i < numSamples; ++i) { 102 inArray2[0] = inData[i]; 103 inArray2[1] = param1; 104 inArray2[2] = param2; 105 outData[i] = model_cond2.forward(inArray2) + inData[i]; 106 } 107 } 108