Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement hist evaluator for multi-target tree. #8908

Merged
merged 1 commit into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions src/common/hist_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,22 @@
#ifndef XGBOOST_COMMON_HIST_UTIL_H_
#define XGBOOST_COMMON_HIST_UTIL_H_

#include <xgboost/data.h>

#include <algorithm>
#include <cstdint> // for uint32_t
#include <limits>
#include <map>
#include <memory>
#include <utility>
#include <vector>

#include "algorithm.h" // SegmentId
#include "categorical.h"
#include "common.h"
#include "quantile.h"
#include "row_set.h"
#include "threading_utils.h"
#include "timer.h"
#include "xgboost/base.h" // bst_feature_t, bst_bin_t
#include "xgboost/base.h" // for bst_feature_t, bst_bin_t
#include "xgboost/data.h"

namespace xgboost {
class GHistIndexMatrix;
Expand Down Expand Up @@ -392,15 +391,18 @@ class HistCollection {
}

// have we computed a histogram for i-th node?
bool RowExists(bst_uint nid) const {
[[nodiscard]] bool RowExists(bst_uint nid) const {
const uint32_t k_max = std::numeric_limits<uint32_t>::max();
return (nid < row_ptr_.size() && row_ptr_[nid] != k_max);
}

// initialize histogram collection
void Init(uint32_t nbins) {
if (nbins_ != nbins) {
nbins_ = nbins;
/**
* \brief Initialize histogram collection.
*
* \param n_total_bins Number of bins across all features.
*/
void Init(std::uint32_t n_total_bins) {
if (nbins_ != n_total_bins) {
nbins_ = n_total_bins;
// quite expensive operation, so let's do this only once
data_.clear();
}
Expand Down
21 changes: 12 additions & 9 deletions src/tree/common_row_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,22 +99,25 @@ class CommonRowPartitioner {

void FindSplitConditions(const std::vector<CPUExpandEntry>& nodes, const RegTree& tree,
const GHistIndexMatrix& gmat, std::vector<int32_t>* split_conditions) {
for (size_t i = 0; i < nodes.size(); ++i) {
const int32_t nid = nodes[i].nid;
const bst_uint fid = tree[nid].SplitIndex();
const bst_float split_pt = tree[nid].SplitCond();
const uint32_t lower_bound = gmat.cut.Ptrs()[fid];
const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1];
auto const& ptrs = gmat.cut.Ptrs();
auto const& vals = gmat.cut.Values();

for (std::size_t i = 0; i < nodes.size(); ++i) {
bst_node_t const nid = nodes[i].nid;
bst_feature_t const fid = tree[nid].SplitIndex();
const float split_pt = tree[nid].SplitCond();
const uint32_t lower_bound = ptrs[fid];
const uint32_t upper_bound = ptrs[fid + 1];
bst_bin_t split_cond = -1;
// convert floating-point split_pt into corresponding bin_id
// split_cond = -1 indicates that split_pt is less than all known cut points
CHECK_LT(upper_bound, static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
for (auto bound = lower_bound; bound < upper_bound; ++bound) {
if (split_pt == gmat.cut.Values()[bound]) {
split_cond = static_cast<int32_t>(bound);
if (split_pt == vals[bound]) {
split_cond = static_cast<bst_bin_t>(bound);
}
}
(*split_conditions).at(i) = split_cond;
(*split_conditions)[i] = split_cond;
}
}

Expand Down
247 changes: 229 additions & 18 deletions src/tree/hist/evaluate_splits.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,25 @@
#ifndef XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_
#define XGBOOST_TREE_HIST_EVALUATE_SPLITS_H_

#include <algorithm>
#include <cstddef> // for size_t
#include <limits>
#include <memory>
#include <numeric>
#include <utility>
#include <vector>

#include "../../common/categorical.h"
#include "../../common/hist_util.h"
#include "../../common/random.h"
#include "../../data/gradient_index.h"
#include "../constraints.h"
#include "../param.h" // for TrainParam
#include "../split_evaluator.h"
#include "xgboost/context.h"
#include <algorithm> // for copy
#include <cstddef> // for size_t
#include <limits> // for numeric_limits
#include <memory> // for shared_ptr
#include <numeric> // for accumulate
#include <utility> // for move
#include <vector> // for vector

#include "../../common/categorical.h" // for CatBitField
#include "../../common/hist_util.h" // for GHistRow, HistogramCuts
#include "../../common/linalg_op.h" // for cbegin, cend, begin
#include "../../common/random.h" // for ColumnSampler
#include "../constraints.h" // for FeatureInteractionConstraintHost
#include "../param.h" // for TrainParam
#include "../split_evaluator.h" // for TreeEvaluator
#include "expand_entry.h" // for MultiExpandEntry
#include "xgboost/base.h" // for bst_node_t, bst_target_t, bst_feature_t
#include "xgboost/context.h" // for COntext
#include "xgboost/linalg.h" // for Constants, Vector

namespace xgboost::tree {
template <typename ExpandEntry>
Expand Down Expand Up @@ -410,8 +413,6 @@ class HistEvaluator {
tree[candidate.nid].SplitIndex(), left_weight,
right_weight);

auto max_node = std::max(left_child, tree[candidate.nid].RightChild());
max_node = std::max(candidate.nid, max_node);
snode_.resize(tree.GetNodes().size());
snode_.at(left_child).stats = candidate.split.left_sum;
snode_.at(left_child).root_gain =
Expand Down Expand Up @@ -456,6 +457,216 @@ class HistEvaluator {
}
};

class HistMultiEvaluator {
std::vector<double> gain_;
linalg::Matrix<GradientPairPrecise> stats_;
TrainParam const *param_;
FeatureInteractionConstraintHost interaction_constraints_;
std::shared_ptr<common::ColumnSampler> column_sampler_;
Context const *ctx_;

private:
static double MultiCalcSplitGain(TrainParam const &param,
linalg::VectorView<GradientPairPrecise const> left_sum,
linalg::VectorView<GradientPairPrecise const> right_sum,
linalg::VectorView<float> left_weight,
linalg::VectorView<float> right_weight) {
CalcWeight(param, left_sum, left_weight);
CalcWeight(param, right_sum, right_weight);

auto left_gain = CalcGainGivenWeight(param, left_sum, left_weight);
auto right_gain = CalcGainGivenWeight(param, right_sum, right_weight);
return left_gain + right_gain;
}

template <bst_bin_t d_step>
bool EnumerateSplit(common::HistogramCuts const &cut, bst_feature_t fidx,
common::Span<common::GHistRow const> hist,
linalg::VectorView<GradientPairPrecise const> parent_sum, double parent_gain,
SplitEntryContainer<std::vector<GradientPairPrecise>> *p_best) const {
auto const &cut_ptr = cut.Ptrs();
auto const &cut_val = cut.Values();
auto const &min_val = cut.MinValues();

auto sum = linalg::Empty<GradientPairPrecise>(ctx_, 2, hist.size());
auto left_sum = sum.Slice(0, linalg::All());
auto right_sum = sum.Slice(1, linalg::All());

bst_bin_t ibegin, iend;
if (d_step > 0) {
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx]);
iend = static_cast<bst_bin_t>(cut_ptr[fidx + 1]);
} else {
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
iend = static_cast<bst_bin_t>(cut_ptr[fidx]) - 1;
}
const auto imin = static_cast<bst_bin_t>(cut_ptr[fidx]);

auto n_targets = hist.size();
auto weight = linalg::Empty<float>(ctx_, 2, n_targets);
auto left_weight = weight.Slice(0, linalg::All());
auto right_weight = weight.Slice(1, linalg::All());

for (bst_bin_t i = ibegin; i != iend; i += d_step) {
for (bst_target_t t = 0; t < n_targets; ++t) {
auto t_hist = hist[t];
auto t_p = parent_sum(t);
left_sum(t) += t_hist[i];
right_sum(t) = t_p - left_sum(t);
}

if (d_step > 0) {
auto split_pt = cut_val[i];
auto loss_chg =
MultiCalcSplitGain(*param_, right_sum, left_sum, right_weight, left_weight) -
parent_gain;
p_best->Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
} else {
float split_pt;
if (i == imin) {
split_pt = min_val[fidx];
} else {
split_pt = cut_val[i - 1];
}
auto loss_chg =
MultiCalcSplitGain(*param_, right_sum, left_sum, left_weight, right_weight) -
parent_gain;
p_best->Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum);
}
}
// return true if there's missing. Doesn't handle floating-point error well.
if (d_step == +1) {
return !std::equal(linalg::cbegin(left_sum), linalg::cend(left_sum),
linalg::cbegin(parent_sum));
}
return false;
}

public:
void EvaluateSplits(RegTree const &tree, common::Span<const common::HistCollection *> hist,
common::HistogramCuts const &cut, std::vector<MultiExpandEntry> *p_entries) {
auto &entries = *p_entries;
std::vector<std::shared_ptr<HostDeviceVector<bst_feature_t>>> features(entries.size());

for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
auto nidx = entries[nidx_in_set].nid;
features[nidx_in_set] = column_sampler_->GetFeatureSet(tree.GetDepth(nidx));
}
CHECK(!features.empty());

std::int32_t n_threads = ctx_->Threads();
std::size_t const grain_size = std::max<std::size_t>(1, features.front()->Size() / n_threads);
common::BlockedSpace2d space(
entries.size(), [&](std::size_t nidx_in_set) { return features[nidx_in_set]->Size(); },
grain_size);

std::vector<MultiExpandEntry> tloc_candidates(n_threads * entries.size());
for (std::size_t i = 0; i < entries.size(); ++i) {
for (std::int32_t j = 0; j < n_threads; ++j) {
tloc_candidates[i * n_threads + j] = entries[i];
}
}
common::ParallelFor2d(space, n_threads, [&](std::size_t nidx_in_set, common::Range1d r) {
auto tidx = omp_get_thread_num();
auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx];
auto best = &entry->split;
auto parent_sum = stats_.Slice(entry->nid, linalg::All());
std::vector<common::GHistRow> node_hist;
for (auto t_hist : hist) {
node_hist.push_back((*t_hist)[entry->nid]);
}
auto features_set = features[nidx_in_set]->ConstHostSpan();

for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) {
auto fidx = features_set[fidx_in_set];
if (!interaction_constraints_.Query(entry->nid, fidx)) {
continue;
}
auto parent_gain = gain_[entry->nid];
bool missing =
this->EnumerateSplit<+1>(cut, fidx, node_hist, parent_sum, parent_gain, best);
if (missing) {
this->EnumerateSplit<-1>(cut, fidx, node_hist, parent_sum, parent_gain, best);
}
}
});

for (std::size_t nidx_in_set = 0; nidx_in_set < entries.size(); ++nidx_in_set) {
for (auto tidx = 0; tidx < n_threads; ++tidx) {
entries[nidx_in_set].split.Update(tloc_candidates[n_threads * nidx_in_set + tidx].split);
}
}
}

linalg::Vector<float> InitRoot(linalg::VectorView<GradientPairPrecise const> root_sum) {
auto n_targets = root_sum.Size();
stats_ = linalg::Constant(ctx_, GradientPairPrecise{}, 1, n_targets);
gain_.resize(1);

linalg::Vector<float> weight({n_targets}, ctx_->gpu_id);
CalcWeight(*param_, root_sum, weight.HostView());
auto root_gain = CalcGainGivenWeight(*param_, root_sum, weight.HostView());
gain_.front() = root_gain;

auto h_stats = stats_.HostView();
std::copy(linalg::cbegin(root_sum), linalg::cend(root_sum), linalg::begin(h_stats));

return weight;
}

void ApplyTreeSplit(MultiExpandEntry const &candidate, RegTree *p_tree) {
auto n_targets = p_tree->NumTargets();
auto parent_sum = stats_.Slice(candidate.nid, linalg::All());

auto weight = linalg::Empty<float>(ctx_, 3, n_targets);
auto base_weight = weight.Slice(0, linalg::All());
CalcWeight(*param_, parent_sum, base_weight);

auto left_weight = weight.Slice(1, linalg::All());
auto left_sum =
linalg::MakeVec(candidate.split.left_sum.data(), candidate.split.left_sum.size());
CalcWeight(*param_, left_sum, param_->learning_rate, left_weight);

auto right_weight = weight.Slice(2, linalg::All());
auto right_sum =
linalg::MakeVec(candidate.split.right_sum.data(), candidate.split.right_sum.size());
CalcWeight(*param_, right_sum, param_->learning_rate, right_weight);

p_tree->ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
candidate.split.DefaultLeft(), base_weight, left_weight, right_weight);
CHECK(p_tree->IsMultiTarget());
auto left_child = p_tree->LeftChild(candidate.nid);
CHECK_GT(left_child, candidate.nid);
auto right_child = p_tree->RightChild(candidate.nid);
CHECK_GT(right_child, candidate.nid);

std::size_t n_nodes = p_tree->Size();
gain_.resize(n_nodes);
gain_[left_child] = CalcGainGivenWeight(*param_, left_sum, left_weight);
gain_[right_child] = CalcGainGivenWeight(*param_, right_sum, right_weight);

if (n_nodes >= stats_.Shape(0)) {
stats_.Reshape(n_nodes * 2, stats_.Shape(1));
}
CHECK_EQ(stats_.Shape(1), n_targets);
auto left_sum_stat = stats_.Slice(left_child, linalg::All());
std::copy(candidate.split.left_sum.cbegin(), candidate.split.left_sum.cend(),
linalg::begin(left_sum_stat));
auto right_sum_stat = stats_.Slice(right_child, linalg::All());
std::copy(candidate.split.right_sum.cbegin(), candidate.split.right_sum.cend(),
linalg::begin(right_sum_stat));
}

explicit HistMultiEvaluator(Context const *ctx, MetaInfo const &info, TrainParam const *param,
std::shared_ptr<common::ColumnSampler> sampler)
: param_{param}, column_sampler_{std::move(sampler)}, ctx_{ctx} {
interaction_constraints_.Configure(*param, info.num_col_);
column_sampler_->Init(ctx, info.num_col_, info.feature_weights.HostVector(),
param_->colsample_bynode, param_->colsample_bylevel,
param_->colsample_bytree);
}
};

/**
* \brief CPU implementation of update prediction cache, which calculates the leaf value
* for the last tree and accumulates it to prediction vector.
Expand Down
Loading