Skip to content

Commit

Permalink
[SPARK-47430][SQL] Support GROUP BY for MapType
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Changes proposed in this PR include:
- Relaxed checks that prevent aggregating of map types
- Added new analyzer rule that uses `MapSort` expression proposed in [this PR](#45639)
- Created codegen that compares two sorted maps

### Why are the changes needed?
Adding new functionality to GROUP BY map types

### Does this PR introduce _any_ user-facing change?
Yes, ability to use `GROUP BY MapType`

### How was this patch tested?
With new UTs

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #45549 from stevomitric/stevomitric/map-group-by.

Lead-authored-by: Stevo Mitric <stevo.mitric@databricks.com>
Co-authored-by: Stefan Kandic <stefan.kandic@databricks.com>
Co-authored-by: Stevo Mitric <stevomitric2000@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
3 people authored and cloud-fan committed Mar 27, 2024
1 parent b540cc5 commit d57164a
Show file tree
Hide file tree
Showing 12 changed files with 161 additions and 157 deletions.
6 changes: 0 additions & 6 deletions common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -1373,12 +1373,6 @@
],
"sqlState" : "42805"
},
"GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE" : {
"message" : [
"The expression <sqlExpr> cannot be used as a grouping expression because its data type <dataType> is not an orderable data type."
],
"sqlState" : "42822"
},
"HLL_INVALID_INPUT_SKETCH_BUFFER" : {
"message" : [
"Invalid call to <function>; only valid HLL sketch buffers are supported as inputs (such as those produced by the `hll_sketch_agg` function)."
Expand Down
6 changes: 0 additions & 6 deletions docs/sql-error-conditions.md
Original file line number Diff line number Diff line change
Expand Up @@ -852,12 +852,6 @@ GROUP BY `<index>` refers to an expression `<aggExpr>` that contains an aggregat

GROUP BY position `<index>` is not in select list (valid range is [1, `<size>`]).

### GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE

[SQLSTATE: 42822](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)

The expression `<sqlExpr>` cannot be used as a grouping expression because its data type `<dataType>` is not an orderable data type.

### HLL_INVALID_INPUT_SKETCH_BUFFER

[SQLSTATE: 22546](sql-error-conditions-sqlstates.html#class-22-data-exception)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,15 +193,6 @@ object ExprUtils extends QueryErrorsBase {
messageParameters = Map("sqlExpr" -> expr.sql))
}

// Check if the data type of expr is orderable.
if (expr.dataType.existsRecursively(_.isInstanceOf[MapType])) {
expr.failAnalysis(
errorClass = "GROUP_EXPRESSION_TYPE_IS_NOT_ORDERABLE",
messageParameters = Map(
"sqlExpr" -> toSQLExpr(expr),
"dataType" -> toSQLType(expr.dataType)))
}

if (!expr.deterministic) {
// This is just a sanity check, our analysis rule PullOutNondeterministic should
// already pull out those nondeterministic expressions and evaluate them in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,13 +660,8 @@ class CodegenContext extends Logging {
case NullType => "0"
case array: ArrayType =>
val elementType = array.elementType
val elementA = freshName("elementA")
val isNullA = freshName("isNullA")
val elementB = freshName("elementB")
val isNullB = freshName("isNullB")
val compareFunc = freshName("compareArray")
val minLength = freshName("minLength")
val jt = javaType(elementType)
val funcCode: String =
s"""
public int $compareFunc(ArrayData a, ArrayData b) {
Expand All @@ -679,22 +674,7 @@ class CodegenContext extends Logging {
int lengthB = b.numElements();
int $minLength = (lengthA > lengthB) ? lengthB : lengthA;
for (int i = 0; i < $minLength; i++) {
boolean $isNullA = a.isNullAt(i);
boolean $isNullB = b.isNullAt(i);
if ($isNullA && $isNullB) {
// Nothing
} else if ($isNullA) {
return -1;
} else if ($isNullB) {
return 1;
} else {
$jt $elementA = ${getValue("a", elementType, "i")};
$jt $elementB = ${getValue("b", elementType, "i")};
int comp = ${genComp(elementType, elementA, elementB)};
if (comp != 0) {
return comp;
}
}
${genCompElementsAt("a", "b", "i", elementType)}
}

if (lengthA < lengthB) {
Expand Down Expand Up @@ -722,12 +702,71 @@ class CodegenContext extends Logging {
}
"""
s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)"
case map: MapType =>
val compareFunc = freshName("compareMapData")
val funcCode = genCompMapData(map.keyType, map.valueType, compareFunc)
s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)"
case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)"
case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2)
case _ =>
throw QueryExecutionErrors.cannotGenerateCodeForIncomparableTypeError("compare", dataType)
}

private def genCompMapData(
keyType: DataType,
valueType: DataType,
compareFunc: String): String = {
s"""
|public int $compareFunc(MapData a, MapData b) {
| int lengthA = a.numElements();
| int lengthB = b.numElements();
| ArrayData keyArrayA = a.keyArray();
| ArrayData valueArrayA = a.valueArray();
| ArrayData keyArrayB = b.keyArray();
| ArrayData valueArrayB = b.valueArray();
| int minLength = (lengthA > lengthB) ? lengthB : lengthA;
| for (int i = 0; i < minLength; i++) {
| ${genCompElementsAt("keyArrayA", "keyArrayB", "i", keyType)}
| ${genCompElementsAt("valueArrayA", "valueArrayB", "i", valueType)}
| }
|
| if (lengthA < lengthB) {
| return -1;
| } else if (lengthA > lengthB) {
| return 1;
| }
| return 0;
|}
""".stripMargin
}

private def genCompElementsAt(arrayA: String, arrayB: String, i: String,
elementType : DataType): String = {
val elementA = freshName("elementA")
val isNullA = freshName("isNullA")
val elementB = freshName("elementB")
val isNullB = freshName("isNullB")
val jt = javaType(elementType);
s"""
|boolean $isNullA = $arrayA.isNullAt($i);
|boolean $isNullB = $arrayB.isNullAt($i);
|if ($isNullA && $isNullB) {
| // Nothing
|} else if ($isNullA) {
| return -1;
|} else if ($isNullB) {
| return 1;
|} else {
| $jt $elementA = ${getValue(arrayA, elementType, i)};
| $jt $elementB = ${getValue(arrayB, elementType, i)};
| int comp = ${genComp(elementType, elementA, elementB)};
| if (comp != 0) {
| return comp;
| }
|}
""".stripMargin
}

/**
* Generates code for greater of two expressions.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.MapSort
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE
import org.apache.spark.sql.types.MapType

/**
* Adds MapSort to group expressions containing map columns, as the key/value paris need to be
* in the correct order before grouping:
* SELECT COUNT(*) FROM TABLE GROUP BY map_column =>
* SELECT COUNT(*) FROM TABLE GROUP BY map_sort(map_column)
*/
object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(AGGREGATE), ruleId) {
case a @ Aggregate(groupingExpr, _, _) =>
val newGrouping = groupingExpr.map { expr =>
if (!expr.isInstanceOf[MapSort] && expr.dataType.isInstanceOf[MapType]) {
MapSort(expr)
} else {
expr
}
}
a.copy(groupingExpressions = newGrouping)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, TransformValues, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Window}
Expand Down Expand Up @@ -98,9 +98,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
case FloatType | DoubleType => true
case StructType(fields) => fields.exists(f => needNormalize(f.dataType))
case ArrayType(et, _) => needNormalize(et)
// Currently MapType is not comparable and analyzer should fail earlier if this case happens.
case _: MapType =>
throw SparkException.internalError("grouping/join/window partition keys cannot be map type.")
case MapType(_, vt, _) => needNormalize(vt)
case _ => false
}

Expand Down Expand Up @@ -144,6 +142,14 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
val function = normalize(lv)
KnownFloatingPointNormalized(ArrayTransform(expr, LambdaFunction(function, Seq(lv))))

case _ if expr.dataType.isInstanceOf[MapType] =>
val MapType(kt, vt, containsNull) = expr.dataType
val keys = NamedLambdaVariable("arg", kt, containsNull)
val values = NamedLambdaVariable("arg", vt, containsNull)
val function = normalize(values)
KnownFloatingPointNormalized(TransformValues(expr,
LambdaFunction(function, Seq(keys, values))))

case _ => throw SparkException.internalError(s"fail to normalize $expr")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
CollapseProject,
RemoveRedundantAliases,
RemoveNoopOperators) :+
Batch("InsertMapSortInGroupingExpressions", Once,
InsertMapSortInGroupingExpressions) :+
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TreeNodeTag, U
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.MetadataColumnHelper
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.{MapType, StructType}
import org.apache.spark.sql.types.StructType


abstract class LogicalPlan
Expand Down Expand Up @@ -348,23 +348,6 @@ object LogicalPlanIntegrity {
}.flatten
}

/**
* Validate that the grouping key types in Aggregate plans are valid.
* Returns an error message if the check fails, or None if it succeeds.
*/
def validateGroupByTypes(plan: LogicalPlan): Option[String] = {
plan.collectFirst {
case a @ Aggregate(groupingExprs, _, _) =>
val badExprs = groupingExprs.filter(_.dataType.isInstanceOf[MapType]).map(_.toString)
if (badExprs.nonEmpty) {
Some(s"Grouping expressions ${badExprs.mkString(", ")} cannot be of type Map " +
s"for plan:\n ${a.treeString}")
} else {
None
}
}.flatten
}

/**
* Validate that the aggregation expressions in Aggregate plans are valid.
* Returns an error message if the check fails, or None if it succeeds.
Expand Down Expand Up @@ -417,7 +400,6 @@ object LogicalPlanIntegrity {
.orElse(LogicalPlanIntegrity.validateExprIdUniqueness(currentPlan))
.orElse(LogicalPlanIntegrity.validateSchemaOutput(previousPlan, currentPlan))
.orElse(LogicalPlanIntegrity.validateNoDanglingReferences(currentPlan))
.orElse(LogicalPlanIntegrity.validateGroupByTypes(currentPlan))
.orElse(LogicalPlanIntegrity.validateAggregateExpressions(currentPlan))
.map(err => s"${err}\nPrevious schema:${previousPlan.output.mkString(", ")}" +
s"\nPrevious plan: ${previousPlan.treeString}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.optimizer.EliminateSerialization" ::
"org.apache.spark.sql.catalyst.optimizer.EliminateWindowPartitions" ::
"org.apache.spark.sql.catalyst.optimizer.InferWindowGroupLimit" ::
"org.apache.spark.sql.catalyst.optimizer.InsertMapSortInGroupingExpressions" ::
"org.apache.spark.sql.catalyst.optimizer.LikeSimplification" ::
"org.apache.spark.sql.catalyst.optimizer.LimitPushDown" ::
"org.apache.spark.sql.catalyst.optimizer.LimitPushDownThroughWindow" ::
Expand Down
Loading

0 comments on commit d57164a

Please sign in to comment.