commit 3b1f2bdb633397d871afb1945a5afd9dec45b9b6
parent 239dd0c83b36ec3d47487cddaceaee8b205f736b
Author: [email protected] <[email protected]>
Date: Tue, 25 Feb 2020 06:32:52 +0000
Allow invert_t in real dft execute/refactoring
Diffstat:
3 files changed, 31 insertions(+), 48 deletions(-)
diff --git a/include/kfr/dft/convolution.hpp b/include/kfr/dft/convolution.hpp
@@ -76,6 +76,26 @@ univector<T> autocorrelate(const univector<T, Tag1>& src)
return intrinsics::autocorrelate(src.slice());
}
+namespace internal
+{
+/// @brief Utility class to abstract real/complex differences
+template <typename T>
+struct dft_conv_plan: public dft_plan_real<T>
+{
+ dft_conv_plan(size_t size) : dft_plan_real<T>(size, dft_pack_format::Perm) {}
+
+ size_t csize() const { return this->size / 2; }
+};
+
+template <typename T>
+struct dft_conv_plan<complex<T>>: public dft_plan<T>
+{
+ dft_conv_plan(size_t size) : dft_plan<T>(size) {}
+
+ size_t csize() const { return this->size; }
+};
+} // namespace internal
+
/// @brief Convolution using Filter API
template <typename T>
class convolve_filter : public filter<T>
@@ -98,7 +118,7 @@ protected:
using ST = subtype<T>;
static constexpr auto real_fft = !std::is_same<T, complex<ST>>::value;
- using plan_t = std::conditional_t<real_fft, dft_plan_real<T>, dft_plan<ST>>;
+ using plan_t = internal::dft_conv_plan<T>;
// Length of filter data.
size_t data_size;
diff --git a/include/kfr/dft/fft.hpp b/include/kfr/dft/fft.hpp
@@ -347,12 +347,12 @@ struct dft_plan_real : dft_plan<T>
void execute(univector<complex<T>, Tag1>&, const univector<complex<T>, Tag2>&, univector<u8, Tag3>&,
cbool_t<inverse>) const = delete;
- KFR_MEM_INTRINSIC void execute(complex<T>* out, const T* in, u8* temp) const
+ KFR_MEM_INTRINSIC void execute(complex<T>* out, const T* in, u8* temp, cdirect_t = {}) const
{
this->execute_dft(cfalse, out, ptr_cast<complex<T>>(in), temp);
fmt_stage->execute(cfalse, out, out, nullptr);
}
- KFR_MEM_INTRINSIC void execute(T* out, const complex<T>* in, u8* temp) const
+ KFR_MEM_INTRINSIC void execute(T* out, const complex<T>* in, u8* temp, cinvert_t = {}) const
{
complex<T>* outdata = ptr_cast<complex<T>>(out);
fmt_stage->execute(ctrue, outdata, in, nullptr);
@@ -361,14 +361,14 @@ struct dft_plan_real : dft_plan<T>
template <univector_tag Tag1, univector_tag Tag2, univector_tag Tag3>
KFR_MEM_INTRINSIC void execute(univector<complex<T>, Tag1>& out, const univector<T, Tag2>& in,
- univector<u8, Tag3>& temp) const
+ univector<u8, Tag3>& temp, cdirect_t = {}) const
{
this->execute_dft(cfalse, out.data(), ptr_cast<complex<T>>(in.data()), temp.data());
fmt_stage->execute(cfalse, out.data(), out.data(), nullptr);
}
template <univector_tag Tag1, univector_tag Tag2, univector_tag Tag3>
KFR_MEM_INTRINSIC void execute(univector<T, Tag1>& out, const univector<complex<T>, Tag2>& in,
- univector<u8, Tag3>& temp) const
+ univector<u8, Tag3>& temp, cinvert_t = {}) const
{
complex<T>* outdata = ptr_cast<complex<T>>(out.data());
fmt_stage->execute(ctrue, outdata, in.data(), nullptr);
diff --git a/include/kfr/dft/impl/convolution-impl.cpp b/include/kfr/dft/impl/convolution-impl.cpp
@@ -82,49 +82,12 @@ univector<T> autocorrelate(const univector_ref<const T>& src1)
} // namespace intrinsics
-// Create a helper template struct to handle the differences between real and complex FFT.
-template <typename T, typename ST = subtype<T>,
- typename plan_t =
- std::conditional_t<std::is_same<T, complex<ST>>::value, dft_plan<ST>, dft_plan_real<T>>>
-struct convolve_filter_fft
-{
- static plan_t make(size_t size);
- static inline void ifft(plan_t const& plan, univector<T>& out, const univector<complex<T>>& in,
- univector<u8>& temp);
- static size_t csize(plan_t const& plan);
-};
-// Partial template specializations for complex and real cases:
-template <typename ST>
-struct convolve_filter_fft<complex<ST>, ST, dft_plan<ST>>
-{
- static dft_plan<ST> make(size_t size) { return dft_plan<ST>(size); }
- static inline void ifft(dft_plan<ST> const& plan, univector<complex<ST>>& out,
- const univector<complex<ST>>& in, univector<u8>& temp)
- {
- plan.execute(out, in, temp, ctrue);
- }
- static size_t csize(dft_plan<ST> const& plan) { return plan.size; }
-};
-template <typename T>
-struct convolve_filter_fft<T, T, dft_plan_real<T>>
-{
- static dft_plan_real<T> make(size_t size) { return dft_plan_real<T>(size, dft_pack_format::Perm); }
- static inline void ifft(dft_plan_real<T> const& plan, univector<T>& out, const univector<complex<T>>& in,
- univector<u8>& temp)
- {
- plan.execute(out, in, temp);
- }
- static size_t csize(dft_plan_real<T> const& plan) { return plan.size / 2; }
-};
-
template <typename T>
convolve_filter<T>::convolve_filter(size_t size_, size_t block_size_)
- : data_size(size_), block_size(next_poweroftwo(block_size_)),
- fft(convolve_filter_fft<T>::make(2 * block_size)), temp(fft.temp_size),
+ : data_size(size_), block_size(next_poweroftwo(block_size_)), fft(2 * block_size), temp(fft.temp_size),
segments((data_size + block_size - 1) / block_size), ir_segments(segments.size()), input_position(0),
- saved_input(block_size), premul(convolve_filter_fft<T>::csize(fft)),
- cscratch(convolve_filter_fft<T>::csize(fft)), scratch1(fft.size), scratch2(fft.size),
- overlap(block_size), position(0)
+ saved_input(block_size), premul(fft.csize()), cscratch(fft.csize()), scratch1(fft.size),
+ scratch2(fft.size), overlap(block_size), position(0)
{
}
@@ -145,8 +108,8 @@ void convolve_filter<T>::set_data(const univector_ref<const T>& data)
const ST ifftsize = reciprocal(ST(fft.size));
for (size_t i = 0; i < ir_segments.size(); i++)
{
- segments[i].resize(convolve_filter_fft<T>::csize(fft));
- ir_segments[i].resize(convolve_filter_fft<T>::csize(fft));
+ segments[i].resize(fft.csize());
+ ir_segments[i].resize(fft.csize());
input = padded(data.slice(i * block_size, block_size));
fft.execute(ir_segments[i], input, temp);
@@ -217,7 +180,7 @@ void convolve_filter<T>::process_buffer(T* output, const T* input, size_t size)
fft_multiply_accumulate(cscratch, premul, ir_segments[0], segments[position], fft_multiply_pack);
}
// y_k = IFFT( Y_k )
- convolve_filter_fft<T>::ifft(fft, scratch2, cscratch, temp);
+ fft.execute(scratch2, cscratch, temp, cinvert_t{});
// z_k = y_k + overlap
process(make_univector(output + processed, processing),