diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml
index ebc40e9279137..9f96f5beab3b2 100644
--- a/.github/workflows/build_and_test.yml
+++ b/.github/workflows/build_and_test.yml
@@ -149,7 +149,7 @@ jobs:
catalyst, hive-thriftserver
- >-
streaming, sql-kafka-0-10, streaming-kafka-0-10,
- mllib-local, mllib,
+ mllib-local, mllib-common, mllib,
yarn, mesos, kubernetes, hadoop-cloud, spark-ganglia-lgpl,
connect, protobuf
# Here, we split Hive and SQL tests into some of slow ones and the rest of them.
diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java
index a15d07cf59958..bf7c256fc94ff 100644
--- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java
+++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java
@@ -56,7 +56,7 @@ public KVTypeInfo(Class> type) {
KVIndex idx = m.getAnnotation(KVIndex.class);
if (idx != null) {
checkIndex(idx, indices);
- Preconditions.checkArgument(m.getParameterTypes().length == 0,
+ Preconditions.checkArgument(m.getParameterCount() == 0,
"Annotated method %s::%s should not have any parameters.", type.getName(), m.getName());
m.setAccessible(true);
indices.put(idx.value(), idx);
diff --git a/connector/connect/client/jvm/pom.xml b/connector/connect/client/jvm/pom.xml
index 7606795f8203a..ac4b1655f5ea0 100644
--- a/connector/connect/client/jvm/pom.xml
+++ b/connector/connect/client/jvm/pom.xml
@@ -62,6 +62,18 @@
+
+ org.apache.spark
+ spark-mllib-common_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+ com.google.guava
+ guava
+
+
+
com.google.protobuf
protobuf-java
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Estimator.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Estimator.scala
new file mode 100644
index 0000000000000..144a10641c758
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import scala.annotation.varargs
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.ml.param.{ParamMap, ParamPair}
+import org.apache.spark.sql.Dataset
+
+/**
+ * Abstract class for estimators that fit models to data.
+ */
+abstract class Estimator[M <: Model[M]] extends PipelineStage {
+
+ /**
+ * Fits a single model to the input data with optional parameters.
+ *
+ * @param dataset
+ * input dataset
+ * @param firstParamPair
+ * the first param pair, overrides embedded params
+ * @param otherParamPairs
+ * other param pairs. These values override any specified in this Estimator's embedded
+ * ParamMap.
+ * @return
+ * fitted model
+ */
+ @Since("3.5.0")
+ @varargs
+ def fit(
+ dataset: Dataset[_],
+ firstParamPair: ParamPair[_],
+ otherParamPairs: ParamPair[_]*): M = {
+ val map = new ParamMap()
+ .put(firstParamPair)
+ .put(otherParamPairs: _*)
+ fit(dataset, map)
+ }
+
+ /**
+ * Fits a single model to the input data with provided parameter map.
+ *
+ * @param dataset
+ * input dataset
+ * @param paramMap
+ * Parameter map. These values override any specified in this Estimator's embedded ParamMap.
+ * @return
+ * fitted model
+ */
+ @Since("3.5.0")
+ def fit(dataset: Dataset[_], paramMap: ParamMap): M = {
+ copy(paramMap).fit(dataset)
+ }
+
+ /**
+ * Fits a model to the input data.
+ */
+ @Since("3.5.0")
+ def fit(dataset: Dataset[_]): M
+
+ /**
+ * Fits multiple models to the input data with multiple sets of parameters. The default
+ * implementation uses a for loop on each parameter map. Subclasses could override this to
+ * optimize multi-model training.
+ *
+ * @param dataset
+ * input dataset
+ * @param paramMaps
+ * An array of parameter maps. These values override any specified in this Estimator's
+ * embedded ParamMap.
+ * @return
+ * fitted models, matching the input parameter maps
+ */
+ @Since("3.5.0")
+ def fit(dataset: Dataset[_], paramMaps: Seq[ParamMap]): Seq[M] = {
+ paramMaps.map(fit(dataset, _))
+ }
+
+ @Since("3.5.0")
+ override def copy(extra: ParamMap): Estimator[M]
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Model.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Model.scala
new file mode 100644
index 0000000000000..a5d6aa1a0795c
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Model.scala
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.ml.param.ParamMap
+
+/**
+ * A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]].
+ *
+ * @tparam M
+ * model type
+ */
+abstract class Model[M <: Model[M]] extends Transformer {
+
+ /**
+ * The parent estimator that produced this model.
+ * @note
+ * For ensembles' component Models, this value can be null.
+ */
+ @transient var parent: Estimator[M] = _
+
+ /**
+ * Sets the parent of this model (Java API).
+ */
+ @Since("3.5.0")
+ def setParent(parent: Estimator[M]): M = {
+ this.parent = parent
+ this.asInstanceOf[M]
+ }
+
+ /** Indicates whether this [[Model]] has a corresponding parent. */
+ @Since("3.5.0")
+ def hasParent: Boolean = parent != null
+
+ @Since("3.5.0")
+ override def copy(extra: ParamMap): M
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Pipeline.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Pipeline.scala
new file mode 100644
index 0000000000000..cebbcd167ce34
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.param.{ParamMap, Params}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * A stage in a pipeline, either an [[Estimator]] or a [[Transformer]].
+ */
+abstract class PipelineStage extends Params with Logging {
+
+ /**
+ * Check transform validity and derive the output schema from the input schema.
+ *
+ * We check validity for interactions between parameters during `transformSchema` and raise an
+ * exception if any parameter value is invalid. Parameter value checks which do not depend on
+ * other parameters are handled by `Param.validate()`.
+ *
+ * Typical implementation should first conduct verification on schema change and parameter
+ * validity, including complex parameter interaction checks.
+ */
+ def transformSchema(schema: StructType): StructType
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * Derives the output schema from the input schema and parameters, optionally with logging.
+ *
+ * This should be optimistic. If it is unclear whether the schema will be valid, then it should
+ * be assumed valid until proven otherwise.
+ */
+ @DeveloperApi
+ protected def transformSchema(schema: StructType, logging: Boolean): StructType = {
+ if (logging) {
+ logDebug(s"Input schema: ${schema.json}")
+ }
+ val outputSchema = transformSchema(schema)
+ if (logging) {
+ logDebug(s"Expected output schema: ${outputSchema.json}")
+ }
+ outputSchema
+ }
+
+ override def copy(extra: ParamMap): PipelineStage
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Predictor.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Predictor.scala
new file mode 100644
index 0000000000000..517d5e060f531
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.ml.linalg.VectorUDT
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.types.{DataType, StructType}
+
+/**
+ * Abstraction for prediction problems (regression and classification). It accepts all NumericType
+ * labels and will automatically cast it to DoubleType in `fit()`. If this predictor supports
+ * weights, it accepts all NumericType weights, which will be automatically casted to DoubleType
+ * in `fit()`.
+ *
+ * @tparam FeaturesType
+ * Type of features. E.g., `VectorUDT` for vector features.
+ * @tparam Learner
+ * Specialization of this class. If you subclass this type, use this type parameter to specify
+ * the concrete type.
+ * @tparam M
+ * Specialization of [[PredictionModel]]. If you subclass this type, use this type parameter to
+ * specify the concrete type for the corresponding model.
+ */
+abstract class Predictor[
+ FeaturesType,
+ Learner <: Predictor[FeaturesType, Learner, M],
+ M <: PredictionModel[FeaturesType, M]]
+ extends Estimator[M]
+ with PredictorParams {
+
+ /** @group setParam */
+ @Since("3.5.0")
+ def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner]
+
+ /** @group setParam */
+ @Since("3.5.0")
+ def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner]
+
+ /** @group setParam */
+ @Since("3.5.0")
+ def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]
+
+ @Since("3.5.0")
+ override def fit(dataset: Dataset[_]): M = {
+ // TODO: should send the id of the input dataset and the latest params to the server,
+ // then invoke the 'fit' method of the remote predictor
+ throw new NotImplementedError
+ }
+
+ @Since("3.5.0")
+ override def copy(extra: ParamMap): Learner
+
+ /**
+ * Returns the SQL DataType corresponding to the FeaturesType type parameter.
+ *
+ * This is used by `validateAndTransformSchema()`. This workaround is needed since SQL has
+ * different APIs for Scala and Java.
+ *
+ * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
+ */
+ private[ml] def featuresDataType: DataType = new VectorUDT
+
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema, fitting = true, featuresDataType)
+ }
+}
+
+/**
+ * Abstraction for a model for prediction tasks (regression and classification).
+ *
+ * @tparam FeaturesType
+ * Type of features. E.g., `VectorUDT` for vector features.
+ * @tparam M
+ * Specialization of [[PredictionModel]]. If you subclass this type, use this type parameter to
+ * specify the concrete type for the corresponding model.
+ */
+abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]]
+ extends Model[M]
+ with PredictorParams {
+
+ /** @group setParam */
+ @Since("3.5.0")
+ def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M]
+
+ /** @group setParam */
+ @Since("3.5.0")
+ def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M]
+
+ /** Returns the number of features the model was trained on. If unknown, returns -1 */
+ @Since("3.5.0")
+ def numFeatures: Int = -1
+
+ /**
+ * Returns the SQL DataType corresponding to the FeaturesType type parameter.
+ *
+ * This is used by `validateAndTransformSchema()`. This workaround is needed since SQL has
+ * different APIs for Scala and Java.
+ *
+ * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
+ */
+ protected def featuresDataType: DataType = new VectorUDT
+
+ @Since("3.5.0")
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = validateAndTransformSchema(schema, fitting = false, featuresDataType)
+ if ($(predictionCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateNumeric(outputSchema, $(predictionCol))
+ }
+ outputSchema
+ }
+
+ /**
+ * Transforms dataset by reading from [[featuresCol]], calling `predict`, and storing the
+ * predictions as a new column [[predictionCol]].
+ *
+ * @param dataset
+ * input dataset
+ * @return
+ * transformed dataset with [[predictionCol]] of type `Double`
+ */
+ @Since("3.5.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
+ if ($(predictionCol).nonEmpty) {
+ transformImpl(dataset)
+ } else {
+ this.logWarning(
+ s"$uid: Predictor.transform() does nothing" +
+ " because no output columns were set.")
+ dataset.toDF
+ }
+ }
+
+ protected def transformImpl(dataset: Dataset[_]): DataFrame = {
+ // TODO: should send the id of the input dataset and the latest params to the server,
+ // then invoke the 'transform' method of the remote model
+ throw new NotImplementedError
+ }
+
+ /**
+ * Predict label for the given features. This method is used to implement `transform()` and
+ * output [[predictionCol]].
+ */
+ @Since("3.5.0")
+ def predict(features: FeaturesType): Double
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Transformer.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Transformer.scala
new file mode 100644
index 0000000000000..4eebf031b90de
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import scala.annotation.varargs
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.types._
+
+/**
+ * Abstract class for transformers that transform one dataset into another.
+ */
+abstract class Transformer extends PipelineStage {
+
+ /**
+ * Transforms the dataset with optional parameters
+ * @param dataset
+ * input dataset
+ * @param firstParamPair
+ * the first param pair, overwrite embedded params
+ * @param otherParamPairs
+ * other param pairs, overwrite embedded params
+ * @return
+ * transformed dataset
+ */
+ @Since("3.5.0")
+ @varargs
+ def transform(
+ dataset: Dataset[_],
+ firstParamPair: ParamPair[_],
+ otherParamPairs: ParamPair[_]*): DataFrame = {
+ val map = new ParamMap()
+ .put(firstParamPair)
+ .put(otherParamPairs: _*)
+ transform(dataset, map)
+ }
+
+ /**
+ * Transforms the dataset with provided parameter map as additional parameters.
+ * @param dataset
+ * input dataset
+ * @param paramMap
+ * additional parameters, overwrite embedded params
+ * @return
+ * transformed dataset
+ */
+ @Since("3.5.0")
+ def transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame = {
+ this.copy(paramMap).transform(dataset)
+ }
+
+ /**
+ * Transforms the input dataset.
+ */
+ @Since("3.5.0")
+ def transform(dataset: Dataset[_]): DataFrame
+
+ @Since("3.5.0")
+ override def copy(extra: ParamMap): Transformer
+}
+
+/**
+ * Abstract class for transformers that take one input column, apply transformation, and output
+ * the result as a new column.
+ */
+abstract class UnaryTransformer[IN: TypeTag, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]]
+ extends Transformer
+ with HasInputCol
+ with HasOutputCol
+ with Logging {
+
+ /** @group setParam */
+ @Since("3.5.0")
+ def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T]
+
+ /** @group setParam */
+ @Since("3.5.0")
+ def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T]
+
+ /**
+ * Creates the transform function using the given param map. The input param map already takes
+ * account of the embedded param map. So the param values should be determined solely by the
+ * input param map.
+ */
+ protected def createTransformFunc: IN => OUT
+
+ /**
+ * Returns the data type of the output column.
+ */
+ @Since("3.5.0")
+ protected def outputDataType: DataType
+
+ /**
+ * Validates the input type. Throw an exception if it is invalid.
+ */
+ protected def validateInputType(inputType: DataType): Unit = {}
+
+ @Since("3.5.0")
+ override def transformSchema(schema: StructType): StructType = {
+ val inputType = schema($(inputCol)).dataType
+ validateInputType(inputType)
+ if (schema.fieldNames.contains($(outputCol))) {
+ throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.")
+ }
+ val outputFields = schema.fields :+
+ StructField($(outputCol), outputDataType, nullable = false)
+ StructType(outputFields)
+ }
+
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ // TODO: should send the id of the input dataset and the latest params to the server,
+ // then invoke the 'transform' method of the remote model
+ throw new NotImplementedError
+ }
+
+ @Since("3.5.0")
+ override def copy(extra: ParamMap): T = defaultCopy(extra)
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
new file mode 100644
index 0000000000000..9adf49866b47f
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.ml.{PredictionModel, Predictor}
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Single-label binary or multiclass classification. Classes are indexed {0, 1, ..., numClasses -
+ * 1}.
+ *
+ * @tparam FeaturesType
+ * Type of input features. E.g., `Vector`
+ * @tparam E
+ * Concrete Estimator type
+ * @tparam M
+ * Concrete Model type
+ */
+abstract class Classifier[
+ FeaturesType,
+ E <: Classifier[FeaturesType, E, M],
+ M <: ClassificationModel[FeaturesType, M]]
+ extends Predictor[FeaturesType, E, M]
+ with ClassifierParams {
+
+ @Since("3.5.0")
+ def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E]
+
+ // TODO: defaultEvaluator (follow-up PR)
+}
+
+/**
+ * Model produced by a [[Classifier]]. Classes are indexed {0, 1, ..., numClasses - 1}.
+ *
+ * @tparam FeaturesType
+ * Type of input features. E.g., `Vector`
+ * @tparam M
+ * Concrete Model type
+ */
+abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]]
+ extends PredictionModel[FeaturesType, M]
+ with ClassifierParams {
+
+ /** @group setParam */
+ @Since("3.5.0")
+ def setRawPredictionCol(value: String): M = set(rawPredictionCol, value).asInstanceOf[M]
+
+ /** Number of classes (values which the label can take). */
+ @Since("3.5.0")
+ def numClasses: Int
+
+ @Since("3.5.0")
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = super.transformSchema(schema)
+ if ($(predictionCol).nonEmpty) {
+ outputSchema = SchemaUtils.updateNumValues(schema, $(predictionCol), numClasses)
+ }
+ if ($(rawPredictionCol).nonEmpty) {
+ outputSchema =
+ SchemaUtils.updateAttributeGroupSize(outputSchema, $(rawPredictionCol), numClasses)
+ }
+ outputSchema
+ }
+
+ /**
+ * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
+ * parameters:
+ * - predicted labels as [[predictionCol]] of type `Double`
+ * - raw predictions (confidences) as [[rawPredictionCol]] of type `Vector`.
+ *
+ * @param dataset
+ * input dataset
+ * @return
+ * transformed dataset
+ */
+ @Since("3.5.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ // TODO: should send the id of the input dataset and the latest params to the server,
+ // then invoke the 'transform' method of the remote model
+ throw new NotImplementedError
+ }
+
+ final override def transformImpl(dataset: Dataset[_]): DataFrame =
+ throw new UnsupportedOperationException(s"transformImpl is not supported in $getClass")
+
+ /**
+ * Predict label for the given features. This method is used to implement `transform()` and
+ * output [[predictionCol]].
+ *
+ * This default implementation for classification predicts the index of the maximum value from
+ * `predictRaw()`.
+ */
+ @Since("3.5.0")
+ override def predict(features: FeaturesType): Double = {
+ // TODO: should send the vector to the server,
+ // then invoke the 'predict' method of the remote model
+
+ // Note: Subclass may need to override this, since the result
+ // maybe adjusted by param like `thresholds`.
+ throw new NotImplementedError
+ }
+
+ /**
+ * Raw prediction for each possible label. The meaning of a "raw" prediction may vary between
+ * algorithms, but it intuitively gives a measure of confidence in each possible label (where
+ * larger = more confident). This internal method is used to implement `transform()` and output
+ * [[rawPredictionCol]].
+ *
+ * @return
+ * vector where element i is the raw prediction for label i. This raw prediction may be any
+ * real number, where a larger value indicates greater confidence for that label.
+ */
+ @Since("3.5.0")
+ def predictRaw(features: FeaturesType): Vector = {
+ // TODO: should send the vector to the server,
+ // then invoke the 'predictRaw' method of the remote model
+ throw new NotImplementedError
+ }
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
new file mode 100644
index 0000000000000..e4db8a047fa27
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.annotation.Since
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * Single-label binary or multiclass classifier which can output class conditional probabilities.
+ *
+ * @tparam FeaturesType
+ * Type of input features. E.g., `Vector`
+ * @tparam E
+ * Concrete Estimator type
+ * @tparam M
+ * Concrete Model type
+ */
+abstract class ProbabilisticClassifier[
+ FeaturesType,
+ E <: ProbabilisticClassifier[FeaturesType, E, M],
+ M <: ProbabilisticClassificationModel[FeaturesType, M]]
+ extends Classifier[FeaturesType, E, M]
+ with ProbabilisticClassifierParams {
+
+ /** @group setParam */
+ @Since("3.5.0")
+ def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E]
+
+ /** @group setParam */
+ @Since("3.5.0")
+ def setThresholds(value: Array[Double]): E = set(thresholds, value).asInstanceOf[E]
+}
+
+/**
+ * Model produced by a [[ProbabilisticClassifier]]. Classes are indexed {0, 1, ..., numClasses -
+ * 1}.
+ *
+ * @tparam FeaturesType
+ * Type of input features. E.g., `Vector`
+ * @tparam M
+ * Concrete Model type
+ */
+abstract class ProbabilisticClassificationModel[
+ FeaturesType,
+ M <: ProbabilisticClassificationModel[FeaturesType, M]]
+ extends ClassificationModel[FeaturesType, M]
+ with ProbabilisticClassifierParams {
+
+ /** @group setParam */
+ @Since("3.5.0")
+ def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M]
+
+ /** @group setParam */
+ @Since("3.5.0")
+ def setThresholds(value: Array[Double]): M = {
+ require(
+ value.length == numClasses,
+ this.getClass.getSimpleName +
+ ".setThresholds() called with non-matching numClasses and thresholds.length." +
+ s" numClasses=$numClasses, but thresholds has length ${value.length}")
+ set(thresholds, value).asInstanceOf[M]
+ }
+
+ @Since("3.5.0")
+ override def transformSchema(schema: StructType): StructType = {
+ var outputSchema = super.transformSchema(schema)
+ if ($(probabilityCol).nonEmpty) {
+ outputSchema =
+ SchemaUtils.updateAttributeGroupSize(outputSchema, $(probabilityCol), numClasses)
+ }
+ outputSchema
+ }
+
+ /**
+ * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
+ * parameters:
+ * - predicted labels as [[predictionCol]] of type `Double`
+ * - raw predictions (confidences) as [[rawPredictionCol]] of type `Vector`
+ * - probability of each class as [[probabilityCol]] of type `Vector`.
+ *
+ * @param dataset
+ * input dataset
+ * @return
+ * transformed dataset
+ */
+ @Since("3.5.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
+ // TODO: should send the id of the input dataset and the latest params to the server,
+ // then invoke the 'transform' method of the remote model
+ throw new NotImplementedError
+ }
+
+ /**
+ * Predict the probability of each class given the features. These predictions are also called
+ * class conditional probabilities.
+ *
+ * This internal method is used to implement `transform()` and output [[probabilityCol]].
+ *
+ * @return
+ * Estimated class conditional probabilities
+ */
+ @Since("3.5.0")
+ def predictProbability(features: FeaturesType): Vector = {
+ // TODO: should send the vector to the server,
+ // then invoke the 'predictProbability' method of the remote model
+ throw new NotImplementedError
+ }
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index ad921bcc4e3f8..193eb4faaaba2 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -24,8 +24,10 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Stable
import org.apache.spark.connect.proto.Parse.ParseFormat
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
+import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.StructType
/**
@@ -531,10 +533,8 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging
*/
@scala.annotation.varargs
def textFile(paths: String*): Dataset[String] = {
- // scalastyle:off throwerror
- // TODO: this method can be supported and should be included in the client API.
- throw new NotImplementedError()
- // scalastyle:on throwerror
+ assertNoSpecifiedSchema("textFile")
+ text(paths: _*).select("value").as(StringEncoder)
}
private def assertSourceFormatSpecified(): Unit = {
@@ -556,6 +556,15 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging
}
}
+ /**
+ * A convenient function for schema validation in APIs.
+ */
+ private def assertNoSpecifiedSchema(operation: String): Unit = {
+ if (userSpecifiedSchema.nonEmpty) {
+ throw QueryCompilationErrors.userSpecifiedSchemaUnsupportedError(operation)
+ }
+ }
+
///////////////////////////////////////////////////////////////////////////////////////
// Builder pattern config options
///////////////////////////////////////////////////////////////////////////////////////
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index 29c2e89c53779..729ee9ed6a09f 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -4003,6 +4003,16 @@ object functions {
*/
def array_compact(column: Column): Column = Column.fn("array_compact", column)
+ /**
+ * Returns an array containing value as well as all elements from array. The new element is
+ * positioned at the beginning of the array.
+ *
+ * @group collection_funcs
+ * @since 3.5.0
+ */
+ def array_prepend(column: Column, element: Any): Column =
+ Column.fn("array_prepend", column, lit(element))
+
/**
* Removes duplicate values from the array.
* @group collection_funcs
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 5aa5500116d8a..605b15123c670 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -25,6 +25,7 @@ import io.grpc.StatusRuntimeException
import java.util.Properties
import org.apache.commons.io.FileUtils
import org.apache.commons.io.output.TeeOutputStream
+import org.apache.commons.lang3.{JavaVersion, SystemUtils}
import org.scalactic.TolerantNumerics
import org.apache.spark.SPARK_VERSION
@@ -55,6 +56,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {
}
test("eager execution of sql") {
+ assume(IntegrationTestUtils.isSparkHiveJarAvailable)
withTable("test_martin") {
// Fails, because table does not exist.
assertThrows[StatusRuntimeException] {
@@ -161,6 +163,26 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {
}
}
+ test("textFile") {
+ val testDataPath = java.nio.file.Paths
+ .get(
+ IntegrationTestUtils.sparkHome,
+ "connector",
+ "connect",
+ "common",
+ "src",
+ "test",
+ "resources",
+ "query-tests",
+ "test-data",
+ "people.txt")
+ .toAbsolutePath
+ val result = spark.read.textFile(testDataPath.toString).collect()
+ val expected = Array("Michael, 29", "Andy, 30", "Justin, 19")
+ assert(result.length == 3)
+ assert(result === expected)
+ }
+
test("write table") {
withTable("myTable") {
val df = spark.range(10).limit(3)
@@ -182,16 +204,18 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {
}
test("write jdbc") {
- val url = "jdbc:derby:memory:1234"
- val table = "t1"
- try {
- spark.range(10).write.jdbc(url = s"$url;create=true", table, new Properties())
- val result = spark.read.jdbc(url = url, table, new Properties()).collect()
- assert(result.length == 10)
- } finally {
- // clean up
- assertThrows[StatusRuntimeException] {
- spark.read.jdbc(url = s"$url;drop=true", table, new Properties()).collect()
+ if (SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9)) {
+ val url = "jdbc:derby:memory:1234"
+ val table = "t1"
+ try {
+ spark.range(10).write.jdbc(url = s"$url;create=true", table, new Properties())
+ val result = spark.read.jdbc(url = url, table, new Properties()).collect()
+ assert(result.length == 10)
+ } finally {
+ // clean up
+ assertThrows[StatusRuntimeException] {
+ spark.read.jdbc(url = s"$url;drop=true", table, new Properties()).collect()
+ }
}
}
}
@@ -227,6 +251,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {
// TODO (SPARK-42519): Revisit this test after we can set configs.
// e.g. spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
test("writeTo with create") {
+ assume(IntegrationTestUtils.isSparkHiveJarAvailable)
withTable("myTableV2") {
// Failed to create as Hive support is required.
spark.range(3).writeTo("myTableV2").create()
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 3c7e1fdeee645..95d6fddc97caa 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -1714,6 +1714,10 @@ class PlanGenerationTestSuite
fn.array_distinct(fn.col("e"))
}
+ functionTest("array_prepend") {
+ fn.array_prepend(fn.col("e"), lit(1))
+ }
+
functionTest("array_intersect") {
fn.array_intersect(fn.col("e"), fn.array(lit(10), lit(4)))
}
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 97d130421a242..a2b4762f0a96b 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -174,7 +174,6 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.callUDF"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.unwrap_udt"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udaf"),
- ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.broadcast"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.typedlit"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.typedLit"),
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
index f27ea614a7eb8..a98f7e9c13b37 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.connect.client.util
import java.io.File
+import java.nio.file.{Files, Paths}
import scala.util.Properties.versionNumberString
@@ -27,14 +28,15 @@ object IntegrationTestUtils {
// System properties used for testing and debugging
private val DEBUG_SC_JVM_CLIENT = "spark.debug.sc.jvm.client"
- private[sql] lazy val scalaDir = {
- val version = versionNumberString.split('.') match {
+ private[sql] lazy val scalaVersion = {
+ versionNumberString.split('.') match {
case Array(major, minor, _*) => major + "." + minor
case _ => versionNumberString
}
- "scala-" + version
}
+ private[sql] lazy val scalaDir = s"scala-$scalaVersion"
+
private[sql] lazy val sparkHome: String = {
if (!(sys.props.contains("spark.test.home") || sys.env.contains("SPARK_HOME"))) {
fail("spark.test.home or SPARK_HOME is not set.")
@@ -49,6 +51,12 @@ object IntegrationTestUtils {
// scalastyle:on println
private[connect] def debug(error: Throwable): Unit = if (isDebug) error.printStackTrace()
+ private[sql] lazy val isSparkHiveJarAvailable: Boolean = {
+ val filePath = s"$sparkHome/assembly/target/$scalaDir/jars/" +
+ s"spark-hive_$scalaVersion-${org.apache.spark.SPARK_VERSION}.jar"
+ Files.exists(Paths.get(filePath))
+ }
+
/**
* Find a jar in the Spark project artifacts. It requires a build first (e.g. build/sbt package,
* build/mvn clean install -DskipTests) so that this method can find the jar in the target
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
index beae5bfa27e2a..d1a34603f48cf 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
@@ -62,6 +62,18 @@ object SparkConnectServerUtils {
"connector/connect/server",
"spark-connect-assembly",
"spark-connect").getCanonicalPath
+ val catalogImplementation = if (IntegrationTestUtils.isSparkHiveJarAvailable) {
+ "hive"
+ } else {
+ // scalastyle:off println
+ println(
+ "Will start Spark Connect server with `spark.sql.catalogImplementation=in-memory`, " +
+ "some tests that rely on Hive will be ignored. If you don't want to skip them:\n" +
+ "1. Test with maven: run `build/mvn install -DskipTests -Phive` before testing\n" +
+ "2. Test with sbt: run test with `-Phive` profile")
+ // scalastyle:on println
+ "in-memory"
+ }
val builder = Process(
Seq(
"bin/spark-submit",
@@ -72,7 +84,7 @@ object SparkConnectServerUtils {
"--conf",
"spark.sql.catalog.testcat=org.apache.spark.sql.connect.catalog.InMemoryTableCatalog",
"--conf",
- "spark.sql.catalogImplementation=hive",
+ s"spark.sql.catalogImplementation=$catalogImplementation",
"--class",
"org.apache.spark.sql.connect.SimpleSparkConnectService",
jar),
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index 2118f8e4823ee..da0f974a74906 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -272,6 +272,9 @@ message ExecutePlanResponse {
// The metrics observed during the execution of the query plan.
repeated ObservedMetrics observed_metrics = 6;
+ // (Optional) The Spark schema. This field is available when `collect` is called.
+ DataType schema = 7;
+
// A SQL command returns an opaque Relation that can be directly used as input for the next
// call.
message SqlCommandResult {
@@ -413,6 +416,11 @@ message AddArtifactsRequest {
// User context
UserContext user_context = 2;
+ // Provides optional information about the client sending the request. This field
+ // can be used for language or version specific information and is only intended for
+ // logging purposes and will not be interpreted by the server.
+ optional string client_type = 6;
+
// A chunk of an Artifact.
message ArtifactChunk {
// Data chunk.
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 69451e7b76eef..aba965082ea2a 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -63,6 +63,7 @@ message Relation {
MapPartitions map_partitions = 28;
CollectMetrics collect_metrics = 29;
Parse parse = 30;
+ GroupMap group_map = 31;
// NA functions
NAFill fill_na = 90;
@@ -788,6 +789,17 @@ message MapPartitions {
CommonInlineUserDefinedFunction func = 2;
}
+message GroupMap {
+ // (Required) Input relation for Group Map API: apply, applyInPandas.
+ Relation input = 1;
+
+ // (Required) Expressions for grouping keys.
+ repeated Expression grouping_expressions = 2;
+
+ // (Required) Input user-defined function.
+ CommonInlineUserDefinedFunction func = 3;
+}
+
// Collect arbitrary (named) metrics from a dataset.
message CollectMetrics {
// (Required) The input relation.
diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
index c30ea8c830136..28ddbe844d445 100644
--- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
+++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
@@ -335,6 +335,7 @@ object DataTypeProtoConverter {
.setType("udt")
.setPythonClass(pyudt.pyUDT)
.setSqlType(toConnectProtoType(pyudt.sqlType))
+ .setSerializedPythonClass(pyudt.serializedPyClass)
.build())
.build()
diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_prepend.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_prepend.explain
new file mode 100644
index 0000000000000..539e1eaf767cc
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_prepend.explain
@@ -0,0 +1,2 @@
+Project [array_prepend(e#0, 1) AS array_prepend(e, 1)#0]
++- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_array_prepend.json b/connector/connect/common/src/test/resources/query-tests/queries/function_array_prepend.json
new file mode 100644
index 0000000000000..ededeb015a227
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/queries/function_array_prepend.json
@@ -0,0 +1,29 @@
+{
+ "common": {
+ "planId": "1"
+ },
+ "project": {
+ "input": {
+ "common": {
+ "planId": "0"
+ },
+ "localRelation": {
+ "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
+ }
+ },
+ "expressions": [{
+ "unresolvedFunction": {
+ "functionName": "array_prepend",
+ "arguments": [{
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "e"
+ }
+ }, {
+ "literal": {
+ "integer": 1
+ }
+ }]
+ }
+ }]
+ }
+}
\ No newline at end of file
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_array_prepend.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_array_prepend.proto.bin
new file mode 100644
index 0000000000000..837710597e7b6
Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/function_array_prepend.proto.bin differ
diff --git a/connector/connect/server/pom.xml b/connector/connect/server/pom.xml
index 079d07db362c1..4d8e082a2db57 100644
--- a/connector/connect/server/pom.xml
+++ b/connector/connect/server/pom.xml
@@ -93,6 +93,18 @@
+
+ org.apache.spark
+ spark-mllib_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+ com.google.guava
+ guava
+
+
+
org.apache.spark
spark-catalyst_${scala.binary.version}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index a057bd8d6c1e5..c8fdaa6641ab3 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -30,6 +30,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{ExecutePlanResponse, SqlCommand}
import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
import org.apache.spark.connect.proto.Parse.ParseFormat
+import org.apache.spark.ml.{functions => MLFunctions}
import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier}
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
@@ -38,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, CommandResult, Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint}
+import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, CommandResult, Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Project, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, UdfPacket}
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
@@ -116,6 +117,8 @@ class SparkConnectPlanner(val session: SparkSession) {
transformRepartitionByExpression(rel.getRepartitionByExpression)
case proto.Relation.RelTypeCase.MAP_PARTITIONS =>
transformMapPartitions(rel.getMapPartitions)
+ case proto.Relation.RelTypeCase.GROUP_MAP =>
+ transformGroupMap(rel.getGroupMap)
case proto.Relation.RelTypeCase.COLLECT_METRICS =>
transformCollectMetrics(rel.getCollectMetrics)
case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse)
@@ -494,6 +497,18 @@ class SparkConnectPlanner(val session: SparkSession) {
}
}
+ private def transformGroupMap(rel: proto.GroupMap): LogicalPlan = {
+ val pythonUdf = transformPythonUDF(rel.getFunc)
+ val cols =
+ rel.getGroupingExpressionsList.asScala.toSeq.map(expr => Column(transformExpression(expr)))
+
+ Dataset
+ .ofRows(session, transformRelation(rel.getInput))
+ .groupBy(cols: _*)
+ .flatMapGroupsInPandas(pythonUdf)
+ .logicalPlan
+ }
+
private def transformWithColumnsRenamed(rel: proto.WithColumnsRenamed): LogicalPlan = {
Dataset
.ofRows(session, transformRelation(rel.getInput))
@@ -675,16 +690,36 @@ class SparkConnectPlanner(val session: SparkSession) {
}
val attributes = structType.toAttributes
val proj = UnsafeProjection.create(attributes, attributes)
- val relation = logical.LocalRelation(attributes, rows.map(r => proj(r).copy()).toSeq)
+ val data = rows.map(proj)
if (schema == null) {
- relation
+ logical.LocalRelation(attributes, data.map(_.copy()).toSeq)
} else {
- Dataset
- .ofRows(session, logicalPlan = relation)
- .toDF(schema.names: _*)
- .to(schema)
+ def udtToSqlType(dt: DataType): DataType = dt match {
+ case udt: UserDefinedType[_] => udt.sqlType
+ case StructType(fields) =>
+ val newFields = fields.map { case StructField(name, dataType, nullable, metadata) =>
+ StructField(name, udtToSqlType(dataType), nullable, metadata)
+ }
+ StructType(newFields)
+ case ArrayType(elementType, containsNull) =>
+ ArrayType(udtToSqlType(elementType), containsNull)
+ case MapType(keyType, valueType, valueContainsNull) =>
+ MapType(udtToSqlType(keyType), udtToSqlType(valueType), valueContainsNull)
+ case _ => dt
+ }
+
+ val sqlTypeOnlySchema = udtToSqlType(schema).asInstanceOf[StructType]
+
+ val project = Dataset
+ .ofRows(session, logicalPlan = logical.LocalRelation(attributes))
+ .toDF(sqlTypeOnlySchema.names: _*)
+ .to(sqlTypeOnlySchema)
.logicalPlan
+ .asInstanceOf[Project]
+
+ val proj = UnsafeProjection.create(project.projectList, project.child.output)
+ logical.LocalRelation(schema.toAttributes, data.map(proj).map(_.copy()).toSeq)
}
} else {
if (schema == null) {
@@ -1187,10 +1222,51 @@ class SparkConnectPlanner(val session: SparkSession) {
None
}
+ // ML-specific functions
+ case "vector_to_array" if fun.getArgumentsCount == 2 =>
+ val expr = transformExpression(fun.getArguments(0))
+ val dtype = transformExpression(fun.getArguments(1)) match {
+ case Literal(s, StringType) if s != null => s.toString
+ case other =>
+ throw InvalidPlanInput(
+ s"dtype in vector_to_array should be a literal string, but got $other")
+ }
+ dtype match {
+ case "float64" =>
+ Some(transformUnregisteredUDF(MLFunctions.vectorToArrayUdf, Seq(expr)))
+ case "float32" =>
+ Some(transformUnregisteredUDF(MLFunctions.vectorToArrayFloatUdf, Seq(expr)))
+ case other =>
+ throw InvalidPlanInput(s"Unsupported dtype: $other. Valid values: float64, float32.")
+ }
+
+ case "array_to_vector" if fun.getArgumentsCount == 1 =>
+ val expr = transformExpression(fun.getArguments(0))
+ Some(transformUnregisteredUDF(MLFunctions.arrayToVectorUdf, Seq(expr)))
+
case _ => None
}
}
+ /**
+ * There are some built-in yet not registered UDFs, for example, 'ml.function.array_to_vector'.
+ * This method is to convert them to ScalaUDF expressions.
+ */
+ private def transformUnregisteredUDF(
+ fun: org.apache.spark.sql.expressions.UserDefinedFunction,
+ exprs: Seq[Expression]): ScalaUDF = {
+ val f = fun.asInstanceOf[org.apache.spark.sql.expressions.SparkUserDefinedFunction]
+ ScalaUDF(
+ function = f.f,
+ dataType = f.dataType,
+ children = exprs,
+ inputEncoders = f.inputEncoders,
+ outputEncoder = f.outputEncoder,
+ udfName = f.name,
+ nullable = f.nullable,
+ udfDeterministic = f.deterministic)
+ }
+
private def transformAlias(alias: proto.Expression.Alias): NamedExpression = {
if (alias.getNameCount == 1) {
val metadata = if (alias.hasMetadata() && alias.getMetadata.nonEmpty) {
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
index 104d840ed52bd..335b871d499be 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala
@@ -28,6 +28,7 @@ import org.apache.spark.connect.proto.{ExecutePlanRequest, ExecutePlanResponse}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
@@ -60,6 +61,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp
// Extract the plan from the request and convert it to a logical plan
val planner = new SparkConnectPlanner(session)
val dataframe = Dataset.ofRows(session, planner.transformRelation(request.getPlan.getRoot))
+ responseObserver.onNext(
+ SparkConnectStreamHandler.sendSchemaToResponse(request.getSessionId, dataframe.schema))
processAsArrowBatches(request.getSessionId, dataframe, responseObserver)
responseObserver.onNext(
SparkConnectStreamHandler.sendMetricsToResponse(request.getSessionId, dataframe))
@@ -203,6 +206,15 @@ object SparkConnectStreamHandler {
}
}
+ def sendSchemaToResponse(sessionId: String, schema: StructType): ExecutePlanResponse = {
+ // Send the Spark data type
+ ExecutePlanResponse
+ .newBuilder()
+ .setSessionId(sessionId)
+ .setSchema(DataTypeProtoConverter.toConnectProtoType(schema))
+ .build()
+ }
+
def sendMetricsToResponse(sessionId: String, rows: DataFrame): ExecutePlanResponse = {
// Send a last batch with the metrics
ExecutePlanResponse
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index e2aecaaea8602..c36ba76f98451 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -160,18 +160,22 @@ class SparkConnectServiceSuite extends SharedSparkSession {
assert(done)
// 4 Partitions + Metrics
- assert(responses.size == 5)
+ assert(responses.size == 6)
+
+ // Make sure the first response is schema only
+ val head = responses.head
+ assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics)
// Make sure the last response is metrics only
val last = responses.last
- assert(last.hasMetrics && !last.hasArrowBatch)
+ assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch)
val allocator = new RootAllocator()
// Check the 'data' batches
var expectedId = 0L
var previousEId = 0.0d
- responses.dropRight(1).foreach { response =>
+ responses.tail.dropRight(1).foreach { response =>
assert(response.hasArrowBatch)
val batch = response.getArrowBatch
assert(batch.getData != null)
@@ -347,11 +351,15 @@ class SparkConnectServiceSuite extends SharedSparkSession {
// The current implementation is expected to be blocking. This is here to make sure it is.
assert(done)
- assert(responses.size == 6)
+ assert(responses.size == 7)
+
+ // Make sure the first response is schema only
+ val head = responses.head
+ assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics)
// Make sure the last response is observed metrics only
val last = responses.last
- assert(last.getObservedMetricsCount == 1 && !last.hasArrowBatch)
+ assert(last.getObservedMetricsCount == 1 && !last.hasSchema && !last.hasArrowBatch)
val observedMetricsList = last.getObservedMetricsList.asScala
val observedMetric = observedMetricsList.head
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
index 6a42158f5876a..291276c198144 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala
@@ -38,6 +38,17 @@ import org.apache.spark.tags.DockerTest
*/
@DockerTest
class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
+
+ override def excluded: Seq[String] = Seq(
+ "scan with aggregate push-down: COVAR_POP with DISTINCT",
+ "scan with aggregate push-down: COVAR_SAMP with DISTINCT",
+ "scan with aggregate push-down: CORR with DISTINCT",
+ "scan with aggregate push-down: CORR without DISTINCT",
+ "scan with aggregate push-down: REGR_INTERCEPT with DISTINCT",
+ "scan with aggregate push-down: REGR_SLOPE with DISTINCT",
+ "scan with aggregate push-down: REGR_R2 with DISTINCT",
+ "scan with aggregate push-down: REGR_SXY with DISTINCT")
+
override val catalogName: String = "db2"
override val namespaceOpt: Option[String] = Some("DB2INST1")
override val db = new DatabaseOnDocker {
@@ -97,23 +108,4 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
}
override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT)
-
- testOffset()
- testLimitAndOffset()
- testPaging()
-
- testVarPop()
- testVarPop(true)
- testVarSamp()
- testVarSamp(true)
- testStddevPop()
- testStddevPop(true)
- testStddevSamp()
- testStddevSamp(true)
- testCovarPop()
- testCovarSamp()
- testRegrIntercept()
- testRegrSlope()
- testRegrR2()
- testRegrSXY()
}
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala
index f0e98fc2722b0..f53dc1d5f6da7 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala
@@ -68,7 +68,4 @@ class DB2NamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceT
}
override val supportsDropSchemaCascade: Boolean = false
-
- testListNamespaces()
- testDropNamespaces()
}
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
index 6b8d62f8f7b1d..107e28d1b3828 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala
@@ -39,8 +39,27 @@ import org.apache.spark.tags.DockerTest
@DockerTest
class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
- override val catalogName: String = "mssql"
+ override def excluded: Seq[String] = Seq(
+ "simple scan with OFFSET",
+ "simple scan with LIMIT and OFFSET",
+ "simple scan with paging: top N and OFFSET",
+ "scan with aggregate push-down: VAR_POP with DISTINCT",
+ "scan with aggregate push-down: COVAR_POP with DISTINCT",
+ "scan with aggregate push-down: COVAR_POP without DISTINCT",
+ "scan with aggregate push-down: COVAR_SAMP with DISTINCT",
+ "scan with aggregate push-down: COVAR_SAMP without DISTINCT",
+ "scan with aggregate push-down: CORR with DISTINCT",
+ "scan with aggregate push-down: CORR without DISTINCT",
+ "scan with aggregate push-down: REGR_INTERCEPT with DISTINCT",
+ "scan with aggregate push-down: REGR_INTERCEPT without DISTINCT",
+ "scan with aggregate push-down: REGR_SLOPE with DISTINCT",
+ "scan with aggregate push-down: REGR_SLOPE without DISTINCT",
+ "scan with aggregate push-down: REGR_R2 with DISTINCT",
+ "scan with aggregate push-down: REGR_R2 without DISTINCT",
+ "scan with aggregate push-down: REGR_SXY with DISTINCT",
+ "scan with aggregate push-down: REGR_SXY without DISTINCT")
+ override val catalogName: String = "mssql"
override val db = new DatabaseOnDocker {
override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME",
"mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04")
@@ -97,13 +116,4 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
assert(msg.contains("UpdateColumnNullability is not supported"))
}
-
- testVarPop()
- testVarPop(true)
- testVarSamp()
- testVarSamp(true)
- testStddevPop()
- testStddevPop(true)
- testStddevSamp()
- testStddevSamp(true)
}
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala
index aa8dac266380a..b0a2d37e465ac 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala
@@ -70,7 +70,4 @@ class MsSqlServerNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNa
override val supportsSchemaComment: Boolean = false
override val supportsDropSchemaCascade: Boolean = false
-
- testListNamespaces()
- testDropNamespaces()
}
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala
index 41a42e21f44d5..789dfeddc214c 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala
@@ -37,6 +37,27 @@ import org.apache.spark.tags.DockerTest
*/
@DockerTest
class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
+
+ override def excluded: Seq[String] = Seq(
+ "scan with aggregate push-down: VAR_POP with DISTINCT",
+ "scan with aggregate push-down: VAR_SAMP with DISTINCT",
+ "scan with aggregate push-down: STDDEV_POP with DISTINCT",
+ "scan with aggregate push-down: STDDEV_SAMP with DISTINCT",
+ "scan with aggregate push-down: COVAR_POP with DISTINCT",
+ "scan with aggregate push-down: COVAR_POP without DISTINCT",
+ "scan with aggregate push-down: COVAR_SAMP with DISTINCT",
+ "scan with aggregate push-down: COVAR_SAMP without DISTINCT",
+ "scan with aggregate push-down: CORR with DISTINCT",
+ "scan with aggregate push-down: CORR without DISTINCT",
+ "scan with aggregate push-down: REGR_INTERCEPT with DISTINCT",
+ "scan with aggregate push-down: REGR_INTERCEPT without DISTINCT",
+ "scan with aggregate push-down: REGR_SLOPE with DISTINCT",
+ "scan with aggregate push-down: REGR_SLOPE without DISTINCT",
+ "scan with aggregate push-down: REGR_R2 with DISTINCT",
+ "scan with aggregate push-down: REGR_R2 without DISTINCT",
+ "scan with aggregate push-down: REGR_SXY with DISTINCT",
+ "scan with aggregate push-down: REGR_SXY without DISTINCT")
+
override val catalogName: String = "mysql"
override val db = new DatabaseOnDocker {
override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:8.0.31")
@@ -124,13 +145,4 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest
override def supportListIndexes: Boolean = true
override def indexOptions: String = "KEY_BLOCK_SIZE=10"
-
- testOffset()
- testLimitAndOffset()
- testPaging()
-
- testVarPop()
- testVarSamp()
- testStddevPop()
- testStddevSamp()
}
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala
index b73e2b8fd23ca..0974a86fe9b83 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala
@@ -68,9 +68,6 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac
override val supportsDropSchemaRestrict: Boolean = false
- testListNamespaces()
- testDropNamespaces()
-
test("Create or remove comment of namespace unsupported") {
val e1 = intercept[AnalysisException] {
catalog.createNamespace(Array("foo"), Map("comment" -> "test comment").asJava)
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
index a810602652766..f9923ef9e1c10 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala
@@ -56,6 +56,20 @@ import org.apache.spark.tags.DockerTest
*/
@DockerTest
class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
+
+ override def excluded: Seq[String] = Seq(
+ "scan with aggregate push-down: VAR_POP with DISTINCT",
+ "scan with aggregate push-down: VAR_SAMP with DISTINCT",
+ "scan with aggregate push-down: STDDEV_POP with DISTINCT",
+ "scan with aggregate push-down: STDDEV_SAMP with DISTINCT",
+ "scan with aggregate push-down: COVAR_POP with DISTINCT",
+ "scan with aggregate push-down: COVAR_SAMP with DISTINCT",
+ "scan with aggregate push-down: CORR with DISTINCT",
+ "scan with aggregate push-down: REGR_INTERCEPT with DISTINCT",
+ "scan with aggregate push-down: REGR_SLOPE with DISTINCT",
+ "scan with aggregate push-down: REGR_R2 with DISTINCT",
+ "scan with aggregate push-down: REGR_SXY with DISTINCT")
+
override val catalogName: String = "oracle"
override val namespaceOpt: Option[String] = Some("SYSTEM")
override val db = new DatabaseOnDocker {
@@ -105,20 +119,4 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes
}
override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT)
-
- testOffset()
- testLimitAndOffset()
- testPaging()
-
- testVarPop()
- testVarSamp()
- testStddevPop()
- testStddevSamp()
- testCovarPop()
- testCovarSamp()
- testCorr()
- testRegrIntercept()
- testRegrSlope()
- testRegrR2()
- testRegrSXY()
}
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleNamespaceSuite.scala
index b3e9d19a10f38..a365a1c4e82e4 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleNamespaceSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleNamespaceSuite.scala
@@ -52,6 +52,9 @@ import org.apache.spark.tags.DockerTest
*/
@DockerTest
class OracleNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest {
+
+ override def excluded: Seq[String] = Seq("listNamespaces: basic behavior", "Drop namespace")
+
override val db = new DatabaseOnDocker {
lazy override val imageName =
sys.env.getOrElse("ORACLE_DOCKER_IMAGE_NAME", "gvenzl/oracle-xe:21.3.0")
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
index 4065dbcc036f6..4742764021bf5 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala
@@ -90,31 +90,4 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT
override def supportsIndex: Boolean = true
override def indexOptions: String = "FILLFACTOR=70"
-
- testOffset()
- testLimitAndOffset()
- testPaging()
-
- testVarPop()
- testVarPop(true)
- testVarSamp()
- testVarSamp(true)
- testStddevPop()
- testStddevPop(true)
- testStddevSamp()
- testStddevSamp(true)
- testCovarPop()
- testCovarPop(true)
- testCovarSamp()
- testCovarSamp(true)
- testCorr()
- testCorr(true)
- testRegrIntercept()
- testRegrIntercept(true)
- testRegrSlope()
- testRegrSlope(true)
- testRegrR2()
- testRegrR2(true)
- testRegrSXY()
- testRegrSXY(true)
}
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala
index 8c525717758c3..cf7266e67e325 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala
@@ -55,7 +55,4 @@ class PostgresNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNames
override def builtinNamespaces: Array[Array[String]] =
Array(Array("information_schema"), Array("pg_catalog"), Array("public"))
-
- testListNamespaces()
- testDropNamespaces()
}
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala
index d3f17187a3754..b7c6e0aff20a7 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala
@@ -55,83 +55,79 @@ private[v2] trait V2JDBCNamespaceTest extends SharedSparkSession with DockerInte
def supportsDropSchemaRestrict: Boolean = true
- def testListNamespaces(): Unit = {
- test("listNamespaces: basic behavior") {
- val commentMap = if (supportsSchemaComment) {
- Map("comment" -> "test comment")
- } else {
- Map.empty[String, String]
- }
- catalog.createNamespace(Array("foo"), commentMap.asJava)
- assert(catalog.listNamespaces().map(_.toSet).toSet ===
- listNamespaces(Array("foo")).map(_.toSet).toSet)
- assert(catalog.listNamespaces(Array("foo")) === Array())
- assert(catalog.namespaceExists(Array("foo")) === true)
-
- if (supportsSchemaComment) {
- val logAppender = new LogAppender("catalog comment")
- withLogAppender(logAppender) {
- catalog.alterNamespace(Array("foo"), NamespaceChange
- .setProperty("comment", "comment for foo"))
- catalog.alterNamespace(Array("foo"), NamespaceChange.removeProperty("comment"))
- }
- val createCommentWarning = logAppender.loggingEvents
- .filter(_.getLevel == Level.WARN)
- .map(_.getMessage.getFormattedMessage)
- .exists(_.contains("catalog comment"))
- assert(createCommentWarning === false)
+ test("listNamespaces: basic behavior") {
+ val commentMap = if (supportsSchemaComment) {
+ Map("comment" -> "test comment")
+ } else {
+ Map.empty[String, String]
+ }
+ catalog.createNamespace(Array("foo"), commentMap.asJava)
+ assert(catalog.listNamespaces().map(_.toSet).toSet ===
+ listNamespaces(Array("foo")).map(_.toSet).toSet)
+ assert(catalog.listNamespaces(Array("foo")) === Array())
+ assert(catalog.namespaceExists(Array("foo")) === true)
+
+ if (supportsSchemaComment) {
+ val logAppender = new LogAppender("catalog comment")
+ withLogAppender(logAppender) {
+ catalog.alterNamespace(Array("foo"), NamespaceChange
+ .setProperty("comment", "comment for foo"))
+ catalog.alterNamespace(Array("foo"), NamespaceChange.removeProperty("comment"))
}
+ val createCommentWarning = logAppender.loggingEvents
+ .filter(_.getLevel == Level.WARN)
+ .map(_.getMessage.getFormattedMessage)
+ .exists(_.contains("catalog comment"))
+ assert(createCommentWarning === false)
+ }
- if (supportsDropSchemaRestrict) {
- catalog.dropNamespace(Array("foo"), cascade = false)
- } else {
- catalog.dropNamespace(Array("foo"), cascade = true)
- }
- assert(catalog.namespaceExists(Array("foo")) === false)
- assert(catalog.listNamespaces() === builtinNamespaces)
- val e = intercept[AnalysisException] {
- catalog.listNamespaces(Array("foo"))
- }
- checkError(e,
- errorClass = "SCHEMA_NOT_FOUND",
- parameters = Map("schemaName" -> "`foo`"))
+ if (supportsDropSchemaRestrict) {
+ catalog.dropNamespace(Array("foo"), cascade = false)
+ } else {
+ catalog.dropNamespace(Array("foo"), cascade = true)
+ }
+ assert(catalog.namespaceExists(Array("foo")) === false)
+ assert(catalog.listNamespaces() === builtinNamespaces)
+ val e = intercept[AnalysisException] {
+ catalog.listNamespaces(Array("foo"))
}
+ checkError(e,
+ errorClass = "SCHEMA_NOT_FOUND",
+ parameters = Map("schemaName" -> "`foo`"))
}
- def testDropNamespaces(): Unit = {
- test("Drop namespace") {
- val ident1 = Identifier.of(Array("foo"), "tab")
- // Drop empty namespace without cascade
- val commentMap = if (supportsSchemaComment) {
- Map("comment" -> "test comment")
- } else {
- Map.empty[String, String]
- }
- catalog.createNamespace(Array("foo"), commentMap.asJava)
- assert(catalog.namespaceExists(Array("foo")) === true)
- if (supportsDropSchemaRestrict) {
+ test("Drop namespace") {
+ val ident1 = Identifier.of(Array("foo"), "tab")
+ // Drop empty namespace without cascade
+ val commentMap = if (supportsSchemaComment) {
+ Map("comment" -> "test comment")
+ } else {
+ Map.empty[String, String]
+ }
+ catalog.createNamespace(Array("foo"), commentMap.asJava)
+ assert(catalog.namespaceExists(Array("foo")) === true)
+ if (supportsDropSchemaRestrict) {
+ catalog.dropNamespace(Array("foo"), cascade = false)
+ } else {
+ catalog.dropNamespace(Array("foo"), cascade = true)
+ }
+ assert(catalog.namespaceExists(Array("foo")) === false)
+
+ // Drop non empty namespace without cascade
+ catalog.createNamespace(Array("foo"), commentMap.asJava)
+ assert(catalog.namespaceExists(Array("foo")) === true)
+ catalog.createTable(ident1, schema, Array.empty[Transform], emptyProps)
+ if (supportsDropSchemaRestrict) {
+ intercept[NonEmptyNamespaceException] {
catalog.dropNamespace(Array("foo"), cascade = false)
- } else {
- catalog.dropNamespace(Array("foo"), cascade = true)
}
- assert(catalog.namespaceExists(Array("foo")) === false)
+ }
- // Drop non empty namespace without cascade
- catalog.createNamespace(Array("foo"), commentMap.asJava)
+ // Drop non empty namespace with cascade
+ if (supportsDropSchemaCascade) {
assert(catalog.namespaceExists(Array("foo")) === true)
- catalog.createTable(ident1, schema, Array.empty[Transform], emptyProps)
- if (supportsDropSchemaRestrict) {
- intercept[NonEmptyNamespaceException] {
- catalog.dropNamespace(Array("foo"), cascade = false)
- }
- }
-
- // Drop non empty namespace with cascade
- if (supportsDropSchemaCascade) {
- assert(catalog.namespaceExists(Array("foo")) === true)
- catalog.dropNamespace(Array("foo"), cascade = true)
- assert(catalog.namespaceExists(Array("foo")) === false)
- }
+ catalog.dropNamespace(Array("foo"), cascade = true)
+ assert(catalog.namespaceExists(Array("foo")) === false)
}
}
}
diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
index 97ee338509031..85b0b807932aa 100644
--- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
+++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala
@@ -314,14 +314,13 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
}
}
- private def limitPushed(df: DataFrame, limit: Int): Boolean = {
+ private def checkLimitPushed(df: DataFrame, limit: Option[Int]): Unit = {
df.queryExecution.optimizedPlan.collect {
case relation: DataSourceV2ScanRelation => relation.scan match {
case v1: V1ScanWrapper =>
- return v1.pushedDownOperators.limit == Some(limit)
+ assert(v1.pushedDownOperators.limit == limit)
}
}
- false
}
private def checkColumnPruned(df: DataFrame, col: String): Unit = {
@@ -354,7 +353,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
val df3 = sql(s"SELECT col1 FROM $catalogName.new_table TABLESAMPLE (BUCKET 6 OUT OF 10)" +
" LIMIT 2")
checkSamplePushed(df3)
- assert(limitPushed(df3, 2))
+ checkLimitPushed(df3, Some(2))
checkColumnPruned(df3, "col1")
assert(df3.collect().length <= 2)
@@ -362,7 +361,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
val df4 = sql(s"SELECT col1 FROM $catalogName.new_table" +
" TABLESAMPLE (50 PERCENT) REPEATABLE (12345) LIMIT 2")
checkSamplePushed(df4)
- assert(limitPushed(df4, 2))
+ checkLimitPushed(df4, Some(2))
checkColumnPruned(df4, "col1")
assert(df4.collect().length <= 2)
@@ -371,7 +370,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
" TABLESAMPLE (BUCKET 6 OUT OF 10) WHERE col1 > 0 LIMIT 2")
checkSamplePushed(df5)
checkFilterPushed(df5)
- assert(limitPushed(df5, 2))
+ checkLimitPushed(df5, Some(2))
assert(df5.collect().length <= 2)
// sample + filter + limit + column pruning
@@ -381,7 +380,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
" TABLESAMPLE (BUCKET 6 OUT OF 10) WHERE col1 > 0 LIMIT 2")
checkSamplePushed(df6)
checkFilterPushed(df6, false)
- assert(!limitPushed(df6, 2))
+ checkLimitPushed(df6, None)
checkColumnPruned(df6, "col1")
assert(df6.collect().length <= 2)
@@ -390,7 +389,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
// only limit is pushed down because in this test sample is after limit
val df7 = spark.read.table(s"$catalogName.new_table").limit(2).sample(0.5)
checkSamplePushed(df7, false)
- assert(limitPushed(df7, 2))
+ checkLimitPushed(df7, Some(2))
// sample + filter
// Push down order is sample -> filter -> limit
@@ -422,7 +421,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
test("simple scan with LIMIT") {
val df = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 LIMIT 1")
- assert(limitPushed(df, 1))
+ checkLimitPushed(df, Some(1))
val rows = df.collect()
assert(rows.length === 1)
assert(rows(0).getString(0) === "amy")
@@ -434,7 +433,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
Seq(NullOrdering.values()).flatten.foreach { nullOrdering =>
val df1 = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 ORDER BY salary $nullOrdering LIMIT 1")
- assert(limitPushed(df1, 1))
+ checkLimitPushed(df1, Some(1))
checkSortRemoved(df1)
val rows1 = df1.collect()
assert(rows1.length === 1)
@@ -444,7 +443,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
val df2 = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 ORDER BY bonus DESC $nullOrdering LIMIT 1")
- assert(limitPushed(df2, 1))
+ checkLimitPushed(df2, Some(1))
checkSortRemoved(df2)
val rows2 = df2.collect()
assert(rows2.length === 1)
@@ -454,60 +453,54 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
}
}
- protected def testOffset(): Unit = {
- test("simple scan with OFFSET") {
- val df = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." +
- s"${caseConvert("employee")} WHERE dept > 0 OFFSET 4")
- checkOffsetPushed(df, Some(4))
- val rows = df.collect()
- assert(rows.length === 1)
- assert(rows(0).getString(0) === "jen")
- assert(rows(0).getDecimal(1) === new java.math.BigDecimal("12000.00"))
- assert(rows(0).getDouble(2) === 1200d)
- }
+ test("simple scan with OFFSET") {
+ val df = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." +
+ s"${caseConvert("employee")} WHERE dept > 0 OFFSET 4")
+ checkOffsetPushed(df, Some(4))
+ val rows = df.collect()
+ assert(rows.length === 1)
+ assert(rows(0).getString(0) === "jen")
+ assert(rows(0).getDecimal(1) === new java.math.BigDecimal("12000.00"))
+ assert(rows(0).getDouble(2) === 1200d)
}
- protected def testLimitAndOffset(): Unit = {
- test("simple scan with LIMIT and OFFSET") {
- val df = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." +
- s"${caseConvert("employee")} WHERE dept > 0 LIMIT 1 OFFSET 2")
- assert(limitPushed(df, 3))
- checkOffsetPushed(df, Some(2))
- val rows = df.collect()
- assert(rows.length === 1)
- assert(rows(0).getString(0) === "cathy")
- assert(rows(0).getDecimal(1) === new java.math.BigDecimal("9000.00"))
- assert(rows(0).getDouble(2) === 1200d)
- }
+ test("simple scan with LIMIT and OFFSET") {
+ val df = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." +
+ s"${caseConvert("employee")} WHERE dept > 0 LIMIT 1 OFFSET 2")
+ checkLimitPushed(df, Some(3))
+ checkOffsetPushed(df, Some(2))
+ val rows = df.collect()
+ assert(rows.length === 1)
+ assert(rows(0).getString(0) === "cathy")
+ assert(rows(0).getDecimal(1) === new java.math.BigDecimal("9000.00"))
+ assert(rows(0).getDouble(2) === 1200d)
}
- protected def testPaging(): Unit = {
- test("simple scan with paging: top N and OFFSET") {
- Seq(NullOrdering.values()).flatten.foreach { nullOrdering =>
- val df1 = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." +
- s"${caseConvert("employee")}" +
- s" WHERE dept > 0 ORDER BY salary $nullOrdering, bonus LIMIT 1 OFFSET 2")
- assert(limitPushed(df1, 3))
- checkOffsetPushed(df1, Some(2))
- checkSortRemoved(df1)
- val rows1 = df1.collect()
- assert(rows1.length === 1)
- assert(rows1(0).getString(0) === "david")
- assert(rows1(0).getDecimal(1) === new java.math.BigDecimal("10000.00"))
- assert(rows1(0).getDouble(2) === 1300d)
-
- val df2 = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." +
- s"${caseConvert("employee")}" +
- s" WHERE dept > 0 ORDER BY salary DESC $nullOrdering, bonus LIMIT 1 OFFSET 2")
- assert(limitPushed(df2, 3))
- checkOffsetPushed(df2, Some(2))
- checkSortRemoved(df2)
- val rows2 = df2.collect()
- assert(rows2.length === 1)
- assert(rows2(0).getString(0) === "amy")
- assert(rows2(0).getDecimal(1) === new java.math.BigDecimal("10000.00"))
- assert(rows2(0).getDouble(2) === 1000d)
- }
+ test("simple scan with paging: top N and OFFSET") {
+ Seq(NullOrdering.values()).flatten.foreach { nullOrdering =>
+ val df1 = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." +
+ s"${caseConvert("employee")}" +
+ s" WHERE dept > 0 ORDER BY salary $nullOrdering, bonus LIMIT 1 OFFSET 2")
+ checkLimitPushed(df1, Some(3))
+ checkOffsetPushed(df1, Some(2))
+ checkSortRemoved(df1)
+ val rows1 = df1.collect()
+ assert(rows1.length === 1)
+ assert(rows1(0).getString(0) === "david")
+ assert(rows1(0).getDecimal(1) === new java.math.BigDecimal("10000.00"))
+ assert(rows1(0).getDouble(2) === 1300d)
+
+ val df2 = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." +
+ s"${caseConvert("employee")}" +
+ s" WHERE dept > 0 ORDER BY salary DESC $nullOrdering, bonus LIMIT 1 OFFSET 2")
+ checkLimitPushed(df2, Some(3))
+ checkOffsetPushed(df2, Some(2))
+ checkSortRemoved(df2)
+ val rows2 = df2.collect()
+ assert(rows2.length === 1)
+ assert(rows2(0).getString(0) === "amy")
+ assert(rows2(0).getDecimal(1) === new java.math.BigDecimal("10000.00"))
+ assert(rows2(0).getDouble(2) === 1000d)
}
}
@@ -536,9 +529,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
private def withOrWithout(isDistinct: Boolean): String = if (isDistinct) "with" else "without"
- protected def testVarPop(isDistinct: Boolean = false): Unit = {
+ Seq(true, false).foreach { isDistinct =>
val distinct = if (isDistinct) "DISTINCT " else ""
- test(s"scan with aggregate push-down: VAR_POP ${withOrWithout(isDistinct)} DISTINCT") {
+ val withOrWithout = if (isDistinct) "with" else "without"
+
+ test(s"scan with aggregate push-down: VAR_POP $withOrWithout DISTINCT") {
val df = sql(s"SELECT VAR_POP(${distinct}bonus) FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
checkFilterPushed(df)
@@ -550,14 +545,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
assert(row(1).getDouble(0) === 2500.0)
assert(row(2).getDouble(0) === 0.0)
}
- }
- protected def testVarSamp(isDistinct: Boolean = false): Unit = {
- val distinct = if (isDistinct) "DISTINCT " else ""
- test(s"scan with aggregate push-down: VAR_SAMP ${withOrWithout(isDistinct)} DISTINCT") {
+ test(s"scan with aggregate push-down: VAR_SAMP $withOrWithout DISTINCT") {
val df = sql(
s"SELECT VAR_SAMP(${distinct}bonus) FROM $catalogAndNamespace." +
- s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
+ s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
checkFilterPushed(df)
checkAggregateRemoved(df)
checkAggregatePushed(df, "VAR_SAMP")
@@ -567,14 +559,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
assert(row(1).getDouble(0) === 5000.0)
assert(row(2).isNullAt(0))
}
- }
- protected def testStddevPop(isDistinct: Boolean = false): Unit = {
- val distinct = if (isDistinct) "DISTINCT " else ""
- test(s"scan with aggregate push-down: STDDEV_POP ${withOrWithout(isDistinct)} DISTINCT") {
+ test(s"scan with aggregate push-down: STDDEV_POP $withOrWithout DISTINCT") {
val df = sql(
s"SELECT STDDEV_POP(${distinct}bonus) FROM $catalogAndNamespace." +
- s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
+ s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
checkFilterPushed(df)
checkAggregateRemoved(df)
checkAggregatePushed(df, "STDDEV_POP")
@@ -584,14 +573,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
assert(row(1).getDouble(0) === 50.0)
assert(row(2).getDouble(0) === 0.0)
}
- }
- protected def testStddevSamp(isDistinct: Boolean = false): Unit = {
- val distinct = if (isDistinct) "DISTINCT " else ""
- test(s"scan with aggregate push-down: STDDEV_SAMP ${withOrWithout(isDistinct)} DISTINCT") {
+ test(s"scan with aggregate push-down: STDDEV_SAMP $withOrWithout DISTINCT") {
val df = sql(
s"SELECT STDDEV_SAMP(${distinct}bonus) FROM $catalogAndNamespace." +
- s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
+ s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
checkFilterPushed(df)
checkAggregateRemoved(df)
checkAggregatePushed(df, "STDDEV_SAMP")
@@ -601,14 +587,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
assert(row(1).getDouble(0) === 70.71067811865476)
assert(row(2).isNullAt(0))
}
- }
- protected def testCovarPop(isDistinct: Boolean = false): Unit = {
- val distinct = if (isDistinct) "DISTINCT " else ""
- test(s"scan with aggregate push-down: COVAR_POP ${withOrWithout(isDistinct)} DISTINCT") {
+ test(s"scan with aggregate push-down: COVAR_POP $withOrWithout DISTINCT") {
val df = sql(
s"SELECT COVAR_POP(${distinct}bonus, bonus) FROM $catalogAndNamespace." +
- s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
+ s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
checkFilterPushed(df)
checkAggregateRemoved(df)
checkAggregatePushed(df, "COVAR_POP")
@@ -618,14 +601,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
assert(row(1).getDouble(0) === 2500.0)
assert(row(2).getDouble(0) === 0.0)
}
- }
- protected def testCovarSamp(isDistinct: Boolean = false): Unit = {
- val distinct = if (isDistinct) "DISTINCT " else ""
- test(s"scan with aggregate push-down: COVAR_SAMP ${withOrWithout(isDistinct)} DISTINCT") {
+ test(s"scan with aggregate push-down: COVAR_SAMP $withOrWithout DISTINCT") {
val df = sql(
s"SELECT COVAR_SAMP(${distinct}bonus, bonus) FROM $catalogAndNamespace." +
- s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
+ s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
checkFilterPushed(df)
checkAggregateRemoved(df)
checkAggregatePushed(df, "COVAR_SAMP")
@@ -635,14 +615,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
assert(row(1).getDouble(0) === 5000.0)
assert(row(2).isNullAt(0))
}
- }
- protected def testCorr(isDistinct: Boolean = false): Unit = {
- val distinct = if (isDistinct) "DISTINCT " else ""
- test(s"scan with aggregate push-down: CORR ${withOrWithout(isDistinct)} DISTINCT") {
+ test(s"scan with aggregate push-down: CORR $withOrWithout DISTINCT") {
val df = sql(
s"SELECT CORR(${distinct}bonus, bonus) FROM $catalogAndNamespace." +
- s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
+ s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
checkFilterPushed(df)
checkAggregateRemoved(df)
checkAggregatePushed(df, "CORR")
@@ -652,11 +629,8 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
assert(row(1).getDouble(0) === 1.0)
assert(row(2).isNullAt(0))
}
- }
- protected def testRegrIntercept(isDistinct: Boolean = false): Unit = {
- val distinct = if (isDistinct) "DISTINCT " else ""
- test(s"scan with aggregate push-down: REGR_INTERCEPT ${withOrWithout(isDistinct)} DISTINCT") {
+ test(s"scan with aggregate push-down: REGR_INTERCEPT $withOrWithout DISTINCT") {
val df = sql(
s"SELECT REGR_INTERCEPT(${distinct}bonus, bonus) FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
@@ -669,11 +643,8 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
assert(row(1).getDouble(0) === 0.0)
assert(row(2).isNullAt(0))
}
- }
- protected def testRegrSlope(isDistinct: Boolean = false): Unit = {
- val distinct = if (isDistinct) "DISTINCT " else ""
- test(s"scan with aggregate push-down: REGR_SLOPE ${withOrWithout(isDistinct)} DISTINCT") {
+ test(s"scan with aggregate push-down: REGR_SLOPE $withOrWithout DISTINCT") {
val df = sql(
s"SELECT REGR_SLOPE(${distinct}bonus, bonus) FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
@@ -686,11 +657,8 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
assert(row(1).getDouble(0) === 1.0)
assert(row(2).isNullAt(0))
}
- }
- protected def testRegrR2(isDistinct: Boolean = false): Unit = {
- val distinct = if (isDistinct) "DISTINCT " else ""
- test(s"scan with aggregate push-down: REGR_R2 ${withOrWithout(isDistinct)} DISTINCT") {
+ test(s"scan with aggregate push-down: REGR_R2 $withOrWithout DISTINCT") {
val df = sql(
s"SELECT REGR_R2(${distinct}bonus, bonus) FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
@@ -703,11 +671,8 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu
assert(row(1).getDouble(0) === 1.0)
assert(row(2).isNullAt(0))
}
- }
- protected def testRegrSXY(isDistinct: Boolean = false): Unit = {
- val distinct = if (isDistinct) "DISTINCT " else ""
- test(s"scan with aggregate push-down: REGR_SXY ${withOrWithout(isDistinct)} DISTINCT") {
+ test(s"scan with aggregate push-down: REGR_SXY $withOrWithout DISTINCT") {
val df = sql(
s"SELECT REGR_SXY(${distinct}bonus, bonus) FROM $catalogAndNamespace." +
s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept")
diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json
index 34026083bb9a6..0433747822391 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -1063,6 +1063,28 @@
],
"sqlState" : "42903"
},
+ "INVALID_WRITE_DISTRIBUTION" : {
+ "message" : [
+ "The requested write distribution is invalid."
+ ],
+ "subClass" : {
+ "PARTITION_NUM_AND_SIZE" : {
+ "message" : [
+ "The partition number and advisory partition size can't be specified at the same time."
+ ]
+ },
+ "PARTITION_NUM_WITH_UNSPECIFIED_DISTRIBUTION" : {
+ "message" : [
+ "The number of partitions can't be specified with unspecified distribution."
+ ]
+ },
+ "PARTITION_SIZE_WITH_UNSPECIFIED_DISTRIBUTION" : {
+ "message" : [
+ "The advisory partition size can't be specified with unspecified distribution."
+ ]
+ }
+ }
+ },
"LOCATION_ALREADY_EXISTS" : {
"message" : [
"Cannot name the managed table as , as its associated location already exists. Please pick a different table name, or remove the existing location first."
@@ -2931,11 +2953,6 @@
"Unsupported data type ."
]
},
- "_LEGACY_ERROR_TEMP_1178" : {
- "message" : [
- "The number of partitions can't be specified with unspecified distribution. Invalid writer requirements detected."
- ]
- },
"_LEGACY_ERROR_TEMP_1181" : {
"message" : [
"Stream-stream join without equality predicate is not supported."
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index 0ee0dc6ae6016..2d4624828a94d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -103,11 +103,8 @@ class TaskInfo(
// finishTime should be set larger than 0, otherwise "finished" below will return false.
assert(time > 0)
finishTime = time
- if (state == TaskState.FAILED) {
- failed = true
- } else if (state == TaskState.KILLED) {
- killed = true
- }
+ failed = state == TaskState.FAILED
+ killed = state == TaskState.KILLED
}
private[spark] def launchSucceeded(): Unit = {
diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
index 27198039fdbaa..ff12f643497d0 100644
--- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala
@@ -31,7 +31,8 @@ import org.apache.logging.log4j._
import org.apache.logging.log4j.core.{LogEvent, Logger, LoggerContext}
import org.apache.logging.log4j.core.appender.AbstractAppender
import org.apache.logging.log4j.core.config.Property
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, BeforeAndAfterEach, Failed, Outcome}
+import org.scalactic.source.Position
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, BeforeAndAfterEach, Failed, Outcome, Tag}
import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
import org.apache.spark.deploy.LocalSparkCluster
@@ -137,6 +138,19 @@ abstract class SparkFunSuite
java.nio.file.Paths.get(sparkHome, first +: more: _*)
}
+ // subclasses can override this to exclude certain tests by name
+ // useful when inheriting a test suite but do not want to run all tests in it
+ protected def excluded: Seq[String] = Seq.empty
+
+ override protected def test(testName: String, testTags: Tag*)(testBody: => Any)
+ (implicit pos: Position): Unit = {
+ if (excluded.contains(testName)) {
+ ignore(s"$testName (excluded)")(testBody)
+ } else {
+ super.test(testName, testTags: _*)(testBody)
+ }
+ }
+
/**
* Note: this method doesn't support `BeforeAndAfter`. You must use `BeforeAndAfterEach` to
* set up and tear down resources.
diff --git a/dev/deps/spark-deps-hadoop-2-hive-2.3 b/dev/deps/spark-deps-hadoop-2-hive-2.3
index d9edb110f48a2..e3d588d36cd7b 100644
--- a/dev/deps/spark-deps-hadoop-2-hive-2.3
+++ b/dev/deps/spark-deps-hadoop-2-hive-2.3
@@ -223,9 +223,9 @@ objenesis/3.2//objenesis-3.2.jar
okhttp/3.12.12//okhttp-3.12.12.jar
okio/1.15.0//okio-1.15.0.jar
opencsv/2.3//opencsv-2.3.jar
-orc-core/1.8.2/shaded-protobuf/orc-core-1.8.2-shaded-protobuf.jar
-orc-mapreduce/1.8.2/shaded-protobuf/orc-mapreduce-1.8.2-shaded-protobuf.jar
-orc-shims/1.8.2//orc-shims-1.8.2.jar
+orc-core/1.8.3/shaded-protobuf/orc-core-1.8.3-shaded-protobuf.jar
+orc-mapreduce/1.8.3/shaded-protobuf/orc-mapreduce-1.8.3-shaded-protobuf.jar
+orc-shims/1.8.3//orc-shims-1.8.3.jar
oro/2.0.8//oro-2.0.8.jar
osgi-resource-locator/1.0.3//osgi-resource-locator-1.0.3.jar
paranamer/2.8//paranamer-2.8.jar
diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3
index 5c17fff4d3789..fd32245ec2865 100644
--- a/dev/deps/spark-deps-hadoop-3-hive-2.3
+++ b/dev/deps/spark-deps-hadoop-3-hive-2.3
@@ -210,9 +210,9 @@ opencsv/2.3//opencsv-2.3.jar
opentracing-api/0.33.0//opentracing-api-0.33.0.jar
opentracing-noop/0.33.0//opentracing-noop-0.33.0.jar
opentracing-util/0.33.0//opentracing-util-0.33.0.jar
-orc-core/1.8.2/shaded-protobuf/orc-core-1.8.2-shaded-protobuf.jar
-orc-mapreduce/1.8.2/shaded-protobuf/orc-mapreduce-1.8.2-shaded-protobuf.jar
-orc-shims/1.8.2//orc-shims-1.8.2.jar
+orc-core/1.8.3/shaded-protobuf/orc-core-1.8.3-shaded-protobuf.jar
+orc-mapreduce/1.8.3/shaded-protobuf/orc-mapreduce-1.8.3-shaded-protobuf.jar
+orc-shims/1.8.3//orc-shims-1.8.3.jar
oro/2.0.8//oro-2.0.8.jar
osgi-resource-locator/1.0.3//osgi-resource-locator-1.0.3.jar
paranamer/2.8//paranamer-2.8.jar
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 751f0687f2c8e..c31a9362cd7fb 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -271,12 +271,36 @@ def __hash__(self):
],
)
+mllib_local = Module(
+ name="mllib-local",
+ dependencies=[tags, core],
+ source_file_regexes=[
+ "mllib/local",
+ ],
+ sbt_test_goals=[
+ "mllib-local/test",
+ ],
+)
+
+
+mllib_common = Module(
+ name="mllib-common",
+ dependencies=[tags, mllib_local, sql],
+ source_file_regexes=[
+ "mllib/common",
+ ],
+ sbt_test_goals=[
+ "mllib-common/test",
+ ],
+)
+
connect = Module(
name="connect",
- dependencies=[hive],
+ dependencies=[hive, mllib_common],
source_file_regexes=[
"connector/connect",
],
+ build_profile_flags=["-Pconnect"],
sbt_test_goals=[
"connect/test",
"connect-client-jvm/test",
@@ -358,24 +382,12 @@ def __hash__(self):
)
-mllib_local = Module(
- name="mllib-local",
- dependencies=[tags, core],
- source_file_regexes=[
- "mllib-local",
- ],
- sbt_test_goals=[
- "mllib-local/test",
- ],
-)
-
-
mllib = Module(
name="mllib",
- dependencies=[mllib_local, streaming, sql],
+ dependencies=[mllib_local, mllib_common, streaming, sql],
source_file_regexes=[
"data/mllib/",
- "mllib/",
+ "mllib/core/",
],
sbt_test_goals=[
"mllib/test",
@@ -501,48 +513,6 @@ def __hash__(self):
],
)
-pyspark_connect = Module(
- name="pyspark-connect",
- dependencies=[pyspark_sql, connect],
- source_file_regexes=["python/pyspark/sql/connect"],
- python_test_goals=[
- # doctests
- "pyspark.sql.connect.catalog",
- "pyspark.sql.connect.conf",
- "pyspark.sql.connect.group",
- "pyspark.sql.connect.session",
- "pyspark.sql.connect.window",
- "pyspark.sql.connect.column",
- "pyspark.sql.connect.readwriter",
- "pyspark.sql.connect.dataframe",
- "pyspark.sql.connect.functions",
- # unittests
- "pyspark.sql.tests.connect.test_client",
- "pyspark.sql.tests.connect.test_connect_plan",
- "pyspark.sql.tests.connect.test_connect_basic",
- "pyspark.sql.tests.connect.test_connect_function",
- "pyspark.sql.tests.connect.test_connect_column",
- "pyspark.sql.tests.connect.test_parity_datasources",
- "pyspark.sql.tests.connect.test_parity_catalog",
- "pyspark.sql.tests.connect.test_parity_conf",
- "pyspark.sql.tests.connect.test_parity_serde",
- "pyspark.sql.tests.connect.test_parity_functions",
- "pyspark.sql.tests.connect.test_parity_group",
- "pyspark.sql.tests.connect.test_parity_dataframe",
- "pyspark.sql.tests.connect.test_parity_types",
- "pyspark.sql.tests.connect.test_parity_column",
- "pyspark.sql.tests.connect.test_parity_readwriter",
- "pyspark.sql.tests.connect.test_parity_udf",
- "pyspark.sql.tests.connect.test_parity_pandas_udf",
- "pyspark.sql.tests.connect.test_parity_pandas_map",
- "pyspark.sql.tests.connect.test_parity_arrow_map",
- ],
- excluded_python_implementations=[
- "PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
- # they aren't available there
- ],
-)
-
pyspark_resource = Module(
name="pyspark-resource",
dependencies=[pyspark_core],
@@ -769,6 +739,58 @@ def __hash__(self):
],
)
+
+pyspark_connect = Module(
+ name="pyspark-connect",
+ dependencies=[pyspark_sql, pyspark_ml, connect],
+ source_file_regexes=[
+ "python/pyspark/sql/connect",
+ "python/pyspark/ml/connect",
+ ],
+ python_test_goals=[
+ # sql doctests
+ "pyspark.sql.connect.catalog",
+ "pyspark.sql.connect.conf",
+ "pyspark.sql.connect.group",
+ "pyspark.sql.connect.session",
+ "pyspark.sql.connect.window",
+ "pyspark.sql.connect.column",
+ "pyspark.sql.connect.readwriter",
+ "pyspark.sql.connect.dataframe",
+ "pyspark.sql.connect.functions",
+ # sql unittests
+ "pyspark.sql.tests.connect.test_client",
+ "pyspark.sql.tests.connect.test_connect_plan",
+ "pyspark.sql.tests.connect.test_connect_basic",
+ "pyspark.sql.tests.connect.test_connect_function",
+ "pyspark.sql.tests.connect.test_connect_column",
+ "pyspark.sql.tests.connect.test_parity_datasources",
+ "pyspark.sql.tests.connect.test_parity_catalog",
+ "pyspark.sql.tests.connect.test_parity_conf",
+ "pyspark.sql.tests.connect.test_parity_serde",
+ "pyspark.sql.tests.connect.test_parity_functions",
+ "pyspark.sql.tests.connect.test_parity_group",
+ "pyspark.sql.tests.connect.test_parity_dataframe",
+ "pyspark.sql.tests.connect.test_parity_types",
+ "pyspark.sql.tests.connect.test_parity_column",
+ "pyspark.sql.tests.connect.test_parity_readwriter",
+ "pyspark.sql.tests.connect.test_parity_udf",
+ "pyspark.sql.tests.connect.test_parity_pandas_udf",
+ "pyspark.sql.tests.connect.test_parity_pandas_map",
+ "pyspark.sql.tests.connect.test_parity_arrow_map",
+ "pyspark.sql.tests.connect.test_parity_pandas_grouped_map",
+ # ml doctests
+ "pyspark.ml.connect.functions",
+ # ml unittests
+ "pyspark.ml.tests.connect.test_connect_function",
+ ],
+ excluded_python_implementations=[
+ "PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
+ # they aren't available there
+ ],
+)
+
+
pyspark_errors = Module(
name="pyspark-errors",
dependencies=[],
diff --git a/dev/sparktestsupport/utils.py b/dev/sparktestsupport/utils.py
index 6b190eb5ab27a..5c270d0948eca 100755
--- a/dev/sparktestsupport/utils.py
+++ b/dev/sparktestsupport/utils.py
@@ -112,22 +112,25 @@ def determine_modules_to_test(changed_modules, deduplicated=True):
>>> sorted([x.name for x in determine_modules_to_test([modules.sql])])
... # doctest: +NORMALIZE_WHITESPACE
['avro', 'connect', 'docker-integration-tests', 'examples', 'hive', 'hive-thriftserver',
- 'mllib', 'protobuf', 'pyspark-connect', 'pyspark-ml', 'pyspark-mllib', 'pyspark-pandas',
- 'pyspark-pandas-slow', 'pyspark-sql', 'repl', 'sparkr', 'sql', 'sql-kafka-0-10']
+ 'mllib', 'mllib-common', 'protobuf', 'pyspark-connect', 'pyspark-ml', 'pyspark-mllib',
+ 'pyspark-pandas', 'pyspark-pandas-slow', 'pyspark-sql', 'repl', 'sparkr', 'sql',
+ 'sql-kafka-0-10']
>>> sorted([x.name for x in determine_modules_to_test(
... [modules.sparkr, modules.sql], deduplicated=False)])
... # doctest: +NORMALIZE_WHITESPACE
['avro', 'connect', 'docker-integration-tests', 'examples', 'hive', 'hive-thriftserver',
- 'mllib', 'protobuf', 'pyspark-connect', 'pyspark-ml', 'pyspark-mllib', 'pyspark-pandas',
- 'pyspark-pandas-slow', 'pyspark-sql', 'repl', 'sparkr', 'sql', 'sql-kafka-0-10']
+ 'mllib', 'mllib-common', 'protobuf', 'pyspark-connect', 'pyspark-ml', 'pyspark-mllib',
+ 'pyspark-pandas', 'pyspark-pandas-slow', 'pyspark-sql', 'repl', 'sparkr', 'sql',
+ 'sql-kafka-0-10']
>>> sorted([x.name for x in determine_modules_to_test(
... [modules.sql, modules.core], deduplicated=False)])
... # doctest: +NORMALIZE_WHITESPACE
['avro', 'catalyst', 'connect', 'core', 'docker-integration-tests', 'examples', 'graphx',
- 'hive', 'hive-thriftserver', 'mllib', 'mllib-local', 'protobuf', 'pyspark-connect',
- 'pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-pandas', 'pyspark-pandas-slow',
- 'pyspark-resource', 'pyspark-sql', 'pyspark-streaming', 'repl', 'root', 'sparkr', 'sql',
- 'sql-kafka-0-10', 'streaming', 'streaming-kafka-0-10', 'streaming-kinesis-asl']
+ 'hive', 'hive-thriftserver', 'mllib', 'mllib-common', 'mllib-local', 'protobuf',
+ 'pyspark-connect', 'pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-pandas',
+ 'pyspark-pandas-slow', 'pyspark-resource', 'pyspark-sql', 'pyspark-streaming', 'repl',
+ 'root', 'sparkr', 'sql', 'sql-kafka-0-10', 'streaming', 'streaming-kafka-0-10',
+ 'streaming-kinesis-asl']
"""
modules_to_test = set()
for module in changed_modules:
diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html
index d44639227665d..9b7c469246165 100755
--- a/docs/_layouts/global.html
+++ b/docs/_layouts/global.html
@@ -7,6 +7,8 @@
+
+
{{ page.title }} - Spark {{site.SPARK_VERSION_SHORT}} Documentation
{% if page.description %}
@@ -17,16 +19,13 @@
{% endif %}
-
-
-
-
+
+
+
+
+
@@ -34,96 +33,118 @@
-
+
-