kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

shape.hpp (26399B)


      1 /** @addtogroup types
      2  *  @{
      3  */
      4 /*
      5   Copyright (C) 2016-2023 Dan Cazarin (https://www.kfrlib.com)
      6   This file is part of KFR
      7 
      8   KFR is free software: you can redistribute it and/or modify
      9   it under the terms of the GNU General Public License as published by
     10   the Free Software Foundation, either version 2 of the License, or
     11   (at your option) any later version.
     12 
     13   KFR is distributed in the hope that it will be useful,
     14   but WITHOUT ANY WARRANTY; without even the implied warranty of
     15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     16   GNU General Public License for more details.
     17 
     18   You should have received a copy of the GNU General Public License
     19   along with KFR.
     20 
     21   If GPL is not suitable for your project, you must purchase a commercial license to use KFR.
     22   Buying a commercial license is mandatory as soon as you develop commercial activities without
     23   disclosing the source code of your own applications.
     24   See https://www.kfrlib.com for details.
     25  */
     26 #pragma once
     27 
     28 #include "../except.hpp"
     29 #include "impl/static_array.hpp"
     30 
     31 #include "../cometa/string.hpp"
     32 #include "../simd/logical.hpp"
     33 #include "../simd/min_max.hpp"
     34 #include "../simd/shuffle.hpp"
     35 #include "../simd/types.hpp"
     36 
     37 #include <bitset>
     38 #include <optional>
     39 
     40 namespace kfr
     41 {
     42 
     43 #ifndef KFR_32BIT_INDICES
     44 #if SIZE_MAX == UINT64_MAX
     45 using index_t        = uint64_t;
     46 using signed_index_t = int64_t;
     47 #else
     48 using index_t        = uint32_t;
     49 using signed_index_t = int32_t;
     50 #endif
     51 #else
     52 using index_t        = uint32_t;
     53 using signed_index_t = int32_t;
     54 #endif
     55 constexpr inline index_t max_index_t         = std::numeric_limits<index_t>::max();
     56 constexpr inline signed_index_t max_sindex_t = std::numeric_limits<signed_index_t>::max();
     57 
     58 template <index_t val>
     59 using cindex_t = cval_t<index_t, val>;
     60 
     61 template <index_t val>
     62 constexpr inline cindex_t<val> cindex{};
     63 
     64 constexpr inline index_t infinite_size = max_index_t;
     65 
     66 constexpr inline index_t undefined_size = 0;
     67 
     68 constexpr inline index_t maximum_dims = 16;
     69 
     70 CMT_INTRINSIC constexpr size_t size_add(size_t x, size_t y)
     71 {
     72     return (x == infinite_size || y == infinite_size) ? infinite_size : x + y;
     73 }
     74 
     75 CMT_INTRINSIC constexpr size_t size_sub(size_t x, size_t y)
     76 {
     77     return (x == infinite_size || y == infinite_size) ? infinite_size : (x > y ? x - y : 0);
     78 }
     79 
     80 CMT_INTRINSIC constexpr size_t size_min(size_t x) CMT_NOEXCEPT { return x; }
     81 
     82 template <typename... Ts>
     83 CMT_INTRINSIC constexpr size_t size_min(size_t x, size_t y, Ts... rest) CMT_NOEXCEPT
     84 {
     85     return size_min(x < y ? x : y, rest...);
     86 }
     87 
     88 using dimset = static_array_of_size<i8, maximum_dims>; // std::array<i8, maximum_dims>;
     89 
     90 template <index_t dims>
     91 struct shape;
     92 
     93 namespace internal_generic
     94 {
     95 template <index_t dims>
     96 KFR_INTRINSIC bool increment_indices(shape<dims>& indices, const shape<dims>& start, const shape<dims>& stop,
     97                                      index_t dim = dims - 1);
     98 } // namespace internal_generic
     99 
    100 template <index_t Dims>
    101 struct shape : static_array_base<index_t, csizeseq_t<Dims>>
    102 {
    103     static_assert(Dims <= 256, "Too many dimensions");
    104     using base = static_array_base<index_t, csizeseq_t<Dims>>;
    105 
    106     using base::base;
    107 
    108     constexpr shape(const base& a) : base(a) {}
    109 
    110     static_assert(Dims <= maximum_dims);
    111 
    112     static constexpr size_t dims() { return base::static_size; }
    113 
    114     template <int dummy = 0, KFR_ENABLE_IF(dummy == 0 && Dims == 1)>
    115     operator index_t() const
    116     {
    117         return this->front();
    118     }
    119 
    120     template <typename TI>
    121     static constexpr shape from_std_array(const std::array<TI, Dims>& a)
    122     {
    123         shape result;
    124         std::copy(a.begin(), a.end(), result.begin());
    125         return result;
    126     }
    127 
    128     template <typename TI = index_t>
    129     constexpr std::array<TI, Dims> to_std_array() const
    130     {
    131         std::array<TI, Dims> result{};
    132         std::copy(this->begin(), this->end(), result.begin());
    133         return result;
    134     }
    135 
    136     bool ge(const shape& other) const
    137     {
    138         if constexpr (Dims == 1)
    139         {
    140             return this->front() >= other.front();
    141         }
    142         else
    143         {
    144             return all(**this >= *other);
    145         }
    146     }
    147 
    148     index_t trailing_zeros() const
    149     {
    150         for (index_t i = 0; i < Dims; ++i)
    151         {
    152             if (revindex(i) != 0)
    153                 return i;
    154         }
    155         return Dims;
    156     }
    157 
    158     bool le(const shape& other) const
    159     {
    160         if constexpr (Dims == 1)
    161         {
    162             return this->front() <= other.front();
    163         }
    164         else
    165         {
    166             return all(**this <= *other);
    167         }
    168     }
    169 
    170     constexpr shape add(index_t value) const
    171     {
    172         shape result = *this;
    173         result.back() += value;
    174         return result;
    175     }
    176     template <index_t Axis>
    177     constexpr shape add_at(index_t value, cval_t<index_t, Axis> = {}) const
    178     {
    179         shape result = *this;
    180         result[Axis] += value;
    181         return result;
    182     }
    183     constexpr shape add(const shape& other) const { return **this + *other; }
    184     constexpr shape sub(const shape& other) const { return **this - *other; }
    185     constexpr index_t sum() const { return (*this)->sum(); }
    186 
    187     constexpr bool has_infinity() const
    188     {
    189         for (index_t i = 0; i < Dims; ++i)
    190         {
    191             if (CMT_UNLIKELY(this->operator[](i) == infinite_size))
    192                 return true;
    193         }
    194         return false;
    195     }
    196 
    197     friend constexpr shape add_shape(const shape& lhs, const shape& rhs)
    198     {
    199         return lhs.bin(rhs, [](index_t x, index_t y) { return std::max(std::max(x, y), x + y); });
    200     }
    201     friend constexpr shape sub_shape(const shape& lhs, const shape& rhs)
    202     {
    203         return lhs.bin(rhs, [](index_t x, index_t y)
    204                        { return std::max(x, y) == infinite_size ? infinite_size : x - y; });
    205     }
    206     friend constexpr shape add_shape_undef(const shape& lhs, const shape& rhs)
    207     {
    208         return lhs.bin(rhs,
    209                        [](index_t x, index_t y)
    210                        {
    211                            bool inf   = std::max(x, y) == infinite_size;
    212                            bool undef = std::min(x, y) == undefined_size;
    213                            return inf ? infinite_size : undef ? undefined_size : x + y;
    214                        });
    215     }
    216     friend constexpr shape sub_shape_undef(const shape& lhs, const shape& rhs)
    217     {
    218         return lhs.bin(rhs,
    219                        [](index_t x, index_t y)
    220                        {
    221                            bool inf   = std::max(x, y) == infinite_size;
    222                            bool undef = std::min(x, y) == undefined_size;
    223                            return inf ? infinite_size : undef ? undefined_size : x - y;
    224                        });
    225     }
    226 
    227     friend constexpr shape min(const shape& x, const shape& y) { return x->min(*y); }
    228 
    229     constexpr const base& operator*() const { return static_cast<const base&>(*this); }
    230 
    231     constexpr const base* operator->() const { return static_cast<const base*>(this); }
    232 
    233     KFR_MEM_INTRINSIC constexpr size_t to_flat(const shape<Dims>& indices) const
    234     {
    235         if constexpr (Dims == 1)
    236         {
    237             return indices[0];
    238         }
    239         else if constexpr (Dims == 2)
    240         {
    241             return (*this)[1] * indices[0] + indices[1];
    242         }
    243         else
    244         {
    245             size_t result = 0;
    246             size_t scale  = 1;
    247             CMT_LOOP_UNROLL
    248             for (size_t i = 0; i < Dims; ++i)
    249             {
    250                 result += scale * indices[Dims - 1 - i];
    251                 scale *= (*this)[Dims - 1 - i];
    252             }
    253             return result;
    254         }
    255     }
    256     KFR_MEM_INTRINSIC constexpr shape<Dims> from_flat(size_t index) const
    257     {
    258         if constexpr (Dims == 1)
    259         {
    260             return { static_cast<index_t>(index) };
    261         }
    262         else if constexpr (Dims == 2)
    263         {
    264             index_t sz = (*this)[1];
    265             return { static_cast<index_t>(index / sz), static_cast<index_t>(index % sz) };
    266         }
    267         else
    268         {
    269             shape<Dims> indices;
    270             CMT_LOOP_UNROLL
    271             for (size_t i = 0; i < Dims; ++i)
    272             {
    273                 size_t sz             = (*this)[Dims - 1 - i];
    274                 indices[Dims - 1 - i] = index % sz;
    275                 index /= sz;
    276             }
    277             return indices;
    278         }
    279     }
    280 
    281     KFR_MEM_INTRINSIC constexpr index_t dot(const shape& other) const { return (*this)->dot(*other); }
    282 
    283     template <index_t indims, bool stop = false>
    284     KFR_MEM_INTRINSIC constexpr shape adapt(const shape<indims>& other, cbool_t<stop> = {}) const
    285     {
    286         static_assert(indims >= Dims);
    287         if constexpr (stop)
    288             return other.template trim<Dims>()->min(**this);
    289         else
    290             return other.template trim<Dims>()->min(**this - 1);
    291     }
    292 
    293     KFR_MEM_INTRINSIC constexpr index_t product() const { return (*this)->product(); }
    294 
    295     KFR_MEM_INTRINSIC constexpr dimset tomask() const
    296     {
    297         dimset result(0);
    298         for (index_t i = 0; i < Dims; ++i)
    299         {
    300             result[i + maximum_dims - Dims] = this->operator[](i) == 1 ? 0 : -1;
    301         }
    302         return result;
    303     }
    304 
    305     template <index_t new_dims>
    306     constexpr KFR_MEM_INTRINSIC shape<new_dims> extend(index_t value = infinite_size) const
    307     {
    308         static_assert(new_dims >= Dims);
    309         if constexpr (new_dims == Dims)
    310             return *this;
    311         else
    312             return shape<new_dims>{ shape<new_dims - Dims>(value), *this };
    313     }
    314 
    315     template <index_t odims>
    316     constexpr shape<odims> trim() const
    317     {
    318         static_assert(odims <= Dims);
    319         if constexpr (odims > 0)
    320         {
    321             return this->template slice<Dims - odims, odims>();
    322         }
    323         else
    324         {
    325             return {};
    326         }
    327     }
    328 
    329     // 0,1,2,3 -> 1,2,3,0
    330     constexpr KFR_MEM_INTRINSIC shape rotate_left() const
    331     {
    332         return this->shuffle(csizeseq<Dims, 1> % csize<Dims>);
    333     }
    334 
    335     // 0,1,2,3 -> 3,0,1,2
    336     constexpr KFR_MEM_INTRINSIC shape rotate_right() const
    337     {
    338         return this->shuffle(csizeseq<Dims, Dims - 1> % csize<Dims>);
    339     }
    340 
    341     constexpr KFR_MEM_INTRINSIC shape<Dims - 1> remove_back() const
    342     {
    343         if constexpr (Dims > 1)
    344         {
    345             return this->template slice<0, Dims - 1>();
    346         }
    347         else
    348         {
    349             return {};
    350         }
    351     }
    352     constexpr KFR_MEM_INTRINSIC shape<Dims - 1> remove_front() const
    353     {
    354         if constexpr (Dims > 1)
    355         {
    356             return this->template slice<1, Dims - 1>();
    357         }
    358         else
    359         {
    360             return {};
    361         }
    362     }
    363 
    364     constexpr KFR_MEM_INTRINSIC shape<Dims - 1> trunc() const { return remove_back(); }
    365 
    366     KFR_MEM_INTRINSIC constexpr index_t revindex(size_t index) const
    367     {
    368         return index < Dims ? this->operator[](Dims - 1 - index) : 1;
    369     }
    370     KFR_MEM_INTRINSIC constexpr void set_revindex(size_t index, index_t val)
    371     {
    372         if (CMT_LIKELY(index < Dims))
    373             this->operator[](Dims - 1 - index) = val;
    374     }
    375 
    376     KFR_MEM_INTRINSIC constexpr shape transpose() const
    377     {
    378         return this->shuffle(csizeseq<Dims, Dims - 1, -1>);
    379     }
    380 };
    381 
    382 template <>
    383 struct shape<0>
    384 {
    385     static constexpr size_t static_size = 0;
    386 
    387     static constexpr size_t size() { return static_size; }
    388 
    389     static constexpr size_t dims() { return static_size; }
    390 
    391     constexpr shape() = default;
    392     constexpr shape(index_t value) {}
    393 
    394     constexpr bool has_infinity() const { return false; }
    395 
    396     KFR_MEM_INTRINSIC size_t to_flat(const shape<0>& indices) const { return 0; }
    397     KFR_MEM_INTRINSIC shape<0> from_flat(size_t index) const { return {}; }
    398 
    399     template <index_t odims, bool stop = false>
    400     KFR_MEM_INTRINSIC shape<0> adapt(const shape<odims>& other, cbool_t<stop> = {}) const
    401     {
    402         return {};
    403     }
    404 
    405     index_t trailing_zeros() const { return 0; }
    406 
    407     KFR_MEM_INTRINSIC index_t dot(const shape& other) const { return 0; }
    408 
    409     KFR_MEM_INTRINSIC index_t product() const { return 0; }
    410 
    411     KFR_MEM_INTRINSIC dimset tomask() const { return dimset(-1); }
    412 
    413     template <index_t new_dims>
    414     constexpr KFR_MEM_INTRINSIC shape<new_dims> extend(index_t value = infinite_size) const
    415     {
    416         if constexpr (new_dims == 0)
    417             return *this;
    418         else
    419             return shape<new_dims>{ value };
    420     }
    421 
    422     template <index_t new_dims>
    423     constexpr shape<new_dims> trim() const
    424     {
    425         static_assert(new_dims == 0);
    426         return {};
    427     }
    428 
    429     KFR_MEM_INTRINSIC constexpr bool operator==(const shape<0>& other) const { return true; }
    430     KFR_MEM_INTRINSIC constexpr bool operator!=(const shape<0>& other) const { return false; }
    431 
    432     KFR_MEM_INTRINSIC constexpr index_t revindex(size_t index) const { return 1; }
    433     KFR_MEM_INTRINSIC void set_revindex(size_t index, index_t val) {}
    434 };
    435 
    436 constexpr inline size_t dynamic_shape = std::numeric_limits<size_t>::max();
    437 
    438 template <>
    439 struct shape<dynamic_shape> : protected std::vector<index_t>
    440 {
    441     using std::vector<index_t>::vector;
    442 
    443     using std::vector<index_t>::begin;
    444     using std::vector<index_t>::end;
    445     using std::vector<index_t>::data;
    446     using std::vector<index_t>::size;
    447     using std::vector<index_t>::front;
    448     using std::vector<index_t>::back;
    449     using std::vector<index_t>::operator[];
    450 
    451     template <index_t Dims, CMT_ENABLE_IF(Dims != dynamic_shape)>
    452     shape(shape<Dims> sh) : shape(sh.begin(), sh.end())
    453     {
    454     }
    455 
    456     size_t dims() const { return size(); }
    457 
    458     KFR_MEM_INTRINSIC index_t product() const
    459     {
    460         if (std::vector<index_t>::empty())
    461             return 0;
    462         index_t p = this->front();
    463         for (size_t i = 1; i < size(); ++i)
    464         {
    465             p *= this->operator[](i);
    466         }
    467         return p;
    468     }
    469 
    470     // 0,1,2,3 -> 1,2,3,0
    471     KFR_MEM_INTRINSIC shape rotate_left() const
    472     {
    473         shape result = *this;
    474         if (result.size() > 1)
    475             std::rotate(result.begin(), result.begin() + 1, result.end());
    476         return result;
    477     }
    478 
    479     // 0,1,2,3 -> 3,0,1,2
    480     KFR_MEM_INTRINSIC shape rotate_right() const
    481     {
    482         shape result = *this;
    483         if (result.size() > 1)
    484             std::rotate(result.begin(), result.end() - 1, result.end());
    485         return result;
    486     }
    487 
    488     KFR_MEM_INTRINSIC shape remove_back() const
    489     {
    490         shape result = *this;
    491         if (!result.empty())
    492             result.erase(result.end() - 1);
    493         return result;
    494     }
    495     KFR_MEM_INTRINSIC shape remove_front() const
    496     {
    497         shape result = *this;
    498         if (!result.empty())
    499         {
    500             result.erase(result.begin());
    501         }
    502         return result;
    503     }
    504 };
    505 
    506 template <typename... Args>
    507 shape(Args&&... args) -> shape<sizeof...(Args)>;
    508 
    509 namespace internal_generic
    510 {
    511 
    512 template <index_t outdims, index_t indims>
    513 KFR_MEM_INTRINSIC shape<outdims> adapt(const shape<indims>& in, const dimset& set)
    514 {
    515     static_assert(indims >= outdims);
    516     if constexpr (outdims == 0)
    517     {
    518         return {};
    519     }
    520     else
    521     {
    522         const static_array_of_size<index_t, maximum_dims> eset = set.template cast<index_t>();
    523         return in->template slice<indims - outdims, outdims>() &
    524                eset.template slice<maximum_dims - outdims, outdims>();
    525     }
    526 }
    527 template <index_t outdims>
    528 KFR_MEM_INTRINSIC shape<outdims> adapt(const shape<0>& in, const dimset& set)
    529 {
    530     static_assert(outdims == 0);
    531     return {};
    532 }
    533 } // namespace internal_generic
    534 
    535 template <size_t Dims>
    536 struct cursor
    537 {
    538     shape<Dims> current;
    539     shape<Dims> minimum;
    540     shape<Dims> maximum;
    541 };
    542 
    543 using opt_index_t = std::optional<signed_index_t>;
    544 
    545 struct tensor_range
    546 {
    547     opt_index_t start;
    548     opt_index_t stop;
    549     opt_index_t step;
    550 };
    551 
    552 constexpr KFR_INTRINSIC tensor_range trange(std::optional<signed_index_t> start = std::nullopt,
    553                                             std::optional<signed_index_t> stop  = std::nullopt,
    554                                             std::optional<signed_index_t> step  = std::nullopt)
    555 {
    556     return { start, stop, step };
    557 }
    558 
    559 constexpr KFR_INTRINSIC tensor_range tall() { return trange(); }
    560 constexpr KFR_INTRINSIC tensor_range tstart(signed_index_t start, signed_index_t step = 1)
    561 {
    562     return trange(start, std::nullopt, step);
    563 }
    564 constexpr KFR_INTRINSIC tensor_range tstop(signed_index_t stop, signed_index_t step = 1)
    565 {
    566     return trange(std::nullopt, stop, step);
    567 }
    568 constexpr KFR_INTRINSIC tensor_range tstep(signed_index_t step = 1)
    569 {
    570     return trange(std::nullopt, std::nullopt, step);
    571 }
    572 
    573 namespace internal_generic
    574 {
    575 
    576 constexpr inline index_t null_index = max_index_t;
    577 
    578 template <index_t dims, bool fortran_order = false>
    579 constexpr KFR_INTRINSIC shape<dims> strides_for_shape(const shape<dims>& sh, index_t stride = 1)
    580 {
    581     shape<dims> strides;
    582     if constexpr (dims > 0)
    583     {
    584         index_t n = stride;
    585         for (index_t i = 0; i < dims; ++i)
    586         {
    587             strides[fortran_order ? i : dims - 1 - i] = n;
    588             n *= sh[fortran_order ? i : dims - 1 - i];
    589         }
    590     }
    591     return strides;
    592 }
    593 
    594 template <size_t dims, size_t outdims, bool... ranges>
    595 constexpr KFR_INTRINSIC shape<outdims> compact_shape(const shape<dims>& in)
    596 {
    597     shape<outdims> result;
    598     constexpr std::array flags{ ranges... };
    599     size_t j = 0;
    600     for (size_t i = 0; i < dims; ++i)
    601     {
    602         if (CMT_LIKELY(i >= flags.size() || flags[i]))
    603         {
    604             result[j++] = in[i];
    605         }
    606     }
    607     return result;
    608 }
    609 
    610 template <index_t dims1, index_t dims2, index_t outdims = const_max(dims1, dims2)>
    611 constexpr bool can_assign_from(const shape<dims1>& dst_shape, const shape<dims2>& src_shape)
    612 {
    613     if constexpr (dims2 == 0)
    614     {
    615         return true;
    616     }
    617     else
    618     {
    619         for (size_t i = 0; i < outdims; ++i)
    620         {
    621             index_t dst_size = dst_shape.revindex(i);
    622             index_t src_size = src_shape.revindex(i);
    623             if (CMT_LIKELY(src_size == 1 || src_size == infinite_size || src_size == dst_size ||
    624                            dst_size == infinite_size))
    625             {
    626             }
    627             else
    628             {
    629                 return false;
    630             }
    631         }
    632         return true;
    633     }
    634 }
    635 
    636 template <bool checked = false, index_t dims>
    637 constexpr shape<dims> common_shape(const shape<dims>& shape)
    638 {
    639     return shape;
    640 }
    641 
    642 template <bool checked = false, index_t dims1, index_t dims2, index_t outdims = const_max(dims1, dims2)>
    643 KFR_MEM_INTRINSIC constexpr shape<outdims> common_shape(const shape<dims1>& shape1,
    644                                                         const shape<dims2>& shape2)
    645 {
    646     shape<outdims> result;
    647     for (size_t i = 0; i < outdims; ++i)
    648     {
    649         index_t size1 = shape1.revindex(i);
    650         index_t size2 = shape2.revindex(i);
    651         if (CMT_UNLIKELY(!size1 || !size2))
    652         {
    653             result[outdims - 1 - i] = 0;
    654             continue;
    655         }
    656 
    657         if (CMT_UNLIKELY(size1 == infinite_size))
    658         {
    659             if (CMT_UNLIKELY(size2 == infinite_size))
    660             {
    661                 result[outdims - 1 - i] = infinite_size;
    662             }
    663             else
    664             {
    665                 result[outdims - 1 - i] = size2 == 1 ? infinite_size : size2;
    666             }
    667         }
    668         else
    669         {
    670             if (CMT_UNLIKELY(size2 == infinite_size))
    671             {
    672                 result[outdims - 1 - i] = size1 == 1 ? infinite_size : size1;
    673             }
    674             else
    675             {
    676                 if (CMT_LIKELY(size1 == 1 || size2 == 1 || size1 == size2))
    677                 {
    678                     result[outdims - 1 - i] = std::max(size1, size2);
    679                 }
    680                 else
    681                 {
    682                     // broadcast failed
    683                     if constexpr (checked)
    684                     {
    685                         KFR_LOGIC_CHECK(false, "invalid or incompatible shapes: ", shape1, " and ", shape2);
    686                     }
    687                     else
    688                     {
    689                         result = shape<outdims>(0);
    690                         return result;
    691                     }
    692                 }
    693             }
    694         }
    695     }
    696     return result;
    697 }
    698 
    699 template <bool checked = false>
    700 KFR_MEM_INTRINSIC constexpr shape<0> common_shape(const shape<0>& shape1, const shape<0>& shape2)
    701 {
    702     return {};
    703 }
    704 
    705 template <bool checked    = false, index_t dims1, index_t dims2, index_t... dims,
    706           index_t outdims = const_max(dims1, dims2, dims...)>
    707 KFR_MEM_INTRINSIC constexpr shape<outdims> common_shape(const shape<dims1>& shape1,
    708                                                         const shape<dims2>& shape2,
    709                                                         const shape<dims>&... shapes)
    710 {
    711     return common_shape<checked>(shape1, common_shape(shape2, shapes...));
    712 }
    713 
    714 template <index_t dims1, index_t dims2>
    715 KFR_MEM_INTRINSIC bool same_layout(const shape<dims1>& x, const shape<dims2>& y)
    716 {
    717     for (index_t i = 0, j = 0;;)
    718     {
    719         while (i < dims1 && x[i] == 1)
    720             ++i;
    721         while (j < dims2 && y[j] == 1)
    722             ++j;
    723         if (i == dims1 && j == dims2)
    724         {
    725             return true;
    726         }
    727         if (i < dims1 && j < dims2)
    728         {
    729             if (x[i] != y[j])
    730                 return false;
    731         }
    732         else
    733         {
    734             return false;
    735         }
    736         ++i;
    737         ++j;
    738     }
    739 }
    740 
    741 #ifdef KFR_VEC_INDICES
    742 template <size_t step, index_t dims>
    743 KFR_INTRINSIC vec<index_t, dims> increment_indices(vec<index_t, dims> indices,
    744                                                    const vec<index_t, dims>& start,
    745                                                    const vec<index_t, dims>& stop)
    746 {
    747     indices = indices + make_vector(cconcat(cvalseq<index_t, dims - 1 - step, 0, 0>, cvalseq<index_t, 1, 1>,
    748                                             cvalseq<index_t, step, 0, 0>));
    749 
    750     if constexpr (step + 1 < dims)
    751     {
    752         vec<bit<index_t>, dims> mask = indices >= stop;
    753         if (CMT_LIKELY(!any(mask)))
    754             return indices;
    755         indices = blend(indices, start, cconcat(csizeseq<dims - step - 1, 0, 0>, csizeseq<step + 1, 1, 0>));
    756 
    757         return increment_indices<step + 1>(indices, stop);
    758     }
    759     else
    760     {
    761         return indices;
    762     }
    763 }
    764 #endif
    765 
    766 template <index_t dims>
    767 KFR_INTRINSIC bool compare_indices(const shape<dims>& indices, const shape<dims>& stop,
    768                                    index_t dim = dims - 1)
    769 {
    770     CMT_LOOP_UNROLL
    771     for (int i = static_cast<int>(dim); i >= 0; --i)
    772     {
    773         if (CMT_UNLIKELY(indices[i] >= stop[i]))
    774             return false;
    775     }
    776     return true;
    777 }
    778 
    779 template <index_t dims>
    780 KFR_INTRINSIC bool increment_indices(shape<dims>& indices, const shape<dims>& start, const shape<dims>& stop,
    781                                      index_t dim)
    782 {
    783 #ifdef KFR_VEC_INDICES
    784     vec<index_t, dims> idx = increment_indices<0>(*indices, *start, *stop);
    785     indices                = idx;
    786     if (any(idx == *stop))
    787         return false;
    788     return true;
    789 #else
    790     if constexpr (dims > 0)
    791     {
    792         indices[dim] += 1;
    793         CMT_LOOP_UNROLL
    794         for (int i = static_cast<int>(dim); i >= 0;)
    795         {
    796             if (CMT_LIKELY(indices[i] < stop[i]))
    797                 return true;
    798             // carry
    799             indices[i] = start[i];
    800             --i;
    801             if (i < 0)
    802             {
    803                 return false;
    804             }
    805             indices[i] += 1;
    806         }
    807         return true;
    808     }
    809     else
    810     {
    811         return false;
    812     }
    813 #endif
    814 }
    815 
    816 template <index_t dims>
    817 KFR_INTRINSIC shape<dims> increment_indices_return(const shape<dims>& indices, const shape<dims>& start,
    818                                                    const shape<dims>& stop, index_t dim = dims - 1)
    819 {
    820     shape<dims> result = indices;
    821     if (CMT_LIKELY(increment_indices(result, start, stop, dim)))
    822     {
    823         return result;
    824     }
    825     else
    826     {
    827         return shape<dims>(null_index);
    828     }
    829 }
    830 
    831 template <typename... Index>
    832 constexpr KFR_INTRINSIC size_t count_dimensions()
    833 {
    834     return ((std::is_same_v<std::decay_t<Index>, tensor_range> ? 1 : 0) + ...);
    835 }
    836 
    837 template <typename U>
    838 struct type_of_list
    839 {
    840     using value_type = U;
    841 };
    842 
    843 template <typename U>
    844 struct type_of_list<std::initializer_list<U>>
    845 {
    846     using value_type = typename type_of_list<U>::value_type;
    847 };
    848 
    849 template <typename U>
    850 using type_of_list_t = typename type_of_list<U>::value_type;
    851 
    852 template <typename U>
    853 constexpr KFR_INTRINSIC shape<1> shape_of_list(const std::initializer_list<U>& list)
    854 {
    855     return list.size();
    856 }
    857 
    858 template <typename U>
    859 constexpr KFR_INTRINSIC auto shape_of_list(const std::initializer_list<std::initializer_list<U>>& list)
    860 {
    861     return shape_of_list(*list.begin());
    862 }
    863 
    864 template <typename U>
    865 constexpr KFR_INTRINSIC U list_get(const std::initializer_list<U>& list, const shape<1>& idx)
    866 {
    867     return list[idx.front()];
    868 }
    869 
    870 template <typename U, index_t dims>
    871 constexpr KFR_INTRINSIC auto list_get(const std::initializer_list<std::initializer_list<U>>& list,
    872                                       const shape<dims>& idx)
    873 {
    874     return list_get(list[idx[0]], idx.template trim<dims - 1>());
    875 }
    876 
    877 template <typename U, typename T>
    878 KFR_FUNCTION T* list_copy_recursively(const std::initializer_list<U>& list, T* dest)
    879 {
    880     for (const auto& value : list)
    881         *dest++ = static_cast<T>(value);
    882     return dest;
    883 }
    884 
    885 template <typename U, typename T>
    886 KFR_FUNCTION T* list_copy_recursively(const std::initializer_list<std::initializer_list<U>>& list, T* dest)
    887 {
    888     for (const auto& sublist : list)
    889         dest = list_copy_recursively(sublist, dest);
    890     return dest;
    891 }
    892 
    893 } // namespace internal_generic
    894 
    895 template <index_t dims>
    896 constexpr KFR_INTRINSIC index_t size_of_shape(const shape<dims>& shape)
    897 {
    898     index_t n = 1;
    899     if constexpr (dims > 0)
    900     {
    901         for (index_t i = 0; i < dims; ++i)
    902         {
    903             n *= shape[i];
    904         }
    905     }
    906     return n;
    907 }
    908 
    909 template <index_t Axis, size_t N>
    910 struct axis_params
    911 {
    912     constexpr static index_t axis  = Axis;
    913     constexpr static index_t width = N;
    914     constexpr static index_t value = N;
    915 
    916     constexpr axis_params() = default;
    917 };
    918 
    919 template <index_t Axis, size_t N>
    920 constexpr inline const axis_params<Axis, N> axis_params_v{};
    921 
    922 } // namespace kfr
    923 
    924 namespace cometa
    925 {
    926 template <kfr::index_t dims>
    927 struct representation<kfr::shape<dims>>
    928 {
    929     using type = std::string;
    930     static std::string get(const kfr::shape<dims>& value)
    931     {
    932         if constexpr (dims == 0)
    933         {
    934             return "shape{}";
    935         }
    936         else
    937         {
    938             return "shape" + array_to_string(dims, value.data());
    939         }
    940     }
    941 };
    942 
    943 } // namespace cometa