Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More refactoring to take advantage of collective aggregators #9081

Merged
merged 4 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,14 @@ class MetaInfo {
*/
bool IsVerticalFederated() const;

/*!
* \brief A convenient method to check if the MetaInfo should contain labels.
*
* Normally we assume labels are available everywhere. The only exception is in vertical federated
* learning where labels are only available on worker 0.
*/
bool ShouldHaveLabels() const;

private:
void SetInfoFromHost(Context const& ctx, StringView key, Json arr);
void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr);
Expand Down
10 changes: 4 additions & 6 deletions src/collective/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,16 @@ namespace collective {
* @param buffer The buffer storing the results.
* @param size The size of the buffer.
* @param function The function used to calculate the results.
* @param args Arguments to the function.
*/
template <typename Function, typename T, typename... Args>
void ApplyWithLabels(MetaInfo const& info, T* buffer, size_t size, Function&& function,
Args&&... args) {
template <typename Function>
void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& function) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is type erasion necessary? (void*)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When doing the broadcast, we don't need to know the type, only the size (https://github.com/dmlc/xgboost/blob/master/src/collective/communicator-inl.h#L128), so in here we don't really care about the type either.

if (info.IsVerticalFederated()) {
// We assume labels are only available on worker 0, so the calculation is done there and result
// broadcast to other workers.
std::string message;
if (collective::GetRank() == 0) {
try {
std::forward<Function>(function)(std::forward<Args>(args)...);
std::forward<Function>(function)();
} catch (dmlc::Error& e) {
message = e.what();
}
Expand All @@ -55,7 +53,7 @@ void ApplyWithLabels(MetaInfo const& info, T* buffer, size_t size, Function&& fu
LOG(FATAL) << &message[0];
}
} else {
std::forward<Function>(function)(std::forward<Args>(args)...);
std::forward<Function>(function)();
}
}

Expand Down
10 changes: 4 additions & 6 deletions src/common/hist_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,18 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b

if (!use_sorted) {
HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info),
m->Info().IsColumnSplit(), n_threads);
HostSketchContainer::UseGroup(info), n_threads);
for (auto const& page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian);
}
container.MakeCuts(&out);
container.MakeCuts(m->Info(), &out);
} else {
SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced,
HostSketchContainer::UseGroup(info),
m->Info().IsColumnSplit(), n_threads};
HostSketchContainer::UseGroup(info), n_threads};
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
container.PushColPage(page, info, hessian);
}
container.MakeCuts(&out);
container.MakeCuts(m->Info(), &out);
}

return out;
Expand Down
26 changes: 12 additions & 14 deletions src/common/quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <limits>
#include <utility>

#include "../collective/aggregator.h"
#include "../collective/communicator-inl.h"
#include "../data/adapter.h"
#include "categorical.h"
Expand All @@ -18,13 +19,12 @@ template <typename WQSketch>
SketchContainerImpl<WQSketch>::SketchContainerImpl(std::vector<bst_row_t> columns_size,
int32_t max_bins,
Span<FeatureType const> feature_types,
bool use_group, bool col_split,
bool use_group,
int32_t n_threads)
: feature_types_(feature_types.cbegin(), feature_types.cend()),
columns_size_{std::move(columns_size)},
max_bins_{max_bins},
use_group_ind_{use_group},
col_split_{col_split},
n_threads_{n_threads} {
monitor_.Init(__func__);
CHECK_NE(columns_size_.size(), 0);
Expand Down Expand Up @@ -202,10 +202,10 @@ void SketchContainerImpl<WQSketch>::GatherSketchInfo(
}

template <typename WQSketch>
void SketchContainerImpl<WQSketch>::AllreduceCategories() {
void SketchContainerImpl<WQSketch>::AllreduceCategories(MetaInfo const& info) {
auto world_size = collective::GetWorldSize();
auto rank = collective::GetRank();
if (world_size == 1 || col_split_) {
if (world_size == 1 || info.IsColumnSplit()) {
return;
}

Expand Down Expand Up @@ -273,6 +273,7 @@ void SketchContainerImpl<WQSketch>::AllreduceCategories() {

template <typename WQSketch>
void SketchContainerImpl<WQSketch>::AllReduce(
MetaInfo const& info,
std::vector<typename WQSketch::SummaryContainer> *p_reduced,
std::vector<int32_t>* p_num_cuts) {
monitor_.Start(__func__);
Expand All @@ -281,7 +282,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(
collective::Allreduce<collective::Operation::kMax>(&n_columns, 1);
CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers";

AllreduceCategories();
AllreduceCategories(info);

auto& num_cuts = *p_num_cuts;
CHECK_EQ(num_cuts.size(), 0);
Expand All @@ -292,10 +293,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(

// Prune the intermediate num cuts for synchronization.
std::vector<bst_row_t> global_column_size(columns_size_);
if (!col_split_) {
collective::Allreduce<collective::Operation::kSum>(global_column_size.data(),
global_column_size.size());
}
collective::GlobalSum(info, &global_column_size);

ParallelFor(sketches_.size(), n_threads_, [&](size_t i) {
int32_t intermediate_num_cuts = static_cast<int32_t>(
Expand All @@ -316,7 +314,7 @@ void SketchContainerImpl<WQSketch>::AllReduce(
});

auto world = collective::GetWorldSize();
if (world == 1 || col_split_) {
if (world == 1 || info.IsColumnSplit()) {
monitor_.Stop(__func__);
return;
}
Expand Down Expand Up @@ -382,11 +380,11 @@ auto AddCategories(std::set<float> const &categories, HistogramCuts *cuts) {
}

template <typename WQSketch>
void SketchContainerImpl<WQSketch>::MakeCuts(HistogramCuts* cuts) {
void SketchContainerImpl<WQSketch>::MakeCuts(MetaInfo const& info, HistogramCuts* cuts) {
monitor_.Start(__func__);
std::vector<typename WQSketch::SummaryContainer> reduced;
std::vector<int32_t> num_cuts;
this->AllReduce(&reduced, &num_cuts);
this->AllReduce(info, &reduced, &num_cuts);

cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f);
std::vector<typename WQSketch::SummaryContainer> final_summaries(reduced.size());
Expand Down Expand Up @@ -443,8 +441,8 @@ template class SketchContainerImpl<WXQuantileSketch<float, float>>;

HostSketchContainer::HostSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
std::vector<size_t> columns_size, bool use_group,
bool col_split, int32_t n_threads)
: SketchContainerImpl{columns_size, max_bins, ft, use_group, col_split, n_threads} {
int32_t n_threads)
: SketchContainerImpl{columns_size, max_bins, ft, use_group, n_threads} {
monitor_.Init(__func__);
ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) {
auto n_bins = std::min(static_cast<size_t>(max_bins_), columns_size_[i]);
Expand Down
16 changes: 7 additions & 9 deletions src/common/quantile.h
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,6 @@ class SketchContainerImpl {
std::vector<bst_row_t> columns_size_;
int32_t max_bins_;
bool use_group_ind_{false};
bool col_split_;
int32_t n_threads_;
bool has_categorical_{false};
Monitor monitor_;
Expand All @@ -802,7 +801,7 @@ class SketchContainerImpl {
* \param use_group whether is assigned to group to data instance.
*/
SketchContainerImpl(std::vector<bst_row_t> columns_size, int32_t max_bins,
common::Span<FeatureType const> feature_types, bool use_group, bool col_split,
common::Span<FeatureType const> feature_types, bool use_group,
int32_t n_threads);

static bool UseGroup(MetaInfo const &info) {
Expand All @@ -829,7 +828,7 @@ class SketchContainerImpl {
std::vector<bst_row_t> *p_sketches_scan,
std::vector<typename WQSketch::Entry> *p_global_sketches);
// Merge sketches from all workers.
void AllReduce(std::vector<typename WQSketch::SummaryContainer> *p_reduced,
void AllReduce(MetaInfo const& info, std::vector<typename WQSketch::SummaryContainer> *p_reduced,
std::vector<int32_t> *p_num_cuts);

template <typename Batch, typename IsValid>
Expand Down Expand Up @@ -883,11 +882,11 @@ class SketchContainerImpl {
/* \brief Push a CSR matrix. */
void PushRowPage(SparsePage const &page, MetaInfo const &info, Span<float const> hessian = {});

void MakeCuts(HistogramCuts* cuts);
void MakeCuts(MetaInfo const& info, HistogramCuts* cuts);

private:
// Merge all categories from other workers.
void AllreduceCategories();
void AllreduceCategories(MetaInfo const& info);
};

class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, float>> {
Expand All @@ -896,8 +895,7 @@ class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, fl

public:
HostSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
std::vector<size_t> columns_size, bool use_group, bool col_split,
int32_t n_threads);
std::vector<size_t> columns_size, bool use_group, int32_t n_threads);

template <typename Batch>
void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, float missing);
Expand Down Expand Up @@ -993,9 +991,9 @@ class SortedSketchContainer : public SketchContainerImpl<WXQuantileSketch<float,

public:
explicit SortedSketchContainer(int32_t max_bins, common::Span<FeatureType const> ft,
std::vector<size_t> columns_size, bool use_group, bool col_split,
std::vector<size_t> columns_size, bool use_group,
int32_t n_threads)
: SketchContainerImpl{columns_size, max_bins, ft, use_group, col_split, n_threads} {
: SketchContainerImpl{columns_size, max_bins, ft, use_group, n_threads} {
monitor_.Init(__func__);
sketches_.resize(columns_size.size());
size_t i = 0;
Expand Down
4 changes: 4 additions & 0 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,10 @@ bool MetaInfo::IsVerticalFederated() const {
return collective::IsFederated() && IsColumnSplit();
}

bool MetaInfo::ShouldHaveLabels() const {
return !IsVerticalFederated() || collective::GetRank() == 0;
}

using DMatrixThreadLocal =
dmlc::ThreadLocalStore<std::map<DMatrix const *, XGBAPIThreadLocalEntry>>;

Expand Down
4 changes: 2 additions & 2 deletions src/data/iterative_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
SyncFeatureType(&h_ft);
p_sketch.reset(new common::HostSketchContainer{
batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(),
proxy->Info().IsColumnSplit(), ctx_.Threads()});
ctx_.Threads()});
}
HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = batch_nnz[i];
Expand All @@ -228,7 +228,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
CHECK_EQ(accumulated_rows, Info().num_row_);

CHECK(p_sketch);
p_sketch->MakeCuts(&cuts);
p_sketch->MakeCuts(Info(), &cuts);
}
if (!h_ft.empty()) {
CHECK_EQ(h_ft.size(), n_features);
Expand Down
70 changes: 33 additions & 37 deletions src/objective/adaptive.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,44 +99,40 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& posit
auto h_predt = linalg::MakeTensorView(ctx, predt.ConstHostSpan(), info.num_row_,
predt.Size() / info.num_row_);

if (!info.IsVerticalFederated() || collective::GetRank() == 0) {
// loop over each leaf
common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) {
auto nidx = h_node_idx[k];
CHECK(tree[nidx].IsLeaf());
CHECK_LT(k + 1, h_node_ptr.size());
size_t n = h_node_ptr[k + 1] - h_node_ptr[k];
auto h_row_set = common::Span<size_t const>{ridx}.subspan(h_node_ptr[k], n);

auto h_labels = info.labels.HostView().Slice(linalg::All(), IdxY(info, group_idx));
auto h_weights = linalg::MakeVec(&info.weights_);

auto iter = common::MakeIndexTransformIter([&](size_t i) -> float {
auto row_idx = h_row_set[i];
return h_labels(row_idx) - h_predt(row_idx, group_idx);
collective::ApplyWithLabels(
info, static_cast<void*>(quantiles.data()), quantiles.size() * sizeof(float), [&] {
// loop over each leaf
common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) {
auto nidx = h_node_idx[k];
CHECK(tree[nidx].IsLeaf());
CHECK_LT(k + 1, h_node_ptr.size());
size_t n = h_node_ptr[k + 1] - h_node_ptr[k];
auto h_row_set = common::Span<size_t const>{ridx}.subspan(h_node_ptr[k], n);

auto h_labels = info.labels.HostView().Slice(linalg::All(), IdxY(info, group_idx));
auto h_weights = linalg::MakeVec(&info.weights_);

auto iter = common::MakeIndexTransformIter([&](size_t i) -> float {
auto row_idx = h_row_set[i];
return h_labels(row_idx) - h_predt(row_idx, group_idx);
});
auto w_it = common::MakeIndexTransformIter([&](size_t i) -> float {
auto row_idx = h_row_set[i];
return h_weights(row_idx);
});

float q{0};
if (info.weights_.Empty()) {
q = common::Quantile(ctx, alpha, iter, iter + h_row_set.size());
} else {
q = common::WeightedQuantile(ctx, alpha, iter, iter + h_row_set.size(), w_it);
}
if (std::isnan(q)) {
CHECK(h_row_set.empty());
}
quantiles.at(k) = q;
});
});
auto w_it = common::MakeIndexTransformIter([&](size_t i) -> float {
auto row_idx = h_row_set[i];
return h_weights(row_idx);
});

float q{0};
if (info.weights_.Empty()) {
q = common::Quantile(ctx, alpha, iter, iter + h_row_set.size());
} else {
q = common::WeightedQuantile(ctx, alpha, iter, iter + h_row_set.size(), w_it);
}
if (std::isnan(q)) {
CHECK(h_row_set.empty());
}
quantiles.at(k) = q;
});
}

if (info.IsVerticalFederated()) {
collective::Broadcast(static_cast<void*>(quantiles.data()), quantiles.size() * sizeof(float),
0);
}

UpdateLeafValues(&quantiles, nidx, info, learning_rate, p_tree);
}
Expand Down
2 changes: 1 addition & 1 deletion src/objective/quantile_obj.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class QuantileRegression : public ObjFunction {
bst_target_t Targets(MetaInfo const& info) const override {
auto const& alpha = param_.quantile_alpha.Get();
CHECK_EQ(alpha.size(), alpha_.Size()) << "The objective is not yet configured.";
if (!info.IsVerticalFederated() || collective::GetRank() == 0) {
if (info.ShouldHaveLabels()) {
CHECK_EQ(info.labels.Shape(1), 1)
<< "Multi-target is not yet supported by the quantile loss.";
}
Expand Down
Loading