kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

tensor.cpp (28749B)


      1 /**
      2  * KFR (https://www.kfrlib.com)
      3  * Copyright (C) 2016-2023 Dan Cazarin
      4  * See LICENSE.txt for details
      5  */
      6 
      7 #include <kfr/base/basic_expressions.hpp>
      8 #include <kfr/base/math_expressions.hpp>
      9 #include <kfr/base/npy.hpp>
     10 #include <kfr/base/reduce.hpp>
     11 #include <kfr/base/simd_expressions.hpp>
     12 #include <kfr/base/tensor.hpp>
     13 #include <kfr/io/tostring.hpp>
     14 #include <kfr/simd.hpp>
     15 
     16 CMT_PRAGMA_MSVC(warning(push))
     17 CMT_PRAGMA_MSVC(warning(disable : 5051))
     18 CMT_PRAGMA_MSVC(warning(disable : 4244))
     19 
     20 namespace kfr
     21 {
     22 
     23 inline namespace CMT_ARCH_NAME
     24 {
     25 
     26 TEST(tensor_base)
     27 {
     28     tensor<float, 2> t{ shape{ 20, 40 } };
     29     CHECK(t.shape() == shape{ 20, 40 });
     30     CHECK(t.strides() == shape{ 40, 1 });
     31     CHECK(t.size() == 800);
     32     CHECK(t.is_contiguous());
     33 
     34     t(0, 0) = 123;
     35     t(1, 1) = 456;
     36 
     37     tensor<float, 2> t2 = t(tstop(10), tstop(10));
     38     CHECK(t2.shape() == shape{ 10, 10 });
     39     CHECK(t2.strides() == shape{ 40, 1 });
     40     CHECK(t2.size() == 100);
     41     CHECK(!t2.is_contiguous());
     42 
     43     CHECK(t2(0, 0) == 123);
     44     CHECK(t2.data() == t.data());
     45     CHECK(t2.finalizer() == t.finalizer());
     46 
     47     tensor<float, 2> t3 = t(tstart(1), tstart(1));
     48     CHECK(t3.shape() == shape{ 19, 39 });
     49     CHECK(t3.strides() == shape{ 40, 1 });
     50     CHECK(t3.size() == 741);
     51     CHECK(!t3.is_contiguous());
     52 
     53     CHECK(t3(0, 0) == 456);
     54     CHECK(t3.data() == t.data() + 40 + 1);
     55     CHECK(t3.finalizer() == t.finalizer());
     56 }
     57 
     58 TEST(tensor_memory)
     59 {
     60     // reference
     61     std::vector<float> vector{ 1.23f };
     62     tensor<float, 1> t{ vector.data(), shape{ 1 }, nullptr };
     63     CHECK(t.shape() == shape{ 1 });
     64     CHECK(t(0) == 1.23f);
     65 
     66     // adapt
     67     std::vector<float> vector2{ 2.34f };
     68     tensor<float, 1> t2 = tensor_from_container(std::move(vector2));
     69     CHECK(t2.shape() == shape{ 1 });
     70     CHECK(t2(0) == 2.34f);
     71 
     72     struct Container
     73     {
     74         std::array<double, 1> arr;
     75         int* refs;
     76         double* data() { return arr.data(); }
     77         size_t size() const { return arr.size(); }
     78         using value_type = double;
     79         Container(std::array<double, 1> arr, int* refs) : arr(arr), refs(refs) {}
     80         Container() { ++*refs; }
     81         Container(Container&& p) : arr(p.arr), refs(p.refs) { ++*refs; }
     82         Container(const Container&) = delete;
     83         ~Container() { --*refs; }
     84     };
     85 
     86     int refs = 0;
     87     Container cont{ { 3.45 }, &refs };
     88     {
     89         tensor<double, 1> t3 = tensor_from_container(std::move(cont));
     90         CHECK(t3.shape() == shape{ 1 });
     91         CHECK(t3(0) == 3.45);
     92         CHECK(refs == 1);
     93     }
     94     CHECK(refs == 0);
     95 }
     96 
     97 TEST(tensor_expression)
     98 {
     99     tensor<float, 1> t1{ shape{ 32 }, 0.f };
    100     tensor<float, 1> t2{ shape{ 32 }, 100.f };
    101     tensor<float, 1> t3{ shape{ 32 }, 0.f };
    102 
    103     t1 = counter();
    104 
    105     CHECK(t1.size() == 32);
    106     CHECK(t1(0) == 0.f);
    107     CHECK(t1(1) == 1.f);
    108     CHECK(t1(31) == 31.f);
    109 
    110     t3 = t1 + t2;
    111 
    112     CHECK(t3.size() == 32);
    113     CHECK(t3(0) == 100.f);
    114     CHECK(t3(1) == 101.f);
    115     CHECK(t3(31) == 131.f);
    116 
    117     tensor<float, 2> t4{ shape{ 6, 6 }, 0.f };
    118 
    119     t4 = 1.f;
    120     CHECK(t4(0, 0) == 1.f);
    121     CHECK(t4(5, 5) == 1.f);
    122     CHECK(minof(t4) == 1);
    123     CHECK(maxof(t4) == 1);
    124     CHECK(sum(t4) == 36);
    125 
    126     t4(trange(2, 4), trange(2, 4)) = scalar(10);
    127 
    128     CHECK(t4(0, 0) == 1.f);
    129     CHECK(t4(1, 1) == 1.f);
    130     CHECK(t4(2, 2) == 10.f);
    131     CHECK(t4(3, 3) == 10.f);
    132     CHECK(t4(5, 5) == 1.f);
    133     CHECK(sum(t4) == 72);
    134 
    135     t4(trange(2, 4), trange(2, 4)) = 10 + counter(0, 2, 1);
    136 
    137     CHECK(t4(2, 2) == 10.f);
    138     CHECK(t4(2, 3) == 11.f);
    139     CHECK(t4(3, 2) == 12.f);
    140     CHECK(t4(3, 3) == 13.f);
    141     CHECK(sum(t4) == 78);
    142 }
    143 
    144 TEST(tensor_broadcast)
    145 {
    146     tensor<float, 2> t1{ shape{ 1, 5 }, { 1.f, 2.f, 3.f, 4.f, 5.f } };
    147     tensor<float, 2> t2{ shape{ 5, 1 }, { 10.f, 20.f, 30.f, 40.f, 50.f } };
    148     tensor<float, 1> t4{ shape{ 5 }, { 1.f, 2.f, 3.f, 4.f, 5.f } };
    149     tensor<float, 2> tresult{ shape{ 5, 5 }, { 11, 12, 13, 14, 15, 21, 22, 23, 24, 25, 31, 32, 33,
    150                                                34, 35, 41, 42, 43, 44, 45, 51, 52, 53, 54, 55 } };
    151 
    152     tensor<float, 2> t3 = t1 + t2;
    153 
    154     CHECK(t3.shape() == shape{ 5, 5 });
    155     CHECK(t3 == tresult);
    156 
    157     tensor<float, 2> t5 = t4 + t2;
    158     // tensor<float, 2> t5 = t4 + t2;
    159     CHECK(t5 == tresult);
    160 }
    161 } // namespace CMT_ARCH_NAME
    162 
    163 template <typename T, size_t N1>
    164 struct expression_traits<std::array<T, N1>> : expression_traits_defaults
    165 {
    166     using value_type             = T;
    167     constexpr static size_t dims = 1;
    168 
    169     constexpr static shape<1> get_shape(const std::array<T, N1>& self) { return shape<1>{ N1 }; }
    170     constexpr static shape<1> get_shape() { return shape<1>{ N1 }; }
    171 };
    172 
    173 template <typename T, size_t N1, size_t N2>
    174 struct expression_traits<std::array<std::array<T, N1>, N2>> : expression_traits_defaults
    175 {
    176     using value_type             = T;
    177     constexpr static size_t dims = 2;
    178 
    179     constexpr static shape<2> get_shape(const std::array<std::array<T, N1>, N2>& self)
    180     {
    181         return shape<2>{ N2, N1 };
    182     }
    183     constexpr static shape<2> get_shape() { return shape<2>{ N2, N1 }; }
    184 };
    185 
    186 inline namespace CMT_ARCH_NAME
    187 {
    188 
    189 template <typename T, size_t N1, index_t Axis, size_t N>
    190 KFR_INTRINSIC vec<T, N> get_elements(const std::array<T, N1>& CMT_RESTRICT self, const shape<1>& index,
    191                                      const axis_params<Axis, N>&)
    192 {
    193     const T* CMT_RESTRICT const data = self.data();
    194     return read<N>(data + index[0]);
    195 }
    196 
    197 template <typename T, size_t N1, index_t Axis, size_t N>
    198 KFR_INTRINSIC void set_elements(std::array<T, N1>& CMT_RESTRICT self, const shape<1>& index,
    199                                 const axis_params<Axis, N>&, const identity<vec<T, N>>& val)
    200 {
    201     T* CMT_RESTRICT const data = self.data();
    202     write(data + index[0], val);
    203 }
    204 
    205 template <typename T, size_t N1, size_t N2, index_t Axis, size_t N>
    206 KFR_INTRINSIC vec<T, N> get_elements(const std::array<std::array<T, N1>, N2>& CMT_RESTRICT self,
    207                                      const shape<2>& index, const axis_params<Axis, N>&)
    208 {
    209     const T* CMT_RESTRICT const data = self.front().data() + index.front() * N1 + index.back();
    210     if constexpr (Axis == 1)
    211     {
    212         return read<N>(data);
    213     }
    214     else
    215     {
    216         return gather_stride<N>(data, N1);
    217     }
    218 }
    219 
    220 template <typename T, size_t N1, size_t N2, index_t Axis, size_t N>
    221 KFR_INTRINSIC void set_elements(std::array<std::array<T, N1>, N2>& CMT_RESTRICT self, const shape<2>& index,
    222                                 const axis_params<Axis, N>&, const identity<vec<T, N>>& val)
    223 {
    224     T* CMT_RESTRICT data = self.front().data() + index.front() * N1 + index.back();
    225     if constexpr (Axis == 1)
    226     {
    227         write(data, val);
    228     }
    229     else
    230     {
    231         scatter_stride(data, val, N1);
    232     }
    233 }
    234 
    235 TEST(tensor_slice)
    236 {
    237     tensor<double, 3> t1{
    238         { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } },
    239         { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } },
    240         { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } },
    241     };
    242     CHECK(t1 == trender(truncate(counter(0.0, 9, 3, 1), shape{ 3, 3, 3 })));
    243 
    244     CHECK(trender(slice(t1, shape{ 1, 1, 1 }, shape{ 2, 2, 2 })) ==
    245           trender(truncate(counter(13.0, 9, 3, 1), shape{ 2, 2, 2 })));
    246 }
    247 
    248 TEST(scalars)
    249 {
    250     CHECK(trender(scalar(3)) == tensor<int, 0>{});
    251     CHECK(trender(scalar(3)).to_string() == "3");
    252 }
    253 
    254 TEST(tensor_lambda)
    255 {
    256     CHECK(trender(lambda<float, 2>([](shape<2> idx) -> float { return 1 + idx[1] + 3 * idx[0]; }),
    257                   shape{ 3, 3 }) ==
    258           tensor<float, 2>{
    259               { 1, 2, 3 },
    260               { 4, 5, 6 },
    261               { 7, 8, 9 },
    262           });
    263 
    264     CHECK(trender(truncate(lambda<float, 2>([](shape<2> idx) -> float { return 1 + idx[1] + 3 * idx[0]; }),
    265                            shape{ 3, 3 })) ==
    266           tensor<float, 2>{
    267               { 1, 2, 3 },
    268               { 4, 5, 6 },
    269               { 7, 8, 9 },
    270           });
    271 }
    272 
    273 TEST(tensor_expressions2)
    274 {
    275     auto aa = std::array<std::array<double, 2>, 2>{ { { { 1, 2 } }, { { 3, 4 } } } };
    276     static_assert(expression_traits<decltype(aa)>::dims == 2);
    277     CHECK(expression_traits<decltype(aa)>::get_shape(aa) == shape{ 2, 2 });
    278     CHECK(get_elements(aa, { 1, 1 }, axis_params<1, 1>{}) == vec{ 4. });
    279     CHECK(get_elements(aa, { 1, 0 }, axis_params<1, 2>{}) == vec{ 3., 4. });
    280 
    281     static_assert(expression_traits<decltype(1234.f)>::dims == 0);
    282     CHECK(expression_traits<decltype(1234.f)>::get_shape(1234.f) == shape{});
    283     CHECK(get_elements(1234.f, {}, axis_params<0, 3>{}) == vec{ 1234.f, 1234.f, 1234.f });
    284 
    285     process(aa, 123.45f);
    286 
    287     CHECK(aa ==
    288           std::array<std::array<double, 2>, 2>{ { { { 123.45f, 123.45f } }, { { 123.45f, 123.45f } } } });
    289 
    290     auto a = std::array<double, 2>{ { -5.f, +5.f } };
    291 
    292     process(aa, a);
    293 
    294     CHECK(aa == std::array<std::array<double, 2>, 2>{ { { { -5., +5. } }, { { -5., +5. } } } });
    295 }
    296 
    297 TEST(tensor_counter)
    298 {
    299     std::array<double, 6> x;
    300 
    301     process(x, counter(0.0, 0.5));
    302 
    303     CHECK(x == std::array<double, 6>{ { 0.0, 0.5, 1.0, 1.5, 2.0, 2.5 } });
    304 
    305     std::array<std::array<double, 4>, 3> y;
    306 
    307     process(y, counter(100.0, 1.0, 10.0));
    308 
    309     CHECK(y == std::array<std::array<double, 4>, 3>{ {
    310                    { { 100.0, 110.0, 120.0, 130.0 } },
    311                    { { 101.0, 111.0, 121.0, 131.0 } },
    312                    { { 102.0, 112.0, 122.0, 132.0 } },
    313                } });
    314 }
    315 namespace tests
    316 {
    317 TEST(tensor_dims)
    318 {
    319     tensor<double, 6> t12{ shape{ 2, 3, 4, 5, 6, 7 } };
    320 
    321     process(t12, counter(0, 1, 10, 100, 1000, 10000, 100000));
    322 
    323     auto t1 = t12(1, 2, 3, tall(), 5, 6);
    324     CHECK(render(t1) == univector<double>{ 650321, 651321, 652321, 653321, 654321 });
    325 
    326     CHECK(t12.reduce(std::plus<>{}, 0) == 1648888920);
    327 }
    328 } // namespace tests
    329 
    330 TEST(vec_from_cvals)
    331 {
    332     CHECK(make_vector(csizes<1, 2, 3, 4>) == make_vector<size_t>(1, 2, 3, 4));
    333     CHECK(make_vector(cconcat(cvalseq<index_t, 2, 0, 0>, cvalseq<index_t, 1, 1>,
    334                               cvalseq<index_t, 2, 0, 0>)) == make_vector<size_t>(0, 0, 1, 0, 0));
    335 }
    336 
    337 TEST(xfunction_test)
    338 {
    339     auto f = expression_function{ expression_with_arguments{ 3.f, 4.f }, std::plus<>{} };
    340     float v;
    341     process(v, f);
    342     CHECK(v == 7.f);
    343     static_assert(std::is_same_v<decltype(f), expression_function<std::plus<>, float, float>>);
    344 
    345     auto f2 = expression_function{ expression_with_arguments{ 10.f, std::array{ 1.f, 2.f, 3.f, 4.f, 5.f } },
    346                                    std::plus<>{} };
    347     std::array<float, 5> v2;
    348     process(v2, f2);
    349     CHECK(v2 == std::array{ 11.f, 12.f, 13.f, 14.f, 15.f });
    350 
    351     auto f3 = scalar(10.f) + std::array{ 1.f, 2.f, 3.f, 4.f, 5.f };
    352     std::array<float, 5> v3;
    353     process(v3, f3);
    354     CHECK(v3 == std::array{ 11.f, 12.f, 13.f, 14.f, 15.f });
    355 
    356     auto f4 = scalar(0) +
    357               std::array<std::array<float, 1>, 5>{
    358                   { { { 100.f } }, { { 200.f } }, { { 300.f } }, { { 400.f } }, { { 500.f } } }
    359               } +
    360               std::array{ 1.f, 2.f, 3.f, 4.f, 5.f };
    361     std::array<std::array<float, 5>, 5> v4;
    362 
    363     CHECK(expression_traits<decltype(f4)>::get_shape(f4) == shape{ 5, 5 });
    364     process(v4, f4);
    365     CHECK(v4 == std::array<std::array<float, 5>, 5>{ { { { 101.f, 102.f, 103.f, 104.f, 105.f } },
    366                                                        { { 201.f, 202.f, 203.f, 204.f, 205.f } },
    367                                                        { { 301.f, 302.f, 303.f, 304.f, 305.f } },
    368                                                        { { 401.f, 402.f, 403.f, 404.f, 405.f } },
    369                                                        { { 501.f, 502.f, 503.f, 504.f, 505.f } } } });
    370 }
    371 
    372 TEST(xfunction_test2)
    373 {
    374     CHECK(trender(abs(tensor<float, 2>{ { 1, -2 }, { -1, 3 } })) == tensor<float, 2>{ { 1, 2 }, { 1, 3 } });
    375     CHECK(trender(min(tensor<float, 2>{ { 1, -2 }, { -1, 3 } }, tensor<float, 2>{ { 0, 3 }, { 2, 1 } })) ==
    376           tensor<float, 2>{ { 0, -2 }, { -1, 1 } });
    377 }
    378 
    379 template <typename Type, index_t Dims>
    380 KFR_FUNCTION expression_counter<Type, Dims> debug_counter(uint64_t scale = 10)
    381 {
    382     expression_counter<Type, Dims> result;
    383     result.start = 0;
    384     uint64_t val = 1;
    385     for (size_t i = 0; i < Dims; i++)
    386     {
    387         result.steps[Dims - 1 - i] = val;
    388         val *= scale;
    389     }
    390     return result;
    391 }
    392 
    393 static std::string nl = R"(
    394 )";
    395 
    396 TEST(tensor_tostring)
    397 {
    398     CHECK(as_string(shape{}) == "shape{}");
    399     CHECK(as_string(shape{ 1, 2, 3 }) == "shape{1, 2, 3}");
    400 
    401     tensor<f32x2, 1> t0(shape<1>{ 3 });
    402     t0(0) = vec{ 1, 2 };
    403     t0(1) = vec{ 3, 4 };
    404     t0(2) = vec{ -1, 1000 };
    405     CHECK(t0.to_string<fmt_t<f32x2, 'f', 0, 0>>() == "{{1, 2}, {3, 4}, {-1, 1000}}");
    406 
    407     tensor<float, 1> t1(shape<1>{ 60 });
    408     t1 = debug_counter<float, 1>();
    409     CHECK(nl + t1.to_string<fmt_t<float, 'f', 2, 0>>(12, 0) + nl == R"(
    410 { 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,
    411  12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
    412  24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
    413  36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
    414  48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59}
    415 )");
    416 
    417     tensor<float, 2> t2(shape<2>{ 12, 5 });
    418     t2 = debug_counter<float, 2>();
    419     CHECK(nl + t2.to_string<fmt_t<float, 'f', 3, 0>>(16, 0) + nl == R"(
    420 {{  0,   1,   2,   3,   4},
    421  { 10,  11,  12,  13,  14},
    422  { 20,  21,  22,  23,  24},
    423  { 30,  31,  32,  33,  34},
    424  { 40,  41,  42,  43,  44},
    425  { 50,  51,  52,  53,  54},
    426  { 60,  61,  62,  63,  64},
    427  { 70,  71,  72,  73,  74},
    428  { 80,  81,  82,  83,  84},
    429  { 90,  91,  92,  93,  94},
    430  {100, 101, 102, 103, 104},
    431  {110, 111, 112, 113, 114}}
    432 )");
    433 
    434     tensor<float, 3> t3(shape<3>{ 3, 4, 5 });
    435     t3 = debug_counter<float, 3>();
    436     CHECK(nl + t3.to_string<fmt_t<float, 'f', 4, 0>>(16, 0) + nl == R"(
    437 {{{   0,    1,    2,    3,    4},
    438   {  10,   11,   12,   13,   14},
    439   {  20,   21,   22,   23,   24},
    440   {  30,   31,   32,   33,   34}},
    441  {{ 100,  101,  102,  103,  104},
    442   { 110,  111,  112,  113,  114},
    443   { 120,  121,  122,  123,  124},
    444   { 130,  131,  132,  133,  134}},
    445  {{ 200,  201,  202,  203,  204},
    446   { 210,  211,  212,  213,  214},
    447   { 220,  221,  222,  223,  224},
    448   { 230,  231,  232,  233,  234}}}
    449 )");
    450 
    451     tensor<float, 4> t4(shape<4>{ 3, 2, 2, 5 });
    452     t4 = debug_counter<float, 4>();
    453     CHECK(nl + t4.to_string<fmt_t<float, 'f', 5, 0>>(16, 0) + nl == R"(
    454 {{{{    0,     1,     2,     3,     4},
    455    {   10,    11,    12,    13,    14}},
    456   {{  100,   101,   102,   103,   104},
    457    {  110,   111,   112,   113,   114}}},
    458  {{{ 1000,  1001,  1002,  1003,  1004},
    459    { 1010,  1011,  1012,  1013,  1014}},
    460   {{ 1100,  1101,  1102,  1103,  1104},
    461    { 1110,  1111,  1112,  1113,  1114}}},
    462  {{{ 2000,  2001,  2002,  2003,  2004},
    463    { 2010,  2011,  2012,  2013,  2014}},
    464   {{ 2100,  2101,  2102,  2103,  2104},
    465    { 2110,  2111,  2112,  2113,  2114}}}}
    466 )");
    467 
    468     tensor<float, 2> t5(shape<2>{ 10, 1 });
    469     t5 = debug_counter<float, 2>();
    470     CHECK(nl + t5.to_string<fmt_t<float, 'f', -1, 0>>(12, 1) + nl == R"(
    471 {{0}, {10}, {20}, {30}, {40}, {50}, {60}, {70}, {80}, {90}}
    472 )");
    473 }
    474 
    475 template <typename T, index_t dims1, index_t dims2>
    476 static void test_reshape_body(const tensor<T, dims1>& t1, const tensor<T, dims2>& t2)
    477 {
    478     CHECK(t1.reshape_may_copy(t2.shape(), true) == t2);
    479 
    480     cforeach(csizeseq<dims2>,
    481              [&](auto x)
    482              {
    483                  constexpr index_t axis = val_of(decltype(x)());
    484                  ::testo::scope s(
    485                      as_string("axis = ", axis, " shape = (", t1.shape(), ") -> (", t2.shape(), ")"));
    486                  CHECK(trender<1, axis>(reshape(t1, t2.shape())) == t2);
    487                  CHECK(trender<2, axis>(reshape(t1, t2.shape())) == t2);
    488                  CHECK(trender<4, axis>(reshape(t1, t2.shape())) == t2);
    489                  CHECK(trender<8, axis>(reshape(t1, t2.shape())) == t2);
    490              });
    491 }
    492 
    493 static void test_reshape() {}
    494 
    495 template <typename T, index_t dims1, index_t... dims>
    496 static void test_reshape(const tensor<T, dims1>& t1, const tensor<T, dims>&... ts)
    497 {
    498     cforeach(std::make_tuple((&ts)...),
    499              [&](auto t2)
    500              {
    501                  test_reshape_body(t1, *t2);
    502                  test_reshape_body(*t2, t1);
    503              });
    504 
    505     test_reshape(ts...);
    506 }
    507 
    508 TEST(expression_reshape)
    509 {
    510     std::array<float, 12> x;
    511     process(reshape(x, shape{ 3, 4 }), expression_counter<float, 2>{ 0, { 10, 1 } });
    512     CHECK(x == std::array<float, 12>{ { 0, 1, 2, 3, 10, 11, 12, 13, 20, 21, 22, 23 } });
    513 
    514     test_reshape(tensor<float, 1>{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 }, //
    515                  tensor<float, 2>{ { 0, 1, 2, 3, 4, 5 }, { 6, 7, 8, 9, 10, 11 } },
    516                  tensor<float, 2>{ { 0, 1, 2, 3 }, { 4, 5, 6, 7 }, { 8, 9, 10, 11 } },
    517                  tensor<float, 2>{ { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 }, { 9, 10, 11 } },
    518                  tensor<float, 2>{ { 0, 1 }, { 2, 3 }, { 4, 5 }, { 6, 7 }, { 8, 9 }, { 10, 11 } },
    519                  tensor<float, 2>{
    520                      { 0 }, { 1 }, { 2 }, { 3 }, { 4 }, { 5 }, { 6 }, { 7 }, { 8 }, { 9 }, { 10 }, { 11 } });
    521 
    522     test_reshape(tensor<float, 1>{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11 }, //
    523                  tensor<float, 3>{ { { 0, 1 }, { 2, 3 }, { 4, 5 } }, { { 6, 7 }, { 8, 9 }, { 10, 11 } } },
    524                  tensor<float, 4>{ { { { 0 }, { 1 } }, { { 2 }, { 3 } }, { { 4 }, { 5 } } },
    525                                    { { { 6 }, { 7 } }, { { 8 }, { 9 } }, { { 10 }, { 11 } } } });
    526 
    527     test_reshape(
    528         tensor<float, 1>{
    529             0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23 }, //
    530         tensor<float, 3>{ { { 0, 1 }, { 2, 3 }, { 4, 5 } },
    531                           { { 6, 7 }, { 8, 9 }, { 10, 11 } },
    532                           { { 12, 13 }, { 14, 15 }, { 16, 17 } },
    533                           { { 18, 19 }, { 20, 21 }, { 22, 23 } } },
    534         tensor<float, 4>{
    535             { { { 0, 1 }, { 2, 3 } }, { { 4, 5 }, { 6, 7 } }, { { 8, 9 }, { 10, 11 } } },
    536             { { { 12, 13 }, { 14, 15 } }, { { 16, 17 }, { 18, 19 } }, { { 20, 21 }, { 22, 23 } } } });
    537 }
    538 
    539 } // namespace CMT_ARCH_NAME
    540 
    541 #if 0
    542 shape<4> sh{ 2, 3, 4, 5 };
    543 
    544 extern "C" __declspec(dllexport) bool assembly_test1(shape<4>& x)
    545 {
    546     return kfr::internal_generic::increment_indices(x, shape<4>(0), sh);
    547 }
    548 
    549 extern "C" __declspec(dllexport) bool assembly_test2(std::array<std::array<double, 2>, 2>& aa,
    550                                                      std::array<double, 2>& a)
    551 {
    552     return process(aa, a).front() > 0;
    553 }
    554 
    555 extern "C" __declspec(dllexport) bool assembly_test3(std::array<double, 16>& x)
    556 {
    557     return process(x, 0.5).front() > 0;
    558 }
    559 
    560 extern "C" __declspec(dllexport) bool assembly_test4(std::array<double, 16>& x)
    561 {
    562     return process(x, counter(1000.0, 1.0)).front() > 0;
    563 }
    564 
    565 extern "C" __declspec(dllexport) bool assembly_test5(const tensor<double, 3>& x)
    566 {
    567     return process(x, counter(1000.0, 1.0, 2.0, 3.0)).front() > 0;
    568 }
    569 
    570 extern "C" __declspec(dllexport) bool assembly_test6(const tensor<double, 2>& x)
    571 {
    572     return process(x, counter(1000.0, 1.0, 2.0)).front() > 0;
    573 }
    574 
    575 extern "C" __declspec(dllexport) bool assembly_test7(const tensor<double, 2>& x)
    576 {
    577     return process(x, 12345.).front() > 0;
    578 }
    579 
    580 extern "C" __declspec(dllexport) index_t assembly_test8_2(const shape<2>& x, const shape<2>& y)
    581 {
    582     return x.dot(y);
    583 }
    584 
    585 extern "C" __declspec(dllexport) index_t assembly_test8_4(const shape<4>& x, const shape<4>& y)
    586 {
    587     return x.dot(y);
    588 }
    589 
    590 extern "C" __declspec(dllexport) void assembly_test9(int64_t* dst, size_t stride)
    591 {
    592     scatter_stride(dst, enumerate(vec_shape<int64_t, 8>()), stride);
    593 }
    594 constexpr inline index_t rank = 1;
    595 extern "C" __declspec(dllexport) void assembly_test10(tensor<double, rank>& t12,
    596                                                       const expression_counter<double, rank>& ctr)
    597 {
    598     process(t12, ctr);
    599 }
    600 extern "C" __declspec(dllexport) void assembly_test11(f64x2& x, u64x2 y) { x = y; }
    601 
    602 extern "C" __declspec(dllexport) void assembly_test12(
    603     std::array<std::array<uint32_t, 4>, 4>& x,
    604     const expression_function<std::plus<>, std::array<std::array<uint32_t, 1>, 4>&,
    605                               std::array<std::array<uint32_t, 4>, 1>&>& y)
    606 {
    607     process(x, y);
    608 }
    609 
    610 extern "C" __declspec(dllexport) void assembly_test13(const tensor<float, 1>& x, const tensor<float, 1>& y)
    611 {
    612     process(x, y * 0.5f);
    613 }
    614 
    615 template <typename T, size_t N1, size_t N2>
    616 using array2d = std::array<std::array<T, N2>, N1>;
    617 
    618 extern "C" __declspec(dllexport) void assembly_test14(std::array<float, 32>& x,
    619                                                       const std::array<float, 32>& y)
    620 {
    621     process(x, reverse(y));
    622 }
    623 
    624 extern "C" __declspec(dllexport) void assembly_test15(array2d<float, 32, 32>& x,
    625                                                       const array2d<float, 32, 32>& y)
    626 {
    627     process(x, reverse(y));
    628 }
    629 
    630 extern "C" __declspec(dllexport) void assembly_test16a(array2d<double, 8, 2>& x,
    631                                                        const array2d<double, 8, 2>& y)
    632 {
    633     process<8, 0>(x, y * y);
    634 }
    635 extern "C" __declspec(dllexport) void assembly_test16b(array2d<double, 8, 2>& x,
    636                                                        const array2d<double, 8, 2>& y)
    637 {
    638     process<2, 1>(x, y * y);
    639 }
    640 
    641 extern "C" __declspec(dllexport) void assembly_test17a(const tensor<double, 2>& x, const tensor<double, 2>& y)
    642 {
    643     expression_function ysqr = expression_function{ expression_with_arguments{ y }, fn::sqr{} };
    644     process<8, 0>(x, ysqr);
    645 }
    646 extern "C" __declspec(dllexport) void assembly_test17b(const tensor<double, 2>& x, const tensor<double, 2>& y)
    647 {
    648     expression_function ysqr = expression_function{ expression_with_arguments{ y }, fn::sqr{} };
    649     process<2, 1>(x, ysqr);
    650 }
    651 
    652 extern "C" __declspec(dllexport) void assembly_test18a(const tensor<double, 2>& x, const tensor<double, 2>& y)
    653 {
    654     expression_function ysqr = expression_function{ expression_with_arguments{ y }, fn::sqr{} };
    655     process<8, 0>(fixshape(x, fixed_shape<8, 2>), fixshape(ysqr, fixed_shape<8, 2>));
    656 }
    657 extern "C" __declspec(dllexport) void assembly_test18b(const tensor<double, 2>& x, const tensor<double, 2>& y)
    658 {
    659     expression_function ysqr = expression_function{ expression_with_arguments{ y }, fn::sqr{} };
    660     process<2, 1>(fixshape(x, fixed_shape<8, 2>), fixshape(ysqr, fixed_shape<8, 2>));
    661 }
    662 
    663 extern "C" __declspec(dllexport) void assembly_test19(const tensor<double, 2>& x,
    664                                                       const expression_reshape<tensor<double, 2>, 2>& y)
    665 {
    666     process(x, y);
    667 }
    668 
    669 extern "C" __declspec(dllexport) shape<2> assembly_test20_2(const shape<2>& x, size_t fl)
    670 {
    671     return x.from_flat(fl);
    672 }
    673 extern "C" __declspec(dllexport) shape<4> assembly_test20_4(const shape<4>& x, size_t fl)
    674 {
    675     return x.from_flat(fl);
    676 }
    677 
    678 extern "C" __declspec(dllexport) shape<4> assembly_test21(const shape<4>& x, size_t fl)
    679 {
    680     return x.from_flat(fl);
    681 }
    682 extern "C" __declspec(dllexport) float assembly_test22(const std::array<float, 440>& x,
    683                                                        const std::array<float, 440>& y)
    684 {
    685     return dotproduct(x, y);
    686 }
    687 extern "C" __declspec(dllexport) float assembly_test23(const std::array<float, 440>& x) { return rms(x); }
    688 #endif
    689 
    690 struct val
    691 {
    692 };
    693 template <>
    694 struct expression_traits<val> : expression_traits_defaults
    695 {
    696     using value_type             = float;
    697     constexpr static size_t dims = 0;
    698     constexpr static shape<dims> get_shape(const val&) { return {}; }
    699     constexpr static shape<dims> get_shape() { return {}; }
    700 };
    701 
    702 inline namespace CMT_ARCH_NAME
    703 {
    704 val rvint_func() { return val{}; }
    705 val& lvint_func()
    706 {
    707     static val v;
    708     return v;
    709 }
    710 TEST(expression_with_arguments)
    711 {
    712     expression_function fn1 = expression_function{ expression_with_arguments{ rvint_func() }, fn::add{} };
    713     static_assert(std::is_same_v<decltype(fn1)::nth<0>, val>);
    714 
    715     expression_function fn2 = expression_function{ expression_with_arguments{ lvint_func() }, fn::add{} };
    716     static_assert(std::is_same_v<decltype(fn2)::nth<0>, val&>);
    717 
    718     expression_function fn3 =
    719         expression_function{ expression_with_arguments{ std::as_const(lvint_func()) }, fn::add{} };
    720     static_assert(std::is_same_v<decltype(fn3)::nth<0>, const val&>);
    721 }
    722 
    723 TEST(slices)
    724 {
    725     const auto _ = std::nullopt;
    726     tensor<float, 1> t1{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
    727     CHECK(t1(tstart(3)) == tensor<float, 1>{ 3, 4, 5, 6, 7, 8, 9 });
    728     CHECK(t1(tstop(3)) == tensor<float, 1>{ 0, 1, 2 });
    729     CHECK(t1(trange(3, 7)) == tensor<float, 1>{ 3, 4, 5, 6 });
    730 
    731     CHECK(t1(tstart(10)) == tensor<float, 1>{});
    732     CHECK(t1(tstop(0)) == tensor<float, 1>{});
    733     CHECK(t1(trange(7, 3)) == tensor<float, 1>{});
    734 
    735     CHECK(t1(tstart(-2)) == tensor<float, 1>{ 8, 9 });
    736     CHECK(t1(trange(-7, -4)) == tensor<float, 1>{ 3, 4, 5 });
    737     CHECK(t1(tall()) == tensor<float, 1>{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
    738 
    739     CHECK(t1(trange(3, _)) == tensor<float, 1>{ 3, 4, 5, 6, 7, 8, 9 });
    740     CHECK(t1(trange(_, 7)) == tensor<float, 1>{ 0, 1, 2, 3, 4, 5, 6 });
    741 
    742     CHECK(t1(trange(_, _, 2)) == tensor<float, 1>{ 0, 2, 4, 6, 8 });
    743     CHECK(t1(trange(_, _, 5)) == tensor<float, 1>{ 0, 5 });
    744     CHECK(t1(trange(_, _, 12)) == tensor<float, 1>{ 0 });
    745     CHECK(t1(trange(1, _, 2)) == tensor<float, 1>{ 1, 3, 5, 7, 9 });
    746     CHECK(t1(trange(1, _, 5)) == tensor<float, 1>{ 1, 6 });
    747     CHECK(t1(trange(1, _, 12)) == tensor<float, 1>{ 1 });
    748 
    749     CHECK(t1(tstep(2))(tstep(2)) == tensor<float, 1>{ 0, 4, 8 });
    750     CHECK(t1(tstep(2))(tstep(2))(tstep(2)) == tensor<float, 1>{ 0, 8 });
    751     CHECK(t1(tstep(2))(tstep(3)) == tensor<float, 1>{ 0, 6 });
    752 
    753     CHECK(t1(trange(_, _, -1)) == tensor<float, 1>{ 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 });
    754     CHECK(t1(trange(5, _, -1)) == tensor<float, 1>{ 5, 4, 3, 2, 1, 0 });
    755     CHECK(t1(trange(1, 0, -1)) == tensor<float, 1>{ 1 });
    756 
    757     CHECK(t1(trange(3, 3 + 12, 0)) == tensor<float, 1>{ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3 });
    758 }
    759 
    760 TEST(complex_tensors)
    761 {
    762     tensor<complex<float>, 1> t1{
    763         complex<float>(0, -1),
    764     };
    765     CHECK(trender(expression_function{ expression_with_arguments{ t1, complex<float>(0, 1) }, fn::mul{} }) ==
    766           tensor<complex<float>, 1>{ complex<float>(1, 0) });
    767     CHECK(trender(expression_function{ expression_with_arguments{ t1, complex<float>(1, 0) }, fn::mul{} }) ==
    768           tensor<complex<float>, 1>{ complex<float>(0, -1) });
    769     CHECK(trender(expression_function{ expression_with_arguments{ t1, complex<float>(0, -1) }, fn::mul{} }) ==
    770           tensor<complex<float>, 1>{ complex<float>(-1, 0) });
    771     CHECK(trender(expression_function{ expression_with_arguments{ t1, complex<float>(-1, 0) }, fn::mul{} }) ==
    772           tensor<complex<float>, 1>{ complex<float>(0, 1) });
    773 }
    774 
    775 TEST(from_ilist)
    776 {
    777     tensor<float, 1> t1{ 10, 20, 30, 40 };
    778     CHECK(t1 == tensor<float, 1>(shape{ 4 }, { 10, 20, 30, 40 }));
    779 
    780     tensor<float, 2> t2{ { 10, 20 }, { 30, 40 } };
    781     CHECK(t2 == tensor<float, 2>(shape{ 2, 2 }, { 10, 20, 30, 40 }));
    782 
    783     tensor<float, 2> t3{ { 10, 20 } };
    784     CHECK(t3 == tensor<float, 2>(shape{ 1, 2 }, { 10, 20 }));
    785 
    786     tensor<float, 3> t4{ { { 10, 20 }, { 30, 40 } }, { { 50, 60 }, { 70, 80 } } };
    787     CHECK(t4 == tensor<float, 3>(shape{ 2, 2, 2 }, { 10, 20, 30, 40, 50, 60, 70, 80 }));
    788 }
    789 
    790 TEST(sharing_data)
    791 {
    792     tensor<int, 2> t{ { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } };
    793     auto t2  = t; // share data
    794     t2(0, 0) = 10;
    795     CHECK(t == tensor<int, 2>{ { 10, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } });
    796     auto t3 = t(0, tall());
    797     CHECK(t3 == tensor<int, 1>{ 10, 2, 3 });
    798     t3 *= 10;
    799     CHECK(t3 == tensor<int, 1>{ 100, 20, 30 });
    800     CHECK(t == tensor<int, 2>{ { 100, 20, 30 }, { 4, 5, 6 }, { 7, 8, 9 } });
    801     t(trange(0, 2), trange(0, 2)) = 0;
    802     CHECK(t == tensor<int, 2>{ { 0, 0, 30 }, { 0, 0, 6 }, { 7, 8, 9 } });
    803 }
    804 
    805 TEST(tensor_from_container)
    806 {
    807     std::vector<int> a{ 1, 2, 3 };
    808     auto t = tensor_from_container(a);
    809     CHECK(t.shape() == shape{ 3 });
    810     CHECK(t == tensor<int, 1>{ 1, 2, 3 });
    811 }
    812 
    813 } // namespace CMT_ARCH_NAME
    814 
    815 template <typename T, index_t Size>
    816 struct identity_matrix
    817 {
    818 };
    819 
    820 template <typename T, index_t Size>
    821 struct expression_traits<identity_matrix<T, Size>> : expression_traits_defaults
    822 {
    823     using value_type             = T;
    824     constexpr static size_t dims = 2;
    825     constexpr static shape<2> get_shape(const identity_matrix<T, Size>& self) { return { Size, Size }; }
    826     constexpr static shape<2> get_shape() { return { Size, Size }; }
    827 };
    828 
    829 template <typename T, index_t Size, index_t Axis, size_t N>
    830 vec<T, N> get_elements(const identity_matrix<T, Size>& self, const shape<2>& index,
    831                        const axis_params<Axis, N>& sh)
    832 {
    833     return select(indices<0>(index, sh) == indices<1>(index, sh), 1, 0);
    834 }
    835 
    836 inline namespace CMT_ARCH_NAME
    837 {
    838 
    839 TEST(identity_matrix)
    840 {
    841     CHECK(trender(identity_matrix<float, 3>{}) == tensor<float, 2>{ { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, 1 } });
    842 }
    843 
    844 } // namespace CMT_ARCH_NAME
    845 
    846 } // namespace kfr
    847 CMT_PRAGMA_MSVC(warning(pop))