STNModel.h (911B)
1 #pragma once 2 3 #include <JuceHeader.h> 4 #include <RTNeural/RTNeural.h> 5 6 #define USE_RTNEURAL_POLY 0 7 #define USE_RTNEURAL_STATIC 1 8 9 namespace STNSpace 10 { 11 class STNModel 12 { 13 public: 14 STNModel(); 15 STNModel (STNModel&&) noexcept = default; 16 17 inline double forward (const double* input) noexcept 18 { 19 #if USE_RTNEURAL_STATIC 20 return model.forward (input); 21 #elif USE_RTNEURAL_POLY 22 return model->forward (input); 23 #endif 24 } 25 26 void loadModel (const nlohmann::json& modelJ); 27 28 private: 29 #if USE_RTNEURAL_STATIC 30 RTNeural::ModelT<double, 5, 1, RTNeural::DenseT<double, 5, 4>, RTNeural::TanhActivationT<double, 4>, RTNeural::DenseT<double, 4, 4>, RTNeural::TanhActivationT<double, 4>, RTNeural::DenseT<double, 4, 1>> model; 31 #elif USE_RTNEURAL_POLY 32 std::unique_ptr<RTNeural::Model<double>> model; 33 #endif 34 35 JUCE_DECLARE_NON_COPYABLE_WITH_LEAK_DETECTOR (STNModel) 36 }; 37 38 } // namespace STNSpace