Skip to content

Commit

Permalink
Hummingbird integration with lgbm converter (#404)
Browse files Browse the repository at this point in the history
* Hummingbird integration with lgbm
Add additional tests for Hummingbird

* Address flake8 errors

* we get hummingbird directly from github

* manually install torch

* roll back previous case

* skip HB tests for older versions of ORT

* fix reason and turn onnx into onnxruntime

* Addressing wenbingl comments

* fix to the git link

* add -f option to get hb

* add finx_links for pytorch in requirements
remove hummingbird_installed from utils
point to the onnxconveter-common branch with the hummingb_installed code

* Update convert.py

trigger pipeline

* point to the actual onnxconverter-common repository
  • Loading branch information
interesaaat authored Jul 8, 2020
1 parent 6b46e34 commit 8cf85d7
Show file tree
Hide file tree
Showing 6 changed files with 332 additions and 8 deletions.
26 changes: 23 additions & 3 deletions onnxmltools/convert/lightgbm/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from onnxconverter_common.onnx_ex import get_maximum_opset_supported
import onnx
from ..common._topology import convert_topology
from ..common.utils import hummingbird_installed
from ._parse import parse_lightgbm, WrappedBooster

# Invoke the registration of all our converters and shape calculators
Expand All @@ -17,7 +18,7 @@

def convert(model, name=None, initial_types=None, doc_string='', target_opset=None,
targeted_onnx=onnx.__version__, custom_conversion_functions=None,
custom_shape_calculators=None):
custom_shape_calculators=None, without_onnx_ml=False):
'''
This function produces an equivalent ONNX model of the given lightgbm model.
The supported lightgbm modules are listed below.
Expand All @@ -35,11 +36,17 @@ def convert(model, name=None, initial_types=None, doc_string='', target_opset=No
produced model. If ONNXMLTools cannot find a compatible ONNX python package, an error may be thrown.
:param custom_conversion_functions: a dictionary for specifying the user customized conversion function
:param custom_shape_calculators: a dictionary for specifying the user customized shape calculator
:param without_onnx_ml: whether to generate a model composed by ONNX operators only, or to allow the converter
to use ONNX-ML operators as well.
:return: An ONNX model (type: ModelProto) which is equivalent to the input lightgbm model
'''
if initial_types is None:
raise ValueError('Initial types are required. See usage of convert(...) in '
'onnxmltools.convert.lightgbm.convert for details')
if without_onnx_ml and not hummingbird_installed():
raise RuntimeError(
'Hummingbird is not installed. Please install hummingbird to use this feature: pip install hummingbird-ml'
)
if isinstance(model, lightgbm.Booster):
model = WrappedBooster(model)
if name is None:
Expand All @@ -48,5 +55,18 @@ def convert(model, name=None, initial_types=None, doc_string='', target_opset=No
target_opset = target_opset if target_opset else get_maximum_opset_supported()
topology = parse_lightgbm(model, initial_types, target_opset, custom_conversion_functions, custom_shape_calculators)
topology.compile()
onnx_model = convert_topology(topology, name, doc_string, target_opset, targeted_onnx)
return onnx_model
onnx_ml_model = convert_topology(topology, name, doc_string, target_opset, targeted_onnx)

if without_onnx_ml:
from hummingbird.ml import convert
from hummingbird.ml import constants

extra_config = {}
extra_config[constants.ONNX_INITIAL_TYPES] = initial_types
extra_config[constants.ONNX_OUTPUT_MODEL_NAME] = name
extra_config[constants.ONNX_TARGET_OPSET] = target_opset
onnx_model = convert(onnx_ml_model, "onnx", extra_config=extra_config)

return onnx_model

return onnx_ml_model
5 changes: 3 additions & 2 deletions onnxmltools/convert/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,14 @@ def convert_catboost(model, name=None, initial_types=None, doc_string='', target


def convert_lightgbm(model, name=None, initial_types=None, doc_string='', target_opset=None,
targeted_onnx=onnx.__version__, custom_conversion_functions=None, custom_shape_calculators=None):
targeted_onnx=onnx.__version__, custom_conversion_functions=None,
custom_shape_calculators=None, without_onnx_ml=False):
if not utils.lightgbm_installed():
raise RuntimeError('lightgbm is not installed. Please install lightgbm to use this feature.')

from .lightgbm.convert import convert
return convert(model, name, initial_types, doc_string, target_opset, targeted_onnx,
custom_conversion_functions, custom_shape_calculators)
custom_conversion_functions, custom_shape_calculators, without_onnx_ml)


def convert_sklearn(model, name=None, initial_types=None, doc_string='', target_opset=None,
Expand Down
6 changes: 3 additions & 3 deletions onnxmltools/utils/tests_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def dump_data_and_model(data, model, onnx=None, basename="model", folder=None,
return names


def convert_model(model, name, input_types):
def convert_model(model, name, input_types, without_onnx_ml=False):
"""
Runs the appropriate conversion method.
Expand All @@ -201,15 +201,15 @@ def convert_model(model, name, input_types):
from sklearn.base import BaseEstimator
if model.__class__.__name__.startswith("LGBM"):
from onnxmltools.convert import convert_lightgbm
model, prefix = convert_lightgbm(model, name, input_types), "LightGbm"
model, prefix = convert_lightgbm(model, name, input_types, without_onnx_ml=without_onnx_ml), "LightGbm"
elif model.__class__.__name__.startswith("XGB"):
from onnxmltools.convert import convert_xgboost
model, prefix = convert_xgboost(model, name, input_types), "XGB"
elif model.__class__.__name__ == 'Booster':
import lightgbm
if isinstance(model, lightgbm.Booster):
from onnxmltools.convert import convert_lightgbm
model, prefix = convert_lightgbm(model, name, input_types), "LightGbm"
model, prefix = convert_lightgbm(model, name, input_types, without_onnx_ml=without_onnx_ml), "LightGbm"
else:
raise RuntimeError("Unable to convert model of type '{0}'.".format(type(model)))
elif model.__class__.__name__.startswith("CatBoost"):
Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
-f https://download.pytorch.org/whl/torch_stable.html
codecov
coremltools
cython
Expand All @@ -17,3 +18,4 @@ wheel
xgboost
catboost
flake8
hummingbird-ml
Loading

0 comments on commit 8cf85d7

Please sign in to comment.