shape.hpp (26399B)
1 /** @addtogroup types 2 * @{ 3 */ 4 /* 5 Copyright (C) 2016-2023 Dan Cazarin (https://www.kfrlib.com) 6 This file is part of KFR 7 8 KFR is free software: you can redistribute it and/or modify 9 it under the terms of the GNU General Public License as published by 10 the Free Software Foundation, either version 2 of the License, or 11 (at your option) any later version. 12 13 KFR is distributed in the hope that it will be useful, 14 but WITHOUT ANY WARRANTY; without even the implied warranty of 15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 GNU General Public License for more details. 17 18 You should have received a copy of the GNU General Public License 19 along with KFR. 20 21 If GPL is not suitable for your project, you must purchase a commercial license to use KFR. 22 Buying a commercial license is mandatory as soon as you develop commercial activities without 23 disclosing the source code of your own applications. 24 See https://www.kfrlib.com for details. 25 */ 26 #pragma once 27 28 #include "../except.hpp" 29 #include "impl/static_array.hpp" 30 31 #include "../cometa/string.hpp" 32 #include "../simd/logical.hpp" 33 #include "../simd/min_max.hpp" 34 #include "../simd/shuffle.hpp" 35 #include "../simd/types.hpp" 36 37 #include <bitset> 38 #include <optional> 39 40 namespace kfr 41 { 42 43 #ifndef KFR_32BIT_INDICES 44 #if SIZE_MAX == UINT64_MAX 45 using index_t = uint64_t; 46 using signed_index_t = int64_t; 47 #else 48 using index_t = uint32_t; 49 using signed_index_t = int32_t; 50 #endif 51 #else 52 using index_t = uint32_t; 53 using signed_index_t = int32_t; 54 #endif 55 constexpr inline index_t max_index_t = std::numeric_limits<index_t>::max(); 56 constexpr inline signed_index_t max_sindex_t = std::numeric_limits<signed_index_t>::max(); 57 58 template <index_t val> 59 using cindex_t = cval_t<index_t, val>; 60 61 template <index_t val> 62 constexpr inline cindex_t<val> cindex{}; 63 64 constexpr inline index_t infinite_size = max_index_t; 65 66 constexpr inline index_t undefined_size = 0; 67 68 constexpr inline index_t maximum_dims = 16; 69 70 CMT_INTRINSIC constexpr size_t size_add(size_t x, size_t y) 71 { 72 return (x == infinite_size || y == infinite_size) ? infinite_size : x + y; 73 } 74 75 CMT_INTRINSIC constexpr size_t size_sub(size_t x, size_t y) 76 { 77 return (x == infinite_size || y == infinite_size) ? infinite_size : (x > y ? x - y : 0); 78 } 79 80 CMT_INTRINSIC constexpr size_t size_min(size_t x) CMT_NOEXCEPT { return x; } 81 82 template <typename... Ts> 83 CMT_INTRINSIC constexpr size_t size_min(size_t x, size_t y, Ts... rest) CMT_NOEXCEPT 84 { 85 return size_min(x < y ? x : y, rest...); 86 } 87 88 using dimset = static_array_of_size<i8, maximum_dims>; // std::array<i8, maximum_dims>; 89 90 template <index_t dims> 91 struct shape; 92 93 namespace internal_generic 94 { 95 template <index_t dims> 96 KFR_INTRINSIC bool increment_indices(shape<dims>& indices, const shape<dims>& start, const shape<dims>& stop, 97 index_t dim = dims - 1); 98 } // namespace internal_generic 99 100 template <index_t Dims> 101 struct shape : static_array_base<index_t, csizeseq_t<Dims>> 102 { 103 static_assert(Dims <= 256, "Too many dimensions"); 104 using base = static_array_base<index_t, csizeseq_t<Dims>>; 105 106 using base::base; 107 108 constexpr shape(const base& a) : base(a) {} 109 110 static_assert(Dims <= maximum_dims); 111 112 static constexpr size_t dims() { return base::static_size; } 113 114 template <int dummy = 0, KFR_ENABLE_IF(dummy == 0 && Dims == 1)> 115 operator index_t() const 116 { 117 return this->front(); 118 } 119 120 template <typename TI> 121 static constexpr shape from_std_array(const std::array<TI, Dims>& a) 122 { 123 shape result; 124 std::copy(a.begin(), a.end(), result.begin()); 125 return result; 126 } 127 128 template <typename TI = index_t> 129 constexpr std::array<TI, Dims> to_std_array() const 130 { 131 std::array<TI, Dims> result{}; 132 std::copy(this->begin(), this->end(), result.begin()); 133 return result; 134 } 135 136 bool ge(const shape& other) const 137 { 138 if constexpr (Dims == 1) 139 { 140 return this->front() >= other.front(); 141 } 142 else 143 { 144 return all(**this >= *other); 145 } 146 } 147 148 index_t trailing_zeros() const 149 { 150 for (index_t i = 0; i < Dims; ++i) 151 { 152 if (revindex(i) != 0) 153 return i; 154 } 155 return Dims; 156 } 157 158 bool le(const shape& other) const 159 { 160 if constexpr (Dims == 1) 161 { 162 return this->front() <= other.front(); 163 } 164 else 165 { 166 return all(**this <= *other); 167 } 168 } 169 170 constexpr shape add(index_t value) const 171 { 172 shape result = *this; 173 result.back() += value; 174 return result; 175 } 176 template <index_t Axis> 177 constexpr shape add_at(index_t value, cval_t<index_t, Axis> = {}) const 178 { 179 shape result = *this; 180 result[Axis] += value; 181 return result; 182 } 183 constexpr shape add(const shape& other) const { return **this + *other; } 184 constexpr shape sub(const shape& other) const { return **this - *other; } 185 constexpr index_t sum() const { return (*this)->sum(); } 186 187 constexpr bool has_infinity() const 188 { 189 for (index_t i = 0; i < Dims; ++i) 190 { 191 if (CMT_UNLIKELY(this->operator[](i) == infinite_size)) 192 return true; 193 } 194 return false; 195 } 196 197 friend constexpr shape add_shape(const shape& lhs, const shape& rhs) 198 { 199 return lhs.bin(rhs, [](index_t x, index_t y) { return std::max(std::max(x, y), x + y); }); 200 } 201 friend constexpr shape sub_shape(const shape& lhs, const shape& rhs) 202 { 203 return lhs.bin(rhs, [](index_t x, index_t y) 204 { return std::max(x, y) == infinite_size ? infinite_size : x - y; }); 205 } 206 friend constexpr shape add_shape_undef(const shape& lhs, const shape& rhs) 207 { 208 return lhs.bin(rhs, 209 [](index_t x, index_t y) 210 { 211 bool inf = std::max(x, y) == infinite_size; 212 bool undef = std::min(x, y) == undefined_size; 213 return inf ? infinite_size : undef ? undefined_size : x + y; 214 }); 215 } 216 friend constexpr shape sub_shape_undef(const shape& lhs, const shape& rhs) 217 { 218 return lhs.bin(rhs, 219 [](index_t x, index_t y) 220 { 221 bool inf = std::max(x, y) == infinite_size; 222 bool undef = std::min(x, y) == undefined_size; 223 return inf ? infinite_size : undef ? undefined_size : x - y; 224 }); 225 } 226 227 friend constexpr shape min(const shape& x, const shape& y) { return x->min(*y); } 228 229 constexpr const base& operator*() const { return static_cast<const base&>(*this); } 230 231 constexpr const base* operator->() const { return static_cast<const base*>(this); } 232 233 KFR_MEM_INTRINSIC constexpr size_t to_flat(const shape<Dims>& indices) const 234 { 235 if constexpr (Dims == 1) 236 { 237 return indices[0]; 238 } 239 else if constexpr (Dims == 2) 240 { 241 return (*this)[1] * indices[0] + indices[1]; 242 } 243 else 244 { 245 size_t result = 0; 246 size_t scale = 1; 247 CMT_LOOP_UNROLL 248 for (size_t i = 0; i < Dims; ++i) 249 { 250 result += scale * indices[Dims - 1 - i]; 251 scale *= (*this)[Dims - 1 - i]; 252 } 253 return result; 254 } 255 } 256 KFR_MEM_INTRINSIC constexpr shape<Dims> from_flat(size_t index) const 257 { 258 if constexpr (Dims == 1) 259 { 260 return { static_cast<index_t>(index) }; 261 } 262 else if constexpr (Dims == 2) 263 { 264 index_t sz = (*this)[1]; 265 return { static_cast<index_t>(index / sz), static_cast<index_t>(index % sz) }; 266 } 267 else 268 { 269 shape<Dims> indices; 270 CMT_LOOP_UNROLL 271 for (size_t i = 0; i < Dims; ++i) 272 { 273 size_t sz = (*this)[Dims - 1 - i]; 274 indices[Dims - 1 - i] = index % sz; 275 index /= sz; 276 } 277 return indices; 278 } 279 } 280 281 KFR_MEM_INTRINSIC constexpr index_t dot(const shape& other) const { return (*this)->dot(*other); } 282 283 template <index_t indims, bool stop = false> 284 KFR_MEM_INTRINSIC constexpr shape adapt(const shape<indims>& other, cbool_t<stop> = {}) const 285 { 286 static_assert(indims >= Dims); 287 if constexpr (stop) 288 return other.template trim<Dims>()->min(**this); 289 else 290 return other.template trim<Dims>()->min(**this - 1); 291 } 292 293 KFR_MEM_INTRINSIC constexpr index_t product() const { return (*this)->product(); } 294 295 KFR_MEM_INTRINSIC constexpr dimset tomask() const 296 { 297 dimset result(0); 298 for (index_t i = 0; i < Dims; ++i) 299 { 300 result[i + maximum_dims - Dims] = this->operator[](i) == 1 ? 0 : -1; 301 } 302 return result; 303 } 304 305 template <index_t new_dims> 306 constexpr KFR_MEM_INTRINSIC shape<new_dims> extend(index_t value = infinite_size) const 307 { 308 static_assert(new_dims >= Dims); 309 if constexpr (new_dims == Dims) 310 return *this; 311 else 312 return shape<new_dims>{ shape<new_dims - Dims>(value), *this }; 313 } 314 315 template <index_t odims> 316 constexpr shape<odims> trim() const 317 { 318 static_assert(odims <= Dims); 319 if constexpr (odims > 0) 320 { 321 return this->template slice<Dims - odims, odims>(); 322 } 323 else 324 { 325 return {}; 326 } 327 } 328 329 // 0,1,2,3 -> 1,2,3,0 330 constexpr KFR_MEM_INTRINSIC shape rotate_left() const 331 { 332 return this->shuffle(csizeseq<Dims, 1> % csize<Dims>); 333 } 334 335 // 0,1,2,3 -> 3,0,1,2 336 constexpr KFR_MEM_INTRINSIC shape rotate_right() const 337 { 338 return this->shuffle(csizeseq<Dims, Dims - 1> % csize<Dims>); 339 } 340 341 constexpr KFR_MEM_INTRINSIC shape<Dims - 1> remove_back() const 342 { 343 if constexpr (Dims > 1) 344 { 345 return this->template slice<0, Dims - 1>(); 346 } 347 else 348 { 349 return {}; 350 } 351 } 352 constexpr KFR_MEM_INTRINSIC shape<Dims - 1> remove_front() const 353 { 354 if constexpr (Dims > 1) 355 { 356 return this->template slice<1, Dims - 1>(); 357 } 358 else 359 { 360 return {}; 361 } 362 } 363 364 constexpr KFR_MEM_INTRINSIC shape<Dims - 1> trunc() const { return remove_back(); } 365 366 KFR_MEM_INTRINSIC constexpr index_t revindex(size_t index) const 367 { 368 return index < Dims ? this->operator[](Dims - 1 - index) : 1; 369 } 370 KFR_MEM_INTRINSIC constexpr void set_revindex(size_t index, index_t val) 371 { 372 if (CMT_LIKELY(index < Dims)) 373 this->operator[](Dims - 1 - index) = val; 374 } 375 376 KFR_MEM_INTRINSIC constexpr shape transpose() const 377 { 378 return this->shuffle(csizeseq<Dims, Dims - 1, -1>); 379 } 380 }; 381 382 template <> 383 struct shape<0> 384 { 385 static constexpr size_t static_size = 0; 386 387 static constexpr size_t size() { return static_size; } 388 389 static constexpr size_t dims() { return static_size; } 390 391 constexpr shape() = default; 392 constexpr shape(index_t value) {} 393 394 constexpr bool has_infinity() const { return false; } 395 396 KFR_MEM_INTRINSIC size_t to_flat(const shape<0>& indices) const { return 0; } 397 KFR_MEM_INTRINSIC shape<0> from_flat(size_t index) const { return {}; } 398 399 template <index_t odims, bool stop = false> 400 KFR_MEM_INTRINSIC shape<0> adapt(const shape<odims>& other, cbool_t<stop> = {}) const 401 { 402 return {}; 403 } 404 405 index_t trailing_zeros() const { return 0; } 406 407 KFR_MEM_INTRINSIC index_t dot(const shape& other) const { return 0; } 408 409 KFR_MEM_INTRINSIC index_t product() const { return 0; } 410 411 KFR_MEM_INTRINSIC dimset tomask() const { return dimset(-1); } 412 413 template <index_t new_dims> 414 constexpr KFR_MEM_INTRINSIC shape<new_dims> extend(index_t value = infinite_size) const 415 { 416 if constexpr (new_dims == 0) 417 return *this; 418 else 419 return shape<new_dims>{ value }; 420 } 421 422 template <index_t new_dims> 423 constexpr shape<new_dims> trim() const 424 { 425 static_assert(new_dims == 0); 426 return {}; 427 } 428 429 KFR_MEM_INTRINSIC constexpr bool operator==(const shape<0>& other) const { return true; } 430 KFR_MEM_INTRINSIC constexpr bool operator!=(const shape<0>& other) const { return false; } 431 432 KFR_MEM_INTRINSIC constexpr index_t revindex(size_t index) const { return 1; } 433 KFR_MEM_INTRINSIC void set_revindex(size_t index, index_t val) {} 434 }; 435 436 constexpr inline size_t dynamic_shape = std::numeric_limits<size_t>::max(); 437 438 template <> 439 struct shape<dynamic_shape> : protected std::vector<index_t> 440 { 441 using std::vector<index_t>::vector; 442 443 using std::vector<index_t>::begin; 444 using std::vector<index_t>::end; 445 using std::vector<index_t>::data; 446 using std::vector<index_t>::size; 447 using std::vector<index_t>::front; 448 using std::vector<index_t>::back; 449 using std::vector<index_t>::operator[]; 450 451 template <index_t Dims, CMT_ENABLE_IF(Dims != dynamic_shape)> 452 shape(shape<Dims> sh) : shape(sh.begin(), sh.end()) 453 { 454 } 455 456 size_t dims() const { return size(); } 457 458 KFR_MEM_INTRINSIC index_t product() const 459 { 460 if (std::vector<index_t>::empty()) 461 return 0; 462 index_t p = this->front(); 463 for (size_t i = 1; i < size(); ++i) 464 { 465 p *= this->operator[](i); 466 } 467 return p; 468 } 469 470 // 0,1,2,3 -> 1,2,3,0 471 KFR_MEM_INTRINSIC shape rotate_left() const 472 { 473 shape result = *this; 474 if (result.size() > 1) 475 std::rotate(result.begin(), result.begin() + 1, result.end()); 476 return result; 477 } 478 479 // 0,1,2,3 -> 3,0,1,2 480 KFR_MEM_INTRINSIC shape rotate_right() const 481 { 482 shape result = *this; 483 if (result.size() > 1) 484 std::rotate(result.begin(), result.end() - 1, result.end()); 485 return result; 486 } 487 488 KFR_MEM_INTRINSIC shape remove_back() const 489 { 490 shape result = *this; 491 if (!result.empty()) 492 result.erase(result.end() - 1); 493 return result; 494 } 495 KFR_MEM_INTRINSIC shape remove_front() const 496 { 497 shape result = *this; 498 if (!result.empty()) 499 { 500 result.erase(result.begin()); 501 } 502 return result; 503 } 504 }; 505 506 template <typename... Args> 507 shape(Args&&... args) -> shape<sizeof...(Args)>; 508 509 namespace internal_generic 510 { 511 512 template <index_t outdims, index_t indims> 513 KFR_MEM_INTRINSIC shape<outdims> adapt(const shape<indims>& in, const dimset& set) 514 { 515 static_assert(indims >= outdims); 516 if constexpr (outdims == 0) 517 { 518 return {}; 519 } 520 else 521 { 522 const static_array_of_size<index_t, maximum_dims> eset = set.template cast<index_t>(); 523 return in->template slice<indims - outdims, outdims>() & 524 eset.template slice<maximum_dims - outdims, outdims>(); 525 } 526 } 527 template <index_t outdims> 528 KFR_MEM_INTRINSIC shape<outdims> adapt(const shape<0>& in, const dimset& set) 529 { 530 static_assert(outdims == 0); 531 return {}; 532 } 533 } // namespace internal_generic 534 535 template <size_t Dims> 536 struct cursor 537 { 538 shape<Dims> current; 539 shape<Dims> minimum; 540 shape<Dims> maximum; 541 }; 542 543 using opt_index_t = std::optional<signed_index_t>; 544 545 struct tensor_range 546 { 547 opt_index_t start; 548 opt_index_t stop; 549 opt_index_t step; 550 }; 551 552 constexpr KFR_INTRINSIC tensor_range trange(std::optional<signed_index_t> start = std::nullopt, 553 std::optional<signed_index_t> stop = std::nullopt, 554 std::optional<signed_index_t> step = std::nullopt) 555 { 556 return { start, stop, step }; 557 } 558 559 constexpr KFR_INTRINSIC tensor_range tall() { return trange(); } 560 constexpr KFR_INTRINSIC tensor_range tstart(signed_index_t start, signed_index_t step = 1) 561 { 562 return trange(start, std::nullopt, step); 563 } 564 constexpr KFR_INTRINSIC tensor_range tstop(signed_index_t stop, signed_index_t step = 1) 565 { 566 return trange(std::nullopt, stop, step); 567 } 568 constexpr KFR_INTRINSIC tensor_range tstep(signed_index_t step = 1) 569 { 570 return trange(std::nullopt, std::nullopt, step); 571 } 572 573 namespace internal_generic 574 { 575 576 constexpr inline index_t null_index = max_index_t; 577 578 template <index_t dims, bool fortran_order = false> 579 constexpr KFR_INTRINSIC shape<dims> strides_for_shape(const shape<dims>& sh, index_t stride = 1) 580 { 581 shape<dims> strides; 582 if constexpr (dims > 0) 583 { 584 index_t n = stride; 585 for (index_t i = 0; i < dims; ++i) 586 { 587 strides[fortran_order ? i : dims - 1 - i] = n; 588 n *= sh[fortran_order ? i : dims - 1 - i]; 589 } 590 } 591 return strides; 592 } 593 594 template <size_t dims, size_t outdims, bool... ranges> 595 constexpr KFR_INTRINSIC shape<outdims> compact_shape(const shape<dims>& in) 596 { 597 shape<outdims> result; 598 constexpr std::array flags{ ranges... }; 599 size_t j = 0; 600 for (size_t i = 0; i < dims; ++i) 601 { 602 if (CMT_LIKELY(i >= flags.size() || flags[i])) 603 { 604 result[j++] = in[i]; 605 } 606 } 607 return result; 608 } 609 610 template <index_t dims1, index_t dims2, index_t outdims = const_max(dims1, dims2)> 611 constexpr bool can_assign_from(const shape<dims1>& dst_shape, const shape<dims2>& src_shape) 612 { 613 if constexpr (dims2 == 0) 614 { 615 return true; 616 } 617 else 618 { 619 for (size_t i = 0; i < outdims; ++i) 620 { 621 index_t dst_size = dst_shape.revindex(i); 622 index_t src_size = src_shape.revindex(i); 623 if (CMT_LIKELY(src_size == 1 || src_size == infinite_size || src_size == dst_size || 624 dst_size == infinite_size)) 625 { 626 } 627 else 628 { 629 return false; 630 } 631 } 632 return true; 633 } 634 } 635 636 template <bool checked = false, index_t dims> 637 constexpr shape<dims> common_shape(const shape<dims>& shape) 638 { 639 return shape; 640 } 641 642 template <bool checked = false, index_t dims1, index_t dims2, index_t outdims = const_max(dims1, dims2)> 643 KFR_MEM_INTRINSIC constexpr shape<outdims> common_shape(const shape<dims1>& shape1, 644 const shape<dims2>& shape2) 645 { 646 shape<outdims> result; 647 for (size_t i = 0; i < outdims; ++i) 648 { 649 index_t size1 = shape1.revindex(i); 650 index_t size2 = shape2.revindex(i); 651 if (CMT_UNLIKELY(!size1 || !size2)) 652 { 653 result[outdims - 1 - i] = 0; 654 continue; 655 } 656 657 if (CMT_UNLIKELY(size1 == infinite_size)) 658 { 659 if (CMT_UNLIKELY(size2 == infinite_size)) 660 { 661 result[outdims - 1 - i] = infinite_size; 662 } 663 else 664 { 665 result[outdims - 1 - i] = size2 == 1 ? infinite_size : size2; 666 } 667 } 668 else 669 { 670 if (CMT_UNLIKELY(size2 == infinite_size)) 671 { 672 result[outdims - 1 - i] = size1 == 1 ? infinite_size : size1; 673 } 674 else 675 { 676 if (CMT_LIKELY(size1 == 1 || size2 == 1 || size1 == size2)) 677 { 678 result[outdims - 1 - i] = std::max(size1, size2); 679 } 680 else 681 { 682 // broadcast failed 683 if constexpr (checked) 684 { 685 KFR_LOGIC_CHECK(false, "invalid or incompatible shapes: ", shape1, " and ", shape2); 686 } 687 else 688 { 689 result = shape<outdims>(0); 690 return result; 691 } 692 } 693 } 694 } 695 } 696 return result; 697 } 698 699 template <bool checked = false> 700 KFR_MEM_INTRINSIC constexpr shape<0> common_shape(const shape<0>& shape1, const shape<0>& shape2) 701 { 702 return {}; 703 } 704 705 template <bool checked = false, index_t dims1, index_t dims2, index_t... dims, 706 index_t outdims = const_max(dims1, dims2, dims...)> 707 KFR_MEM_INTRINSIC constexpr shape<outdims> common_shape(const shape<dims1>& shape1, 708 const shape<dims2>& shape2, 709 const shape<dims>&... shapes) 710 { 711 return common_shape<checked>(shape1, common_shape(shape2, shapes...)); 712 } 713 714 template <index_t dims1, index_t dims2> 715 KFR_MEM_INTRINSIC bool same_layout(const shape<dims1>& x, const shape<dims2>& y) 716 { 717 for (index_t i = 0, j = 0;;) 718 { 719 while (i < dims1 && x[i] == 1) 720 ++i; 721 while (j < dims2 && y[j] == 1) 722 ++j; 723 if (i == dims1 && j == dims2) 724 { 725 return true; 726 } 727 if (i < dims1 && j < dims2) 728 { 729 if (x[i] != y[j]) 730 return false; 731 } 732 else 733 { 734 return false; 735 } 736 ++i; 737 ++j; 738 } 739 } 740 741 #ifdef KFR_VEC_INDICES 742 template <size_t step, index_t dims> 743 KFR_INTRINSIC vec<index_t, dims> increment_indices(vec<index_t, dims> indices, 744 const vec<index_t, dims>& start, 745 const vec<index_t, dims>& stop) 746 { 747 indices = indices + make_vector(cconcat(cvalseq<index_t, dims - 1 - step, 0, 0>, cvalseq<index_t, 1, 1>, 748 cvalseq<index_t, step, 0, 0>)); 749 750 if constexpr (step + 1 < dims) 751 { 752 vec<bit<index_t>, dims> mask = indices >= stop; 753 if (CMT_LIKELY(!any(mask))) 754 return indices; 755 indices = blend(indices, start, cconcat(csizeseq<dims - step - 1, 0, 0>, csizeseq<step + 1, 1, 0>)); 756 757 return increment_indices<step + 1>(indices, stop); 758 } 759 else 760 { 761 return indices; 762 } 763 } 764 #endif 765 766 template <index_t dims> 767 KFR_INTRINSIC bool compare_indices(const shape<dims>& indices, const shape<dims>& stop, 768 index_t dim = dims - 1) 769 { 770 CMT_LOOP_UNROLL 771 for (int i = static_cast<int>(dim); i >= 0; --i) 772 { 773 if (CMT_UNLIKELY(indices[i] >= stop[i])) 774 return false; 775 } 776 return true; 777 } 778 779 template <index_t dims> 780 KFR_INTRINSIC bool increment_indices(shape<dims>& indices, const shape<dims>& start, const shape<dims>& stop, 781 index_t dim) 782 { 783 #ifdef KFR_VEC_INDICES 784 vec<index_t, dims> idx = increment_indices<0>(*indices, *start, *stop); 785 indices = idx; 786 if (any(idx == *stop)) 787 return false; 788 return true; 789 #else 790 if constexpr (dims > 0) 791 { 792 indices[dim] += 1; 793 CMT_LOOP_UNROLL 794 for (int i = static_cast<int>(dim); i >= 0;) 795 { 796 if (CMT_LIKELY(indices[i] < stop[i])) 797 return true; 798 // carry 799 indices[i] = start[i]; 800 --i; 801 if (i < 0) 802 { 803 return false; 804 } 805 indices[i] += 1; 806 } 807 return true; 808 } 809 else 810 { 811 return false; 812 } 813 #endif 814 } 815 816 template <index_t dims> 817 KFR_INTRINSIC shape<dims> increment_indices_return(const shape<dims>& indices, const shape<dims>& start, 818 const shape<dims>& stop, index_t dim = dims - 1) 819 { 820 shape<dims> result = indices; 821 if (CMT_LIKELY(increment_indices(result, start, stop, dim))) 822 { 823 return result; 824 } 825 else 826 { 827 return shape<dims>(null_index); 828 } 829 } 830 831 template <typename... Index> 832 constexpr KFR_INTRINSIC size_t count_dimensions() 833 { 834 return ((std::is_same_v<std::decay_t<Index>, tensor_range> ? 1 : 0) + ...); 835 } 836 837 template <typename U> 838 struct type_of_list 839 { 840 using value_type = U; 841 }; 842 843 template <typename U> 844 struct type_of_list<std::initializer_list<U>> 845 { 846 using value_type = typename type_of_list<U>::value_type; 847 }; 848 849 template <typename U> 850 using type_of_list_t = typename type_of_list<U>::value_type; 851 852 template <typename U> 853 constexpr KFR_INTRINSIC shape<1> shape_of_list(const std::initializer_list<U>& list) 854 { 855 return list.size(); 856 } 857 858 template <typename U> 859 constexpr KFR_INTRINSIC auto shape_of_list(const std::initializer_list<std::initializer_list<U>>& list) 860 { 861 return shape_of_list(*list.begin()); 862 } 863 864 template <typename U> 865 constexpr KFR_INTRINSIC U list_get(const std::initializer_list<U>& list, const shape<1>& idx) 866 { 867 return list[idx.front()]; 868 } 869 870 template <typename U, index_t dims> 871 constexpr KFR_INTRINSIC auto list_get(const std::initializer_list<std::initializer_list<U>>& list, 872 const shape<dims>& idx) 873 { 874 return list_get(list[idx[0]], idx.template trim<dims - 1>()); 875 } 876 877 template <typename U, typename T> 878 KFR_FUNCTION T* list_copy_recursively(const std::initializer_list<U>& list, T* dest) 879 { 880 for (const auto& value : list) 881 *dest++ = static_cast<T>(value); 882 return dest; 883 } 884 885 template <typename U, typename T> 886 KFR_FUNCTION T* list_copy_recursively(const std::initializer_list<std::initializer_list<U>>& list, T* dest) 887 { 888 for (const auto& sublist : list) 889 dest = list_copy_recursively(sublist, dest); 890 return dest; 891 } 892 893 } // namespace internal_generic 894 895 template <index_t dims> 896 constexpr KFR_INTRINSIC index_t size_of_shape(const shape<dims>& shape) 897 { 898 index_t n = 1; 899 if constexpr (dims > 0) 900 { 901 for (index_t i = 0; i < dims; ++i) 902 { 903 n *= shape[i]; 904 } 905 } 906 return n; 907 } 908 909 template <index_t Axis, size_t N> 910 struct axis_params 911 { 912 constexpr static index_t axis = Axis; 913 constexpr static index_t width = N; 914 constexpr static index_t value = N; 915 916 constexpr axis_params() = default; 917 }; 918 919 template <index_t Axis, size_t N> 920 constexpr inline const axis_params<Axis, N> axis_params_v{}; 921 922 } // namespace kfr 923 924 namespace cometa 925 { 926 template <kfr::index_t dims> 927 struct representation<kfr::shape<dims>> 928 { 929 using type = std::string; 930 static std::string get(const kfr::shape<dims>& value) 931 { 932 if constexpr (dims == 0) 933 { 934 return "shape{}"; 935 } 936 else 937 { 938 return "shape" + array_to_string(dims, value.data()); 939 } 940 } 941 }; 942 943 } // namespace cometa