commit 351b6f77af644f85fd47e393bd04e301b7efd8f2
parent 61cf8a69d0811b7b6d93d836f12dde966cd275e1
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Mon, 16 Oct 2023 08:14:02 +0100
DFT: support degenerate sizes, more DFT tests
Diffstat:
4 files changed, 56 insertions(+), 45 deletions(-)
diff --git a/include/kfr/dft/fft.hpp b/include/kfr/dft/fft.hpp
@@ -311,6 +311,7 @@ struct dft_plan_real : dft_plan<T>
explicit dft_plan_real(cpu_t cpu, size_t size, dft_pack_format fmt = dft_pack_format::CCs)
: dft_plan<T>(typename dft_plan<T>::noinit{}, size / 2), size(size), fmt(fmt)
{
+ KFR_LOGIC_CHECK(is_even(size), "dft_plan_real requires size to be even");
#ifdef KFR_DFT_MULTI
if (cpu == cpu_t::runtime)
cpu = get_cpu();
diff --git a/include/kfr/dft/impl/fft-impl.hpp b/include/kfr/dft/impl/fft-impl.hpp
@@ -591,6 +591,21 @@ template <typename T, size_t log2n>
struct fft_specialization;
template <typename T>
+struct fft_specialization<T, 0> : dft_stage<T>
+{
+ fft_specialization(size_t) { this->name = dft_name(this); }
+
+ constexpr static bool aligned = false;
+ DFT_STAGE_FN
+
+ template <bool inverse>
+ KFR_MEM_INTRINSIC void do_execute(complex<T>* out, const complex<T>* in, u8*)
+ {
+ out[0] = in[0];
+ }
+};
+
+template <typename T>
struct fft_specialization<T, 1> : dft_stage<T>
{
fft_specialization(size_t) { this->name = dft_name(this); }
@@ -953,7 +968,7 @@ KFR_INTRINSIC void init_fft(dft_plan<T>* self, size_t size, dft_order)
{
const size_t log2n = ilog2(size);
cswitch(
- csizes_t<1, 2, 3, 4, 5, 6, 7, 8, 9, 10>(), log2n,
+ csizes_t<0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10>(), log2n,
[&](auto log2n)
{
(void)log2n;
@@ -1123,6 +1138,8 @@ KFR_INTRINSIC void initialize_stages(dft_plan<T>* self)
template <typename T>
void dft_initialize(dft_plan<T>& plan)
{
+ if (plan.size == 0)
+ return;
initialize_stages(&plan);
initialize_data(&plan);
initialize_order(&plan);
@@ -1171,6 +1188,8 @@ public:
template <typename T>
void dft_real_initialize(dft_plan_real<T>& plan)
{
+ if (plan.size == 0)
+ return;
initialize_stages(&plan);
plan.fmt_stage.reset(new dft_stage_real_repack<T>(plan.size, plan.fmt));
plan.data_size += plan.fmt_stage->data_size;
diff --git a/include/kfr/dft/reference_dft.hpp b/include/kfr/dft/reference_dft.hpp
@@ -42,6 +42,7 @@ template <typename Tnumber = double>
void reference_fft_pass(Tnumber pi2, size_t N, size_t offset, size_t delta, int flag, Tnumber (*x)[2],
Tnumber (*X)[2], Tnumber (*XX)[2])
{
+ KFR_LOGIC_CHECK(N >= 2, "reference_fft_pass: invalid N");
const size_t N2 = N / 2;
using std::cos;
using std::sin;
@@ -84,26 +85,12 @@ template <typename Tnumber = double, typename T>
void reference_fft(complex<T>* out, const complex<T>* in, size_t size, bool inversion = false)
{
using Tcmplx = Tnumber(*)[2];
- if (size < 2)
+ if (size < 1)
return;
- std::vector<complex<Tnumber>> datain(size);
- std::vector<complex<Tnumber>> dataout(size);
- std::vector<complex<Tnumber>> temp(size);
- std::copy(in, in + size, datain.begin());
- const Tnumber pi2 = c_pi<Tnumber, 2, 1>;
- reference_fft_pass<Tnumber>(pi2, size, 0, 1, inversion ? -1 : +1, Tcmplx(datain.data()),
- Tcmplx(dataout.data()), Tcmplx(temp.data()));
- std::copy(dataout.begin(), dataout.end(), out);
-}
-
-/// @brief Performs Direct Real FFT using reference implementation (slow, used for testing)
-template <typename Tnumber = double, typename T>
-void reference_fft(complex<T>* out, const T* in, size_t size)
-{
- constexpr bool inversion = false;
- using Tcmplx = Tnumber(*)[2];
- if (size < 2)
+ if (size == 1) {
+ out[0] = in[0];
return;
+ }
std::vector<complex<Tnumber>> datain(size);
std::vector<complex<Tnumber>> dataout(size);
std::vector<complex<Tnumber>> temp(size);
@@ -114,25 +101,6 @@ void reference_fft(complex<T>* out, const T* in, size_t size)
std::copy(dataout.begin(), dataout.end(), out);
}
-/// @brief Performs Inverse Real FFT using reference implementation (slow, used for testing)
-template <typename Tnumber = double, typename T>
-void reference_fft(T* out, const complex<T>* in, size_t size)
-{
- constexpr bool inversion = true;
- using Tcmplx = Tnumber(*)[2];
- if (size < 2)
- return;
- std::vector<complex<Tnumber>> datain(size);
- std::vector<complex<Tnumber>> dataout(size);
- std::vector<complex<Tnumber>> temp(size);
- std::copy(in, in + size, datain.begin());
- const Tnumber pi2 = c_pi<Tnumber, 2, 1>;
- reference_fft_pass<Tnumber>(pi2, size, 0, 1, inversion ? -1 : +1, Tcmplx(datain.data()),
- Tcmplx(dataout.data()), Tcmplx(temp.data()));
- for (size_t i = 0; i < size; i++)
- out[i] = dataout[i].real();
-}
-
/// @brief Performs Complex DFT using reference implementation (slow, used for testing)
template <typename Tnumber = double, typename T>
void reference_dft(complex<T>* out, const complex<T>* in, size_t size, bool inversion = false)
@@ -183,6 +151,29 @@ void reference_dft(complex<T>* out, const complex<T>* in, size_t size, bool inve
}
}
+/// @brief Performs Direct Real DFT using reference implementation (slow, used for testing)
+template <typename T>
+void reference_dft(complex<T>* out, const T* in, size_t size)
+{
+ if (size < 1)
+ return;
+ std::vector<complex<T>> datain(size);
+ std::copy(in, in + size, datain.begin());
+ reference_dft(out, datain.data(), size, false);
+}
+
+/// @brief Performs Inverse Real DFT using reference implementation (slow, used for testing)
+template <typename T>
+void reference_dft(T* out, const complex<T>* in, size_t size)
+{
+ if (size < 1)
+ return;
+ std::vector<complex<T>> dataout(size);
+ reference_dft(dataout.data(), in, size, true);
+ for (size_t i = 0; i < size; i++)
+ out[i] = dataout[i].real();
+}
+
/// @brief Performs DFT using reference implementation (slow, used for testing)
template <typename Tnumber = double, typename T>
inline univector<complex<T>> reference_dft(const univector<complex<T>>& in, bool inversion = false)
diff --git a/tests/dft_test.cpp b/tests/dft_test.cpp
@@ -206,7 +206,7 @@ TEST(fft_accuracy)
#endif
random_state gen = random_init(2247448713, 915890490, 864203735, 2982561);
std::set<size_t> size_set;
- univector<size_t> sizes = truncate(1 + counter(), fft_stopsize - 1);
+ univector<size_t> sizes = truncate(counter(), fft_stopsize);
sizes = round(pow(2.0, sizes));
#ifndef KFR_DFT_NO_NPo2
@@ -251,12 +251,12 @@ TEST(fft_accuracy)
dft.execute(out, out, temp, inverse);
const float_type rms_diff_inplace = rms(cabs(refout - out));
- CHECK(rms_diff_inplace < min_prec2);
+ CHECK(rms_diff_inplace <= min_prec2);
const float_type rms_diff_outofplace = rms(cabs(refout - outo));
- CHECK(rms_diff_outofplace < min_prec2);
+ CHECK(rms_diff_outofplace <= min_prec2);
}
- if (size >= 4 && is_poweroftwo(size))
+ if (is_even(size))
{
univector<float_type> in =
truncate(gen_random_range<float_type>(gen, -1.0, +1.0), size);
@@ -267,18 +267,18 @@ TEST(fft_accuracy)
univector<u8> temp(dft.temp_size);
testo::scope s("real-direct");
- reference_fft(refout.data(), in.data(), size);
+ reference_dft(refout.data(), in.data(), size);
dft.execute(out, in, temp);
float_type rms_diff =
rms(cabs(refout.truncate(size / 2 + 1) - out.truncate(size / 2 + 1)));
- CHECK(rms_diff < min_prec);
+ CHECK(rms_diff <= min_prec);
univector<float_type> out2(size, 0.f);
s.text = "real-inverse";
dft.execute(out2, out, temp);
out2 = out2 / size;
rms_diff = rms(in - out2);
- CHECK(rms_diff < min_prec);
+ CHECK(rms_diff <= min_prec);
}
});
}