kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

reference_dft.hpp (7036B)


      1 /** @addtogroup dft
      2  *  @{
      3  */
      4 /*
      5   Copyright (C) 2016-2023 Dan Cazarin (https://www.kfrlib.com)
      6   This file is part of KFR
      7 
      8   KFR is free software: you can redistribute it and/or modify
      9   it under the terms of the GNU General Public License as published by
     10   the Free Software Foundation, either version 2 of the License, or
     11   (at your option) any later version.
     12 
     13   KFR is distributed in the hope that it will be useful,
     14   but WITHOUT ANY WARRANTY; without even the implied warranty of
     15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     16   GNU General Public License for more details.
     17 
     18   You should have received a copy of the GNU General Public License
     19   along with KFR.
     20 
     21   If GPL is not suitable for your project, you must purchase a commercial license to use KFR.
     22   Buying a commercial license is mandatory as soon as you develop commercial activities without
     23   disclosing the source code of your own applications.
     24   See https://www.kfrlib.com for details.
     25  */
     26 #pragma once
     27 
     28 #include "../base/memory.hpp"
     29 #include "../base/univector.hpp"
     30 #include "../simd/complex.hpp"
     31 #include "../simd/constants.hpp"
     32 #include "../simd/read_write.hpp"
     33 #include "../simd/vec.hpp"
     34 #include <cmath>
     35 #include <vector>
     36 
     37 namespace kfr
     38 {
     39 
     40 namespace internal_generic
     41 {
     42 
     43 template <typename T>
     44 void reference_dft_po2_pass(size_t N, int flag, const complex<T>* in, complex<T>* out, complex<T>* scratch,
     45                             size_t in_delta = 1, size_t out_delta = 1, size_t scratch_delta = 1)
     46 {
     47     const T pi2        = c_pi<T, 2, 1>;
     48     const size_t N2    = N / 2;
     49     const complex<T> w = pi2 * complex<T>{ 0, -T(flag) };
     50 
     51     if (N != 2)
     52     {
     53         reference_dft_po2_pass(N2, flag, in, scratch, out, 2 * in_delta, 2 * scratch_delta, 2 * out_delta);
     54         reference_dft_po2_pass(N2, flag, in + in_delta, scratch + scratch_delta, out + out_delta,
     55                                2 * in_delta, 2 * scratch_delta, 2 * out_delta);
     56 
     57         for (size_t k = 0; k < N2; k++)
     58         {
     59             const T m                 = static_cast<T>(k) / N;
     60             const complex<T> tw       = std::exp(w * m);
     61             const complex<T> tmp      = scratch[(2 * k + 1) * scratch_delta] * tw;
     62             out[(k + N2) * out_delta] = scratch[(2 * k) * scratch_delta] - tmp;
     63             out[(k)*out_delta]        = scratch[(2 * k) * scratch_delta] + tmp;
     64         }
     65     }
     66     else
     67     {
     68         out[out_delta] = in[0] - in[in_delta];
     69         out[0]         = in[0] + in[in_delta];
     70     }
     71 }
     72 
     73 template <typename T>
     74 void reference_dft_po2(complex<T>* out, const complex<T>* in, size_t size, bool inversion,
     75                        size_t out_delta = 1, size_t in_delta = 1)
     76 {
     77     if (size < 1)
     78         return;
     79     if (size == 1)
     80     {
     81         out[0] = in[0];
     82         return;
     83     }
     84     std::vector<complex<T>> temp(size);
     85     reference_dft_po2_pass(size, inversion ? -1 : +1, in, out, temp.data(), in_delta, out_delta, 1);
     86 }
     87 
     88 /// @brief Performs Complex FFT using reference implementation (slow, used for testing)
     89 template <typename T>
     90 void reference_dft_nonpo2(complex<T>* out, const complex<T>* in, size_t size, bool inversion,
     91                           size_t out_delta = 1, size_t in_delta = 1)
     92 {
     93     constexpr T pi2    = c_pi<T, 2>;
     94     const complex<T> w = pi2 * complex<T>{ 0, T(inversion ? +1 : -1) };
     95     if (size < 2)
     96         return;
     97     {
     98         complex<T> sum = 0;
     99         for (size_t j = 0; j < size; j++)
    100             sum += in[j * in_delta];
    101         out[0] = sum;
    102     }
    103     for (size_t i = 1; i < size; i++)
    104     {
    105         complex<T> sum = in[0];
    106         for (size_t j = 1; j < size; j++)
    107         {
    108             complex<T> tw = std::exp(w * (static_cast<T>(i) * j / size));
    109             sum += tw * in[j * in_delta];
    110         }
    111         out[i * out_delta] = sum;
    112     }
    113 }
    114 } // namespace internal_generic
    115 
    116 /// @brief Performs Complex DFT using reference implementation (slow, used for testing)
    117 template <typename T>
    118 void reference_dft(complex<T>* out, const complex<T>* in, size_t size, bool inversion = false,
    119                    size_t out_delta = 1, size_t in_delta = 1)
    120 {
    121     if (in == out)
    122     {
    123         std::vector<complex<T>> tmpin(size);
    124         for (int i = 0; i < size; ++i)
    125             tmpin[i] = in[i * in_delta];
    126         return reference_dft(out, tmpin.data(), size, inversion, out_delta, 1);
    127     }
    128     if (is_poweroftwo(size))
    129     {
    130         return internal_generic::reference_dft_po2(out, in, size, inversion, out_delta, in_delta);
    131     }
    132     else
    133     {
    134         return internal_generic::reference_dft_nonpo2(out, in, size, inversion, out_delta, in_delta);
    135     }
    136 }
    137 
    138 /// @brief Performs Direct Real DFT using reference implementation (slow, used for testing)
    139 template <typename T>
    140 void reference_dft(complex<T>* out, const T* in, size_t size, size_t out_delta = 1, size_t in_delta = 1)
    141 {
    142     if (size < 1)
    143         return;
    144     std::vector<complex<T>> tmpin(size);
    145     for (index_t i = 0; i < size; ++i)
    146         tmpin[i] = in[i * in_delta];
    147     std::vector<complex<T>> tmpout(size);
    148     reference_dft(tmpout.data(), tmpin.data(), size, false, 1, 1);
    149     for (index_t i = 0; i < size / 2 + 1; i++)
    150         out[i * out_delta] = tmpout[i];
    151 }
    152 
    153 /// @brief Performs Multidimensional Complex DFT using reference implementation (slow, used for testing)
    154 template <typename T>
    155 void reference_dft_md(complex<T>* out, const complex<T>* in, shape<dynamic_shape> size,
    156                       bool inversion = false, size_t out_delta = 1, size_t in_delta = 1)
    157 {
    158     index_t total = size.product();
    159     if (total < 1)
    160         return;
    161     if (total == 1)
    162     {
    163         out[0] = in[0];
    164         return;
    165     }
    166     index_t inner = 1;
    167     index_t outer = total;
    168     for (int axis = size.dims() - 1; axis >= 0; --axis)
    169     {
    170         index_t d = size[axis];
    171         outer /= d;
    172         for (index_t o = 0; o < outer; ++o)
    173         {
    174             for (index_t i = 0; i < inner; ++i)
    175             {
    176                 reference_dft(out + (i + o * inner * d) * out_delta, in + (i + o * inner * d) * in_delta, d,
    177                               inversion, out_delta * inner, in_delta * inner);
    178             }
    179         }
    180         in       = out;
    181         in_delta = out_delta;
    182         inner *= d;
    183     }
    184 }
    185 
    186 /// @brief Performs Multidimensional Direct Real DFT using reference implementation (slow, used for testing)
    187 template <typename T>
    188 void reference_dft_md(complex<T>* out, const T* in, shape<dynamic_shape> shape, bool inversion = false,
    189                       size_t out_delta = 1, size_t in_delta = 1)
    190 {
    191     index_t size = shape.product();
    192     if (size < 1)
    193         return;
    194     std::vector<complex<T>> tmpin(size);
    195     for (index_t i = 0; i < size; ++i)
    196         tmpin[i] = in[i * in_delta];
    197     std::vector<complex<T>> tmpout(size);
    198     reference_dft_md(tmpout.data(), tmpin.data(), shape, inversion, 1, 1);
    199     index_t last = shape.back() / 2 + 1;
    200     for (index_t i = 0; i < std::max(index_t(1), shape.remove_back().product()); ++i)
    201         for (index_t j = 0; j < last; j++)
    202             out[(i * last + j) * out_delta] = tmpout[i * shape.back() + j];
    203 }
    204 
    205 } // namespace kfr