commit 1f2d9979a26406e321d3ee2abf33c70081ce1587
parent 6d8923feb4f7b538289b58674c3b54a85932915c
Author: [email protected] <[email protected]>
Date: Mon, 1 Apr 2019 17:20:43 +0000
concatenate
Diffstat:
1 file changed, 71 insertions(+), 14 deletions(-)
diff --git a/include/kfr/base/basic_expressions.hpp b/include/kfr/base/basic_expressions.hpp
@@ -66,8 +66,8 @@ struct expression_convert : expression_with_arguments<E>
}
template <size_t N>
- friend KFR_INTRINSIC vec<To, N> get_elements(const expression_convert& self, cinput_t input,
- size_t index, vec_shape<To, N>)
+ friend KFR_INTRINSIC vec<To, N> get_elements(const expression_convert& self, cinput_t input, size_t index,
+ vec_shape<To, N>)
{
return self.argument_first(input, index, vec_shape<To, N>());
}
@@ -226,8 +226,8 @@ struct expression_slice : expression_with_arguments<E1>
{
}
template <size_t N>
- friend KFR_INTRINSIC vec<T, N> get_elements(const expression_slice& self, cinput_t cinput,
- size_t index, vec_shape<T, N> y)
+ friend KFR_INTRINSIC vec<T, N> get_elements(const expression_slice& self, cinput_t cinput, size_t index,
+ vec_shape<T, N> y)
{
return self.argument_first(cinput, index + self.start, y);
}
@@ -243,8 +243,8 @@ struct expression_reverse : expression_with_arguments<E1>
using T = value_type;
expression_reverse(E1&& e1) : expression_with_arguments<E1>(std::forward<E1>(e1)), expr_size(e1.size()) {}
template <size_t N>
- friend KFR_INTRINSIC vec<T, N> get_elements(const expression_reverse& self, cinput_t cinput,
- size_t index, vec_shape<T, N> y)
+ friend KFR_INTRINSIC vec<T, N> get_elements(const expression_reverse& self, cinput_t cinput, size_t index,
+ vec_shape<T, N> y)
{
return reverse(self.argument_first(cinput, self.expr_size - index - N, y));
}
@@ -275,7 +275,7 @@ struct expression_linspace<T, false> : input_expression
template <size_t N>
friend KFR_INTRINSIC vec<T, N> get_elements(const expression_linspace& self, cinput_t, size_t index,
- vec_shape<T, N> x)
+ vec_shape<T, N> x)
{
using TI = itype<T>;
return T(self.start) + (enumerate(x) + static_cast<T>(static_cast<TI>(index))) * T(self.offset);
@@ -306,7 +306,7 @@ struct expression_linspace<T, true> : input_expression
template <size_t N>
friend KFR_INTRINSIC vec<T, N> get_elements(const expression_linspace& self, cinput_t, size_t index,
- vec_shape<T, N> x)
+ vec_shape<T, N> x)
{
using TI = itype<T>;
return mix((enumerate(x) + static_cast<T>(static_cast<TI>(index))) * self.invsize, self.start,
@@ -344,7 +344,7 @@ public:
template <size_t N>
KFR_INTRINSIC friend vec<T, N> get_elements(const expression_sequence& self, cinput_t cinput,
- size_t index, vec_shape<T, N> y)
+ size_t index, vec_shape<T, N> y)
{
std::size_t sindex =
size_t(std::upper_bound(std::begin(self.segments), std::end(self.segments), index) - 1 -
@@ -368,7 +368,7 @@ public:
protected:
template <size_t N>
KFR_INTRINSIC friend vec<T, N> get_elements(const expression_sequence& self, cinput_t cinput,
- size_t index, size_t expr_index, vec_shape<T, N> y)
+ size_t index, size_t expr_index, vec_shape<T, N> y)
{
return cswitch(indicesfor_t<E...>(), expr_index,
[&](auto val) { return self.argument(cinput, val, index, y); },
@@ -391,7 +391,7 @@ struct expression_adjacent : expression_with_arguments<E>
template <size_t N>
KFR_INTRINSIC friend vec<T, N> get_elements(const expression_adjacent& self, cinput_t cinput,
- size_t index, vec_shape<T, N>)
+ size_t index, vec_shape<T, N>)
{
const vec<T, N> in = self.argument_first(cinput, index, vec_shape<T, N>());
const vec<T, N> delayed = insertleft(self.data, in);
@@ -486,7 +486,7 @@ struct expression_padded : expression_with_arguments<E>
template <size_t N>
KFR_INTRINSIC friend vec<value_type, N> get_elements(const expression_padded& self, cinput_t cinput,
- size_t index, vec_shape<value_type, N> y)
+ size_t index, vec_shape<value_type, N> y)
{
if (index >= self.input_size)
{
@@ -556,8 +556,8 @@ struct expression_pack : expression_with_arguments<E...>
using expression_with_arguments<E...>::size;
template <size_t N>
- friend KFR_INTRINSIC vec<T, N> get_elements(const expression_pack& self, cinput_t cinput,
- size_t index, vec_shape<T, N> y)
+ friend KFR_INTRINSIC vec<T, N> get_elements(const expression_pack& self, cinput_t cinput, size_t index,
+ vec_shape<T, N> y)
{
return self.call(cinput, fn::packtranspose(), index, y);
}
@@ -647,5 +647,62 @@ task_partition<OutExpr, InExpr> partition(OutExpr&& output, InExpr&& input, size
chunk_size, (size + chunk_size - 1) / chunk_size);
return result;
}
+
+namespace internal
+{
+
+template <typename E1, typename E2>
+struct concatenate_expression : expression_with_arguments<E1, E2>
+{
+ using value_type = common_type<value_type_of<E1>, value_type_of<E2>>;
+ using T = value_type;
+
+ KFR_MEM_INTRINSIC constexpr size_t size() const CMT_NOEXCEPT
+ {
+ return size_add(std::get<0>(this->args).size(), std::get<1>(this->args).size());
+ }
+ template <typename E1_, typename E2_>
+ concatenate_expression(E1_&& e1, E2_&& e2)
+ : expression_with_arguments<E1, E2>(std::forward<E1_>(e1), std::forward<E2_>(e2))
+ {
+ }
+
+ template <size_t N>
+ KFR_INTRINSIC friend vec<T, N> get_elements(const concatenate_expression& self, cinput_t cinput,
+ size_t index, vec_shape<T, N> y)
+ {
+ const size_t size0 = std::get<0>(self.args).size();
+ if (index >= size0)
+ {
+ return self.argument(cinput, csize<1>, index - size0, y);
+ }
+ else if (index + N <= size0)
+ {
+ return self.argument(cinput, csize<0>, index, y);
+ }
+ else // (index < size0) && (index + N > size0)
+ {
+ vec<T, N> result;
+ for (size_t i = 0; i < size0 - index; ++i)
+ {
+ result[i] = self.argument(cinput, csize<0>, index + i, vec_shape<T, 1>{})[0];
+ }
+ for (size_t i = size0 - index; i < N; ++i)
+ {
+ result[i] = self.argument(cinput, csize<1>, index + i, vec_shape<T, 1>{})[0];
+ }
+ return result;
+ }
+ }
+};
+} // namespace internal
+
+template <typename E1, typename E2,
+ KFR_ENABLE_IF(is_input_expression<E1>::value&& is_input_expression<E2>::value)>
+internal::concatenate_expression<E1, E2> concatenate(E1&& e1, E2&& e2)
+{
+ return { std::forward<E1>(e1), std::forward<E2>(e2) };
+}
+
} // namespace CMT_ARCH_NAME
} // namespace kfr