commit 3e97d16af133075c579274d5bb7722a7e528c9ba
parent 0623d0d73c72b57ed50a0499a9646af05b5f1489
Author: [email protected] <[email protected]>
Date: Wed, 7 Sep 2016 15:32:00 +0300
Cross correlation of two vectors
Diffstat:
2 files changed, 31 insertions(+), 3 deletions(-)
diff --git a/include/kfr/dft/conv.hpp b/include/kfr/dft/conv.hpp
@@ -55,7 +55,26 @@ KFR_INTRIN univector<T> convolve(const univector<T, Tag1>& src1, const univector
plan.execute(src2padded, src2padded, temp);
src1padded = src1padded * src2padded;
plan.execute(src1padded, src1padded, temp, true);
- return truncate(real(src1padded), src1.size() + src2.size() - 1) / T(size);
+ const T invsize = reciprocal<T>(size);
+ return truncate(real(src1padded), src1.size() + src2.size() - 1) * invsize;
+}
+
+template <typename T, size_t Tag1, size_t Tag2>
+KFR_INTRIN univector<T> correlate(const univector<T, Tag1>& src1, const univector<T, Tag2>& src2)
+{
+ const size_t size = next_poweroftwo(src1.size() + src2.size() - 1);
+ univector<complex<T>> src1padded = src1;
+ univector<complex<T>> src2padded = reverse(src2);
+ src1padded.resize(size, 0);
+ src2padded.resize(size, 0);
+ dft_plan<T> plan(size);
+ univector<u8> temp(plan.temp_size);
+ plan.execute(src1padded, src1padded, temp);
+ plan.execute(src2padded, src2padded, temp);
+ src1padded = src1padded * src2padded;
+ plan.execute(src1padded, src1padded, temp, true);
+ const T invsize = reciprocal<T>(size);
+ return truncate(real(src1padded), src1.size() + src2.size() - 1) * invsize;
}
}
#pragma clang diagnostic pop
diff --git a/tests/dft_test.cpp b/tests/dft_test.cpp
@@ -21,10 +21,19 @@ constexpr ctypes_t<float> float_types{};
TEST(test_convolve)
{
univector<fbase, 5> a({ 1, 2, 3, 4, 5 });
- univector<fbase, 5> b({ 0.25, 0.5, 1.0, 0.5, 0.25 });
+ univector<fbase, 5> b({ 0.25, 0.5, 1.0, -2.0, 1.5 });
univector<fbase> c = convolve(a, b);
CHECK(c.size() == 9);
- CHECK(rms(c - univector<fbase>({ 0.25, 1., 2.75, 5., 7.5, 8.5, 7.75, 3.5, 1.25 })) < 0.0001);
+ CHECK(rms(c - univector<fbase>({ 0.25, 1., 2.75, 2.5, 3.75, 3.5, 1.5, -4., 7.5 })) < 0.0001);
+}
+
+TEST(test_correlate)
+{
+ univector<fbase, 5> a({ 1, 2, 3, 4, 5 });
+ univector<fbase, 5> b({ 0.25, 0.5, 1.0, -2.0, 1.5 });
+ univector<fbase> c = correlate(a, b);
+ CHECK(c.size() == 9);
+ CHECK(rms(c - univector<fbase>({ 1.5, 1., 1.5, 2.5, 3.75, -4., 7.75, 3.5, 1.25 })) < 0.0001);
}
#ifdef CMT_ARCH_ARM