AnalogTapeModel

Physical modelling signal processing for analog tape recording
Log | Files | Refs | Submodules | README | LICENSE

STNTest.cpp (4171B)


      1 #include "Processors/Hysteresis/HysteresisSTN.h"
      2 
      3 namespace STNTestUtils
      4 {
      5 constexpr double sampleRate = 48000.0;
      6 constexpr double trainingSampleRate = 96000.0;
      7 constexpr auto sampleRateCorr = trainingSampleRate / sampleRate;
      8 
      9 double input alignas (16)[5] = { 1.0, 1.0, 1.0, 1.0, 1.0 };
     10 } // namespace STNTestUtils
     11 
     12 class STNTest : public UnitTest
     13 {
     14 public:
     15     STNTest() : UnitTest ("STNTest")
     16     {
     17     }
     18 
     19     void runTest() override
     20     {
     21 #if JUCE_LINUX
     22         return; // @TODO: figure out why this fails!
     23 #endif
     24         beginTest ("STN Accuracy Test");
     25         accTest();
     26 
     27         beginTest ("STN Performance Test");
     28         perfTest(); // Keep this disabled most of the time for CI
     29     }
     30 
     31     void accTest()
     32     {
     33         using namespace STNTestUtils;
     34 
     35         HysteresisSTN stn;
     36         stn.prepare (sampleRate);
     37         stn.setParams (0.5f, 0.5f);
     38 
     39         auto refModel = loadModel();
     40 
     41         for (int i = 0; i < 10; ++i)
     42         {
     43             auto x = stn.process (input);
     44             auto xRef = refModel->forward (input) * sampleRateCorr;
     45             expectWithinAbsoluteError (x, xRef, 1.0e-15, "STN output is incorrect!");
     46         }
     47     }
     48 
     49     void perfTest()
     50     {
     51         using namespace STNTestUtils;
     52 
     53         HysteresisSTN stn;
     54         stn.prepare (sampleRate);
     55         stn.setParams (0.5f, 0.5f);
     56         auto refModel = loadModel();
     57 
     58         constexpr int nIter = 400000;
     59         double result = 0.0;
     60 
     61         // ref timing
     62         double durationRef = 0.0f;
     63         {
     64             Time time;
     65             auto start = time.getMillisecondCounterHiRes();
     66             for (int i = 0; i < nIter; ++i)
     67                 result = refModel->forward (input) * sampleRateCorr;
     68             auto end = time.getMillisecondCounterHiRes();
     69             durationRef = (end - start) / 1000.0;
     70         }
     71         std::cout << "Reference output: " << result << std::endl;
     72         std::cout << "Reference duration: " << durationRef << std::endl;
     73 
     74         // static STN timing
     75         auto durationStatic = durationRef;
     76         {
     77             auto jsonStream = std::make_unique<MemoryInputStream> (BinaryData::hyst_width_50_json, BinaryData::hyst_width_50_jsonSize, false);
     78             auto modelsJson = nlohmann::json::parse (jsonStream->readEntireStreamAsString().toStdString());
     79             auto thisModelJson = modelsJson["drive_50_50"];
     80             RTNeural::ModelT<double, 5, 1, RTNeural::DenseT<double, 5, 4>, RTNeural::TanhActivationT<double, 4>, RTNeural::DenseT<double, 4, 4>, RTNeural::TanhActivationT<double, 4>, RTNeural::DenseT<double, 4, 1>> staticModel;
     81             staticModel.parseJson (thisModelJson);
     82 
     83             Time time;
     84             auto start = time.getMillisecondCounterHiRes();
     85             for (int i = 0; i < nIter; ++i)
     86                 result = staticModel.forward (input) * sampleRateCorr;
     87             auto end = time.getMillisecondCounterHiRes();
     88             durationStatic = (end - start) / 1000.0;
     89         }
     90         std::cout << "Static output: " << result << std::endl;
     91         std::cout << "Static duration: " << durationStatic << std::endl;
     92 
     93         // plugin timing
     94         auto durationReal = durationRef;
     95         {
     96             Time time;
     97             auto start = time.getMillisecondCounterHiRes();
     98             for (int i = 0; i < nIter; ++i)
     99                 result = stn.process (input);
    100             auto end = time.getMillisecondCounterHiRes();
    101             durationReal = (end - start) / 1000.0;
    102         }
    103         std::cout << "Actual output: " << result << std::endl;
    104         std::cout << "Actual duration: " << durationReal << std::endl;
    105 
    106         expectLessThan (durationReal, durationRef * 1.25, "Plugin STN processing is too slow!");
    107     }
    108 
    109     std::unique_ptr<RTNeural::Model<double>> loadModel()
    110     {
    111         auto jsonStream = std::make_unique<MemoryInputStream> (BinaryData::hyst_width_50_json, BinaryData::hyst_width_50_jsonSize, false);
    112         auto modelsJson = nlohmann::json::parse (jsonStream->readEntireStreamAsString().toStdString());
    113         auto thisModelJson = modelsJson["drive_50_50"];
    114         return RTNeural::json_parser::parseJson<double> (thisModelJson);
    115     }
    116 };
    117 
    118 #if ! JUCE_MAC
    119 static STNTest stnTest;
    120 #endif