kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

commit 84b9d8ea9e0050889602baf36daaa26aea32941c
parent a895e321fb66722689d3f139a7078da0553b48cb
Author: gopher2008 <cjhu2008@gmail.com>
Date:   Tue, 13 Dec 2022 12:13:56 +0800

merge with origin

Diffstat:
Minclude/kfr/dft/impl/fft-impl.hpp | 27+++++++++++++++------------
Mtests/dft_test.cpp | 18++++++++++++++++++
2 files changed, 33 insertions(+), 12 deletions(-)

diff --git a/include/kfr/dft/impl/fft-impl.hpp b/include/kfr/dft/impl/fft-impl.hpp @@ -1003,7 +1003,7 @@ to_fmt(size_t real_size, const complex<T>* rtwiddle, complex<T>* out, const comp constexpr size_t width = vector_width<T> * 2; const cvec<T, 1> dc = cread<1>(out); - const size_t count = csize / 2; + const size_t count = (csize + 1) / 2; block_process(count - 1, csizes_t<width, 1>(), [&](size_t i, auto w) @@ -1022,6 +1022,7 @@ to_fmt(size_t real_size, const complex<T>* rtwiddle, complex<T>* out, const comp cwrite<width>(out + csize - i - widthm1, reverse<2>(negodd(T(0.5) * (f1k - t)))); }); + if (is_even(csize)) { size_t k = csize / 2; const cvec<T, 1> fpk = cread<1>(in + k); @@ -1068,7 +1069,7 @@ void from_fmt(size_t real_size, complex<T>* rtwiddle, complex<T>* out, const com } constexpr size_t width = vector_width<T> * 2; - const size_t count = csize / 2; + const size_t count = (csize + 1) / 2; block_process(count - 1, csizes_t<width, 1>(), [&](size_t i, auto w) @@ -1080,13 +1081,13 @@ void from_fmt(size_t real_size, complex<T>* rtwiddle, complex<T>* out, const com const cvec<T, width> fpk = cread<width>(in + i); const cvec<T, width> fpnk = reverse<2>(negodd(cread<width>(in + csize - i - widthm1))); - const cvec<T, width> f1k = fpk + fpnk; - const cvec<T, width> f2k = fpk - fpnk; - const cvec<T, width> t = cmul_conj(f2k, tw); - cwrite<width>(out + i, f1k + t); - cwrite<width>(out + csize - i - widthm1, reverse<2>(negodd(f1k - t))); - }); - + const cvec<T, width> f1k = fpk + fpnk; + const cvec<T, width> f2k = fpk - fpnk; + const cvec<T, width> t = cmul_conj(f2k, tw); + cwrite<width>(out + i, f1k + t); + cwrite<width>(out + csize - i - widthm1, reverse<2>(negodd(f1k - t))); + }); + if (is_even(csize)) { size_t k = csize / 2; const cvec<T, 1> fpk = cread<1>(in + k); @@ -1132,7 +1133,8 @@ public: { this->user = static_cast<int>(fmt); this->stage_size = real_size; - this->data_size = align_up(sizeof(complex<T>) * (real_size / 4), platform<>::native_cache_alignment); + const size_t count = (real_size / 2 + 1) / 2; + this->data_size = align_up(sizeof(complex<T>) * count, platform<>::native_cache_alignment); } void do_initialize(size_t) override { @@ -1140,14 +1142,15 @@ public: constexpr size_t width = vector_width<T> * 2; size_t real_size = this->stage_size; complex<T>* rtwiddle = ptr_cast<complex<T>>(this->data); - block_process(real_size / 4, csizes_t<width, 1>(), + const size_t count = (real_size / 2 + 1) / 2; + block_process(count, csizes_t<width, 1>(), [=](size_t i, auto w) { constexpr size_t width = val_of(decltype(w)()); cwrite<width>( rtwiddle + i, cossin(dup(-constants<T>::pi * - ((enumerate<T, width>() + i + real_size / 4) / (real_size / 2))))); + ((enumerate<T, width>() + i + real_size / T(4)) / (real_size / 2))))); }); } void do_execute(cdirect_t, complex<T>* out, const complex<T>* in, u8* temp) override diff --git a/tests/dft_test.cpp b/tests/dft_test.cpp @@ -175,6 +175,24 @@ TEST(fft_real) CHECK(rms(rev - in) <= 0.00001f); } +TEST(fft_real_not_size_4N) +{ + kfr::univector<double, 6> in = counter(); + auto out = realdft(in); + kfr::univector<kfr::complex<double>> expected { + 15.0, { -3, 5.19615242}, {-3, +1.73205081}, -3.0 }; + CHECK(rms(cabs(out - expected)) <= 0.00001f); + kfr::univector<double, 6> rev = irealdft(out) / 6; + CHECK(rms(rev - in) <= 0.00001f); + + random_state gen = random_init(2247448713, 915890490, 864203735, 2982561); + constexpr size_t size = 66; + kfr::univector<double, size> in2 = gen_random_range<double>(gen, -1.0, +1.0); + kfr::univector<kfr::complex<double>, size / 2 + 1> out2 = realdft(in2); + kfr::univector<double, size> rev2 = irealdft(out2) / size; + CHECK(rms(rev2 - in2) <= 0.00001f); +} + TEST(fft_accuracy) { #ifdef DEBUG_DFT_PROGRESS