From 2f9af35a7e24c3a3ebb9607d9e0e1abca1149af3 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 7 Mar 2024 18:57:08 +0800 Subject: [PATCH 1/2] Support graphviz plot for multi-target tree. --- include/xgboost/tree_model.h | 5 +- src/tree/tree_model.cc | 154 +++++++++++------- .../cpp/tree/test_multi_target_tree_model.cc | 37 ++++- 3 files changed, 133 insertions(+), 63 deletions(-) diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 4c475da2ea29..32b93c5cacaf 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -1,5 +1,5 @@ /** - * Copyright 2014-2023 by Contributors + * Copyright 2014-2024, XGBoost Contributors * \file tree_model.h * \brief model structure for tree * \author Tianqi Chen @@ -688,6 +688,9 @@ class RegTree : public Model { } return (*this)[nidx].DefaultLeft(); } + [[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const { + return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx); + } [[nodiscard]] bool IsRoot(bst_node_t nidx) const { if (IsMultiTarget()) { return nidx == kRoot; diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index f18b519264a0..b3e49fce2022 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -1,5 +1,5 @@ /** - * Copyright 2015-2023, XGBoost Contributors + * Copyright 2015-2024, XGBoost Contributors * \file tree_model.cc * \brief model structure for tree */ @@ -8,6 +8,7 @@ #include #include +#include // for array #include #include #include @@ -15,7 +16,7 @@ #include #include "../common/categorical.h" -#include "../common/common.h" // for EscapeU8 +#include "../common/common.h" // for EscapeU8 #include "../predictor/predict_fn.h" #include "io_utils.h" // for GetElem #include "param.h" @@ -31,26 +32,50 @@ namespace tree { DMLC_REGISTER_PARAMETER(TrainParam); } +namespace { +template +std::enable_if_t, std::string> ToStr(Float value) { + int32_t constexpr kFloatMaxPrecision = std::numeric_limits::max_digits10; + static_assert(std::is_floating_point::value, + "Use std::to_string instead for non-floating point values."); + std::stringstream ss; + ss << std::setprecision(kFloatMaxPrecision) << value; + return ss.str(); +} + +template +std::string ToStr(linalg::VectorView value, bst_target_t limit) { + int32_t constexpr kFloatMaxPrecision = std::numeric_limits::max_digits10; + static_assert(std::is_floating_point::value, + "Use std::to_string instead for non-floating point values."); + std::stringstream ss; + ss << std::setprecision(kFloatMaxPrecision); + if (value.Size() == 1) { + ss << value(0); + return ss.str(); + } + CHECK_GE(limit, 2); + auto n = std::min(static_cast(value.Size() - 1), limit - 1); + ss << "["; + for (std::size_t i = 0; i < n; ++i) { + ss << value(i) << ", "; + } + if (value.Size() > limit) { + ss << "..., "; + } + ss << value(value.Size() - 1) << "]"; + return ss.str(); +} +} // namespace /*! * \brief Base class for dump model implementation, modeling closely after code generator. */ class TreeGenerator { protected: - static int32_t constexpr kFloatMaxPrecision = - std::numeric_limits::max_digits10; FeatureMap const& fmap_; std::stringstream ss_; bool const with_stats_; - template - static std::string ToStr(Float value) { - static_assert(std::is_floating_point::value, - "Use std::to_string instead for non-floating point values."); - std::stringstream ss; - ss << std::setprecision(kFloatMaxPrecision) << value; - return ss.str(); - } - static std::string Tabs(uint32_t n) { std::string res; for (uint32_t i = 0; i < n; ++i) { @@ -258,10 +283,10 @@ class TextGenerator : public TreeGenerator { kLeafTemplate, {{"{tabs}", SuperT::Tabs(depth)}, {"{nid}", std::to_string(nid)}, - {"{leaf}", SuperT::ToStr(tree[nid].LeafValue())}, + {"{leaf}", ToStr(tree[nid].LeafValue())}, {"{stats}", with_stats_ ? SuperT::Match(kStatTemplate, - {{"{cover}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}}); + {{"{cover}", ToStr(tree.Stat(nid).sum_hess)}}) : ""}}); return result; } @@ -311,14 +336,14 @@ class TextGenerator : public TreeGenerator { static std::string const kQuantitiveTemplate = "{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}"; auto cond = tree[nid].SplitCond(); - return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth); + return SplitNodeImpl(tree, nid, kQuantitiveTemplate, ToStr(cond), depth); } std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override { auto cond = tree[nid].SplitCond(); static std::string const kNodeTemplate = "{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}"; - return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth); + return SplitNodeImpl(tree, nid, kNodeTemplate, ToStr(cond), depth); } std::string Categorical(RegTree const &tree, int32_t nid, @@ -336,8 +361,8 @@ class TextGenerator : public TreeGenerator { static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}"; std::string const result = SuperT::Match( kStatTemplate, - {{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)}, - {"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}); + {{"{loss_chg}", ToStr(tree.Stat(nid).loss_chg)}, + {"{sum_hess}", ToStr(tree.Stat(nid).sum_hess)}}); return result; } @@ -393,11 +418,11 @@ class JsonGenerator : public TreeGenerator { std::string result = SuperT::Match( kLeafTemplate, {{"{nid}", std::to_string(nid)}, - {"{leaf}", SuperT::ToStr(tree[nid].LeafValue())}, + {"{leaf}", ToStr(tree[nid].LeafValue())}, {"{stat}", with_stats_ ? SuperT::Match( kStatTemplate, {{"{sum_hess}", - SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}}); + ToStr(tree.Stat(nid).sum_hess)}}) : ""}}); return result; } @@ -468,7 +493,7 @@ class JsonGenerator : public TreeGenerator { R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I" R"I("missing": {missing})I"; bst_float cond = tree[nid].SplitCond(); - return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth); + return SplitNodeImpl(tree, nid, kQuantitiveTemplate, ToStr(cond), depth); } std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override { @@ -477,7 +502,7 @@ class JsonGenerator : public TreeGenerator { R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I" R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I" R"I("missing": {missing})I"; - return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth); + return SplitNodeImpl(tree, nid, kNodeTemplate, ToStr(cond), depth); } std::string NodeStat(RegTree const& tree, int32_t nid) const override { @@ -485,8 +510,8 @@ class JsonGenerator : public TreeGenerator { R"S(, "gain": {loss_chg}, "cover": {sum_hess})S"; auto result = SuperT::Match( kStatTemplate, - {{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)}, - {"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}); + {{"{loss_chg}", ToStr(tree.Stat(nid).loss_chg)}, + {"{sum_hess}", ToStr(tree.Stat(nid).sum_hess)}}); return result; } @@ -622,11 +647,11 @@ class GraphvizGenerator : public TreeGenerator { protected: template - std::string BuildEdge(RegTree const &tree, bst_node_t nid, int32_t child, bool left) const { + std::string BuildEdge(RegTree const &tree, bst_node_t nidx, int32_t child, bool left) const { static std::string const kEdgeTemplate = " {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n"; // Is this the default child for missing value? - bool is_missing = tree[nid].DefaultChild() == child; + bool is_missing = tree.DefaultChild(nidx) == child; std::string branch; if (is_categorical) { branch = std::string{left ? "no" : "yes"} + std::string{is_missing ? ", missing" : ""}; @@ -635,7 +660,7 @@ class GraphvizGenerator : public TreeGenerator { } std::string buffer = SuperT::Match(kEdgeTemplate, - {{"{nid}", std::to_string(nid)}, + {{"{nid}", std::to_string(nidx)}, {"{child}", std::to_string(child)}, {"{color}", is_missing ? param_.yes_color : param_.no_color}, {"{branch}", branch}}); @@ -644,68 +669,77 @@ class GraphvizGenerator : public TreeGenerator { // Only indicator is different, so we combine all different node types into this // function. - std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t) const override { - auto split_index = tree[nid].SplitIndex(); - auto cond = tree[nid].SplitCond(); + std::string PlainNode(RegTree const& tree, bst_node_t nidx, uint32_t) const override { + auto split_index = tree.SplitIndex(nidx); + auto cond = tree.SplitCond(nidx); static std::string const kNodeTemplate = " {nid} [ label=\"{fname}{<}{cond}\" {params}]\n"; bool has_less = (split_index >= fmap_.Size()) || fmap_.TypeOf(split_index) != FeatureMap::kIndicator; std::string result = - SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nid)}, + SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nidx)}, {"{fname}", GetFeatureName(fmap_, split_index)}, {"{<}", has_less ? "<" : ""}, - {"{cond}", has_less ? SuperT::ToStr(cond) : ""}, + {"{cond}", has_less ? ToStr(cond) : ""}, {"{params}", param_.condition_node_params}}); - result += BuildEdge(tree, nid, tree[nid].LeftChild(), true); - result += BuildEdge(tree, nid, tree[nid].RightChild(), false); + result += BuildEdge(tree, nidx, tree.LeftChild(nidx), true); + result += BuildEdge(tree, nidx, tree.RightChild(nidx), false); return result; }; - std::string Categorical(RegTree const& tree, int32_t nid, uint32_t) const override { + std::string Categorical(RegTree const& tree, bst_node_t nidx, uint32_t) const override { static std::string const kLabelTemplate = " {nid} [ label=\"{fname}:{cond}\" {params}]\n"; - auto cats = GetSplitCategories(tree, nid); + auto cats = GetSplitCategories(tree, nidx); auto cats_str = PrintCatsAsSet(cats); - auto split_index = tree[nid].SplitIndex(); + auto split_index = tree.SplitIndex(nidx); std::string result = - SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nid)}, + SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nidx)}, {"{fname}", GetFeatureName(fmap_, split_index)}, {"{cond}", cats_str}, {"{params}", param_.condition_node_params}}); - result += BuildEdge(tree, nid, tree[nid].LeftChild(), true); - result += BuildEdge(tree, nid, tree[nid].RightChild(), false); + result += BuildEdge(tree, nidx, tree.LeftChild(nidx), true); + result += BuildEdge(tree, nidx, tree.RightChild(nidx), false); return result; } - std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t) const override { - static std::string const kLeafTemplate = - " {nid} [ label=\"leaf={leaf-value}\" {params}]\n"; - auto result = SuperT::Match(kLeafTemplate, { - {"{nid}", std::to_string(nid)}, - {"{leaf-value}", ToStr(tree[nid].LeafValue())}, - {"{params}", param_.leaf_node_params}}); - return result; - }; + std::string LeafNode(RegTree const& tree, bst_node_t nidx, uint32_t) const override { + static std::string const kLeafTemplate = " {nid} [ label=\"leaf={leaf-value}\" {params}]\n"; + // hardcoded limit to avoid dumping long arrays into dot graph. + bst_target_t constexpr kLimit{3}; + if (tree.IsMultiTarget()) { + auto value = tree.GetMultiTargetTree()->LeafValue(nidx); + auto result = SuperT::Match(kLeafTemplate, {{"{nid}", std::to_string(nidx)}, + {"{leaf-value}", ToStr(value, kLimit)}, + {"{params}", param_.leaf_node_params}}); + return result; + } else { + auto value = tree[nidx].LeafValue(); + auto result = SuperT::Match(kLeafTemplate, {{"{nid}", std::to_string(nidx)}, + {"{leaf-value}", ToStr(value)}, + {"{params}", param_.leaf_node_params}}); + return result; + } + } - std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override { - if (tree[nid].IsLeaf()) { - return this->LeafNode(tree, nid, depth); + std::string BuildTree(RegTree const& tree, bst_node_t nidx, uint32_t depth) override { + if (tree.IsLeaf(nidx)) { + return this->LeafNode(tree, nidx, depth); } static std::string const kNodeTemplate = "{parent}\n{left}\n{right}"; - auto node = tree.GetSplitTypes()[nid] == FeatureType::kCategorical - ? this->Categorical(tree, nid, depth) - : this->PlainNode(tree, nid, depth); + auto node = tree.GetSplitTypes()[nidx] == FeatureType::kCategorical + ? this->Categorical(tree, nidx, depth) + : this->PlainNode(tree, nidx, depth); auto result = SuperT::Match( kNodeTemplate, {{"{parent}", node}, - {"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)}, - {"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}}); + {"{left}", this->BuildTree(tree, tree.LeftChild(nidx), depth+1)}, + {"{right}", this->BuildTree(tree, tree.RightChild(nidx), depth+1)}}); return result; } @@ -733,7 +767,9 @@ XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot") constexpr bst_node_t RegTree::kRoot; std::string RegTree::DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const { - CHECK(!IsMultiTarget()); + if (format != "dot") { + MTNotImplemented(); + } std::unique_ptr builder{TreeGenerator::Create(format, fmap, with_stats)}; builder->BuildTree(*this); diff --git a/tests/cpp/tree/test_multi_target_tree_model.cc b/tests/cpp/tree/test_multi_target_tree_model.cc index 550b8837c1cd..0b5745a20781 100644 --- a/tests/cpp/tree/test_multi_target_tree_model.cc +++ b/tests/cpp/tree/test_multi_target_tree_model.cc @@ -1,5 +1,5 @@ /** - * Copyright 2023 by XGBoost Contributors + * Copyright 2023-2024, XGBoost Contributors */ #include #include // for Context @@ -7,16 +7,23 @@ #include // for RegTree namespace xgboost { -TEST(MultiTargetTree, JsonIO) { +namespace { +auto MakeTreeForTest() { bst_target_t n_targets{3}; bst_feature_t n_features{4}; RegTree tree{n_targets, n_features}; - ASSERT_TRUE(tree.IsMultiTarget()); + CHECK(tree.IsMultiTarget()); linalg::Vector base_weight{{1.0f, 2.0f, 3.0f}, {3ul}, DeviceOrd::CPU()}; linalg::Vector left_weight{{2.0f, 3.0f, 4.0f}, {3ul}, DeviceOrd::CPU()}; linalg::Vector right_weight{{3.0f, 4.0f, 5.0f}, {3ul}, DeviceOrd::CPU()}; tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(), left_weight.HostView(), right_weight.HostView()); + return tree; +} +} // namespace + +TEST(MultiTargetTree, JsonIO) { + auto tree = MakeTreeForTest(); ASSERT_EQ(tree.NumNodes(), 3); ASSERT_EQ(tree.NumTargets(), 3); ASSERT_EQ(tree.GetMultiTargetTree()->Size(), 3); @@ -44,4 +51,28 @@ TEST(MultiTargetTree, JsonIO) { loaded.SaveModel(&jtree1); check_jtree(jtree1, tree); } + +TEST(MultiTargetTree, DumpDot) { + auto tree = MakeTreeForTest(); + auto n_features = tree.NumFeatures(); + FeatureMap fmap; + for (bst_feature_t f = 0; f < n_features; ++f) { + auto name = "feat_" + std::to_string(f); + fmap.PushBack(f, name.c_str(), "q"); + } + auto str = tree.DumpModel(fmap, true, "dot"); + ASSERT_NE(str.find("leaf=[2, 3, 4]"), std::string::npos); + ASSERT_NE(str.find("leaf=[3, 4, 5]"), std::string::npos); + + { + bst_target_t n_targets{4}; + bst_feature_t n_features{4}; + RegTree tree{n_targets, n_features}; + linalg::Vector weight{{1.0f, 2.0f, 3.0f, 4.0f}, {4ul}, DeviceOrd::CPU()}; + tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, weight.HostView(), + weight.HostView(), weight.HostView()); + auto str = tree.DumpModel(fmap, true, "dot"); + ASSERT_NE(str.find("leaf=[1, 2, ..., 4]"), std::string::npos); + } +} } // namespace xgboost From 9c87144339673bc6b2cc93d1e5fc4f3d6e4595a6 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 8 Mar 2024 13:33:29 +0800 Subject: [PATCH 2/2] Fix error message. --- src/tree/tree_model.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index b3e49fce2022..45834cc7755e 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -767,8 +767,8 @@ XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot") constexpr bst_node_t RegTree::kRoot; std::string RegTree::DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const { - if (format != "dot") { - MTNotImplemented(); + if (this->IsMultiTarget() && format != "dot") { + LOG(FATAL) << format << " tree dump " << MTNotImplemented(); } std::unique_ptr builder{TreeGenerator::Create(format, fmap, with_stats)}; builder->BuildTree(*this);