kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

dft_test.cpp (19544B)


      1 /**
      2  * KFR (https://www.kfrlib.com)
      3  * Copyright (C) 2016-2023 Dan Cazarin
      4  * See LICENSE.txt for details
      5  */
      6 
      7 #include <kfr/testo/testo.hpp>
      8 
      9 #include <chrono>
     10 #include <kfr/base.hpp>
     11 #include <kfr/dft.hpp>
     12 #include <kfr/dsp.hpp>
     13 #include <kfr/io/tostring.hpp>
     14 #include <set>
     15 
     16 using namespace kfr;
     17 
     18 namespace CMT_ARCH_NAME
     19 {
     20 
     21 TEST(print_vector_capacity)
     22 {
     23     println("vector_capacity<float> = ", vector_capacity<float>);
     24     println("vector_capacity<double> = ", vector_capacity<double>);
     25 
     26     println("fft_config<float>::process_width = ", vector_capacity<float> / 16);
     27     println("fft_config<double>::process_width = ", vector_capacity<double> / 16);
     28 }
     29 
     30 #ifdef CMT_NATIVE_F64
     31 constexpr ctypes_t<float, double> dft_float_types{};
     32 #else
     33 constexpr ctypes_t<float> dft_float_types{};
     34 #endif
     35 
     36 #if !defined(KFR_NO_PERF_TESTS)
     37 
     38 static void full_barrier()
     39 {
     40 #if defined(CMT_ARCH_NEON)
     41     asm volatile("dmb ish" ::: "memory");
     42 #elif defined(CMT_COMPILER_GNU)
     43     asm volatile("mfence" ::: "memory");
     44 #else
     45     _ReadWriteBarrier();
     46 #endif
     47 }
     48 static CMT_NOINLINE void dont_optimize(const void* in)
     49 {
     50 #ifdef CMT_COMPILER_GNU
     51     asm volatile("" : "+m"(in));
     52 #else
     53     volatile uint8_t a = *reinterpret_cast<const uint8_t*>(in);
     54 #endif
     55 }
     56 
     57 template <typename T>
     58 static void perf_test_t(int size)
     59 {
     60     print("[PERFORMANCE] DFT ", fmt<'s', 6>(type_name<T>()), " ", fmt<'d', 6>(size), "...");
     61     random_state gen1 = random_init(2247448713, 915890490, 864203735, 2982561);
     62     random_state gen2 = random_init(2982561, 2247448713, 915890490, 864203735);
     63     std::chrono::high_resolution_clock::duration duration(0);
     64     dft_plan<T> dft(size);
     65     univector<u8> tmp(dft.temp_size);
     66     uint64_t counter = 0;
     67     while (duration < std::chrono::seconds(1))
     68     {
     69         univector<complex<T>> data(size);
     70         data = make_complex(gen_random_range<T>(gen1, -1.0, +1.0), gen_random_range<T>(gen2, -1.0, +1.0));
     71         full_barrier();
     72         auto start = std::chrono::high_resolution_clock::now();
     73         dft.execute(data, data, tmp);
     74 
     75         full_barrier();
     76         duration += std::chrono::high_resolution_clock::now() - start;
     77         dont_optimize(data.data());
     78         ++counter;
     79     }
     80     double opspersecond = counter / (std::chrono::nanoseconds(duration).count() / 1'000'000'000.0);
     81     println(" ", fmt<'f', 12, 1>(opspersecond), " ops/second");
     82 }
     83 
     84 static void perf_test(int size)
     85 {
     86     perf_test_t<float>(size);
     87     perf_test_t<double>(size);
     88 }
     89 
     90 TEST(test_performance)
     91 {
     92     for (int size = 16; size <= 65536; size <<= 1)
     93     {
     94         perf_test(size);
     95     }
     96 
     97 #ifndef KFR_DFT_NO_NPo2
     98     perf_test(210);
     99     perf_test(3150);
    100     perf_test(211);
    101     perf_test(3163);
    102 #endif
    103 }
    104 #endif
    105 
    106 TEST(test_convolve)
    107 {
    108     univector<fbase, 5> a({ 1, 2, 3, 4, 5 });
    109     univector<fbase, 5> b({ 0.25, 0.5, 1.0, -2.0, 1.5 });
    110     univector<fbase> c = convolve(a, b);
    111     CHECK(c.size() == 9u);
    112     CHECK(rms(c - univector<fbase>({ 0.25, 1., 2.75, 2.5, 3.75, 3.5, 1.5, -4., 7.5 })) < 0.0001);
    113 }
    114 
    115 TEST(test_complex_convolve)
    116 {
    117     univector<complex<fbase>, 5> a({ 1, 2, 3, 4, 5 });
    118     univector<complex<fbase>, 5> b({ 0.25, 0.5, 1.0, -2.0, 1.5 });
    119     univector<complex<fbase>> c = convolve(a, b);
    120     CHECK(c.size() == 9u);
    121     CHECK(rms(cabs(c - univector<complex<fbase>>({ 0.25, 1., 2.75, 2.5, 3.75, 3.5, 1.5, -4., 7.5 }))) <
    122           0.0001);
    123 }
    124 
    125 TEST(test_convolve_filter)
    126 {
    127     univector<fbase, 5> a({ 1, 2, 3, 4, 5 });
    128     univector<fbase, 5> b({ 0.25, 0.5, 1.0, -2.0, 1.5 });
    129     univector<fbase, 5> dest;
    130     convolve_filter<fbase> filter(a);
    131     filter.apply(dest, b);
    132     CHECK(rms(dest - univector<fbase>({ 0.25, 1., 2.75, 2.5, 3.75 })) < 0.0001);
    133 }
    134 
    135 TEST(test_complex_convolve_filter)
    136 {
    137     univector<complex<fbase>, 5> a({ 1, 2, 3, 4, 5 });
    138     univector<complex<fbase>, 5> b({ 0.25, 0.5, 1.0, -2.0, 1.5 });
    139     univector<complex<fbase>, 5> dest;
    140     convolve_filter<complex<fbase>> filter(a);
    141     filter.apply(dest, b);
    142     CHECK(rms(cabs(dest - univector<complex<fbase>>({ 0.25, 1., 2.75, 2.5, 3.75 }))) < 0.0001);
    143     filter.apply(dest, b);
    144     CHECK(rms(cabs(dest - univector<complex<fbase>>({ 0.25, 1., 2.75, 2.5, 3.75 }))) > 0.0001);
    145     filter.reset();
    146     filter.apply(dest, b);
    147     CHECK(rms(cabs(dest - univector<complex<fbase>>({ 0.25, 1., 2.75, 2.5, 3.75 }))) < 0.0001);
    148 }
    149 
    150 TEST(test_correlate)
    151 {
    152     univector<fbase, 5> a({ 1, 2, 3, 4, 5 });
    153     univector<fbase, 5> b({ 0.25, 0.5, 1.0, -2.0, 1.5 });
    154     univector<fbase> c = correlate(a, b);
    155     CHECK(c.size() == 9u);
    156     CHECK(rms(c - univector<fbase>({ 1.5, 1., 1.5, 2.5, 3.75, -4., 7.75, 3.5, 1.25 })) < 0.0001);
    157 }
    158 
    159 TEST(test_complex_correlate)
    160 {
    161     univector<complex<fbase>, 5> a({ 1, 2, 3, 4, 5 });
    162     univector<complex<fbase>, 5> b({ 0.25, 0.5, 1.0, -2.0, 1.5 });
    163     univector<complex<fbase>> c = correlate(a, b);
    164     CHECK(c.size() == 9u);
    165     CHECK(rms(cabs(c - univector<fbase>({ 1.5, 1., 1.5, 2.5, 3.75, -4., 7.75, 3.5, 1.25 }))) < 0.0001);
    166 }
    167 
    168 #if defined CMT_ARCH_ARM || !defined NDEBUG
    169 constexpr size_t fft_stopsize = 12;
    170 #ifndef KFR_DFT_NO_NPo2
    171 constexpr size_t dft_stopsize = 101;
    172 #endif
    173 #else
    174 constexpr size_t fft_stopsize = 21;
    175 #ifndef KFR_DFT_NO_NPo2
    176 constexpr size_t dft_stopsize = 257;
    177 #endif
    178 #endif
    179 
    180 TEST(fft_accuracy)
    181 {
    182 #ifdef DEBUG_DFT_PROGRESS
    183     testo::active_test()->show_progress = true;
    184 #endif
    185     random_state gen = random_init(2247448713, 915890490, 864203735, 2982561);
    186     std::set<size_t> size_set;
    187     univector<size_t> sizes = truncate(counter(), fft_stopsize);
    188     sizes                   = round(pow(2.0, sizes));
    189 
    190 #ifndef KFR_DFT_NO_NPo2
    191     univector<size_t> sizes2 = truncate(2 + counter(), dft_stopsize - 2);
    192     for (size_t s : sizes2)
    193     {
    194         if (std::find(sizes.begin(), sizes.end(), s) == sizes.end())
    195             sizes.push_back(s);
    196     }
    197 #endif
    198 #ifdef DEBUG_DFT_PROGRESS
    199     println(sizes);
    200 #endif
    201 
    202     testo::matrix(
    203         named("type") = dft_float_types, //
    204         named("size") = sizes, //
    205         [&gen](auto type, size_t size)
    206         {
    207             using float_type      = type_of<decltype(type)>;
    208             const double min_prec = 0.000001 * std::log(size) * size;
    209 
    210             for (bool inverse : { false, true })
    211             {
    212                 testo::scope s(inverse ? "complex-inverse" : "complex-direct");
    213                 univector<complex<float_type>> in =
    214                     truncate(gen_random_range<float_type>(gen, -1.0, +1.0), size);
    215                 univector<complex<float_type>> out    = in;
    216                 univector<complex<float_type>> refout = out;
    217                 univector<complex<float_type>> outo   = in;
    218                 const dft_plan<float_type> dft(size);
    219                 double min_prec2 = dft.arblen ? 2 * min_prec : min_prec;
    220                 if (!inverse)
    221                 {
    222 #if DEBUG_DFT_PROGRESS
    223                     dft.dump();
    224 #endif
    225                 }
    226                 univector<u8> temp(dft.temp_size);
    227 
    228                 reference_dft(refout.data(), in.data(), size, inverse);
    229                 dft.execute(outo, in, temp, inverse);
    230                 dft.execute(out, out, temp, inverse);
    231 
    232                 const float_type rms_diff_inplace = rms(cabs(refout - out));
    233                 CHECK(rms_diff_inplace <= min_prec2);
    234                 const float_type rms_diff_outofplace = rms(cabs(refout - outo));
    235                 CHECK(rms_diff_outofplace <= min_prec2);
    236             }
    237 
    238             if (is_even(size))
    239             {
    240                 index_t csize = dft_plan_real<float_type>::complex_size_for(size, dft_pack_format::CCs);
    241                 univector<float_type> in = truncate(gen_random_range<float_type>(gen, -1.0, +1.0), size);
    242 
    243                 univector<complex<float_type>> out    = truncate(dimensions<1>(scalar(qnan)), csize);
    244                 univector<complex<float_type>> refout = truncate(dimensions<1>(scalar(qnan)), csize);
    245                 const dft_plan_real<float_type> dft(size);
    246                 univector<u8> temp(dft.temp_size);
    247 
    248                 {
    249                     testo::scope s("real-direct");
    250                     reference_dft(refout.data(), in.data(), size);
    251                     dft.execute(out, in, temp);
    252                     float_type rms_diff_outofplace = rms(cabs(refout - out));
    253                     CHECK(rms_diff_outofplace <= min_prec);
    254 
    255                     univector<complex<float_type>> outi(csize);
    256                     outi = padded(make_univector(ptr_cast<complex<float_type>>(in.data()), size / 2),
    257                                   complex<float_type>{ 0.f });
    258                     dft.execute(outi.data(), ptr_cast<float_type>(outi.data()), temp.data());
    259                     float_type rms_diff_inplace = rms(cabs(refout - outi.truncate(csize)));
    260                     CHECK(rms_diff_inplace <= min_prec);
    261                 }
    262 
    263                 {
    264                     testo::scope s("real-inverse");
    265                     univector<float_type> out2(size, 0.f);
    266                     dft.execute(out2, out, temp);
    267                     out2                           = out2 / size;
    268                     float_type rms_diff_outofplace = rms(in - out2);
    269                     CHECK(rms_diff_outofplace <= min_prec);
    270 
    271                     univector<float_type> outi(2 * csize);
    272                     outi = make_univector(ptr_cast<float_type>(out.data()), 2 * csize);
    273 
    274                     dft.execute(outi.data(), ptr_cast<complex<float_type>>(outi.data()), temp.data());
    275                     outi                        = outi / size;
    276                     float_type rms_diff_inplace = rms(in - outi.truncate(size));
    277                     CHECK(rms_diff_inplace <= min_prec);
    278                 }
    279             }
    280         });
    281 }
    282 
    283 TEST(dct)
    284 {
    285     constexpr size_t size = 16;
    286     dct_plan<float> plan(size);
    287     univector<float, size> in = counter();
    288     univector<float, size> out;
    289     univector<float, size> outinv;
    290     univector<u8> tmp(plan.temp_size);
    291     plan.execute(out, in, tmp, false);
    292 
    293     univector<float, size> refout = { 120.f, -51.79283109806667f,  0.f, -5.6781471211595695f,
    294                                       0.f,   -1.9843883778092053f, 0.f, -0.9603691873838152f,
    295                                       0.f,   -0.5308329190495176f, 0.f, -0.3030379000702155f,
    296                                       0.f,   -0.1584982220313824f, 0.f, -0.0494839805703826f };
    297 
    298     CHECK(rms(refout - out) < 0.00001f);
    299 
    300     plan.execute(outinv, in, tmp, true);
    301 
    302     univector<float, size> refoutinv = { 59.00747544192212f,  -65.54341437693878f,  27.70332758523579f,
    303                                          -24.56124678824279f, 15.546989102481612f,  -14.293082621965974f,
    304                                          10.08224348063459f,  -9.38097406470581f,   6.795411054455922f,
    305                                          -6.320715753372687f, 4.455202292297903f,   -4.0896421269390455f,
    306                                          2.580439536964837f,  -2.2695816108369176f, 0.9311870090070382f,
    307                                          -0.643618159997807f };
    308 
    309     CHECK(rms(refoutinv - outinv) < 0.00001f);
    310 }
    311 
    312 template <typename T, index_t Dims, typename dft_type, typename dft_real_type>
    313 static void test_dft_md_t(random_state& gen, shape<Dims> shape)
    314 {
    315     index_t size = shape.product();
    316     testo::scope s(as_string("shape=", shape));
    317 
    318     const double min_prec = 0.000002 * std::log(size) * size;
    319 
    320     {
    321         const dft_type dft(shape);
    322 #if DEBUG_DFT_PROGRESS
    323         dft.dump();
    324 #endif
    325         univector<complex<T>> in = truncate(gen_random_range<T>(gen, -1.0, +1.0), size);
    326         for (bool inverse : { false, true })
    327         {
    328             testo::scope s(inverse ? "complex-inverse" : "complex-direct");
    329             univector<complex<T>> out    = in;
    330             univector<complex<T>> refout = out;
    331             univector<complex<T>> outo   = in;
    332             univector<u8> temp(dft.temp_size);
    333 
    334             reference_dft_md(refout.data(), in.data(), shape, inverse);
    335             dft.execute(outo.data(), in.data(), temp.data(), inverse);
    336             dft.execute(out.data(), out.data(), temp.data(), inverse);
    337 
    338             const T rms_diff_inplace = rms(cabs(refout - out));
    339             CHECK(rms_diff_inplace <= min_prec);
    340             const T rms_diff_outofplace = rms(cabs(refout - outo));
    341             CHECK(rms_diff_outofplace <= min_prec);
    342         }
    343     }
    344 
    345     if (is_even(shape.back()))
    346     {
    347         index_t csize   = dft_plan_md_real<float, Dims>::complex_size_for(shape).product();
    348         univector<T> in = truncate(gen_random_range<T>(gen, -1.0, +1.0), size);
    349 
    350         univector<complex<T>> out    = truncate(dimensions<1>(scalar(qnan)), csize);
    351         univector<complex<T>> refout = truncate(dimensions<1>(scalar(qnan)), csize);
    352         const dft_real_type dft(shape, true);
    353 #if DEBUG_DFT_PROGRESS
    354         dft.dump();
    355 #endif
    356         univector<u8> temp(dft.temp_size);
    357 
    358         {
    359             testo::scope s("real-direct");
    360             reference_dft_md(refout.data(), in.data(), shape);
    361             dft.execute(out.data(), in.data(), temp.data());
    362             T rms_diff_outofplace = rms(cabs(refout - out));
    363             CHECK(rms_diff_outofplace <= min_prec);
    364 
    365             univector<complex<T>> outi(csize);
    366             outi = padded(make_univector(ptr_cast<complex<T>>(in.data()), size / 2), complex<T>{ 0.f });
    367             dft.execute(outi.data(), ptr_cast<T>(outi.data()), temp.data());
    368             T rms_diff_inplace = rms(cabs(refout - outi));
    369             CHECK(rms_diff_inplace <= min_prec);
    370         }
    371 
    372         {
    373             testo::scope s("real-inverse");
    374             univector<T> out2(dft.real_out_size(), 0.f);
    375             dft.execute(out2.data(), out.data(), temp.data());
    376             out2                  = out2 / size;
    377             T rms_diff_outofplace = rms(in - out2.truncate(size));
    378             CHECK(rms_diff_outofplace <= min_prec);
    379 
    380             univector<T> outi(2 * csize);
    381             outi = make_univector(ptr_cast<T>(out.data()), 2 * csize);
    382             dft.execute(outi.data(), ptr_cast<complex<T>>(outi.data()), temp.data());
    383             outi               = outi / size;
    384             T rms_diff_inplace = rms(in - outi.truncate(size));
    385             CHECK(rms_diff_inplace <= min_prec);
    386         }
    387     }
    388 }
    389 
    390 template <typename T, index_t Dims>
    391 static void test_dft_md(random_state& gen, shape<Dims> shape)
    392 {
    393     {
    394         testo::scope s("compile-time dims");
    395         test_dft_md_t<T, Dims, dft_plan_md<T, Dims>, dft_plan_md_real<T, Dims>>(gen, shape);
    396     }
    397     {
    398         testo::scope s("runtime dims");
    399         test_dft_md_t<T, Dims, dft_plan_md<T, dynamic_shape>, dft_plan_md_real<T, dynamic_shape>>(gen, shape);
    400     }
    401 }
    402 
    403 TEST(dft_md)
    404 {
    405     random_state gen = random_init(2247448713, 915890490, 864203735, 2982561);
    406 
    407     testo::matrix(named("type") = dft_float_types, //
    408                   [&gen](auto type)
    409                   {
    410                       using float_type = type_of<decltype(type)>;
    411                       test_dft_md<float_type>(gen, shape{ 120 });
    412                       test_dft_md<float_type>(gen, shape{ 2, 60 });
    413                       test_dft_md<float_type>(gen, shape{ 3, 40 });
    414                       test_dft_md<float_type>(gen, shape{ 4, 30 });
    415                       test_dft_md<float_type>(gen, shape{ 5, 24 });
    416                       test_dft_md<float_type>(gen, shape{ 6, 20 });
    417                       test_dft_md<float_type>(gen, shape{ 8, 15 });
    418                       test_dft_md<float_type>(gen, shape{ 10, 12 });
    419                       test_dft_md<float_type>(gen, shape{ 12, 10 });
    420                       test_dft_md<float_type>(gen, shape{ 15, 8 });
    421                       test_dft_md<float_type>(gen, shape{ 20, 6 });
    422                       test_dft_md<float_type>(gen, shape{ 24, 5 });
    423                       test_dft_md<float_type>(gen, shape{ 30, 4 });
    424                       test_dft_md<float_type>(gen, shape{ 40, 3 });
    425                       test_dft_md<float_type>(gen, shape{ 60, 2 });
    426 
    427                       test_dft_md<float_type>(gen, shape{ 2, 3, 24 });
    428                       test_dft_md<float_type>(gen, shape{ 12, 5, 2 });
    429                       test_dft_md<float_type>(gen, shape{ 5, 12, 2 });
    430 
    431                       test_dft_md<float_type>(gen, shape{ 2, 3, 2, 12 });
    432                       test_dft_md<float_type>(gen, shape{ 3, 4, 5, 2 });
    433                       test_dft_md<float_type>(gen, shape{ 5, 4, 3, 2 });
    434 
    435                       test_dft_md<float_type>(gen, shape{ 5, 2, 2, 3, 2 });
    436                       test_dft_md<float_type>(gen, shape{ 2, 5, 2, 2, 3 });
    437 
    438                       test_dft_md<float_type>(gen, shape{ 1, 120 });
    439                       test_dft_md<float_type>(gen, shape{ 120, 1 });
    440                       test_dft_md<float_type>(gen, shape{ 2, 1, 1, 60 });
    441                       test_dft_md<float_type>(gen, shape{ 1, 2, 10, 2, 1, 3 });
    442 
    443                       test_dft_md<float_type>(gen, shape{ 4, 4 });
    444                       test_dft_md<float_type>(gen, shape{ 4, 4, 4 });
    445                       test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4 });
    446                       test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4, 4 });
    447                       test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4, 4, 4 });
    448                       test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4, 4, 4, 4 });
    449                       test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4, 4, 4, 4, 4 });
    450                       test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4, 4, 4, 4, 4, 4 });
    451                       test_dft_md<float_type>(gen, shape{ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 });
    452 #if defined NDEBUG
    453                       test_dft_md<float_type>(gen, shape{ 512, 512 });
    454                       test_dft_md<float_type>(gen, shape{ 32, 32, 32 });
    455                       test_dft_md<float_type>(gen, shape{ 8, 8, 8, 8 });
    456                       test_dft_md<float_type>(gen, shape{ 2, 2, 2, 2, 2, 2 });
    457 
    458                       test_dft_md<float_type>(gen, shape{ 1, 65536 });
    459                       test_dft_md<float_type>(gen, shape{ 2, 65536 });
    460                       test_dft_md<float_type>(gen, shape{ 3, 65536 });
    461                       test_dft_md<float_type>(gen, shape{ 4, 65536 });
    462                       test_dft_md<float_type>(gen, shape{ 65536, 1 });
    463                       test_dft_md<float_type>(gen, shape{ 65536, 2 });
    464                       test_dft_md<float_type>(gen, shape{ 65536, 3 });
    465                       test_dft_md<float_type>(gen, shape{ 65536, 4 });
    466 
    467                       test_dft_md<float_type>(gen, shape{ 1, 2 });
    468                       test_dft_md<float_type>(gen, shape{ 1, 2, 3 });
    469                       test_dft_md<float_type>(gen, shape{ 1, 2, 3, 4 });
    470                       test_dft_md<float_type>(gen, shape{ 1, 2, 3, 4, 5 });
    471                       test_dft_md<float_type>(gen, shape{ 1, 2, 3, 4, 5, 6 });
    472                       test_dft_md<float_type>(gen, shape{ 1, 2, 3, 4, 5, 6, 7 });
    473                       test_dft_md<float_type>(gen, shape{ 1, 2, 3, 4, 5, 6, 7, 8 });
    474                       test_dft_md<float_type>(gen, shape{ 2, 1 });
    475                       test_dft_md<float_type>(gen, shape{ 3, 2, 1 });
    476                       test_dft_md<float_type>(gen, shape{ 4, 3, 2, 1 });
    477                       test_dft_md<float_type>(gen, shape{ 5, 4, 3, 2, 1 });
    478                       test_dft_md<float_type>(gen, shape{ 6, 5, 4, 3, 2, 1 });
    479                       test_dft_md<float_type>(gen, shape{ 7, 6, 5, 4, 3, 2, 1 });
    480                       test_dft_md<float_type>(gen, shape{ 8, 7, 6, 5, 4, 3, 2, 1 });
    481 #endif
    482                   });
    483 }
    484 
    485 } // namespace CMT_ARCH_NAME
    486 
    487 #ifndef KFR_NO_MAIN
    488 int main()
    489 {
    490     println(library_version(), " running on ", cpu_runtime());
    491 
    492     return testo::run_all("", false);
    493 }
    494 #endif