kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

commit 4a08fc6b325b6fcfa8d99c30fe24b8a12bd9b806
parent a1e3d299ba8f30e7bcd871b893fbf979b58310db
Author: [email protected] <[email protected]>
Date:   Thu, 24 Nov 2022 07:11:44 +0000

Fix get_arg: incorrect VecAxis

Diffstat:
Minclude/kfr/base/expression.hpp | 20+++++++++++---------
Minclude/kfr/base/tensor.hpp | 2++
2 files changed, 13 insertions(+), 9 deletions(-)

diff --git a/include/kfr/base/expression.hpp b/include/kfr/base/expression.hpp @@ -540,39 +540,41 @@ KFR_INTRINSIC void end_pass_args(const expression_with_arguments<Args...>& self, (end_pass(std::get<idx>(self.args), start, stop), ...); } -template <index_t outdims, typename Fn, typename... Args, index_t Axis, size_t N, index_t Dims, size_t idx, +template <index_t outdims, typename Fn, typename... Args, index_t VecAxis, size_t N, index_t Dims, size_t idx, typename Traits = expression_traits<typename expression_function<Fn, Args...>::template nth<idx>>> KFR_MEM_INTRINSIC vec<typename Traits::value_type, N> get_arg(const expression_function<Fn, Args...>& self, const shape<Dims>& index, - const axis_params<Axis, N>& sh, csize_t<idx>) + const axis_params<VecAxis, N>& sh, csize_t<idx>) { if constexpr (Traits::dims == 0) { - return repeat<N>(get_elements(std::get<idx>(self.args), {}, axis_params<Axis, 1>{})); + return repeat<N>(get_elements(std::get<idx>(self.args), {}, axis_params<0, 1>{})); } else { - auto indices = internal_generic::adapt<Traits::dims>(index, self.getmask(csize<idx>)); - constexpr index_t last_dim = Traits::get_shape().back(); + constexpr size_t NewVecAxis = Traits::dims - (Dims - VecAxis); + auto indices = internal_generic::adapt<Traits::dims>(index, self.getmask(csize<idx>)); + constexpr index_t last_dim = Traits::get_shape().back(); if constexpr (last_dim != undefined_size) { constexpr index_t last_dim_pot = prev_poweroftwo(last_dim); return repeat<N / std::min(last_dim_pot, static_cast<index_t>(N))>( get_elements(std::get<idx>(self.args), indices, - axis_params<Axis, std::min(last_dim_pot, static_cast<index_t>(N))>{})); + axis_params<NewVecAxis, std::min(last_dim_pot, static_cast<index_t>(N))>{})); } else { if constexpr (sizeof...(Args) > 1 && N > 1) { if (CMT_UNLIKELY(self.masks[idx].back() == 0)) - return get_elements(std::get<idx>(self.args), indices, axis_params<Axis, 1>{}).front(); + return get_elements(std::get<idx>(self.args), indices, axis_params<NewVecAxis, 1>{}) + .front(); else - return get_elements(std::get<idx>(self.args), indices, sh); + return get_elements(std::get<idx>(self.args), indices, axis_params<NewVecAxis, N>{}); } else { - return get_elements(std::get<idx>(self.args), indices, sh); + return get_elements(std::get<idx>(self.args), indices, axis_params<NewVecAxis, N>{}); } } } diff --git a/include/kfr/base/tensor.hpp b/include/kfr/base/tensor.hpp @@ -873,6 +873,7 @@ template <typename T, index_t NDims, index_t Axis, size_t N> KFR_INTRINSIC vec<T, N> get_elements(const tensor<T, NDims>& self, const shape<NDims>& index, const axis_params<Axis, N>&) { + static_assert(Axis < NDims || NDims == 0); const T* data = self.data() + self.calc_index(index); if constexpr (NDims == 0) { @@ -891,6 +892,7 @@ template <typename T, index_t NDims, index_t Axis, size_t N> KFR_INTRINSIC void set_elements(const tensor<T, NDims>& self, const shape<NDims>& index, const axis_params<Axis, N>&, const identity<vec<T, N>>& value) { + static_assert(Axis < NDims || NDims == 0); T* data = self.data() + self.calc_index(index); if constexpr (NDims == 0) {