commit c0ee34aeb6429cf8a5a630d049dea479d610cdd9
parent 3e97d16af133075c579274d5bb7722a7e528c9ba
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Wed, 7 Sep 2016 17:11:45 +0300
DFT plan cache
Diffstat:
3 files changed, 112 insertions(+), 10 deletions(-)
diff --git a/include/kfr/dft/cache.hpp b/include/kfr/dft/cache.hpp
@@ -0,0 +1,99 @@
+/** @addtogroup dft
+ * @{
+ */
+/*
+ Copyright (C) 2016 D Levin (https://www.kfrlib.com)
+ This file is part of KFR
+
+ KFR is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ KFR is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with KFR.
+
+ If GPL is not suitable for your project, you must purchase a commercial license to use KFR.
+ Buying a commercial license is mandatory as soon as you develop commercial activities without
+ disclosing the source code of your own applications.
+ See https://www.kfrlib.com for details.
+ */
+#pragma once
+
+#include "fft.hpp"
+#include <memory>
+#include <mutex>
+#include <vector>
+
+namespace kfr
+{
+
+template <typename T>
+using dft_plan_ptr = std::shared_ptr<const dft_plan<T>>;
+
+struct dft_cache
+{
+ static dft_cache& instance()
+ {
+ static dft_cache cache;
+ return cache;
+ }
+ dft_plan_ptr<f32> get(ctype_t<f32>, size_t size)
+ {
+#ifndef KFR_SINGLE_THREAD
+ std::lock_guard<std::mutex> guard(mutex);
+#endif
+ return get_or_create(cache_f32, size);
+ }
+ dft_plan_ptr<f64> get(ctype_t<f64>, size_t size)
+ {
+#ifndef KFR_SINGLE_THREAD
+ std::lock_guard<std::mutex> guard(mutex);
+#endif
+ return get_or_create(cache_f64, size);
+ }
+ void clear()
+ {
+#ifndef KFR_SINGLE_THREAD
+ std::lock_guard<std::mutex> guard(mutex);
+#endif
+ cache_f32.clear();
+ cache_f64.clear();
+ }
+
+private:
+ template <typename T>
+ std::shared_ptr<const dft_plan<T>> get_or_create(std::vector<dft_plan_ptr<T>>& cache, size_t size)
+ {
+ for (dft_plan_ptr<T>& dft : cache)
+ {
+ if (dft->size == size)
+ return dft;
+ }
+ dft_plan_ptr<T> sh = std::make_shared<dft_plan<T>>(size);
+ cache.push_back(sh);
+ return sh;
+ }
+
+ std::vector<dft_plan_ptr<f32>> cache_f32;
+ std::vector<dft_plan_ptr<f64>> cache_f64;
+#ifndef KFR_SINGLE_THREAD
+ std::mutex mutex;
+#endif
+};
+
+template <typename T, size_t Tag>
+univector<complex<T>> dft(const univector<complex<T>, Tag>& input)
+{
+ dft_plan_ptr<T> dft = dft_cache::instance().get(ctype<T>, input.size());
+ univector<T> output(input.size());
+ univector<u8> temp(dft->temp_size);
+ dft->execute(output, input, temp);
+ return output;
+}
+}
diff --git a/include/kfr/dft/conv.hpp b/include/kfr/dft/conv.hpp
@@ -31,6 +31,7 @@
#include "../base/read_write.hpp"
#include "../base/vec.hpp"
+#include "cache.hpp"
#include "fft.hpp"
#pragma clang diagnostic push
@@ -49,12 +50,13 @@ KFR_INTRIN univector<T> convolve(const univector<T, Tag1>& src1, const univector
univector<complex<T>> src2padded = src2;
src1padded.resize(size, 0);
src2padded.resize(size, 0);
- dft_plan<T> plan(size);
- univector<u8> temp(plan.temp_size);
- plan.execute(src1padded, src1padded, temp);
- plan.execute(src2padded, src2padded, temp);
+
+ dft_plan_ptr<T> dft = dft_cache::instance().get(ctype<T>, size);
+ univector<u8> temp(dft->temp_size);
+ dft->execute(src1padded, src1padded, temp);
+ dft->execute(src2padded, src2padded, temp);
src1padded = src1padded * src2padded;
- plan.execute(src1padded, src1padded, temp, true);
+ dft->execute(src1padded, src1padded, temp, true);
const T invsize = reciprocal<T>(size);
return truncate(real(src1padded), src1.size() + src2.size() - 1) * invsize;
}
@@ -67,12 +69,12 @@ KFR_INTRIN univector<T> correlate(const univector<T, Tag1>& src1, const univecto
univector<complex<T>> src2padded = reverse(src2);
src1padded.resize(size, 0);
src2padded.resize(size, 0);
- dft_plan<T> plan(size);
- univector<u8> temp(plan.temp_size);
- plan.execute(src1padded, src1padded, temp);
- plan.execute(src2padded, src2padded, temp);
+ dft_plan_ptr<T> dft = dft_cache::instance().get(ctype<T>, size);
+ univector<u8> temp(dft->temp_size);
+ dft->execute(src1padded, src1padded, temp);
+ dft->execute(src2padded, src2padded, temp);
src1padded = src1padded * src2padded;
- plan.execute(src1padded, src1padded, temp, true);
+ dft->execute(src1padded, src1padded, temp, true);
const T invsize = reciprocal<T>(size);
return truncate(real(src1padded), src1.size() + src2.size() - 1) * invsize;
}
diff --git a/sources.cmake b/sources.cmake
@@ -68,6 +68,7 @@ set(
${PROJECT_SOURCE_DIR}/include/kfr/data/bitrev.hpp
${PROJECT_SOURCE_DIR}/include/kfr/data/sincos.hpp
${PROJECT_SOURCE_DIR}/include/kfr/dft/bitrev.hpp
+ ${PROJECT_SOURCE_DIR}/include/kfr/dft/cache.hpp
${PROJECT_SOURCE_DIR}/include/kfr/dft/conv.hpp
${PROJECT_SOURCE_DIR}/include/kfr/dft/fft.hpp
${PROJECT_SOURCE_DIR}/include/kfr/dft/ft.hpp