Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[optional] support loading and saving models with protobuf #908

Merged
merged 1 commit into from
Oct 19, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ env:
- TASK=if-else
- TASK=sdist PYTHON_VERSION=3.4
- TASK=bdist PYTHON_VERSION=3.5
- TASK=proto
- TASK=gpu METHOD=source
- TASK=gpu METHOD=pip

Expand All @@ -38,6 +39,8 @@ matrix:
env: TASK=pylint
- os: osx
env: TASK=check-docs
- os: osx
env: TASK=proto

before_install:
- test -n $CC && unset CC
Expand Down
14 changes: 13 additions & 1 deletion .travis/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,24 @@ if [[ ${TASK} == "if-else" ]]; then
conda create -q -n test-env python=$PYTHON_VERSION numpy
source activate test-env
mkdir build && cd build && cmake .. && make lightgbm || exit -1
cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=train.conf && ../../lightgbm config=predict.conf output_result=origin.pred || exit -1
cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=train.conf convert_model_language=cpp convert_model=../../src/boosting/gbdt_prediction.cpp && ../../lightgbm config=predict.conf output_result=origin.pred || exit -1
cd $TRAVIS_BUILD_DIR/build && make lightgbm || exit -1
cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=predict.conf output_result=ifelse.pred && python test.py || exit -1
exit 0
fi

if [[ ${TASK} == "proto" ]]; then
conda create -q -n test-env python=$PYTHON_VERSION numpy
source activate test-env
mkdir build && cd build && cmake .. && make lightgbm || exit -1
cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=train.conf && ../../lightgbm config=predict.conf output_result=origin.pred || exit -1
cd $TRAVIS_BUILD_DIR && git clone https://github.com/google/protobuf && cd protobuf && ./autogen.sh && ./configure && make && sudo make install && sudo ldconfig
cd $TRAVIS_BUILD_DIR/build && rm -rf * && cmake -DUSE_PROTO=ON .. && make lightgbm || exit -1
cd $TRAVIS_BUILD_DIR/tests/cpp_test && ../../lightgbm config=train.conf model_format=proto && ../../lightgbm config=predict.conf output_result=proto.pred model_format=proto || exit -1
cd $TRAVIS_BUILD_DIR/tests/cpp_test && python test.py || exit -1
exit 0
fi

conda create -q -n test-env python=$PYTHON_VERSION numpy nose scipy scikit-learn pandas matplotlib pytest
source activate test-env

Expand Down
20 changes: 18 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,24 @@ file(GLOB SOURCES
src/treelearner/*.cpp
)

add_executable(lightgbm src/main.cpp ${SOURCES})
add_library(_lightgbm SHARED src/c_api.cpp src/lightgbm_R.cpp ${SOURCES})
if (USE_PROTO)
find_package(Protobuf REQUIRED)
PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS proto/model.proto)
include_directories(${PROTOBUF_INCLUDE_DIRS})
include_directories(${CMAKE_CURRENT_BINARY_DIR})
SET(PROTO_FILES src/proto/gbdt_model_proto.cpp ${PROTO_HDRS} ${PROTO_SRCS})
else()
include_directories(src/proto/not_implemented)
SET(PROTO_FILES src/proto/not_implemented/gbdt_model_proto.cpp)
endif(USE_PROTO)

add_executable(lightgbm src/main.cpp ${SOURCES} ${PROTO_FILES})
add_library(_lightgbm SHARED src/c_api.cpp src/lightgbm_R.cpp ${SOURCES} ${PROTO_FILES})

if (USE_PROTO)
TARGET_LINK_LIBRARIES(lightgbm ${PROTOBUF_LIBRARIES})
TARGET_LINK_LIBRARIES(_lightgbm ${PROTOBUF_LIBRARIES})
endif(USE_PROTO)

if(MSVC)
set_target_properties(_lightgbm PROPERTIES OUTPUT_NAME "lib_lightgbm")
Expand Down
14 changes: 14 additions & 0 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,20 @@ IO Parameters

- file name of prediction result in ``prediction`` task

- ``model_format``, default=\ ``text``, type=string

- format to save and load model.

- ``text``, use text string.

- ``proto``, use protocol buffer binary format.

- save multiple formats by joining them with comma, like ``text,proto``, in this case, ``model_format`` will be add as suffix after ``output_model``.

- not support loading with multiple formats.

- Note: you need to cmake with -DUSE_PROTO=ON to use this parameter.

- ``is_pre_partition``, default=\ ``false``, type=bool

- used for parallel learning (not include feature parallel)
Expand Down
31 changes: 20 additions & 11 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <LightGBM/meta.h>
#include <LightGBM/config.h>
#include "model.pb.h"

#include <vector>
#include <string>
Expand Down Expand Up @@ -166,7 +167,7 @@ class LIGHTGBM_EXPORT Boosting {

/*!
* \brief Save model to file
* \param num_used_model Number of model that want to save, -1 means save all
* \param num_iterations Number of model that want to save, -1 means save all
* \param is_finish Is training finished or not
* \param filename Filename that want to save to
* \return true if succeeded
Expand All @@ -175,7 +176,7 @@ class LIGHTGBM_EXPORT Boosting {

/*!
* \brief Save model to string
* \param num_used_model Number of model that want to save, -1 means save all
* \param num_iterations Number of model that want to save, -1 means save all
* \return Non-empty string if succeeded
*/
virtual std::string SaveModelToString(int num_iterations) const = 0;
Expand All @@ -187,6 +188,20 @@ class LIGHTGBM_EXPORT Boosting {
*/
virtual bool LoadModelFromString(const std::string& model_str) = 0;

/*!
* \brief Save model with protobuf
* \param num_iterations Number of model that want to save, -1 means save all
* \param filename Filename that want to save to
*/
virtual void SaveModelToProto(int num_iteration, const char* filename) const = 0;

/*!
* \brief Restore from a serialized protobuf file
* \param filename Filename that want to restore from
* \return true if succeeded
*/
virtual bool LoadModelFromProto(const char* filename) = 0;

/*!
* \brief Calculate feature importances
* \param num_iteration Number of model that want to use for feature importance, -1 means use all
Expand Down Expand Up @@ -251,23 +266,17 @@ class LIGHTGBM_EXPORT Boosting {
/*! \brief Disable copy */
Boosting(const Boosting&) = delete;

static bool LoadFileToBoosting(Boosting* boosting, const char* filename);
static bool LoadFileToBoosting(Boosting* boosting, const std::string& format, const char* filename);

/*!
* \brief Create boosting object
* \param type Type of boosting
* \param format Format of model
* \param config config for boosting
* \param filename name of model file, if existing will continue to train from this model
* \return The boosting object
*/
static Boosting* CreateBoosting(const std::string& type, const char* filename);

/*!
* \brief Create boosting object from model file
* \param filename name of model file
* \return The boosting object
*/
static Boosting* CreateBoosting(const char* filename);
static Boosting* CreateBoosting(const std::string& type, const std::string& format, const char* filename);

};

Expand Down
3 changes: 2 additions & 1 deletion include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ struct IOConfig: public ConfigBase {
std::string output_result = "LightGBM_predict_result.txt";
std::string convert_model = "gbdt_prediction.cpp";
std::string input_model = "";
std::string model_format = "text";
int verbosity = 1;
int num_iteration_predict = -1;
bool is_pre_partition = false;
Expand Down Expand Up @@ -446,7 +447,7 @@ struct ParameterAlias {
const std::unordered_set<std::string> parameter_set({
"config", "config_file", "task", "device",
"num_threads", "seed", "boosting_type", "objective", "data",
"output_model", "input_model", "output_result", "valid_data",
"output_model", "input_model", "output_result", "model_format", "valid_data",
"is_enable_sparse", "is_pre_partition", "is_training_metric",
"ndcg_eval_at", "min_data_in_leaf", "min_sum_hessian_in_leaf",
"num_leaves", "feature_fraction", "num_iterations",
Expand Down
10 changes: 10 additions & 0 deletions include/LightGBM/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <LightGBM/meta.h>
#include <LightGBM/dataset.h>
#include "model.pb.h"

#include <string>
#include <vector>
Expand Down Expand Up @@ -31,6 +32,12 @@ class Tree {
*/
explicit Tree(const std::string& str);

/*!
* \brief Construtor, from a protobuf object
* \param model_tree Model protobuf object
*/
explicit Tree(const LightGBM::Model_Tree& model_tree);

~Tree();

/*!
Expand Down Expand Up @@ -165,6 +172,9 @@ class Tree {
/*! \brief Serialize this object to if-else statement*/
std::string ToIfElse(int index, bool is_predict_leaf_index) const;

/*! \brief Serialize this object to protobuf object*/
void ToProto(Model_Tree& model_tree) const;

inline static bool IsZero(double fval) {
if (fval > -kZeroAsMissingValueRange && fval <= kZeroAsMissingValueRange) {
return true;
Expand Down
33 changes: 33 additions & 0 deletions proto/model.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
syntax = "proto3";

package LightGBM;

message Model {
string name = 1;
uint32 num_class = 2;
uint32 num_tree_per_iteration = 3;
uint32 label_index = 4;
uint32 max_feature_idx = 5;
string objective = 6;
bool average_output = 7;
repeated string feature_names = 8;
repeated string feature_infos = 9;
message Tree {
uint32 num_leaves = 1;
uint32 num_cat = 2;
repeated uint32 split_feature = 3;
repeated double split_gain = 4;
repeated double threshold = 5;
repeated uint32 decision_type = 6;
repeated sint32 left_child = 7;
repeated sint32 right_child = 8;
repeated double leaf_value = 9;
repeated uint32 leaf_count = 10;
repeated double internal_value = 11;
repeated double internal_count = 12;
repeated sint32 cat_boundaries = 13;
repeated uint32 cat_threshold = 14;
double shrinkage = 15;
}
repeated Tree trees = 10;
}
21 changes: 20 additions & 1 deletion src/application/application.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ void Application::InitTrain() {
// create boosting
boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type,
config_.io_config.model_format.c_str(),
config_.io_config.input_model.c_str()));
// create objective function
objective_fun_.reset(
Expand All @@ -203,6 +204,22 @@ void Application::InitTrain() {
void Application::Train() {
Log::Info("Started training...");
boosting_->Train(config_.io_config.snapshot_freq, config_.io_config.output_model);
std::vector<std::string> model_formats = Common::Split(config_.io_config.model_format.c_str(), ',');
bool save_with_multiple_format = (model_formats.size() > 1);
for (auto model_format: model_formats) {
std::string save_file_name = config_.io_config.output_model;
if (save_with_multiple_format) {
// use suffix to distinguish different model format
save_file_name += "." + model_format;
}
if (model_format == std::string("text")) {
boosting_->SaveModelToFile(-1, save_file_name.c_str());
} else if (model_format == std::string("proto")) {
boosting_->SaveModelToProto(-1, save_file_name.c_str());
} else {
Log::Fatal("Unknown model format during saving: %s", model_format.c_str());
}
}
// convert model to if-else statement code
if (config_.convert_model_language == std::string("cpp")) {
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
Expand All @@ -223,13 +240,15 @@ void Application::Predict() {

void Application::InitPredict() {
boosting_.reset(
Boosting::CreateBoosting(config_.io_config.input_model.c_str()));
Boosting::CreateBoosting("gbdt", config_.io_config.model_format.c_str(),
config_.io_config.input_model.c_str()));
Log::Info("Finished initializing prediction");
}

void Application::ConvertModel() {
boosting_.reset(
Boosting::CreateBoosting(config_.boosting_type,
config_.io_config.model_format.c_str(),
config_.io_config.input_model.c_str()));
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
}
Expand Down
46 changes: 21 additions & 25 deletions src/boosting/boosting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,30 @@ std::string GetBoostingTypeFromModelFile(const char* filename) {
return type;
}

bool Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
bool Boosting::LoadFileToBoosting(Boosting* boosting, const std::string& format, const char* filename) {
if (boosting != nullptr) {
TextReader<size_t> model_reader(filename, true);
model_reader.ReadAllLines();
std::stringstream str_buf;
for (auto& line : model_reader.Lines()) {
str_buf << line << '\n';
if (format == std::string("text")) {
TextReader<size_t> model_reader(filename, true);
model_reader.ReadAllLines();
std::stringstream str_buf;
for (auto& line : model_reader.Lines()) {
str_buf << line << '\n';
}
if (!boosting->LoadModelFromString(str_buf.str())) {
return false;
}
} else if (format == std::string("proto")) {
if (!boosting->LoadModelFromProto(filename)) {
return false;
}
} else {
Log::Fatal("Unknown model format during loading: %s", format.c_str());
}
if (!boosting->LoadModelFromString(str_buf.str()))
return false;
}
return true;
}

Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename) {
Boosting* Boosting::CreateBoosting(const std::string& type, const std::string& format, const char* filename) {
if (filename == nullptr || filename[0] == '\0') {
if (type == std::string("gbdt")) {
return new GBDT();
Expand All @@ -41,8 +50,7 @@ Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename
}
} else {
std::unique_ptr<Boosting> ret;
auto type_in_file = GetBoostingTypeFromModelFile(filename);
if (type_in_file == std::string("tree")) {
if (format == std::string("proto") || GetBoostingTypeFromModelFile(filename) == std::string("tree")) {
if (type == std::string("gbdt")) {
ret.reset(new GBDT());
} else if (type == std::string("dart")) {
Expand All @@ -54,24 +62,12 @@ Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename
} else {
Log::Fatal("unknown boosting type %s", type.c_str());
}
LoadFileToBoosting(ret.get(), filename);
LoadFileToBoosting(ret.get(), format, filename);
} else {
Log::Fatal("unknown submodel type in model file %s", filename);
Log::Fatal("unknown model format or submodel type in model file %s", filename);
}
return ret.release();
}
}

Boosting* Boosting::CreateBoosting(const char* filename) {
auto type = GetBoostingTypeFromModelFile(filename);
std::unique_ptr<Boosting> ret;
if (type == std::string("tree")) {
ret.reset(new GBDT());
} else {
Log::Fatal("unknown submodel type in model file %s", filename);
}
LoadFileToBoosting(ret.get(), filename);
return ret.release();
}

} // namespace LightGBM
1 change: 0 additions & 1 deletion src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,6 @@ void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
SaveModelToFile(-1, snapshot_out.c_str());
}
}
SaveModelToFile(-1, model_output_path.c_str());
}

double GBDT::BoostFromAverage() {
Expand Down
Loading