NeuralPi

Raspberry Pi guitar pedal using neural networks to emulate real amps and effects
Log | Files | Refs | Submodules | README

commit 4c789d9e136fe426cd6842b933922bf287a5cd53
parent 223c69260c6eaab67e689d172a9fb6f9195099d1
Author: mtseng15 <tseng.micah@gmail.com>
Date:   Sat,  6 Nov 2021 20:59:58 -0700

Implemented a Templated system for loading models

Diffstat:
MSource/RTNeuralLSTM.cpp | 80++++++++++++++++++++-----------------------------------------------------------
MSource/RTNeuralLSTM.h | 6+++---
2 files changed, 23 insertions(+), 63 deletions(-)

diff --git a/Source/RTNeuralLSTM.cpp b/Source/RTNeuralLSTM.cpp @@ -17,10 +17,14 @@ Vec2d transpose(const Vec2d& x) return y; } -template <typename T1, typename T2> -void RT_LSTM::set_weights(T1 lstm, T2 dense, const char* filename) +template <typename T1> +void RT_LSTM::set_weights(T1 model, const char* filename) { - // read a JSON file + // Initialize the correct model + auto& lstm = (*model).get<0>(); + auto& dense = (*model).get<1>(); + + // read a JSON file std::ifstream i2(filename); nlohmann::json weights_json; i2 >> weights_json; @@ -46,70 +50,26 @@ void RT_LSTM::set_weights(T1 lstm, T2 dense, const char* filename) } void RT_LSTM::load_json(const char* filename) { - // Initialize the correct model - auto& lstm = model.get<0>(); - auto& dense = model.get<1>(); - - // set_weights(lstm, dense, filename); - - // read a JSON file + // Read in the JSON file std::ifstream i2(filename); - nlohmann::json weights_json; - i2 >> weights_json; - - Vec2d lstm_weights_ih = weights_json["/state_dict/rec.weight_ih_l0"_json_pointer]; - lstm.setWVals(transpose(lstm_weights_ih)); - - Vec2d lstm_weights_hh = weights_json["/state_dict/rec.weight_hh_l0"_json_pointer]; - lstm.setUVals(transpose(lstm_weights_hh)); - - std::vector<float> lstm_bias_ih = weights_json["/state_dict/rec.bias_ih_l0"_json_pointer]; - std::vector<float> lstm_bias_hh = weights_json["/state_dict/rec.bias_hh_l0"_json_pointer]; - for (int i = 0; i < 80; ++i) - lstm_bias_hh[i] += lstm_bias_ih[i]; - lstm.setBVals(lstm_bias_hh); - - Vec2d dense_weights = weights_json["/state_dict/lin.weight"_json_pointer]; - dense.setWeights(dense_weights); - - std::vector<float> dense_bias = weights_json["/state_dict/lin.bias"_json_pointer]; - dense.setBias(dense_bias.data()); -} - -void RT_LSTM::load_json2(const char* filename) -{ + nlohmann::json weights_json; + i2 >> weights_json; - auto& lstm = model_cond1.get<0>(); - auto& dense = model_cond1.get<1>(); + // Get the input size of the JSON file + int input_size_json = weights_json["/model_data/input_size"_json_pointer]; + input_size = input_size_json; - set_weights(lstm, dense, filename); - - // // read a JSON file - // std::ifstream i2(filename); - // nlohmann::json weights_json; - // i2 >> weights_json; - - // Vec2d lstm_weights_ih = weights_json["/state_dict/rec.weight_ih_l0"_json_pointer]; - // lstm.setWVals(transpose(lstm_weights_ih)); - - // Vec2d lstm_weights_hh = weights_json["/state_dict/rec.weight_hh_l0"_json_pointer]; - // lstm.setUVals(transpose(lstm_weights_hh)); - - // std::vector<float> lstm_bias_ih = weights_json["/state_dict/rec.bias_ih_l0"_json_pointer]; - // std::vector<float> lstm_bias_hh = weights_json["/state_dict/rec.bias_hh_l0"_json_pointer]; - // for (int i = 0; i < 80; ++i) - // lstm_bias_hh[i] += lstm_bias_ih[i]; - // lstm.setBVals(lstm_bias_hh); - - // Vec2d dense_weights = weights_json["/state_dict/lin.weight"_json_pointer]; - // dense.setWeights(dense_weights); + // Load the appropriate model + if (input_size == 1) { + set_weights(&model, filename); + } + else { + set_weights(&model_cond1, filename); + } - // std::vector<float> dense_bias = weights_json["/state_dict/lin.bias"_json_pointer]; - // dense.setBias(dense_bias.data()); } - void RT_LSTM::reset() { if (input_size == 1) { diff --git a/Source/RTNeuralLSTM.h b/Source/RTNeuralLSTM.h @@ -9,9 +9,9 @@ public: void reset(); void load_json(const char* filename); - void load_json2(const char* filename); - template <typename T1, typename T2> - void set_weights(T1 lstm, T2 dense, const char* filename); + template <typename T1> + + void set_weights(T1 model, const char* filename); void process(const float* inData, float* outData, int numSamples); void process(const float* inData, float param, float* outData, int numSamples);