kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

commit 3e97d16af133075c579274d5bb7722a7e528c9ba
parent 0623d0d73c72b57ed50a0499a9646af05b5f1489
Author: [email protected] <[email protected]>
Date:   Wed,  7 Sep 2016 15:32:00 +0300

Cross correlation of two vectors

Diffstat:
Minclude/kfr/dft/conv.hpp | 21++++++++++++++++++++-
Mtests/dft_test.cpp | 13+++++++++++--
2 files changed, 31 insertions(+), 3 deletions(-)

diff --git a/include/kfr/dft/conv.hpp b/include/kfr/dft/conv.hpp @@ -55,7 +55,26 @@ KFR_INTRIN univector<T> convolve(const univector<T, Tag1>& src1, const univector plan.execute(src2padded, src2padded, temp); src1padded = src1padded * src2padded; plan.execute(src1padded, src1padded, temp, true); - return truncate(real(src1padded), src1.size() + src2.size() - 1) / T(size); + const T invsize = reciprocal<T>(size); + return truncate(real(src1padded), src1.size() + src2.size() - 1) * invsize; +} + +template <typename T, size_t Tag1, size_t Tag2> +KFR_INTRIN univector<T> correlate(const univector<T, Tag1>& src1, const univector<T, Tag2>& src2) +{ + const size_t size = next_poweroftwo(src1.size() + src2.size() - 1); + univector<complex<T>> src1padded = src1; + univector<complex<T>> src2padded = reverse(src2); + src1padded.resize(size, 0); + src2padded.resize(size, 0); + dft_plan<T> plan(size); + univector<u8> temp(plan.temp_size); + plan.execute(src1padded, src1padded, temp); + plan.execute(src2padded, src2padded, temp); + src1padded = src1padded * src2padded; + plan.execute(src1padded, src1padded, temp, true); + const T invsize = reciprocal<T>(size); + return truncate(real(src1padded), src1.size() + src2.size() - 1) * invsize; } } #pragma clang diagnostic pop diff --git a/tests/dft_test.cpp b/tests/dft_test.cpp @@ -21,10 +21,19 @@ constexpr ctypes_t<float> float_types{}; TEST(test_convolve) { univector<fbase, 5> a({ 1, 2, 3, 4, 5 }); - univector<fbase, 5> b({ 0.25, 0.5, 1.0, 0.5, 0.25 }); + univector<fbase, 5> b({ 0.25, 0.5, 1.0, -2.0, 1.5 }); univector<fbase> c = convolve(a, b); CHECK(c.size() == 9); - CHECK(rms(c - univector<fbase>({ 0.25, 1., 2.75, 5., 7.5, 8.5, 7.75, 3.5, 1.25 })) < 0.0001); + CHECK(rms(c - univector<fbase>({ 0.25, 1., 2.75, 2.5, 3.75, 3.5, 1.5, -4., 7.5 })) < 0.0001); +} + +TEST(test_correlate) +{ + univector<fbase, 5> a({ 1, 2, 3, 4, 5 }); + univector<fbase, 5> b({ 0.25, 0.5, 1.0, -2.0, 1.5 }); + univector<fbase> c = correlate(a, b); + CHECK(c.size() == 9); + CHECK(rms(c - univector<fbase>({ 1.5, 1., 1.5, 2.5, 3.75, -4., 7.75, 3.5, 1.25 })) < 0.0001); } #ifdef CMT_ARCH_ARM