diff --git a/R-package/src/Makevars.in b/R-package/src/Makevars.in index f199544a32b1..f03bbc73f169 100644 --- a/R-package/src/Makevars.in +++ b/R-package/src/Makevars.in @@ -68,6 +68,7 @@ OBJECTS= \ $(PKGROOT)/src/tree/updater_quantile_hist.o \ $(PKGROOT)/src/tree/updater_refresh.o \ $(PKGROOT)/src/tree/updater_sync.o \ + $(PKGROOT)/src/tree/hist/param.o \ $(PKGROOT)/src/linear/linear_updater.o \ $(PKGROOT)/src/linear/updater_coordinate.o \ $(PKGROOT)/src/linear/updater_shotgun.o \ diff --git a/R-package/src/Makevars.win b/R-package/src/Makevars.win index 2e7f9811300b..9f4d0d5f3a8b 100644 --- a/R-package/src/Makevars.win +++ b/R-package/src/Makevars.win @@ -68,6 +68,7 @@ OBJECTS= \ $(PKGROOT)/src/tree/updater_quantile_hist.o \ $(PKGROOT)/src/tree/updater_refresh.o \ $(PKGROOT)/src/tree/updater_sync.o \ + $(PKGROOT)/src/tree/hist/param.o \ $(PKGROOT)/src/linear/linear_updater.o \ $(PKGROOT)/src/linear/updater_coordinate.o \ $(PKGROOT)/src/linear/updater_shotgun.o \ diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index 4ca5b9f7e78c..6d2b54f84c17 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -574,7 +574,9 @@ template ::value && !std::is_pointer_v> * = nullptr> auto MakeTensorView(Context const *ctx, Container &data, S &&...shape) { // NOLINT - using T = typename Container::value_type; + using T = std::conditional_t, + std::add_const_t, + typename Container::value_type>; std::size_t in_shape[sizeof...(S)]; detail::IndexToArr(in_shape, std::forward(shape)...); return TensorView{data, in_shape, ctx->gpu_id}; diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index 489ef2396542..e52ce1f662bf 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -81,11 +81,11 @@ void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end) { /*! * \brief Increment hist as dst += add in range [begin, end) */ -void IncrementHist(GHistRow dst, const GHistRow add, size_t begin, size_t end) { - double* pdst = reinterpret_cast(dst.data()); +void IncrementHist(GHistRow dst, ConstGHistRow add, std::size_t begin, std::size_t end) { + double *pdst = reinterpret_cast(dst.data()); const double *padd = reinterpret_cast(add.data()); - for (size_t i = 2 * begin; i < 2 * end; ++i) { + for (std::size_t i = 2 * begin; i < 2 * end; ++i) { pdst[i] += padd[i]; } } @@ -207,18 +207,23 @@ void RowsWiseBuildHistKernel(Span gpair, const size_t size = row_indices.Size(); const size_t *rid = row_indices.begin; - auto const *pgh = reinterpret_cast(gpair.data()); + auto const *p_gpair = reinterpret_cast(gpair.data()); const BinIdxType *gradient_index = gmat.index.data(); auto const &row_ptr = gmat.row_ptr.data(); auto base_rowid = gmat.base_rowid; - const uint32_t *offsets = gmat.index.Offset(); - auto get_row_ptr = [&](size_t ridx) { + uint32_t const *offsets = gmat.index.Offset(); + // There's no feature-based compression if missing value is present. + if (kAnyMissing) { + CHECK(!offsets); + } else { + CHECK(offsets); + } + + auto get_row_ptr = [&](bst_row_t ridx) { return kFirstPage ? row_ptr[ridx] : row_ptr[ridx - base_rowid]; }; - auto get_rid = [&](size_t ridx) { - return kFirstPage ? ridx : (ridx - base_rowid); - }; + auto get_rid = [&](bst_row_t ridx) { return kFirstPage ? ridx : (ridx - base_rowid); }; const size_t n_features = get_row_ptr(row_indices.begin[0] + 1) - get_row_ptr(row_indices.begin[0]); @@ -228,7 +233,7 @@ void RowsWiseBuildHistKernel(Span gpair, // So we need to multiply each row-index/bin-index by 2 // to work with gradient pairs as a singe row FP array - for (size_t i = 0; i < size; ++i) { + for (std::size_t i = 0; i < size; ++i) { const size_t icol_start = kAnyMissing ? get_row_ptr(rid[i]) : get_rid(rid[i]) * n_features; const size_t icol_end = @@ -246,7 +251,7 @@ void RowsWiseBuildHistKernel(Span gpair, kAnyMissing ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset] + 1) : icol_start_prefetch + n_features; - PREFETCH_READ_T0(pgh + two * rid[i + Prefetch::kPrefetchOffset]); + PREFETCH_READ_T0(p_gpair + two * rid[i + Prefetch::kPrefetchOffset]); for (size_t j = icol_start_prefetch; j < icol_end_prefetch; j += Prefetch::GetPrefetchStep()) { PREFETCH_READ_T0(gradient_index + j); @@ -255,12 +260,12 @@ void RowsWiseBuildHistKernel(Span gpair, const BinIdxType *gr_index_local = gradient_index + icol_start; // The trick with pgh_t buffer helps the compiler to generate faster binary. - const float pgh_t[] = {pgh[idx_gh], pgh[idx_gh + 1]}; + const float pgh_t[] = {p_gpair[idx_gh], p_gpair[idx_gh + 1]}; for (size_t j = 0; j < row_size; ++j) { - const uint32_t idx_bin = two * (static_cast(gr_index_local[j]) + - (kAnyMissing ? 0 : offsets[j])); + const uint32_t idx_bin = + two * (static_cast(gr_index_local[j]) + (kAnyMissing ? 0 : offsets[j])); auto hist_local = hist_data + idx_bin; - *(hist_local) += pgh_t[0]; + *(hist_local) += pgh_t[0]; *(hist_local + 1) += pgh_t[1]; } } @@ -281,12 +286,10 @@ void ColsWiseBuildHistKernel(Span gpair, auto const &row_ptr = gmat.row_ptr.data(); auto base_rowid = gmat.base_rowid; const uint32_t *offsets = gmat.index.Offset(); - auto get_row_ptr = [&](size_t ridx) { + auto get_row_ptr = [&](bst_row_t ridx) { return kFirstPage ? row_ptr[ridx] : row_ptr[ridx - base_rowid]; }; - auto get_rid = [&](size_t ridx) { - return kFirstPage ? ridx : (ridx - base_rowid); - }; + auto get_rid = [&](bst_row_t ridx) { return kFirstPage ? ridx : (ridx - base_rowid); }; const size_t n_features = gmat.cut.Ptrs().size() - 1; const size_t n_columns = n_features; diff --git a/src/common/hist_util.h b/src/common/hist_util.h index fd364b8ac5a3..12db898a98cf 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -362,6 +362,7 @@ bst_bin_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(std::size_t begin, std::size_t } using GHistRow = Span; +using ConstGHistRow = Span; /*! * \brief fill a histogram by zeros @@ -371,7 +372,7 @@ void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end); /*! * \brief Increment hist as dst += add in range [begin, end) */ -void IncrementHist(GHistRow dst, const GHistRow add, size_t begin, size_t end); +void IncrementHist(GHistRow dst, ConstGHistRow add, std::size_t begin, std::size_t end); /*! * \brief Copy hist from src to dst in range [begin, end) diff --git a/src/common/threading_utils.h b/src/common/threading_utils.h index 0247e4dccbf5..9c74838474fa 100644 --- a/src/common/threading_utils.h +++ b/src/common/threading_utils.h @@ -136,7 +136,7 @@ class BlockedSpace2d { // Wrapper to implement nested parallelism with simple omp parallel for template void ParallelFor2d(const BlockedSpace2d& space, int nthreads, Func func) { - const size_t num_blocks_in_space = space.Size(); + std::size_t n_blocks_in_space = space.Size(); CHECK_GE(nthreads, 1); dmlc::OMPException exc; @@ -144,11 +144,10 @@ void ParallelFor2d(const BlockedSpace2d& space, int nthreads, Func func) { { exc.Run([&]() { size_t tid = omp_get_thread_num(); - size_t chunck_size = - num_blocks_in_space / nthreads + !!(num_blocks_in_space % nthreads); + size_t chunck_size = n_blocks_in_space / nthreads + !!(n_blocks_in_space % nthreads); - size_t begin = chunck_size * tid; - size_t end = std::min(begin + chunck_size, num_blocks_in_space); + std::size_t begin = chunck_size * tid; + std::size_t end = std::min(begin + chunck_size, n_blocks_in_space); for (auto i = begin; i < end; i++) { func(space.GetFirstDimension(i), space.GetRange(i)); } diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index 4fb857e06e38..8e6ea9ac3c22 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -65,7 +65,7 @@ class HistEvaluator { * pseudo-category for missing value but here we just do a complete scan to avoid * making specialized histogram bin. */ - void EnumerateOneHot(common::HistogramCuts const &cut, const common::GHistRow &hist, + void EnumerateOneHot(common::HistogramCuts const &cut, common::ConstGHistRow hist, bst_feature_t fidx, bst_node_t nidx, TreeEvaluator::SplitEvaluator const &evaluator, SplitEntry *p_best) const { @@ -143,7 +143,7 @@ class HistEvaluator { */ template void EnumeratePart(common::HistogramCuts const &cut, common::Span sorted_idx, - common::GHistRow const &hist, bst_feature_t fidx, bst_node_t nidx, + common::ConstGHistRow hist, bst_feature_t fidx, bst_node_t nidx, TreeEvaluator::SplitEvaluator const &evaluator, SplitEntry *p_best) { static_assert(d_step == +1 || d_step == -1, "Invalid step."); @@ -214,7 +214,7 @@ class HistEvaluator { // Returns the sum of gradients corresponding to the data points that contains // a non-missing value for the particular feature fid. template - GradStats EnumerateSplit(common::HistogramCuts const &cut, const common::GHistRow &hist, + GradStats EnumerateSplit(common::HistogramCuts const &cut, common::ConstGHistRow hist, bst_feature_t fidx, bst_node_t nidx, TreeEvaluator::SplitEvaluator const &evaluator, SplitEntry *p_best) const { @@ -510,7 +510,7 @@ class HistMultiEvaluator { template bool EnumerateSplit(common::HistogramCuts const &cut, bst_feature_t fidx, - common::Span hist, + common::Span hist, linalg::VectorView parent_sum, double parent_gain, SplitEntryContainer> *p_best) const { auto const &cut_ptr = cut.Ptrs(); @@ -651,7 +651,7 @@ class HistMultiEvaluator { auto entry = &tloc_candidates[n_threads * nidx_in_set + tidx]; auto best = &entry->split; auto parent_sum = stats_.Slice(entry->nid, linalg::All()); - std::vector node_hist; + std::vector node_hist; for (auto t_hist : hist) { node_hist.push_back((*t_hist)[entry->nid]); } diff --git a/src/tree/hist/param.cc b/src/tree/hist/param.cc new file mode 100644 index 000000000000..64d1de5f4b29 --- /dev/null +++ b/src/tree/hist/param.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2021-2023, XGBoost Contributors + */ +#include "param.h" + +#include // for string + +#include "../../collective/communicator-inl.h" // for GetRank, Broadcast +#include "xgboost/json.h" // for Object, Json +#include "xgboost/tree_model.h" // for RegTree + +namespace xgboost::tree { +DMLC_REGISTER_PARAMETER(HistMakerTrainParam); + +void HistMakerTrainParam::CheckTreesSynchronized(RegTree const* local_tree) const { + if (!this->debug_synchronize) { + return; + } + + std::string s_model; + Json model{Object{}}; + int rank = collective::GetRank(); + if (rank == 0) { + local_tree->SaveModel(&model); + } + Json::Dump(model, &s_model, std::ios::binary); + collective::Broadcast(&s_model, 0); + + RegTree ref_tree{}; // rank 0 tree + auto j_ref_tree = Json::Load(StringView{s_model}); + ref_tree.LoadModel(j_ref_tree); + CHECK(*local_tree == ref_tree); +} +} // namespace xgboost::tree diff --git a/src/tree/hist/param.h b/src/tree/hist/param.h new file mode 100644 index 000000000000..3dfbf68e1eb2 --- /dev/null +++ b/src/tree/hist/param.h @@ -0,0 +1,20 @@ +/** + * Copyright 2021-2023, XGBoost Contributors + */ +#pragma once +#include "xgboost/parameter.h" +#include "xgboost/tree_model.h" // for RegTree + +namespace xgboost::tree { +struct HistMakerTrainParam : public XGBoostParameter { + bool debug_synchronize; + void CheckTreesSynchronized(RegTree const* local_tree) const; + + // declare parameters + DMLC_DECLARE_PARAMETER(HistMakerTrainParam) { + DMLC_DECLARE_FIELD(debug_synchronize) + .set_default(false) + .describe("Check if all distributed tree are identical after tree construction."); + } +}; +} // namespace xgboost::tree diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 78506305faa0..fe9f681cffd2 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -11,17 +11,17 @@ #include "../common/random.h" #include "../data/gradient_index.h" #include "common_row_partitioner.h" -#include "constraints.h" #include "driver.h" #include "hist/evaluate_splits.h" #include "hist/histogram.h" +#include "hist/param.h" #include "hist/sampler.h" // for SampleGradient -#include "param.h" +#include "param.h" // for HistMakerTrainParam #include "xgboost/base.h" #include "xgboost/data.h" #include "xgboost/json.h" #include "xgboost/linalg.h" -#include "xgboost/task.h" // for ObjInfo +#include "xgboost/task.h" // for ObjInfo #include "xgboost/tree_model.h" #include "xgboost/tree_updater.h" // for TreeUpdater @@ -43,6 +43,7 @@ auto BatchSpec(TrainParam const &p, common::Span hess) { class GloablApproxBuilder { protected: TrainParam const *param_; + HistMakerTrainParam const *hist_param_{nullptr}; std::shared_ptr col_sampler_; HistEvaluator evaluator_; HistogramBuilder histogram_builder_; @@ -169,10 +170,12 @@ class GloablApproxBuilder { } public: - explicit GloablApproxBuilder(TrainParam const *param, MetaInfo const &info, Context const *ctx, + explicit GloablApproxBuilder(TrainParam const *param, HistMakerTrainParam const *hist_param, + MetaInfo const &info, Context const *ctx, std::shared_ptr column_sampler, ObjInfo const *task, common::Monitor *monitor) : param_{param}, + hist_param_{hist_param}, col_sampler_{std::move(column_sampler)}, evaluator_{ctx, param_, info, col_sampler_}, ctx_{ctx}, @@ -260,6 +263,7 @@ class GlobalApproxUpdater : public TreeUpdater { std::shared_ptr column_sampler_ = std::make_shared(); ObjInfo const *task_; + HistMakerTrainParam hist_param_; public: explicit GlobalApproxUpdater(Context const *ctx, ObjInfo const *task) @@ -267,9 +271,15 @@ class GlobalApproxUpdater : public TreeUpdater { monitor_.Init(__func__); } - void Configure(Args const &) override {} - void LoadConfig(Json const &) override {} - void SaveConfig(Json *) const override {} + void Configure(Args const &args) override { hist_param_.UpdateAllowUnknown(args); } + void LoadConfig(Json const &in) override { + auto const &config = get(in); + FromJson(config.at("hist_train_param"), &hist_param_); + } + void SaveConfig(Json *p_out) const override { + auto &out = *p_out; + out["hist_train_param"] = ToJson(hist_param_); + } void InitData(TrainParam const ¶m, HostDeviceVector const *gpair, linalg::Matrix *sampled) { @@ -284,8 +294,8 @@ class GlobalApproxUpdater : public TreeUpdater { void Update(TrainParam const *param, HostDeviceVector *gpair, DMatrix *m, common::Span> out_position, const std::vector &trees) override { - pimpl_ = std::make_unique(param, m->Info(), ctx_, column_sampler_, task_, - &monitor_); + pimpl_ = std::make_unique(param, &hist_param_, m->Info(), ctx_, + column_sampler_, task_, &monitor_); linalg::Matrix h_gpair; // Obtain the hessian values for weighted sketching @@ -300,6 +310,7 @@ class GlobalApproxUpdater : public TreeUpdater { std::size_t t_idx = 0; for (auto p_tree : trees) { this->pimpl_->UpdateTree(m, s_gpair, hess, p_tree, &out_position[t_idx]); + hist_param_.CheckTreesSynchronized(p_tree); ++t_idx; } } diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index e2a863e3d2d2..ce53d50d00fd 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -30,6 +30,7 @@ #include "gpu_hist/gradient_based_sampler.cuh" #include "gpu_hist/histogram.cuh" #include "gpu_hist/row_partitioner.cuh" +#include "hist/param.h" #include "param.h" #include "split_evaluator.h" #include "updater_gpu_common.cuh" @@ -48,20 +49,6 @@ namespace xgboost::tree { DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); #endif // !defined(GTEST_TEST) -// training parameters specific to this algorithm -struct GPUHistMakerTrainParam - : public XGBoostParameter { - bool debug_synchronize; - // declare parameters - DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) { - DMLC_DECLARE_FIELD(debug_synchronize).set_default(false).describe( - "Check if all distributed tree are identical after tree construction."); - } -}; -#if !defined(GTEST_TEST) -DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam); -#endif // !defined(GTEST_TEST) - /** * \struct DeviceHistogramStorage * @@ -767,12 +754,12 @@ class GPUHistMaker : public TreeUpdater { void LoadConfig(Json const& in) override { auto const& config = get(in); - FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_); + FromJson(config.at("hist_train_param"), &this->hist_maker_param_); initialised_ = false; } void SaveConfig(Json* p_out) const override { auto& out = *p_out; - out["gpu_hist_train_param"] = ToJson(hist_maker_param_); + out["hist_train_param"] = ToJson(hist_maker_param_); } ~GPUHistMaker() { // NOLINT @@ -790,9 +777,7 @@ class GPUHistMaker : public TreeUpdater { for (xgboost::RegTree* tree : trees) { this->UpdateTree(param, gpair, dmat, tree, &out_position[t_idx]); - if (hist_maker_param_.debug_synchronize) { - this->CheckTreesSynchronized(tree); - } + hist_maker_param_.CheckTreesSynchronized(tree); ++t_idx; } dh::safe_cuda(cudaGetLastError()); @@ -876,7 +861,7 @@ class GPUHistMaker : public TreeUpdater { private: bool initialised_{false}; - GPUHistMakerTrainParam hist_maker_param_; + HistMakerTrainParam hist_maker_param_; DMatrix* p_last_fmat_{nullptr}; RegTree const* p_last_tree_{nullptr}; diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 22b715dc7db7..63aaf27f6f6d 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -4,18 +4,17 @@ * \brief use quantized feature values to construct a tree * \author Philip Cho, Tianqi Checn, Egor Smirnov */ -#include // for max, copy, transform -#include // for size_t -#include // for uint32_t, int32_t -#include // for unique_ptr, allocator, make_unique, shared_ptr -#include // for accumulate -#include // for basic_ostream, char_traits, operator<< -#include // for move, swap -#include // for vector +#include // for max, copy, transform +#include // for size_t +#include // for uint32_t, int32_t +#include // for unique_ptr, allocator, make_unique, shared_ptr +#include // for accumulate +#include // for basic_ostream, char_traits, operator<< +#include // for move, swap +#include // for vector #include "../collective/aggregator.h" // for GlobalSum #include "../collective/communicator-inl.h" // for Allreduce, IsDistributed -#include "../collective/communicator.h" // for Operation #include "../common/hist_util.h" // for HistogramCuts, HistCollection #include "../common/linalg_op.h" // for begin, cbegin, cend #include "../common/random.h" // for ColumnSampler @@ -24,12 +23,12 @@ #include "../common/transform_iterator.h" // for IndexTransformIter, MakeIndexTransformIter #include "../data/gradient_index.h" // for GHistIndexMatrix #include "common_row_partitioner.h" // for CommonRowPartitioner -#include "dmlc/omp.h" // for omp_get_thread_num #include "dmlc/registry.h" // for DMLC_REGISTRY_FILE_TAG #include "driver.h" // for Driver #include "hist/evaluate_splits.h" // for HistEvaluator, HistMultiEvaluator, UpdatePre... #include "hist/expand_entry.h" // for MultiExpandEntry, CPUExpandEntry #include "hist/histogram.h" // for HistogramBuilder, ConstructHistSpace +#include "hist/param.h" // for HistMakerTrainParam #include "hist/sampler.h" // for SampleGradient #include "param.h" // for TrainParam, SplitEntryContainer, GradStats #include "xgboost/base.h" // for GradientPairInternal, GradientPair, bst_targ... @@ -117,6 +116,7 @@ class MultiTargetHistBuilder { private: common::Monitor *monitor_{nullptr}; TrainParam const *param_{nullptr}; + HistMakerTrainParam const *hist_param_{nullptr}; std::shared_ptr col_sampler_; std::unique_ptr evaluator_; // Histogram builder for each target. @@ -306,10 +306,12 @@ class MultiTargetHistBuilder { public: explicit MultiTargetHistBuilder(Context const *ctx, MetaInfo const &info, TrainParam const *param, + HistMakerTrainParam const *hist_param, std::shared_ptr column_sampler, ObjInfo const *task, common::Monitor *monitor) : monitor_{monitor}, param_{param}, + hist_param_{hist_param}, col_sampler_{std::move(column_sampler)}, evaluator_{std::make_unique(ctx, info, param, col_sampler_)}, ctx_{ctx}, @@ -331,10 +333,14 @@ class MultiTargetHistBuilder { } }; -class HistBuilder { +/** + * @brief Tree updater for single-target trees. + */ +class HistUpdater { private: common::Monitor *monitor_; TrainParam const *param_; + HistMakerTrainParam const *hist_param_{nullptr}; std::shared_ptr col_sampler_; std::unique_ptr evaluator_; std::vector partitioner_; @@ -349,14 +355,14 @@ class HistBuilder { Context const *ctx_{nullptr}; public: - explicit HistBuilder(Context const *ctx, std::shared_ptr column_sampler, - TrainParam const *param, DMatrix const *fmat, ObjInfo const *task, - common::Monitor *monitor) + explicit HistUpdater(Context const *ctx, std::shared_ptr column_sampler, + TrainParam const *param, HistMakerTrainParam const *hist_param, + DMatrix const *fmat, ObjInfo const *task, common::Monitor *monitor) : monitor_{monitor}, param_{param}, + hist_param_{hist_param}, col_sampler_{std::move(column_sampler)}, - evaluator_{std::make_unique(ctx, param, fmat->Info(), - col_sampler_)}, + evaluator_{std::make_unique(ctx, param, fmat->Info(), col_sampler_)}, p_last_fmat_(fmat), histogram_builder_{new HistogramBuilder}, task_{task}, @@ -529,7 +535,7 @@ class HistBuilder { std::vector *p_out_position) { monitor_->Start(__func__); if (!task_->UpdateTreeLeaf()) { - monitor_->Stop(__func__); + monitor_->Stop(__func__); return; } for (auto const &part : partitioner_) { @@ -541,20 +547,27 @@ class HistBuilder { /*! \brief construct a tree using quantized feature values */ class QuantileHistMaker : public TreeUpdater { - std::unique_ptr p_impl_{nullptr}; + std::unique_ptr p_impl_{nullptr}; std::unique_ptr p_mtimpl_{nullptr}; std::shared_ptr column_sampler_ = std::make_shared(); common::Monitor monitor_; ObjInfo const *task_{nullptr}; + HistMakerTrainParam hist_param_; public: explicit QuantileHistMaker(Context const *ctx, ObjInfo const *task) : TreeUpdater{ctx}, task_{task} {} - void Configure(const Args &) override {} - void LoadConfig(Json const &) override {} - void SaveConfig(Json *) const override {} + void Configure(Args const &args) override { hist_param_.UpdateAllowUnknown(args); } + void LoadConfig(Json const &in) override { + auto const &config = get(in); + FromJson(config.at("hist_train_param"), &hist_param_); + } + void SaveConfig(Json *p_out) const override { + auto &out = *p_out; + out["hist_train_param"] = ToJson(hist_param_); + } [[nodiscard]] char const *Name() const override { return "grow_quantile_histmaker"; } @@ -562,15 +575,17 @@ class QuantileHistMaker : public TreeUpdater { common::Span> out_position, const std::vector &trees) override { if (trees.front()->IsMultiTarget()) { + CHECK(hist_param_.GetInitialised()); CHECK(param->monotone_constraints.empty()) << "monotone constraint" << MTNotImplemented(); if (!p_mtimpl_) { this->p_mtimpl_ = std::make_unique( - ctx_, p_fmat->Info(), param, column_sampler_, task_, &monitor_); + ctx_, p_fmat->Info(), param, &hist_param_, column_sampler_, task_, &monitor_); } } else { + CHECK(hist_param_.GetInitialised()); if (!p_impl_) { - p_impl_ = - std::make_unique(ctx_, column_sampler_, param, p_fmat, task_, &monitor_); + p_impl_ = std::make_unique(ctx_, column_sampler_, param, &hist_param_, p_fmat, + task_, &monitor_); } } @@ -601,6 +616,8 @@ class QuantileHistMaker : public TreeUpdater { UpdateTree(&monitor_, h_sample_out, p_impl_.get(), p_fmat, param, h_out_position, *tree_it); } + + hist_param_.CheckTreesSynchronized(*tree_it); } } diff --git a/tests/test_distributed/test_with_dask/test_with_dask.py b/tests/test_distributed/test_with_dask/test_with_dask.py index 66c6058a5660..23bfc2d234ed 100644 --- a/tests/test_distributed/test_with_dask/test_with_dask.py +++ b/tests/test_distributed/test_with_dask/test_with_dask.py @@ -1459,6 +1459,7 @@ def run_updater_test( tree_method: str, ) -> None: params["tree_method"] = tree_method + params["debug_synchronize"] = True params = dataset.set_params(params) # It doesn't make sense to distribute a completely # empty dataset.