commit 8450ed6fb4749d0e03e39da510638bbd0c328404
parent aab2dcf5f8b445ee066dd982687189915d2dcc05
Author: [email protected] <[email protected]>
Date: Wed, 30 Nov 2022 03:24:26 +0000
Normal distribution for random number generator
Diffstat:
2 files changed, 149 insertions(+), 0 deletions(-)
diff --git a/include/kfr/base/random.hpp b/include/kfr/base/random.hpp
@@ -25,6 +25,9 @@
*/
#pragma once
+#include "../math/log_exp.hpp"
+#include "../math/sin_cos.hpp"
+#include "../math/sqrt.hpp"
#include "random_bits.hpp"
#include "state_holder.hpp"
@@ -76,6 +79,20 @@ KFR_INTRINSIC vec<T, N> random_range(random_state& state, T min, T max)
return (tmp * (max - min) + min) >> typebits<T>::bits;
}
+template <size_t N, typename T>
+KFR_INTRINSIC vec<T, N> random_normal(random_state& state, T mu, T sigma)
+{
+ static_assert(std::is_floating_point_v<T>, "random_normal requires floating point type");
+
+ constexpr size_t M = align_up(N, 2); // round up to 2
+
+ vec<T, M> u = random_uniform<T, M>(state);
+
+ vec<T, M / 2> mag = sigma * sqrt(T(-2.0) * log(even(u)));
+ vec<T, M> z = dup(mag) * cossin(c_pi<T, 2> * dupodd(u)) + mu;
+ return slice<0, N>(z);
+}
+
template <typename T, index_t Dims, bool Reference = false>
struct expression_random_uniform : expression_traits_defaults
{
@@ -120,6 +137,29 @@ struct expression_random_range : expression_traits_defaults
}
};
+template <typename T, index_t Dims, bool Reference = false>
+struct expression_random_normal : expression_traits_defaults
+{
+ using value_type = T;
+ constexpr static size_t dims = Dims;
+ constexpr static shape<dims> get_shape(const expression_random_normal&)
+ {
+ return shape<dims>(infinite_size);
+ }
+ constexpr static shape<dims> get_shape() { return shape<dims>(infinite_size); }
+
+ mutable state_holder<random_state, Reference> state;
+ T sigma{ 1 };
+ T mu{ 0 };
+
+ template <size_t N, index_t VecAxis>
+ friend KFR_INTRINSIC vec<T, N> get_elements(const expression_random_normal& self, shape<Dims>,
+ axis_params<VecAxis, N>)
+ {
+ return random_normal<N, T>(*self.state, self.mu, self.sigma);
+ }
+};
+
/// @brief Returns expression that returns pseudorandom values. Copies the given generator
template <typename T, index_t Dims = 1>
KFR_FUNCTION expression_random_uniform<T, Dims> gen_random_uniform(const random_state& state)
@@ -170,6 +210,33 @@ KFR_FUNCTION expression_random_range<T, Dims> gen_random_range(T min, T max)
}
#endif
+/// @brief Returns expression that returns pseudorandom values from normal (gaussian) distribution. Copies the
+/// given generator
+template <typename T, index_t Dims = 1>
+KFR_FUNCTION expression_random_normal<T, Dims> gen_random_normal(const random_state& state, T sigma = 1,
+ T mu = 0)
+{
+ return { {}, state, sigma, mu };
+}
+
+/// @brief Returns expression that returns pseudorandom values from normal (gaussian) distribution. References
+/// the given generator. Use std::ref(gen) to force this overload
+template <typename T, index_t Dims = 1>
+KFR_FUNCTION expression_random_normal<T, Dims, true> gen_random_normal(
+ std::reference_wrapper<random_state> state, T sigma = 1, T mu = 0)
+{
+ return { {}, state, sigma, mu };
+}
+
+#ifndef KFR_DISABLE_READCYCLECOUNTER
+/// @brief Returns expression that returns pseudorandom values from normal (gaussian) distribution
+template <typename T, index_t Dims = 1>
+KFR_FUNCTION expression_random_normal<T, Dims> gen_random_normal(T sigma = 1, T mu = 0)
+{
+ return { {}, random_init(), sigma, mu };
+}
+#endif
+
} // namespace CMT_ARCH_NAME
} // namespace kfr
diff --git a/tests/unit/base/random.cpp b/tests/unit/base/random.cpp
@@ -72,5 +72,87 @@ TEST(gen_random_range)
CHECK(maxof(v) <= fbase(1.0));
// println(mean(v));
}
+
+template <size_t Bins, typename E, typename TCount = uint32_t>
+struct expression_histogram : public expression_with_traits<E>
+{
+ size_t size;
+ using vector_type = univector<TCount, Bins == 0 ? tag_dynamic_vector : Bins>;
+ mutable vector_type values{};
+
+ using expression_with_traits<E>::expression_with_traits;
+
+ KFR_MEM_INTRINSIC expression_histogram(E&& e, size_t steps) : expression_with_traits<E>{ std::forward<E>(e) }
+ {
+ if constexpr (Bins == 0)
+ {
+ values = vector_type(steps, 0);
+ }
+ }
+
+ KFR_MEM_INTRINSIC TCount operator[](size_t n) const
+ {
+ KFR_LOGIC_CHECK(n < values.size() - 2, "n is outside histogram size");
+ return values[1 + n];
+ }
+ KFR_MEM_INTRINSIC TCount below() const { return values.front(); }
+ KFR_MEM_INTRINSIC TCount above() const { return values.back(); }
+ KFR_MEM_INTRINSIC univector_ref<const TCount> histogram() const
+ {
+ return values.slice(1, values.size());
+ }
+
+ using value_type = typename expression_with_traits<E>::value_type;
+
+ template <index_t Axis, size_t N>
+ friend KFR_INTRINSIC vec<value_type, N> get_elements(const expression_histogram& self,
+ const shape<expression_with_traits<E>::dims>& index,
+ const axis_params<Axis, N>& sh)
+ {
+ vec<value_type, N> v = get_elements(self.first(), index, sh);
+ for (size_t i = 0; i < N; ++i)
+ {
+ int64_t n = 1 + std::floor(v[i] * (self.values.size() - 2));
+ ++self.values[clamp(n, 0, self.values.size() - 1)];
+ }
+ return v;
+ }
+};
+
+template <typename E, typename TCount = uint32_t>
+KFR_INTRINSIC expression_histogram<0, E, TCount> histogram(E&& expr, size_t bins)
+{
+ return { std::forward<E>(expr), bins };
+}
+
+template <size_t Bins, typename E, typename TCount = uint32_t>
+KFR_INTRINSIC expression_histogram<Bins, E, TCount> histogram(E&& expr)
+{
+ return { std::forward<E>(expr), Bins };
+}
+
+TEST(random_normal)
+{
+ random_state gen = random_init(1, 2, 3, 4);
+ vec<fbase, 12> r = random_normal<12, fbase>(gen, 0.0, 1.0);
+ println(r);
+ r = random_normal<12, fbase>(gen, 0.0, 1.0);
+ vec<fbase, 11> r2 = random_normal<11, fbase>(gen, 0.0, 1.0);
+ println(r2);
+
+ expression_histogram h = histogram<20>(gen_random_normal<double>() * 0.15 + 0.5);
+ render(truncate(h, 1000));
+ println(h.below());
+ println(h.histogram());
+ println(h.above());
+ render(truncate(h, 10000));
+ println(h.below());
+ println(h.histogram());
+ println(h.above());
+ render(truncate(h, 100000));
+ println(h.below());
+ println(h.histogram());
+ println(h.above());
+}
} // namespace CMT_ARCH_NAME
} // namespace kfr