kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

commit d2903072f93a2a245ea69e1a052a22c731148ad9
parent 64af38a71dc1c83a2df4d71ad95572813a150ec6
Author: [email protected] <[email protected]>
Date:   Thu,  6 Oct 2016 01:46:41 +0300

New function: padded for padding expression

Diffstat:
Minclude/kfr/base/basic_expressions.hpp | 53+++++++++++++++++++++++++++++++++++++++++++++++++++++
Mtests/expression_test.cpp | 16++++++++++++++++
2 files changed, 69 insertions(+), 0 deletions(-)

diff --git a/include/kfr/base/basic_expressions.hpp b/include/kfr/base/basic_expressions.hpp @@ -412,6 +412,59 @@ CMT_INLINE internal::expression_adjacent<Fn, E1> adjacent(Fn&& fn, E1&& e1) namespace internal { +template <typename E> +struct expression_padded : expression<E> +{ + using value_type = value_type_of<E>; + + CMT_INLINE constexpr static size_t size() noexcept { return infinite_size; } + + expression_padded(value_type fill_value, E&& e) + : fill_value(fill_value), input_size(e.size()), expression<E>(std::forward<E>(e)) + { + } + + template <size_t N> + vec<value_type, N> operator()(cinput_t cinput, size_t index, vec_t<value_type, N> y) const + { + if (index >= input_size) + { + return fill_value; + } + else if (index + N <= input_size) + { + return this->argument_first(cinput, index, y); + } + else + { + vec<value_type, N> x; + for (size_t i = 0; i < N; i++) + { + if (index + i < input_size) + x.data()[i] = this->argument_first(cinput, index + i, vec_t<value_type, 1>())[0]; + else + x.data()[i] = fill_value; + } + return x; + } + } + value_type fill_value; + const size_t input_size; +}; +} + +/** + * @brief Returns infinite template expression that pads e with fill_value (default value = 0) + */ +template <typename E, typename T = value_type_of<E>> +internal::expression_padded<E> padded(E&& e, const T& fill_value = T(0)) +{ + static_assert(is_input_expression<E>::value, "E must be an input expression"); + return internal::expression_padded<E>(fill_value, std::forward<E>(e)); +} + +namespace internal +{ template <typename... E> struct multioutput : output_expression { diff --git a/tests/expression_test.cpp b/tests/expression_test.cpp @@ -43,6 +43,22 @@ TEST(adjacent) CHECK(v1[19] == 342); } +TEST(padded) +{ + static_assert(is_infinite<decltype(padded(counter()))>::value, ""); + static_assert(is_infinite<decltype(padded(truncate(counter(), 100)))>::value, ""); + + univector<int, 21> v1 = padded(truncate(counter(), 6), -1); + CHECK(v1[0] == 0); + CHECK(v1[1] == 1); + CHECK(v1[2] == 2); + CHECK(v1[3] == 3); + CHECK(v1[4] == 4); + CHECK(v1[5] == 5); + CHECK(v1[6] == -1); + CHECK(v1[20] == -1); +} + TEST(rebind) { auto c_minus_two = counter() - 2;