Skip to content

Commit

Permalink
Merge branch 'master' into SPARK-49789
Browse files Browse the repository at this point in the history
  • Loading branch information
ashahid committed Jan 14, 2025
2 parents fd43fbe + 1fd8362 commit 5dd4f5e
Show file tree
Hide file tree
Showing 61 changed files with 4,535 additions and 819 deletions.
18 changes: 18 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,24 @@
},
"sqlState" : "56K00"
},
"CONNECT_ML" : {
"message" : [
"Generic Spark Connect ML error."
],
"subClass" : {
"ATTRIBUTE_NOT_ALLOWED" : {
"message" : [
"<attribute> is not allowed to be accessed."
]
},
"UNSUPPORTED_EXCEPTION" : {
"message" : [
"<message>"
]
}
},
"sqlState" : "XX000"
},
"CONVERSION_INVALID_INPUT" : {
"message" : [
"The value <str> (<fmt>) cannot be converted to <targetType> because it is malformed. Correct the value as per the syntax, or change its format. Use <suggestion> to tolerate malformed input and return NULL instead."
Expand Down
2 changes: 2 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,7 @@ def __hash__(self):
"pyspark.ml.tests.connect.test_legacy_mode_classification",
"pyspark.ml.tests.connect.test_legacy_mode_pipeline",
"pyspark.ml.tests.connect.test_legacy_mode_tuning",
"pyspark.ml.tests.test_classification",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there
Expand Down Expand Up @@ -1106,6 +1107,7 @@ def __hash__(self):
"pyspark.ml.tests.connect.test_connect_classification",
"pyspark.ml.tests.connect.test_connect_pipeline",
"pyspark.ml.tests.connect.test_connect_tuning",
"pyspark.ml.tests.connect.test_connect_spark_ml_classification",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
Expand Down
4 changes: 2 additions & 2 deletions docs/sql-ref-syntax-aux-describe-table.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ to return the metadata pertaining to a partition or column respectively.
"num_buckets": <num_buckets>,
"bucket_columns": ["<col_name>"],
"sort_columns": ["<col_name>"],
"created_time": "<timestamp_ISO-8601>",
"created_time": "<yyyy-MM-dd'T'HH:mm:ss'Z'>",
"created_by": "<created_by>",
"last_access": "<timestamp_ISO-8601>",
"last_access": "<yyyy-MM-dd'T'HH:mm:ss'Z'>",
"partition_provider": "<partition_provider>"
}
```
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml estimators.
# So register the supported estimator here if you're trying to add a new one.
org.apache.spark.ml.classification.LogisticRegression
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml non-model transformer.
# So register the supported transformer here if you're trying to add a new one.
org.apache.spark.ml.feature.VectorAssembler
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification

import org.apache.spark.annotation.Since
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.util.Summary
import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, lit}
Expand All @@ -28,7 +29,7 @@ import org.apache.spark.sql.types.DoubleType
/**
* Abstraction for multiclass classification results for a given model.
*/
private[classification] trait ClassificationSummary extends Serializable {
private[classification] trait ClassificationSummary extends Summary with Serializable {

/**
* Dataframe output by the model's `transform` method.
Expand Down
9 changes: 7 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ package org.apache.spark.ml.param

import java.lang.reflect.Modifier
import java.util.{List => JList}
import java.util.NoSuchElementException

import scala.annotation.varargs
import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag

import org.json4s._
import org.json4s.jackson.JsonMethods._
Expand All @@ -45,9 +45,14 @@ import org.apache.spark.util.ArrayImplicits._
* See [[ParamValidators]] for factory methods for common validation functions.
* @tparam T param value type
*/
class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
class Param[T: ClassTag](
val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
extends Serializable {

// Spark Connect ML needs T type information which has been erased when compiling,
// Use classTag to preserve the T type.
val paramValueClassTag = implicitly[ClassTag[T]]

def this(parent: Identifiable, name: String, doc: String, isValid: T => Boolean) =
this(parent.uid, name, doc, isValid)

Expand Down
28 changes: 28 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/util/Summary.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.util

import org.apache.spark.annotation.Since

/**
* Trait for the Summary
* All the summaries should extend from this Summary in order to
* support connect.
*/
@Since("4.0.0")
private[spark] trait Summary
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.List;

import org.apache.spark.ml.util.Identifiable$;
import scala.reflect.ClassTag;

/**
* A subclass of Params for testing.
Expand Down Expand Up @@ -110,7 +111,7 @@ private void init() {
ParamValidators.inRange(0.0, 1.0));
List<String> validStrings = Arrays.asList("a", "b");
myStringParam_ = new Param<>(this, "myStringParam", "this is a string param",
ParamValidators.inArray(validStrings));
ParamValidators.inArray(validStrings), ClassTag.apply(String.class));
myDoubleArrayParam_ =
new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param");

Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
HasSolver,
HasParallelism,
)
from pyspark.ml.remote.util import try_remote_attribute_relation
from pyspark.ml.tree import (
_DecisionTreeModel,
_DecisionTreeParams,
Expand Down Expand Up @@ -336,6 +337,7 @@ class _ClassificationSummary(JavaWrapper):

@property
@since("3.1.0")
@try_remote_attribute_relation
def predictions(self) -> DataFrame:
"""
Dataframe outputted by the model's `transform` method.
Expand Down Expand Up @@ -521,6 +523,7 @@ def scoreCol(self) -> str:
return self._call_java("scoreCol")

@property
@try_remote_attribute_relation
def roc(self) -> DataFrame:
"""
Returns the receiver operating characteristic (ROC) curve,
Expand All @@ -546,6 +549,7 @@ def areaUnderROC(self) -> float:

@property
@since("3.1.0")
@try_remote_attribute_relation
def pr(self) -> DataFrame:
"""
Returns the precision-recall curve, which is a Dataframe
Expand All @@ -556,6 +560,7 @@ def pr(self) -> DataFrame:

@property
@since("3.1.0")
@try_remote_attribute_relation
def fMeasureByThreshold(self) -> DataFrame:
"""
Returns a dataframe with two fields (threshold, F-Measure) curve
Expand All @@ -565,6 +570,7 @@ def fMeasureByThreshold(self) -> DataFrame:

@property
@since("3.1.0")
@try_remote_attribute_relation
def precisionByThreshold(self) -> DataFrame:
"""
Returns a dataframe with two fields (threshold, precision) curve.
Expand All @@ -575,6 +581,7 @@ def precisionByThreshold(self) -> DataFrame:

@property
@since("3.1.0")
@try_remote_attribute_relation
def recallByThreshold(self) -> DataFrame:
"""
Returns a dataframe with two fields (threshold, recall) curve.
Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/ml/remote/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
76 changes: 76 additions & 0 deletions python/pyspark/ml/remote/proto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, TYPE_CHECKING, List

import pyspark.sql.connect.proto as pb2
from pyspark.sql.connect.plan import LogicalPlan

if TYPE_CHECKING:
from pyspark.sql.connect.client import SparkConnectClient


class TransformerRelation(LogicalPlan):
"""A logical plan for transforming of a transformer which could be a cached model
or a non-model transformer like VectorAssembler."""

def __init__(
self,
child: Optional["LogicalPlan"],
name: str,
ml_params: pb2.MlParams,
uid: str = "",
is_model: bool = True,
) -> None:
super().__init__(child)
self._name = name
self._ml_params = ml_params
self._uid = uid
self._is_model = is_model

def plan(self, session: "SparkConnectClient") -> pb2.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.ml_relation.transform.input.CopyFrom(self._child.plan(session))

if self._is_model:
plan.ml_relation.transform.obj_ref.CopyFrom(pb2.ObjectRef(id=self._name))
else:
plan.ml_relation.transform.transformer.CopyFrom(
pb2.MlOperator(name=self._name, uid=self._uid, type=pb2.MlOperator.TRANSFORMER)
)

if self._ml_params is not None:
plan.ml_relation.transform.params.CopyFrom(self._ml_params)

return plan


class AttributeRelation(LogicalPlan):
"""A logical plan used in ML to represent an attribute of an instance, which
could be a model or a summary. This attribute returns a DataFrame.
"""

def __init__(self, ref_id: str, methods: List[pb2.Fetch.Method]) -> None:
super().__init__(None)
self._ref_id = ref_id
self._methods = methods

def plan(self, session: "SparkConnectClient") -> pb2.Relation:
plan = self._create_proto_relation()
plan.ml_relation.fetch.obj_ref.CopyFrom(pb2.ObjectRef(id=self._ref_id))
plan.ml_relation.fetch.methods.extend(self._methods)
return plan
Loading

0 comments on commit 5dd4f5e

Please sign in to comment.