kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

commit 3b1f2bdb633397d871afb1945a5afd9dec45b9b6
parent 239dd0c83b36ec3d47487cddaceaee8b205f736b
Author: [email protected] <[email protected]>
Date:   Tue, 25 Feb 2020 06:32:52 +0000

Allow invert_t in real dft execute/refactoring

Diffstat:
Minclude/kfr/dft/convolution.hpp | 22+++++++++++++++++++++-
Minclude/kfr/dft/fft.hpp | 8++++----
Minclude/kfr/dft/impl/convolution-impl.cpp | 49++++++-------------------------------------------
3 files changed, 31 insertions(+), 48 deletions(-)

diff --git a/include/kfr/dft/convolution.hpp b/include/kfr/dft/convolution.hpp @@ -76,6 +76,26 @@ univector<T> autocorrelate(const univector<T, Tag1>& src) return intrinsics::autocorrelate(src.slice()); } +namespace internal +{ +/// @brief Utility class to abstract real/complex differences +template <typename T> +struct dft_conv_plan: public dft_plan_real<T> +{ + dft_conv_plan(size_t size) : dft_plan_real<T>(size, dft_pack_format::Perm) {} + + size_t csize() const { return this->size / 2; } +}; + +template <typename T> +struct dft_conv_plan<complex<T>>: public dft_plan<T> +{ + dft_conv_plan(size_t size) : dft_plan<T>(size) {} + + size_t csize() const { return this->size; } +}; +} // namespace internal + /// @brief Convolution using Filter API template <typename T> class convolve_filter : public filter<T> @@ -98,7 +118,7 @@ protected: using ST = subtype<T>; static constexpr auto real_fft = !std::is_same<T, complex<ST>>::value; - using plan_t = std::conditional_t<real_fft, dft_plan_real<T>, dft_plan<ST>>; + using plan_t = internal::dft_conv_plan<T>; // Length of filter data. size_t data_size; diff --git a/include/kfr/dft/fft.hpp b/include/kfr/dft/fft.hpp @@ -347,12 +347,12 @@ struct dft_plan_real : dft_plan<T> void execute(univector<complex<T>, Tag1>&, const univector<complex<T>, Tag2>&, univector<u8, Tag3>&, cbool_t<inverse>) const = delete; - KFR_MEM_INTRINSIC void execute(complex<T>* out, const T* in, u8* temp) const + KFR_MEM_INTRINSIC void execute(complex<T>* out, const T* in, u8* temp, cdirect_t = {}) const { this->execute_dft(cfalse, out, ptr_cast<complex<T>>(in), temp); fmt_stage->execute(cfalse, out, out, nullptr); } - KFR_MEM_INTRINSIC void execute(T* out, const complex<T>* in, u8* temp) const + KFR_MEM_INTRINSIC void execute(T* out, const complex<T>* in, u8* temp, cinvert_t = {}) const { complex<T>* outdata = ptr_cast<complex<T>>(out); fmt_stage->execute(ctrue, outdata, in, nullptr); @@ -361,14 +361,14 @@ struct dft_plan_real : dft_plan<T> template <univector_tag Tag1, univector_tag Tag2, univector_tag Tag3> KFR_MEM_INTRINSIC void execute(univector<complex<T>, Tag1>& out, const univector<T, Tag2>& in, - univector<u8, Tag3>& temp) const + univector<u8, Tag3>& temp, cdirect_t = {}) const { this->execute_dft(cfalse, out.data(), ptr_cast<complex<T>>(in.data()), temp.data()); fmt_stage->execute(cfalse, out.data(), out.data(), nullptr); } template <univector_tag Tag1, univector_tag Tag2, univector_tag Tag3> KFR_MEM_INTRINSIC void execute(univector<T, Tag1>& out, const univector<complex<T>, Tag2>& in, - univector<u8, Tag3>& temp) const + univector<u8, Tag3>& temp, cinvert_t = {}) const { complex<T>* outdata = ptr_cast<complex<T>>(out.data()); fmt_stage->execute(ctrue, outdata, in.data(), nullptr); diff --git a/include/kfr/dft/impl/convolution-impl.cpp b/include/kfr/dft/impl/convolution-impl.cpp @@ -82,49 +82,12 @@ univector<T> autocorrelate(const univector_ref<const T>& src1) } // namespace intrinsics -// Create a helper template struct to handle the differences between real and complex FFT. -template <typename T, typename ST = subtype<T>, - typename plan_t = - std::conditional_t<std::is_same<T, complex<ST>>::value, dft_plan<ST>, dft_plan_real<T>>> -struct convolve_filter_fft -{ - static plan_t make(size_t size); - static inline void ifft(plan_t const& plan, univector<T>& out, const univector<complex<T>>& in, - univector<u8>& temp); - static size_t csize(plan_t const& plan); -}; -// Partial template specializations for complex and real cases: -template <typename ST> -struct convolve_filter_fft<complex<ST>, ST, dft_plan<ST>> -{ - static dft_plan<ST> make(size_t size) { return dft_plan<ST>(size); } - static inline void ifft(dft_plan<ST> const& plan, univector<complex<ST>>& out, - const univector<complex<ST>>& in, univector<u8>& temp) - { - plan.execute(out, in, temp, ctrue); - } - static size_t csize(dft_plan<ST> const& plan) { return plan.size; } -}; -template <typename T> -struct convolve_filter_fft<T, T, dft_plan_real<T>> -{ - static dft_plan_real<T> make(size_t size) { return dft_plan_real<T>(size, dft_pack_format::Perm); } - static inline void ifft(dft_plan_real<T> const& plan, univector<T>& out, const univector<complex<T>>& in, - univector<u8>& temp) - { - plan.execute(out, in, temp); - } - static size_t csize(dft_plan_real<T> const& plan) { return plan.size / 2; } -}; - template <typename T> convolve_filter<T>::convolve_filter(size_t size_, size_t block_size_) - : data_size(size_), block_size(next_poweroftwo(block_size_)), - fft(convolve_filter_fft<T>::make(2 * block_size)), temp(fft.temp_size), + : data_size(size_), block_size(next_poweroftwo(block_size_)), fft(2 * block_size), temp(fft.temp_size), segments((data_size + block_size - 1) / block_size), ir_segments(segments.size()), input_position(0), - saved_input(block_size), premul(convolve_filter_fft<T>::csize(fft)), - cscratch(convolve_filter_fft<T>::csize(fft)), scratch1(fft.size), scratch2(fft.size), - overlap(block_size), position(0) + saved_input(block_size), premul(fft.csize()), cscratch(fft.csize()), scratch1(fft.size), + scratch2(fft.size), overlap(block_size), position(0) { } @@ -145,8 +108,8 @@ void convolve_filter<T>::set_data(const univector_ref<const T>& data) const ST ifftsize = reciprocal(ST(fft.size)); for (size_t i = 0; i < ir_segments.size(); i++) { - segments[i].resize(convolve_filter_fft<T>::csize(fft)); - ir_segments[i].resize(convolve_filter_fft<T>::csize(fft)); + segments[i].resize(fft.csize()); + ir_segments[i].resize(fft.csize()); input = padded(data.slice(i * block_size, block_size)); fft.execute(ir_segments[i], input, temp); @@ -217,7 +180,7 @@ void convolve_filter<T>::process_buffer(T* output, const T* input, size_t size) fft_multiply_accumulate(cscratch, premul, ir_segments[0], segments[position], fft_multiply_pack); } // y_k = IFFT( Y_k ) - convolve_filter_fft<T>::ifft(fft, scratch2, cscratch, temp); + fft.execute(scratch2, cscratch, temp, cinvert_t{}); // z_k = y_k + overlap process(make_univector(output + processed, processing),