diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 4af306859988..fe22fb2b5d2c 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -196,6 +196,14 @@ class MetaInfo { */ bool IsVerticalFederated() const; + /*! + * \brief A convenient method to check if the MetaInfo should contain labels. + * + * Normally we assume labels are available everywhere. The only exception is in vertical federated + * learning where labels are only available on worker 0. + */ + bool ShouldHaveLabels() const; + private: void SetInfoFromHost(Context const& ctx, StringView key, Json arr); void SetInfoFromCUDA(Context const& ctx, StringView key, Json arr); diff --git a/src/collective/aggregator.h b/src/collective/aggregator.h index fe7b6593093e..b33ca28ef575 100644 --- a/src/collective/aggregator.h +++ b/src/collective/aggregator.h @@ -31,18 +31,16 @@ namespace collective { * @param buffer The buffer storing the results. * @param size The size of the buffer. * @param function The function used to calculate the results. - * @param args Arguments to the function. */ -template -void ApplyWithLabels(MetaInfo const& info, T* buffer, size_t size, Function&& function, - Args&&... args) { +template +void ApplyWithLabels(MetaInfo const& info, void* buffer, size_t size, Function&& function) { if (info.IsVerticalFederated()) { // We assume labels are only available on worker 0, so the calculation is done there and result // broadcast to other workers. std::string message; if (collective::GetRank() == 0) { try { - std::forward(function)(std::forward(args)...); + std::forward(function)(); } catch (dmlc::Error& e) { message = e.what(); } @@ -55,7 +53,7 @@ void ApplyWithLabels(MetaInfo const& info, T* buffer, size_t size, Function&& fu LOG(FATAL) << &message[0]; } } else { - std::forward(function)(std::forward(args)...); + std::forward(function)(); } } diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index a99ed4f10936..f97003d1d3c4 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -45,20 +45,18 @@ HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, b if (!use_sorted) { HostSketchContainer container(max_bins, m->Info().feature_types.ConstHostSpan(), reduced, - HostSketchContainer::UseGroup(info), - m->Info().IsColumnSplit(), n_threads); + HostSketchContainer::UseGroup(info), n_threads); for (auto const& page : m->GetBatches()) { container.PushRowPage(page, info, hessian); } - container.MakeCuts(&out); + container.MakeCuts(m->Info(), &out); } else { SortedSketchContainer container{max_bins, m->Info().feature_types.ConstHostSpan(), reduced, - HostSketchContainer::UseGroup(info), - m->Info().IsColumnSplit(), n_threads}; + HostSketchContainer::UseGroup(info), n_threads}; for (auto const& page : m->GetBatches()) { container.PushColPage(page, info, hessian); } - container.MakeCuts(&out); + container.MakeCuts(m->Info(), &out); } return out; diff --git a/src/common/quantile.cc b/src/common/quantile.cc index aaf271934474..60626052c61c 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -6,6 +6,7 @@ #include #include +#include "../collective/aggregator.h" #include "../collective/communicator-inl.h" #include "../data/adapter.h" #include "categorical.h" @@ -18,13 +19,12 @@ template SketchContainerImpl::SketchContainerImpl(std::vector columns_size, int32_t max_bins, Span feature_types, - bool use_group, bool col_split, + bool use_group, int32_t n_threads) : feature_types_(feature_types.cbegin(), feature_types.cend()), columns_size_{std::move(columns_size)}, max_bins_{max_bins}, use_group_ind_{use_group}, - col_split_{col_split}, n_threads_{n_threads} { monitor_.Init(__func__); CHECK_NE(columns_size_.size(), 0); @@ -202,10 +202,10 @@ void SketchContainerImpl::GatherSketchInfo( } template -void SketchContainerImpl::AllreduceCategories() { +void SketchContainerImpl::AllreduceCategories(MetaInfo const& info) { auto world_size = collective::GetWorldSize(); auto rank = collective::GetRank(); - if (world_size == 1 || col_split_) { + if (world_size == 1 || info.IsColumnSplit()) { return; } @@ -273,6 +273,7 @@ void SketchContainerImpl::AllreduceCategories() { template void SketchContainerImpl::AllReduce( + MetaInfo const& info, std::vector *p_reduced, std::vector* p_num_cuts) { monitor_.Start(__func__); @@ -281,7 +282,7 @@ void SketchContainerImpl::AllReduce( collective::Allreduce(&n_columns, 1); CHECK_EQ(n_columns, sketches_.size()) << "Number of columns differs across workers"; - AllreduceCategories(); + AllreduceCategories(info); auto& num_cuts = *p_num_cuts; CHECK_EQ(num_cuts.size(), 0); @@ -292,10 +293,7 @@ void SketchContainerImpl::AllReduce( // Prune the intermediate num cuts for synchronization. std::vector global_column_size(columns_size_); - if (!col_split_) { - collective::Allreduce(global_column_size.data(), - global_column_size.size()); - } + collective::GlobalSum(info, &global_column_size); ParallelFor(sketches_.size(), n_threads_, [&](size_t i) { int32_t intermediate_num_cuts = static_cast( @@ -316,7 +314,7 @@ void SketchContainerImpl::AllReduce( }); auto world = collective::GetWorldSize(); - if (world == 1 || col_split_) { + if (world == 1 || info.IsColumnSplit()) { monitor_.Stop(__func__); return; } @@ -382,11 +380,11 @@ auto AddCategories(std::set const &categories, HistogramCuts *cuts) { } template -void SketchContainerImpl::MakeCuts(HistogramCuts* cuts) { +void SketchContainerImpl::MakeCuts(MetaInfo const& info, HistogramCuts* cuts) { monitor_.Start(__func__); std::vector reduced; std::vector num_cuts; - this->AllReduce(&reduced, &num_cuts); + this->AllReduce(info, &reduced, &num_cuts); cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f); std::vector final_summaries(reduced.size()); @@ -443,8 +441,8 @@ template class SketchContainerImpl>; HostSketchContainer::HostSketchContainer(int32_t max_bins, common::Span ft, std::vector columns_size, bool use_group, - bool col_split, int32_t n_threads) - : SketchContainerImpl{columns_size, max_bins, ft, use_group, col_split, n_threads} { + int32_t n_threads) + : SketchContainerImpl{columns_size, max_bins, ft, use_group, n_threads} { monitor_.Init(__func__); ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) { auto n_bins = std::min(static_cast(max_bins_), columns_size_[i]); diff --git a/src/common/quantile.h b/src/common/quantile.h index a19b4bbb0d01..f8d347112772 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -789,7 +789,6 @@ class SketchContainerImpl { std::vector columns_size_; int32_t max_bins_; bool use_group_ind_{false}; - bool col_split_; int32_t n_threads_; bool has_categorical_{false}; Monitor monitor_; @@ -802,7 +801,7 @@ class SketchContainerImpl { * \param use_group whether is assigned to group to data instance. */ SketchContainerImpl(std::vector columns_size, int32_t max_bins, - common::Span feature_types, bool use_group, bool col_split, + common::Span feature_types, bool use_group, int32_t n_threads); static bool UseGroup(MetaInfo const &info) { @@ -829,7 +828,7 @@ class SketchContainerImpl { std::vector *p_sketches_scan, std::vector *p_global_sketches); // Merge sketches from all workers. - void AllReduce(std::vector *p_reduced, + void AllReduce(MetaInfo const& info, std::vector *p_reduced, std::vector *p_num_cuts); template @@ -883,11 +882,11 @@ class SketchContainerImpl { /* \brief Push a CSR matrix. */ void PushRowPage(SparsePage const &page, MetaInfo const &info, Span hessian = {}); - void MakeCuts(HistogramCuts* cuts); + void MakeCuts(MetaInfo const& info, HistogramCuts* cuts); private: // Merge all categories from other workers. - void AllreduceCategories(); + void AllreduceCategories(MetaInfo const& info); }; class HostSketchContainer : public SketchContainerImpl> { @@ -896,8 +895,7 @@ class HostSketchContainer : public SketchContainerImpl ft, - std::vector columns_size, bool use_group, bool col_split, - int32_t n_threads); + std::vector columns_size, bool use_group, int32_t n_threads); template void PushAdapterBatch(Batch const &batch, size_t base_rowid, MetaInfo const &info, float missing); @@ -993,9 +991,9 @@ class SortedSketchContainer : public SketchContainerImpl ft, - std::vector columns_size, bool use_group, bool col_split, + std::vector columns_size, bool use_group, int32_t n_threads) - : SketchContainerImpl{columns_size, max_bins, ft, use_group, col_split, n_threads} { + : SketchContainerImpl{columns_size, max_bins, ft, use_group, n_threads} { monitor_.Init(__func__); sketches_.resize(columns_size.size()); size_t i = 0; diff --git a/src/data/data.cc b/src/data/data.cc index 694bc48b99d8..9f85e7db28a2 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -774,6 +774,10 @@ bool MetaInfo::IsVerticalFederated() const { return collective::IsFederated() && IsColumnSplit(); } +bool MetaInfo::ShouldHaveLabels() const { + return !IsVerticalFederated() || collective::GetRank() == 0; +} + using DMatrixThreadLocal = dmlc::ThreadLocalStore>; diff --git a/src/data/iterative_dmatrix.cc b/src/data/iterative_dmatrix.cc index 1bf75591541b..3a473122a0ef 100644 --- a/src/data/iterative_dmatrix.cc +++ b/src/data/iterative_dmatrix.cc @@ -213,7 +213,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing, SyncFeatureType(&h_ft); p_sketch.reset(new common::HostSketchContainer{ batch_param_.max_bin, h_ft, column_sizes, !proxy->Info().group_ptr_.empty(), - proxy->Info().IsColumnSplit(), ctx_.Threads()}); + ctx_.Threads()}); } HostAdapterDispatch(proxy, [&](auto const& batch) { proxy->Info().num_nonzero_ = batch_nnz[i]; @@ -228,7 +228,7 @@ void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing, CHECK_EQ(accumulated_rows, Info().num_row_); CHECK(p_sketch); - p_sketch->MakeCuts(&cuts); + p_sketch->MakeCuts(Info(), &cuts); } if (!h_ft.empty()) { CHECK_EQ(h_ft.size(), n_features); diff --git a/src/objective/adaptive.cc b/src/objective/adaptive.cc index 32fda9ef17b2..b195dffd793a 100644 --- a/src/objective/adaptive.cc +++ b/src/objective/adaptive.cc @@ -99,44 +99,40 @@ void UpdateTreeLeafHost(Context const* ctx, std::vector const& posit auto h_predt = linalg::MakeTensorView(ctx, predt.ConstHostSpan(), info.num_row_, predt.Size() / info.num_row_); - if (!info.IsVerticalFederated() || collective::GetRank() == 0) { - // loop over each leaf - common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) { - auto nidx = h_node_idx[k]; - CHECK(tree[nidx].IsLeaf()); - CHECK_LT(k + 1, h_node_ptr.size()); - size_t n = h_node_ptr[k + 1] - h_node_ptr[k]; - auto h_row_set = common::Span{ridx}.subspan(h_node_ptr[k], n); - - auto h_labels = info.labels.HostView().Slice(linalg::All(), IdxY(info, group_idx)); - auto h_weights = linalg::MakeVec(&info.weights_); - - auto iter = common::MakeIndexTransformIter([&](size_t i) -> float { - auto row_idx = h_row_set[i]; - return h_labels(row_idx) - h_predt(row_idx, group_idx); + collective::ApplyWithLabels( + info, static_cast(quantiles.data()), quantiles.size() * sizeof(float), [&] { + // loop over each leaf + common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) { + auto nidx = h_node_idx[k]; + CHECK(tree[nidx].IsLeaf()); + CHECK_LT(k + 1, h_node_ptr.size()); + size_t n = h_node_ptr[k + 1] - h_node_ptr[k]; + auto h_row_set = common::Span{ridx}.subspan(h_node_ptr[k], n); + + auto h_labels = info.labels.HostView().Slice(linalg::All(), IdxY(info, group_idx)); + auto h_weights = linalg::MakeVec(&info.weights_); + + auto iter = common::MakeIndexTransformIter([&](size_t i) -> float { + auto row_idx = h_row_set[i]; + return h_labels(row_idx) - h_predt(row_idx, group_idx); + }); + auto w_it = common::MakeIndexTransformIter([&](size_t i) -> float { + auto row_idx = h_row_set[i]; + return h_weights(row_idx); + }); + + float q{0}; + if (info.weights_.Empty()) { + q = common::Quantile(ctx, alpha, iter, iter + h_row_set.size()); + } else { + q = common::WeightedQuantile(ctx, alpha, iter, iter + h_row_set.size(), w_it); + } + if (std::isnan(q)) { + CHECK(h_row_set.empty()); + } + quantiles.at(k) = q; + }); }); - auto w_it = common::MakeIndexTransformIter([&](size_t i) -> float { - auto row_idx = h_row_set[i]; - return h_weights(row_idx); - }); - - float q{0}; - if (info.weights_.Empty()) { - q = common::Quantile(ctx, alpha, iter, iter + h_row_set.size()); - } else { - q = common::WeightedQuantile(ctx, alpha, iter, iter + h_row_set.size(), w_it); - } - if (std::isnan(q)) { - CHECK(h_row_set.empty()); - } - quantiles.at(k) = q; - }); - } - - if (info.IsVerticalFederated()) { - collective::Broadcast(static_cast(quantiles.data()), quantiles.size() * sizeof(float), - 0); - } UpdateLeafValues(&quantiles, nidx, info, learning_rate, p_tree); } diff --git a/src/objective/quantile_obj.cu b/src/objective/quantile_obj.cu index b34f37ff93a5..f94b5edf0494 100644 --- a/src/objective/quantile_obj.cu +++ b/src/objective/quantile_obj.cu @@ -36,7 +36,7 @@ class QuantileRegression : public ObjFunction { bst_target_t Targets(MetaInfo const& info) const override { auto const& alpha = param_.quantile_alpha.Get(); CHECK_EQ(alpha.size(), alpha_.Size()) << "The objective is not yet configured."; - if (!info.IsVerticalFederated() || collective::GetRank() == 0) { + if (info.ShouldHaveLabels()) { CHECK_EQ(info.labels.Shape(1), 1) << "Multi-target is not yet supported by the quantile loss."; } diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index 3cd32ea0cd56..4771cc9bfc93 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -73,7 +73,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) { auto hess = Span{hessian}; ContainerType sketch_distributed(n_bins, m->Info().feature_types.ConstHostSpan(), - column_size, false, false, AllThreadsForTest()); + column_size, false, AllThreadsForTest()); if (use_column) { for (auto const& page : m->GetBatches()) { @@ -86,7 +86,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) { } HistogramCuts distributed_cuts; - sketch_distributed.MakeCuts(&distributed_cuts); + sketch_distributed.MakeCuts(m->Info(), &distributed_cuts); // Generate cuts for single node environment collective::Finalize(); @@ -94,7 +94,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) { std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; }); m->Info().num_row_ = world * rows; ContainerType sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(), - column_size, false, false, AllThreadsForTest()); + column_size, false, AllThreadsForTest()); m->Info().num_row_ = rows; for (auto rank = 0; rank < world; ++rank) { @@ -117,7 +117,7 @@ void DoTestDistributedQuantile(size_t rows, size_t cols) { } HistogramCuts single_node_cuts; - sketch_on_single_node.MakeCuts(&single_node_cuts); + sketch_on_single_node.MakeCuts(m->Info(), &single_node_cuts); auto const& sptrs = single_node_cuts.Ptrs(); auto const& dptrs = distributed_cuts.Ptrs(); @@ -205,7 +205,7 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) { HistogramCuts distributed_cuts; { ContainerType sketch_distributed(n_bins, m->Info().feature_types.ConstHostSpan(), - column_size, false, true, AllThreadsForTest()); + column_size, false, AllThreadsForTest()); std::vector hessian(rows, 1.0); auto hess = Span{hessian}; @@ -219,7 +219,7 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) { } } - sketch_distributed.MakeCuts(&distributed_cuts); + sketch_distributed.MakeCuts(m->Info(), &distributed_cuts); } // Generate cuts for single node environment @@ -228,7 +228,7 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) { HistogramCuts single_node_cuts; { ContainerType sketch_on_single_node(n_bins, m->Info().feature_types.ConstHostSpan(), - column_size, false, false, AllThreadsForTest()); + column_size, false, AllThreadsForTest()); std::vector hessian(rows, 1.0); auto hess = Span{hessian}; @@ -242,7 +242,7 @@ void DoTestColSplitQuantile(size_t rows, size_t cols) { } } - sketch_on_single_node.MakeCuts(&single_node_cuts); + sketch_on_single_node.MakeCuts(m->Info(), &single_node_cuts); } auto const& sptrs = single_node_cuts.Ptrs();