reference_dft.hpp (7036B)
1 /** @addtogroup dft 2 * @{ 3 */ 4 /* 5 Copyright (C) 2016-2023 Dan Cazarin (https://www.kfrlib.com) 6 This file is part of KFR 7 8 KFR is free software: you can redistribute it and/or modify 9 it under the terms of the GNU General Public License as published by 10 the Free Software Foundation, either version 2 of the License, or 11 (at your option) any later version. 12 13 KFR is distributed in the hope that it will be useful, 14 but WITHOUT ANY WARRANTY; without even the implied warranty of 15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 GNU General Public License for more details. 17 18 You should have received a copy of the GNU General Public License 19 along with KFR. 20 21 If GPL is not suitable for your project, you must purchase a commercial license to use KFR. 22 Buying a commercial license is mandatory as soon as you develop commercial activities without 23 disclosing the source code of your own applications. 24 See https://www.kfrlib.com for details. 25 */ 26 #pragma once 27 28 #include "../base/memory.hpp" 29 #include "../base/univector.hpp" 30 #include "../simd/complex.hpp" 31 #include "../simd/constants.hpp" 32 #include "../simd/read_write.hpp" 33 #include "../simd/vec.hpp" 34 #include <cmath> 35 #include <vector> 36 37 namespace kfr 38 { 39 40 namespace internal_generic 41 { 42 43 template <typename T> 44 void reference_dft_po2_pass(size_t N, int flag, const complex<T>* in, complex<T>* out, complex<T>* scratch, 45 size_t in_delta = 1, size_t out_delta = 1, size_t scratch_delta = 1) 46 { 47 const T pi2 = c_pi<T, 2, 1>; 48 const size_t N2 = N / 2; 49 const complex<T> w = pi2 * complex<T>{ 0, -T(flag) }; 50 51 if (N != 2) 52 { 53 reference_dft_po2_pass(N2, flag, in, scratch, out, 2 * in_delta, 2 * scratch_delta, 2 * out_delta); 54 reference_dft_po2_pass(N2, flag, in + in_delta, scratch + scratch_delta, out + out_delta, 55 2 * in_delta, 2 * scratch_delta, 2 * out_delta); 56 57 for (size_t k = 0; k < N2; k++) 58 { 59 const T m = static_cast<T>(k) / N; 60 const complex<T> tw = std::exp(w * m); 61 const complex<T> tmp = scratch[(2 * k + 1) * scratch_delta] * tw; 62 out[(k + N2) * out_delta] = scratch[(2 * k) * scratch_delta] - tmp; 63 out[(k)*out_delta] = scratch[(2 * k) * scratch_delta] + tmp; 64 } 65 } 66 else 67 { 68 out[out_delta] = in[0] - in[in_delta]; 69 out[0] = in[0] + in[in_delta]; 70 } 71 } 72 73 template <typename T> 74 void reference_dft_po2(complex<T>* out, const complex<T>* in, size_t size, bool inversion, 75 size_t out_delta = 1, size_t in_delta = 1) 76 { 77 if (size < 1) 78 return; 79 if (size == 1) 80 { 81 out[0] = in[0]; 82 return; 83 } 84 std::vector<complex<T>> temp(size); 85 reference_dft_po2_pass(size, inversion ? -1 : +1, in, out, temp.data(), in_delta, out_delta, 1); 86 } 87 88 /// @brief Performs Complex FFT using reference implementation (slow, used for testing) 89 template <typename T> 90 void reference_dft_nonpo2(complex<T>* out, const complex<T>* in, size_t size, bool inversion, 91 size_t out_delta = 1, size_t in_delta = 1) 92 { 93 constexpr T pi2 = c_pi<T, 2>; 94 const complex<T> w = pi2 * complex<T>{ 0, T(inversion ? +1 : -1) }; 95 if (size < 2) 96 return; 97 { 98 complex<T> sum = 0; 99 for (size_t j = 0; j < size; j++) 100 sum += in[j * in_delta]; 101 out[0] = sum; 102 } 103 for (size_t i = 1; i < size; i++) 104 { 105 complex<T> sum = in[0]; 106 for (size_t j = 1; j < size; j++) 107 { 108 complex<T> tw = std::exp(w * (static_cast<T>(i) * j / size)); 109 sum += tw * in[j * in_delta]; 110 } 111 out[i * out_delta] = sum; 112 } 113 } 114 } // namespace internal_generic 115 116 /// @brief Performs Complex DFT using reference implementation (slow, used for testing) 117 template <typename T> 118 void reference_dft(complex<T>* out, const complex<T>* in, size_t size, bool inversion = false, 119 size_t out_delta = 1, size_t in_delta = 1) 120 { 121 if (in == out) 122 { 123 std::vector<complex<T>> tmpin(size); 124 for (int i = 0; i < size; ++i) 125 tmpin[i] = in[i * in_delta]; 126 return reference_dft(out, tmpin.data(), size, inversion, out_delta, 1); 127 } 128 if (is_poweroftwo(size)) 129 { 130 return internal_generic::reference_dft_po2(out, in, size, inversion, out_delta, in_delta); 131 } 132 else 133 { 134 return internal_generic::reference_dft_nonpo2(out, in, size, inversion, out_delta, in_delta); 135 } 136 } 137 138 /// @brief Performs Direct Real DFT using reference implementation (slow, used for testing) 139 template <typename T> 140 void reference_dft(complex<T>* out, const T* in, size_t size, size_t out_delta = 1, size_t in_delta = 1) 141 { 142 if (size < 1) 143 return; 144 std::vector<complex<T>> tmpin(size); 145 for (index_t i = 0; i < size; ++i) 146 tmpin[i] = in[i * in_delta]; 147 std::vector<complex<T>> tmpout(size); 148 reference_dft(tmpout.data(), tmpin.data(), size, false, 1, 1); 149 for (index_t i = 0; i < size / 2 + 1; i++) 150 out[i * out_delta] = tmpout[i]; 151 } 152 153 /// @brief Performs Multidimensional Complex DFT using reference implementation (slow, used for testing) 154 template <typename T> 155 void reference_dft_md(complex<T>* out, const complex<T>* in, shape<dynamic_shape> size, 156 bool inversion = false, size_t out_delta = 1, size_t in_delta = 1) 157 { 158 index_t total = size.product(); 159 if (total < 1) 160 return; 161 if (total == 1) 162 { 163 out[0] = in[0]; 164 return; 165 } 166 index_t inner = 1; 167 index_t outer = total; 168 for (int axis = size.dims() - 1; axis >= 0; --axis) 169 { 170 index_t d = size[axis]; 171 outer /= d; 172 for (index_t o = 0; o < outer; ++o) 173 { 174 for (index_t i = 0; i < inner; ++i) 175 { 176 reference_dft(out + (i + o * inner * d) * out_delta, in + (i + o * inner * d) * in_delta, d, 177 inversion, out_delta * inner, in_delta * inner); 178 } 179 } 180 in = out; 181 in_delta = out_delta; 182 inner *= d; 183 } 184 } 185 186 /// @brief Performs Multidimensional Direct Real DFT using reference implementation (slow, used for testing) 187 template <typename T> 188 void reference_dft_md(complex<T>* out, const T* in, shape<dynamic_shape> shape, bool inversion = false, 189 size_t out_delta = 1, size_t in_delta = 1) 190 { 191 index_t size = shape.product(); 192 if (size < 1) 193 return; 194 std::vector<complex<T>> tmpin(size); 195 for (index_t i = 0; i < size; ++i) 196 tmpin[i] = in[i * in_delta]; 197 std::vector<complex<T>> tmpout(size); 198 reference_dft_md(tmpout.data(), tmpin.data(), shape, inversion, 1, 1); 199 index_t last = shape.back() / 2 + 1; 200 for (index_t i = 0; i < std::max(index_t(1), shape.remove_back().product()); ++i) 201 for (index_t j = 0; j < last; j++) 202 out[(i * last + j) * out_delta] = tmpout[i * shape.back() + j]; 203 } 204 205 } // namespace kfr