From d57164a5cab92e8821563e776c19a424721adce3 Mon Sep 17 00:00:00 2001 From: Stevo Mitric Date: Wed, 27 Mar 2024 18:54:54 +0800 Subject: [PATCH] [SPARK-47430][SQL] Support GROUP BY for MapType ### What changes were proposed in this pull request? Changes proposed in this PR include: - Relaxed checks that prevent aggregating of map types - Added new analyzer rule that uses `MapSort` expression proposed in [this PR](https://github.com/apache/spark/pull/45639) - Created codegen that compares two sorted maps ### Why are the changes needed? Adding new functionality to GROUP BY map types ### Does this PR introduce _any_ user-facing change? Yes, ability to use `GROUP BY MapType` ### How was this patch tested? With new UTs ### Was this patch authored or co-authored using generative AI tooling? No Closes #45549 from stevomitric/stevomitric/map-group-by. Lead-authored-by: Stevo Mitric Co-authored-by: Stefan Kandic Co-authored-by: Stevo Mitric Signed-off-by: Wenchen Fan --- .../main/resources/error/error-classes.json | 6 -- docs/sql-error-conditions.md | 6 -- .../sql/catalyst/expressions/ExprUtils.scala | 9 --- .../expressions/codegen/CodeGenerator.scala | 81 ++++++++++++++----- .../InsertMapSortInGroupingExpressions.scala | 45 +++++++++++ .../optimizer/NormalizeFloatingNumbers.scala | 14 +++- .../sql/catalyst/optimizer/Optimizer.scala | 2 + .../catalyst/plans/logical/LogicalPlan.scala | 20 +---- .../sql/catalyst/rules/RuleIdCollection.scala | 1 + .../analysis/AnalysisErrorSuite.scala | 64 ++------------- .../catalyst/optimizer/OptimizerSuite.scala | 24 ------ .../spark/sql/DataFrameAggregateSuite.scala | 46 ++++++++--- 12 files changed, 161 insertions(+), 157 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 717d5e6631ec1..4aad6f68b03b5 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -1373,12 +1373,6 @@ ], "sqlState" : "42805" }, - "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE" : { - "message" : [ - "The expression cannot be used as a grouping expression because its data type is not an orderable data type." - ], - "sqlState" : "42822" - }, "HLL_INVALID_INPUT_SKETCH_BUFFER" : { "message" : [ "Invalid call to ; only valid HLL sketch buffers are supported as inputs (such as those produced by the `hll_sketch_agg` function)." diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index b05a8d1ff61eb..33e10ed7b5e39 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -852,12 +852,6 @@ GROUP BY `` refers to an expression `` that contains an aggregat GROUP BY position `` is not in select list (valid range is [1, ``]). -### GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE - -[SQLSTATE: 42822](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) - -The expression `` cannot be used as a grouping expression because its data type `` is not an orderable data type. - ### HLL_INVALID_INPUT_SKETCH_BUFFER [SQLSTATE: 22546](sql-error-conditions-sqlstates.html#class-22-data-exception) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 2bbe730d4cfb8..eaf10973e71de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -193,15 +193,6 @@ object ExprUtils extends QueryErrorsBase { messageParameters = Map("sqlExpr" -> expr.sql)) } - // Check if the data type of expr is orderable. - if (expr.dataType.existsRecursively(_.isInstanceOf[MapType])) { - expr.failAnalysis( - errorClass = "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE", - messageParameters = Map( - "sqlExpr" -> toSQLExpr(expr), - "dataType" -> toSQLType(expr.dataType))) - } - if (!expr.deterministic) { // This is just a sanity check, our analysis rule PullOutNondeterministic should // already pull out those nondeterministic expressions and evaluate them in diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 46349a7faf03d..dfe07a443a230 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -660,13 +660,8 @@ class CodegenContext extends Logging { case NullType => "0" case array: ArrayType => val elementType = array.elementType - val elementA = freshName("elementA") - val isNullA = freshName("isNullA") - val elementB = freshName("elementB") - val isNullB = freshName("isNullB") val compareFunc = freshName("compareArray") val minLength = freshName("minLength") - val jt = javaType(elementType) val funcCode: String = s""" public int $compareFunc(ArrayData a, ArrayData b) { @@ -679,22 +674,7 @@ class CodegenContext extends Logging { int lengthB = b.numElements(); int $minLength = (lengthA > lengthB) ? lengthB : lengthA; for (int i = 0; i < $minLength; i++) { - boolean $isNullA = a.isNullAt(i); - boolean $isNullB = b.isNullAt(i); - if ($isNullA && $isNullB) { - // Nothing - } else if ($isNullA) { - return -1; - } else if ($isNullB) { - return 1; - } else { - $jt $elementA = ${getValue("a", elementType, "i")}; - $jt $elementB = ${getValue("b", elementType, "i")}; - int comp = ${genComp(elementType, elementA, elementB)}; - if (comp != 0) { - return comp; - } - } + ${genCompElementsAt("a", "b", "i", elementType)} } if (lengthA < lengthB) { @@ -722,12 +702,71 @@ class CodegenContext extends Logging { } """ s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" + case map: MapType => + val compareFunc = freshName("compareMapData") + val funcCode = genCompMapData(map.keyType, map.valueType, compareFunc) + s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2) case _ => throw QueryExecutionErrors.cannotGenerateCodeForIncomparableTypeError("compare", dataType) } + private def genCompMapData( + keyType: DataType, + valueType: DataType, + compareFunc: String): String = { + s""" + |public int $compareFunc(MapData a, MapData b) { + | int lengthA = a.numElements(); + | int lengthB = b.numElements(); + | ArrayData keyArrayA = a.keyArray(); + | ArrayData valueArrayA = a.valueArray(); + | ArrayData keyArrayB = b.keyArray(); + | ArrayData valueArrayB = b.valueArray(); + | int minLength = (lengthA > lengthB) ? lengthB : lengthA; + | for (int i = 0; i < minLength; i++) { + | ${genCompElementsAt("keyArrayA", "keyArrayB", "i", keyType)} + | ${genCompElementsAt("valueArrayA", "valueArrayB", "i", valueType)} + | } + | + | if (lengthA < lengthB) { + | return -1; + | } else if (lengthA > lengthB) { + | return 1; + | } + | return 0; + |} + """.stripMargin + } + + private def genCompElementsAt(arrayA: String, arrayB: String, i: String, + elementType : DataType): String = { + val elementA = freshName("elementA") + val isNullA = freshName("isNullA") + val elementB = freshName("elementB") + val isNullB = freshName("isNullB") + val jt = javaType(elementType); + s""" + |boolean $isNullA = $arrayA.isNullAt($i); + |boolean $isNullB = $arrayB.isNullAt($i); + |if ($isNullA && $isNullB) { + | // Nothing + |} else if ($isNullA) { + | return -1; + |} else if ($isNullB) { + | return 1; + |} else { + | $jt $elementA = ${getValue(arrayA, elementType, i)}; + | $jt $elementB = ${getValue(arrayB, elementType, i)}; + | int comp = ${genComp(elementType, elementA, elementB)}; + | if (comp != 0) { + | return comp; + | } + |} + """.stripMargin + } + /** * Generates code for greater of two expressions. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala new file mode 100644 index 0000000000000..3d883fa9d9ae2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InsertMapSortInGroupingExpressions.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.MapSort +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE +import org.apache.spark.sql.types.MapType + +/** + * Adds MapSort to group expressions containing map columns, as the key/value paris need to be + * in the correct order before grouping: + * SELECT COUNT(*) FROM TABLE GROUP BY map_column => + * SELECT COUNT(*) FROM TABLE GROUP BY map_sort(map_column) + */ +object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( + _.containsPattern(AGGREGATE), ruleId) { + case a @ Aggregate(groupingExpr, _, _) => + val newGrouping = groupingExpr.map { expr => + if (!expr.isInstanceOf[MapSort] && expr.dataType.isInstanceOf[MapType]) { + MapSort(expr) + } else { + expr + } + } + a.copy(groupingExpressions = newGrouping) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index f946fe76bde4d..2fcc689b9df2b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, TransformValues, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Window} @@ -98,9 +98,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case FloatType | DoubleType => true case StructType(fields) => fields.exists(f => needNormalize(f.dataType)) case ArrayType(et, _) => needNormalize(et) - // Currently MapType is not comparable and analyzer should fail earlier if this case happens. - case _: MapType => - throw SparkException.internalError("grouping/join/window partition keys cannot be map type.") + case MapType(_, vt, _) => needNormalize(vt) case _ => false } @@ -144,6 +142,14 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { val function = normalize(lv) KnownFloatingPointNormalized(ArrayTransform(expr, LambdaFunction(function, Seq(lv)))) + case _ if expr.dataType.isInstanceOf[MapType] => + val MapType(kt, vt, containsNull) = expr.dataType + val keys = NamedLambdaVariable("arg", kt, containsNull) + val values = NamedLambdaVariable("arg", vt, containsNull) + val function = normalize(values) + KnownFloatingPointNormalized(TransformValues(expr, + LambdaFunction(function, Seq(keys, values)))) + case _ => throw SparkException.internalError(s"fail to normalize $expr") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ffb0e3a73389f..f0f86de39e0d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -243,6 +243,8 @@ abstract class Optimizer(catalogManager: CatalogManager) CollapseProject, RemoveRedundantAliases, RemoveNoopOperators) :+ + Batch("InsertMapSortInGroupingExpressions", Once, + InsertMapSortInGroupingExpressions) :+ // This batch must be executed after the `RewriteSubquery` batch, which creates joins. Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+ Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index e1121d1f9026e..b989233da6740 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TreeNodeTag, U import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.MetadataColumnHelper import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.types.{MapType, StructType} +import org.apache.spark.sql.types.StructType abstract class LogicalPlan @@ -348,23 +348,6 @@ object LogicalPlanIntegrity { }.flatten } - /** - * Validate that the grouping key types in Aggregate plans are valid. - * Returns an error message if the check fails, or None if it succeeds. - */ - def validateGroupByTypes(plan: LogicalPlan): Option[String] = { - plan.collectFirst { - case a @ Aggregate(groupingExprs, _, _) => - val badExprs = groupingExprs.filter(_.dataType.isInstanceOf[MapType]).map(_.toString) - if (badExprs.nonEmpty) { - Some(s"Grouping expressions ${badExprs.mkString(", ")} cannot be of type Map " + - s"for plan:\n ${a.treeString}") - } else { - None - } - }.flatten - } - /** * Validate that the aggregation expressions in Aggregate plans are valid. * Returns an error message if the check fails, or None if it succeeds. @@ -417,7 +400,6 @@ object LogicalPlanIntegrity { .orElse(LogicalPlanIntegrity.validateExprIdUniqueness(currentPlan)) .orElse(LogicalPlanIntegrity.validateSchemaOutput(previousPlan, currentPlan)) .orElse(LogicalPlanIntegrity.validateNoDanglingReferences(currentPlan)) - .orElse(LogicalPlanIntegrity.validateGroupByTypes(currentPlan)) .orElse(LogicalPlanIntegrity.validateAggregateExpressions(currentPlan)) .map(err => s"${err}\nPrevious schema:${previousPlan.output.mkString(", ")}" + s"\nPrevious plan: ${previousPlan.treeString}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 08f728da2e9dd..778d56788e89e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -126,6 +126,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.optimizer.EliminateSerialization" :: "org.apache.spark.sql.catalyst.optimizer.EliminateWindowPartitions" :: "org.apache.spark.sql.catalyst.optimizer.InferWindowGroupLimit" :: + "org.apache.spark.sql.catalyst.optimizer.InsertMapSortInGroupingExpressions" :: "org.apache.spark.sql.catalyst.optimizer.LikeSimplification" :: "org.apache.spark.sql.catalyst.optimizer.LimitPushDown" :: "org.apache.spark.sql.catalyst.optimizer.LimitPushDownThroughWindow" :: diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 18729934ea9bb..5a30081872710 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import org.scalatest.Assertions._ - import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -28,7 +26,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{Count, Max} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.{AsOfJoinDirection, Cross, Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.errors.DataTypeErrorsBase import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -59,32 +56,6 @@ private[sql] case class UngroupableData(data: Map[Int, Int]) { def getData: Map[Int, Int] = data } -private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] { - - override def sqlType: DataType = MapType(IntegerType, IntegerType) - - override def serialize(ungroupableData: UngroupableData): MapData = { - val keyArray = new GenericArrayData(ungroupableData.data.keys.toSeq) - val valueArray = new GenericArrayData(ungroupableData.data.values.toSeq) - new ArrayBasedMapData(keyArray, valueArray) - } - - override def deserialize(datum: Any): UngroupableData = { - datum match { - case data: MapData => - val keyArray = data.keyArray().array - val valueArray = data.valueArray().array - assert(keyArray.length == valueArray.length) - val mapData = keyArray.zip(valueArray).toMap.asInstanceOf[Map[Int, Int]] - UngroupableData(mapData) - } - } - - override def userClass: Class[UngroupableData] = classOf[UngroupableData] - - private[spark] override def asNullable: UngroupableUDT = this -} - case class TestFunction( children: Seq[Expression], inputTypes: Seq[AbstractDataType]) @@ -984,8 +955,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { } test("check grouping expression data types") { - def checkDataType( - dataType: DataType, shouldSuccess: Boolean, dataTypeMsg: String = ""): Unit = { + def checkDataType(dataType: DataType): Unit = { val plan = Aggregate( AttributeReference("a", dataType)(exprId = ExprId(2)) :: Nil, @@ -994,18 +964,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { AttributeReference("a", dataType)(exprId = ExprId(2)), AttributeReference("b", IntegerType)(exprId = ExprId(1)))) - if (shouldSuccess) { - assertAnalysisSuccess(plan, true) - } else { - assertAnalysisErrorClass( - inputPlan = plan, - expectedErrorClass = "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE", - expectedMessageParameters = Map( - "sqlExpr" -> "\"a\"", - "dataType" -> dataTypeMsg - ) - ) - } + assertAnalysisSuccess(plan, true) } val supportedDataTypes = Seq( @@ -1015,6 +974,10 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), DateType, TimestampType, ArrayType(IntegerType), + MapType(StringType, LongType), + new StructType() + .add("f1", FloatType, nullable = true) + .add("f2", MapType(StringType, LongType), nullable = true), new StructType() .add("f1", FloatType, nullable = true) .add("f2", StringType, nullable = true), @@ -1023,20 +986,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase { .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true), new GroupableUDT()) supportedDataTypes.foreach { dataType => - checkDataType(dataType, shouldSuccess = true) - } - - val unsupportedDataTypes = Seq( - MapType(StringType, LongType), - new StructType() - .add("f1", FloatType, nullable = true) - .add("f2", MapType(StringType, LongType), nullable = true), - new UngroupableUDT()) - val expectedDataTypeParameters = - Seq("\"MAP\"", "\"STRUCT>\"") - unsupportedDataTypes.zip(expectedDataTypeParameters).foreach { - case (dataType, dataTypeMsg) => - checkDataType(dataType, shouldSuccess = false, dataTypeMsg) + checkDataType(dataType) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala index 590fb323000b9..48cdbbe7be539 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerSuite.scala @@ -103,30 +103,6 @@ class OptimizerSuite extends PlanTest { assert(message1.contains("are dangling")) } - test("Optimizer per rule validation catches invalid grouping types") { - val analyzed = LocalRelation(Symbol("a").map(IntegerType, IntegerType)) - .select(Symbol("a")).analyze - - /** - * A dummy optimizer rule for testing that invalid grouping types are not allowed. - */ - object InvalidGroupingType extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - Aggregate(plan.output, plan.output, plan) - } - } - - val optimizer = new SimpleTestOptimizer() { - override def defaultBatches: Seq[Batch] = - Batch("test", FixedPoint(1), - InvalidGroupingType) :: Nil - } - val message1 = intercept[SparkException] { - optimizer.execute(analyzed) - }.getMessage - assert(message1.contains("cannot be of type Map")) - } - test("Optimizer per rule validation catches invalid aggregation expressions") { val analyzed = LocalRelation(Symbol("a").long, Symbol("b").long) .select(Symbol("a"), Symbol("b")).analyze diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 21d7156a62b36..7c81ac7244878 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -2155,8 +2155,8 @@ class DataFrameAggregateSuite extends QueryTest ) } - test("SPARK-46536 Support GROUP BY CalendarIntervalType") { - val numRows = 50 + private def assertAggregateOnDataframe(df: DataFrame, + expected: Int, aggregateColumn: String): Unit = { val configurations = Seq( Seq.empty[(String, String)], // hash aggregate is used by default Seq(SQLConf.CODEGEN_FACTORY_MODE.key -> "NO_CODEGEN", @@ -2168,22 +2168,46 @@ class DataFrameAggregateSuite extends QueryTest Seq("spark.sql.test.forceApplySortAggregate" -> "true") ) + for (conf <- configurations) { + withSQLConf(conf: _*) { + assert(createAggregate(df).count() == expected) + } + } + + def createAggregate(df: DataFrame): DataFrame = df.groupBy(aggregateColumn).agg(count("*")) + } + + test("SPARK-47430 Support GROUP BY MapType") { + val numRows = 50 + + val dfSameInt = (0 until numRows) + .map(_ => Tuple1(Map(1 -> 1))) + .toDF("m0") + assertAggregateOnDataframe(dfSameInt, 1, "m0") + + val dfSameFloat = (0 until numRows) + .map(i => Tuple1(Map(if (i % 2 == 0) 1 -> 0.0 else 1 -> -0.0 ))) + .toDF("m0") + assertAggregateOnDataframe(dfSameFloat, 1, "m0") + + val dfDifferent = (0 until numRows) + .map(i => Tuple1(Map(i -> i))) + .toDF("m0") + assertAggregateOnDataframe(dfDifferent, numRows, "m0") + } + + test("SPARK-46536 Support GROUP BY CalendarIntervalType") { + val numRows = 50 + val dfSame = (0 until numRows) .map(_ => Tuple1(new CalendarInterval(1, 2, 3))) .toDF("c0") + assertAggregateOnDataframe(dfSame, 1, "c0") val dfDifferent = (0 until numRows) .map(i => Tuple1(new CalendarInterval(i, i, i))) .toDF("c0") - - for (conf <- configurations) { - withSQLConf(conf: _*) { - assert(createAggregate(dfSame).count() == 1) - assert(createAggregate(dfDifferent).count() == numRows) - } - } - - def createAggregate(df: DataFrame): DataFrame = df.groupBy("c0").agg(count("*")) + assertAggregateOnDataframe(dfDifferent, numRows, "c0") } test("SPARK-46779: Group by subquery with a cached relation") {