Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Faster Treelite serialization #2263

Merged
merged 13 commits into from
Jun 17, 2020
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@
- PR #2257: Update QN and LogisticRegression to use CumlArray
- PR #2259: Add CumlArray support to Naive Bayes
- PR #2252: Add benchmark for the Gram matrix prims
- PR #2263: Faster serialization for Treelite objects with RF
- PR #2264: Reduce build time for cuML by using make_blobs from libcuml++ interface
- PR #2269: Add docs targets to build.sh and fix python cuml.common docs
- PR #2271: Clarify doc for `_unique` default implementation in OneHotEncoder
Expand Down
2 changes: 1 addition & 1 deletion ci/gpu/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ export LD_LIBRARY_PATH=$LD_LIBRARY_PATH_CACHED
export LD_LIBRARY_PATH_CACHED=""

logger "Install Treelite for GPU testing..."
python -m pip install -v treelite==0.91
python -m pip install -v treelite==0.92 treelite_runtime==0.92

cd $WORKSPACE

Expand Down
2 changes: 1 addition & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ set(CUML_INCLUDE_DIRECTORIES
${CUB_DIR}/src/cub
${SPDLOG_DIR}/src/spdlog/include
${TREELITE_DIR}/include
${TREELITE_DIR}/include/fmt
${TREELITE_DIR}/src/treelite/include
${RAFT_DIR}/cpp/include)

set(CUML_LINK_LIBRARIES
Expand Down
22 changes: 10 additions & 12 deletions cpp/cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -126,28 +126,26 @@ set(TREELITE_DIR ${CMAKE_CURRENT_BINARY_DIR}/treelite CACHE STRING
"Path to treelite install directory")
ExternalProject_Add(treelite
GIT_REPOSITORY https://github.com/dmlc/treelite.git
GIT_TAG 6fd01e4f1890950bbcf9b124da24e886751bffe6
GIT_TAG 0.92
PREFIX ${TREELITE_DIR}
CMAKE_ARGS -DBUILD_SHARED_LIBS=OFF
-DCMAKE_INSTALL_PREFIX=<INSTALL_DIR>
-DENABLE_PROTOBUF=ON
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=<INSTALL_DIR>
-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH}
BUILD_BYPRODUCTS ${TREELITE_DIR}/lib/libtreelite.a
${TREELITE_DIR}/lib/libdmlc.a
${TREELITE_DIR}/lib/libtreelite_runtime.so
UPDATE_COMMAND ""
PATCH_COMMAND patch -p1 -N < ${CMAKE_CURRENT_SOURCE_DIR}/cmake/treelite_protobuf.patch || true)
-DENABLE_PROTOBUF=ON
BUILD_BYPRODUCTS ${TREELITE_DIR}/src/treelite-build/libtreelite_static.a
${TREELITE_DIR}/src/treelite-build/_deps/dmlccore-build/libdmlc.a
${TREELITE_DIR}/src/treelite-build/libtreelite_runtime.so
UPDATE_COMMAND "")
hcho3 marked this conversation as resolved.
Show resolved Hide resolved

add_library(dmlclib STATIC IMPORTED)
add_library(treelitelib STATIC IMPORTED)
add_library(treelite_runtimelib SHARED IMPORTED)

set_property(TARGET dmlclib PROPERTY
IMPORTED_LOCATION ${TREELITE_DIR}/lib/libdmlc.a)
IMPORTED_LOCATION ${TREELITE_DIR}/src/treelite-build/_deps/dmlccore-build/libdmlc.a)
set_property(TARGET treelitelib PROPERTY
IMPORTED_LOCATION ${TREELITE_DIR}/lib/libtreelite.a)
IMPORTED_LOCATION ${TREELITE_DIR}/src/treelite-build/libtreelite_static.a)
set_property(TARGET treelite_runtimelib PROPERTY
IMPORTED_LOCATION ${TREELITE_DIR}/lib/libtreelite_runtime.so)
IMPORTED_LOCATION ${TREELITE_DIR}/src/treelite-build/libtreelite_runtime.so)

add_dependencies(dmlclib treelite)
add_dependencies(treelitelib treelite)
Expand Down
5 changes: 1 addition & 4 deletions cpp/include/cuml/ensemble/randomforest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,7 @@ void print_rf_detailed(const RandomForestMetaData<T, L>* forest);
template <class T, class L>
void build_treelite_forest(ModelHandle* model,
const RandomForestMetaData<T, L>* forest,
int num_features, int task_category,
std::vector<unsigned char>& data);

std::vector<unsigned char> save_model(ModelHandle model);
int num_features, int task_category);

ModelHandle concatenate_trees(std::vector<ModelHandle> treelite_handles);

Expand Down
4 changes: 2 additions & 2 deletions cpp/src/decisiontree/decisiontree_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ void build_treelite_tree(TreeBuilderHandle tree_builder,
int num_output_group) {
int node_id = 0;
TREELITE_CHECK(TreeliteTreeBuilderCreateNode(tree_builder, node_id));
TREELITE_CHECK(TreeliteTreeBuilderSetRootNode(tree_builder, node_id));

std::queue<Node_ID_info<T, L>> cur_level_queue;
std::queue<Node_ID_info<T, L>> next_level_queue;
Expand Down Expand Up @@ -138,7 +137,7 @@ void build_treelite_tree(TreeBuilderHandle tree_builder,
TREELITE_CHECK(TreeliteTreeBuilderSetLeafNode(
tree_builder, q_node.unique_node_id, q_node.node.prediction));
} else {
std::vector<double> leaf_vector(num_output_group);
std::vector<float> leaf_vector(num_output_group);
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
for (int j = 0; j < num_output_group; j++) {
if (q_node.node.prediction == j) {
leaf_vector[j] = 1;
Expand All @@ -157,6 +156,7 @@ void build_treelite_tree(TreeBuilderHandle tree_builder,
// The cur_level_queue is empty here, as all the elements are already poped out.
cur_level_queue.swap(next_level_queue);
}
TREELITE_CHECK(TreeliteTreeBuilderSetRootNode(tree_builder, 0));
}

/**
Expand Down
127 changes: 55 additions & 72 deletions cpp/src/fil/fil.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION.
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -367,53 +367,36 @@ void check_params(const forest_params_t* params, bool dense) {
}
}

// tl_node_at is a checked version of tree[i]
inline const tl::Tree::Node& tl_node_at(const tl::Tree& tree, size_t i) {
ASSERT(i < tree.num_nodes, "node index out of range");
return tree[i];
}

int tree_root(const tl::Tree& tree) {
// find the root
int root = -1;
for (int i = 0; i < tree.num_nodes; ++i) {
if (tl_node_at(tree, i).is_root()) {
ASSERT(root == -1, "multi-root trees not supported");
root = i;
}
}
ASSERT(root != -1, "a tree must have a root");
return root;
return 0; // Treelite format assumes that the root is 0
}

int max_depth_helper(const tl::Tree& tree, const tl::Tree::Node& node,
int limit) {
if (node.is_leaf()) return 0;
int max_depth_helper(const tl::Tree& tree, int node_id, int limit) {
if (tree.IsLeaf(node_id)) return 0;
ASSERT(limit > 0,
"recursion depth limit reached, might be a cycle in the tree");
return 1 +
std::max(
max_depth_helper(tree, tl_node_at(tree, node.cleft()), limit - 1),
max_depth_helper(tree, tl_node_at(tree, node.cright()), limit - 1));
std::max(max_depth_helper(tree, tree.LeftChild(node_id), limit - 1),
max_depth_helper(tree, tree.RightChild(node_id), limit - 1));
}

inline int max_depth(const tl::Tree& tree) {
// trees of this depth aren't used, so it most likely means bad input data,
// e.g. cycles in the forest
const int DEPTH_LIMIT = 500;
int root_index = tree_root(tree);
typedef std::pair<const tl::Tree::Node*, int> pair_t;
typedef std::pair<int, int> pair_t;
std::stack<pair_t> stack;
stack.push(pair_t(&tl_node_at(tree, root_index), 0));
stack.push(pair_t(root_index, 0));
int max_depth = 0;
while (!stack.empty()) {
const pair_t& pair = stack.top();
const tl::Tree::Node* node = pair.first;
int node_id = pair.first;
int depth = pair.second;
stack.pop();
while (!node->is_leaf()) {
stack.push(pair_t(&tl_node_at(tree, node->cleft()), depth + 1));
node = &tl_node_at(tree, node->cright());
while (!tree.IsLeaf(node_id)) {
stack.push(pair_t(tree.LeftChild(node_id), depth + 1));
node_id = tree.RightChild(node_id);
depth++;
ASSERT(depth < DEPTH_LIMIT,
"depth limit reached, might be a cycle in the tree");
Expand All @@ -431,12 +414,12 @@ int max_depth(const tl::Model& model) {
}

inline void adjust_threshold(float* pthreshold, int* tl_left, int* tl_right,
bool* default_left, const tl::Tree::Node& node) {
bool* default_left, tl::Operator comparison_op) {
// in treelite (take left node if val [op] threshold),
// the meaning of the condition is reversed compared to FIL;
// thus, "<" in treelite corresonds to comparison ">=" used by FIL
// https://github.com/dmlc/treelite/blob/master/include/treelite/tree.h#L243
switch (node.comparison_op()) {
switch (comparison_op) {
case tl::Operator::kLT:
break;
case tl::Operator::kLE:
Expand Down Expand Up @@ -480,18 +463,18 @@ int find_class_label_from_one_hot(tl::tl_float* vector, int len) {
}

template <typename fil_node_t>
void tl2fil_leaf_payload(fil_node_t* fil_node, const tl::Tree::Node& tl_node,
const forest_params_t& forest_params) {
auto vec = tl_node.leaf_vector();
void tl2fil_leaf_payload(fil_node_t* fil_node, const tl::Tree& tl_tree,
int tl_node_id, const forest_params_t& forest_params) {
auto vec = tl_tree.LeafVector(tl_node_id);
switch (forest_params.leaf_payload_type) {
case leaf_value_t::INT_CLASS_LABEL:
ASSERT(vec.size() == forest_params.num_classes,
"inconsistent number of classes in treelite leaves");
fil_node->val.idx = find_class_label_from_one_hot(&vec[0], vec.size());
break;
case leaf_value_t::FLOAT_SCALAR:
fil_node->val.f = tl_node.leaf_value();
ASSERT(tl_node.leaf_vector().size() == 0,
fil_node->val.f = tl_tree.LeafValue(tl_node_id);
ASSERT(!tl_tree.HasLeafVector(tl_node_id),
"some but not all treelite leaves have leaf_vector()");
break;
default:
Expand All @@ -500,61 +483,61 @@ void tl2fil_leaf_payload(fil_node_t* fil_node, const tl::Tree::Node& tl_node,
}

void node2fil_dense(std::vector<dense_node_t>* pnodes, int root, int cur,
const tl::Tree& tree, const tl::Tree::Node& node,
const tl::Tree& tree, int node_id,
const forest_params_t& forest_params) {
if (node.is_leaf()) {
if (tree.IsLeaf(node_id)) {
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
dense_node_init(&(*pnodes)[root + cur], val_t{.f = NAN}, NAN, 0, false,
true);
tl2fil_leaf_payload(&(*pnodes)[root + cur], node, forest_params);
tl2fil_leaf_payload(&(*pnodes)[root + cur], tree, node_id, forest_params);
return;
}

// inner node
ASSERT(node.split_type() == tl::SplitFeatureType::kNumerical,
ASSERT(tree.SplitType(node_id) == tl::SplitFeatureType::kNumerical,
"only numerical split nodes are supported");
int tl_left = node.cleft(), tl_right = node.cright();
bool default_left = node.default_left();
float threshold = node.threshold();
adjust_threshold(&threshold, &tl_left, &tl_right, &default_left, node);
int tl_left = tree.LeftChild(node_id), tl_right = tree.RightChild(node_id);
bool default_left = tree.DefaultLeft(node_id);
float threshold = tree.Threshold(node_id);
adjust_threshold(&threshold, &tl_left, &tl_right, &default_left,
tree.ComparisonOp(node_id));
dense_node_init(&(*pnodes)[root + cur], val_t{.f = 0}, threshold,
node.split_index(), default_left, false);
tree.SplitIndex(node_id), default_left, false);
int left = 2 * cur + 1;
node2fil_dense(pnodes, root, left, tree, tl_node_at(tree, tl_left),
forest_params);
node2fil_dense(pnodes, root, left + 1, tree, tl_node_at(tree, tl_right),
forest_params);
node2fil_dense(pnodes, root, left, tree, tl_left, forest_params);
node2fil_dense(pnodes, root, left + 1, tree, tl_right, forest_params);
}

void tree2fil_dense(std::vector<dense_node_t>* pnodes, int root,
const tl::Tree& tree,
const forest_params_t& forest_params) {
node2fil_dense(pnodes, root, 0, tree, tl_node_at(tree, tree_root(tree)),
forest_params);
node2fil_dense(pnodes, root, 0, tree, tree_root(tree), forest_params);
}

int tree2fil_sparse(std::vector<sparse_node_t>* pnodes, const tl::Tree& tree,
const forest_params_t& forest_params) {
typedef std::pair<const tl::Tree::Node*, int> pair_t;
typedef std::pair<int, int> pair_t;
std::stack<pair_t> stack;
int root = pnodes->size();
pnodes->push_back(sparse_node_t());
stack.push(pair_t(&tl_node_at(tree, tree_root(tree)), 0));
stack.push(pair_t(tree_root(tree), 0));
while (!stack.empty()) {
const pair_t& top = stack.top();
const tl::Tree::Node* node = top.first;
int node_id = top.first;
int cur = top.second;
stack.pop();

while (!node->is_leaf()) {
while (!tree.IsLeaf(node_id)) {
// inner node
ASSERT(node->split_type() == tl::SplitFeatureType::kNumerical,
ASSERT(tree.SplitType(node_id) == tl::SplitFeatureType::kNumerical,
"only numerical split nodes are supported");
// tl_left and tl_right are indices of the children in the treelite tree
// (stored as an array of nodes)
int tl_left = node->cleft(), tl_right = node->cright();
bool default_left = node->default_left();
float threshold = node->threshold();
adjust_threshold(&threshold, &tl_left, &tl_right, &default_left, *node);
int tl_left = tree.LeftChild(node_id),
tl_right = tree.RightChild(node_id);
bool default_left = tree.DefaultLeft(node_id);
float threshold = tree.Threshold(node_id);
adjust_threshold(&threshold, &tl_left, &tl_right, &default_left,
tree.ComparisonOp(node_id));

// reserve space for child nodes
// left is the offset of the left child node relative to the tree root
Expand All @@ -563,19 +546,20 @@ int tree2fil_sparse(std::vector<sparse_node_t>* pnodes, const tl::Tree& tree,
pnodes->push_back(sparse_node_t());
pnodes->push_back(sparse_node_t());
sparse_node_init_inline(&(*pnodes)[root + cur], val_t{.f = 0}, threshold,
node->split_index(), default_left, false, left);
tree.SplitIndex(node_id), default_left, false,
left);

// push child nodes into the stack
stack.push(pair_t(&tl_node_at(tree, tl_right), left + 1));
//stack.push(pair_t(&tl_node_at(tree, tl_left), left));
node = &tl_node_at(tree, tl_left);
stack.push(pair_t(tl_right, left + 1));
//stack.push(pair_t(tl_left, left));
node_id = tl_left;
cur = left;
}

// leaf node
sparse_node_init(&(*pnodes)[root + cur], val_t{.f = NAN}, NAN, 0, false,
true, 0);
tl2fil_leaf_payload(&(*pnodes)[root + cur], *node, forest_params);
tl2fil_leaf_payload(&(*pnodes)[root + cur], tree, node_id, forest_params);
}

return root;
Expand All @@ -584,11 +568,10 @@ int tree2fil_sparse(std::vector<sparse_node_t>* pnodes, const tl::Tree& tree,
size_t tl_leaf_vector_size(const tl::Model& model) {
const tl::Tree& tree = model.trees[0];
int node_key;
for (node_key = tree_root(tree); !tl_node_at(tree, node_key).is_leaf();
node_key = tl_node_at(tree, node_key).cright())
for (node_key = tree_root(tree); !tree.IsLeaf(node_key);
node_key = tree.RightChild(node_key))
;
const tl::Tree::Node& node = tl_node_at(tree, node_key);
if (node.has_leaf_vector()) return node.leaf_vector().size();
if (tree.HasLeafVector(node_key)) return tree.LeafVector(node_key).size();
return 0;
}

Expand Down Expand Up @@ -627,11 +610,11 @@ void tl2fil_common(forest_params_t* params, const tl::Model& model,
if (model.random_forest_flag) {
params->output = output_t(params->output | output_t::AVG);
}
if (param.pred_transform == "sigmoid") {
if (std::string(param.pred_transform) == "sigmoid") {
params->output = output_t(params->output | output_t::SIGMOID);
} else if (param.pred_transform != "identity") {
} else if (std::string(param.pred_transform) != "identity") {
ASSERT(false, "%s: unsupported treelite prediction transform",
param.pred_transform.c_str());
param.pred_transform);
}
params->num_trees = model.trees.size();
}
Expand Down
Loading