commit f10ea53776425ae4ac173bad6144b32908e91e4f
parent 55e7bb014cd21b1fb72d6a9e171023cef1656f0e
Author: d.levin256@gmail.com <d.levin256@gmail.com>
Date: Fri, 8 Dec 2023 15:59:37 +0000
matrix_transpose
Diffstat:
5 files changed, 570 insertions(+), 199 deletions(-)
diff --git a/include/kfr/base/transpose.hpp b/include/kfr/base/transpose.hpp
@@ -35,66 +35,62 @@ namespace kfr
{
inline namespace CMT_ARCH_NAME
{
-/// @brief Matrix transpose
-template <size_t group = 1, typename T, index_t Dims>
-void matrix_transpose(T* out, const T* in, shape<Dims> shape);
-
-/// @brief Matrix transpose (complex)
-template <size_t group = 1, typename T, index_t Dims>
-void matrix_transpose(complex<T>* out, const complex<T>* in, shape<Dims> shape);
namespace internal
{
-template <size_t group = 1, typename T, size_t N>
-void matrix_transpose_block_one(T* out, const T* in, size_t i, size_t stride)
+template <typename T, size_t N, index_t Dims>
+void matrix_transpose(vec<T, N>* out, const vec<T, N>* in, shape<Dims> tshape);
+
+template <typename T, size_t width, size_t N>
+void matrix_transpose_block_one(vec<T, N>* out, const vec<T, N>* in, size_t i, size_t stride)
{
- if constexpr (N == 1)
+ if constexpr (width == 1)
{
- write(out + group * i, kfr::read<group>(in + group * i));
+ write(ptr_cast<T>(out + i), kfr::read<N>(ptr_cast<T>(in + i)));
}
else
{
- vec<T, (group * N * N)> vi = read_group<N, N, group>(in + group * i, stride);
- vi = transpose<N, group>(vi);
- write_group<N, N, group>(out + group * i, stride, vi);
+ vec<T, (N * width * width)> vi = read_group<width, width, N>(ptr_cast<T>(in + i), stride);
+ vi = transpose<width, N>(vi);
+ write_group<width, width, N>(ptr_cast<T>(out + i), stride, vi);
}
}
-template <size_t group = 1, typename T, size_t N>
-void matrix_transpose_block_two(T* out, const T* in, size_t i, size_t j, size_t stride)
+template <typename T, size_t width, size_t N>
+void matrix_transpose_block_two(vec<T, N>* out, const vec<T, N>* in, size_t i, size_t j, size_t stride)
{
- if constexpr (N == 1)
+ if constexpr (width == 1)
{
- vec<T, group> vi = kfr::read<group>(in + group * i);
- vec<T, group> vj = kfr::read<group>(in + group * j);
- write(out + group * i, vj);
- write(out + group * j, vi);
+ vec<T, N> vi = kfr::read<N>(ptr_cast<T>(in + i));
+ vec<T, N> vj = kfr::read<N>(ptr_cast<T>(in + j));
+ write(ptr_cast<T>(out + i), vj);
+ write(ptr_cast<T>(out + j), vi);
}
else
{
- vec<T, (group * N * N)> vi = read_group<N, N, group>(in + group * i, stride);
- vec<T, (group * N * N)> vj = read_group<N, N, group>(in + group * j, stride);
- vi = transpose<N, group>(vi);
- vj = transpose<N, group>(vj);
- write_group<N, N, group>(out + group * i, stride, vj);
- write_group<N, N, group>(out + group * j, stride, vi);
+ vec<T, (N * width * width)> vi = read_group<width, width, N>(ptr_cast<T>(in + i), stride);
+ vec<T, (N * width * width)> vj = read_group<width, width, N>(ptr_cast<T>(in + j), stride);
+ vi = transpose<width, N>(vi);
+ vj = transpose<width, N>(vj);
+ write_group<width, width, N>(ptr_cast<T>(out + i), stride, vj);
+ write_group<width, width, N>(ptr_cast<T>(out + j), stride, vi);
}
}
-template <size_t group = 1, typename T>
-void matrix_transpose_square_small(T* out, const T* in, size_t n)
+template <typename T, size_t N>
+void matrix_transpose_square_small(vec<T, N>* out, const vec<T, N>* in, size_t n)
{
cswitch(csizeseq<6, 1>, n, // 1, 2, 3, 4, 5 or 6
[&](auto n_) CMT_INLINE_LAMBDA
{
constexpr size_t n = CMT_CVAL(n_);
- write(out, transpose<n, group>(kfr::read<n * n * group>(in)));
+ write(ptr_cast<T>(out), transpose<n, N>(kfr::read<n * n * N>(ptr_cast<T>(in))));
});
}
-template <size_t group = 1, typename T>
-void matrix_transpose_square(T* out, const T* in, size_t n, size_t stride)
+template <typename T, size_t N>
+void matrix_transpose_square(vec<T, N>* out, const vec<T, N>* in, size_t n, size_t stride)
{
#if 1
constexpr size_t width = 4;
@@ -106,7 +102,7 @@ void matrix_transpose_square(T* out, const T* in, size_t n, size_t stride)
CMT_LOOP_NOUNROLL
for (; i < nw; i += width)
{
- matrix_transpose_block_one<group, T, width>(out, in, istridei, stride);
+ matrix_transpose_block_one<T, width>(out, in, istridei, stride);
size_t j = i + width;
size_t istridej = istridei + width;
@@ -114,7 +110,7 @@ void matrix_transpose_square(T* out, const T* in, size_t n, size_t stride)
CMT_LOOP_NOUNROLL
for (; j < nw; j += width)
{
- matrix_transpose_block_two<group, T, width>(out, in, istridej, jstridei, stride);
+ matrix_transpose_block_two<T, width>(out, in, istridej, jstridei, stride);
istridej += width;
jstridei += wstride;
}
@@ -124,7 +120,7 @@ void matrix_transpose_square(T* out, const T* in, size_t n, size_t stride)
CMT_LOOP_NOUNROLL
for (size_t ii = i; ii < i + width; ++ii)
{
- matrix_transpose_block_two<group, T, 1>(out, in, istridej, jstridei, stride);
+ matrix_transpose_block_two<T, 1>(out, in, istridej, jstridei, stride);
istridej += stride;
jstridei += 1;
}
@@ -137,11 +133,11 @@ void matrix_transpose_square(T* out, const T* in, size_t n, size_t stride)
CMT_LOOP_NOUNROLL
for (; i < n; ++i)
{
- matrix_transpose_block_one<group, T, 1>(out, in, i * stride + i, stride);
+ matrix_transpose_block_one<T, 1>(out, in, i * stride + i, stride);
CMT_LOOP_NOUNROLL
for (size_t j = i + 1; j < n; ++j)
{
- matrix_transpose_block_two<group, T, 1>(out, in, i * stride + j, j * stride + i, stride);
+ matrix_transpose_block_two<T, 1>(out, in, i * stride + j, j * stride + i, stride);
}
}
#else
@@ -152,13 +148,13 @@ void matrix_transpose_square(T* out, const T* in, size_t n, size_t stride)
CMT_LOOP_NOUNROLL
for (; i < nw; i += width)
{
- matrix_transpose_block_one<group, T, width>(out, in, i * stride + i, stride);
+ matrix_transpose_block_one<T, width>(out, in, i * stride + i, stride);
size_t j = i + width;
CMT_LOOP_NOUNROLL
for (; j < nw; j += width)
{
- matrix_transpose_block_two<group, T, width>(out, in, i * stride + j, j * stride + i, stride);
+ matrix_transpose_block_two<T, width>(out, in, i * stride + j, j * stride + i, stride);
}
CMT_LOOP_NOUNROLL
for (; j < n; ++j)
@@ -166,7 +162,7 @@ void matrix_transpose_square(T* out, const T* in, size_t n, size_t stride)
CMT_LOOP_NOUNROLL
for (size_t ii = i; ii < i + width; ++ii)
{
- matrix_transpose_block_two<group, T, 1>(out, in, ii * stride + j, j * stride + ii, stride);
+ matrix_transpose_block_two<T, 1>(out, in, ii * stride + j, j * stride + ii, stride);
}
}
}
@@ -174,91 +170,431 @@ void matrix_transpose_square(T* out, const T* in, size_t n, size_t stride)
CMT_LOOP_NOUNROLL
for (; i < n; ++i)
{
- matrix_transpose_block_one<group, T, 1>(out, in, i * stride + i, stride);
+ matrix_transpose_block_one<T, 1>(out, in, i * stride + i, stride);
CMT_LOOP_NOUNROLL
for (size_t j = i + 1; j < n; ++j)
{
- matrix_transpose_block_two<group, T, 1>(out, in, i * stride + j, j * stride + i, stride);
+ matrix_transpose_block_two<T, 1>(out, in, i * stride + j, j * stride + i, stride);
}
}
#endif
}
-template <size_t group = 1, typename T>
-void matrix_transpose_any(T* out, const T* in, size_t rows, size_t cols)
+template <typename T, size_t N>
+CMT_ALWAYS_INLINE void do_reverse(vec<T, N>* first, vec<T, N>* last)
{
- // 1. transpose square sub-matrix
- const size_t side = std::min(cols, rows);
- matrix_transpose_square<group>(out, in, side, cols);
+ constexpr size_t width = vector_capacity<T> / 4 / N;
+ for (; first + width - 1 < last - width; first += width, last -= width)
+ {
+ vec<T, (N * width)> a = read<N * width>(first);
+ vec<T, (N * width)> b = read<N * width>(last - width);
+ write(first, reverse<N>(b));
+ write(last - width, reverse<N>(a));
+ }
+ for (; first < last; first += 1, last -= 1)
+ {
+ vec<T, N> a = read<N>(ptr_cast<T>(first));
+ vec<T, N> b = read<N>(ptr_cast<T>(last - 1));
+ write(ptr_cast<T>(first), b);
+ write(ptr_cast<T>(last - 1), a);
+ }
+}
+
+template <typename T, size_t N>
+CMT_ALWAYS_INLINE void ranges_swap(vec<T, N>* x, vec<T, N>* y, size_t size)
+{
+ block_process(size, csizes<const_max(vector_capacity<T> / 4 / N, 2), 1>,
+ [x, y](size_t index, auto w) CMT_INLINE_LAMBDA
+ {
+ constexpr size_t width = CMT_CVAL(w);
+ vec<T, N* width> xx = read<N * width>(ptr_cast<T>(x + index));
+ vec<T, N* width> yy = read<N * width>(ptr_cast<T>(y + index));
+ write(ptr_cast<T>(x + index), yy);
+ write(ptr_cast<T>(y + index), xx);
+ });
+}
+
+template <typename T>
+CMT_ALWAYS_INLINE void do_swap(T* arr, size_t a, size_t b, size_t k)
+{
+ ranges_swap(arr + a, arr + b, k);
+}
+template <typename T>
+CMT_ALWAYS_INLINE void do_block_swap(T* arr, size_t k, size_t n)
+{
+ if (k == 0 || k == n)
+ return;
- if (cols > rows)
+ for (;;)
{
- // 2. copy remaining
- size_t remaining = cols - rows;
- if (in != out)
+ if (k == n - k)
{
- for (size_t r = 0; r < rows; ++r)
- {
- builtin_memcpy(out + group * (side + r * cols), in + group * (side + r * cols),
- group * remaining * sizeof(T));
- }
+ do_swap(arr, 0, n - k, k);
+ return;
+ }
+ else if (k < n - k)
+ {
+ do_swap(arr, 0, n - k, k);
+ n = n - k;
+ }
+ else
+ {
+ do_swap(arr, 0, k, n - k);
+ arr += n - k;
+ const size_t newk = 2 * k - n;
+ n = k;
+ k = newk;
}
+ }
+}
+
+template <typename T, size_t N>
+CMT_ALWAYS_INLINE void range_rotate(vec<T, N>* first, vec<T, N>* middle, vec<T, N>* last)
+{
+#ifndef KFR_T_REV
+ do_block_swap(first, middle - first, last - first);
+#else
+ do_reverse<group>(first, middle);
+ do_reverse<group>(middle, last);
+ do_reverse<group>(first, last);
+#endif
+}
- // 3. shift rows
- auto* p = ptr_cast<vec<T, group>>(out) + side;
- for (size_t r = 0; r + 1 < rows; ++r)
+struct matrix_size
+{
+ size_t rows;
+ size_t cols;
+};
+
+template <typename T, size_t N>
+void matrix_transpose_copy(vec<T, N>* out, const vec<T, N>* in, matrix_size size, matrix_size done)
+{
+ if (size.cols != done.cols)
+ {
+ for (size_t r = 0; r < size.rows; ++r)
+ builtin_memcpy(out + r * size.cols + done.cols, //
+ in + r * size.cols + done.cols, //
+ (size.cols - done.cols) * N * sizeof(T));
+ }
+
+ for (size_t r = done.rows; r < size.rows; ++r)
+ builtin_memcpy(out + r * size.cols, //
+ in + r * size.cols, //
+ (size.cols) * N * sizeof(T));
+}
+
+template <typename T, size_t N>
+void matrix_transpose_shift_rows(vec<T, N>* out, size_t done, matrix_size size)
+{
+ const size_t remaining = size.cols - done;
+ vec<T, N>* p = out + done;
+ for (size_t r = 1; r < size.rows; ++r)
+ {
+ range_rotate(p, p + r * remaining, p + done + r * remaining);
+ p += done;
+ }
+}
+
+template <typename T, size_t N>
+void matrix_transpose_shift_cols(vec<T, N>* out, size_t done, matrix_size size)
+{
+ const size_t remaining = size.rows - done;
+ vec<T, N>* p = out + done * (size.cols - 1);
+ for (size_t c = size.cols - 1; c >= 1; --c)
+ {
+ range_rotate(p, p + done, p + done + c * remaining);
+ p -= done;
+ }
+}
+
+class matrix_cycles
+{
+public:
+ matrix_cycles(const matrix_cycles&) = delete;
+ matrix_cycles(matrix_cycles&&) = delete;
+ matrix_cycles& operator=(const matrix_cycles&) = delete;
+ matrix_cycles& operator=(matrix_cycles&&) = delete;
+
+ CMT_INLINE_MEMBER explicit matrix_cycles(shape<2> size) : size(size), flat_size(size.product())
+ {
+ size_t bits = (flat_size + 1) / 2;
+ size_t words = (bits + word_bits - 1) / word_bits;
+ if (words <= std::size(on_stack))
+ data = on_stack;
+ else
+ data = new word_t[words];
+ builtin_memset(data, 0, sizeof(word_t) * words);
+ }
+
+ ~matrix_cycles()
+ {
+ if (data != on_stack)
+ delete data;
+ }
+
+ size_t next_cycle_origin(size_t origin = 0)
+ {
+ for (; origin < (flat_size + 1) / 2; ++origin)
{
- std::rotate(p, p + remaining + r * remaining, p + side + remaining + r * remaining);
- p += side;
+ if (!test_and_set(origin))
+ return origin;
}
- // 4. transpose remainder
- matrix_transpose<group>(out + group * side * side, out + group * side * side,
- shape{ side, remaining });
+ return static_cast<size_t>(-1);
}
- else // if (cols < rows)
+
+ template <bool first_pass = true, typename Start, typename Iterate, typename Stop>
+ void iterate(size_t origin, Start&& start, Iterate&& iterate, Stop&& stop, bool skip_fixed = false)
{
- // 2. copy remaining
- size_t remaining = rows - cols;
- if (in != out)
+ shape<2> transposed_size = size.transpose();
+ size_t next = transposed_size.to_flat(size.from_flat(origin).transpose());
+ if (next == origin)
+ {
+ bool is_fixed = next != flat_size - 1 - next;
+ if (!(is_fixed && skip_fixed))
+ {
+ start(origin, flat_size - 1 - origin, is_fixed);
+ stop(next != flat_size - 1 - next);
+ }
+ }
+ else
{
- for (size_t r = 0; r < remaining; ++r)
+ size_t inv_next = flat_size - 1 - next;
+ size_t min_next = std::min(next, inv_next);
+ if (min_next == origin)
+ {
+ bool is_fixed = next == origin;
+ if (!(is_fixed && skip_fixed))
+ {
+ start(origin, flat_size - 1 - origin, next == origin);
+ stop(next == origin);
+ }
+ }
+ else
{
- builtin_memcpy(out + group * ((cols + r) * cols), in + group * ((cols + r) * cols),
- group * cols * sizeof(T));
+ start(origin, flat_size - 1 - origin, false);
+ for (;;)
+ {
+ if constexpr (first_pass)
+ {
+ set(min_next);
+ }
+ iterate(next, inv_next);
+ next = transposed_size.to_flat(size.from_flat(next).transpose());
+ inv_next = flat_size - 1 - next;
+ min_next = std::min(next, inv_next);
+ if (min_next == origin)
+ {
+ stop(next == origin);
+ break;
+ }
+ }
}
}
+ }
- // 3. transpose remainder
+private:
+ using word_t = uint32_t;
+ constexpr static size_t word_bits = sizeof(word_t) * 8;
+ shape<2> size;
+ size_t flat_size;
+ word_t* data;
+CMT_PRAGMA_GNU(GCC diagnostic push)
+CMT_PRAGMA_GNU(GCC diagnostic ignored "-Wattributes")
+ [[maybe_unused]] uint8_t cache_line__[64];
+CMT_PRAGMA_GNU(GCC diagnostic pop)
+ alignas(16) word_t on_stack[1024];
+
+ CMT_INLINE_MEMBER void set(size_t index)
+ {
+ word_t& word = data[index / word_bits];
+ word_t mask = 1u << (index % word_bits);
+ word |= mask;
+ }
+
+ CMT_INLINE_MEMBER bool test_and_set(size_t index)
+ {
+ word_t& word = data[index / word_bits];
+ word_t mask = 1u << (index % word_bits);
+ if (word & mask)
+ return true;
+ word |= mask;
+ return false;
+ }
+};
+
+template <typename T, size_t N, bool horizontal>
+CMT_INTRINSIC void matrix_merge_squares_fast(vec<T, N>* out, size_t side, size_t squares, matrix_size size,
+ size_t stride, cbool_t<horizontal>)
+{
+ if constexpr (!horizontal)
+ {
+ stride = stride * side;
+ }
+ for (size_t i = 0; i < side; ++i)
+ {
+ for (size_t j = i + 1; j < side; ++j)
+ {
+ size_t index1 = i * stride + j * side;
+ size_t index2 = j * stride + i * side;
+ ranges_swap(out + index1, out + index2, side);
+ }
+ }
+}
+
+static CMT_INTRINSIC size_t matrix_offset(size_t flat_index, size_t side, size_t stride1, size_t stride2)
+{
+ size_t i = flat_index / side;
+ size_t j = flat_index % side;
+ return i * stride1 + j * stride2;
+}
+
+template <typename T, size_t N, bool horizontal>
+void matrix_merge_squares(vec<T, N>* out, size_t side, size_t squares, matrix_size size, size_t stride,
+ cbool_t<horizontal>)
+{
+ if (side == squares)
+ {
+ return matrix_merge_squares_fast(out, side, squares, size, stride, cbool<horizontal>);
+ }
+ if constexpr (!horizontal)
+ {
+ stride = stride * side;
+ }
+ shape sh = horizontal ? shape{ squares, side } : shape{ side, squares };
+ size_t flat_side = sh[0];
+ matrix_cycles cycles(sh);
- matrix_transpose<group>(out + group * side * side, out + group * side * side,
- shape{ remaining, cols });
+ size_t origin = 0;
+ do
+ {
+ block_process(
+ side, csizes<const_max(2, vector_capacity<T> / 8 / N), 1>,
+ [&](size_t offset, auto width_)
+ {
+ constexpr size_t width = CMT_CVAL(width_);
- // 4. shift cols
- auto* p = ptr_cast<vec<T, group>>(out) + side * (cols - 1);
- for (size_t c = cols - 1; c >= 1;)
+ vec<T, width * N> temp;
+ vec<T, width * N> temp_inv;
+ size_t previous;
+ size_t previous_inv;
+ cycles.iterate(
+ origin,
+ [&](size_t origin, size_t origin_inv, bool /* fixed */) CMT_INLINE_LAMBDA
+ {
+#ifdef CMT_COMPILER_IS_MSVC
+ constexpr size_t width = CMT_CVAL(width_);
+#endif
+ temp = read<width * N>(
+ ptr_cast<T>(out + matrix_offset(origin, flat_side, stride, side) + offset));
+ temp_inv = read<width * N>(
+ ptr_cast<T>(out + matrix_offset(origin_inv, flat_side, stride, side) + offset));
+ previous = origin;
+ previous_inv = origin_inv;
+ },
+ [&](size_t current, size_t current_inv) CMT_INLINE_LAMBDA
+ {
+#ifdef CMT_COMPILER_IS_MSVC
+ constexpr size_t width = CMT_CVAL(width_);
+#endif
+ vec<T, (width * N)> val = read<width * N>(
+ ptr_cast<T>(out + matrix_offset(current, flat_side, stride, side) + offset));
+ vec<T, (width * N)> val_inv = read<width * N>(
+ ptr_cast<T>(out + matrix_offset(current_inv, flat_side, stride, side) + offset));
+ write(ptr_cast<T>(out + matrix_offset(previous, flat_side, stride, side) + offset),
+ val);
+ write(
+ ptr_cast<T>(out + matrix_offset(previous_inv, flat_side, stride, side) + offset),
+ val_inv);
+ previous = current;
+ previous_inv = current_inv;
+ },
+ [&](bool symmetric) CMT_INLINE_LAMBDA
+ {
+ if (!symmetric)
+ std::swap(temp, temp_inv);
+ write(ptr_cast<T>(out + matrix_offset(previous, flat_side, stride, side) + offset),
+ temp);
+ write(
+ ptr_cast<T>(out + matrix_offset(previous_inv, flat_side, stride, side) + offset),
+ temp_inv);
+ },
+ true);
+ });
+ origin = cycles.next_cycle_origin(origin + 1);
+ } while (origin != static_cast<size_t>(-1));
+}
+
+template <typename T, size_t N>
+void matrix_transpose_any(vec<T, N>* out, const vec<T, N>* in, matrix_size size)
+{
+ if (size.cols > size.rows)
+ {
+ // 1. transpose square sub-matrices
+ const size_t side = size.rows;
+ const size_t squares = size.cols / side;
+ for (size_t i = 0; i < squares; ++i)
{
- --c;
- std::rotate(p, p + side, p + (side + remaining + c * remaining));
- p -= side;
+ matrix_transpose_square(out + i * side, in + i * side, side, size.cols);
}
+ if (squares > 1)
+ matrix_merge_squares(out, side, squares, size, size.cols, ctrue);
+ const size_t done = side * squares;
+ if (in != out)
+ matrix_transpose_copy(out, in, size, { side, done });
+
+ const size_t remaining = size.cols - done;
+ if (remaining == 0)
+ return;
+
+ // 2. shift rows
+ matrix_transpose_shift_rows(out, done, size);
+
+ // 3. transpose remainder
+ internal::matrix_transpose(out + done * size.rows, out + done * size.rows,
+ shape{ size.rows, remaining });
+ }
+ else // if (cols < rows)
+ {
+ // 1. transpose square sub-matrices
+ const size_t side = size.cols;
+ const size_t squares = size.rows / side;
+ for (size_t i = 0; i < squares; ++i)
+ {
+ matrix_transpose_square(out + i * side * side, in + i * side * side, side, size.cols);
+ }
+ if (squares > 1)
+ matrix_merge_squares(out, side, squares, size, size.cols, cfalse);
+ const size_t done = side * squares;
+ if (in != out)
+ matrix_transpose_copy(out, in, size, { done, side });
+
+ const size_t remaining = size.rows - done;
+ if (remaining == 0)
+ return;
+
+ // 2. transpose remainder
+ internal::matrix_transpose(out + done * size.cols, out + done * size.cols,
+ shape{ remaining, size.cols });
+
+ // 3. shift cols
+ matrix_transpose_shift_cols(out, done, size);
}
}
-template <size_t group = 1, typename T>
+template <typename T>
KFR_INTRINSIC void matrix_transpose_noop(T* out, const T* in, size_t total)
{
if (out == in)
return;
- builtin_memcpy(out, in, total * sizeof(T) * group);
+ builtin_memcpy(out, in, total * sizeof(T));
}
-} // namespace internal
-template <size_t group, typename T, index_t Dims>
-void matrix_transpose(T* out, const T* in, shape<Dims> tshape)
+template <typename T, size_t N, index_t Dims>
+void matrix_transpose(vec<T, N>* out, const vec<T, N>* in, shape<Dims> tshape)
{
if constexpr (Dims <= 1)
{
- return internal::matrix_transpose_noop<group>(out, in, tshape.product());
+ return internal::matrix_transpose_noop(out, in, tshape.product());
}
else if constexpr (Dims == 2)
{
@@ -266,35 +602,41 @@ void matrix_transpose(T* out, const T* in, shape<Dims> tshape)
const index_t cols = tshape[1];
if (cols == 1 || rows == 1)
{
- return internal::matrix_transpose_noop<group>(out, in, tshape.product());
+ return internal::matrix_transpose_noop(out, in, tshape.product());
}
// TODO: special cases for tall or wide matrices
if (cols == rows)
{
if (cols <= 6)
- return internal::matrix_transpose_square_small<group>(out, in, cols);
- return internal::matrix_transpose_square<group>(out, in, cols, cols);
+ return internal::matrix_transpose_square_small(out, in, cols);
+ return internal::matrix_transpose_square(out, in, cols, cols);
}
- return internal::matrix_transpose_any<group>(out, in, rows, cols);
+ return internal::matrix_transpose_any(out, in, { rows, cols });
}
else
{
shape<Dims - 1> x = tshape.template slice<0, Dims - 1>();
index_t xproduct = x.product();
index_t y = tshape.back();
- matrix_transpose<group>(out, in, shape<2>{ xproduct, y });
+ internal::matrix_transpose(out, in, shape<2>{ xproduct, y });
for (index_t i = 0; i < y; ++i)
{
- matrix_transpose<group>(out, out, x);
- out += group * xproduct;
+ internal::matrix_transpose(out, out, x);
+ out += xproduct;
}
}
}
+} // namespace internal
-template <size_t group, typename T, index_t Dims>
-void matrix_transpose(complex<T>* out, const complex<T>* in, shape<Dims> shape)
+/// @brief Matrix transpose.
+/// Accepts vec, complex and other compound types
+template <typename T, index_t Dims>
+void matrix_transpose(T* out, const T* in, shape<Dims> shape)
{
- return matrix_transpose<2 * group>(ptr_cast<T>(out), ptr_cast<T>(in), shape);
+ using U = typename compound_type_traits<T>::deep_subtype;
+ constexpr size_t width = compound_type_traits<T>::deep_width;
+ return internal::matrix_transpose<U, width, Dims>(ptr_cast<vec<U, width>>(out),
+ ptr_cast<vec<U, width>>(in), shape);
}
} // namespace CMT_ARCH_NAME
diff --git a/sources.cmake b/sources.cmake
@@ -492,6 +492,7 @@ set(
${PROJECT_SOURCE_DIR}/tests/unit/base/simd_expressions.cpp
${PROJECT_SOURCE_DIR}/tests/unit/base/std_ambiguities.cpp
${PROJECT_SOURCE_DIR}/tests/unit/base/tensor.cpp
+ ${PROJECT_SOURCE_DIR}/tests/unit/base/transpose.cpp
${PROJECT_SOURCE_DIR}/tests/unit/base/univector.cpp
${PROJECT_SOURCE_DIR}/tests/unit/dsp/biquad.cpp
${PROJECT_SOURCE_DIR}/tests/unit/dsp/biquad_design.cpp
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
@@ -79,8 +79,7 @@ if (KFR_ENABLE_ASMTEST)
COMMAND objconv -fyasm $<TARGET_FILE:asm_test>)
endif ()
-set(ALL_TESTS_CPP
- ${KFR_UNITTEST_SRC})
+set(ALL_TESTS_CPP ${KFR_UNITTEST_SRC})
if (KFR_ENABLE_DFT)
list(APPEND ALL_TESTS_CPP dft_test.cpp)
diff --git a/tests/unit/base/tensor.cpp b/tests/unit/base/tensor.cpp
@@ -840,107 +840,6 @@ TEST(identity_matrix)
CHECK(trender(identity_matrix<float, 3>{}) == tensor<float, 2>{ { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, 1 } });
}
-template <typename T, bool Transposed = false>
-struct expression_test_matrix : public expression_traits_defaults
-{
- shape<2> matrix_shape;
- index_t mark;
- expression_test_matrix(index_t rows, size_t cols, index_t mark = 10000)
- : matrix_shape({ rows, cols }), mark(mark)
- {
- if constexpr (Transposed)
- std::swap(matrix_shape[0], matrix_shape[1]);
- }
-
- using value_type = T;
- constexpr static size_t dims = 2;
- constexpr static shape<2> get_shape(const expression_test_matrix& self) { return self.matrix_shape; }
- constexpr static shape<2> get_shape() { return {}; }
-
- template <index_t Axis, size_t N>
- friend vec<T, N> get_elements(const expression_test_matrix& self, shape<2> index,
- const axis_params<Axis, N>&)
- {
- shape<2> scale{ self.mark, 1 };
- if constexpr (Transposed)
- std::swap(scale[0], scale[1]);
- vec<T, N> result;
- for (size_t i = 0; i < N; ++i)
- {
- result[i] = index[0] * scale[0] + index[1] * scale[1];
- index[Axis] += 1;
- }
- return result;
- }
-};
-
-template <typename T>
-static void test_transpose(size_t rows, size_t cols, size_t mark = 10000)
-{
- tensor<T, 2> t = expression_test_matrix<T>(rows, cols, mark);
-
- tensor<T, 2> t2(shape<2>{ cols, rows });
- univector<T> tt(t.size());
- auto d = tensor<T, 2>(tt.data(), shape{ rows, cols }, nullptr);
- auto d2 = tensor<T, 2>(tt.data(), shape{ cols, rows }, nullptr);
- CHECK(d.data() == d2.data());
- d = expression_test_matrix<T>(rows, cols, mark);
- t2 = -1;
- matrix_transpose(t2.data(), t.data(), shape{ rows, cols });
-
- matrix_transpose(d2.data(), d.data(), shape{ rows, cols });
-
- testo::scope s(as_string("type=", type_name<T>(), " rows=", rows, " cols=", cols));
-
- auto erro = maxof(cabs(t2 - expression_test_matrix<T, true>(rows, cols, mark)));
- CHECK(erro == 0);
-
- auto erri = maxof(cabs(d2 - expression_test_matrix<T, true>(rows, cols, mark)));
- CHECK(erri == 0);
-}
-
-[[maybe_unused]] static void test_transpose_t(size_t rows, size_t cols, size_t mark = 10000)
-{
- test_transpose<float>(rows, cols, mark);
- test_transpose<double>(rows, cols, mark);
- test_transpose<complex<float>>(rows, cols, mark);
- test_transpose<complex<double>>(rows, cols, mark);
-}
-
-TEST(matrix_transpose)
-{
- for (int i = 1; i <= 100; ++i)
- {
- for (int j = 1; j <= 100; ++j)
- {
- test_transpose_t(i, j);
- }
- }
-
- univector<int, 24> x = counter();
- matrix_transpose(x.data(), x.data(), shape{ 2, 3, 4 });
- CHECK(x == univector<int, 24>{ 0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21,
- 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23 });
-
- univector<uint8_t, 120> x2 = counter();
- matrix_transpose(x2.data(), x2.data(), shape{ 2, 3, 4, 5 });
- CHECK(x2 == univector<uint8_t, 120>{ 0, 60, 20, 80, 40, 100, 5, 65, 25, 85, 45, 105, 10, 70, 30,
- 90, 50, 110, 15, 75, 35, 95, 55, 115, 1, 61, 21, 81, 41, 101,
- 6, 66, 26, 86, 46, 106, 11, 71, 31, 91, 51, 111, 16, 76, 36,
- 96, 56, 116, 2, 62, 22, 82, 42, 102, 7, 67, 27, 87, 47, 107,
- 12, 72, 32, 92, 52, 112, 17, 77, 37, 97, 57, 117, 3, 63, 23,
- 83, 43, 103, 8, 68, 28, 88, 48, 108, 13, 73, 33, 93, 53, 113,
- 18, 78, 38, 98, 58, 118, 4, 64, 24, 84, 44, 104, 9, 69, 29,
- 89, 49, 109, 14, 74, 34, 94, 54, 114, 19, 79, 39, 99, 59, 119 });
-
- tensor<int, 1> d{ shape{ 24 } };
- d = counter();
- tensor<int, 3> dd = d.reshape(shape{ 2, 3, 4 });
- tensor<int, 3> ddd = dd.transpose();
- CHECK(trender(ddd.flatten_may_copy()) == tensor<int, 1>{ 0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21,
- 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23 });
-}
-
} // namespace CMT_ARCH_NAME
} // namespace kfr
diff --git a/tests/unit/base/transpose.cpp b/tests/unit/base/transpose.cpp
@@ -0,0 +1,130 @@
+/**
+ * KFR (https://www.kfrlib.com)
+ * Copyright (C) 2016-2023 Dan Cazarin
+ * See LICENSE.txt for details
+ */
+
+#include <kfr/base/basic_expressions.hpp>
+#include <kfr/base/math_expressions.hpp>
+#include <kfr/base/reduce.hpp>
+#include <kfr/base/simd_expressions.hpp>
+#include <kfr/base/tensor.hpp>
+#include <kfr/io/tostring.hpp>
+#include <kfr/simd.hpp>
+
+CMT_PRAGMA_MSVC(warning(push))
+CMT_PRAGMA_MSVC(warning(disable : 5051))
+CMT_PRAGMA_MSVC(warning(disable : 4244))
+
+namespace kfr
+{
+
+inline namespace CMT_ARCH_NAME
+{
+template <typename T, bool Transposed = false>
+struct expression_test_matrix : public expression_traits_defaults
+{
+ shape<2> matrix_shape;
+ index_t mark;
+ expression_test_matrix(index_t rows, size_t cols, index_t mark = 10000)
+ : matrix_shape{ rows, cols }, mark(mark)
+ {
+ if constexpr (Transposed)
+ std::swap(matrix_shape[0], matrix_shape[1]);
+ }
+
+ using value_type = T;
+ constexpr static size_t dims = 2;
+ constexpr static shape<2> get_shape(const expression_test_matrix& self) { return self.matrix_shape; }
+ constexpr static shape<2> get_shape() { return {}; }
+
+ template <index_t Axis, size_t N>
+ friend vec<T, N> get_elements(const expression_test_matrix& self, shape<2> index,
+ const axis_params<Axis, N>&)
+ {
+ shape<2> scale{ self.mark, 1 };
+ if constexpr (Transposed)
+ std::swap(scale[0], scale[1]);
+ vec<T, N> result;
+ for (size_t i = 0; i < N; ++i)
+ {
+ result[i] = index[0] * scale[0] + index[1] * scale[1];
+ index[Axis] += 1;
+ }
+ return result;
+ }
+};
+
+template <typename T>
+static void test_transpose(size_t rows, size_t cols, size_t mark = 10000)
+{
+ tensor<T, 2> t = expression_test_matrix<T>(rows, cols, mark);
+
+ tensor<T, 2> t2(shape<2>{ cols, rows });
+ univector<T> tt(t.size());
+ auto d = tensor<T, 2>(tt.data(), shape{ rows, cols }, nullptr);
+ auto d2 = tensor<T, 2>(tt.data(), shape{ cols, rows }, nullptr);
+ CHECK(d.data() == d2.data());
+ d = expression_test_matrix<T>(rows, cols, mark);
+ t2 = -1;
+ matrix_transpose(t2.data(), t.data(), shape{ rows, cols });
+
+ matrix_transpose(d2.data(), d.data(), shape{ rows, cols });
+
+ testo::scope s(as_string("type=", type_name<T>(), " rows=", rows, " cols=", cols));
+
+ auto erro = maxof(cabs(t2 - expression_test_matrix<T, true>(rows, cols, mark)));
+ CHECK(erro == 0);
+
+ auto erri = maxof(cabs(d2 - expression_test_matrix<T, true>(rows, cols, mark)));
+ CHECK(erri == 0);
+}
+
+[[maybe_unused]] static void test_transpose_t(size_t rows, size_t cols, size_t mark = 10000)
+{
+ test_transpose<float>(rows, cols, mark);
+ test_transpose<double>(rows, cols, mark);
+ test_transpose<complex<float>>(rows, cols, mark);
+ test_transpose<complex<double>>(rows, cols, mark);
+}
+
+TEST(matrix_transpose)
+{
+ constexpr size_t limit = 100;
+
+ for (int i = 1; i <= limit; ++i)
+ {
+ for (int j = 1; j <= limit; ++j)
+ {
+ test_transpose_t(i, j);
+ }
+ }
+
+ univector<int, 24> x = counter();
+ matrix_transpose(x.data(), x.data(), shape{ 2, 3, 4 });
+ CHECK(x == univector<int, 24>{ 0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21,
+ 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23 });
+
+ univector<uint8_t, 120> x2 = counter();
+ matrix_transpose(x2.data(), x2.data(), shape{ 2, 3, 4, 5 });
+ CHECK(x2 == univector<uint8_t, 120>{ 0, 60, 20, 80, 40, 100, 5, 65, 25, 85, 45, 105, 10, 70, 30,
+ 90, 50, 110, 15, 75, 35, 95, 55, 115, 1, 61, 21, 81, 41, 101,
+ 6, 66, 26, 86, 46, 106, 11, 71, 31, 91, 51, 111, 16, 76, 36,
+ 96, 56, 116, 2, 62, 22, 82, 42, 102, 7, 67, 27, 87, 47, 107,
+ 12, 72, 32, 92, 52, 112, 17, 77, 37, 97, 57, 117, 3, 63, 23,
+ 83, 43, 103, 8, 68, 28, 88, 48, 108, 13, 73, 33, 93, 53, 113,
+ 18, 78, 38, 98, 58, 118, 4, 64, 24, 84, 44, 104, 9, 69, 29,
+ 89, 49, 109, 14, 74, 34, 94, 54, 114, 19, 79, 39, 99, 59, 119 });
+
+ tensor<int, 1> d{ shape{ 24 } };
+ d = counter();
+ tensor<int, 3> dd = d.reshape(shape{ 2, 3, 4 });
+ tensor<int, 3> ddd = dd.transpose();
+ CHECK(trender(ddd.flatten_may_copy()) == tensor<int, 1>{ 0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21,
+ 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23 });
+}
+
+} // namespace CMT_ARCH_NAME
+} // namespace kfr
+
+CMT_PRAGMA_MSVC(warning(pop))