Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-47430][SQL] Support GROUP BY for MapType #45549

Closed
Closed
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
a081649
initial working version
stefankandic Feb 8, 2024
1441549
add golden files
stefankandic Feb 11, 2024
1be06e3
add map sort to other languages
stefankandic Feb 14, 2024
249e903
fix typoes
stefankandic Feb 28, 2024
aaae883
fix scalastyle issue
stefankandic Feb 28, 2024
acaf95e
add proto golden files
stefankandic Feb 28, 2024
5619fdb
fix python function call
stefankandic Feb 28, 2024
7754c14
fix ci errors
stefankandic Feb 29, 2024
f0ebf5d
fix ci checks
stefankandic Feb 29, 2024
1f78167
Optimized map-sort by switching to array sorting
stevomitric Mar 12, 2024
a5eb480
Potential tests fix
stevomitric Mar 13, 2024
9497f99
Potential tests fix 2
stevomitric Mar 13, 2024
5e38220
Allowed group by expression with Maps
stevomitric Mar 14, 2024
03a752d
replaced map data type with arrays in test
stevomitric Mar 14, 2024
b80afed
Added codegen for map ordering
stevomitric Mar 17, 2024
5e7a033
Removed TODOs and changed parmIndex to ordinal
stevomitric Mar 17, 2024
ab70f1e
Shortened map sort function and added more docs
stevomitric Mar 18, 2024
e79d65c
updated map_sort test suite
stevomitric Mar 18, 2024
28d6f70
Added map normalization and import cleanup
stevomitric Mar 18, 2024
a435355
Update sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunction…
stevomitric Mar 18, 2024
c9901d0
Update sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunction…
stevomitric Mar 18, 2024
da6a710
docs fix
stevomitric Mar 18, 2024
81008c2
Updated codegen and removed once test-case
stevomitric Mar 19, 2024
86b29c5
Update python/pyspark/sql/functions/builtin.py
stevomitric Mar 19, 2024
c08ab6c
Updated 'select.show' to give more info in map_sort desc
stevomitric Mar 19, 2024
31a797c
Restructured docs, removed unused variable and refactored code
stevomitric Mar 19, 2024
69e3b48
Removed map_sort function but left the MapSort expression
stevomitric Mar 21, 2024
51ab204
Merge branch 'master' into stevomitric/map-expr
stevomitric Mar 21, 2024
8d9ac51
aditional erasions
stevomitric Mar 21, 2024
2951bcc
removed ExpressionDescription
stevomitric Mar 21, 2024
0fc3c6a
Moved ordering outside of comapre function
stevomitric Mar 21, 2024
0c7d21a
Removed oredering type
stevomitric Mar 21, 2024
8af9c42
Merge branch 'stevomitric/map-expr' into stevomitric/map-group-by
stevomitric Mar 22, 2024
671dabe
Removed second parameter from MapSort expression invocation
stevomitric Mar 22, 2024
7cc928a
Fix scalastyles issues
stevomitric Mar 22, 2024
2e27026
Merge remote-tracking branch 'upstream/master' into stevomitric/map-g…
stevomitric Mar 22, 2024
d84b2b5
Updated normalised functionality
stevomitric Mar 22, 2024
9f1035d
Refactored aggregation rule in optimizer
stevomitric Mar 22, 2024
d184d48
Removed GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE error and fixed a test
stevomitric Mar 25, 2024
04d68cc
Update sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expr…
stevomitric Mar 25, 2024
ebb3325
added scala stripMargin identation control
stevomitric Mar 25, 2024
185f7f1
Replaced comparison in array with genCompElementsAt
stevomitric Mar 25, 2024
3137f6a
Refactor optimizer rule for InsertMapSortInGroupingExpressions
stevomitric Mar 25, 2024
14fdcd2
Update sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expr…
stevomitric Mar 26, 2024
7fe7b7e
Refactored code-gen and separated optimizer rule in separate batch
stevomitric Mar 26, 2024
8076045
refactored tests
stevomitric Mar 26, 2024
3c573e0
Regenerated sql-error-conditions.md
stevomitric Mar 26, 2024
c6050c0
Removed a test that checks for Map as an invalid grouping type
stevomitric Mar 26, 2024
3eac76c
Fixed map-group-by test
stevomitric Mar 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -1373,12 +1373,6 @@
],
"sqlState" : "42805"
},
"GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE" : {
"message" : [
"The expression <sqlExpr> cannot be used as a grouping expression because its data type <dataType> is not an orderable data type."
],
"sqlState" : "42822"
},
"HLL_INVALID_INPUT_SKETCH_BUFFER" : {
"message" : [
"Invalid call to <function>; only valid HLL sketch buffers are supported as inputs (such as those produced by the `hll_sketch_agg` function)."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,6 @@ object ExprUtils extends QueryErrorsBase {
messageParameters = Map("sqlExpr" -> expr.sql))
}

// Check if the data type of expr is orderable.
if (expr.dataType.existsRecursively(_.isInstanceOf[MapType])) {
expr.failAnalysis(
errorClass = "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also remove the error from error-classes.json

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+ modified a test inside AnalysisErrorSuite.scala that uses it

messageParameters = Map(
"sqlExpr" -> toSQLExpr(expr),
"dataType" -> toSQLType(expr.dataType)))
}

if (!expr.deterministic) {
// This is just a sanity check, our analysis rule PullOutNondeterministic should
// already pull out those nondeterministic expressions and evaluate them in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,13 +660,8 @@ class CodegenContext extends Logging {
case NullType => "0"
case array: ArrayType =>
val elementType = array.elementType
val elementA = freshName("elementA")
val isNullA = freshName("isNullA")
val elementB = freshName("elementB")
val isNullB = freshName("isNullB")
val compareFunc = freshName("compareArray")
val minLength = freshName("minLength")
val jt = javaType(elementType)
val funcCode: String =
s"""
public int $compareFunc(ArrayData a, ArrayData b) {
Expand All @@ -679,22 +674,7 @@ class CodegenContext extends Logging {
int lengthB = b.numElements();
int $minLength = (lengthA > lengthB) ? lengthB : lengthA;
for (int i = 0; i < $minLength; i++) {
boolean $isNullA = a.isNullAt(i);
boolean $isNullB = b.isNullAt(i);
if ($isNullA && $isNullB) {
// Nothing
} else if ($isNullA) {
return -1;
} else if ($isNullB) {
return 1;
} else {
$jt $elementA = ${getValue("a", elementType, "i")};
$jt $elementB = ${getValue("b", elementType, "i")};
int comp = ${genComp(elementType, elementA, elementB)};
if (comp != 0) {
return comp;
}
}
${genCompElementsAt("a", "b", "i", elementType)}
}

if (lengthA < lengthB) {
Expand Down Expand Up @@ -722,12 +702,71 @@ class CodegenContext extends Logging {
}
"""
s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)"
case map: MapType =>
val compareFunc = freshName("compareMapData")
val funcCode = genCompMapData(map.keyType, map.valueType, compareFunc)
s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)"
case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)"
case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2)
case _ =>
throw QueryExecutionErrors.cannotGenerateCodeForIncomparableTypeError("compare", dataType)
}

private def genCompMapData(
keyType: DataType,
valueType: DataType,
compareFunc: String): String = {
s"""
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
|public int $compareFunc(MapData a, MapData b) {
| int lengthA = a.numElements();
| int lengthB = b.numElements();
| ArrayData keyArrayA = a.keyArray();
| ArrayData valueArrayA = a.valueArray();
| ArrayData keyArrayB = b.keyArray();
| ArrayData valueArrayB = b.valueArray();
| int minLength = (lengthA > lengthB) ? lengthB : lengthA;
| for (int i = 0; i < minLength; i++) {
| ${genCompElementsAt("keyArrayA", "keyArrayB", "i", keyType)}
| ${genCompElementsAt("valueArrayA", "valueArrayB", "i", valueType)}
| }
|
| if (lengthA < lengthB) {
| return -1;
| } else if (lengthA > lengthB) {
| return 1;
| }
| return 0;
|}
""".stripMargin
}

private def genCompElementsAt(arrayA: String, arrayB: String, i: String,
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
elementType : DataType): String = {
val elementA = freshName("elementA")
val isNullA = freshName("isNullA")
val elementB = freshName("elementB")
val isNullB = freshName("isNullB")
val jt = javaType(elementType);
s"""
|boolean $isNullA = $arrayA.isNullAt($i);
|boolean $isNullB = $arrayB.isNullAt($i);
|if ($isNullA && $isNullB) {
| // Nothing
|} else if ($isNullA) {
| return -1;
|} else if ($isNullB) {
| return 1;
|} else {
| $jt $elementA = ${getValue(arrayA, elementType, i)};
| $jt $elementB = ${getValue(arrayB, elementType, i)};
| int comp = ${genComp(elementType, elementA, elementB)};
| if (comp != 0) {
| return comp;
| }
|}
""".stripMargin
}

/**
* Generates code for greater of two expressions.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.MapSort
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE
import org.apache.spark.sql.types.MapType

/**
* Adds MapSort to group expressions containing map columns, as the key/value paris need to be
* in the correct order before grouping:
* SELECT COUNT(*) FROM TABLE GROUP BY map_column =>
* SELECT COUNT(*) FROM TABLE GROUP BY map_sort(map_column)
*/
object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(AGGREGATE), ruleId) {
case a @ Aggregate(groupingExpr, _, _) =>
val newGrouping = groupingExpr.map { expr =>
if (!expr.isInstanceOf[MapSort] && expr.dataType.isInstanceOf[MapType]) {
MapSort(expr)
} else {
expr
}
}
a.copy(groupingExpressions = newGrouping)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, TransformValues, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Window}
Expand Down Expand Up @@ -98,9 +98,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
case FloatType | DoubleType => true
case StructType(fields) => fields.exists(f => needNormalize(f.dataType))
case ArrayType(et, _) => needNormalize(et)
// Currently MapType is not comparable and analyzer should fail earlier if this case happens.
case _: MapType =>
throw SparkException.internalError("grouping/join/window partition keys cannot be map type.")
case MapType(_, vt, _) => needNormalize(vt)
cloud-fan marked this conversation as resolved.
Show resolved Hide resolved
case _ => false
}

Expand Down Expand Up @@ -144,6 +142,14 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
val function = normalize(lv)
KnownFloatingPointNormalized(ArrayTransform(expr, LambdaFunction(function, Seq(lv))))

case _ if expr.dataType.isInstanceOf[MapType] =>
val MapType(kt, vt, containsNull) = expr.dataType
val keys = NamedLambdaVariable("arg", kt, containsNull)
val values = NamedLambdaVariable("arg", vt, containsNull)
val function = normalize(values)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we normalize map keys?

Copy link
Contributor Author

@stevomitric stevomitric Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to normalize keys? Would a conflict arise for a map that looks like this: Map(0.0 -> 1, -0.0 -> 2) as we don't support duplicate keys and -0.0 would be reduced to 0.0. Aren't keys normalized on map creation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, this should be handled when creating the map. We should fix ArrayBasedMapBuilder to handle floating points well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created a separate PR for map normalization.

KnownFloatingPointNormalized(TransformValues(expr,
LambdaFunction(function, Seq(keys, values))))

case _ => throw SparkException.internalError(s"fail to normalize $expr")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
CollapseProject,
RemoveRedundantAliases,
RemoveNoopOperators) :+
Batch("InsertMapSortInGroupingExpressions", Once,
InsertMapSortInGroupingExpressions) :+
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TreeNodeTag, U
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.MetadataColumnHelper
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.{MapType, StructType}
import org.apache.spark.sql.types.StructType


abstract class LogicalPlan
Expand Down Expand Up @@ -348,23 +348,6 @@ object LogicalPlanIntegrity {
}.flatten
}

/**
* Validate that the grouping key types in Aggregate plans are valid.
* Returns an error message if the check fails, or None if it succeeds.
*/
def validateGroupByTypes(plan: LogicalPlan): Option[String] = {
plan.collectFirst {
case a @ Aggregate(groupingExprs, _, _) =>
val badExprs = groupingExprs.filter(_.dataType.isInstanceOf[MapType]).map(_.toString)
if (badExprs.nonEmpty) {
Some(s"Grouping expressions ${badExprs.mkString(", ")} cannot be of type Map " +
s"for plan:\n ${a.treeString}")
} else {
None
}
}.flatten
}

/**
* Validate that the aggregation expressions in Aggregate plans are valid.
* Returns an error message if the check fails, or None if it succeeds.
Expand Down Expand Up @@ -417,7 +400,6 @@ object LogicalPlanIntegrity {
.orElse(LogicalPlanIntegrity.validateExprIdUniqueness(currentPlan))
.orElse(LogicalPlanIntegrity.validateSchemaOutput(previousPlan, currentPlan))
.orElse(LogicalPlanIntegrity.validateNoDanglingReferences(currentPlan))
.orElse(LogicalPlanIntegrity.validateGroupByTypes(currentPlan))
.orElse(LogicalPlanIntegrity.validateAggregateExpressions(currentPlan))
.map(err => s"${err}\nPrevious schema:${previousPlan.output.mkString(", ")}" +
s"\nPrevious plan: ${previousPlan.treeString}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.optimizer.EliminateSerialization" ::
"org.apache.spark.sql.catalyst.optimizer.EliminateWindowPartitions" ::
"org.apache.spark.sql.catalyst.optimizer.InferWindowGroupLimit" ::
"org.apache.spark.sql.catalyst.optimizer.InsertMapSortInGroupingExpressions" ::
"org.apache.spark.sql.catalyst.optimizer.LikeSimplification" ::
"org.apache.spark.sql.catalyst.optimizer.LimitPushDown" ::
"org.apache.spark.sql.catalyst.optimizer.LimitPushDownThroughWindow" ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.sql.catalyst.analysis

import org.scalatest.Assertions._

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.dsl.expressions._
Expand All @@ -28,7 +26,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{Count, Max}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.{AsOfJoinDirection, Cross, Inner, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData}
import org.apache.spark.sql.errors.DataTypeErrorsBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -59,32 +56,6 @@ private[sql] case class UngroupableData(data: Map[Int, Int]) {
def getData: Map[Int, Int] = data
}

private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] {

override def sqlType: DataType = MapType(IntegerType, IntegerType)

override def serialize(ungroupableData: UngroupableData): MapData = {
val keyArray = new GenericArrayData(ungroupableData.data.keys.toSeq)
val valueArray = new GenericArrayData(ungroupableData.data.values.toSeq)
new ArrayBasedMapData(keyArray, valueArray)
}

override def deserialize(datum: Any): UngroupableData = {
datum match {
case data: MapData =>
val keyArray = data.keyArray().array
val valueArray = data.valueArray().array
assert(keyArray.length == valueArray.length)
val mapData = keyArray.zip(valueArray).toMap.asInstanceOf[Map[Int, Int]]
UngroupableData(mapData)
}
}

override def userClass: Class[UngroupableData] = classOf[UngroupableData]

private[spark] override def asNullable: UngroupableUDT = this
}

case class TestFunction(
children: Seq[Expression],
inputTypes: Seq[AbstractDataType])
Expand Down Expand Up @@ -1005,8 +976,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase {
}

test("check grouping expression data types") {
def checkDataType(
dataType: DataType, shouldSuccess: Boolean, dataTypeMsg: String = ""): Unit = {
def checkDataType(dataType: DataType): Unit = {
val plan =
Aggregate(
AttributeReference("a", dataType)(exprId = ExprId(2)) :: Nil,
Expand All @@ -1015,18 +985,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase {
AttributeReference("a", dataType)(exprId = ExprId(2)),
AttributeReference("b", IntegerType)(exprId = ExprId(1))))

if (shouldSuccess) {
assertAnalysisSuccess(plan, true)
} else {
assertAnalysisErrorClass(
inputPlan = plan,
expectedErrorClass = "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE",
expectedMessageParameters = Map(
"sqlExpr" -> "\"a\"",
"dataType" -> dataTypeMsg
)
)
}
assertAnalysisSuccess(plan, true)
}

val supportedDataTypes = Seq(
Expand All @@ -1036,6 +995,10 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase {
FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
DateType, TimestampType,
ArrayType(IntegerType),
MapType(StringType, LongType),
new StructType()
.add("f1", FloatType, nullable = true)
.add("f2", MapType(StringType, LongType), nullable = true),
new StructType()
.add("f1", FloatType, nullable = true)
.add("f2", StringType, nullable = true),
Expand All @@ -1044,20 +1007,7 @@ class AnalysisErrorSuite extends AnalysisTest with DataTypeErrorsBase {
.add("f2", ArrayType(BooleanType, containsNull = true), nullable = true),
new GroupableUDT())
supportedDataTypes.foreach { dataType =>
checkDataType(dataType, shouldSuccess = true)
}

val unsupportedDataTypes = Seq(
MapType(StringType, LongType),
new StructType()
.add("f1", FloatType, nullable = true)
.add("f2", MapType(StringType, LongType), nullable = true),
new UngroupableUDT())
val expectedDataTypeParameters =
Seq("\"MAP<STRING, BIGINT>\"", "\"STRUCT<f1: FLOAT, f2: MAP<STRING, BIGINT>>\"")
unsupportedDataTypes.zip(expectedDataTypeParameters).foreach {
case (dataType, dataTypeMsg) =>
checkDataType(dataType, shouldSuccess = false, dataTypeMsg)
checkDataType(dataType)
}
}

Expand Down
Loading