commit 2046a9e9e6235e2c6d86da0a65888733f1858122
parent 2d457a951599ad878866b800a6a5ea099c793eef
Author: [email protected] <[email protected]>
Date: Tue, 8 Nov 2016 13:43:54 +0300
FIR filter with external state
Diffstat:
2 files changed, 85 insertions(+), 36 deletions(-)
diff --git a/include/kfr/dsp/fir.hpp b/include/kfr/dsp/fir.hpp
@@ -28,10 +28,8 @@
#include "../base/basic_expressions.hpp"
#include "../base/memory.hpp"
#include "../base/reduce.hpp"
-#include "../base/sin_cos.hpp"
#include "../base/univector.hpp"
#include "../base/vec.hpp"
-#include "window.hpp"
namespace kfr
{
@@ -39,20 +37,66 @@ namespace kfr
template <typename T, size_t Size>
using fir_taps = univector<T, Size>;
+template <size_t tapcount, typename T>
+struct short_fir_state
+{
+ template <size_t N>
+ short_fir_state(const univector<T, N>& taps)
+ : taps(widen<tapcount>(read<N>(taps.data()), T(0))), delayline(0)
+ {
+ }
+ template <size_t N>
+ short_fir_state(const univector<const T, N>& taps)
+ : taps(widen<tapcount>(read<N>(taps.data()), T(0))), delayline(0)
+ {
+ }
+ vec<T, tapcount> taps;
+ mutable vec<T, tapcount - 1> delayline;
+};
+
+template <typename T>
+struct fir_state
+{
+ fir_state(const array_ref<const T>& taps)
+ : taps(taps.size()), delayline(taps.size(), T(0)), delayline_cursor(0)
+ {
+ this->taps = reverse(make_univector(taps.data(), taps.size()));
+ }
+ univector_dyn<T> taps;
+ mutable univector_dyn<T> delayline;
+ mutable size_t delayline_cursor;
+};
+
namespace internal
{
-template <size_t tapcount, typename T, typename E1, KFR_ARCH_DEP>
+
+template <typename T, bool stateless>
+struct state_holder
+{
+ state_holder() = delete;
+ state_holder(const state_holder&) = default;
+ state_holder(state_holder&&) = default;
+ constexpr state_holder(const T& state) noexcept : s(state) {}
+ T s;
+};
+
+template <typename T>
+struct state_holder<T, true>
+{
+ state_holder() = delete;
+ state_holder(const state_holder&) = default;
+ state_holder(state_holder&&) = default;
+ constexpr state_holder(const T& state) noexcept : s(state) {}
+ const T& s;
+};
+
+template <size_t tapcount, typename T, typename E1, bool stateless = false, KFR_ARCH_DEP>
struct expression_short_fir : expression<E1>
{
using value_type = T;
- static_assert(is_poweroftwo(tapcount), "tapcount must be a power of two");
- expression_short_fir(E1&& e1, const array_ref<T>& taps)
- : expression<E1>(std::forward<E1>(e1)), taps(read<tapcount>(taps.data())), delayline(0)
- {
- }
- expression_short_fir(E1&& e1, const array_ref<const T>& taps)
- : expression<E1>(std::forward<E1>(e1)), taps(read<tapcount>(taps.data())), delayline(0)
+ expression_short_fir(E1&& e1, const short_fir_state<tapcount, T>& state)
+ : expression<E1>(std::forward<E1>(e1)), state(state)
{
}
template <size_t N>
@@ -60,48 +104,42 @@ struct expression_short_fir : expression<E1>
{
vec<T, N> in = this->argument_first(cinput, index, x);
- vec<T, N> out = in * taps[0];
- cfor(csize_t<1>(), csize_t<tapcount>(),
- [&](auto I) { out = out + concat_and_slice<tapcount - 1 - I, N>(delayline, in) * taps[I]; });
- delayline = concat_and_slice<N, tapcount - 1>(delayline, in);
+ vec<T, N> out = in * state.s.taps[0];
+ cforeach(csizeseq_t<tapcount - 1, 1>(), [&](auto I) {
+ out = out + concat_and_slice<tapcount - 1 - I, N>(state.s.delayline, in) * state.s.taps[I];
+ });
+ state.s.delayline = concat_and_slice<N, tapcount - 1>(state.s.delayline, in);
return out;
}
- vec<T, tapcount> taps;
- mutable vec<T, tapcount - 1> delayline;
+ state_holder<short_fir_state<tapcount, T>, stateless> state;
};
-template <typename T, typename E1, KFR_ARCH_DEP>
+template <typename T, typename E1, bool stateless = false, KFR_ARCH_DEP>
struct expression_fir : expression<E1>
{
using value_type = T;
- expression_fir(E1&& e1, const array_ref<const T>& taps)
- : expression<E1>(std::forward<E1>(e1)), taps(taps.size()), delayline(taps.size(), T(0)),
- delayline_cursor(0)
- {
- this->taps = reverse(make_univector(taps.data(), taps.size()));
- }
+ expression_fir(E1&& e1, const fir_state<T>& state) : expression<E1>(std::forward<E1>(e1)), state(state) {}
+
template <size_t N>
CMT_INLINE vec<T, N> operator()(cinput_t cinput, size_t index, vec_t<T, N> x) const
{
- const size_t tapcount = taps.size();
+ const size_t tapcount = state.s.taps.size();
const vec<T, N> input = this->argument_first(cinput, index, x);
vec<T, N> output;
- size_t cursor = delayline_cursor;
+ size_t cursor = state.s.delayline_cursor;
CMT_LOOP_NOUNROLL
for (size_t i = 0; i < N; i++)
{
- delayline.ringbuf_write(cursor, input[i]);
- output[i] = dotproduct(taps, delayline.slice(cursor) /*, tapcount - cursor*/) +
- dotproduct(taps.slice(tapcount - cursor), delayline /*, cursor*/);
+ state.s.delayline.ringbuf_write(cursor, input[i]);
+ output[i] = dotproduct(state.s.taps, state.s.delayline.slice(cursor) /*, tapcount - cursor*/) +
+ dotproduct(state.s.taps.slice(tapcount - cursor), state.s.delayline /*, cursor*/);
}
- delayline_cursor = cursor;
+ state.s.delayline_cursor = cursor;
return output;
}
- univector_dyn<T> taps;
- mutable univector_dyn<T> delayline;
- mutable size_t delayline_cursor;
+ state_holder<fir_state<T>, stateless> state;
};
}
@@ -117,16 +155,27 @@ CMT_INLINE internal::expression_fir<T, E1> fir(E1&& e1, const univector<T, Tag>&
}
/**
+ * @brief Returns template expression that applies FIR filter to the input
+ * @param state FIR filter state
+ * @param e1 an input expression
+ */
+template <typename T, typename E1, size_t Tag>
+CMT_INLINE internal::expression_fir<T, E1, true> fir(fir_state<T>& state, E1&& e1)
+{
+ return internal::expression_fir<T, E1, true>(std::forward<E1>(e1), state);
+}
+
+/**
* @brief Returns template expression that applies FIR filter to the input (count of coefficients must be in
* range 2..32)
* @param e1 an input expression
* @param taps coefficients for the FIR filter
*/
template <typename T, size_t TapCount, typename E1>
-CMT_INLINE internal::expression_short_fir<TapCount, T, E1> short_fir(E1&& e1,
- const univector<T, TapCount>& taps)
+CMT_INLINE internal::expression_short_fir<next_poweroftwo(TapCount), T, E1> short_fir(
+ E1&& e1, const univector<T, TapCount>& taps)
{
static_assert(TapCount >= 2 && TapCount <= 32, "Use short_fir only for small FIR filters");
- return internal::expression_short_fir<TapCount, T, E1>(std::forward<E1>(e1), taps.ref());
+ return internal::expression_short_fir<next_poweroftwo(TapCount), T, E1>(std::forward<E1>(e1), taps);
}
}
diff --git a/include/kfr/dsp/fracdelay.hpp b/include/kfr/dsp/fracdelay.hpp
@@ -36,6 +36,6 @@ CMT_INLINE internal::expression_short_fir<2, T, E1> fracdelay(E1&& e1, T delay)
if (delay < 0)
delay = 0;
univector<T, 2> taps({ 1 - delay, delay });
- return internal::expression_short_fir<2, T, E1>(std::forward<E1>(e1), taps.ref());
+ return internal::expression_short_fir<2, T, E1>(std::forward<E1>(e1), taps);
}
}