HysteresisSTN.cpp (3016B)
1 #include "HysteresisSTN.h" 2 #include <future> 3 4 namespace 5 { 6 constexpr double trainingSampleRate = 96e3; 7 8 constexpr float satIdxMult = (float) HysteresisSTN::numSatModels - 1.0f; 9 constexpr float widthIdxMult = (float) HysteresisSTN::numWidthModels - 1.0f; 10 11 const std::array<String, HysteresisSTN::numWidthModels> widthTags { "0", "10", "20", "30", "40", "50", "60", "70", "80", "90", "100" }; 12 13 const std::array<String, HysteresisSTN::numSatModels> satTags { "0", "5", "10", "15", "20", "25", "30", "35", "40", "45", "50", "55", "60", "65", "70", "75", "80", "85", "90", "95", "100" }; 14 } // namespace 15 16 constexpr size_t getSatIdx (float satParam) 17 { 18 return (size_t) std::clamp (static_cast<int> (satIdxMult * satParam), 0, HysteresisSTN::numSatModels - 1); 19 } 20 21 constexpr size_t getWidthIdx (float widthParam) 22 { 23 return (size_t) std::clamp (static_cast<int> (widthIdxMult * widthParam), 0, HysteresisSTN::numWidthModels - 1); 24 } 25 26 std::unique_ptr<MemoryInputStream> getModelFileStream (const String& modelFile) 27 { 28 std::unique_ptr<MemoryInputStream> stream; 29 for (int i = 0; i < BinaryData::namedResourceListSize; ++i) 30 { 31 if (String (BinaryData::originalFilenames[i]) == modelFile) 32 { 33 int fileSize = 0; 34 auto* fileData = BinaryData::getNamedResource (BinaryData::namedResourceList[i], fileSize); 35 stream = std::make_unique<MemoryInputStream> (fileData, fileSize, false); 36 return std::move (stream); 37 } 38 } 39 40 return {}; 41 } 42 43 HysteresisSTN::HysteresisSTN() 44 { 45 // Since we have a lot of models to load 46 // let's split them up and load them asychronously! 47 // This cuts down the model loading time for both 48 // channels from ~100 ms to ~30 ms 49 size_t widthLoadIdx = 0; 50 std::vector<std::future<void>> futures; 51 for (const auto& width : widthTags) 52 { 53 auto loadModelSet = [=] (size_t widthModelIdx) 54 { 55 auto modelsStream = getModelFileStream ("hyst_width_" + width + ".json"); 56 jassert (modelsStream != nullptr); 57 58 auto modelsJson = nlohmann::json::parse (modelsStream->readEntireStreamAsString().toStdString()); 59 size_t satLoadIdx = 0; 60 for (const auto& sat : satTags) 61 { 62 String modelTag = "drive_" + sat + "_" + width; 63 auto thisModelJson = modelsJson[modelTag.toStdString()]; 64 stnModels[widthModelIdx][satLoadIdx].loadModel (thisModelJson); 65 satLoadIdx++; 66 } 67 }; 68 69 futures.push_back (std::async (std::launch::async, 70 [=, &widthLoadIdx] 71 { loadModelSet (widthLoadIdx++); })); 72 } 73 74 for (auto& f : futures) 75 f.wait(); 76 } 77 78 void HysteresisSTN::prepare (double sampleRate) 79 { 80 sampleRateCorr = trainingSampleRate / sampleRate; 81 } 82 83 void HysteresisSTN::setParams (float saturation, float width) 84 { 85 satIdx = getSatIdx (saturation); 86 widthIdx = getWidthIdx (width); 87 }