commit 84b9d8ea9e0050889602baf36daaa26aea32941c
parent a895e321fb66722689d3f139a7078da0553b48cb
Author: gopher2008 <cjhu2008@gmail.com>
Date: Tue, 13 Dec 2022 12:13:56 +0800
merge with origin
Diffstat:
2 files changed, 33 insertions(+), 12 deletions(-)
diff --git a/include/kfr/dft/impl/fft-impl.hpp b/include/kfr/dft/impl/fft-impl.hpp
@@ -1003,7 +1003,7 @@ to_fmt(size_t real_size, const complex<T>* rtwiddle, complex<T>* out, const comp
constexpr size_t width = vector_width<T> * 2;
const cvec<T, 1> dc = cread<1>(out);
- const size_t count = csize / 2;
+ const size_t count = (csize + 1) / 2;
block_process(count - 1, csizes_t<width, 1>(),
[&](size_t i, auto w)
@@ -1022,6 +1022,7 @@ to_fmt(size_t real_size, const complex<T>* rtwiddle, complex<T>* out, const comp
cwrite<width>(out + csize - i - widthm1, reverse<2>(negodd(T(0.5) * (f1k - t))));
});
+ if (is_even(csize))
{
size_t k = csize / 2;
const cvec<T, 1> fpk = cread<1>(in + k);
@@ -1068,7 +1069,7 @@ void from_fmt(size_t real_size, complex<T>* rtwiddle, complex<T>* out, const com
}
constexpr size_t width = vector_width<T> * 2;
- const size_t count = csize / 2;
+ const size_t count = (csize + 1) / 2;
block_process(count - 1, csizes_t<width, 1>(),
[&](size_t i, auto w)
@@ -1080,13 +1081,13 @@ void from_fmt(size_t real_size, complex<T>* rtwiddle, complex<T>* out, const com
const cvec<T, width> fpk = cread<width>(in + i);
const cvec<T, width> fpnk = reverse<2>(negodd(cread<width>(in + csize - i - widthm1)));
- const cvec<T, width> f1k = fpk + fpnk;
- const cvec<T, width> f2k = fpk - fpnk;
- const cvec<T, width> t = cmul_conj(f2k, tw);
- cwrite<width>(out + i, f1k + t);
- cwrite<width>(out + csize - i - widthm1, reverse<2>(negodd(f1k - t)));
- });
-
+ const cvec<T, width> f1k = fpk + fpnk;
+ const cvec<T, width> f2k = fpk - fpnk;
+ const cvec<T, width> t = cmul_conj(f2k, tw);
+ cwrite<width>(out + i, f1k + t);
+ cwrite<width>(out + csize - i - widthm1, reverse<2>(negodd(f1k - t)));
+ });
+ if (is_even(csize))
{
size_t k = csize / 2;
const cvec<T, 1> fpk = cread<1>(in + k);
@@ -1132,7 +1133,8 @@ public:
{
this->user = static_cast<int>(fmt);
this->stage_size = real_size;
- this->data_size = align_up(sizeof(complex<T>) * (real_size / 4), platform<>::native_cache_alignment);
+ const size_t count = (real_size / 2 + 1) / 2;
+ this->data_size = align_up(sizeof(complex<T>) * count, platform<>::native_cache_alignment);
}
void do_initialize(size_t) override
{
@@ -1140,14 +1142,15 @@ public:
constexpr size_t width = vector_width<T> * 2;
size_t real_size = this->stage_size;
complex<T>* rtwiddle = ptr_cast<complex<T>>(this->data);
- block_process(real_size / 4, csizes_t<width, 1>(),
+ const size_t count = (real_size / 2 + 1) / 2;
+ block_process(count, csizes_t<width, 1>(),
[=](size_t i, auto w)
{
constexpr size_t width = val_of(decltype(w)());
cwrite<width>(
rtwiddle + i,
cossin(dup(-constants<T>::pi *
- ((enumerate<T, width>() + i + real_size / 4) / (real_size / 2)))));
+ ((enumerate<T, width>() + i + real_size / T(4)) / (real_size / 2)))));
});
}
void do_execute(cdirect_t, complex<T>* out, const complex<T>* in, u8* temp) override
diff --git a/tests/dft_test.cpp b/tests/dft_test.cpp
@@ -175,6 +175,24 @@ TEST(fft_real)
CHECK(rms(rev - in) <= 0.00001f);
}
+TEST(fft_real_not_size_4N)
+{
+ kfr::univector<double, 6> in = counter();
+ auto out = realdft(in);
+ kfr::univector<kfr::complex<double>> expected {
+ 15.0, { -3, 5.19615242}, {-3, +1.73205081}, -3.0 };
+ CHECK(rms(cabs(out - expected)) <= 0.00001f);
+ kfr::univector<double, 6> rev = irealdft(out) / 6;
+ CHECK(rms(rev - in) <= 0.00001f);
+
+ random_state gen = random_init(2247448713, 915890490, 864203735, 2982561);
+ constexpr size_t size = 66;
+ kfr::univector<double, size> in2 = gen_random_range<double>(gen, -1.0, +1.0);
+ kfr::univector<kfr::complex<double>, size / 2 + 1> out2 = realdft(in2);
+ kfr::univector<double, size> rev2 = irealdft(out2) / size;
+ CHECK(rms(rev2 - in2) <= 0.00001f);
+}
+
TEST(fft_accuracy)
{
#ifdef DEBUG_DFT_PROGRESS