Skip to content

Commit

Permalink
CatBoost converter (#392)
Browse files Browse the repository at this point in the history
* catboost converter

* requirements updated

* fixes

* binclass fix

Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com>
  • Loading branch information
monkey0head and wenbingl authored Jun 8, 2020
1 parent ed01799 commit 6b46e34
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 1 deletion.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ ONNXMLTools enables you to convert models from different machine learning toolki
* libsvm
* XGBoost
* H2O
* CatBoost
<p>Pytorch has its builtin ONNX exporter check <a href="https://pytorch.org/docs/stable/onnx.html">here</a> for details</p>

## Install
Expand All @@ -31,7 +32,7 @@ pip install git+https://github.com/onnx/onnxmltools
If you choose to install `onnxmltools` from its source code, you must set the environment variable `ONNX_ML=1` before installing the `onnx` package.

## Dependencies
This package relies on ONNX, NumPy, and ProtoBuf. If you are converting a model from scikit-learn, Core ML, Keras, LightGBM, SparkML, XGBoost, H2O or LibSVM, you will need an environment with the respective package installed from the list below:
This package relies on ONNX, NumPy, and ProtoBuf. If you are converting a model from scikit-learn, Core ML, Keras, LightGBM, SparkML, XGBoost, H2O, CatBoost or LibSVM, you will need an environment with the respective package installed from the list below:
1. scikit-learn
2. CoreMLTools
3. Keras (version 2.0.8 or higher) with the corresponding Tensorflow version
Expand All @@ -40,6 +41,7 @@ This package relies on ONNX, NumPy, and ProtoBuf. If you are converting a model
6. XGBoost (scikit-learn interface)
7. libsvm
8. H2O
9. CatBoost

ONNXMLTools has been tested with Python **3.5**, **3.6**, and **3.7**.
Version 1.6.1 is the latest version supporting Python 2.7.
Expand Down
1 change: 1 addition & 0 deletions onnxmltools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .convert import convert_tensorflow
from .convert import convert_xgboost
from .convert import convert_h2o
from .convert import convert_catboost

from .utils import load_model
from .utils import save_model
1 change: 1 addition & 0 deletions onnxmltools/convert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
from .main import convert_tensorflow
from .main import convert_xgboost
from .main import convert_h2o
from .main import convert_catboost
11 changes: 11 additions & 0 deletions onnxmltools/convert/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ def convert_libsvm(model, name=None, initial_types=None, doc_string='', target_o
custom_conversion_functions, custom_shape_calculators)


def convert_catboost(model, name=None, initial_types=None, doc_string='', target_opset=None):
try:
from catboost.utils import convert_to_onnx_object
except ImportError:
raise RuntimeError('CatBoost is not installed or needs to be updated. '
'Please install/upgrade CatBoost to use this feature.')

return convert_to_onnx_object(model, export_parameters={'onnx_doc_string': doc_string, 'onnx_graph_name': name},
initial_types=initial_types, target_opset=target_opset)


def convert_lightgbm(model, name=None, initial_types=None, doc_string='', target_opset=None,
targeted_onnx=onnx.__version__, custom_conversion_functions=None, custom_shape_calculators=None):
if not utils.lightgbm_installed():
Expand Down
3 changes: 3 additions & 0 deletions onnxmltools/utils/tests_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ def convert_model(model, name, input_types):
model, prefix = convert_lightgbm(model, name, input_types), "LightGbm"
else:
raise RuntimeError("Unable to convert model of type '{0}'.".format(type(model)))
elif model.__class__.__name__.startswith("CatBoost"):
from onnxmltools.convert import convert_catboost
model, prefix = convert_catboost(model, name, input_types), "CatBoost"
elif isinstance(model, BaseEstimator):
from onnxmltools.convert import convert_sklearn
model, prefix = convert_sklearn(model, name, input_types), "Sklearn"
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ scipy
svm
wheel
xgboost
catboost
flake8
61 changes: 61 additions & 0 deletions tests/catboost/test_CatBoost_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Tests for CatBoostRegressor and CatBoostClassifier converter.
"""
import unittest
import numpy
import warnings
import catboost

from sklearn.datasets import make_regression, make_classification
from onnxmltools.convert import convert_catboost
from onnxmltools.utils import dump_data_and_model, dump_single_regression, dump_multiple_classification


class TestCatBoost(unittest.TestCase):
def test_catboost_regressor(self):
X, y = make_regression(n_samples=100, n_features=4, random_state=0)
catboost_model = catboost.CatBoostRegressor(task_type='CPU', loss_function='RMSE',
n_estimators=10, verbose=0)
dump_single_regression(catboost_model)

catboost_model.fit(X.astype(numpy.float32), y)
catboost_onnx = convert_catboost(catboost_model, name='CatBoostRegression',
doc_string='test regression')
self.assertTrue(catboost_onnx is not None)
dump_data_and_model(X.astype(numpy.float32), catboost_model, catboost_onnx, basename="CatBoostReg-Dec4")

def test_catboost_bin_classifier(self):
import onnxruntime
from distutils.version import StrictVersion

if StrictVersion(onnxruntime.__version__) >= StrictVersion('1.3.0'):
X, y = make_classification(n_samples=100, n_features=4, random_state=0)
catboost_model = catboost.CatBoostClassifier(task_type='CPU', loss_function='CrossEntropy',
n_estimators=10, verbose=0)
catboost_model.fit(X.astype(numpy.float32), y)

catboost_onnx = convert_catboost(catboost_model, name='CatBoostBinClassification',
doc_string='test binary classification')
self.assertTrue(catboost_onnx is not None)
dump_data_and_model(X.astype(numpy.float32), catboost_model, catboost_onnx, basename="CatBoostBinClass")

else:
warnings.warn('Converted CatBoost models for binary classification work with onnxruntime version 1.3.0 or '
'a newer one')

def test_catboost_multi_classifier(self):
X, y = make_classification(n_samples=10, n_informative=8, n_classes=3, random_state=0)
catboost_model = catboost.CatBoostClassifier(task_type='CPU', loss_function='MultiClass',
n_estimators=100, verbose=0)

dump_multiple_classification(catboost_model)

catboost_model.fit(X.astype(numpy.float32), y)
catboost_onnx = convert_catboost(catboost_model, name='CatBoostMultiClassification',
doc_string='test multiclass classification')
self.assertTrue(catboost_onnx is not None)
dump_data_and_model(X.astype(numpy.float32), catboost_model, catboost_onnx, basename="CatBoostMultiClass")


if __name__ == "__main__":
unittest.main()

0 comments on commit 6b46e34

Please sign in to comment.