kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

basic_expressions.cpp (5446B)


      1 /**
      2  * KFR (https://www.kfrlib.com)
      3  * Copyright (C) 2016-2023 Dan Cazarin
      4  * See LICENSE.txt for details
      5  */
      6 
      7 #include <kfr/base/basic_expressions.hpp>
      8 #include <kfr/base/simd_expressions.hpp>
      9 #include <kfr/base/univector.hpp>
     10 #include <kfr/io/tostring.hpp>
     11 
     12 namespace kfr
     13 {
     14 inline namespace CMT_ARCH_NAME
     15 {
     16 
     17 TEST(linspace)
     18 {
     19     testo::epsilon_scope<> eps(10);
     20     CHECK_EXPRESSION(linspace(0.0, 1.0, 5, true, ctrue), { 0.0, 0.25, 0.50, 0.75, 1.0 });
     21     CHECK_EXPRESSION(linspace(0.0, 1.0, 4, false, ctrue), { 0.0, 0.25, 0.50, 0.75 });
     22     CHECK(get_shape(linspace(0.0, 1.0, 5, true, cfalse)) == shape{ infinite_size });
     23     CHECK_EXPRESSION(linspace(0.0, 1.0, 4, false, ctrue), { 0.0, 0.25, 0.50, 0.75 });
     24     CHECK_EXPRESSION(symmlinspace(3.0, 4, ctrue), { -3.0, -1.00, 1.00, 3.00 });
     25 
     26     CHECK_EXPRESSION(linspace(1, 21, 4, false, ctrue), { 1, 6, 11, 16 });
     27     CHECK_EXPRESSION(linspace(1, 21, 4, true, ctrue), { 1, 7.66666667f, 14.3333333f, 21 });
     28 }
     29 
     30 TEST(counter_shape)
     31 {
     32     CHECK(get_shape(1) == shape{});
     33     CHECK(get_shape(counter()) == shape{ infinite_size });
     34     CHECK(get_shape(counter() + 1) == shape{ infinite_size });
     35     CHECK(get_shape(counter(0, 1, 1)) == shape{ infinite_size, infinite_size });
     36 }
     37 
     38 TEST(pack)
     39 {
     40     static_assert(std::is_same_v<vec<f32x2, 1>, std::invoke_result_t<fn::reverse, vec<f32x2, 1>>>);
     41     const univector<float, 21> v1 = 1 + counter();
     42     const univector<float, 21> v2 = v1 * 11;
     43 
     44     CHECK_EXPRESSION(pack(v1, v2), 21, [](float i) { return f32x2{ 1 + i, (1 + i) * 11 }; });
     45 
     46     CHECK_EXPRESSION(bind_expression(fn::reverse(), pack(v1, v2)), 21,
     47                      [](float i) {
     48                          return f32x2{ (1 + i) * 11, 1 + i };
     49                      });
     50 }
     51 
     52 TEST(adjacent)
     53 {
     54     CHECK_EXPRESSION(adjacent(fn::mul(), counter()), infinite_size,
     55                      [](size_t i) { return i > 0 ? i * (i - 1) : 0; });
     56 }
     57 
     58 TEST(dimensions)
     59 {
     60     static_assert(expression_dims<decltype(scalar(0))> == 0);
     61     static_assert(expression_dims<decltype(dimensions<1>(scalar(0)))> == 1);
     62 
     63     static_assert(get_shape<decltype(scalar(0))>() == shape{});
     64     static_assert(get_shape<decltype(dimensions<1>(scalar(0)))>() == shape{ infinite_size });
     65     static_assert(get_shape<decltype(dimensions<2>(dimensions<1>(scalar(0))))>() ==
     66                   shape{ infinite_size, infinite_size });
     67     CHECK_EXPRESSION(truncate(dimensions<1>(scalar(1)), 10), { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 });
     68 }
     69 
     70 TEST(padded)
     71 {
     72     static_assert(is_infinite<decltype(padded(counter()))>, "");
     73     static_assert(is_infinite<decltype(padded(truncate(counter(), 100)))>, "");
     74 
     75     CHECK_EXPRESSION(padded(truncate(counter(), 6), -1), infinite_size,
     76                      [](size_t i) { return i >= 6 ? -1 : i; });
     77 
     78     CHECK_EXPRESSION(padded(truncate(counter(), 0), -1), infinite_size, [](size_t i) { return -1; });
     79 
     80     CHECK_EXPRESSION(padded(truncate(counter(), 501), -1), infinite_size,
     81                      [](size_t i) { return i >= 501 ? -1 : i; });
     82 }
     83 
     84 TEST(concatenate)
     85 {
     86     CHECK_EXPRESSION(concatenate(truncate(counter(5, 0), 5), truncate(counter(10, 0), 5)),
     87                      { 5, 5, 5, 5, 5, 10, 10, 10, 10, 10 });
     88 }
     89 
     90 #ifndef CMT_COMPILER_IS_MSVC
     91 // The following test causes ICE in recent MSVC
     92 TEST(rebind)
     93 {
     94     auto c_minus_two  = counter() - 2;
     95     auto four_minus_c = rebind(c_minus_two, 4, counter());
     96     CHECK_EXPRESSION(counter(), infinite_size, [](size_t i) { return i; });
     97     CHECK_EXPRESSION(c_minus_two, infinite_size, [](size_t i) { return i - 2; });
     98     CHECK_EXPRESSION(four_minus_c, infinite_size, [](size_t i) { return 4 - i; });
     99 }
    100 #endif
    101 
    102 TEST(test_arg_access)
    103 {
    104     univector<float> v1(10);
    105     v1                      = counter();
    106     auto e1                 = std::move(v1) + 10;
    107     std::get<0>(e1.args)[0] = 100;
    108     std::get<1>(e1.args)    = 1;
    109 
    110     CHECK_EXPRESSION(e1, 10, [](size_t i) { return (i == 0 ? 100 : i) + 1; });
    111 }
    112 
    113 TEST(size_calc)
    114 {
    115     auto a = counter();
    116     CHECK(get_shape(a) == shape{ infinite_size });
    117     auto b = slice(counter(), 100);
    118     CHECK(get_shape(b) == shape{ infinite_size });
    119     auto c = slice(counter(), 100, 1000);
    120     CHECK(get_shape(c) == shape{ 1000 });
    121     auto d = slice(c, 100);
    122     CHECK(get_shape(d) == shape{ 900 });
    123 }
    124 
    125 TEST(reverse_expression)
    126 {
    127     CHECK_EXPRESSION(reverse(truncate(counter(), 21)), 21, [](size_t i) { return 20 - i; });
    128 }
    129 
    130 TEST(sequence)
    131 {
    132     CHECK_EXPRESSION(sequence(0, 0.5f, 1, 0.5f), infinite_size,
    133                      [](size_t i) {
    134                          return std::array<float, 4>{ 0, 0.5f, 1, 0.5f }[i % 4];
    135                      });
    136 }
    137 
    138 TEST(assign_expression)
    139 {
    140     univector<float> f = truncate(counter(0, 1), 10);
    141     f *= 10;
    142     CHECK_EXPRESSION(f, { 0, 10, 20, 30, 40, 50, 60, 70, 80, 90 });
    143 
    144     univector<float> a = truncate(counter(0, 1), 10);
    145     univector<float> b = truncate(counter(100, 1), 10);
    146     pack(a, b) *= broadcast<2>(10.f);
    147     CHECK_EXPRESSION(a, { 0, 10, 20, 30, 40, 50, 60, 70, 80, 90 });
    148     CHECK_EXPRESSION(b, { 1000, 1010, 1020, 1030, 1040, 1050, 1060, 1070, 1080, 1090 });
    149 
    150     static_assert(std::is_same_v<std::common_type_t<f32x2x2, f32x2x2>, f32x2x2>);
    151     static_assert(
    152         std::is_same_v<std::common_type_t<vec<vec<double, 2>, 1>, vec<double, 2>>, vec<vec<double, 2>, 1>>);
    153 }
    154 
    155 TEST(trace) { render(trace(counter()), 44); }
    156 
    157 TEST(get_element) { CHECK(get_element(counter(0, 1, 10, 100), { 1, 2, 3 }) == 321); }
    158 
    159 } // namespace CMT_ARCH_NAME
    160 } // namespace kfr