commit 2277d66d82e235a341b7de1a4adf71ab55736706
parent c00f56817b53c261e30e1e67d420ba0b9c605617
Author: [email protected] <[email protected]>
Date: Wed, 9 Nov 2016 13:13:34 +0300
FFT based zero latency convolution
Diffstat:
1 file changed, 103 insertions(+), 0 deletions(-)
diff --git a/include/kfr/dft/convolution.hpp b/include/kfr/dft/convolution.hpp
@@ -86,5 +86,108 @@ CMT_FUNC univector<T> autocorrelate(const univector<T, Tag1>& src)
result = result.slice(result.size() / 2);
return result;
}
+
+template <typename T>
+class convolve_filter : public filter<T>
+{
+public:
+ explicit convolve_filter(size_t size, size_t block_size = 1024)
+ : size(size), block_size(block_size), fft(2 * next_poweroftwo(block_size)), temp(fft.temp_size),
+ segments((size + block_size - 1) / block_size)
+ {
+ }
+ explicit convolve_filter(const univector<T>& data, size_t block_size = 1024)
+ : size(data.size()), block_size(next_poweroftwo(block_size)), fft(2 * next_poweroftwo(block_size)),
+ temp(fft.temp_size),
+ segments((data.size() + next_poweroftwo(block_size) - 1) / next_poweroftwo(block_size)),
+ ir_segments((data.size() + next_poweroftwo(block_size) - 1) / next_poweroftwo(block_size)),
+ input_position(0), position(0)
+ {
+ set_data(data);
+ }
+ void set_data(const univector<T>& data)
+ {
+ univector<T> input(fft.size);
+ const T ifftsize = reciprocal(T(fft.size));
+ for (size_t i = 0; i < ir_segments.size(); i++)
+ {
+ segments[i].resize(block_size);
+ ir_segments[i].resize(block_size, 0);
+ input = padded(data.slice(i * block_size, block_size));
+
+ fft.execute(ir_segments[i], input, temp, dft_pack_format::Perm);
+ process(ir_segments[i], ir_segments[i] * ifftsize);
+ }
+ saved_input.resize(block_size, 0);
+ scratch.resize(block_size * 2);
+ premul.resize(block_size, 0);
+ cscratch.resize(block_size);
+ overlap.resize(block_size, 0);
+ }
+
+protected:
+ void process_expression(T* dest, const expression_pointer<T>& src, size_t size) final
+ {
+ univector<T> input = truncate(src, size);
+ process_buffer(dest, input.data(), input.size());
+ }
+ void process_buffer(T* output, const T* input, size_t size) final
+ {
+ size_t processed = 0;
+ while (processed < size)
+ {
+ const size_t processing = std::min(size - processed, block_size - input_position);
+ internal::builtin_memcpy(saved_input.data() + input_position, input + processed,
+ processing * sizeof(T));
+
+ process(scratch, padded(saved_input));
+ fft.execute(segments[position], scratch, temp, dft_pack_format::Perm);
+
+ if (input_position == 0)
+ {
+ process(premul, zeros());
+ for (size_t i = 1; i < segments.size(); i++)
+ {
+ const size_t n = (position + i) % segments.size();
+ fft_multiply_accumulate(premul, ir_segments[i], segments[n], dft_pack_format::Perm);
+ }
+ }
+ fft_multiply_accumulate(cscratch, premul, ir_segments[0], segments[position],
+ dft_pack_format::Perm);
+
+ fft.execute(scratch, cscratch, temp, dft_pack_format::Perm);
+
+ process(make_univector(output + processed, processing),
+ scratch.slice(input_position) + overlap.slice(input_position));
+
+ input_position += processing;
+ if (input_position == block_size)
+ {
+ input_position = 0;
+ process(saved_input, zeros());
+
+ internal::builtin_memcpy(overlap.data(), scratch.data() + block_size, block_size * sizeof(T));
+
+ position = position > 0 ? position - 1 : segments.size() - 1;
+ }
+
+ processed += processing;
+ }
+ }
+
+ const dft_plan_real<T> fft;
+ univector<u8> temp;
+ std::vector<univector<complex<T>>> segments;
+ std::vector<univector<complex<T>>> ir_segments;
+ const size_t size;
+ const size_t block_size;
+ size_t input_position;
+ univector<T> saved_input;
+ univector<complex<T>> premul;
+ univector<complex<T>> cscratch;
+ univector<T> scratch;
+ univector<T> overlap;
+ size_t position;
+};
}
CMT_PRAGMA_GNU(GCC diagnostic pop)