diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index ebc40e9279137..9f96f5beab3b2 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -149,7 +149,7 @@ jobs: catalyst, hive-thriftserver - >- streaming, sql-kafka-0-10, streaming-kafka-0-10, - mllib-local, mllib, + mllib-local, mllib-common, mllib, yarn, mesos, kubernetes, hadoop-cloud, spark-ganglia-lgpl, connect, protobuf # Here, we split Hive and SQL tests into some of slow ones and the rest of them. diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java index a15d07cf59958..bf7c256fc94ff 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java @@ -56,7 +56,7 @@ public KVTypeInfo(Class type) { KVIndex idx = m.getAnnotation(KVIndex.class); if (idx != null) { checkIndex(idx, indices); - Preconditions.checkArgument(m.getParameterTypes().length == 0, + Preconditions.checkArgument(m.getParameterCount() == 0, "Annotated method %s::%s should not have any parameters.", type.getName(), m.getName()); m.setAccessible(true); indices.put(idx.value(), idx); diff --git a/connector/connect/client/jvm/pom.xml b/connector/connect/client/jvm/pom.xml index 7606795f8203a..ac4b1655f5ea0 100644 --- a/connector/connect/client/jvm/pom.xml +++ b/connector/connect/client/jvm/pom.xml @@ -62,6 +62,18 @@ + + org.apache.spark + spark-mllib-common_${scala.binary.version} + ${project.version} + provided + + + com.google.guava + guava + + + com.google.protobuf protobuf-java diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Estimator.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Estimator.scala new file mode 100644 index 0000000000000..144a10641c758 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import scala.annotation.varargs + +import org.apache.spark.annotation.Since +import org.apache.spark.ml.param.{ParamMap, ParamPair} +import org.apache.spark.sql.Dataset + +/** + * Abstract class for estimators that fit models to data. + */ +abstract class Estimator[M <: Model[M]] extends PipelineStage { + + /** + * Fits a single model to the input data with optional parameters. + * + * @param dataset + * input dataset + * @param firstParamPair + * the first param pair, overrides embedded params + * @param otherParamPairs + * other param pairs. These values override any specified in this Estimator's embedded + * ParamMap. + * @return + * fitted model + */ + @Since("3.5.0") + @varargs + def fit( + dataset: Dataset[_], + firstParamPair: ParamPair[_], + otherParamPairs: ParamPair[_]*): M = { + val map = new ParamMap() + .put(firstParamPair) + .put(otherParamPairs: _*) + fit(dataset, map) + } + + /** + * Fits a single model to the input data with provided parameter map. + * + * @param dataset + * input dataset + * @param paramMap + * Parameter map. These values override any specified in this Estimator's embedded ParamMap. + * @return + * fitted model + */ + @Since("3.5.0") + def fit(dataset: Dataset[_], paramMap: ParamMap): M = { + copy(paramMap).fit(dataset) + } + + /** + * Fits a model to the input data. + */ + @Since("3.5.0") + def fit(dataset: Dataset[_]): M + + /** + * Fits multiple models to the input data with multiple sets of parameters. The default + * implementation uses a for loop on each parameter map. Subclasses could override this to + * optimize multi-model training. + * + * @param dataset + * input dataset + * @param paramMaps + * An array of parameter maps. These values override any specified in this Estimator's + * embedded ParamMap. + * @return + * fitted models, matching the input parameter maps + */ + @Since("3.5.0") + def fit(dataset: Dataset[_], paramMaps: Seq[ParamMap]): Seq[M] = { + paramMaps.map(fit(dataset, _)) + } + + @Since("3.5.0") + override def copy(extra: ParamMap): Estimator[M] +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Model.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Model.scala new file mode 100644 index 0000000000000..a5d6aa1a0795c --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Model.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.annotation.Since +import org.apache.spark.ml.param.ParamMap + +/** + * A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]]. + * + * @tparam M + * model type + */ +abstract class Model[M <: Model[M]] extends Transformer { + + /** + * The parent estimator that produced this model. + * @note + * For ensembles' component Models, this value can be null. + */ + @transient var parent: Estimator[M] = _ + + /** + * Sets the parent of this model (Java API). + */ + @Since("3.5.0") + def setParent(parent: Estimator[M]): M = { + this.parent = parent + this.asInstanceOf[M] + } + + /** Indicates whether this [[Model]] has a corresponding parent. */ + @Since("3.5.0") + def hasParent: Boolean = parent != null + + @Since("3.5.0") + override def copy(extra: ParamMap): M +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Pipeline.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Pipeline.scala new file mode 100644 index 0000000000000..cebbcd167ce34 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.internal.Logging +import org.apache.spark.ml.param.{ParamMap, Params} +import org.apache.spark.sql.types.StructType + +/** + * A stage in a pipeline, either an [[Estimator]] or a [[Transformer]]. + */ +abstract class PipelineStage extends Params with Logging { + + /** + * Check transform validity and derive the output schema from the input schema. + * + * We check validity for interactions between parameters during `transformSchema` and raise an + * exception if any parameter value is invalid. Parameter value checks which do not depend on + * other parameters are handled by `Param.validate()`. + * + * Typical implementation should first conduct verification on schema change and parameter + * validity, including complex parameter interaction checks. + */ + def transformSchema(schema: StructType): StructType + + /** + * :: DeveloperApi :: + * + * Derives the output schema from the input schema and parameters, optionally with logging. + * + * This should be optimistic. If it is unclear whether the schema will be valid, then it should + * be assumed valid until proven otherwise. + */ + @DeveloperApi + protected def transformSchema(schema: StructType, logging: Boolean): StructType = { + if (logging) { + logDebug(s"Input schema: ${schema.json}") + } + val outputSchema = transformSchema(schema) + if (logging) { + logDebug(s"Expected output schema: ${outputSchema.json}") + } + outputSchema + } + + override def copy(extra: ParamMap): PipelineStage +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Predictor.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Predictor.scala new file mode 100644 index 0000000000000..517d5e060f531 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.annotation.Since +import org.apache.spark.ml.linalg.VectorUDT +import org.apache.spark.ml.param._ +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.types.{DataType, StructType} + +/** + * Abstraction for prediction problems (regression and classification). It accepts all NumericType + * labels and will automatically cast it to DoubleType in `fit()`. If this predictor supports + * weights, it accepts all NumericType weights, which will be automatically casted to DoubleType + * in `fit()`. + * + * @tparam FeaturesType + * Type of features. E.g., `VectorUDT` for vector features. + * @tparam Learner + * Specialization of this class. If you subclass this type, use this type parameter to specify + * the concrete type. + * @tparam M + * Specialization of [[PredictionModel]]. If you subclass this type, use this type parameter to + * specify the concrete type for the corresponding model. + */ +abstract class Predictor[ + FeaturesType, + Learner <: Predictor[FeaturesType, Learner, M], + M <: PredictionModel[FeaturesType, M]] + extends Estimator[M] + with PredictorParams { + + /** @group setParam */ + @Since("3.5.0") + def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner] + + /** @group setParam */ + @Since("3.5.0") + def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner] + + /** @group setParam */ + @Since("3.5.0") + def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner] + + @Since("3.5.0") + override def fit(dataset: Dataset[_]): M = { + // TODO: should send the id of the input dataset and the latest params to the server, + // then invoke the 'fit' method of the remote predictor + throw new NotImplementedError + } + + @Since("3.5.0") + override def copy(extra: ParamMap): Learner + + /** + * Returns the SQL DataType corresponding to the FeaturesType type parameter. + * + * This is used by `validateAndTransformSchema()`. This workaround is needed since SQL has + * different APIs for Scala and Java. + * + * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. + */ + private[ml] def featuresDataType: DataType = new VectorUDT + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema, fitting = true, featuresDataType) + } +} + +/** + * Abstraction for a model for prediction tasks (regression and classification). + * + * @tparam FeaturesType + * Type of features. E.g., `VectorUDT` for vector features. + * @tparam M + * Specialization of [[PredictionModel]]. If you subclass this type, use this type parameter to + * specify the concrete type for the corresponding model. + */ +abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]] + extends Model[M] + with PredictorParams { + + /** @group setParam */ + @Since("3.5.0") + def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M] + + /** @group setParam */ + @Since("3.5.0") + def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M] + + /** Returns the number of features the model was trained on. If unknown, returns -1 */ + @Since("3.5.0") + def numFeatures: Int = -1 + + /** + * Returns the SQL DataType corresponding to the FeaturesType type parameter. + * + * This is used by `validateAndTransformSchema()`. This workaround is needed since SQL has + * different APIs for Scala and Java. + * + * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. + */ + protected def featuresDataType: DataType = new VectorUDT + + @Since("3.5.0") + override def transformSchema(schema: StructType): StructType = { + var outputSchema = validateAndTransformSchema(schema, fitting = false, featuresDataType) + if ($(predictionCol).nonEmpty) { + outputSchema = SchemaUtils.updateNumeric(outputSchema, $(predictionCol)) + } + outputSchema + } + + /** + * Transforms dataset by reading from [[featuresCol]], calling `predict`, and storing the + * predictions as a new column [[predictionCol]]. + * + * @param dataset + * input dataset + * @return + * transformed dataset with [[predictionCol]] of type `Double` + */ + @Since("3.5.0") + override def transform(dataset: Dataset[_]): DataFrame = { + transformSchema(dataset.schema, logging = true) + if ($(predictionCol).nonEmpty) { + transformImpl(dataset) + } else { + this.logWarning( + s"$uid: Predictor.transform() does nothing" + + " because no output columns were set.") + dataset.toDF + } + } + + protected def transformImpl(dataset: Dataset[_]): DataFrame = { + // TODO: should send the id of the input dataset and the latest params to the server, + // then invoke the 'transform' method of the remote model + throw new NotImplementedError + } + + /** + * Predict label for the given features. This method is used to implement `transform()` and + * output [[predictionCol]]. + */ + @Since("3.5.0") + def predict(features: FeaturesType): Double +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Transformer.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Transformer.scala new file mode 100644 index 0000000000000..4eebf031b90de --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import scala.annotation.varargs +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.annotation.Since +import org.apache.spark.internal.Logging +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.types._ + +/** + * Abstract class for transformers that transform one dataset into another. + */ +abstract class Transformer extends PipelineStage { + + /** + * Transforms the dataset with optional parameters + * @param dataset + * input dataset + * @param firstParamPair + * the first param pair, overwrite embedded params + * @param otherParamPairs + * other param pairs, overwrite embedded params + * @return + * transformed dataset + */ + @Since("3.5.0") + @varargs + def transform( + dataset: Dataset[_], + firstParamPair: ParamPair[_], + otherParamPairs: ParamPair[_]*): DataFrame = { + val map = new ParamMap() + .put(firstParamPair) + .put(otherParamPairs: _*) + transform(dataset, map) + } + + /** + * Transforms the dataset with provided parameter map as additional parameters. + * @param dataset + * input dataset + * @param paramMap + * additional parameters, overwrite embedded params + * @return + * transformed dataset + */ + @Since("3.5.0") + def transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame = { + this.copy(paramMap).transform(dataset) + } + + /** + * Transforms the input dataset. + */ + @Since("3.5.0") + def transform(dataset: Dataset[_]): DataFrame + + @Since("3.5.0") + override def copy(extra: ParamMap): Transformer +} + +/** + * Abstract class for transformers that take one input column, apply transformation, and output + * the result as a new column. + */ +abstract class UnaryTransformer[IN: TypeTag, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]] + extends Transformer + with HasInputCol + with HasOutputCol + with Logging { + + /** @group setParam */ + @Since("3.5.0") + def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T] + + /** @group setParam */ + @Since("3.5.0") + def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T] + + /** + * Creates the transform function using the given param map. The input param map already takes + * account of the embedded param map. So the param values should be determined solely by the + * input param map. + */ + protected def createTransformFunc: IN => OUT + + /** + * Returns the data type of the output column. + */ + @Since("3.5.0") + protected def outputDataType: DataType + + /** + * Validates the input type. Throw an exception if it is invalid. + */ + protected def validateInputType(inputType: DataType): Unit = {} + + @Since("3.5.0") + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + validateInputType(inputType) + if (schema.fieldNames.contains($(outputCol))) { + throw new IllegalArgumentException(s"Output column ${$(outputCol)} already exists.") + } + val outputFields = schema.fields :+ + StructField($(outputCol), outputDataType, nullable = false) + StructType(outputFields) + } + + override def transform(dataset: Dataset[_]): DataFrame = { + // TODO: should send the id of the input dataset and the latest params to the server, + // then invoke the 'transform' method of the remote model + throw new NotImplementedError + } + + @Since("3.5.0") + override def copy(extra: ParamMap): T = defaultCopy(extra) +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/classification/Classifier.scala new file mode 100644 index 0000000000000..9adf49866b47f --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.Since +import org.apache.spark.ml.{PredictionModel, Predictor} +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.types.StructType + +/** + * Single-label binary or multiclass classification. Classes are indexed {0, 1, ..., numClasses - + * 1}. + * + * @tparam FeaturesType + * Type of input features. E.g., `Vector` + * @tparam E + * Concrete Estimator type + * @tparam M + * Concrete Model type + */ +abstract class Classifier[ + FeaturesType, + E <: Classifier[FeaturesType, E, M], + M <: ClassificationModel[FeaturesType, M]] + extends Predictor[FeaturesType, E, M] + with ClassifierParams { + + @Since("3.5.0") + def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E] + + // TODO: defaultEvaluator (follow-up PR) +} + +/** + * Model produced by a [[Classifier]]. Classes are indexed {0, 1, ..., numClasses - 1}. + * + * @tparam FeaturesType + * Type of input features. E.g., `Vector` + * @tparam M + * Concrete Model type + */ +abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]] + extends PredictionModel[FeaturesType, M] + with ClassifierParams { + + /** @group setParam */ + @Since("3.5.0") + def setRawPredictionCol(value: String): M = set(rawPredictionCol, value).asInstanceOf[M] + + /** Number of classes (values which the label can take). */ + @Since("3.5.0") + def numClasses: Int + + @Since("3.5.0") + override def transformSchema(schema: StructType): StructType = { + var outputSchema = super.transformSchema(schema) + if ($(predictionCol).nonEmpty) { + outputSchema = SchemaUtils.updateNumValues(schema, $(predictionCol), numClasses) + } + if ($(rawPredictionCol).nonEmpty) { + outputSchema = + SchemaUtils.updateAttributeGroupSize(outputSchema, $(rawPredictionCol), numClasses) + } + outputSchema + } + + /** + * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by + * parameters: + * - predicted labels as [[predictionCol]] of type `Double` + * - raw predictions (confidences) as [[rawPredictionCol]] of type `Vector`. + * + * @param dataset + * input dataset + * @return + * transformed dataset + */ + @Since("3.5.0") + override def transform(dataset: Dataset[_]): DataFrame = { + // TODO: should send the id of the input dataset and the latest params to the server, + // then invoke the 'transform' method of the remote model + throw new NotImplementedError + } + + final override def transformImpl(dataset: Dataset[_]): DataFrame = + throw new UnsupportedOperationException(s"transformImpl is not supported in $getClass") + + /** + * Predict label for the given features. This method is used to implement `transform()` and + * output [[predictionCol]]. + * + * This default implementation for classification predicts the index of the maximum value from + * `predictRaw()`. + */ + @Since("3.5.0") + override def predict(features: FeaturesType): Double = { + // TODO: should send the vector to the server, + // then invoke the 'predict' method of the remote model + + // Note: Subclass may need to override this, since the result + // maybe adjusted by param like `thresholds`. + throw new NotImplementedError + } + + /** + * Raw prediction for each possible label. The meaning of a "raw" prediction may vary between + * algorithms, but it intuitively gives a measure of confidence in each possible label (where + * larger = more confident). This internal method is used to implement `transform()` and output + * [[rawPredictionCol]]. + * + * @return + * vector where element i is the raw prediction for label i. This raw prediction may be any + * real number, where a larger value indicates greater confidence for that label. + */ + @Since("3.5.0") + def predictRaw(features: FeaturesType): Vector = { + // TODO: should send the vector to the server, + // then invoke the 'predictRaw' method of the remote model + throw new NotImplementedError + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala new file mode 100644 index 0000000000000..e4db8a047fa27 --- /dev/null +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.Since +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.util.SchemaUtils +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.types.StructType + +/** + * Single-label binary or multiclass classifier which can output class conditional probabilities. + * + * @tparam FeaturesType + * Type of input features. E.g., `Vector` + * @tparam E + * Concrete Estimator type + * @tparam M + * Concrete Model type + */ +abstract class ProbabilisticClassifier[ + FeaturesType, + E <: ProbabilisticClassifier[FeaturesType, E, M], + M <: ProbabilisticClassificationModel[FeaturesType, M]] + extends Classifier[FeaturesType, E, M] + with ProbabilisticClassifierParams { + + /** @group setParam */ + @Since("3.5.0") + def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E] + + /** @group setParam */ + @Since("3.5.0") + def setThresholds(value: Array[Double]): E = set(thresholds, value).asInstanceOf[E] +} + +/** + * Model produced by a [[ProbabilisticClassifier]]. Classes are indexed {0, 1, ..., numClasses - + * 1}. + * + * @tparam FeaturesType + * Type of input features. E.g., `Vector` + * @tparam M + * Concrete Model type + */ +abstract class ProbabilisticClassificationModel[ + FeaturesType, + M <: ProbabilisticClassificationModel[FeaturesType, M]] + extends ClassificationModel[FeaturesType, M] + with ProbabilisticClassifierParams { + + /** @group setParam */ + @Since("3.5.0") + def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M] + + /** @group setParam */ + @Since("3.5.0") + def setThresholds(value: Array[Double]): M = { + require( + value.length == numClasses, + this.getClass.getSimpleName + + ".setThresholds() called with non-matching numClasses and thresholds.length." + + s" numClasses=$numClasses, but thresholds has length ${value.length}") + set(thresholds, value).asInstanceOf[M] + } + + @Since("3.5.0") + override def transformSchema(schema: StructType): StructType = { + var outputSchema = super.transformSchema(schema) + if ($(probabilityCol).nonEmpty) { + outputSchema = + SchemaUtils.updateAttributeGroupSize(outputSchema, $(probabilityCol), numClasses) + } + outputSchema + } + + /** + * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by + * parameters: + * - predicted labels as [[predictionCol]] of type `Double` + * - raw predictions (confidences) as [[rawPredictionCol]] of type `Vector` + * - probability of each class as [[probabilityCol]] of type `Vector`. + * + * @param dataset + * input dataset + * @return + * transformed dataset + */ + @Since("3.5.0") + override def transform(dataset: Dataset[_]): DataFrame = { + // TODO: should send the id of the input dataset and the latest params to the server, + // then invoke the 'transform' method of the remote model + throw new NotImplementedError + } + + /** + * Predict the probability of each class given the features. These predictions are also called + * class conditional probabilities. + * + * This internal method is used to implement `transform()` and output [[probabilityCol]]. + * + * @return + * Estimated class conditional probabilities + */ + @Since("3.5.0") + def predictProbability(features: FeaturesType): Vector = { + // TODO: should send the vector to the server, + // then invoke the 'predictProbability' method of the remote model + throw new NotImplementedError + } +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index ad921bcc4e3f8..193eb4faaaba2 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -24,8 +24,10 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Stable import org.apache.spark.connect.proto.Parse.ParseFormat import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.connect.common.DataTypeProtoConverter +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.StructType /** @@ -531,10 +533,8 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging */ @scala.annotation.varargs def textFile(paths: String*): Dataset[String] = { - // scalastyle:off throwerror - // TODO: this method can be supported and should be included in the client API. - throw new NotImplementedError() - // scalastyle:on throwerror + assertNoSpecifiedSchema("textFile") + text(paths: _*).select("value").as(StringEncoder) } private def assertSourceFormatSpecified(): Unit = { @@ -556,6 +556,15 @@ class DataFrameReader private[sql] (sparkSession: SparkSession) extends Logging } } + /** + * A convenient function for schema validation in APIs. + */ + private def assertNoSpecifiedSchema(operation: String): Unit = { + if (userSpecifiedSchema.nonEmpty) { + throw QueryCompilationErrors.userSpecifiedSchemaUnsupportedError(operation) + } + } + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index 29c2e89c53779..729ee9ed6a09f 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -4003,6 +4003,16 @@ object functions { */ def array_compact(column: Column): Column = Column.fn("array_compact", column) + /** + * Returns an array containing value as well as all elements from array. The new element is + * positioned at the beginning of the array. + * + * @group collection_funcs + * @since 3.5.0 + */ + def array_prepend(column: Column, element: Any): Column = + Column.fn("array_prepend", column, lit(element)) + /** * Removes duplicate values from the array. * @group collection_funcs diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 5aa5500116d8a..605b15123c670 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -25,6 +25,7 @@ import io.grpc.StatusRuntimeException import java.util.Properties import org.apache.commons.io.FileUtils import org.apache.commons.io.output.TeeOutputStream +import org.apache.commons.lang3.{JavaVersion, SystemUtils} import org.scalactic.TolerantNumerics import org.apache.spark.SPARK_VERSION @@ -55,6 +56,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper { } test("eager execution of sql") { + assume(IntegrationTestUtils.isSparkHiveJarAvailable) withTable("test_martin") { // Fails, because table does not exist. assertThrows[StatusRuntimeException] { @@ -161,6 +163,26 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper { } } + test("textFile") { + val testDataPath = java.nio.file.Paths + .get( + IntegrationTestUtils.sparkHome, + "connector", + "connect", + "common", + "src", + "test", + "resources", + "query-tests", + "test-data", + "people.txt") + .toAbsolutePath + val result = spark.read.textFile(testDataPath.toString).collect() + val expected = Array("Michael, 29", "Andy, 30", "Justin, 19") + assert(result.length == 3) + assert(result === expected) + } + test("write table") { withTable("myTable") { val df = spark.range(10).limit(3) @@ -182,16 +204,18 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper { } test("write jdbc") { - val url = "jdbc:derby:memory:1234" - val table = "t1" - try { - spark.range(10).write.jdbc(url = s"$url;create=true", table, new Properties()) - val result = spark.read.jdbc(url = url, table, new Properties()).collect() - assert(result.length == 10) - } finally { - // clean up - assertThrows[StatusRuntimeException] { - spark.read.jdbc(url = s"$url;drop=true", table, new Properties()).collect() + if (SystemUtils.isJavaVersionAtLeast(JavaVersion.JAVA_9)) { + val url = "jdbc:derby:memory:1234" + val table = "t1" + try { + spark.range(10).write.jdbc(url = s"$url;create=true", table, new Properties()) + val result = spark.read.jdbc(url = url, table, new Properties()).collect() + assert(result.length == 10) + } finally { + // clean up + assertThrows[StatusRuntimeException] { + spark.read.jdbc(url = s"$url;drop=true", table, new Properties()).collect() + } } } } @@ -227,6 +251,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper { // TODO (SPARK-42519): Revisit this test after we can set configs. // e.g. spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) test("writeTo with create") { + assume(IntegrationTestUtils.isSparkHiveJarAvailable) withTable("myTableV2") { // Failed to create as Hive support is required. spark.range(3).writeTo("myTableV2").create() diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 3c7e1fdeee645..95d6fddc97caa 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -1714,6 +1714,10 @@ class PlanGenerationTestSuite fn.array_distinct(fn.col("e")) } + functionTest("array_prepend") { + fn.array_prepend(fn.col("e"), lit(1)) + } + functionTest("array_intersect") { fn.array_intersect(fn.col("e"), fn.array(lit(10), lit(4))) } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala index 97d130421a242..a2b4762f0a96b 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala @@ -174,7 +174,6 @@ object CheckConnectJvmClientCompatibility { ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.callUDF"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.unwrap_udt"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.udaf"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.broadcast"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.typedlit"), ProblemFilters.exclude[Problem]("org.apache.spark.sql.functions.typedLit"), diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala index f27ea614a7eb8..a98f7e9c13b37 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.connect.client.util import java.io.File +import java.nio.file.{Files, Paths} import scala.util.Properties.versionNumberString @@ -27,14 +28,15 @@ object IntegrationTestUtils { // System properties used for testing and debugging private val DEBUG_SC_JVM_CLIENT = "spark.debug.sc.jvm.client" - private[sql] lazy val scalaDir = { - val version = versionNumberString.split('.') match { + private[sql] lazy val scalaVersion = { + versionNumberString.split('.') match { case Array(major, minor, _*) => major + "." + minor case _ => versionNumberString } - "scala-" + version } + private[sql] lazy val scalaDir = s"scala-$scalaVersion" + private[sql] lazy val sparkHome: String = { if (!(sys.props.contains("spark.test.home") || sys.env.contains("SPARK_HOME"))) { fail("spark.test.home or SPARK_HOME is not set.") @@ -49,6 +51,12 @@ object IntegrationTestUtils { // scalastyle:on println private[connect] def debug(error: Throwable): Unit = if (isDebug) error.printStackTrace() + private[sql] lazy val isSparkHiveJarAvailable: Boolean = { + val filePath = s"$sparkHome/assembly/target/$scalaDir/jars/" + + s"spark-hive_$scalaVersion-${org.apache.spark.SPARK_VERSION}.jar" + Files.exists(Paths.get(filePath)) + } + /** * Find a jar in the Spark project artifacts. It requires a build first (e.g. build/sbt package, * build/mvn clean install -DskipTests) so that this method can find the jar in the target diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala index beae5bfa27e2a..d1a34603f48cf 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala @@ -62,6 +62,18 @@ object SparkConnectServerUtils { "connector/connect/server", "spark-connect-assembly", "spark-connect").getCanonicalPath + val catalogImplementation = if (IntegrationTestUtils.isSparkHiveJarAvailable) { + "hive" + } else { + // scalastyle:off println + println( + "Will start Spark Connect server with `spark.sql.catalogImplementation=in-memory`, " + + "some tests that rely on Hive will be ignored. If you don't want to skip them:\n" + + "1. Test with maven: run `build/mvn install -DskipTests -Phive` before testing\n" + + "2. Test with sbt: run test with `-Phive` profile") + // scalastyle:on println + "in-memory" + } val builder = Process( Seq( "bin/spark-submit", @@ -72,7 +84,7 @@ object SparkConnectServerUtils { "--conf", "spark.sql.catalog.testcat=org.apache.spark.sql.connect.catalog.InMemoryTableCatalog", "--conf", - "spark.sql.catalogImplementation=hive", + s"spark.sql.catalogImplementation=$catalogImplementation", "--class", "org.apache.spark.sql.connect.SimpleSparkConnectService", jar), diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto index 2118f8e4823ee..da0f974a74906 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto @@ -272,6 +272,9 @@ message ExecutePlanResponse { // The metrics observed during the execution of the query plan. repeated ObservedMetrics observed_metrics = 6; + // (Optional) The Spark schema. This field is available when `collect` is called. + DataType schema = 7; + // A SQL command returns an opaque Relation that can be directly used as input for the next // call. message SqlCommandResult { @@ -413,6 +416,11 @@ message AddArtifactsRequest { // User context UserContext user_context = 2; + // Provides optional information about the client sending the request. This field + // can be used for language or version specific information and is only intended for + // logging purposes and will not be interpreted by the server. + optional string client_type = 6; + // A chunk of an Artifact. message ArtifactChunk { // Data chunk. diff --git a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto index 69451e7b76eef..aba965082ea2a 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -63,6 +63,7 @@ message Relation { MapPartitions map_partitions = 28; CollectMetrics collect_metrics = 29; Parse parse = 30; + GroupMap group_map = 31; // NA functions NAFill fill_na = 90; @@ -788,6 +789,17 @@ message MapPartitions { CommonInlineUserDefinedFunction func = 2; } +message GroupMap { + // (Required) Input relation for Group Map API: apply, applyInPandas. + Relation input = 1; + + // (Required) Expressions for grouping keys. + repeated Expression grouping_expressions = 2; + + // (Required) Input user-defined function. + CommonInlineUserDefinedFunction func = 3; +} + // Collect arbitrary (named) metrics from a dataset. message CollectMetrics { // (Required) The input relation. diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala index c30ea8c830136..28ddbe844d445 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala @@ -335,6 +335,7 @@ object DataTypeProtoConverter { .setType("udt") .setPythonClass(pyudt.pyUDT) .setSqlType(toConnectProtoType(pyudt.sqlType)) + .setSerializedPythonClass(pyudt.serializedPyClass) .build()) .build() diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_prepend.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_prepend.explain new file mode 100644 index 0000000000000..539e1eaf767cc --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_array_prepend.explain @@ -0,0 +1,2 @@ +Project [array_prepend(e#0, 1) AS array_prepend(e, 1)#0] ++- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_array_prepend.json b/connector/connect/common/src/test/resources/query-tests/queries/function_array_prepend.json new file mode 100644 index 0000000000000..ededeb015a227 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_array_prepend.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "array_prepend", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "e" + } + }, { + "literal": { + "integer": 1 + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_array_prepend.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_array_prepend.proto.bin new file mode 100644 index 0000000000000..837710597e7b6 Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/function_array_prepend.proto.bin differ diff --git a/connector/connect/server/pom.xml b/connector/connect/server/pom.xml index 079d07db362c1..4d8e082a2db57 100644 --- a/connector/connect/server/pom.xml +++ b/connector/connect/server/pom.xml @@ -93,6 +93,18 @@ + + org.apache.spark + spark-mllib_${scala.binary.version} + ${project.version} + provided + + + com.google.guava + guava + + + org.apache.spark spark-catalyst_${scala.binary.version} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index a057bd8d6c1e5..c8fdaa6641ab3 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -30,6 +30,7 @@ import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{ExecutePlanResponse, SqlCommand} import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult import org.apache.spark.connect.proto.Parse.ParseFormat +import org.apache.spark.ml.{functions => MLFunctions} import org.apache.spark.sql.{Column, Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier} import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, ParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar} @@ -38,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, CommandResult, Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint} +import org.apache.spark.sql.catalyst.plans.logical.{CollectMetrics, CommandResult, Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Project, Sample, Sort, SubqueryAlias, Union, Unpivot, UnresolvedHint} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, UdfPacket} import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE @@ -116,6 +117,8 @@ class SparkConnectPlanner(val session: SparkSession) { transformRepartitionByExpression(rel.getRepartitionByExpression) case proto.Relation.RelTypeCase.MAP_PARTITIONS => transformMapPartitions(rel.getMapPartitions) + case proto.Relation.RelTypeCase.GROUP_MAP => + transformGroupMap(rel.getGroupMap) case proto.Relation.RelTypeCase.COLLECT_METRICS => transformCollectMetrics(rel.getCollectMetrics) case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse) @@ -494,6 +497,18 @@ class SparkConnectPlanner(val session: SparkSession) { } } + private def transformGroupMap(rel: proto.GroupMap): LogicalPlan = { + val pythonUdf = transformPythonUDF(rel.getFunc) + val cols = + rel.getGroupingExpressionsList.asScala.toSeq.map(expr => Column(transformExpression(expr))) + + Dataset + .ofRows(session, transformRelation(rel.getInput)) + .groupBy(cols: _*) + .flatMapGroupsInPandas(pythonUdf) + .logicalPlan + } + private def transformWithColumnsRenamed(rel: proto.WithColumnsRenamed): LogicalPlan = { Dataset .ofRows(session, transformRelation(rel.getInput)) @@ -675,16 +690,36 @@ class SparkConnectPlanner(val session: SparkSession) { } val attributes = structType.toAttributes val proj = UnsafeProjection.create(attributes, attributes) - val relation = logical.LocalRelation(attributes, rows.map(r => proj(r).copy()).toSeq) + val data = rows.map(proj) if (schema == null) { - relation + logical.LocalRelation(attributes, data.map(_.copy()).toSeq) } else { - Dataset - .ofRows(session, logicalPlan = relation) - .toDF(schema.names: _*) - .to(schema) + def udtToSqlType(dt: DataType): DataType = dt match { + case udt: UserDefinedType[_] => udt.sqlType + case StructType(fields) => + val newFields = fields.map { case StructField(name, dataType, nullable, metadata) => + StructField(name, udtToSqlType(dataType), nullable, metadata) + } + StructType(newFields) + case ArrayType(elementType, containsNull) => + ArrayType(udtToSqlType(elementType), containsNull) + case MapType(keyType, valueType, valueContainsNull) => + MapType(udtToSqlType(keyType), udtToSqlType(valueType), valueContainsNull) + case _ => dt + } + + val sqlTypeOnlySchema = udtToSqlType(schema).asInstanceOf[StructType] + + val project = Dataset + .ofRows(session, logicalPlan = logical.LocalRelation(attributes)) + .toDF(sqlTypeOnlySchema.names: _*) + .to(sqlTypeOnlySchema) .logicalPlan + .asInstanceOf[Project] + + val proj = UnsafeProjection.create(project.projectList, project.child.output) + logical.LocalRelation(schema.toAttributes, data.map(proj).map(_.copy()).toSeq) } } else { if (schema == null) { @@ -1187,10 +1222,51 @@ class SparkConnectPlanner(val session: SparkSession) { None } + // ML-specific functions + case "vector_to_array" if fun.getArgumentsCount == 2 => + val expr = transformExpression(fun.getArguments(0)) + val dtype = transformExpression(fun.getArguments(1)) match { + case Literal(s, StringType) if s != null => s.toString + case other => + throw InvalidPlanInput( + s"dtype in vector_to_array should be a literal string, but got $other") + } + dtype match { + case "float64" => + Some(transformUnregisteredUDF(MLFunctions.vectorToArrayUdf, Seq(expr))) + case "float32" => + Some(transformUnregisteredUDF(MLFunctions.vectorToArrayFloatUdf, Seq(expr))) + case other => + throw InvalidPlanInput(s"Unsupported dtype: $other. Valid values: float64, float32.") + } + + case "array_to_vector" if fun.getArgumentsCount == 1 => + val expr = transformExpression(fun.getArguments(0)) + Some(transformUnregisteredUDF(MLFunctions.arrayToVectorUdf, Seq(expr))) + case _ => None } } + /** + * There are some built-in yet not registered UDFs, for example, 'ml.function.array_to_vector'. + * This method is to convert them to ScalaUDF expressions. + */ + private def transformUnregisteredUDF( + fun: org.apache.spark.sql.expressions.UserDefinedFunction, + exprs: Seq[Expression]): ScalaUDF = { + val f = fun.asInstanceOf[org.apache.spark.sql.expressions.SparkUserDefinedFunction] + ScalaUDF( + function = f.f, + dataType = f.dataType, + children = exprs, + inputEncoders = f.inputEncoders, + outputEncoder = f.outputEncoder, + udfName = f.name, + nullable = f.nullable, + udfDeterministic = f.deterministic) + } + private def transformAlias(alias: proto.Expression.Alias): NamedExpression = { if (alias.getNameCount == 1) { val metadata = if (alias.hasMetadata() && alias.getMetadata.nonEmpty) { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 104d840ed52bd..335b871d499be 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -28,6 +28,7 @@ import org.apache.spark.connect.proto.{ExecutePlanRequest, ExecutePlanResponse} import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE import org.apache.spark.sql.connect.planner.SparkConnectPlanner @@ -60,6 +61,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp // Extract the plan from the request and convert it to a logical plan val planner = new SparkConnectPlanner(session) val dataframe = Dataset.ofRows(session, planner.transformRelation(request.getPlan.getRoot)) + responseObserver.onNext( + SparkConnectStreamHandler.sendSchemaToResponse(request.getSessionId, dataframe.schema)) processAsArrowBatches(request.getSessionId, dataframe, responseObserver) responseObserver.onNext( SparkConnectStreamHandler.sendMetricsToResponse(request.getSessionId, dataframe)) @@ -203,6 +206,15 @@ object SparkConnectStreamHandler { } } + def sendSchemaToResponse(sessionId: String, schema: StructType): ExecutePlanResponse = { + // Send the Spark data type + ExecutePlanResponse + .newBuilder() + .setSessionId(sessionId) + .setSchema(DataTypeProtoConverter.toConnectProtoType(schema)) + .build() + } + def sendMetricsToResponse(sessionId: String, rows: DataFrame): ExecutePlanResponse = { // Send a last batch with the metrics ExecutePlanResponse diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index e2aecaaea8602..c36ba76f98451 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -160,18 +160,22 @@ class SparkConnectServiceSuite extends SharedSparkSession { assert(done) // 4 Partitions + Metrics - assert(responses.size == 5) + assert(responses.size == 6) + + // Make sure the first response is schema only + val head = responses.head + assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics) // Make sure the last response is metrics only val last = responses.last - assert(last.hasMetrics && !last.hasArrowBatch) + assert(last.hasMetrics && !last.hasSchema && !last.hasArrowBatch) val allocator = new RootAllocator() // Check the 'data' batches var expectedId = 0L var previousEId = 0.0d - responses.dropRight(1).foreach { response => + responses.tail.dropRight(1).foreach { response => assert(response.hasArrowBatch) val batch = response.getArrowBatch assert(batch.getData != null) @@ -347,11 +351,15 @@ class SparkConnectServiceSuite extends SharedSparkSession { // The current implementation is expected to be blocking. This is here to make sure it is. assert(done) - assert(responses.size == 6) + assert(responses.size == 7) + + // Make sure the first response is schema only + val head = responses.head + assert(head.hasSchema && !head.hasArrowBatch && !head.hasMetrics) // Make sure the last response is observed metrics only val last = responses.last - assert(last.getObservedMetricsCount == 1 && !last.hasArrowBatch) + assert(last.getObservedMetricsCount == 1 && !last.hasSchema && !last.hasArrowBatch) val observedMetricsList = last.getObservedMetricsList.asScala val observedMetric = observedMetricsList.head diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala index 6a42158f5876a..291276c198144 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala @@ -38,6 +38,17 @@ import org.apache.spark.tags.DockerTest */ @DockerTest class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { + + override def excluded: Seq[String] = Seq( + "scan with aggregate push-down: COVAR_POP with DISTINCT", + "scan with aggregate push-down: COVAR_SAMP with DISTINCT", + "scan with aggregate push-down: CORR with DISTINCT", + "scan with aggregate push-down: CORR without DISTINCT", + "scan with aggregate push-down: REGR_INTERCEPT with DISTINCT", + "scan with aggregate push-down: REGR_SLOPE with DISTINCT", + "scan with aggregate push-down: REGR_R2 with DISTINCT", + "scan with aggregate push-down: REGR_SXY with DISTINCT") + override val catalogName: String = "db2" override val namespaceOpt: Option[String] = Some("DB2INST1") override val db = new DatabaseOnDocker { @@ -97,23 +108,4 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { } override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT) - - testOffset() - testLimitAndOffset() - testPaging() - - testVarPop() - testVarPop(true) - testVarSamp() - testVarSamp(true) - testStddevPop() - testStddevPop(true) - testStddevSamp() - testStddevSamp(true) - testCovarPop() - testCovarSamp() - testRegrIntercept() - testRegrSlope() - testRegrR2() - testRegrSXY() } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala index f0e98fc2722b0..f53dc1d5f6da7 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala @@ -68,7 +68,4 @@ class DB2NamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceT } override val supportsDropSchemaCascade: Boolean = false - - testListNamespaces() - testDropNamespaces() } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index 6b8d62f8f7b1d..107e28d1b3828 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -39,8 +39,27 @@ import org.apache.spark.tags.DockerTest @DockerTest class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { - override val catalogName: String = "mssql" + override def excluded: Seq[String] = Seq( + "simple scan with OFFSET", + "simple scan with LIMIT and OFFSET", + "simple scan with paging: top N and OFFSET", + "scan with aggregate push-down: VAR_POP with DISTINCT", + "scan with aggregate push-down: COVAR_POP with DISTINCT", + "scan with aggregate push-down: COVAR_POP without DISTINCT", + "scan with aggregate push-down: COVAR_SAMP with DISTINCT", + "scan with aggregate push-down: COVAR_SAMP without DISTINCT", + "scan with aggregate push-down: CORR with DISTINCT", + "scan with aggregate push-down: CORR without DISTINCT", + "scan with aggregate push-down: REGR_INTERCEPT with DISTINCT", + "scan with aggregate push-down: REGR_INTERCEPT without DISTINCT", + "scan with aggregate push-down: REGR_SLOPE with DISTINCT", + "scan with aggregate push-down: REGR_SLOPE without DISTINCT", + "scan with aggregate push-down: REGR_R2 with DISTINCT", + "scan with aggregate push-down: REGR_R2 without DISTINCT", + "scan with aggregate push-down: REGR_SXY with DISTINCT", + "scan with aggregate push-down: REGR_SXY without DISTINCT") + override val catalogName: String = "mssql" override val db = new DatabaseOnDocker { override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME", "mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04") @@ -97,13 +116,4 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD assert(msg.contains("UpdateColumnNullability is not supported")) } - - testVarPop() - testVarPop(true) - testVarSamp() - testVarSamp(true) - testStddevPop() - testStddevPop(true) - testStddevSamp() - testStddevSamp(true) } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala index aa8dac266380a..b0a2d37e465ac 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala @@ -70,7 +70,4 @@ class MsSqlServerNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNa override val supportsSchemaComment: Boolean = false override val supportsDropSchemaCascade: Boolean = false - - testListNamespaces() - testDropNamespaces() } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index 41a42e21f44d5..789dfeddc214c 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -37,6 +37,27 @@ import org.apache.spark.tags.DockerTest */ @DockerTest class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { + + override def excluded: Seq[String] = Seq( + "scan with aggregate push-down: VAR_POP with DISTINCT", + "scan with aggregate push-down: VAR_SAMP with DISTINCT", + "scan with aggregate push-down: STDDEV_POP with DISTINCT", + "scan with aggregate push-down: STDDEV_SAMP with DISTINCT", + "scan with aggregate push-down: COVAR_POP with DISTINCT", + "scan with aggregate push-down: COVAR_POP without DISTINCT", + "scan with aggregate push-down: COVAR_SAMP with DISTINCT", + "scan with aggregate push-down: COVAR_SAMP without DISTINCT", + "scan with aggregate push-down: CORR with DISTINCT", + "scan with aggregate push-down: CORR without DISTINCT", + "scan with aggregate push-down: REGR_INTERCEPT with DISTINCT", + "scan with aggregate push-down: REGR_INTERCEPT without DISTINCT", + "scan with aggregate push-down: REGR_SLOPE with DISTINCT", + "scan with aggregate push-down: REGR_SLOPE without DISTINCT", + "scan with aggregate push-down: REGR_R2 with DISTINCT", + "scan with aggregate push-down: REGR_R2 without DISTINCT", + "scan with aggregate push-down: REGR_SXY with DISTINCT", + "scan with aggregate push-down: REGR_SXY without DISTINCT") + override val catalogName: String = "mysql" override val db = new DatabaseOnDocker { override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:8.0.31") @@ -124,13 +145,4 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest override def supportListIndexes: Boolean = true override def indexOptions: String = "KEY_BLOCK_SIZE=10" - - testOffset() - testLimitAndOffset() - testPaging() - - testVarPop() - testVarSamp() - testStddevPop() - testStddevSamp() } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala index b73e2b8fd23ca..0974a86fe9b83 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala @@ -68,9 +68,6 @@ class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespac override val supportsDropSchemaRestrict: Boolean = false - testListNamespaces() - testDropNamespaces() - test("Create or remove comment of namespace unsupported") { val e1 = intercept[AnalysisException] { catalog.createNamespace(Array("foo"), Map("comment" -> "test comment").asJava) diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala index a810602652766..f9923ef9e1c10 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala @@ -56,6 +56,20 @@ import org.apache.spark.tags.DockerTest */ @DockerTest class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { + + override def excluded: Seq[String] = Seq( + "scan with aggregate push-down: VAR_POP with DISTINCT", + "scan with aggregate push-down: VAR_SAMP with DISTINCT", + "scan with aggregate push-down: STDDEV_POP with DISTINCT", + "scan with aggregate push-down: STDDEV_SAMP with DISTINCT", + "scan with aggregate push-down: COVAR_POP with DISTINCT", + "scan with aggregate push-down: COVAR_SAMP with DISTINCT", + "scan with aggregate push-down: CORR with DISTINCT", + "scan with aggregate push-down: REGR_INTERCEPT with DISTINCT", + "scan with aggregate push-down: REGR_SLOPE with DISTINCT", + "scan with aggregate push-down: REGR_R2 with DISTINCT", + "scan with aggregate push-down: REGR_SXY with DISTINCT") + override val catalogName: String = "oracle" override val namespaceOpt: Option[String] = Some("SYSTEM") override val db = new DatabaseOnDocker { @@ -105,20 +119,4 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes } override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT) - - testOffset() - testLimitAndOffset() - testPaging() - - testVarPop() - testVarSamp() - testStddevPop() - testStddevSamp() - testCovarPop() - testCovarSamp() - testCorr() - testRegrIntercept() - testRegrSlope() - testRegrR2() - testRegrSXY() } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleNamespaceSuite.scala index b3e9d19a10f38..a365a1c4e82e4 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleNamespaceSuite.scala @@ -52,6 +52,9 @@ import org.apache.spark.tags.DockerTest */ @DockerTest class OracleNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { + + override def excluded: Seq[String] = Seq("listNamespaces: basic behavior", "Drop namespace") + override val db = new DatabaseOnDocker { lazy override val imageName = sys.env.getOrElse("ORACLE_DOCKER_IMAGE_NAME", "gvenzl/oracle-xe:21.3.0") diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index 4065dbcc036f6..4742764021bf5 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -90,31 +90,4 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT override def supportsIndex: Boolean = true override def indexOptions: String = "FILLFACTOR=70" - - testOffset() - testLimitAndOffset() - testPaging() - - testVarPop() - testVarPop(true) - testVarSamp() - testVarSamp(true) - testStddevPop() - testStddevPop(true) - testStddevSamp() - testStddevSamp(true) - testCovarPop() - testCovarPop(true) - testCovarSamp() - testCovarSamp(true) - testCorr() - testCorr(true) - testRegrIntercept() - testRegrIntercept(true) - testRegrSlope() - testRegrSlope(true) - testRegrR2() - testRegrR2(true) - testRegrSXY() - testRegrSXY(true) } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala index 8c525717758c3..cf7266e67e325 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala @@ -55,7 +55,4 @@ class PostgresNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNames override def builtinNamespaces: Array[Array[String]] = Array(Array("information_schema"), Array("pg_catalog"), Array("public")) - - testListNamespaces() - testDropNamespaces() } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala index d3f17187a3754..b7c6e0aff20a7 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala @@ -55,83 +55,79 @@ private[v2] trait V2JDBCNamespaceTest extends SharedSparkSession with DockerInte def supportsDropSchemaRestrict: Boolean = true - def testListNamespaces(): Unit = { - test("listNamespaces: basic behavior") { - val commentMap = if (supportsSchemaComment) { - Map("comment" -> "test comment") - } else { - Map.empty[String, String] - } - catalog.createNamespace(Array("foo"), commentMap.asJava) - assert(catalog.listNamespaces().map(_.toSet).toSet === - listNamespaces(Array("foo")).map(_.toSet).toSet) - assert(catalog.listNamespaces(Array("foo")) === Array()) - assert(catalog.namespaceExists(Array("foo")) === true) - - if (supportsSchemaComment) { - val logAppender = new LogAppender("catalog comment") - withLogAppender(logAppender) { - catalog.alterNamespace(Array("foo"), NamespaceChange - .setProperty("comment", "comment for foo")) - catalog.alterNamespace(Array("foo"), NamespaceChange.removeProperty("comment")) - } - val createCommentWarning = logAppender.loggingEvents - .filter(_.getLevel == Level.WARN) - .map(_.getMessage.getFormattedMessage) - .exists(_.contains("catalog comment")) - assert(createCommentWarning === false) + test("listNamespaces: basic behavior") { + val commentMap = if (supportsSchemaComment) { + Map("comment" -> "test comment") + } else { + Map.empty[String, String] + } + catalog.createNamespace(Array("foo"), commentMap.asJava) + assert(catalog.listNamespaces().map(_.toSet).toSet === + listNamespaces(Array("foo")).map(_.toSet).toSet) + assert(catalog.listNamespaces(Array("foo")) === Array()) + assert(catalog.namespaceExists(Array("foo")) === true) + + if (supportsSchemaComment) { + val logAppender = new LogAppender("catalog comment") + withLogAppender(logAppender) { + catalog.alterNamespace(Array("foo"), NamespaceChange + .setProperty("comment", "comment for foo")) + catalog.alterNamespace(Array("foo"), NamespaceChange.removeProperty("comment")) } + val createCommentWarning = logAppender.loggingEvents + .filter(_.getLevel == Level.WARN) + .map(_.getMessage.getFormattedMessage) + .exists(_.contains("catalog comment")) + assert(createCommentWarning === false) + } - if (supportsDropSchemaRestrict) { - catalog.dropNamespace(Array("foo"), cascade = false) - } else { - catalog.dropNamespace(Array("foo"), cascade = true) - } - assert(catalog.namespaceExists(Array("foo")) === false) - assert(catalog.listNamespaces() === builtinNamespaces) - val e = intercept[AnalysisException] { - catalog.listNamespaces(Array("foo")) - } - checkError(e, - errorClass = "SCHEMA_NOT_FOUND", - parameters = Map("schemaName" -> "`foo`")) + if (supportsDropSchemaRestrict) { + catalog.dropNamespace(Array("foo"), cascade = false) + } else { + catalog.dropNamespace(Array("foo"), cascade = true) + } + assert(catalog.namespaceExists(Array("foo")) === false) + assert(catalog.listNamespaces() === builtinNamespaces) + val e = intercept[AnalysisException] { + catalog.listNamespaces(Array("foo")) } + checkError(e, + errorClass = "SCHEMA_NOT_FOUND", + parameters = Map("schemaName" -> "`foo`")) } - def testDropNamespaces(): Unit = { - test("Drop namespace") { - val ident1 = Identifier.of(Array("foo"), "tab") - // Drop empty namespace without cascade - val commentMap = if (supportsSchemaComment) { - Map("comment" -> "test comment") - } else { - Map.empty[String, String] - } - catalog.createNamespace(Array("foo"), commentMap.asJava) - assert(catalog.namespaceExists(Array("foo")) === true) - if (supportsDropSchemaRestrict) { + test("Drop namespace") { + val ident1 = Identifier.of(Array("foo"), "tab") + // Drop empty namespace without cascade + val commentMap = if (supportsSchemaComment) { + Map("comment" -> "test comment") + } else { + Map.empty[String, String] + } + catalog.createNamespace(Array("foo"), commentMap.asJava) + assert(catalog.namespaceExists(Array("foo")) === true) + if (supportsDropSchemaRestrict) { + catalog.dropNamespace(Array("foo"), cascade = false) + } else { + catalog.dropNamespace(Array("foo"), cascade = true) + } + assert(catalog.namespaceExists(Array("foo")) === false) + + // Drop non empty namespace without cascade + catalog.createNamespace(Array("foo"), commentMap.asJava) + assert(catalog.namespaceExists(Array("foo")) === true) + catalog.createTable(ident1, schema, Array.empty[Transform], emptyProps) + if (supportsDropSchemaRestrict) { + intercept[NonEmptyNamespaceException] { catalog.dropNamespace(Array("foo"), cascade = false) - } else { - catalog.dropNamespace(Array("foo"), cascade = true) } - assert(catalog.namespaceExists(Array("foo")) === false) + } - // Drop non empty namespace without cascade - catalog.createNamespace(Array("foo"), commentMap.asJava) + // Drop non empty namespace with cascade + if (supportsDropSchemaCascade) { assert(catalog.namespaceExists(Array("foo")) === true) - catalog.createTable(ident1, schema, Array.empty[Transform], emptyProps) - if (supportsDropSchemaRestrict) { - intercept[NonEmptyNamespaceException] { - catalog.dropNamespace(Array("foo"), cascade = false) - } - } - - // Drop non empty namespace with cascade - if (supportsDropSchemaCascade) { - assert(catalog.namespaceExists(Array("foo")) === true) - catalog.dropNamespace(Array("foo"), cascade = true) - assert(catalog.namespaceExists(Array("foo")) === false) - } + catalog.dropNamespace(Array("foo"), cascade = true) + assert(catalog.namespaceExists(Array("foo")) === false) } } } diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index 97ee338509031..85b0b807932aa 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -314,14 +314,13 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - private def limitPushed(df: DataFrame, limit: Int): Boolean = { + private def checkLimitPushed(df: DataFrame, limit: Option[Int]): Unit = { df.queryExecution.optimizedPlan.collect { case relation: DataSourceV2ScanRelation => relation.scan match { case v1: V1ScanWrapper => - return v1.pushedDownOperators.limit == Some(limit) + assert(v1.pushedDownOperators.limit == limit) } } - false } private def checkColumnPruned(df: DataFrame, col: String): Unit = { @@ -354,7 +353,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu val df3 = sql(s"SELECT col1 FROM $catalogName.new_table TABLESAMPLE (BUCKET 6 OUT OF 10)" + " LIMIT 2") checkSamplePushed(df3) - assert(limitPushed(df3, 2)) + checkLimitPushed(df3, Some(2)) checkColumnPruned(df3, "col1") assert(df3.collect().length <= 2) @@ -362,7 +361,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu val df4 = sql(s"SELECT col1 FROM $catalogName.new_table" + " TABLESAMPLE (50 PERCENT) REPEATABLE (12345) LIMIT 2") checkSamplePushed(df4) - assert(limitPushed(df4, 2)) + checkLimitPushed(df4, Some(2)) checkColumnPruned(df4, "col1") assert(df4.collect().length <= 2) @@ -371,7 +370,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu " TABLESAMPLE (BUCKET 6 OUT OF 10) WHERE col1 > 0 LIMIT 2") checkSamplePushed(df5) checkFilterPushed(df5) - assert(limitPushed(df5, 2)) + checkLimitPushed(df5, Some(2)) assert(df5.collect().length <= 2) // sample + filter + limit + column pruning @@ -381,7 +380,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu " TABLESAMPLE (BUCKET 6 OUT OF 10) WHERE col1 > 0 LIMIT 2") checkSamplePushed(df6) checkFilterPushed(df6, false) - assert(!limitPushed(df6, 2)) + checkLimitPushed(df6, None) checkColumnPruned(df6, "col1") assert(df6.collect().length <= 2) @@ -390,7 +389,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu // only limit is pushed down because in this test sample is after limit val df7 = spark.read.table(s"$catalogName.new_table").limit(2).sample(0.5) checkSamplePushed(df7, false) - assert(limitPushed(df7, 2)) + checkLimitPushed(df7, Some(2)) // sample + filter // Push down order is sample -> filter -> limit @@ -422,7 +421,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu test("simple scan with LIMIT") { val df = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 LIMIT 1") - assert(limitPushed(df, 1)) + checkLimitPushed(df, Some(1)) val rows = df.collect() assert(rows.length === 1) assert(rows(0).getString(0) === "amy") @@ -434,7 +433,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu Seq(NullOrdering.values()).flatten.foreach { nullOrdering => val df1 = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 ORDER BY salary $nullOrdering LIMIT 1") - assert(limitPushed(df1, 1)) + checkLimitPushed(df1, Some(1)) checkSortRemoved(df1) val rows1 = df1.collect() assert(rows1.length === 1) @@ -444,7 +443,7 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu val df2 = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 ORDER BY bonus DESC $nullOrdering LIMIT 1") - assert(limitPushed(df2, 1)) + checkLimitPushed(df2, Some(1)) checkSortRemoved(df2) val rows2 = df2.collect() assert(rows2.length === 1) @@ -454,60 +453,54 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu } } - protected def testOffset(): Unit = { - test("simple scan with OFFSET") { - val df = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." + - s"${caseConvert("employee")} WHERE dept > 0 OFFSET 4") - checkOffsetPushed(df, Some(4)) - val rows = df.collect() - assert(rows.length === 1) - assert(rows(0).getString(0) === "jen") - assert(rows(0).getDecimal(1) === new java.math.BigDecimal("12000.00")) - assert(rows(0).getDouble(2) === 1200d) - } + test("simple scan with OFFSET") { + val df = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 OFFSET 4") + checkOffsetPushed(df, Some(4)) + val rows = df.collect() + assert(rows.length === 1) + assert(rows(0).getString(0) === "jen") + assert(rows(0).getDecimal(1) === new java.math.BigDecimal("12000.00")) + assert(rows(0).getDouble(2) === 1200d) } - protected def testLimitAndOffset(): Unit = { - test("simple scan with LIMIT and OFFSET") { - val df = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." + - s"${caseConvert("employee")} WHERE dept > 0 LIMIT 1 OFFSET 2") - assert(limitPushed(df, 3)) - checkOffsetPushed(df, Some(2)) - val rows = df.collect() - assert(rows.length === 1) - assert(rows(0).getString(0) === "cathy") - assert(rows(0).getDecimal(1) === new java.math.BigDecimal("9000.00")) - assert(rows(0).getDouble(2) === 1200d) - } + test("simple scan with LIMIT and OFFSET") { + val df = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 LIMIT 1 OFFSET 2") + checkLimitPushed(df, Some(3)) + checkOffsetPushed(df, Some(2)) + val rows = df.collect() + assert(rows.length === 1) + assert(rows(0).getString(0) === "cathy") + assert(rows(0).getDecimal(1) === new java.math.BigDecimal("9000.00")) + assert(rows(0).getDouble(2) === 1200d) } - protected def testPaging(): Unit = { - test("simple scan with paging: top N and OFFSET") { - Seq(NullOrdering.values()).flatten.foreach { nullOrdering => - val df1 = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." + - s"${caseConvert("employee")}" + - s" WHERE dept > 0 ORDER BY salary $nullOrdering, bonus LIMIT 1 OFFSET 2") - assert(limitPushed(df1, 3)) - checkOffsetPushed(df1, Some(2)) - checkSortRemoved(df1) - val rows1 = df1.collect() - assert(rows1.length === 1) - assert(rows1(0).getString(0) === "david") - assert(rows1(0).getDecimal(1) === new java.math.BigDecimal("10000.00")) - assert(rows1(0).getDouble(2) === 1300d) - - val df2 = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." + - s"${caseConvert("employee")}" + - s" WHERE dept > 0 ORDER BY salary DESC $nullOrdering, bonus LIMIT 1 OFFSET 2") - assert(limitPushed(df2, 3)) - checkOffsetPushed(df2, Some(2)) - checkSortRemoved(df2) - val rows2 = df2.collect() - assert(rows2.length === 1) - assert(rows2(0).getString(0) === "amy") - assert(rows2(0).getDecimal(1) === new java.math.BigDecimal("10000.00")) - assert(rows2(0).getDouble(2) === 1000d) - } + test("simple scan with paging: top N and OFFSET") { + Seq(NullOrdering.values()).flatten.foreach { nullOrdering => + val df1 = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." + + s"${caseConvert("employee")}" + + s" WHERE dept > 0 ORDER BY salary $nullOrdering, bonus LIMIT 1 OFFSET 2") + checkLimitPushed(df1, Some(3)) + checkOffsetPushed(df1, Some(2)) + checkSortRemoved(df1) + val rows1 = df1.collect() + assert(rows1.length === 1) + assert(rows1(0).getString(0) === "david") + assert(rows1(0).getDecimal(1) === new java.math.BigDecimal("10000.00")) + assert(rows1(0).getDouble(2) === 1300d) + + val df2 = sql(s"SELECT name, salary, bonus FROM $catalogAndNamespace." + + s"${caseConvert("employee")}" + + s" WHERE dept > 0 ORDER BY salary DESC $nullOrdering, bonus LIMIT 1 OFFSET 2") + checkLimitPushed(df2, Some(3)) + checkOffsetPushed(df2, Some(2)) + checkSortRemoved(df2) + val rows2 = df2.collect() + assert(rows2.length === 1) + assert(rows2(0).getString(0) === "amy") + assert(rows2(0).getDecimal(1) === new java.math.BigDecimal("10000.00")) + assert(rows2(0).getDouble(2) === 1000d) } } @@ -536,9 +529,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu private def withOrWithout(isDistinct: Boolean): String = if (isDistinct) "with" else "without" - protected def testVarPop(isDistinct: Boolean = false): Unit = { + Seq(true, false).foreach { isDistinct => val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: VAR_POP ${withOrWithout(isDistinct)} DISTINCT") { + val withOrWithout = if (isDistinct) "with" else "without" + + test(s"scan with aggregate push-down: VAR_POP $withOrWithout DISTINCT") { val df = sql(s"SELECT VAR_POP(${distinct}bonus) FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) @@ -550,14 +545,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu assert(row(1).getDouble(0) === 2500.0) assert(row(2).getDouble(0) === 0.0) } - } - protected def testVarSamp(isDistinct: Boolean = false): Unit = { - val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: VAR_SAMP ${withOrWithout(isDistinct)} DISTINCT") { + test(s"scan with aggregate push-down: VAR_SAMP $withOrWithout DISTINCT") { val df = sql( s"SELECT VAR_SAMP(${distinct}bonus) FROM $catalogAndNamespace." + - s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) checkAggregateRemoved(df) checkAggregatePushed(df, "VAR_SAMP") @@ -567,14 +559,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu assert(row(1).getDouble(0) === 5000.0) assert(row(2).isNullAt(0)) } - } - protected def testStddevPop(isDistinct: Boolean = false): Unit = { - val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: STDDEV_POP ${withOrWithout(isDistinct)} DISTINCT") { + test(s"scan with aggregate push-down: STDDEV_POP $withOrWithout DISTINCT") { val df = sql( s"SELECT STDDEV_POP(${distinct}bonus) FROM $catalogAndNamespace." + - s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) checkAggregateRemoved(df) checkAggregatePushed(df, "STDDEV_POP") @@ -584,14 +573,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu assert(row(1).getDouble(0) === 50.0) assert(row(2).getDouble(0) === 0.0) } - } - protected def testStddevSamp(isDistinct: Boolean = false): Unit = { - val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: STDDEV_SAMP ${withOrWithout(isDistinct)} DISTINCT") { + test(s"scan with aggregate push-down: STDDEV_SAMP $withOrWithout DISTINCT") { val df = sql( s"SELECT STDDEV_SAMP(${distinct}bonus) FROM $catalogAndNamespace." + - s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) checkAggregateRemoved(df) checkAggregatePushed(df, "STDDEV_SAMP") @@ -601,14 +587,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu assert(row(1).getDouble(0) === 70.71067811865476) assert(row(2).isNullAt(0)) } - } - protected def testCovarPop(isDistinct: Boolean = false): Unit = { - val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: COVAR_POP ${withOrWithout(isDistinct)} DISTINCT") { + test(s"scan with aggregate push-down: COVAR_POP $withOrWithout DISTINCT") { val df = sql( s"SELECT COVAR_POP(${distinct}bonus, bonus) FROM $catalogAndNamespace." + - s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) checkAggregateRemoved(df) checkAggregatePushed(df, "COVAR_POP") @@ -618,14 +601,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu assert(row(1).getDouble(0) === 2500.0) assert(row(2).getDouble(0) === 0.0) } - } - protected def testCovarSamp(isDistinct: Boolean = false): Unit = { - val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: COVAR_SAMP ${withOrWithout(isDistinct)} DISTINCT") { + test(s"scan with aggregate push-down: COVAR_SAMP $withOrWithout DISTINCT") { val df = sql( s"SELECT COVAR_SAMP(${distinct}bonus, bonus) FROM $catalogAndNamespace." + - s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) checkAggregateRemoved(df) checkAggregatePushed(df, "COVAR_SAMP") @@ -635,14 +615,11 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu assert(row(1).getDouble(0) === 5000.0) assert(row(2).isNullAt(0)) } - } - protected def testCorr(isDistinct: Boolean = false): Unit = { - val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: CORR ${withOrWithout(isDistinct)} DISTINCT") { + test(s"scan with aggregate push-down: CORR $withOrWithout DISTINCT") { val df = sql( s"SELECT CORR(${distinct}bonus, bonus) FROM $catalogAndNamespace." + - s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") checkFilterPushed(df) checkAggregateRemoved(df) checkAggregatePushed(df, "CORR") @@ -652,11 +629,8 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu assert(row(1).getDouble(0) === 1.0) assert(row(2).isNullAt(0)) } - } - protected def testRegrIntercept(isDistinct: Boolean = false): Unit = { - val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: REGR_INTERCEPT ${withOrWithout(isDistinct)} DISTINCT") { + test(s"scan with aggregate push-down: REGR_INTERCEPT $withOrWithout DISTINCT") { val df = sql( s"SELECT REGR_INTERCEPT(${distinct}bonus, bonus) FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") @@ -669,11 +643,8 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu assert(row(1).getDouble(0) === 0.0) assert(row(2).isNullAt(0)) } - } - protected def testRegrSlope(isDistinct: Boolean = false): Unit = { - val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: REGR_SLOPE ${withOrWithout(isDistinct)} DISTINCT") { + test(s"scan with aggregate push-down: REGR_SLOPE $withOrWithout DISTINCT") { val df = sql( s"SELECT REGR_SLOPE(${distinct}bonus, bonus) FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") @@ -686,11 +657,8 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu assert(row(1).getDouble(0) === 1.0) assert(row(2).isNullAt(0)) } - } - protected def testRegrR2(isDistinct: Boolean = false): Unit = { - val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: REGR_R2 ${withOrWithout(isDistinct)} DISTINCT") { + test(s"scan with aggregate push-down: REGR_R2 $withOrWithout DISTINCT") { val df = sql( s"SELECT REGR_R2(${distinct}bonus, bonus) FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") @@ -703,11 +671,8 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu assert(row(1).getDouble(0) === 1.0) assert(row(2).isNullAt(0)) } - } - protected def testRegrSXY(isDistinct: Boolean = false): Unit = { - val distinct = if (isDistinct) "DISTINCT " else "" - test(s"scan with aggregate push-down: REGR_SXY ${withOrWithout(isDistinct)} DISTINCT") { + test(s"scan with aggregate push-down: REGR_SXY $withOrWithout DISTINCT") { val df = sql( s"SELECT REGR_SXY(${distinct}bonus, bonus) FROM $catalogAndNamespace." + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 34026083bb9a6..0433747822391 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -1063,6 +1063,28 @@ ], "sqlState" : "42903" }, + "INVALID_WRITE_DISTRIBUTION" : { + "message" : [ + "The requested write distribution is invalid." + ], + "subClass" : { + "PARTITION_NUM_AND_SIZE" : { + "message" : [ + "The partition number and advisory partition size can't be specified at the same time." + ] + }, + "PARTITION_NUM_WITH_UNSPECIFIED_DISTRIBUTION" : { + "message" : [ + "The number of partitions can't be specified with unspecified distribution." + ] + }, + "PARTITION_SIZE_WITH_UNSPECIFIED_DISTRIBUTION" : { + "message" : [ + "The advisory partition size can't be specified with unspecified distribution." + ] + } + } + }, "LOCATION_ALREADY_EXISTS" : { "message" : [ "Cannot name the managed table as , as its associated location already exists. Please pick a different table name, or remove the existing location first." @@ -2931,11 +2953,6 @@ "Unsupported data type ." ] }, - "_LEGACY_ERROR_TEMP_1178" : { - "message" : [ - "The number of partitions can't be specified with unspecified distribution. Invalid writer requirements detected." - ] - }, "_LEGACY_ERROR_TEMP_1181" : { "message" : [ "Stream-stream join without equality predicate is not supported." diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 0ee0dc6ae6016..2d4624828a94d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -103,11 +103,8 @@ class TaskInfo( // finishTime should be set larger than 0, otherwise "finished" below will return false. assert(time > 0) finishTime = time - if (state == TaskState.FAILED) { - failed = true - } else if (state == TaskState.KILLED) { - killed = true - } + failed = state == TaskState.FAILED + killed = state == TaskState.KILLED } private[spark] def launchSucceeded(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala index 27198039fdbaa..ff12f643497d0 100644 --- a/core/src/test/scala/org/apache/spark/SparkFunSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkFunSuite.scala @@ -31,7 +31,8 @@ import org.apache.logging.log4j._ import org.apache.logging.log4j.core.{LogEvent, Logger, LoggerContext} import org.apache.logging.log4j.core.appender.AbstractAppender import org.apache.logging.log4j.core.config.Property -import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, BeforeAndAfterEach, Failed, Outcome} +import org.scalactic.source.Position +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, BeforeAndAfterEach, Failed, Outcome, Tag} import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite import org.apache.spark.deploy.LocalSparkCluster @@ -137,6 +138,19 @@ abstract class SparkFunSuite java.nio.file.Paths.get(sparkHome, first +: more: _*) } + // subclasses can override this to exclude certain tests by name + // useful when inheriting a test suite but do not want to run all tests in it + protected def excluded: Seq[String] = Seq.empty + + override protected def test(testName: String, testTags: Tag*)(testBody: => Any) + (implicit pos: Position): Unit = { + if (excluded.contains(testName)) { + ignore(s"$testName (excluded)")(testBody) + } else { + super.test(testName, testTags: _*)(testBody) + } + } + /** * Note: this method doesn't support `BeforeAndAfter`. You must use `BeforeAndAfterEach` to * set up and tear down resources. diff --git a/dev/deps/spark-deps-hadoop-2-hive-2.3 b/dev/deps/spark-deps-hadoop-2-hive-2.3 index d9edb110f48a2..e3d588d36cd7b 100644 --- a/dev/deps/spark-deps-hadoop-2-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-2-hive-2.3 @@ -223,9 +223,9 @@ objenesis/3.2//objenesis-3.2.jar okhttp/3.12.12//okhttp-3.12.12.jar okio/1.15.0//okio-1.15.0.jar opencsv/2.3//opencsv-2.3.jar -orc-core/1.8.2/shaded-protobuf/orc-core-1.8.2-shaded-protobuf.jar -orc-mapreduce/1.8.2/shaded-protobuf/orc-mapreduce-1.8.2-shaded-protobuf.jar -orc-shims/1.8.2//orc-shims-1.8.2.jar +orc-core/1.8.3/shaded-protobuf/orc-core-1.8.3-shaded-protobuf.jar +orc-mapreduce/1.8.3/shaded-protobuf/orc-mapreduce-1.8.3-shaded-protobuf.jar +orc-shims/1.8.3//orc-shims-1.8.3.jar oro/2.0.8//oro-2.0.8.jar osgi-resource-locator/1.0.3//osgi-resource-locator-1.0.3.jar paranamer/2.8//paranamer-2.8.jar diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 5c17fff4d3789..fd32245ec2865 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -210,9 +210,9 @@ opencsv/2.3//opencsv-2.3.jar opentracing-api/0.33.0//opentracing-api-0.33.0.jar opentracing-noop/0.33.0//opentracing-noop-0.33.0.jar opentracing-util/0.33.0//opentracing-util-0.33.0.jar -orc-core/1.8.2/shaded-protobuf/orc-core-1.8.2-shaded-protobuf.jar -orc-mapreduce/1.8.2/shaded-protobuf/orc-mapreduce-1.8.2-shaded-protobuf.jar -orc-shims/1.8.2//orc-shims-1.8.2.jar +orc-core/1.8.3/shaded-protobuf/orc-core-1.8.3-shaded-protobuf.jar +orc-mapreduce/1.8.3/shaded-protobuf/orc-mapreduce-1.8.3-shaded-protobuf.jar +orc-shims/1.8.3//orc-shims-1.8.3.jar oro/2.0.8//oro-2.0.8.jar osgi-resource-locator/1.0.3//osgi-resource-locator-1.0.3.jar paranamer/2.8//paranamer-2.8.jar diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 751f0687f2c8e..c31a9362cd7fb 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -271,12 +271,36 @@ def __hash__(self): ], ) +mllib_local = Module( + name="mllib-local", + dependencies=[tags, core], + source_file_regexes=[ + "mllib/local", + ], + sbt_test_goals=[ + "mllib-local/test", + ], +) + + +mllib_common = Module( + name="mllib-common", + dependencies=[tags, mllib_local, sql], + source_file_regexes=[ + "mllib/common", + ], + sbt_test_goals=[ + "mllib-common/test", + ], +) + connect = Module( name="connect", - dependencies=[hive], + dependencies=[hive, mllib_common], source_file_regexes=[ "connector/connect", ], + build_profile_flags=["-Pconnect"], sbt_test_goals=[ "connect/test", "connect-client-jvm/test", @@ -358,24 +382,12 @@ def __hash__(self): ) -mllib_local = Module( - name="mllib-local", - dependencies=[tags, core], - source_file_regexes=[ - "mllib-local", - ], - sbt_test_goals=[ - "mllib-local/test", - ], -) - - mllib = Module( name="mllib", - dependencies=[mllib_local, streaming, sql], + dependencies=[mllib_local, mllib_common, streaming, sql], source_file_regexes=[ "data/mllib/", - "mllib/", + "mllib/core/", ], sbt_test_goals=[ "mllib/test", @@ -501,48 +513,6 @@ def __hash__(self): ], ) -pyspark_connect = Module( - name="pyspark-connect", - dependencies=[pyspark_sql, connect], - source_file_regexes=["python/pyspark/sql/connect"], - python_test_goals=[ - # doctests - "pyspark.sql.connect.catalog", - "pyspark.sql.connect.conf", - "pyspark.sql.connect.group", - "pyspark.sql.connect.session", - "pyspark.sql.connect.window", - "pyspark.sql.connect.column", - "pyspark.sql.connect.readwriter", - "pyspark.sql.connect.dataframe", - "pyspark.sql.connect.functions", - # unittests - "pyspark.sql.tests.connect.test_client", - "pyspark.sql.tests.connect.test_connect_plan", - "pyspark.sql.tests.connect.test_connect_basic", - "pyspark.sql.tests.connect.test_connect_function", - "pyspark.sql.tests.connect.test_connect_column", - "pyspark.sql.tests.connect.test_parity_datasources", - "pyspark.sql.tests.connect.test_parity_catalog", - "pyspark.sql.tests.connect.test_parity_conf", - "pyspark.sql.tests.connect.test_parity_serde", - "pyspark.sql.tests.connect.test_parity_functions", - "pyspark.sql.tests.connect.test_parity_group", - "pyspark.sql.tests.connect.test_parity_dataframe", - "pyspark.sql.tests.connect.test_parity_types", - "pyspark.sql.tests.connect.test_parity_column", - "pyspark.sql.tests.connect.test_parity_readwriter", - "pyspark.sql.tests.connect.test_parity_udf", - "pyspark.sql.tests.connect.test_parity_pandas_udf", - "pyspark.sql.tests.connect.test_parity_pandas_map", - "pyspark.sql.tests.connect.test_parity_arrow_map", - ], - excluded_python_implementations=[ - "PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and - # they aren't available there - ], -) - pyspark_resource = Module( name="pyspark-resource", dependencies=[pyspark_core], @@ -769,6 +739,58 @@ def __hash__(self): ], ) + +pyspark_connect = Module( + name="pyspark-connect", + dependencies=[pyspark_sql, pyspark_ml, connect], + source_file_regexes=[ + "python/pyspark/sql/connect", + "python/pyspark/ml/connect", + ], + python_test_goals=[ + # sql doctests + "pyspark.sql.connect.catalog", + "pyspark.sql.connect.conf", + "pyspark.sql.connect.group", + "pyspark.sql.connect.session", + "pyspark.sql.connect.window", + "pyspark.sql.connect.column", + "pyspark.sql.connect.readwriter", + "pyspark.sql.connect.dataframe", + "pyspark.sql.connect.functions", + # sql unittests + "pyspark.sql.tests.connect.test_client", + "pyspark.sql.tests.connect.test_connect_plan", + "pyspark.sql.tests.connect.test_connect_basic", + "pyspark.sql.tests.connect.test_connect_function", + "pyspark.sql.tests.connect.test_connect_column", + "pyspark.sql.tests.connect.test_parity_datasources", + "pyspark.sql.tests.connect.test_parity_catalog", + "pyspark.sql.tests.connect.test_parity_conf", + "pyspark.sql.tests.connect.test_parity_serde", + "pyspark.sql.tests.connect.test_parity_functions", + "pyspark.sql.tests.connect.test_parity_group", + "pyspark.sql.tests.connect.test_parity_dataframe", + "pyspark.sql.tests.connect.test_parity_types", + "pyspark.sql.tests.connect.test_parity_column", + "pyspark.sql.tests.connect.test_parity_readwriter", + "pyspark.sql.tests.connect.test_parity_udf", + "pyspark.sql.tests.connect.test_parity_pandas_udf", + "pyspark.sql.tests.connect.test_parity_pandas_map", + "pyspark.sql.tests.connect.test_parity_arrow_map", + "pyspark.sql.tests.connect.test_parity_pandas_grouped_map", + # ml doctests + "pyspark.ml.connect.functions", + # ml unittests + "pyspark.ml.tests.connect.test_connect_function", + ], + excluded_python_implementations=[ + "PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and + # they aren't available there + ], +) + + pyspark_errors = Module( name="pyspark-errors", dependencies=[], diff --git a/dev/sparktestsupport/utils.py b/dev/sparktestsupport/utils.py index 6b190eb5ab27a..5c270d0948eca 100755 --- a/dev/sparktestsupport/utils.py +++ b/dev/sparktestsupport/utils.py @@ -112,22 +112,25 @@ def determine_modules_to_test(changed_modules, deduplicated=True): >>> sorted([x.name for x in determine_modules_to_test([modules.sql])]) ... # doctest: +NORMALIZE_WHITESPACE ['avro', 'connect', 'docker-integration-tests', 'examples', 'hive', 'hive-thriftserver', - 'mllib', 'protobuf', 'pyspark-connect', 'pyspark-ml', 'pyspark-mllib', 'pyspark-pandas', - 'pyspark-pandas-slow', 'pyspark-sql', 'repl', 'sparkr', 'sql', 'sql-kafka-0-10'] + 'mllib', 'mllib-common', 'protobuf', 'pyspark-connect', 'pyspark-ml', 'pyspark-mllib', + 'pyspark-pandas', 'pyspark-pandas-slow', 'pyspark-sql', 'repl', 'sparkr', 'sql', + 'sql-kafka-0-10'] >>> sorted([x.name for x in determine_modules_to_test( ... [modules.sparkr, modules.sql], deduplicated=False)]) ... # doctest: +NORMALIZE_WHITESPACE ['avro', 'connect', 'docker-integration-tests', 'examples', 'hive', 'hive-thriftserver', - 'mllib', 'protobuf', 'pyspark-connect', 'pyspark-ml', 'pyspark-mllib', 'pyspark-pandas', - 'pyspark-pandas-slow', 'pyspark-sql', 'repl', 'sparkr', 'sql', 'sql-kafka-0-10'] + 'mllib', 'mllib-common', 'protobuf', 'pyspark-connect', 'pyspark-ml', 'pyspark-mllib', + 'pyspark-pandas', 'pyspark-pandas-slow', 'pyspark-sql', 'repl', 'sparkr', 'sql', + 'sql-kafka-0-10'] >>> sorted([x.name for x in determine_modules_to_test( ... [modules.sql, modules.core], deduplicated=False)]) ... # doctest: +NORMALIZE_WHITESPACE ['avro', 'catalyst', 'connect', 'core', 'docker-integration-tests', 'examples', 'graphx', - 'hive', 'hive-thriftserver', 'mllib', 'mllib-local', 'protobuf', 'pyspark-connect', - 'pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-pandas', 'pyspark-pandas-slow', - 'pyspark-resource', 'pyspark-sql', 'pyspark-streaming', 'repl', 'root', 'sparkr', 'sql', - 'sql-kafka-0-10', 'streaming', 'streaming-kafka-0-10', 'streaming-kinesis-asl'] + 'hive', 'hive-thriftserver', 'mllib', 'mllib-common', 'mllib-local', 'protobuf', + 'pyspark-connect', 'pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-pandas', + 'pyspark-pandas-slow', 'pyspark-resource', 'pyspark-sql', 'pyspark-streaming', 'repl', + 'root', 'sparkr', 'sql', 'sql-kafka-0-10', 'streaming', 'streaming-kafka-0-10', + 'streaming-kinesis-asl'] """ modules_to_test = set() for module in changed_modules: diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index d44639227665d..9b7c469246165 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -7,6 +7,8 @@ + + {{ page.title }} - Spark {{site.SPARK_VERSION_SHORT}} Documentation {% if page.description %} @@ -17,16 +19,13 @@ {% endif %} - - - - + + + + + @@ -34,96 +33,118 @@ - + - + + + {% endif %} -
+
{% if page.url contains "/ml" or page.url contains "/sql" or page.url contains "migration-guide.html" %} {% if page.url contains "migration-guide.html" %} @@ -147,23 +168,27 @@

{{ page.title }}

{% else %}
- {% if page.displayTitle %} -

{{ page.displayTitle }}

- {% else %} -

{{ page.title }}

+ {% if page.url != "/" %} + {% if page.displayTitle %} +

{{ page.displayTitle }}

+ {% else %} +

{{ page.title }}

+ {% endif %} {% endif %} - {{ content }} -
{% endif %}
- - + + + +