dft-impl.hpp (18591B)
1 /** @addtogroup dft 2 * @{ 3 */ 4 /* 5 Copyright (C) 2016-2023 Dan Cazarin (https://www.kfrlib.com) 6 This file is part of KFR 7 8 KFR is free software: you can redistribute it and/or modify 9 it under the terms of the GNU General Public License as published by 10 the Free Software Foundation, either version 2 of the License, or 11 (at your option) any later version. 12 13 KFR is distributed in the hope that it will be useful, 14 but WITHOUT ANY WARRANTY; without even the implied warranty of 15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 GNU General Public License for more details. 17 18 You should have received a copy of the GNU General Public License 19 along with KFR. 20 21 If GPL is not suitable for your project, you must purchase a commercial license to use KFR. 22 Buying a commercial license is mandatory as soon as you develop commercial activities without 23 disclosing the source code of your own applications. 24 See https://www.kfrlib.com for details. 25 */ 26 #pragma once 27 28 #include <kfr/base/math_expressions.hpp> 29 #include <kfr/base/simd_expressions.hpp> 30 #include "dft-fft.hpp" 31 32 CMT_PRAGMA_GNU(GCC diagnostic push) 33 #if CMT_HAS_WARNING("-Wshadow") 34 CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wshadow") 35 #endif 36 #if CMT_HAS_WARNING("-Wunused-lambda-capture") 37 CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wunused-lambda-capture") 38 #endif 39 #if CMT_HAS_WARNING("-Wpass-failed") 40 CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wpass-failed") 41 #endif 42 43 CMT_PRAGMA_MSVC(warning(push)) 44 CMT_PRAGMA_MSVC(warning(disable : 4100)) 45 46 namespace kfr 47 { 48 49 inline namespace CMT_ARCH_NAME 50 { 51 constexpr csizes_t<2, 3, 4, 5, 6, 7, 8, 9, 10> dft_radices{}; 52 53 namespace intrinsics 54 { 55 56 template <typename T> 57 void dft_stage_fixed_initialize(dft_stage<T>* stage, size_t width) 58 { 59 complex<T>* twiddle = ptr_cast<complex<T>>(stage->data); 60 const size_t N = stage->repeats * stage->radix; 61 const size_t Nord = stage->repeats; 62 size_t i = 0; 63 64 while (width > 0) 65 { 66 CMT_LOOP_NOUNROLL 67 for (; i < Nord / width * width; i += width) 68 { 69 CMT_LOOP_NOUNROLL 70 for (size_t j = 1; j < stage->radix; j++) 71 { 72 CMT_LOOP_NOUNROLL 73 for (size_t k = 0; k < width; k++) 74 { 75 cvec<T, 1> xx = cossin_conj(broadcast<2, T>(c_pi<T, 2> * (i + k) * j / N)); 76 ref_cast<cvec<T, 1>>(twiddle[k]) = xx; 77 } 78 twiddle += width; 79 } 80 } 81 width = width / 2; 82 } 83 } 84 85 template <typename T, size_t fixed_radix> 86 struct dft_stage_fixed_impl : dft_stage<T> 87 { 88 dft_stage_fixed_impl(size_t, size_t iterations, size_t blocks) 89 { 90 this->name = dft_name(this); 91 this->radix = fixed_radix; 92 this->blocks = blocks; 93 this->repeats = iterations; 94 this->recursion = false; // true; 95 this->stage_size = fixed_radix * iterations * blocks; 96 this->data_size = align_up((this->repeats * (fixed_radix - 1)) * sizeof(complex<T>), 97 platform<>::native_cache_alignment); 98 } 99 100 constexpr static size_t rradix = fixed_radix; 101 102 constexpr static size_t width = fixed_radix >= 7 ? fft_vector_width<T> / 2 103 : fixed_radix >= 4 ? fft_vector_width<T> 104 : fft_vector_width<T> * 2; 105 virtual void do_initialize(size_t) override final { dft_stage_fixed_initialize(this, width); } 106 107 DFT_STAGE_FN 108 template <bool inverse> 109 KFR_MEM_INTRINSIC void do_execute(complex<T>* out, const complex<T>* in, u8*) 110 { 111 const size_t Nord = this->repeats; 112 const complex<T>* twiddle = ptr_cast<complex<T>>(this->data); 113 114 const size_t N = Nord * fixed_radix; 115 CMT_LOOP_NOUNROLL 116 for (size_t b = 0; b < this->blocks; b++) 117 { 118 butterflies(Nord, csize<width>, csize<fixed_radix>, cbool<inverse>, out, in, twiddle, Nord); 119 in += N; 120 out += N; 121 } 122 } 123 }; 124 125 template <typename T, size_t fixed_radix> 126 struct dft_stage_fixed_final_impl : dft_stage<T> 127 { 128 dft_stage_fixed_final_impl(size_t, size_t iterations, size_t blocks) 129 { 130 this->name = dft_name(this); 131 this->radix = fixed_radix; 132 this->blocks = blocks; 133 this->repeats = iterations; 134 this->stage_size = fixed_radix * iterations * blocks; 135 this->recursion = false; 136 this->can_inplace = false; 137 } 138 constexpr static size_t width = fixed_radix >= 7 ? fft_vector_width<T> / 2 139 : fixed_radix >= 4 ? fft_vector_width<T> 140 : fft_vector_width<T> * 2; 141 142 DFT_STAGE_FN 143 template <bool inverse> 144 KFR_MEM_INTRINSIC void do_execute(complex<T>* out, const complex<T>* in, u8*) 145 { 146 const size_t b = this->blocks; 147 148 butterflies(b, csize<width>, csize<fixed_radix>, cbool<inverse>, out, in, b); 149 } 150 }; 151 152 template <typename E> 153 inline E& apply_conj(E& e, cfalse_t) 154 { 155 return e; 156 } 157 158 template <typename E> 159 inline auto apply_conj(E& e, ctrue_t) 160 { 161 return cconj(e); 162 } 163 164 /// [0, N - 1, N - 2, N - 3, ..., 3, 2, 1] 165 template <typename E> 166 struct fft_inverse : expression_with_traits<E> 167 { 168 using value_type = typename expression_with_traits<E>::value_type; 169 170 KFR_MEM_INTRINSIC fft_inverse(E&& expr) CMT_NOEXCEPT : expression_with_traits<E>(std::forward<E>(expr)) {} 171 172 friend KFR_INTRINSIC vec<value_type, 1> get_elements(const fft_inverse& self, shape<1> index, 173 axis_params<0, 1>) 174 { 175 const size_t size = get_shape(self).front(); 176 return get_elements(self.first(), index.front() == 0 ? 0 : size - index, axis_params<0, 1>()); 177 } 178 179 template <size_t N> 180 friend KFR_MEM_INTRINSIC vec<value_type, N> get_elements(const fft_inverse& self, shape<1> index, 181 axis_params<0, N>) 182 { 183 const size_t size = get_shape(self).front(); 184 if (index.front() == 0) 185 { 186 return concat(get_elements(self.first(), index, axis_params<0, 1>()), 187 reverse(get_elements(self.first(), size - (N - 1), axis_params<0, N - 1>()))); 188 } 189 return reverse(get_elements(self.first(), size - index - (N - 1), axis_params<0, N>())); 190 } 191 }; 192 193 template <typename E> 194 inline auto apply_fft_inverse(E&& e) 195 { 196 return fft_inverse<E>(std::forward<E>(e)); 197 } 198 199 template <typename T> 200 struct dft_arblen_stage_impl : dft_stage<T> 201 { 202 dft_arblen_stage_impl(size_t size) 203 : size(size), fftsize(next_poweroftwo(size) * 2), plan(fftsize, dft_order::internal) 204 { 205 this->name = dft_name(this); 206 this->radix = size; 207 this->blocks = 1; 208 this->repeats = 1; 209 this->recursion = false; 210 this->can_inplace = false; 211 this->temp_size = plan.temp_size; 212 this->stage_size = size; 213 214 chirp_ = render(cexp(sqr(linspace(T(1) - size, size - T(1), size * 2 - 1, true, ctrue)) * 215 complex<T>(0, -1) * c_pi<T> / size)); 216 217 ichirpp_ = render(truncate(padded(1 / slice(chirp_, 0, 2 * size - 1)), fftsize)); 218 219 univector<u8> temp(plan.temp_size); 220 plan.execute(ichirpp_, ichirpp_, temp); 221 xp.resize(fftsize, 0); 222 xp_fft.resize(fftsize); 223 invN2 = T(1) / fftsize; 224 } 225 226 DFT_STAGE_FN 227 template <bool inverse> 228 KFR_MEM_INTRINSIC void do_execute(complex<T>* out, const complex<T>* in, u8* temp) 229 { 230 const size_t n = this->size; 231 232 auto&& chirp = apply_conj(chirp_, cbool<inverse>); 233 234 xp.slice(0, n) = make_univector(in, n) * slice(chirp, n - 1); 235 236 plan.execute(xp_fft.data(), xp.data(), temp); 237 238 if (inverse) 239 xp_fft = xp_fft * cconj(apply_fft_inverse(ichirpp_)); 240 else 241 xp_fft = xp_fft * ichirpp_; 242 plan.execute(xp_fft.data(), xp_fft.data(), temp, ctrue); 243 244 make_univector(out, n) = xp_fft.slice(n - 1, n) * slice(chirp, n - 1, n) * invN2; 245 } 246 247 const size_t size; 248 const size_t fftsize; 249 T invN2; 250 dft_plan<T> plan; 251 univector<complex<T>> chirp_; 252 univector<complex<T>> ichirpp_; 253 univector<complex<T>> xp; 254 univector<complex<T>> xp_fft; 255 }; 256 257 template <typename T, size_t radix1, size_t radix2, size_t size = radix1 * radix2> 258 struct dft_special_stage_impl : dft_stage<T> 259 { 260 dft_special_stage_impl() : stage1(radix1, size / radix1, 1), stage2(radix2, 1, size / radix2) 261 { 262 this->name = dft_name(this); 263 this->radix = size; 264 this->blocks = 1; 265 this->repeats = 1; 266 this->recursion = false; 267 this->can_inplace = false; 268 this->stage_size = size; 269 this->temp_size = stage1.temp_size + stage2.temp_size + sizeof(complex<T>) * size; 270 this->data_size = stage1.data_size + stage2.data_size; 271 } 272 void dump() const override 273 { 274 dft_stage<T>::dump(); 275 printf(" "); 276 stage1.dump(); 277 printf(" "); 278 stage2.dump(); 279 } 280 void do_initialize(size_t stage_size) override 281 { 282 stage1.data = this->data; 283 stage2.data = this->data + stage1.data_size; 284 stage1.initialize(stage_size); 285 stage2.initialize(stage_size); 286 } 287 DFT_STAGE_FN 288 template <bool inverse> 289 KFR_MEM_INTRINSIC void do_execute(complex<T>* out, const complex<T>* in, u8* temp) 290 { 291 complex<T>* scratch = ptr_cast<complex<T>>(temp + stage1.temp_size + stage2.temp_size); 292 stage1.do_execute(cbool<inverse>, scratch, in, temp); 293 stage2.do_execute(cbool<inverse>, out, scratch, temp + stage1.temp_size); 294 } 295 dft_stage_fixed_impl<T, radix1> stage1; 296 dft_stage_fixed_final_impl<T, radix2> stage2; 297 }; 298 299 template <typename T, bool final> 300 struct dft_stage_generic_impl : dft_stage<T> 301 { 302 dft_stage_generic_impl(size_t radix, size_t iterations, size_t blocks) 303 { 304 this->name = dft_name(this); 305 this->radix = radix; 306 this->blocks = blocks; 307 this->repeats = iterations; 308 this->recursion = false; // true; 309 this->can_inplace = false; 310 this->stage_size = radix * iterations * blocks; 311 this->temp_size = align_up(sizeof(complex<T>) * radix, platform<>::native_cache_alignment); 312 this->data_size = 313 align_up(sizeof(complex<T>) * sqr(this->radix / 2), platform<>::native_cache_alignment); 314 } 315 316 protected: 317 virtual void do_initialize(size_t) override final 318 { 319 complex<T>* twiddle = ptr_cast<complex<T>>(this->data); 320 CMT_LOOP_NOUNROLL 321 for (size_t i = 0; i < this->radix / 2; i++) 322 { 323 CMT_LOOP_NOUNROLL 324 for (size_t j = 0; j < this->radix / 2; j++) 325 { 326 cwrite<1>(twiddle++, cossin_conj(broadcast<2>((i + 1) * (j + 1) * c_pi<T, 2> / this->radix))); 327 } 328 } 329 } 330 331 DFT_STAGE_FN 332 template <bool inverse> 333 KFR_MEM_INTRINSIC void do_execute(complex<T>* out, const complex<T>* in, u8* temp) 334 { 335 const complex<T>* twiddle = ptr_cast<complex<T>>(this->data); 336 const size_t bl = this->blocks; 337 338 CMT_LOOP_NOUNROLL 339 for (size_t b = 0; b < bl; b++) 340 generic_butterfly(this->radix, cbool<inverse>, out + b, in + b * this->radix, 341 ptr_cast<complex<T>>(temp), twiddle, bl); 342 } 343 }; 344 345 template <typename T, typename Tr2> 346 inline void dft_permute(complex<T>* out, const complex<T>* in, size_t r0, size_t r1, Tr2 first_radix) 347 { 348 CMT_ASSUME(r0 > 1); 349 CMT_ASSUME(r1 > 1); 350 351 CMT_LOOP_NOUNROLL 352 for (size_t p = 0; p < r0; p++) 353 { 354 const complex<T>* in1 = in; 355 CMT_LOOP_NOUNROLL 356 for (size_t i = 0; i < r1; i++) 357 { 358 const complex<T>* in2 = in1; 359 CMT_LOOP_UNROLL 360 for (size_t j = 0; j < first_radix; j++) 361 { 362 *out++ = *in2; 363 in2 += r1; 364 } 365 in1++; 366 in += first_radix; 367 } 368 } 369 } 370 371 template <typename T, typename Tr2> 372 inline void dft_permute_deep(complex<T>*& out, const complex<T>* in, const size_t* radices, size_t count, 373 size_t index, size_t inscale, size_t inner_size, Tr2 first_radix) 374 { 375 const bool b = index == 1; 376 const size_t radix = radices[index]; 377 if (b) 378 { 379 CMT_LOOP_NOUNROLL 380 for (size_t i = 0; i < radix; i++) 381 { 382 const complex<T>* in1 = in; 383 CMT_LOOP_UNROLL 384 for (size_t j = 0; j < first_radix; j++) 385 { 386 *out++ = *in1; 387 in1 += inner_size; 388 } 389 in += inscale; 390 } 391 } 392 else 393 { 394 const size_t steps = radix; 395 const size_t inscale_next = inscale * radix; 396 CMT_LOOP_NOUNROLL 397 for (size_t i = 0; i < steps; i++) 398 { 399 dft_permute_deep(out, in, radices, count, index - 1, inscale_next, inner_size, first_radix); 400 in += inscale; 401 } 402 } 403 } 404 405 template <typename T> 406 struct dft_reorder_stage_impl : dft_stage<T> 407 { 408 dft_reorder_stage_impl(const int* radices, size_t count) : count(count) 409 { 410 this->name = dft_name(this); 411 this->can_inplace = false; 412 this->data_size = 0; 413 std::copy(radices, radices + count, this->radices); 414 this->inner_size = 1; 415 this->size = 1; 416 for (size_t r = 0; r < count; r++) 417 { 418 if (r != 0 && r != count - 1) 419 this->inner_size *= radices[r]; 420 this->size *= radices[r]; 421 } 422 this->stage_size = this->size; 423 } 424 425 protected: 426 size_t radices[32]; 427 size_t count = 0; 428 size_t size = 0; 429 size_t inner_size = 0; 430 virtual void do_initialize(size_t) override final {} 431 432 DFT_STAGE_FN 433 template <bool inverse> 434 KFR_MEM_INTRINSIC void do_execute(complex<T>* out, const complex<T>* in, u8*) 435 { 436 cswitch( 437 dft_radices, radices[0], 438 [&](auto first_radix) 439 { 440 if (count == 3) 441 { 442 dft_permute(out, in, radices[2], radices[1], first_radix); 443 } 444 else 445 { 446 const size_t rlast = radices[count - 1]; 447 for (size_t p = 0; p < rlast; p++) 448 { 449 dft_permute_deep(out, in, radices, count, count - 2, 1, inner_size, first_radix); 450 in += size / rlast; 451 } 452 } 453 }, 454 [&]() 455 { 456 if (count == 3) 457 { 458 dft_permute(out, in, radices[2], radices[1], radices[0]); 459 } 460 else 461 { 462 const size_t rlast = radices[count - 1]; 463 for (size_t p = 0; p < rlast; p++) 464 { 465 dft_permute_deep(out, in, radices, count, count - 2, 1, inner_size, radices[0]); 466 in += size / rlast; 467 } 468 } 469 }); 470 } 471 }; 472 } // namespace intrinsics 473 474 template <bool is_final, typename T> 475 void prepare_dft_stage(dft_plan<T>* self, size_t radix, size_t iterations, size_t blocks, cbool_t<is_final>) 476 { 477 return cswitch( 478 dft_radices, radix, 479 [self, iterations, blocks](auto radix) CMT_INLINE_LAMBDA 480 { 481 add_stage<std::conditional_t<is_final, intrinsics::dft_stage_fixed_final_impl<T, val_of(radix)>, 482 intrinsics::dft_stage_fixed_impl<T, val_of(radix)>>>( 483 self, radix, iterations, blocks); 484 }, 485 [self, radix, iterations, blocks]() 486 { add_stage<intrinsics::dft_stage_generic_impl<T, is_final>>(self, radix, iterations, blocks); }); 487 } 488 489 template <typename T> 490 void init_dft(dft_plan<T>* self, size_t size, dft_order) 491 { 492 if (size == 60) 493 { 494 add_stage<intrinsics::dft_special_stage_impl<T, 6, 10>>(self); 495 } 496 else if (size == 48) 497 { 498 add_stage<intrinsics::dft_special_stage_impl<T, 6, 8>>(self); 499 } 500 else 501 { 502 size_t cur_size = size; 503 constexpr size_t radices_count = dft_radices.back() + 1; 504 u8 count[radices_count] = { 0 }; 505 int radices[32] = { 0 }; 506 size_t radices_size = 0; 507 508 cforeach(dft_radices[csizeseq<dft_radices.size(), dft_radices.size() - 1, -1>], 509 [&](auto radix) 510 { 511 while (cur_size && cur_size % val_of(radix) == 0) 512 { 513 count[val_of(radix)]++; 514 cur_size /= val_of(radix); 515 } 516 }); 517 518 int num_stages = 0; 519 if (cur_size >= 101) 520 { 521 add_stage<intrinsics::dft_arblen_stage_impl<T>>(self, size); 522 ++num_stages; 523 self->arblen = true; 524 } 525 else 526 { 527 size_t blocks = 1; 528 size_t iterations = size; 529 530 for (size_t r = dft_radices.front(); r <= dft_radices.back(); r++) 531 { 532 for (size_t i = 0; i < count[r]; i++) 533 { 534 iterations /= r; 535 radices[radices_size++] = static_cast<int>(r); 536 if (iterations == 1) 537 prepare_dft_stage(self, r, iterations, blocks, ctrue); 538 else 539 prepare_dft_stage(self, r, iterations, blocks, cfalse); 540 ++num_stages; 541 blocks *= r; 542 } 543 } 544 545 if (cur_size > 1) 546 { 547 iterations /= cur_size; 548 radices[radices_size++] = static_cast<int>(cur_size); 549 if (iterations == 1) 550 prepare_dft_stage(self, cur_size, iterations, blocks, ctrue); 551 else 552 prepare_dft_stage(self, cur_size, iterations, blocks, cfalse); 553 ++num_stages; 554 } 555 556 if (num_stages > 2) 557 add_stage<intrinsics::dft_reorder_stage_impl<T>>(self, radices, radices_size); 558 } 559 } 560 } 561 562 } // namespace CMT_ARCH_NAME 563 564 } // namespace kfr 565 566 CMT_PRAGMA_GNU(GCC diagnostic pop) 567 568 CMT_PRAGMA_MSVC(warning(pop))