kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

dft-impl.hpp (18591B)


      1 /** @addtogroup dft
      2  *  @{
      3  */
      4 /*
      5   Copyright (C) 2016-2023 Dan Cazarin (https://www.kfrlib.com)
      6   This file is part of KFR
      7 
      8   KFR is free software: you can redistribute it and/or modify
      9   it under the terms of the GNU General Public License as published by
     10   the Free Software Foundation, either version 2 of the License, or
     11   (at your option) any later version.
     12 
     13   KFR is distributed in the hope that it will be useful,
     14   but WITHOUT ANY WARRANTY; without even the implied warranty of
     15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     16   GNU General Public License for more details.
     17 
     18   You should have received a copy of the GNU General Public License
     19   along with KFR.
     20 
     21   If GPL is not suitable for your project, you must purchase a commercial license to use KFR.
     22   Buying a commercial license is mandatory as soon as you develop commercial activities without
     23   disclosing the source code of your own applications.
     24   See https://www.kfrlib.com for details.
     25  */
     26 #pragma once
     27 
     28 #include <kfr/base/math_expressions.hpp>
     29 #include <kfr/base/simd_expressions.hpp>
     30 #include "dft-fft.hpp"
     31 
     32 CMT_PRAGMA_GNU(GCC diagnostic push)
     33 #if CMT_HAS_WARNING("-Wshadow")
     34 CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wshadow")
     35 #endif
     36 #if CMT_HAS_WARNING("-Wunused-lambda-capture")
     37 CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wunused-lambda-capture")
     38 #endif
     39 #if CMT_HAS_WARNING("-Wpass-failed")
     40 CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wpass-failed")
     41 #endif
     42 
     43 CMT_PRAGMA_MSVC(warning(push))
     44 CMT_PRAGMA_MSVC(warning(disable : 4100))
     45 
     46 namespace kfr
     47 {
     48 
     49 inline namespace CMT_ARCH_NAME
     50 {
     51 constexpr csizes_t<2, 3, 4, 5, 6, 7, 8, 9, 10> dft_radices{};
     52 
     53 namespace intrinsics
     54 {
     55 
     56 template <typename T>
     57 void dft_stage_fixed_initialize(dft_stage<T>* stage, size_t width)
     58 {
     59     complex<T>* twiddle = ptr_cast<complex<T>>(stage->data);
     60     const size_t N      = stage->repeats * stage->radix;
     61     const size_t Nord   = stage->repeats;
     62     size_t i            = 0;
     63 
     64     while (width > 0)
     65     {
     66         CMT_LOOP_NOUNROLL
     67         for (; i < Nord / width * width; i += width)
     68         {
     69             CMT_LOOP_NOUNROLL
     70             for (size_t j = 1; j < stage->radix; j++)
     71             {
     72                 CMT_LOOP_NOUNROLL
     73                 for (size_t k = 0; k < width; k++)
     74                 {
     75                     cvec<T, 1> xx = cossin_conj(broadcast<2, T>(c_pi<T, 2> * (i + k) * j / N));
     76                     ref_cast<cvec<T, 1>>(twiddle[k]) = xx;
     77                 }
     78                 twiddle += width;
     79             }
     80         }
     81         width = width / 2;
     82     }
     83 }
     84 
     85 template <typename T, size_t fixed_radix>
     86 struct dft_stage_fixed_impl : dft_stage<T>
     87 {
     88     dft_stage_fixed_impl(size_t, size_t iterations, size_t blocks)
     89     {
     90         this->name       = dft_name(this);
     91         this->radix      = fixed_radix;
     92         this->blocks     = blocks;
     93         this->repeats    = iterations;
     94         this->recursion  = false; // true;
     95         this->stage_size = fixed_radix * iterations * blocks;
     96         this->data_size  = align_up((this->repeats * (fixed_radix - 1)) * sizeof(complex<T>),
     97                                     platform<>::native_cache_alignment);
     98     }
     99 
    100     constexpr static size_t rradix = fixed_radix;
    101 
    102     constexpr static size_t width = fixed_radix >= 7   ? fft_vector_width<T> / 2
    103                                     : fixed_radix >= 4 ? fft_vector_width<T>
    104                                                        : fft_vector_width<T> * 2;
    105     virtual void do_initialize(size_t) override final { dft_stage_fixed_initialize(this, width); }
    106 
    107     DFT_STAGE_FN
    108     template <bool inverse>
    109     KFR_MEM_INTRINSIC void do_execute(complex<T>* out, const complex<T>* in, u8*)
    110     {
    111         const size_t Nord         = this->repeats;
    112         const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
    113 
    114         const size_t N = Nord * fixed_radix;
    115         CMT_LOOP_NOUNROLL
    116         for (size_t b = 0; b < this->blocks; b++)
    117         {
    118             butterflies(Nord, csize<width>, csize<fixed_radix>, cbool<inverse>, out, in, twiddle, Nord);
    119             in += N;
    120             out += N;
    121         }
    122     }
    123 };
    124 
    125 template <typename T, size_t fixed_radix>
    126 struct dft_stage_fixed_final_impl : dft_stage<T>
    127 {
    128     dft_stage_fixed_final_impl(size_t, size_t iterations, size_t blocks)
    129     {
    130         this->name        = dft_name(this);
    131         this->radix       = fixed_radix;
    132         this->blocks      = blocks;
    133         this->repeats     = iterations;
    134         this->stage_size  = fixed_radix * iterations * blocks;
    135         this->recursion   = false;
    136         this->can_inplace = false;
    137     }
    138     constexpr static size_t width = fixed_radix >= 7   ? fft_vector_width<T> / 2
    139                                     : fixed_radix >= 4 ? fft_vector_width<T>
    140                                                        : fft_vector_width<T> * 2;
    141 
    142     DFT_STAGE_FN
    143     template <bool inverse>
    144     KFR_MEM_INTRINSIC void do_execute(complex<T>* out, const complex<T>* in, u8*)
    145     {
    146         const size_t b = this->blocks;
    147 
    148         butterflies(b, csize<width>, csize<fixed_radix>, cbool<inverse>, out, in, b);
    149     }
    150 };
    151 
    152 template <typename E>
    153 inline E& apply_conj(E& e, cfalse_t)
    154 {
    155     return e;
    156 }
    157 
    158 template <typename E>
    159 inline auto apply_conj(E& e, ctrue_t)
    160 {
    161     return cconj(e);
    162 }
    163 
    164 /// [0, N - 1, N - 2, N - 3, ..., 3, 2, 1]
    165 template <typename E>
    166 struct fft_inverse : expression_with_traits<E>
    167 {
    168     using value_type = typename expression_with_traits<E>::value_type;
    169 
    170     KFR_MEM_INTRINSIC fft_inverse(E&& expr) CMT_NOEXCEPT : expression_with_traits<E>(std::forward<E>(expr)) {}
    171 
    172     friend KFR_INTRINSIC vec<value_type, 1> get_elements(const fft_inverse& self, shape<1> index,
    173                                                          axis_params<0, 1>)
    174     {
    175         const size_t size = get_shape(self).front();
    176         return get_elements(self.first(), index.front() == 0 ? 0 : size - index, axis_params<0, 1>());
    177     }
    178 
    179     template <size_t N>
    180     friend KFR_MEM_INTRINSIC vec<value_type, N> get_elements(const fft_inverse& self, shape<1> index,
    181                                                              axis_params<0, N>)
    182     {
    183         const size_t size = get_shape(self).front();
    184         if (index.front() == 0)
    185         {
    186             return concat(get_elements(self.first(), index, axis_params<0, 1>()),
    187                           reverse(get_elements(self.first(), size - (N - 1), axis_params<0, N - 1>())));
    188         }
    189         return reverse(get_elements(self.first(), size - index - (N - 1), axis_params<0, N>()));
    190     }
    191 };
    192 
    193 template <typename E>
    194 inline auto apply_fft_inverse(E&& e)
    195 {
    196     return fft_inverse<E>(std::forward<E>(e));
    197 }
    198 
    199 template <typename T>
    200 struct dft_arblen_stage_impl : dft_stage<T>
    201 {
    202     dft_arblen_stage_impl(size_t size)
    203         : size(size), fftsize(next_poweroftwo(size) * 2), plan(fftsize, dft_order::internal)
    204     {
    205         this->name        = dft_name(this);
    206         this->radix       = size;
    207         this->blocks      = 1;
    208         this->repeats     = 1;
    209         this->recursion   = false;
    210         this->can_inplace = false;
    211         this->temp_size   = plan.temp_size;
    212         this->stage_size  = size;
    213 
    214         chirp_ = render(cexp(sqr(linspace(T(1) - size, size - T(1), size * 2 - 1, true, ctrue)) *
    215                              complex<T>(0, -1) * c_pi<T> / size));
    216 
    217         ichirpp_ = render(truncate(padded(1 / slice(chirp_, 0, 2 * size - 1)), fftsize));
    218 
    219         univector<u8> temp(plan.temp_size);
    220         plan.execute(ichirpp_, ichirpp_, temp);
    221         xp.resize(fftsize, 0);
    222         xp_fft.resize(fftsize);
    223         invN2 = T(1) / fftsize;
    224     }
    225 
    226     DFT_STAGE_FN
    227     template <bool inverse>
    228     KFR_MEM_INTRINSIC void do_execute(complex<T>* out, const complex<T>* in, u8* temp)
    229     {
    230         const size_t n = this->size;
    231 
    232         auto&& chirp = apply_conj(chirp_, cbool<inverse>);
    233 
    234         xp.slice(0, n) = make_univector(in, n) * slice(chirp, n - 1);
    235 
    236         plan.execute(xp_fft.data(), xp.data(), temp);
    237 
    238         if (inverse)
    239             xp_fft = xp_fft * cconj(apply_fft_inverse(ichirpp_));
    240         else
    241             xp_fft = xp_fft * ichirpp_;
    242         plan.execute(xp_fft.data(), xp_fft.data(), temp, ctrue);
    243 
    244         make_univector(out, n) = xp_fft.slice(n - 1, n) * slice(chirp, n - 1, n) * invN2;
    245     }
    246 
    247     const size_t size;
    248     const size_t fftsize;
    249     T invN2;
    250     dft_plan<T> plan;
    251     univector<complex<T>> chirp_;
    252     univector<complex<T>> ichirpp_;
    253     univector<complex<T>> xp;
    254     univector<complex<T>> xp_fft;
    255 };
    256 
    257 template <typename T, size_t radix1, size_t radix2, size_t size = radix1 * radix2>
    258 struct dft_special_stage_impl : dft_stage<T>
    259 {
    260     dft_special_stage_impl() : stage1(radix1, size / radix1, 1), stage2(radix2, 1, size / radix2)
    261     {
    262         this->name        = dft_name(this);
    263         this->radix       = size;
    264         this->blocks      = 1;
    265         this->repeats     = 1;
    266         this->recursion   = false;
    267         this->can_inplace = false;
    268         this->stage_size  = size;
    269         this->temp_size   = stage1.temp_size + stage2.temp_size + sizeof(complex<T>) * size;
    270         this->data_size   = stage1.data_size + stage2.data_size;
    271     }
    272     void dump() const override
    273     {
    274         dft_stage<T>::dump();
    275         printf("    ");
    276         stage1.dump();
    277         printf("    ");
    278         stage2.dump();
    279     }
    280     void do_initialize(size_t stage_size) override
    281     {
    282         stage1.data = this->data;
    283         stage2.data = this->data + stage1.data_size;
    284         stage1.initialize(stage_size);
    285         stage2.initialize(stage_size);
    286     }
    287     DFT_STAGE_FN
    288     template <bool inverse>
    289     KFR_MEM_INTRINSIC void do_execute(complex<T>* out, const complex<T>* in, u8* temp)
    290     {
    291         complex<T>* scratch = ptr_cast<complex<T>>(temp + stage1.temp_size + stage2.temp_size);
    292         stage1.do_execute(cbool<inverse>, scratch, in, temp);
    293         stage2.do_execute(cbool<inverse>, out, scratch, temp + stage1.temp_size);
    294     }
    295     dft_stage_fixed_impl<T, radix1> stage1;
    296     dft_stage_fixed_final_impl<T, radix2> stage2;
    297 };
    298 
    299 template <typename T, bool final>
    300 struct dft_stage_generic_impl : dft_stage<T>
    301 {
    302     dft_stage_generic_impl(size_t radix, size_t iterations, size_t blocks)
    303     {
    304         this->name        = dft_name(this);
    305         this->radix       = radix;
    306         this->blocks      = blocks;
    307         this->repeats     = iterations;
    308         this->recursion   = false; // true;
    309         this->can_inplace = false;
    310         this->stage_size  = radix * iterations * blocks;
    311         this->temp_size   = align_up(sizeof(complex<T>) * radix, platform<>::native_cache_alignment);
    312         this->data_size =
    313             align_up(sizeof(complex<T>) * sqr(this->radix / 2), platform<>::native_cache_alignment);
    314     }
    315 
    316 protected:
    317     virtual void do_initialize(size_t) override final
    318     {
    319         complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
    320         CMT_LOOP_NOUNROLL
    321         for (size_t i = 0; i < this->radix / 2; i++)
    322         {
    323             CMT_LOOP_NOUNROLL
    324             for (size_t j = 0; j < this->radix / 2; j++)
    325             {
    326                 cwrite<1>(twiddle++, cossin_conj(broadcast<2>((i + 1) * (j + 1) * c_pi<T, 2> / this->radix)));
    327             }
    328         }
    329     }
    330 
    331     DFT_STAGE_FN
    332     template <bool inverse>
    333     KFR_MEM_INTRINSIC void do_execute(complex<T>* out, const complex<T>* in, u8* temp)
    334     {
    335         const complex<T>* twiddle = ptr_cast<complex<T>>(this->data);
    336         const size_t bl           = this->blocks;
    337 
    338         CMT_LOOP_NOUNROLL
    339         for (size_t b = 0; b < bl; b++)
    340             generic_butterfly(this->radix, cbool<inverse>, out + b, in + b * this->radix,
    341                               ptr_cast<complex<T>>(temp), twiddle, bl);
    342     }
    343 };
    344 
    345 template <typename T, typename Tr2>
    346 inline void dft_permute(complex<T>* out, const complex<T>* in, size_t r0, size_t r1, Tr2 first_radix)
    347 {
    348     CMT_ASSUME(r0 > 1);
    349     CMT_ASSUME(r1 > 1);
    350 
    351     CMT_LOOP_NOUNROLL
    352     for (size_t p = 0; p < r0; p++)
    353     {
    354         const complex<T>* in1 = in;
    355         CMT_LOOP_NOUNROLL
    356         for (size_t i = 0; i < r1; i++)
    357         {
    358             const complex<T>* in2 = in1;
    359             CMT_LOOP_UNROLL
    360             for (size_t j = 0; j < first_radix; j++)
    361             {
    362                 *out++ = *in2;
    363                 in2 += r1;
    364             }
    365             in1++;
    366             in += first_radix;
    367         }
    368     }
    369 }
    370 
    371 template <typename T, typename Tr2>
    372 inline void dft_permute_deep(complex<T>*& out, const complex<T>* in, const size_t* radices, size_t count,
    373                              size_t index, size_t inscale, size_t inner_size, Tr2 first_radix)
    374 {
    375     const bool b       = index == 1;
    376     const size_t radix = radices[index];
    377     if (b)
    378     {
    379         CMT_LOOP_NOUNROLL
    380         for (size_t i = 0; i < radix; i++)
    381         {
    382             const complex<T>* in1 = in;
    383             CMT_LOOP_UNROLL
    384             for (size_t j = 0; j < first_radix; j++)
    385             {
    386                 *out++ = *in1;
    387                 in1 += inner_size;
    388             }
    389             in += inscale;
    390         }
    391     }
    392     else
    393     {
    394         const size_t steps        = radix;
    395         const size_t inscale_next = inscale * radix;
    396         CMT_LOOP_NOUNROLL
    397         for (size_t i = 0; i < steps; i++)
    398         {
    399             dft_permute_deep(out, in, radices, count, index - 1, inscale_next, inner_size, first_radix);
    400             in += inscale;
    401         }
    402     }
    403 }
    404 
    405 template <typename T>
    406 struct dft_reorder_stage_impl : dft_stage<T>
    407 {
    408     dft_reorder_stage_impl(const int* radices, size_t count) : count(count)
    409     {
    410         this->name        = dft_name(this);
    411         this->can_inplace = false;
    412         this->data_size   = 0;
    413         std::copy(radices, radices + count, this->radices);
    414         this->inner_size = 1;
    415         this->size       = 1;
    416         for (size_t r = 0; r < count; r++)
    417         {
    418             if (r != 0 && r != count - 1)
    419                 this->inner_size *= radices[r];
    420             this->size *= radices[r];
    421         }
    422         this->stage_size = this->size;
    423     }
    424 
    425 protected:
    426     size_t radices[32];
    427     size_t count      = 0;
    428     size_t size       = 0;
    429     size_t inner_size = 0;
    430     virtual void do_initialize(size_t) override final {}
    431 
    432     DFT_STAGE_FN
    433     template <bool inverse>
    434     KFR_MEM_INTRINSIC void do_execute(complex<T>* out, const complex<T>* in, u8*)
    435     {
    436         cswitch(
    437             dft_radices, radices[0],
    438             [&](auto first_radix)
    439             {
    440                 if (count == 3)
    441                 {
    442                     dft_permute(out, in, radices[2], radices[1], first_radix);
    443                 }
    444                 else
    445                 {
    446                     const size_t rlast = radices[count - 1];
    447                     for (size_t p = 0; p < rlast; p++)
    448                     {
    449                         dft_permute_deep(out, in, radices, count, count - 2, 1, inner_size, first_radix);
    450                         in += size / rlast;
    451                     }
    452                 }
    453             },
    454             [&]()
    455             {
    456                 if (count == 3)
    457                 {
    458                     dft_permute(out, in, radices[2], radices[1], radices[0]);
    459                 }
    460                 else
    461                 {
    462                     const size_t rlast = radices[count - 1];
    463                     for (size_t p = 0; p < rlast; p++)
    464                     {
    465                         dft_permute_deep(out, in, radices, count, count - 2, 1, inner_size, radices[0]);
    466                         in += size / rlast;
    467                     }
    468                 }
    469             });
    470     }
    471 };
    472 } // namespace intrinsics
    473 
    474 template <bool is_final, typename T>
    475 void prepare_dft_stage(dft_plan<T>* self, size_t radix, size_t iterations, size_t blocks, cbool_t<is_final>)
    476 {
    477     return cswitch(
    478         dft_radices, radix,
    479         [self, iterations, blocks](auto radix) CMT_INLINE_LAMBDA
    480         {
    481             add_stage<std::conditional_t<is_final, intrinsics::dft_stage_fixed_final_impl<T, val_of(radix)>,
    482                                          intrinsics::dft_stage_fixed_impl<T, val_of(radix)>>>(
    483                 self, radix, iterations, blocks);
    484         },
    485         [self, radix, iterations, blocks]()
    486         { add_stage<intrinsics::dft_stage_generic_impl<T, is_final>>(self, radix, iterations, blocks); });
    487 }
    488 
    489 template <typename T>
    490 void init_dft(dft_plan<T>* self, size_t size, dft_order)
    491 {
    492     if (size == 60)
    493     {
    494         add_stage<intrinsics::dft_special_stage_impl<T, 6, 10>>(self);
    495     }
    496     else if (size == 48)
    497     {
    498         add_stage<intrinsics::dft_special_stage_impl<T, 6, 8>>(self);
    499     }
    500     else
    501     {
    502         size_t cur_size                = size;
    503         constexpr size_t radices_count = dft_radices.back() + 1;
    504         u8 count[radices_count]        = { 0 };
    505         int radices[32]                = { 0 };
    506         size_t radices_size            = 0;
    507 
    508         cforeach(dft_radices[csizeseq<dft_radices.size(), dft_radices.size() - 1, -1>],
    509                  [&](auto radix)
    510                  {
    511                      while (cur_size && cur_size % val_of(radix) == 0)
    512                      {
    513                          count[val_of(radix)]++;
    514                          cur_size /= val_of(radix);
    515                      }
    516                  });
    517 
    518         int num_stages = 0;
    519         if (cur_size >= 101)
    520         {
    521             add_stage<intrinsics::dft_arblen_stage_impl<T>>(self, size);
    522             ++num_stages;
    523             self->arblen = true;
    524         }
    525         else
    526         {
    527             size_t blocks     = 1;
    528             size_t iterations = size;
    529 
    530             for (size_t r = dft_radices.front(); r <= dft_radices.back(); r++)
    531             {
    532                 for (size_t i = 0; i < count[r]; i++)
    533                 {
    534                     iterations /= r;
    535                     radices[radices_size++] = static_cast<int>(r);
    536                     if (iterations == 1)
    537                         prepare_dft_stage(self, r, iterations, blocks, ctrue);
    538                     else
    539                         prepare_dft_stage(self, r, iterations, blocks, cfalse);
    540                     ++num_stages;
    541                     blocks *= r;
    542                 }
    543             }
    544 
    545             if (cur_size > 1)
    546             {
    547                 iterations /= cur_size;
    548                 radices[radices_size++] = static_cast<int>(cur_size);
    549                 if (iterations == 1)
    550                     prepare_dft_stage(self, cur_size, iterations, blocks, ctrue);
    551                 else
    552                     prepare_dft_stage(self, cur_size, iterations, blocks, cfalse);
    553                 ++num_stages;
    554             }
    555 
    556             if (num_stages > 2)
    557                 add_stage<intrinsics::dft_reorder_stage_impl<T>>(self, radices, radices_size);
    558         }
    559     }
    560 }
    561 
    562 } // namespace CMT_ARCH_NAME
    563 
    564 } // namespace kfr
    565 
    566 CMT_PRAGMA_GNU(GCC diagnostic pop)
    567 
    568 CMT_PRAGMA_MSVC(warning(pop))