Skip to content

Commit

Permalink
feat: add merge method to ultra::distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
morinim committed Jul 3, 2024
1 parent 9e30485 commit e2ad8a3
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/kernel/distribution.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ class distribution
[[nodiscard]] T standard_deviation() const;
[[nodiscard]] T variance() const;

bool is_valid() const;
void merge(distribution<T>);

[[nodiscard]] bool is_valid() const;

public: // Serialization
bool load(std::istream &);
Expand Down
51 changes: 51 additions & 0 deletions src/kernel/distribution.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ const std::map<T, std::uintmax_t> &distribution<T>::seen() const
///
/// \f$H(X)=-\sum_{i=1}^n p(x_i) \dot log_b(p(x_i))\f$
///
/// \note
/// We use an offline algorithm
/// (https://en.wikipedia.org/wiki/Online_algorithm).
///
Expand Down Expand Up @@ -296,4 +297,54 @@ bool distribution<T>::is_valid() const
return true;
}

///
/// Updates the this distribution considering data from another distribution.
///
/// \param[in] d2 distribution to be merged with `*this`
///
/// \see
/// - https://math.stackexchange.com/q/453113
/// - https://stats.stackexchange.com/q/43159
///
template<ArithmeticFloatingType T>
void distribution<T>::merge(distribution<T> d2)
{
if (!d2.size())
return;

const auto max1(max());
const auto max2(d2.max());

const auto min1(min());
const auto min2(d2.min());

const auto size1(size());
const auto size2(d2.size());

const auto new_size(size1 + size2);

const auto mean1(mean());
const auto mean2(d2.mean());

const auto new_mean((size1*mean1 + size2*mean2) / new_size);

const auto variance1(variance());
const auto variance2(d2.variance());

const auto new_variance(
((variance1 + mean1*mean1)*size1
+ (variance2 + mean2*mean2)*size2) / new_size
- new_mean*new_mean);

mean_ = new_mean;
m2_ = new_variance * new_size;

max_ = std::max(max1, max2);
min_ = std::min(min1, min2);

size_ = new_size;

seen_.merge(d2.seen_);
}

#endif // include guard
78 changes: 78 additions & 0 deletions src/test/distribution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,84 @@ TEST_CASE("Base")
CHECK(e2 < d.entropy());
}

TEST_CASE("Merge")
{
using namespace ultra;

SUBCASE("Same distribution")
{
const std::vector<std::pair<double, unsigned>> elems =
{
{2.0, 1},
{4.0, 3},
{5.0, 2},
{7.0, 1},
{9.0, 1}
};

distribution<double> d;
for (const auto &e : elems)
for (unsigned n(e.second); n; --n)
d.add(e.first);

const auto mean_before(d.mean());
const auto variance_before(d.variance());
const auto min_before(d.min());
const auto max_before(d.max());

auto d2(d);

d.merge(std::move(d2));

CHECK(d.mean() == doctest::Approx(mean_before));
CHECK(d.min() == doctest::Approx(min_before));
CHECK(d.max() == doctest::Approx(max_before));
CHECK(d.variance() == doctest::Approx(variance_before));
}

SUBCASE("Single element distribution")
{
distribution<double> d1;
d1.add(-1.0);

distribution<double> d2;
d2.add(+1.0);

d1.merge(std::move(d2));

CHECK(d1.mean() == doctest::Approx(0.0));
CHECK(d1.variance() == doctest::Approx(1.0));
CHECK(d1.min() == doctest::Approx(-1.0));
CHECK(d1.max() == doctest::Approx(+1.0));
}

SUBCASE("General case")
{
distribution<double> d, d1, d2;

for (unsigned cycles(100); cycles; --cycles)
{
const auto elem(random::between(-1000.0, 1000.0));
d.add(elem);

if (cycles < 500)
d1.add(elem);
else
d2.add(elem);
}

CHECK(d1.min() >= -1000.0);
CHECK(d1.max() < 1000.0);

d1.merge(std::move(d2));

CHECK(d.mean() == doctest::Approx(d1.mean()));
CHECK(d.min() == doctest::Approx(d1.min()));
CHECK(d.max() == doctest::Approx(d1.max()));
CHECK(d.variance() == doctest::Approx(d1.variance()));
}
}

TEST_CASE("Serialization")
{
using namespace ultra;
Expand Down

0 comments on commit e2ad8a3

Please sign in to comment.