commit 351dc220e12a298c300b9e39531159fb3a606bb2
parent 967642436de0a9d580b7f6bbe4964758fe5fbb25
Author: [email protected] <[email protected]>
Date: Mon, 12 Sep 2016 20:16:04 +0300
Operators for complex numbers of different types
Diffstat:
3 files changed, 94 insertions(+), 4 deletions(-)
diff --git a/include/kfr/base/complex.hpp b/include/kfr/base/complex.hpp
@@ -97,6 +97,48 @@ struct complex
{
return (make_vector(x) / make_vector(y))[0];
}
+
+ template <typename U, KFR_ENABLE_IF(is_number<U>::value), typename C = common_type<complex, U>>
+ KFR_INTRIN friend C operator+(const complex& x, const U& y)
+ {
+ return static_cast<C>(x) + static_cast<C>(y);
+ }
+ template <typename U, KFR_ENABLE_IF(is_number<U>::value), typename C = common_type<complex, U>>
+ KFR_INTRIN friend C operator-(const complex& x, const U& y)
+ {
+ return static_cast<C>(x) - static_cast<C>(y);
+ }
+ template <typename U, KFR_ENABLE_IF(is_number<U>::value), typename C = common_type<complex, U>>
+ KFR_INTRIN friend C operator*(const complex& x, const U& y)
+ {
+ return static_cast<C>(x) * static_cast<C>(y);
+ }
+ template <typename U, KFR_ENABLE_IF(is_number<U>::value), typename C = common_type<complex, U>>
+ KFR_INTRIN friend C operator/(const complex& x, const U& y)
+ {
+ return static_cast<C>(x) / static_cast<C>(y);
+ }
+
+ template <typename U, KFR_ENABLE_IF(is_number<U>::value), typename C = common_type<complex, U>>
+ KFR_INTRIN friend C operator+(const U& x, const complex& y)
+ {
+ return static_cast<C>(x) + static_cast<C>(y);
+ }
+ template <typename U, KFR_ENABLE_IF(is_number<U>::value), typename C = common_type<complex, U>>
+ KFR_INTRIN friend C operator-(const U& x, const complex& y)
+ {
+ return static_cast<C>(x) - static_cast<C>(y);
+ }
+ template <typename U, KFR_ENABLE_IF(is_number<U>::value), typename C = common_type<complex, U>>
+ KFR_INTRIN friend C operator*(const U& x, const complex& y)
+ {
+ return static_cast<C>(x) * static_cast<C>(y);
+ }
+ template <typename U, KFR_ENABLE_IF(is_number<U>::value), typename C = common_type<complex, U>>
+ KFR_INTRIN friend C operator/(const U& x, const complex& y)
+ {
+ return static_cast<C>(x) / static_cast<C>(y);
+ }
KFR_INTRIN friend complex operator-(const complex& x) { return (-make_vector(x))[0]; }
};
#endif
@@ -610,3 +652,22 @@ KFR_INTRIN internal::expression_function<fn::csqrt, E1> csqrt(E1&& x)
return { fn::csqrt(), std::forward<E1>(x) };
}
}
+
+namespace std
+{
+template <typename T1, typename T2>
+struct common_type<kfr::complex<T1>, kfr::complex<T2>>
+{
+ using type = kfr::complex<typename common_type<T1, T2>::type>;
+};
+template <typename T1, typename T2>
+struct common_type<kfr::complex<T1>, T2>
+{
+ using type = kfr::complex<typename common_type<T1, T2>::type>;
+};
+template <typename T1, typename T2>
+struct common_type<T1, kfr::complex<T2>>
+{
+ using type = kfr::complex<typename common_type<T1, T2>::type>;
+};
+}
diff --git a/include/kfr/base/expression.hpp b/include/kfr/base/expression.hpp
@@ -37,6 +37,9 @@
namespace kfr
{
+template <typename>
+struct complex;
+
constexpr size_t infinite_size = static_cast<size_t>(-1);
constexpr inline size_t size_add(size_t x, size_t y)
@@ -276,12 +279,32 @@ struct expression_scalar : input_expression
}
};
-template <typename T>
-using arg_impl = conditional<is_number<T>::value || is_vec<T>::value,
- expression_scalar<subtype<decay<T>>, compound_type_traits<decay<T>>::width>, T>;
+template <typename, typename T, typename = void>
+struct arg_impl
+{
+ using type = T;
+};
+
+template <typename T1, typename T2>
+struct arg_impl<T1, T2, void_t<enable_if<is_number<T1>::value>>>
+{
+ using type = expression_scalar<T1>;
+};
+
+template <typename T1, typename T2>
+struct arg_impl<complex<T1>, T2>
+{
+ using type = expression_scalar<complex<T1>>;
+};
+
+template <typename T1, typename T2, size_t N>
+struct arg_impl<vec<T1, N>, T2>
+{
+ using type = expression_scalar<T1, N>;
+};
template <typename T>
-using arg = internal::arg_impl<T>;
+using arg = typename internal::arg_impl<decay<T>, T>::type;
template <typename Fn, typename... Args>
struct expression_function : expression<arg<Args>...>
diff --git a/tests/complex_test.cpp b/tests/complex_test.cpp
@@ -195,6 +195,12 @@ TEST(static_tests)
assert_is_same<ftype<complex<i64>>, complex<f64>>();
assert_is_same<ftype<vec<complex<i32>, 4>>, vec<complex<f32>, 4>>();
assert_is_same<ftype<vec<complex<i64>, 8>>, vec<complex<f64>, 8>>();
+
+ assert_is_same<kfr::internal::arg<int>, kfr::internal::expression_scalar<int, 1>>();
+ assert_is_same<kfr::internal::arg<complex<int>>,
+ kfr::internal::expression_scalar<kfr::complex<int>, 1>>();
+
+ assert_is_same<common_type<complex<int>, double>, complex<double>>();
}
int main(int argc, char** argv)