commit 321cf99edbcdf451290aaa36fbf4ff60454ac950
parent 631c3168538b69f23e34504bb2dee4da54f53fec
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Sun, 25 Nov 2018 20:52:51 +0300
Merge branch 'dft' into dev
Diffstat:
5 files changed, 562 insertions(+), 176 deletions(-)
diff --git a/include/kfr/dft/cache.hpp b/include/kfr/dft/cache.hpp
@@ -127,7 +127,7 @@ template <typename T, size_t Tag>
univector<complex<T>> dft(const univector<complex<T>, Tag>& input)
{
dft_plan_ptr<T> dft = dft_cache::instance().get(ctype_t<T>(), input.size());
- univector<complex<T>> output(input.size());
+ univector<complex<T>> output(input.size(), std::numeric_limits<T>::quiet_NaN());
univector<u8> temp(dft->temp_size);
dft->execute(output, input, temp);
return output;
@@ -137,7 +137,7 @@ template <typename T, size_t Tag>
univector<complex<T>> idft(const univector<complex<T>, Tag>& input)
{
dft_plan_ptr<T> dft = dft_cache::instance().get(ctype_t<T>(), input.size());
- univector<complex<T>> output(input.size());
+ univector<complex<T>> output(input.size(), std::numeric_limits<T>::quiet_NaN());
univector<u8> temp(dft->temp_size);
dft->execute(output, input, temp, ctrue);
return output;
@@ -147,7 +147,7 @@ template <typename T, size_t Tag>
univector<complex<T>> realdft(const univector<T, Tag>& input)
{
dft_plan_real_ptr<T> dft = dft_cache::instance().getreal(ctype_t<T>(), input.size());
- univector<complex<T>> output(input.size() / 2 + 1);
+ univector<complex<T>> output(input.size() / 2 + 1, std::numeric_limits<T>::quiet_NaN());
univector<u8> temp(dft->temp_size);
dft->execute(output, input, temp);
return output;
@@ -157,7 +157,7 @@ template <typename T, size_t Tag>
univector<T> irealdft(const univector<complex<T>, Tag>& input)
{
dft_plan_real_ptr<T> dft = dft_cache::instance().getreal(ctype_t<T>(), (input.size() - 1) * 2);
- univector<T> output((input.size() - 1) * 2);
+ univector<T> output((input.size() - 1) * 2, std::numeric_limits<T>::quiet_NaN());
univector<u8> temp(dft->temp_size);
dft->execute(output, input, temp);
return output;
diff --git a/include/kfr/dft/dft-src.cpp b/include/kfr/dft/dft-src.cpp
@@ -45,11 +45,16 @@ CMT_PRAGMA_MSVC(warning(disable : 4100))
namespace kfr
{
+constexpr csizes_t<2, 3, 4, 5, 6, 7, 8, 10> dft_radices{};
+
#define DFT_ASSERT TESTO_ASSERT_INACTIVE
template <typename T>
constexpr size_t fft_vector_width = platform<T>::vector_width;
+using cdirect_t = cfalse_t;
+using cinvert_t = ctrue_t;
+
template <typename T>
struct dft_stage
{
@@ -59,19 +64,42 @@ struct dft_stage
u8* data = nullptr;
size_t repeats = 1;
size_t out_offset = 0;
+ size_t width = 0;
+ size_t blocks = 0;
const char* name;
- bool recursion = false;
+ bool recursion = false;
+ bool can_inplace = true;
+ bool inplace = false;
+ bool to_scratch = false;
void initialize(size_t size) { do_initialize(size); }
- KFR_INTRIN void execute(complex<T>* out, const complex<T>* in, u8* temp) { do_execute(out, in, temp); }
+ KFR_INTRIN void execute(cdirect_t, complex<T>* out, const complex<T>* in, u8* temp)
+ {
+ do_execute(cdirect_t(), out, in, temp);
+ }
+ KFR_INTRIN void execute(cinvert_t, complex<T>* out, const complex<T>* in, u8* temp)
+ {
+ do_execute(cinvert_t(), out, in, temp);
+ }
virtual ~dft_stage() {}
protected:
virtual void do_initialize(size_t) {}
- virtual void do_execute(complex<T>*, const complex<T>*, u8* temp) = 0;
+ virtual void do_execute(cdirect_t, complex<T>*, const complex<T>*, u8* temp) = 0;
+ virtual void do_execute(cinvert_t, complex<T>*, const complex<T>*, u8* temp) = 0;
};
+#define DFT_STAGE_FN \
+ void do_execute(cdirect_t, complex<T>* out, const complex<T>* in, u8* temp) override \
+ { \
+ return do_execute<false>(out, in, temp); \
+ } \
+ void do_execute(cinvert_t, complex<T>* out, const complex<T>* in, u8* temp) override \
+ { \
+ return do_execute<true>(out, in, temp); \
+ }
+
CMT_PRAGMA_GNU(GCC diagnostic push)
#if CMT_HAS_WARNING("-Wassume")
CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wassume")
@@ -440,7 +468,265 @@ KFR_SINTRIN ctrue_t radix4_pass(csize_t<4>, size_t blocks, csize_t<width>, cfals
return {};
}
-template <typename T, bool splitin, bool is_even, bool inverse>
+template <typename T>
+static void dft_stage_fixed_initialize(dft_stage<T>* stage, size_t width)
+{
+ complex<T>* twiddle = ptr_cast<complex<T>>(stage->data);
+ const size_t N = stage->repeats * stage->stage_size;
+ const size_t Nord = stage->repeats;
+ size_t i = 0;
+
+ while (width > 0)
+ {
+ CMT_LOOP_NOUNROLL
+ for (; i < Nord / width * width; i += width)
+ {
+ CMT_LOOP_NOUNROLL
+ for (size_t j = 1; j < stage->stage_size; j++)
+ {
+ CMT_LOOP_NOUNROLL
+ for (size_t k = 0; k < width; k++)
+ {
+ cvec<T, 1> xx = cossin_conj(broadcast<2, T>(c_pi<T, 2> * (i + k) * j / N));
+ ref_cast<cvec<T, 1>>(twiddle[k]) = xx;
+ }
+ twiddle += width;
+ }
+ }
+ width = width / 2;
+ }
+}
+
+template <typename T, size_t radix>
+struct dft_stage_fixed_impl : dft_stage<T>
+{
+ dft_stage_fixed_impl(size_t radix_, size_t iterations, size_t blocks)
+ {
+ this->stage_size = radix;
+ this->blocks = blocks;
+ this->repeats = iterations;
+ this->recursion = false; // true;
+ this->data_size =
+ align_up((this->repeats * (radix - 1)) * sizeof(complex<T>), platform<>::native_cache_alignment);
+ }
+
+protected:
+ constexpr static size_t width = fft_vector_width<T>;
+ virtual void do_initialize(size_t size) override final { dft_stage_fixed_initialize(this, width); }
+
+ DFT_STAGE_FN
+ template <bool inverse>
+ void do_execute(complex<T>* out, const complex<T>* in, u8*)
+ {
+ const size_t Nord = this->repeats;
+ const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
+
+ const size_t N = Nord * this->stage_size;
+ CMT_LOOP_NOUNROLL
+ for (size_t b = 0; b < this->blocks; b++)
+ {
+ butterflies(Nord, csize<width>, csize<radix>, cbool<inverse>, out, in, twiddle, Nord);
+ in += N;
+ out += N;
+ }
+ }
+};
+
+template <typename T, size_t radix>
+struct dft_stage_fixed_final_impl : dft_stage<T>
+{
+ dft_stage_fixed_final_impl(size_t radix_, size_t iterations, size_t blocks)
+ {
+ this->stage_size = radix;
+ this->blocks = blocks;
+ this->repeats = iterations;
+ this->recursion = false; // true;
+ this->can_inplace = false;
+ }
+ constexpr static size_t width = fft_vector_width<T>;
+ DFT_STAGE_FN
+ template <bool inverse>
+ void do_execute(complex<T>* out, const complex<T>* in, u8*)
+ {
+ const size_t b = this->blocks;
+ const size_t size = b * radix;
+
+ butterflies(b, csize<width>, csize<radix>, cbool<inverse>, out, in, b);
+ }
+};
+
+template <typename T, bool final>
+struct dft_stage_generic_impl : dft_stage<T>
+{
+ dft_stage_generic_impl(size_t radix, size_t iterations, size_t blocks)
+ {
+ this->stage_size = radix;
+ this->blocks = blocks;
+ this->repeats = iterations;
+ this->recursion = false; // true;
+ this->can_inplace = false;
+ this->temp_size = align_up(sizeof(complex<T>) * radix, platform<>::native_cache_alignment);
+ this->data_size =
+ align_up(sizeof(complex<T>) * sqr(this->stage_size / 2), platform<>::native_cache_alignment);
+ }
+
+protected:
+ virtual void do_initialize(size_t size) override final
+ {
+ complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
+ CMT_LOOP_NOUNROLL
+ for (size_t i = 0; i < this->stage_size / 2; i++)
+ {
+ CMT_LOOP_NOUNROLL
+ for (size_t j = 0; j < this->stage_size / 2; j++)
+ {
+ cwrite<1>(twiddle++,
+ cossin_conj(broadcast<2>((i + 1) * (j + 1) * c_pi<T, 2> / this->stage_size)));
+ }
+ }
+ }
+
+ DFT_STAGE_FN
+ template <bool inverse>
+ void do_execute(complex<T>* out, const complex<T>* in, u8* temp)
+ {
+ const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
+ const size_t bl = this->blocks;
+ const size_t Nord = this->repeats;
+ const size_t N = Nord * this->stage_size;
+
+ CMT_LOOP_NOUNROLL
+ for (size_t b = 0; b < bl; b++)
+ generic_butterfly(this->stage_size, cbool<inverse>, out + b, in + b * this->stage_size,
+ ptr_cast<complex<T>>(temp), twiddle, bl);
+ }
+};
+
+template <typename T, typename Tr2>
+inline void dft_permute(complex<T>* out, const complex<T>* in, size_t r0, size_t r1, Tr2 r2)
+{
+ CMT_ASSUME(r0 > 1);
+ CMT_ASSUME(r1 > 1);
+
+ CMT_LOOP_NOUNROLL
+ for (size_t p = 0; p < r0; p++)
+ {
+ const complex<T>* in1 = in;
+ CMT_LOOP_NOUNROLL
+ for (size_t i = 0; i < r1; i++)
+ {
+ const complex<T>* in2 = in1;
+ CMT_LOOP_UNROLL
+ for (size_t j = 0; j < r2; j++)
+ {
+ *out++ = *in2;
+ in2 += r1;
+ }
+ in1++;
+ in += r2;
+ }
+ }
+}
+
+template <typename T>
+inline void dft_permute_deep(complex<T>*& out, const complex<T>* in, const size_t* radices, size_t count,
+ size_t index, size_t inscale, size_t inner_size)
+{
+ const bool b = index == 1;
+ const size_t radix = radices[index];
+ if (b)
+ {
+ CMT_LOOP_NOUNROLL
+ for (size_t i = 0; i < radix; i++)
+ {
+ const complex<T>* in1 = in;
+ CMT_LOOP_UNROLL
+ for (size_t j = 0; j < radices[0]; j++)
+ {
+ *out++ = *in1;
+ in1 += inner_size;
+ }
+ in += inscale;
+ }
+ }
+ else
+ {
+ const size_t steps = radix;
+ const size_t inscale_next = inscale * radix;
+ CMT_LOOP_NOUNROLL
+ for (size_t i = 0; i < steps; i++)
+ {
+ dft_permute_deep(out, in, radices, count, index - 1, inscale_next, inner_size);
+ in += inscale;
+ }
+ }
+}
+
+template <typename T>
+struct dft_reorder_stage_impl : dft_stage<T>
+{
+ dft_reorder_stage_impl(const int* radices, size_t count) : count(count)
+ {
+ this->can_inplace = false;
+ this->data_size = 0;
+ std::copy(std::make_reverse_iterator(radices + count), std::make_reverse_iterator(radices), this->radices);
+ this->inner_size = 1;
+ this->size = 1;
+ for (size_t r = 0; r < count; r++)
+ {
+ if (r != 0 && r != count - 2)
+ this->inner_size *= radices[r];
+ this->size *= radices[r];
+ }
+ }
+
+protected:
+ size_t radices[32];
+ size_t count = 0;
+ size_t size = 0;
+ size_t inner_size = 0;
+ virtual void do_initialize(size_t) override final {}
+
+ DFT_STAGE_FN
+ template <bool inverse>
+ void do_execute(complex<T>* out, const complex<T>* in, u8*)
+ {
+ // std::copy(in, in + this->size, out);
+ // return;
+ if (count == 3)
+ {
+ dft_permute(out, in, radices[2], radices[1], radices[0]);
+ }
+ else
+ {
+ const size_t rlast = radices[count - 1];
+ for (size_t p = 0; p < rlast; p++)
+ {
+ dft_permute_deep(out, in, radices, count, count - 2, 1, inner_size);
+ in += size / rlast;
+ }
+ }
+ // if (count == 1)
+ // {
+ // cswitch(dft_radices, radices[0], [&](auto rfirst) { dft_permute3(out, in, 1, 1, rfirst);
+ // },
+ // [&]() { dft_permute3(out, in, 1, 1, radices[0]); });
+ // }
+
+ // const size_t rlast = radices[count - 1];
+ // const size_t size = this->size / rlast;
+ //
+ // cswitch(dft_radices, radices[0], [&](auto rfirst) {
+ // for (size_t p = 0; p < rlast; p++)
+ // {
+ // reorder_impl(out, in, index, 1, rfirst);
+ // in += size;
+ // }
+ // });
+ }
+};
+
+template <typename T, bool splitin, bool is_even>
struct fft_stage_impl : dft_stage<T>
{
fft_stage_impl(size_t stage_size)
@@ -463,7 +749,9 @@ protected:
initialize_twiddles<T, width>(twiddle, this->stage_size, size, true);
}
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override final
+ DFT_STAGE_FN
+ template <bool inverse>
+ void do_execute(complex<T>* out, const complex<T>* in, u8*)
{
const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
if (splitin)
@@ -476,7 +764,7 @@ protected:
}
};
-template <typename T, bool splitin, size_t size, bool inverse>
+template <typename T, bool splitin, size_t size>
struct fft_final_stage_impl : dft_stage<T>
{
fft_final_stage_impl(size_t)
@@ -513,13 +801,15 @@ protected:
init_twiddles(csize<size>, total_size, cbool<splitin>, twiddle);
}
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override
+ DFT_STAGE_FN
+ template <bool inverse>
+ void do_execute(complex<T>* out, const complex<T>* in, u8*)
{
const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
- final_stage(csize<size>, 1, cbool<splitin>, out, in, twiddle);
+ final_stage<inverse>(csize<size>, 1, cbool<splitin>, out, in, twiddle);
}
- template <typename U = T, KFR_ENABLE_IF(is_same<U, float>::value)>
+ template <bool inverse, typename U = T, KFR_ENABLE_IF(is_same<U, float>::value)>
KFR_INTRIN void final_stage(csize_t<32>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*,
const complex<T>*& twiddle)
{
@@ -527,7 +817,7 @@ protected:
cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
}
- template <typename U = T, KFR_ENABLE_IF(is_same<U, float>::value)>
+ template <bool inverse, typename U = T, KFR_ENABLE_IF(is_same<U, float>::value)>
KFR_INTRIN void final_stage(csize_t<16>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*,
const complex<T>*& twiddle)
{
@@ -535,6 +825,7 @@ protected:
cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
}
+ template <bool inverse>
KFR_INTRIN void final_stage(csize_t<8>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*,
const complex<T>*& twiddle)
{
@@ -542,6 +833,7 @@ protected:
cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
}
+ template <bool inverse>
KFR_INTRIN void final_stage(csize_t<4>, size_t invN, cfalse_t, complex<T>* out, const complex<T>*,
const complex<T>*& twiddle)
{
@@ -549,7 +841,7 @@ protected:
cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
}
- template <size_t N, bool pass_splitin>
+ template <bool inverse, size_t N, bool pass_splitin>
KFR_INTRIN void final_stage(csize_t<N>, size_t invN, cbool_t<pass_splitin>, complex<T>* out,
const complex<T>* in, const complex<T>*& twiddle)
{
@@ -561,7 +853,7 @@ protected:
radix4_pass(N, invN, csize_t<pass_width>(), cbool<pass_split>, cbool_t<pass_splitin>(),
cbool_t<use_br2>(), cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, in,
twiddle);
- final_stage(csize<N / 4>, invN * 4, cbool<pass_split>, out, out, twiddle);
+ final_stage<inverse>(csize<N / 4>, invN * 4, cbool<pass_split>, out, out, twiddle);
}
};
@@ -580,23 +872,28 @@ protected:
virtual void do_initialize(size_t) override final {}
- virtual void do_execute(complex<T>* out, const complex<T>*, u8* /*temp*/) override final
+ DFT_STAGE_FN
+ template <bool inverse>
+ void do_execute(complex<T>* out, const complex<T>* in, u8*)
{
fft_reorder(out, log2n, cbool_t<!is_even>());
}
};
-template <typename T, size_t log2n, bool inverse>
+template <typename T, size_t log2n>
struct fft_specialization;
-template <typename T, bool inverse>
-struct fft_specialization<T, 1, inverse> : dft_stage<T>
+template <typename T>
+struct fft_specialization<T, 1> : dft_stage<T>
{
fft_specialization(size_t) {}
protected:
constexpr static bool aligned = false;
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
+ DFT_STAGE_FN
+
+ template <bool inverse>
+ void do_execute(complex<T>* out, const complex<T>* in, u8*)
{
cvec<T, 1> a0, a1;
split(cread<2, aligned>(in), a0, a1);
@@ -604,14 +901,16 @@ protected:
}
};
-template <typename T, bool inverse>
-struct fft_specialization<T, 2, inverse> : dft_stage<T>
+template <typename T>
+struct fft_specialization<T, 2> : dft_stage<T>
{
fft_specialization(size_t) {}
protected:
constexpr static bool aligned = false;
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
+ DFT_STAGE_FN
+ template <bool inverse>
+ void do_execute(complex<T>* out, const complex<T>* in, u8*)
{
cvec<T, 1> a0, a1, a2, a3;
split(cread<4>(in), a0, a1, a2, a3);
@@ -620,14 +919,16 @@ protected:
}
};
-template <typename T, bool inverse>
-struct fft_specialization<T, 3, inverse> : dft_stage<T>
+template <typename T>
+struct fft_specialization<T, 3> : dft_stage<T>
{
fft_specialization(size_t) {}
protected:
constexpr static bool aligned = false;
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
+ DFT_STAGE_FN
+ template <bool inverse>
+ void do_execute(complex<T>* out, const complex<T>* in, u8*)
{
cvec<T, 8> v8 = cread<8, aligned>(in);
butterfly8<inverse>(v8);
@@ -635,14 +936,16 @@ protected:
}
};
-template <typename T, bool inverse>
-struct fft_specialization<T, 4, inverse> : dft_stage<T>
+template <typename T>
+struct fft_specialization<T, 4> : dft_stage<T>
{
fft_specialization(size_t) {}
protected:
constexpr static bool aligned = false;
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
+ DFT_STAGE_FN
+ template <bool inverse>
+ void do_execute(complex<T>* out, const complex<T>* in, u8*)
{
cvec<T, 16> v16 = cread<16, aligned>(in);
butterfly16<inverse>(v16);
@@ -650,14 +953,16 @@ protected:
}
};
-template <typename T, bool inverse>
-struct fft_specialization<T, 5, inverse> : dft_stage<T>
+template <typename T>
+struct fft_specialization<T, 5> : dft_stage<T>
{
fft_specialization(size_t) {}
protected:
constexpr static bool aligned = false;
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
+ DFT_STAGE_FN
+ template <bool inverse>
+ void do_execute(complex<T>* out, const complex<T>* in, u8*)
{
cvec<T, 32> v32 = cread<32, aligned>(in);
butterfly32<inverse>(v32);
@@ -665,21 +970,23 @@ protected:
}
};
-template <typename T, bool inverse>
-struct fft_specialization<T, 6, inverse> : dft_stage<T>
+template <typename T>
+struct fft_specialization<T, 6> : dft_stage<T>
{
fft_specialization(size_t) {}
protected:
constexpr static bool aligned = false;
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8*) override final
+ DFT_STAGE_FN
+ template <bool inverse>
+ void do_execute(complex<T>* out, const complex<T>* in, u8*)
{
butterfly64(cbool_t<inverse>(), cbool_t<aligned>(), out, in);
}
};
-template <typename T, bool inverse>
-struct fft_specialization<T, 7, inverse> : dft_stage<T>
+template <typename T>
+struct fft_specialization<T, 7> : dft_stage<T>
{
fft_specialization(size_t)
{
@@ -704,13 +1011,16 @@ protected:
initialize_twiddles<T, width>(twiddle, 8, total_size, split_format);
}
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override final
+ DFT_STAGE_FN
+ template <bool inverse>
+ void do_execute(complex<T>* out, const complex<T>* in, u8*)
{
const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
- final_pass(csize_t<final_size>(), out, in, twiddle);
+ final_pass<inverse>(csize_t<final_size>(), out, in, twiddle);
fft_reorder(out, csize_t<7>());
}
+ template <bool inverse>
KFR_INTRIN void final_pass(csize_t<8>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle)
{
radix4_pass(128, 1, csize_t<width>(), ctrue, cfalse, cbool_t<use_br2>(), cbool_t<prefetch>(),
@@ -721,6 +1031,7 @@ protected:
cbool_t<prefetch>(), cbool_t<inverse>(), cbool_t<aligned>(), out, out, twiddle);
}
+ template <bool inverse>
KFR_INTRIN void final_pass(csize_t<32>, complex<T>* out, const complex<T>* in, const complex<T>* twiddle)
{
radix4_pass(128, 1, csize_t<width>(), cfalse, cfalse, cbool_t<use_br2>(), cbool_t<prefetch>(),
@@ -730,13 +1041,16 @@ protected:
}
};
-template <bool inverse>
-struct fft_specialization<float, 8, inverse> : dft_stage<float>
+template <>
+struct fft_specialization<float, 8> : dft_stage<float>
{
fft_specialization(size_t) { this->temp_size = sizeof(complex<float>) * 256; }
protected:
- virtual void do_execute(complex<float>* out, const complex<float>* in, u8* temp) override final
+ using T = float;
+ DFT_STAGE_FN
+ template <bool inverse>
+ void do_execute(complex<T>* out, const complex<T>* in, u8* temp)
{
complex<float>* scratch = ptr_cast<complex<float>>(temp);
if (out == in)
@@ -766,59 +1080,47 @@ protected:
}
};
-template <bool inverse>
-struct fft_specialization<double, 8, inverse> : fft_final_stage_impl<double, false, 256, inverse>
+template <>
+struct fft_specialization<double, 8> : fft_final_stage_impl<double, false, 256>
{
using T = double;
- using fft_final_stage_impl<double, false, 256, inverse>::fft_final_stage_impl;
+ using fft_final_stage_impl<double, false, 256>::fft_final_stage_impl;
- virtual void do_execute(complex<T>* out, const complex<T>* in, u8* /*temp*/) override final
+ DFT_STAGE_FN
+ template <bool inverse>
+ void do_execute(complex<T>* out, const complex<T>* in, u8*)
{
- fft_final_stage_impl<double, false, 256, inverse>::do_execute(out, in, nullptr);
+ fft_final_stage_impl<double, false, 256>::template do_execute<inverse>(out, in, nullptr);
fft_reorder(out, csize_t<8>());
}
};
-
-template <typename T, bool splitin, bool is_even>
-struct fft_stage_impl_t
-{
- template <bool inverse>
- using type = internal::fft_stage_impl<T, splitin, is_even, inverse>;
-};
-template <typename T, bool splitin, size_t size>
-struct fft_final_stage_impl_t
-{
- template <bool inverse>
- using type = internal::fft_final_stage_impl<T, splitin, size, inverse>;
-};
-template <typename T, bool is_even>
-struct fft_reorder_stage_impl_t
-{
- template <bool>
- using type = internal::fft_reorder_stage_impl<T, is_even>;
-};
-template <typename T, size_t log2n, bool aligned>
-struct fft_specialization_t
-{
- template <bool inverse>
- using type = internal::fft_specialization<T, log2n, inverse>;
-};
} // namespace internal
//
template <typename T>
-template <template <bool inverse> class Stage>
-void dft_plan<T>::add_stage(size_t stage_size)
+template <typename Stage, typename... Args>
+void dft_plan<T>::add_stage(Args... args)
+{
+ dft_stage<T>* stage = new Stage(args...);
+ stage->name = nullptr;
+ this->data_size += stage->data_size;
+ this->temp_size += stage->temp_size;
+ stages.push_back(dft_stage_ptr(stage));
+}
+
+template <typename T>
+template <bool is_final>
+void dft_plan<T>::prepare_dft_stage(size_t radix, size_t iterations, size_t blocks, cbool_t<is_final>)
{
- dft_stage<T>* direct_stage = new Stage<false>(stage_size);
- direct_stage->name = nullptr;
- this->data_size += direct_stage->data_size;
- this->temp_size += direct_stage->temp_size;
- stages[0].push_back(dft_stage_ptr(direct_stage));
- dft_stage<T>* inverse_stage = new Stage<true>(stage_size);
- inverse_stage->name = nullptr;
- stages[1].push_back(dft_stage_ptr(inverse_stage));
+ return cswitch(
+ dft_radices, radix,
+ [&](auto radix) CMT_INLINE_LAMBDA {
+ add_stage<conditional<is_final, internal::dft_stage_fixed_final_impl<T, val_of(radix)>,
+ internal::dft_stage_fixed_impl<T, val_of(radix)>>>(radix, iterations,
+ blocks);
+ },
+ [&]() { add_stage<internal::dft_stage_generic_impl<T, is_final>>(radix, iterations, blocks); });
}
template <typename T>
@@ -827,70 +1129,118 @@ void dft_plan<T>::make_fft(size_t stage_size, cbool_t<is_even>, cbool_t<first>)
{
constexpr size_t final_size = is_even ? 1024 : 512;
- using fft_stage_impl_t = internal::fft_stage_impl_t<T, !first, is_even>;
- using fft_final_stage_impl_t = internal::fft_final_stage_impl_t<T, !first, final_size>;
-
if (stage_size >= 2048)
{
- add_stage<fft_stage_impl_t::template type>(stage_size);
+ add_stage<internal::fft_stage_impl<T, !first, is_even>>(stage_size);
make_fft(stage_size / 4, cbool_t<is_even>(), cfalse);
}
else
{
- add_stage<fft_final_stage_impl_t::template type>(final_size);
+ add_stage<internal::fft_final_stage_impl<T, !first, final_size>>(final_size);
}
}
template <typename T>
+struct reverse_wrapper
+{
+ T& iterable;
+};
+
+template <typename T>
+auto begin(reverse_wrapper<T> w)
+{
+ return std::rbegin(w.iterable);
+}
+
+template <typename T>
+auto end(reverse_wrapper<T> w)
+{
+ return std::rend(w.iterable);
+}
+
+template <typename T>
+reverse_wrapper<T> reversed(T&& iterable)
+{
+ return { iterable };
+}
+
+template <typename T>
void dft_plan<T>::initialize()
{
data = autofree<u8>(data_size);
size_t offset = 0;
- for (dft_stage_ptr& stage : stages[0])
+ for (dft_stage_ptr& stage : stages)
{
stage->data = data.data() + offset;
stage->initialize(this->size);
offset += stage->data_size;
}
- offset = 0;
- for (dft_stage_ptr& stage : stages[1])
+
+ bool to_scratch = false;
+ for (dft_stage_ptr& stage : reversed(stages))
{
- stage->data = data.data() + offset;
- offset += stage->data_size;
+ if (to_scratch)
+ {
+ this->temp_size += align_up(sizeof(complex<T>) * this->size, platform<>::native_cache_alignment);
+ }
+ stage->to_scratch = to_scratch;
+ if (!stage->can_inplace)
+ {
+ to_scratch = !to_scratch;
+ }
}
}
template <typename T>
+const complex<T>* dft_plan<T>::select_in(size_t stage, const complex<T>* out, const complex<T>* in,
+ const complex<T>* scratch) const
+{
+ if (stage == 0)
+ return in;
+ return stages[stage - 1]->to_scratch ? scratch : out;
+}
+
+template <typename T>
+complex<T>* dft_plan<T>::select_out(size_t stage, complex<T>* out, complex<T>* scratch) const
+{
+ return stages[stage]->to_scratch ? scratch : out;
+}
+
+template <typename T>
template <bool inverse>
void dft_plan<T>::execute_dft(cbool_t<inverse>, complex<T>* out, const complex<T>* in, u8* temp) const
{
size_t stack[32] = { 0 };
- const size_t count = stages[inverse].size();
+ complex<T>* scratch =
+ ptr_cast<complex<T>>(temp + this->temp_size -
+ align_up(sizeof(complex<T>) * this->size, platform<>::native_cache_alignment));
+
+ const size_t count = stages.size();
for (size_t depth = 0; depth < count;)
{
- if (stages[inverse][depth]->recursion)
+ if (stages[depth]->recursion)
{
- complex<T>* rout = out;
- const complex<T>* rin = in;
- size_t rdepth = depth;
- size_t maxdepth = depth;
+ size_t offset = 0;
+ size_t rdepth = depth;
+ size_t maxdepth = depth;
do
{
- if (stack[rdepth] == stages[inverse][rdepth]->repeats)
+ if (stack[rdepth] == stages[rdepth]->repeats)
{
stack[rdepth] = 0;
rdepth--;
}
else
{
- stages[inverse][rdepth]->execute(rout, rin, temp);
- rout += stages[inverse][rdepth]->out_offset;
- rin = rout;
+ complex<T>* rout = select_out(rdepth, out, scratch);
+ const complex<T>* rin = select_in(rdepth, out, in, scratch);
+ stages[rdepth]->execute(cbool<inverse>, rout + offset, rin + offset, temp);
+ offset += stages[rdepth]->out_offset;
stack[rdepth]++;
- if (rdepth < count - 1 && stages[inverse][rdepth + 1]->recursion)
+ if (rdepth < count - 1 && stages[rdepth + 1]->recursion)
rdepth++;
else
maxdepth = rdepth;
@@ -900,45 +1250,42 @@ void dft_plan<T>::execute_dft(cbool_t<inverse>, complex<T>* out, const complex<T
}
else
{
- stages[inverse][depth]->execute(out, in, temp);
+ stages[depth]->execute(cbool<inverse>, select_out(depth, out, scratch),
+ select_in(depth, out, in, scratch), temp);
depth++;
}
- in = out;
}
}
-constexpr csizes_t<2, 3, 4, 5, 6, 8, 10> dft_radices{};
-
template <typename T>
dft_plan<T>::dft_plan(size_t size, dft_type) : size(size), temp_size(0), data_size(0)
{
if (is_poweroftwo(size))
{
const size_t log2n = ilog2(size);
- cswitch(
- csizes_t<1, 2, 3, 4, 5, 6, 7, 8>(), log2n,
- [&](auto log2n) {
- (void)log2n;
- this->add_stage<
- internal::fft_specialization_t<T, val_of(decltype(log2n)()), false>::template type>(size);
- },
- [&]() {
- cswitch(cfalse_true, is_even(log2n), [&](auto is_even) {
- this->make_fft(size, is_even, ctrue);
- this->add_stage<
- internal::fft_reorder_stage_impl_t<T, val_of(decltype(is_even)())>::template type>(
- size);
+ cswitch(csizes_t<1, 2, 3, 4, 5, 6, 7, 8>(), log2n,
+ [&](auto log2n) {
+ (void)log2n;
+ constexpr size_t log2nv = val_of(decltype(log2n)());
+ this->add_stage<internal::fft_specialization<T, log2nv>>(size);
+ },
+ [&]() {
+ cswitch(cfalse_true, is_even(log2n), [&](auto is_even) {
+ this->make_fft(size, is_even, ctrue);
+ constexpr size_t is_evenv = val_of(decltype(is_even)());
+ this->add_stage<internal::fft_reorder_stage_impl<T, is_evenv>>(size);
+ });
});
- });
}
-#if 0
else
{
size_t cur_size = size;
constexpr size_t radices_count = dft_radices.back() + 1;
u8 count[radices_count] = { 0 };
+ int radices[32] = { 0 };
+ size_t radices_size = 0;
- cforeach(dft_radices, [&](auto radix) {
+ cforeach(dft_radices[csizeseq<dft_radices.size(), dft_radices.size() - 1, -1>], [&](auto radix) {
while (cur_size && cur_size % val_of(radix) == 0)
{
count[val_of(radix)]++;
@@ -954,7 +1301,11 @@ dft_plan<T>::dft_plan(size_t size, dft_type) : size(size), temp_size(0), data_si
for (size_t i = 0; i < count[r]; i++)
{
iterations /= r;
- this->add_dft_stage(r, iterations, blocks, fft_vector_width<T>, type);
+ radices[radices_size++] = r;
+ if (iterations == 1)
+ this->prepare_dft_stage(r, iterations, blocks, ctrue);
+ else
+ this->prepare_dft_stage(r, iterations, blocks, cfalse);
blocks *= r;
}
}
@@ -962,10 +1313,16 @@ dft_plan<T>::dft_plan(size_t size, dft_type) : size(size), temp_size(0), data_si
if (cur_size > 1)
{
iterations /= cur_size;
- this->add_dft_stage(cur_size, iterations, blocks, fft_vector_width<T>);
+ radices[radices_size++] = cur_size;
+ if (iterations == 1)
+ this->prepare_dft_stage(cur_size, iterations, blocks, ctrue);
+ else
+ this->prepare_dft_stage(cur_size, iterations, blocks, cfalse);
}
+
+ if (stages.size() > 2)
+ this->add_stage<internal::dft_reorder_stage_impl<T>>(radices, radices_size);
}
-#endif
initialize();
}
diff --git a/include/kfr/dft/fft.hpp b/include/kfr/dft/fft.hpp
@@ -100,10 +100,13 @@ struct dft_plan
protected:
autofree<u8> data;
size_t data_size;
- std::vector<dft_stage_ptr> stages[2];
+ std::vector<dft_stage_ptr> stages;
- template <template <bool inverse> class Stage>
- void add_stage(size_t stage_size);
+ template <typename Stage, typename... Args>
+ void add_stage(Args... args);
+
+ template <bool is_final>
+ void prepare_dft_stage(size_t radix, size_t iterations, size_t blocks, cbool_t<is_final>);
template <bool is_even, bool first>
void make_fft(size_t stage_size, cbool_t<is_even>, cbool_t<first>);
@@ -111,6 +114,9 @@ protected:
void initialize();
template <bool inverse>
void execute_dft(cbool_t<inverse>, complex<T>* out, const complex<T>* in, u8* temp) const;
+
+ const complex<T>* select_in(size_t stage, const complex<T>* out, const complex<T>* in, const complex<T>* scratch) const;
+ complex<T>* select_out(size_t stage, complex<T>* out, complex<T>* scratch) const;
};
enum class dft_pack_format
diff --git a/include/kfr/dft/ft.hpp b/include/kfr/dft/ft.hpp
@@ -376,16 +376,16 @@ KFR_INTRIN void transpose4x8(const cvec<T, 8>& z0, const cvec<T, 8>& z1, const c
{
cvec<T, 16> a = concat(low(z0), low(z1), low(z2), low(z3));
cvec<T, 16> b = concat(high(z0), high(z1), high(z2), high(z3));
- a = digitreverse4<2>(a);
- b = digitreverse4<2>(b);
- w0 = part<4, 0>(a);
- w1 = part<4, 1>(a);
- w2 = part<4, 2>(a);
- w3 = part<4, 3>(a);
- w4 = part<4, 0>(b);
- w5 = part<4, 1>(b);
- w6 = part<4, 2>(b);
- w7 = part<4, 3>(b);
+ a = digitreverse4<2>(a);
+ b = digitreverse4<2>(b);
+ w0 = part<4, 0>(a);
+ w1 = part<4, 1>(a);
+ w2 = part<4, 2>(a);
+ w3 = part<4, 3>(a);
+ w4 = part<4, 0>(b);
+ w5 = part<4, 1>(b);
+ w6 = part<4, 2>(b);
+ w7 = part<4, 3>(b);
}
template <typename T>
@@ -396,12 +396,12 @@ KFR_INTRIN void transpose4x8(const cvec<T, 4>& w0, const cvec<T, 4>& w1, const c
{
cvec<T, 16> a = concat(w0, w1, w2, w3);
cvec<T, 16> b = concat(w4, w5, w6, w7);
- a = digitreverse4<2>(a);
- b = digitreverse4<2>(b);
- z0 = concat(part<4, 0>(a), part<4, 0>(b));
- z1 = concat(part<4, 1>(a), part<4, 1>(b));
- z2 = concat(part<4, 2>(a), part<4, 2>(b));
- z3 = concat(part<4, 3>(a), part<4, 3>(b));
+ a = digitreverse4<2>(a);
+ b = digitreverse4<2>(b);
+ z0 = concat(part<4, 0>(a), part<4, 0>(b));
+ z1 = concat(part<4, 1>(a), part<4, 1>(b));
+ z2 = concat(part<4, 2>(a), part<4, 2>(b));
+ z3 = concat(part<4, 3>(a), part<4, 3>(b));
}
template <typename T>
@@ -493,7 +493,7 @@ CMT_PRAGMA_GNU(GCC diagnostic pop)
template <typename T, size_t N>
CMT_NOINLINE static vec<T, N> cossin_conj(const vec<T, N>& x)
{
- return cconj(cossin(x));
+ return negodd(cossin(x));
}
template <size_t k, size_t size, bool inverse = false, typename T, size_t width,
@@ -743,7 +743,7 @@ KFR_INTRIN void apply_twiddle(const cvec<T, N>& a1, const cvec<T, N>& tw1, cvec<
cvec<T, N> tw1_ = tw1;
if (inverse)
tw1_ = -(tw1_);
- w1 = subadd(b1, a1_ * dupodd(tw1_));
+ w1 = subadd(b1, a1_ * dupodd(tw1_));
}
}
@@ -1020,13 +1020,13 @@ template <size_t N, bool inverse = false, typename T>
KFR_INTRIN void butterfly3(const cvec<T, N>& a00, const cvec<T, N>& a01, const cvec<T, N>& a02,
cvec<T, N>& w00, cvec<T, N>& w01, cvec<T, N>& w02)
{
- constexpr cvec<T, N> tw3r1 = static_cast<T>(-0.5);
- constexpr cvec<T, N> tw3i1 =
+ const static cvec<T, N> tw3r1 = static_cast<T>(-0.5 - 1.0);
+ const static cvec<T, N> tw3i1 =
static_cast<T>(0.86602540378443864676372317075) * twiddleimagmask<T, N, inverse>();
const cvec<T, N> sum1 = a01 + a02;
const cvec<T, N> dif1 = swap<2>(a01 - a02);
- w00 = a00 + sum1;
+ w00 = a00 + sum1;
const cvec<T, N> s1 = w00 + sum1 * tw3r1;
@@ -1055,7 +1055,7 @@ KFR_INTRIN void butterfly6(const cvec<T, N>& a0, const cvec<T, N>& a1, const cve
split(a03, t0, t1);
split(a25, t2, t3);
split(a41, t4, t5);
- t3 = -t3;
+ t3 = -t3;
cvec<T, N* 2> a04 = concat(t0, t4);
cvec<T, N* 2> a15 = concat(t1, t5);
cvec<T, N * 2> w02, w35;
@@ -1079,14 +1079,14 @@ KFR_INTRIN void butterfly7(const cvec<T, N>& a00, const cvec<T, N>& a01, const c
const cvec<T, N>& a06, cvec<T, N>& w00, cvec<T, N>& w01, cvec<T, N>& w02,
cvec<T, N>& w03, cvec<T, N>& w04, cvec<T, N>& w05, cvec<T, N>& w06)
{
- constexpr cvec<T, N> tw7r1 = static_cast<T>(0.623489801858733530525004884);
- constexpr cvec<T, N> tw7i1 =
+ const static cvec<T, N> tw7r1 = static_cast<T>(0.623489801858733530525004884 - 1.0);
+ const static cvec<T, N> tw7i1 =
static_cast<T>(0.78183148246802980870844452667) * twiddleimagmask<T, N, inverse>();
- constexpr cvec<T, N> tw7r2 = static_cast<T>(-0.2225209339563144042889025645);
- constexpr cvec<T, N> tw7i2 =
+ const static cvec<T, N> tw7r2 = static_cast<T>(-0.2225209339563144042889025645 - 1.0);
+ const static cvec<T, N> tw7i2 =
static_cast<T>(0.97492791218182360701813168299) * twiddleimagmask<T, N, inverse>();
- constexpr cvec<T, N> tw7r3 = static_cast<T>(-0.90096886790241912623610231951);
- constexpr cvec<T, N> tw7i3 =
+ const static cvec<T, N> tw7r3 = static_cast<T>(-0.90096886790241912623610231951 - 1.0);
+ const static cvec<T, N> tw7i3 =
static_cast<T>(0.43388373911755812047576833285) * twiddleimagmask<T, N, inverse>();
const cvec<T, N> sum1 = a01 + a06;
@@ -1095,7 +1095,7 @@ KFR_INTRIN void butterfly7(const cvec<T, N>& a00, const cvec<T, N>& a01, const c
const cvec<T, N> dif2 = swap<2>(a02 - a05);
const cvec<T, N> sum3 = a03 + a04;
const cvec<T, N> dif3 = swap<2>(a03 - a04);
- w00 = a00 + sum1 + sum2 + sum3;
+ w00 = a00 + sum1 + sum2 + sum3;
const cvec<T, N> s1 = w00 + sum1 * tw7r1 + sum2 * tw7r2 + sum3 * tw7r3;
const cvec<T, N> s2 = w00 + sum1 * tw7r2 + sum2 * tw7r3 + sum3 * tw7r1;
@@ -1125,18 +1125,18 @@ KFR_INTRIN void butterfly5(const cvec<T, N>& a00, const cvec<T, N>& a01, const c
const cvec<T, N>& a03, const cvec<T, N>& a04, cvec<T, N>& w00, cvec<T, N>& w01,
cvec<T, N>& w02, cvec<T, N>& w03, cvec<T, N>& w04)
{
- constexpr cvec<T, N> tw5r1 = static_cast<T>(0.30901699437494742410229341718);
- constexpr cvec<T, N> tw5i1 =
+ const static cvec<T, N> tw5r1 = static_cast<T>(0.30901699437494742410229341718 - 1.0);
+ const static cvec<T, N> tw5i1 =
static_cast<T>(0.95105651629515357211643933338) * twiddleimagmask<T, N, inverse>();
- constexpr cvec<T, N> tw5r2 = static_cast<T>(-0.80901699437494742410229341718);
- constexpr cvec<T, N> tw5i2 =
+ const static cvec<T, N> tw5r2 = static_cast<T>(-0.80901699437494742410229341718 - 1.0);
+ const static cvec<T, N> tw5i2 =
static_cast<T>(0.58778525229247312916870595464) * twiddleimagmask<T, N, inverse>();
const cvec<T, N> sum1 = a01 + a04;
const cvec<T, N> dif1 = swap<2>(a01 - a04);
const cvec<T, N> sum2 = a02 + a03;
const cvec<T, N> dif2 = swap<2>(a02 - a03);
- w00 = a00 + sum1 + sum2;
+ w00 = a00 + sum1 + sum2;
const cvec<T, N> s1 = w00 + sum1 * tw5r1 + sum2 * tw5r2;
const cvec<T, N> s2 = w00 + sum1 * tw5r2 + sum2 * tw5r1;
@@ -1268,7 +1268,7 @@ KFR_INTRIN void cread_transposed(cbool_t<true>, const complex<f32>* ptr, cvec<f3
{
cvec<f32, 4> w3;
cvec<f32, 16> v16 = concat(cread<4>(ptr), cread<4>(ptr + 3), cread<4>(ptr + 6), cread<4>(ptr + 9));
- v16 = digitreverse4<2>(v16);
+ v16 = digitreverse4<2>(v16);
split(v16, w0, w1, w2, w3);
}
@@ -1276,7 +1276,7 @@ KFR_INTRIN void cread_transposed(cbool_t<true>, const complex<f32>* ptr, cvec<f3
cvec<f32, 4>& w2, cvec<f32, 4>& w3, cvec<f32, 4>& w4)
{
cvec<f32, 16> v16 = concat(cread<4>(ptr), cread<4>(ptr + 5), cread<4>(ptr + 10), cread<4>(ptr + 15));
- v16 = digitreverse4<2>(v16);
+ v16 = digitreverse4<2>(v16);
split(v16, w0, w1, w2, w3);
w4 = cgather<4, 5>(ptr + 4);
}
@@ -1384,13 +1384,16 @@ KFR_INTRIN void generic_butterfly_cycle(csize_t<width>, size_t radix, cbool_t<in
const cvec<T, 1> inb = cread<1>(in + radix - (j + 1));
cvec<T, width> tw = cread<width>(twiddle);
if (inverse)
- tw = cconj(tw);
+ tw = negodd /*cconj*/ (tw);
cmul_2conj(sum0, sum1, ina, inb, tw);
twiddle += halfradix;
}
twiddle = twiddle - halfradix_sqr + width;
+ // if (inverse)
+ // std::swap(sum0, sum1);
+
if (is_constant_val(ostride))
{
cwrite<width>(out + (1 + i), sum0);
@@ -1407,6 +1410,17 @@ KFR_INTRIN void generic_butterfly_cycle(csize_t<width>, size_t radix, cbool_t<in
halfradix_sqr, twiddle, i);
}
+template <typename T>
+KFR_SINTRIN vec<T, 2> hcadd(vec<T, 2> value)
+{
+ return value;
+}
+template <typename T, size_t N, KFR_ENABLE_IF(N >= 4)>
+KFR_SINTRIN vec<T, 2> hcadd(vec<T, N> value)
+{
+ return hcadd(low(value) + high(value));
+}
+
template <size_t width, typename T, bool inverse, typename Tstride = csize_t<1>>
KFR_INTRIN void generic_butterfly_w(size_t radix, cbool_t<inverse>, complex<T>* out, const complex<T>* in,
const complex<T>* twiddle, Tstride ostride = Tstride{})
@@ -1414,7 +1428,7 @@ KFR_INTRIN void generic_butterfly_w(size_t radix, cbool_t<inverse>, complex<T>*
CMT_ASSUME(radix > 0);
{
cvec<T, width> sum = T();
- size_t j = 0;
+ size_t j = 0;
CMT_LOOP_NOUNROLL
for (; j < radix / width * width; j += width)
{
@@ -1448,13 +1462,14 @@ KFR_INTRIN void generic_butterfly(size_t radix, cbool_t<inverse>, complex<T>* ou
}
constexpr size_t width = platform<T>::vector_width;
- cswitch(csizes_t<11>(), radix,
+ generic_butterfly_w<width>(radix, cbool_t<inverse>(), out, in, twiddle, ostride);
+ /*cswitch(csizes_t<11>(), radix,
[&](auto radix_) CMT_INLINE_LAMBDA {
generic_butterfly_w<width>(decltype(radix_)(), cbool_t<inverse>(), out, in, twiddle, ostride);
},
[&]() CMT_INLINE_LAMBDA {
generic_butterfly_w<width>(radix, cbool_t<inverse>(), out, in, twiddle, ostride);
- });
+ });*/
}
template <typename T, size_t N>
@@ -1521,7 +1536,7 @@ KFR_INTRIN void cdigitreverse4_write<false, f64, 32>(complex<f64>* dest, const v
cwrite<1>(dest + 15, part<16, 15>(x));
}
#endif
-}
-}
+} // namespace internal
+} // namespace kfr
CMT_PRAGMA_MSVC(warning(pop))
diff --git a/include/kfr/dft/reference_dft.hpp b/include/kfr/dft/reference_dft.hpp
@@ -41,8 +41,8 @@ void reference_fft_pass(Tnumber pi2, size_t N, size_t offset, size_t delta, int
Tnumber (*X)[2], Tnumber (*XX)[2])
{
const size_t N2 = N / 2;
- using std::sin;
using std::cos;
+ using std::sin;
if (N != 2)
{
@@ -125,14 +125,14 @@ void reference_fft(T* out, const complex<T>* in, size_t size)
reference_fft_pass<Tnumber>(pi2, size, 0, 1, inversion ? -1 : +1, Tcmplx(datain.data()),
Tcmplx(dataout.data()), Tcmplx(temp.data()));
for (size_t i = 0; i < size; i++)
- out[i] = dataout[i].real();
+ out[i] = dataout[i].real();
}
template <typename Tnumber = double, typename T>
void reference_dft(complex<T>* out, const complex<T>* in, size_t size, bool inversion = false)
{
- using std::sin;
using std::cos;
+ using std::sin;
if (is_poweroftwo(size))
{
return reference_fft<Tnumber>(out, in, size, inversion);
@@ -177,6 +177,14 @@ void reference_dft(complex<T>* out, const complex<T>* in, size_t size, bool inve
}
}
+template <typename Tnumber = double, typename T>
+inline univector<complex<T>> reference_dft(const univector<complex<T>>& in, bool inversion = false)
+{
+ univector<complex<T>> out(in.size());
+ reference_dft(&out[0], &in[0], in.size(), inversion);
+ return out;
+}
+
template <typename T>
struct reference_dft_plan
{
@@ -195,4 +203,4 @@ struct reference_dft_plan
static constexpr size_t temp_size = 0;
const size_t size;
};
-}
+} // namespace kfr