commit 4a08fc6b325b6fcfa8d99c30fe24b8a12bd9b806
parent a1e3d299ba8f30e7bcd871b893fbf979b58310db
Author: [email protected] <[email protected]>
Date: Thu, 24 Nov 2022 07:11:44 +0000
Fix get_arg: incorrect VecAxis
Diffstat:
2 files changed, 13 insertions(+), 9 deletions(-)
diff --git a/include/kfr/base/expression.hpp b/include/kfr/base/expression.hpp
@@ -540,39 +540,41 @@ KFR_INTRINSIC void end_pass_args(const expression_with_arguments<Args...>& self,
(end_pass(std::get<idx>(self.args), start, stop), ...);
}
-template <index_t outdims, typename Fn, typename... Args, index_t Axis, size_t N, index_t Dims, size_t idx,
+template <index_t outdims, typename Fn, typename... Args, index_t VecAxis, size_t N, index_t Dims, size_t idx,
typename Traits = expression_traits<typename expression_function<Fn, Args...>::template nth<idx>>>
KFR_MEM_INTRINSIC vec<typename Traits::value_type, N> get_arg(const expression_function<Fn, Args...>& self,
const shape<Dims>& index,
- const axis_params<Axis, N>& sh, csize_t<idx>)
+ const axis_params<VecAxis, N>& sh, csize_t<idx>)
{
if constexpr (Traits::dims == 0)
{
- return repeat<N>(get_elements(std::get<idx>(self.args), {}, axis_params<Axis, 1>{}));
+ return repeat<N>(get_elements(std::get<idx>(self.args), {}, axis_params<0, 1>{}));
}
else
{
- auto indices = internal_generic::adapt<Traits::dims>(index, self.getmask(csize<idx>));
- constexpr index_t last_dim = Traits::get_shape().back();
+ constexpr size_t NewVecAxis = Traits::dims - (Dims - VecAxis);
+ auto indices = internal_generic::adapt<Traits::dims>(index, self.getmask(csize<idx>));
+ constexpr index_t last_dim = Traits::get_shape().back();
if constexpr (last_dim != undefined_size)
{
constexpr index_t last_dim_pot = prev_poweroftwo(last_dim);
return repeat<N / std::min(last_dim_pot, static_cast<index_t>(N))>(
get_elements(std::get<idx>(self.args), indices,
- axis_params<Axis, std::min(last_dim_pot, static_cast<index_t>(N))>{}));
+ axis_params<NewVecAxis, std::min(last_dim_pot, static_cast<index_t>(N))>{}));
}
else
{
if constexpr (sizeof...(Args) > 1 && N > 1)
{
if (CMT_UNLIKELY(self.masks[idx].back() == 0))
- return get_elements(std::get<idx>(self.args), indices, axis_params<Axis, 1>{}).front();
+ return get_elements(std::get<idx>(self.args), indices, axis_params<NewVecAxis, 1>{})
+ .front();
else
- return get_elements(std::get<idx>(self.args), indices, sh);
+ return get_elements(std::get<idx>(self.args), indices, axis_params<NewVecAxis, N>{});
}
else
{
- return get_elements(std::get<idx>(self.args), indices, sh);
+ return get_elements(std::get<idx>(self.args), indices, axis_params<NewVecAxis, N>{});
}
}
}
diff --git a/include/kfr/base/tensor.hpp b/include/kfr/base/tensor.hpp
@@ -873,6 +873,7 @@ template <typename T, index_t NDims, index_t Axis, size_t N>
KFR_INTRINSIC vec<T, N> get_elements(const tensor<T, NDims>& self, const shape<NDims>& index,
const axis_params<Axis, N>&)
{
+ static_assert(Axis < NDims || NDims == 0);
const T* data = self.data() + self.calc_index(index);
if constexpr (NDims == 0)
{
@@ -891,6 +892,7 @@ template <typename T, index_t NDims, index_t Axis, size_t N>
KFR_INTRINSIC void set_elements(const tensor<T, NDims>& self, const shape<NDims>& index,
const axis_params<Axis, N>&, const identity<vec<T, N>>& value)
{
+ static_assert(Axis < NDims || NDims == 0);
T* data = self.data() + self.calc_index(index);
if constexpr (NDims == 0)
{