Skip to content

Commit

Permalink
Fixes GBT Classifier Issue #648 (#653)
Browse files Browse the repository at this point in the history
* Check if base_score is available and it is a string type convert it to float (#637)

Signed-off-by: Donald Tolley <tolleybot@gmail.com>
Signed-off-by: Xavier Dupre <xadupre@microsoft.com>
Co-authored-by: Donald Tolley <tolleybot@gmail.com>
Signed-off-by: James Cao <james.cao@ironwoodcyber.com>

* signed (#639)

Signed-off-by: Xavier Dupre <xadupre@microsoft.com>
Signed-off-by: James Cao <james.cao@ironwoodcyber.com>

* Bump ONNX 1.14.1 in CI pipelines (#644)

* verify onnx 1.14.1 rc2

Signed-off-by: jcwchen <jacky82226@gmail.com>

* Bump ONNX 1.14.1

Signed-off-by: jcwchen <jacky82226@gmail.com>

---------

Signed-off-by: jcwchen <jacky82226@gmail.com>
Signed-off-by: James Cao <james.cao@ironwoodcyber.com>

* fix (dev): Working start to address issue #648. This will help enable saving and reading of models from Spark, a requirement for GBTClassifier tree conversion

Signed-off-by: James Cao <james.cao@ironwoodcyber.com>

* feat: Allow conversions of SparkML models to ONNX using cluster mode

Signed-off-by: James Cao <james.cao@ironwoodcyber.com>

* fix: fix bug that did not fully create temp paths

Signed-off-by: James Cao <james.cao@ironwoodcyber.com>

* fix: reformat style

Signed-off-by: James Cao <james.cao@ironwoodcyber.com>

* fix: Fixed formatting style to pass ruff tests

Signed-off-by: James Cao <james.cao@ironwoodcyber.com>

---------

Signed-off-by: Donald Tolley <tolleybot@gmail.com>
Signed-off-by: Xavier Dupre <xadupre@microsoft.com>
Signed-off-by: James Cao <james.cao@ironwoodcyber.com>
Signed-off-by: jcwchen <jacky82226@gmail.com>
Co-authored-by: Xavier Dupré <xadupre@users.noreply.github.com>
Co-authored-by: Donald Tolley <tolleybot@gmail.com>
Co-authored-by: Chun-Wei Chen <jacky82226@gmail.com>
  • Loading branch information
4 people authored Oct 2, 2023
1 parent 024a62f commit 606b41e
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 13 deletions.
1 change: 1 addition & 0 deletions onnxmltools/convert/sparkml/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ..common.onnx_ex import get_maximum_opset_supported
from ..common._topology import convert_topology
from ._parse import parse_sparkml
from . import operator_converters

Check failure on line 8 in onnxmltools/convert/sparkml/convert.py

View workflow job for this annotation

GitHub Actions / ruff-format-check

Ruff (F401)

onnxmltools/convert/sparkml/convert.py:8:15: F401 `.operator_converters` imported but unused


def convert(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import time
import numpy
import re
from pyspark.sql import SparkSession


Expand Down Expand Up @@ -47,19 +48,65 @@ def sparkml_tree_dataset_to_sklearn(tree_df, is_classifier):


def save_read_sparkml_model_data(spark: SparkSession, model):
tdir = tempfile.tempdir
if tdir is None:
local_dir = spark._jvm.org.apache.spark.util.Utils.getLocalDir(
spark._jsc.sc().conf()
)
tdir = spark._jvm.org.apache.spark.util.Utils.createTempDir(
local_dir, "onnx"
).getAbsolutePath()
if tdir is None:
raise FileNotFoundError(
"Unable to create a temporary directory for model '{}'"
".".format(type(model).__name__)
)
# Get the value of spark.master
spark_mode = spark.conf.get("spark.master")

# Check the value of spark.master using regular expression
if "spark://" in spark_mode and (
"localhost" not in spark_mode or "127.0.0.1" not in spark_mode
):
dfs_key = "ONNX_DFS_PATH"
try:
dfs_path = spark.conf.get("ONNX_DFS_PATH")
except Exception:
raise ValueError(
"Configuration property '{}' does not exist for SparkSession. \
Please set this variable to a root distributed file system path to allow \
for saving and reading of spark models in cluster mode. \
You can set this in your SparkConfig \
by setting sparkBuilder.config(ONNX_DFS_PATH, dfs_path)".format(
dfs_key
)
)
if dfs_path is None:
# If dfs_path is not specified, throw an error message
# dfs_path arg is required for cluster mode
raise ValueError(
"Argument dfs_path is required for saving model '{}' in cluster mode. \
You can set this in your SparkConfig by \
setting sparkBuilder.config(ONNX_DFS_PATH, dfs_path)".format(
type(model).__name__
)
)
else:
# Check that the dfs_path is a valid distributed file system path
# This can be hdfs, wabs, s3, etc.
if re.match(r"^[a-zA-Z]+://", dfs_path) is None:
raise ValueError(
"Argument dfs_path '{}' is not a valid distributed path".format(
dfs_path
)
)
else:
# If dfs_path is specified, save the model to a tmp directory
# The dfs_path will be the root of the /tmp
tdir = os.path.join(dfs_path, "tmp/onnx")
else:
# If spark.master is not set or set to local, save the model to a local path.
tdir = tempfile.tempdir
if tdir is None:
local_dir = spark._jvm.org.apache.spark.util.Utils.getLocalDir(
spark._jsc.sc().conf()
)
tdir = spark._jvm.org.apache.spark.util.Utils.createTempDir(
local_dir, "onnx"
).getAbsolutePath()
if tdir is None:
raise FileNotFoundError(
"Unable to create a temporary directory for model '{}'"
".".format(type(model).__name__)
)

path = os.path.join(tdir, type(model).__name__ + "_" + str(time.time()))
model.write().overwrite().save(path)
df = spark.read.parquet(os.path.join(path, "data"))
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ scikit-learn>=1.2.0
scipy
wheel
xgboost==1.7.5
onnxruntime

0 comments on commit 606b41e

Please sign in to comment.