Skip to content

Commit

Permalink
Rebase.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Feb 15, 2023
1 parent 31373a9 commit 941c51e
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 18 deletions.
17 changes: 8 additions & 9 deletions src/common/stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
#ifndef XGBOOST_COMMON_STATS_H_
#define XGBOOST_COMMON_STATS_H_
#include <algorithm>
#include <iterator>
#include <iterator> // for distance
#include <limits>
#include <vector>

#include "algorithm.h" // for StableSort
#include "common.h" // AssertGPUSupport, OptionalWeights
#include "optional_weight.h" // OptionalWeights
#include "transform_iterator.h" // MakeIndexTransformIter
Expand All @@ -30,7 +31,7 @@ namespace common {
* \return The result of interpolation.
*/
template <typename Iter>
float Quantile(double alpha, Iter const& begin, Iter const& end) {
float Quantile(Context const* ctx, double alpha, Iter const& begin, Iter const& end) {
CHECK(alpha >= 0 && alpha <= 1);
auto n = static_cast<double>(std::distance(begin, end));
if (n == 0) {
Expand All @@ -43,9 +44,8 @@ float Quantile(double alpha, Iter const& begin, Iter const& end) {
std::stable_sort(sorted_idx.begin(), sorted_idx.end(),
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
} else {
XGBOOST_PARALLEL_STABLE_SORT(
sorted_idx.begin(), sorted_idx.end(),
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
StableSort(ctx, sorted_idx.begin(), sorted_idx.end(),
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
}

auto val = [&](size_t i) { return *(begin + sorted_idx[i]); };
Expand Down Expand Up @@ -76,7 +76,7 @@ float Quantile(double alpha, Iter const& begin, Iter const& end) {
* weighted quantile with interpolation.
*/
template <typename Iter, typename WeightIter>
float WeightedQuantile(double alpha, Iter begin, Iter end, WeightIter w_begin) {
float WeightedQuantile(Context const* ctx, double alpha, Iter begin, Iter end, WeightIter w_begin) {
auto n = static_cast<double>(std::distance(begin, end));
if (n == 0) {
return std::numeric_limits<float>::quiet_NaN();
Expand All @@ -87,9 +87,8 @@ float WeightedQuantile(double alpha, Iter begin, Iter end, WeightIter w_begin) {
std::stable_sort(sorted_idx.begin(), sorted_idx.end(),
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
} else {
XGBOOST_PARALLEL_STABLE_SORT(
sorted_idx.begin(), sorted_idx.end(),
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
StableSort(ctx, sorted_idx.begin(), sorted_idx.end(),
[&](std::size_t l, std::size_t r) { return *(begin + l) < *(begin + r); });
}

auto val = [&](size_t i) { return *(begin + sorted_idx[i]); };
Expand Down
4 changes: 2 additions & 2 deletions src/objective/adaptive.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit

float q{0};
if (info.weights_.Empty()) {
q = common::Quantile(alpha, iter, iter + h_row_set.size());
q = common::Quantile(ctx, alpha, iter, iter + h_row_set.size());
} else {
q = common::WeightedQuantile(alpha, iter, iter + h_row_set.size(), w_it);
q = common::WeightedQuantile(ctx, alpha, iter, iter + h_row_set.size(), w_it);
}
if (std::isnan(q)) {
CHECK(h_row_set.empty());
Expand Down
16 changes: 9 additions & 7 deletions tests/cpp/common/test_stats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,34 @@
namespace xgboost {
namespace common {
TEST(Stats, Quantile) {
Context ctx;
{
linalg::Tensor<float, 1> arr({20.f, 0.f, 15.f, 50.f, 40.f, 0.f, 35.f}, {7}, Context::kCpuId);
std::vector<size_t> index{0, 2, 3, 4, 6};
auto h_arr = arr.HostView();
auto beg = MakeIndexTransformIter([&](size_t i) { return h_arr(index[i]); });
auto end = beg + index.size();
auto q = Quantile(0.40f, beg, end);
auto q = Quantile(&ctx, 0.40f, beg, end);
ASSERT_EQ(q, 26.0);

q = Quantile(0.20f, beg, end);
q = Quantile(&ctx, 0.20f, beg, end);
ASSERT_EQ(q, 16.0);

q = Quantile(0.10f, beg, end);
q = Quantile(&ctx, 0.10f, beg, end);
ASSERT_EQ(q, 15.0);
}

{
std::vector<float> vec{1., 2., 3., 4., 5.};
auto beg = MakeIndexTransformIter([&](size_t i) { return vec[i]; });
auto end = beg + vec.size();
auto q = Quantile(0.5f, beg, end);
auto q = Quantile(&ctx, 0.5f, beg, end);
ASSERT_EQ(q, 3.);
}
}

TEST(Stats, WeightedQuantile) {
Context ctx;
linalg::Tensor<float, 1> arr({1.f, 2.f, 3.f, 4.f, 5.f}, {5}, Context::kCpuId);
linalg::Tensor<float, 1> weight({1.f, 1.f, 1.f, 1.f, 1.f}, {5}, Context::kCpuId);

Expand All @@ -47,13 +49,13 @@ TEST(Stats, WeightedQuantile) {
auto end = beg + arr.Size();
auto w = MakeIndexTransformIter([&](size_t i) { return h_weight(i); });

auto q = WeightedQuantile(0.50f, beg, end, w);
auto q = WeightedQuantile(&ctx, 0.50f, beg, end, w);
ASSERT_EQ(q, 3);

q = WeightedQuantile(0.0, beg, end, w);
q = WeightedQuantile(&ctx, 0.0, beg, end, w);
ASSERT_EQ(q, 1);

q = WeightedQuantile(1.0, beg, end, w);
q = WeightedQuantile(&ctx, 1.0, beg, end, w);
ASSERT_EQ(q, 5);
}

Expand Down

0 comments on commit 941c51e

Please sign in to comment.