STNTest.cpp (4171B)
1 #include "Processors/Hysteresis/HysteresisSTN.h" 2 3 namespace STNTestUtils 4 { 5 constexpr double sampleRate = 48000.0; 6 constexpr double trainingSampleRate = 96000.0; 7 constexpr auto sampleRateCorr = trainingSampleRate / sampleRate; 8 9 double input alignas (16)[5] = { 1.0, 1.0, 1.0, 1.0, 1.0 }; 10 } // namespace STNTestUtils 11 12 class STNTest : public UnitTest 13 { 14 public: 15 STNTest() : UnitTest ("STNTest") 16 { 17 } 18 19 void runTest() override 20 { 21 #if JUCE_LINUX 22 return; // @TODO: figure out why this fails! 23 #endif 24 beginTest ("STN Accuracy Test"); 25 accTest(); 26 27 beginTest ("STN Performance Test"); 28 perfTest(); // Keep this disabled most of the time for CI 29 } 30 31 void accTest() 32 { 33 using namespace STNTestUtils; 34 35 HysteresisSTN stn; 36 stn.prepare (sampleRate); 37 stn.setParams (0.5f, 0.5f); 38 39 auto refModel = loadModel(); 40 41 for (int i = 0; i < 10; ++i) 42 { 43 auto x = stn.process (input); 44 auto xRef = refModel->forward (input) * sampleRateCorr; 45 expectWithinAbsoluteError (x, xRef, 1.0e-15, "STN output is incorrect!"); 46 } 47 } 48 49 void perfTest() 50 { 51 using namespace STNTestUtils; 52 53 HysteresisSTN stn; 54 stn.prepare (sampleRate); 55 stn.setParams (0.5f, 0.5f); 56 auto refModel = loadModel(); 57 58 constexpr int nIter = 400000; 59 double result = 0.0; 60 61 // ref timing 62 double durationRef = 0.0f; 63 { 64 Time time; 65 auto start = time.getMillisecondCounterHiRes(); 66 for (int i = 0; i < nIter; ++i) 67 result = refModel->forward (input) * sampleRateCorr; 68 auto end = time.getMillisecondCounterHiRes(); 69 durationRef = (end - start) / 1000.0; 70 } 71 std::cout << "Reference output: " << result << std::endl; 72 std::cout << "Reference duration: " << durationRef << std::endl; 73 74 // static STN timing 75 auto durationStatic = durationRef; 76 { 77 auto jsonStream = std::make_unique<MemoryInputStream> (BinaryData::hyst_width_50_json, BinaryData::hyst_width_50_jsonSize, false); 78 auto modelsJson = nlohmann::json::parse (jsonStream->readEntireStreamAsString().toStdString()); 79 auto thisModelJson = modelsJson["drive_50_50"]; 80 RTNeural::ModelT<double, 5, 1, RTNeural::DenseT<double, 5, 4>, RTNeural::TanhActivationT<double, 4>, RTNeural::DenseT<double, 4, 4>, RTNeural::TanhActivationT<double, 4>, RTNeural::DenseT<double, 4, 1>> staticModel; 81 staticModel.parseJson (thisModelJson); 82 83 Time time; 84 auto start = time.getMillisecondCounterHiRes(); 85 for (int i = 0; i < nIter; ++i) 86 result = staticModel.forward (input) * sampleRateCorr; 87 auto end = time.getMillisecondCounterHiRes(); 88 durationStatic = (end - start) / 1000.0; 89 } 90 std::cout << "Static output: " << result << std::endl; 91 std::cout << "Static duration: " << durationStatic << std::endl; 92 93 // plugin timing 94 auto durationReal = durationRef; 95 { 96 Time time; 97 auto start = time.getMillisecondCounterHiRes(); 98 for (int i = 0; i < nIter; ++i) 99 result = stn.process (input); 100 auto end = time.getMillisecondCounterHiRes(); 101 durationReal = (end - start) / 1000.0; 102 } 103 std::cout << "Actual output: " << result << std::endl; 104 std::cout << "Actual duration: " << durationReal << std::endl; 105 106 expectLessThan (durationReal, durationRef * 1.25, "Plugin STN processing is too slow!"); 107 } 108 109 std::unique_ptr<RTNeural::Model<double>> loadModel() 110 { 111 auto jsonStream = std::make_unique<MemoryInputStream> (BinaryData::hyst_width_50_json, BinaryData::hyst_width_50_jsonSize, false); 112 auto modelsJson = nlohmann::json::parse (jsonStream->readEntireStreamAsString().toStdString()); 113 auto thisModelJson = modelsJson["drive_50_50"]; 114 return RTNeural::json_parser::parseJson<double> (thisModelJson); 115 } 116 }; 117 118 #if ! JUCE_MAC 119 static STNTest stnTest; 120 #endif