kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

commit c0ee34aeb6429cf8a5a630d049dea479d610cdd9
parent 3e97d16af133075c579274d5bb7722a7e528c9ba
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date:   Wed,  7 Sep 2016 17:11:45 +0300

DFT plan cache

Diffstat:
Ainclude/kfr/dft/cache.hpp | 99+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Minclude/kfr/dft/conv.hpp | 22++++++++++++----------
Msources.cmake | 1+
3 files changed, 112 insertions(+), 10 deletions(-)

diff --git a/include/kfr/dft/cache.hpp b/include/kfr/dft/cache.hpp @@ -0,0 +1,99 @@ +/** @addtogroup dft + * @{ + */ +/* + Copyright (C) 2016 D Levin (https://www.kfrlib.com) + This file is part of KFR + + KFR is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + KFR is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with KFR. + + If GPL is not suitable for your project, you must purchase a commercial license to use KFR. + Buying a commercial license is mandatory as soon as you develop commercial activities without + disclosing the source code of your own applications. + See https://www.kfrlib.com for details. + */ +#pragma once + +#include "fft.hpp" +#include <memory> +#include <mutex> +#include <vector> + +namespace kfr +{ + +template <typename T> +using dft_plan_ptr = std::shared_ptr<const dft_plan<T>>; + +struct dft_cache +{ + static dft_cache& instance() + { + static dft_cache cache; + return cache; + } + dft_plan_ptr<f32> get(ctype_t<f32>, size_t size) + { +#ifndef KFR_SINGLE_THREAD + std::lock_guard<std::mutex> guard(mutex); +#endif + return get_or_create(cache_f32, size); + } + dft_plan_ptr<f64> get(ctype_t<f64>, size_t size) + { +#ifndef KFR_SINGLE_THREAD + std::lock_guard<std::mutex> guard(mutex); +#endif + return get_or_create(cache_f64, size); + } + void clear() + { +#ifndef KFR_SINGLE_THREAD + std::lock_guard<std::mutex> guard(mutex); +#endif + cache_f32.clear(); + cache_f64.clear(); + } + +private: + template <typename T> + std::shared_ptr<const dft_plan<T>> get_or_create(std::vector<dft_plan_ptr<T>>& cache, size_t size) + { + for (dft_plan_ptr<T>& dft : cache) + { + if (dft->size == size) + return dft; + } + dft_plan_ptr<T> sh = std::make_shared<dft_plan<T>>(size); + cache.push_back(sh); + return sh; + } + + std::vector<dft_plan_ptr<f32>> cache_f32; + std::vector<dft_plan_ptr<f64>> cache_f64; +#ifndef KFR_SINGLE_THREAD + std::mutex mutex; +#endif +}; + +template <typename T, size_t Tag> +univector<complex<T>> dft(const univector<complex<T>, Tag>& input) +{ + dft_plan_ptr<T> dft = dft_cache::instance().get(ctype<T>, input.size()); + univector<T> output(input.size()); + univector<u8> temp(dft->temp_size); + dft->execute(output, input, temp); + return output; +} +} diff --git a/include/kfr/dft/conv.hpp b/include/kfr/dft/conv.hpp @@ -31,6 +31,7 @@ #include "../base/read_write.hpp" #include "../base/vec.hpp" +#include "cache.hpp" #include "fft.hpp" #pragma clang diagnostic push @@ -49,12 +50,13 @@ KFR_INTRIN univector<T> convolve(const univector<T, Tag1>& src1, const univector univector<complex<T>> src2padded = src2; src1padded.resize(size, 0); src2padded.resize(size, 0); - dft_plan<T> plan(size); - univector<u8> temp(plan.temp_size); - plan.execute(src1padded, src1padded, temp); - plan.execute(src2padded, src2padded, temp); + + dft_plan_ptr<T> dft = dft_cache::instance().get(ctype<T>, size); + univector<u8> temp(dft->temp_size); + dft->execute(src1padded, src1padded, temp); + dft->execute(src2padded, src2padded, temp); src1padded = src1padded * src2padded; - plan.execute(src1padded, src1padded, temp, true); + dft->execute(src1padded, src1padded, temp, true); const T invsize = reciprocal<T>(size); return truncate(real(src1padded), src1.size() + src2.size() - 1) * invsize; } @@ -67,12 +69,12 @@ KFR_INTRIN univector<T> correlate(const univector<T, Tag1>& src1, const univecto univector<complex<T>> src2padded = reverse(src2); src1padded.resize(size, 0); src2padded.resize(size, 0); - dft_plan<T> plan(size); - univector<u8> temp(plan.temp_size); - plan.execute(src1padded, src1padded, temp); - plan.execute(src2padded, src2padded, temp); + dft_plan_ptr<T> dft = dft_cache::instance().get(ctype<T>, size); + univector<u8> temp(dft->temp_size); + dft->execute(src1padded, src1padded, temp); + dft->execute(src2padded, src2padded, temp); src1padded = src1padded * src2padded; - plan.execute(src1padded, src1padded, temp, true); + dft->execute(src1padded, src1padded, temp, true); const T invsize = reciprocal<T>(size); return truncate(real(src1padded), src1.size() + src2.size() - 1) * invsize; } diff --git a/sources.cmake b/sources.cmake @@ -68,6 +68,7 @@ set( ${PROJECT_SOURCE_DIR}/include/kfr/data/bitrev.hpp ${PROJECT_SOURCE_DIR}/include/kfr/data/sincos.hpp ${PROJECT_SOURCE_DIR}/include/kfr/dft/bitrev.hpp + ${PROJECT_SOURCE_DIR}/include/kfr/dft/cache.hpp ${PROJECT_SOURCE_DIR}/include/kfr/dft/conv.hpp ${PROJECT_SOURCE_DIR}/include/kfr/dft/fft.hpp ${PROJECT_SOURCE_DIR}/include/kfr/dft/ft.hpp