basic_expressions.cpp (5446B)
1 /** 2 * KFR (https://www.kfrlib.com) 3 * Copyright (C) 2016-2023 Dan Cazarin 4 * See LICENSE.txt for details 5 */ 6 7 #include <kfr/base/basic_expressions.hpp> 8 #include <kfr/base/simd_expressions.hpp> 9 #include <kfr/base/univector.hpp> 10 #include <kfr/io/tostring.hpp> 11 12 namespace kfr 13 { 14 inline namespace CMT_ARCH_NAME 15 { 16 17 TEST(linspace) 18 { 19 testo::epsilon_scope<> eps(10); 20 CHECK_EXPRESSION(linspace(0.0, 1.0, 5, true, ctrue), { 0.0, 0.25, 0.50, 0.75, 1.0 }); 21 CHECK_EXPRESSION(linspace(0.0, 1.0, 4, false, ctrue), { 0.0, 0.25, 0.50, 0.75 }); 22 CHECK(get_shape(linspace(0.0, 1.0, 5, true, cfalse)) == shape{ infinite_size }); 23 CHECK_EXPRESSION(linspace(0.0, 1.0, 4, false, ctrue), { 0.0, 0.25, 0.50, 0.75 }); 24 CHECK_EXPRESSION(symmlinspace(3.0, 4, ctrue), { -3.0, -1.00, 1.00, 3.00 }); 25 26 CHECK_EXPRESSION(linspace(1, 21, 4, false, ctrue), { 1, 6, 11, 16 }); 27 CHECK_EXPRESSION(linspace(1, 21, 4, true, ctrue), { 1, 7.66666667f, 14.3333333f, 21 }); 28 } 29 30 TEST(counter_shape) 31 { 32 CHECK(get_shape(1) == shape{}); 33 CHECK(get_shape(counter()) == shape{ infinite_size }); 34 CHECK(get_shape(counter() + 1) == shape{ infinite_size }); 35 CHECK(get_shape(counter(0, 1, 1)) == shape{ infinite_size, infinite_size }); 36 } 37 38 TEST(pack) 39 { 40 static_assert(std::is_same_v<vec<f32x2, 1>, std::invoke_result_t<fn::reverse, vec<f32x2, 1>>>); 41 const univector<float, 21> v1 = 1 + counter(); 42 const univector<float, 21> v2 = v1 * 11; 43 44 CHECK_EXPRESSION(pack(v1, v2), 21, [](float i) { return f32x2{ 1 + i, (1 + i) * 11 }; }); 45 46 CHECK_EXPRESSION(bind_expression(fn::reverse(), pack(v1, v2)), 21, 47 [](float i) { 48 return f32x2{ (1 + i) * 11, 1 + i }; 49 }); 50 } 51 52 TEST(adjacent) 53 { 54 CHECK_EXPRESSION(adjacent(fn::mul(), counter()), infinite_size, 55 [](size_t i) { return i > 0 ? i * (i - 1) : 0; }); 56 } 57 58 TEST(dimensions) 59 { 60 static_assert(expression_dims<decltype(scalar(0))> == 0); 61 static_assert(expression_dims<decltype(dimensions<1>(scalar(0)))> == 1); 62 63 static_assert(get_shape<decltype(scalar(0))>() == shape{}); 64 static_assert(get_shape<decltype(dimensions<1>(scalar(0)))>() == shape{ infinite_size }); 65 static_assert(get_shape<decltype(dimensions<2>(dimensions<1>(scalar(0))))>() == 66 shape{ infinite_size, infinite_size }); 67 CHECK_EXPRESSION(truncate(dimensions<1>(scalar(1)), 10), { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }); 68 } 69 70 TEST(padded) 71 { 72 static_assert(is_infinite<decltype(padded(counter()))>, ""); 73 static_assert(is_infinite<decltype(padded(truncate(counter(), 100)))>, ""); 74 75 CHECK_EXPRESSION(padded(truncate(counter(), 6), -1), infinite_size, 76 [](size_t i) { return i >= 6 ? -1 : i; }); 77 78 CHECK_EXPRESSION(padded(truncate(counter(), 0), -1), infinite_size, [](size_t i) { return -1; }); 79 80 CHECK_EXPRESSION(padded(truncate(counter(), 501), -1), infinite_size, 81 [](size_t i) { return i >= 501 ? -1 : i; }); 82 } 83 84 TEST(concatenate) 85 { 86 CHECK_EXPRESSION(concatenate(truncate(counter(5, 0), 5), truncate(counter(10, 0), 5)), 87 { 5, 5, 5, 5, 5, 10, 10, 10, 10, 10 }); 88 } 89 90 #ifndef CMT_COMPILER_IS_MSVC 91 // The following test causes ICE in recent MSVC 92 TEST(rebind) 93 { 94 auto c_minus_two = counter() - 2; 95 auto four_minus_c = rebind(c_minus_two, 4, counter()); 96 CHECK_EXPRESSION(counter(), infinite_size, [](size_t i) { return i; }); 97 CHECK_EXPRESSION(c_minus_two, infinite_size, [](size_t i) { return i - 2; }); 98 CHECK_EXPRESSION(four_minus_c, infinite_size, [](size_t i) { return 4 - i; }); 99 } 100 #endif 101 102 TEST(test_arg_access) 103 { 104 univector<float> v1(10); 105 v1 = counter(); 106 auto e1 = std::move(v1) + 10; 107 std::get<0>(e1.args)[0] = 100; 108 std::get<1>(e1.args) = 1; 109 110 CHECK_EXPRESSION(e1, 10, [](size_t i) { return (i == 0 ? 100 : i) + 1; }); 111 } 112 113 TEST(size_calc) 114 { 115 auto a = counter(); 116 CHECK(get_shape(a) == shape{ infinite_size }); 117 auto b = slice(counter(), 100); 118 CHECK(get_shape(b) == shape{ infinite_size }); 119 auto c = slice(counter(), 100, 1000); 120 CHECK(get_shape(c) == shape{ 1000 }); 121 auto d = slice(c, 100); 122 CHECK(get_shape(d) == shape{ 900 }); 123 } 124 125 TEST(reverse_expression) 126 { 127 CHECK_EXPRESSION(reverse(truncate(counter(), 21)), 21, [](size_t i) { return 20 - i; }); 128 } 129 130 TEST(sequence) 131 { 132 CHECK_EXPRESSION(sequence(0, 0.5f, 1, 0.5f), infinite_size, 133 [](size_t i) { 134 return std::array<float, 4>{ 0, 0.5f, 1, 0.5f }[i % 4]; 135 }); 136 } 137 138 TEST(assign_expression) 139 { 140 univector<float> f = truncate(counter(0, 1), 10); 141 f *= 10; 142 CHECK_EXPRESSION(f, { 0, 10, 20, 30, 40, 50, 60, 70, 80, 90 }); 143 144 univector<float> a = truncate(counter(0, 1), 10); 145 univector<float> b = truncate(counter(100, 1), 10); 146 pack(a, b) *= broadcast<2>(10.f); 147 CHECK_EXPRESSION(a, { 0, 10, 20, 30, 40, 50, 60, 70, 80, 90 }); 148 CHECK_EXPRESSION(b, { 1000, 1010, 1020, 1030, 1040, 1050, 1060, 1070, 1080, 1090 }); 149 150 static_assert(std::is_same_v<std::common_type_t<f32x2x2, f32x2x2>, f32x2x2>); 151 static_assert( 152 std::is_same_v<std::common_type_t<vec<vec<double, 2>, 1>, vec<double, 2>>, vec<vec<double, 2>, 1>>); 153 } 154 155 TEST(trace) { render(trace(counter()), 44); } 156 157 TEST(get_element) { CHECK(get_element(counter(0, 1, 10, 100), { 1, 2, 3 }) == 321); } 158 159 } // namespace CMT_ARCH_NAME 160 } // namespace kfr