kfr

Fast, modern C++ DSP framework, FFT, Sample Rate Conversion, FIR/IIR/Biquad Filters (SSE, AVX, AVX-512, ARM NEON)
Log | Files | Refs | README

commit 2277d66d82e235a341b7de1a4adf71ab55736706
parent c00f56817b53c261e30e1e67d420ba0b9c605617
Author: [email protected] <[email protected]>
Date:   Wed,  9 Nov 2016 13:13:34 +0300

FFT based zero latency convolution

Diffstat:
Minclude/kfr/dft/convolution.hpp | 103+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 103 insertions(+), 0 deletions(-)

diff --git a/include/kfr/dft/convolution.hpp b/include/kfr/dft/convolution.hpp @@ -86,5 +86,108 @@ CMT_FUNC univector<T> autocorrelate(const univector<T, Tag1>& src) result = result.slice(result.size() / 2); return result; } + +template <typename T> +class convolve_filter : public filter<T> +{ +public: + explicit convolve_filter(size_t size, size_t block_size = 1024) + : size(size), block_size(block_size), fft(2 * next_poweroftwo(block_size)), temp(fft.temp_size), + segments((size + block_size - 1) / block_size) + { + } + explicit convolve_filter(const univector<T>& data, size_t block_size = 1024) + : size(data.size()), block_size(next_poweroftwo(block_size)), fft(2 * next_poweroftwo(block_size)), + temp(fft.temp_size), + segments((data.size() + next_poweroftwo(block_size) - 1) / next_poweroftwo(block_size)), + ir_segments((data.size() + next_poweroftwo(block_size) - 1) / next_poweroftwo(block_size)), + input_position(0), position(0) + { + set_data(data); + } + void set_data(const univector<T>& data) + { + univector<T> input(fft.size); + const T ifftsize = reciprocal(T(fft.size)); + for (size_t i = 0; i < ir_segments.size(); i++) + { + segments[i].resize(block_size); + ir_segments[i].resize(block_size, 0); + input = padded(data.slice(i * block_size, block_size)); + + fft.execute(ir_segments[i], input, temp, dft_pack_format::Perm); + process(ir_segments[i], ir_segments[i] * ifftsize); + } + saved_input.resize(block_size, 0); + scratch.resize(block_size * 2); + premul.resize(block_size, 0); + cscratch.resize(block_size); + overlap.resize(block_size, 0); + } + +protected: + void process_expression(T* dest, const expression_pointer<T>& src, size_t size) final + { + univector<T> input = truncate(src, size); + process_buffer(dest, input.data(), input.size()); + } + void process_buffer(T* output, const T* input, size_t size) final + { + size_t processed = 0; + while (processed < size) + { + const size_t processing = std::min(size - processed, block_size - input_position); + internal::builtin_memcpy(saved_input.data() + input_position, input + processed, + processing * sizeof(T)); + + process(scratch, padded(saved_input)); + fft.execute(segments[position], scratch, temp, dft_pack_format::Perm); + + if (input_position == 0) + { + process(premul, zeros()); + for (size_t i = 1; i < segments.size(); i++) + { + const size_t n = (position + i) % segments.size(); + fft_multiply_accumulate(premul, ir_segments[i], segments[n], dft_pack_format::Perm); + } + } + fft_multiply_accumulate(cscratch, premul, ir_segments[0], segments[position], + dft_pack_format::Perm); + + fft.execute(scratch, cscratch, temp, dft_pack_format::Perm); + + process(make_univector(output + processed, processing), + scratch.slice(input_position) + overlap.slice(input_position)); + + input_position += processing; + if (input_position == block_size) + { + input_position = 0; + process(saved_input, zeros()); + + internal::builtin_memcpy(overlap.data(), scratch.data() + block_size, block_size * sizeof(T)); + + position = position > 0 ? position - 1 : segments.size() - 1; + } + + processed += processing; + } + } + + const dft_plan_real<T> fft; + univector<u8> temp; + std::vector<univector<complex<T>>> segments; + std::vector<univector<complex<T>>> ir_segments; + const size_t size; + const size_t block_size; + size_t input_position; + univector<T> saved_input; + univector<complex<T>> premul; + univector<complex<T>> cscratch; + univector<T> scratch; + univector<T> overlap; + size_t position; +}; } CMT_PRAGMA_GNU(GCC diagnostic pop)