From 606b41ea22a516342f02e18597068257bc9392f4 Mon Sep 17 00:00:00 2001 From: yungcero <133906218+yungcero@users.noreply.github.com> Date: Mon, 2 Oct 2023 04:03:14 -0600 Subject: [PATCH] Fixes GBT Classifier Issue #648 (#653) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Check if base_score is available and it is a string type convert it to float (#637) Signed-off-by: Donald Tolley Signed-off-by: Xavier Dupre Co-authored-by: Donald Tolley Signed-off-by: James Cao * signed (#639) Signed-off-by: Xavier Dupre Signed-off-by: James Cao * Bump ONNX 1.14.1 in CI pipelines (#644) * verify onnx 1.14.1 rc2 Signed-off-by: jcwchen * Bump ONNX 1.14.1 Signed-off-by: jcwchen --------- Signed-off-by: jcwchen Signed-off-by: James Cao * fix (dev): Working start to address issue #648. This will help enable saving and reading of models from Spark, a requirement for GBTClassifier tree conversion Signed-off-by: James Cao * feat: Allow conversions of SparkML models to ONNX using cluster mode Signed-off-by: James Cao * fix: fix bug that did not fully create temp paths Signed-off-by: James Cao * fix: reformat style Signed-off-by: James Cao * fix: Fixed formatting style to pass ruff tests Signed-off-by: James Cao --------- Signed-off-by: Donald Tolley Signed-off-by: Xavier Dupre Signed-off-by: James Cao Signed-off-by: jcwchen Co-authored-by: Xavier Dupré Co-authored-by: Donald Tolley Co-authored-by: Chun-Wei Chen --- onnxmltools/convert/sparkml/convert.py | 1 + .../tree_ensemble_common.py | 73 +++++++++++++++---- requirements-dev.txt | 1 + 3 files changed, 62 insertions(+), 13 deletions(-) diff --git a/onnxmltools/convert/sparkml/convert.py b/onnxmltools/convert/sparkml/convert.py index abd350c0..32ea48fc 100644 --- a/onnxmltools/convert/sparkml/convert.py +++ b/onnxmltools/convert/sparkml/convert.py @@ -5,6 +5,7 @@ from ..common.onnx_ex import get_maximum_opset_supported from ..common._topology import convert_topology from ._parse import parse_sparkml +from . import operator_converters def convert( diff --git a/onnxmltools/convert/sparkml/operator_converters/tree_ensemble_common.py b/onnxmltools/convert/sparkml/operator_converters/tree_ensemble_common.py index 9d95cb42..2e19c7c3 100644 --- a/onnxmltools/convert/sparkml/operator_converters/tree_ensemble_common.py +++ b/onnxmltools/convert/sparkml/operator_converters/tree_ensemble_common.py @@ -4,6 +4,7 @@ import os import time import numpy +import re from pyspark.sql import SparkSession @@ -47,19 +48,65 @@ def sparkml_tree_dataset_to_sklearn(tree_df, is_classifier): def save_read_sparkml_model_data(spark: SparkSession, model): - tdir = tempfile.tempdir - if tdir is None: - local_dir = spark._jvm.org.apache.spark.util.Utils.getLocalDir( - spark._jsc.sc().conf() - ) - tdir = spark._jvm.org.apache.spark.util.Utils.createTempDir( - local_dir, "onnx" - ).getAbsolutePath() - if tdir is None: - raise FileNotFoundError( - "Unable to create a temporary directory for model '{}'" - ".".format(type(model).__name__) - ) + # Get the value of spark.master + spark_mode = spark.conf.get("spark.master") + + # Check the value of spark.master using regular expression + if "spark://" in spark_mode and ( + "localhost" not in spark_mode or "127.0.0.1" not in spark_mode + ): + dfs_key = "ONNX_DFS_PATH" + try: + dfs_path = spark.conf.get("ONNX_DFS_PATH") + except Exception: + raise ValueError( + "Configuration property '{}' does not exist for SparkSession. \ + Please set this variable to a root distributed file system path to allow \ + for saving and reading of spark models in cluster mode. \ + You can set this in your SparkConfig \ + by setting sparkBuilder.config(ONNX_DFS_PATH, dfs_path)".format( + dfs_key + ) + ) + if dfs_path is None: + # If dfs_path is not specified, throw an error message + # dfs_path arg is required for cluster mode + raise ValueError( + "Argument dfs_path is required for saving model '{}' in cluster mode. \ + You can set this in your SparkConfig by \ + setting sparkBuilder.config(ONNX_DFS_PATH, dfs_path)".format( + type(model).__name__ + ) + ) + else: + # Check that the dfs_path is a valid distributed file system path + # This can be hdfs, wabs, s3, etc. + if re.match(r"^[a-zA-Z]+://", dfs_path) is None: + raise ValueError( + "Argument dfs_path '{}' is not a valid distributed path".format( + dfs_path + ) + ) + else: + # If dfs_path is specified, save the model to a tmp directory + # The dfs_path will be the root of the /tmp + tdir = os.path.join(dfs_path, "tmp/onnx") + else: + # If spark.master is not set or set to local, save the model to a local path. + tdir = tempfile.tempdir + if tdir is None: + local_dir = spark._jvm.org.apache.spark.util.Utils.getLocalDir( + spark._jsc.sc().conf() + ) + tdir = spark._jvm.org.apache.spark.util.Utils.createTempDir( + local_dir, "onnx" + ).getAbsolutePath() + if tdir is None: + raise FileNotFoundError( + "Unable to create a temporary directory for model '{}'" + ".".format(type(model).__name__) + ) + path = os.path.join(tdir, type(model).__name__ + "_" + str(time.time())) model.write().overwrite().save(path) df = spark.read.parquet(os.path.join(path, "data")) diff --git a/requirements-dev.txt b/requirements-dev.txt index 80d2026b..1345c554 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,3 +17,4 @@ scikit-learn>=1.2.0 scipy wheel xgboost==1.7.5 +onnxruntime