-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-47430][SQL] Support GROUP BY for MapType #45549
Changes from 43 commits
a081649
1441549
1be06e3
249e903
aaae883
acaf95e
5619fdb
7754c14
f0ebf5d
1f78167
a5eb480
9497f99
5e38220
03a752d
b80afed
5e7a033
ab70f1e
e79d65c
28d6f70
a435355
c9901d0
da6a710
81008c2
86b29c5
c08ab6c
31a797c
69e3b48
51ab204
8d9ac51
2951bcc
0fc3c6a
0c7d21a
8af9c42
671dabe
7cc928a
2e27026
d84b2b5
9f1035d
d184d48
04d68cc
ebb3325
185f7f1
3137f6a
14fdcd2
7fe7b7e
8076045
3c573e0
c6050c0
3eac76c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -660,13 +660,8 @@ class CodegenContext extends Logging { | |
case NullType => "0" | ||
case array: ArrayType => | ||
val elementType = array.elementType | ||
val elementA = freshName("elementA") | ||
val isNullA = freshName("isNullA") | ||
val elementB = freshName("elementB") | ||
val isNullB = freshName("isNullB") | ||
val compareFunc = freshName("compareArray") | ||
val minLength = freshName("minLength") | ||
val jt = javaType(elementType) | ||
val funcCode: String = | ||
s""" | ||
public int $compareFunc(ArrayData a, ArrayData b) { | ||
|
@@ -679,22 +674,7 @@ class CodegenContext extends Logging { | |
int lengthB = b.numElements(); | ||
int $minLength = (lengthA > lengthB) ? lengthB : lengthA; | ||
for (int i = 0; i < $minLength; i++) { | ||
boolean $isNullA = a.isNullAt(i); | ||
boolean $isNullB = b.isNullAt(i); | ||
if ($isNullA && $isNullB) { | ||
// Nothing | ||
} else if ($isNullA) { | ||
return -1; | ||
} else if ($isNullB) { | ||
return 1; | ||
} else { | ||
$jt $elementA = ${getValue("a", elementType, "i")}; | ||
$jt $elementB = ${getValue("b", elementType, "i")}; | ||
int comp = ${genComp(elementType, elementA, elementB)}; | ||
if (comp != 0) { | ||
return comp; | ||
} | ||
} | ||
${genCompElementsAt("a", "b", "i", elementType)} | ||
} | ||
|
||
if (lengthA < lengthB) { | ||
|
@@ -722,12 +702,76 @@ class CodegenContext extends Logging { | |
} | ||
""" | ||
s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" | ||
case map: MapType => | ||
val compareFunc = freshName("compareMapData") | ||
val funcCode = genCompMapData(map.keyType, map.valueType, compareFunc) | ||
s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" | ||
case other if other.isInstanceOf[AtomicType] => s"$c1.compare($c2)" | ||
case udt: UserDefinedType[_] => genComp(udt.sqlType, c1, c2) | ||
case _ => | ||
throw QueryExecutionErrors.cannotGenerateCodeForIncomparableTypeError("compare", dataType) | ||
} | ||
|
||
private def genCompMapData( | ||
keyType: DataType, | ||
valueType: DataType, | ||
compareFunc : String): String = { | ||
stevomitric marked this conversation as resolved.
Show resolved
Hide resolved
|
||
val keyArrayA = freshName("keyArrayA") | ||
val keyArrayB = freshName("keyArrayB") | ||
val valueArrayA = freshName("valueArrayA") | ||
val valueArrayB = freshName("valueArrayB") | ||
val minLength = freshName("minLength") | ||
s""" | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|public int $compareFunc(MapData a, MapData b) { | ||
| int lengthA = a.numElements(); | ||
| int lengthB = b.numElements(); | ||
| ArrayData $keyArrayA = a.keyArray(); | ||
| ArrayData $valueArrayA = a.valueArray(); | ||
| ArrayData $keyArrayB = b.keyArray(); | ||
| ArrayData $valueArrayB = b.valueArray(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do the above 4 variables need to use |
||
| int $minLength = (lengthA > lengthB) ? lengthB : lengthA; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
| for (int i = 0; i < $minLength; i++) { | ||
| ${genCompElementsAt(keyArrayA, keyArrayB, "i", keyType)} | ||
| ${genCompElementsAt(valueArrayA, valueArrayB, "i", valueType)} | ||
| } | ||
| | ||
| if (lengthA < lengthB) { | ||
| return -1; | ||
| } else if (lengthA > lengthB) { | ||
| return 1; | ||
| } | ||
| return 0; | ||
|} | ||
""".stripMargin | ||
} | ||
|
||
private def genCompElementsAt(arrayA: String, arrayB: String, i: String, | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
elementType : DataType): String = { | ||
val elementA = freshName("elementA") | ||
val isNullA = freshName("isNullA") | ||
val elementB = freshName("elementB") | ||
val isNullB = freshName("isNullB") | ||
val jt = javaType(elementType); | ||
s""" | ||
|boolean $isNullA = $arrayA.isNullAt($i); | ||
|boolean $isNullB = $arrayB.isNullAt($i); | ||
|if ($isNullA && $isNullB) { | ||
| // Nothing | ||
|} else if ($isNullA) { | ||
| return -1; | ||
|} else if ($isNullB) { | ||
| return 1; | ||
|} else { | ||
| $jt $elementA = ${getValue(arrayA, elementType, i)}; | ||
| $jt $elementB = ${getValue(arrayB, elementType, i)}; | ||
| int comp = ${genComp(elementType, elementA, elementB)}; | ||
| if (comp != 0) { | ||
| return comp; | ||
| } | ||
|} | ||
""".stripMargin | ||
} | ||
|
||
/** | ||
* Generates code for greater of two expressions. | ||
* | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.catalyst.optimizer | ||
|
||
import org.apache.spark.sql.catalyst.expressions.MapSort | ||
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} | ||
import org.apache.spark.sql.catalyst.rules.Rule | ||
import org.apache.spark.sql.catalyst.trees.TreePattern.AGGREGATE | ||
import org.apache.spark.sql.types.MapType | ||
|
||
/** | ||
* Adds MapSort to group expressions containing map columns, as the key/value paris need to be | ||
* in the correct order before grouping: | ||
* SELECT COUNT(*) FROM TABLE GROUP BY map_column => | ||
* SELECT COUNT(*) FROM TABLE GROUP BY map_sort(map_column) | ||
*/ | ||
object InsertMapSortInGroupingExpressions extends Rule[LogicalPlan] { | ||
override def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( | ||
_.containsPattern(AGGREGATE), ruleId) { | ||
case a @ Aggregate(groupingExpr, _, _) => | ||
val newGrouping = groupingExpr.map { expr => | ||
if (!expr.isInstanceOf[MapSort] && expr.dataType.isInstanceOf[MapType]) { | ||
MapSort(expr) | ||
} else { | ||
expr | ||
} | ||
} | ||
a.copy(groupingExpressions = newGrouping) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ | |
package org.apache.spark.sql.catalyst.optimizer | ||
|
||
import org.apache.spark.SparkException | ||
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression} | ||
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, TransformValues, UnaryExpression} | ||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} | ||
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys | ||
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Window} | ||
|
@@ -98,9 +98,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { | |
case FloatType | DoubleType => true | ||
case StructType(fields) => fields.exists(f => needNormalize(f.dataType)) | ||
case ArrayType(et, _) => needNormalize(et) | ||
// Currently MapType is not comparable and analyzer should fail earlier if this case happens. | ||
case _: MapType => | ||
throw SparkException.internalError("grouping/join/window partition keys cannot be map type.") | ||
case MapType(_, vt, _) => needNormalize(vt) | ||
cloud-fan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
case _ => false | ||
} | ||
|
||
|
@@ -144,6 +142,14 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { | |
val function = normalize(lv) | ||
KnownFloatingPointNormalized(ArrayTransform(expr, LambdaFunction(function, Seq(lv)))) | ||
|
||
case _ if expr.dataType.isInstanceOf[MapType] => | ||
val MapType(kt, vt, containsNull) = expr.dataType | ||
val keys = NamedLambdaVariable("arg", kt, containsNull) | ||
val values = NamedLambdaVariable("arg", vt, containsNull) | ||
val function = normalize(values) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why don't we normalize map keys? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to normalize keys? Would a conflict arise for a map that looks like this: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right, this should be handled when creating the map. We should fix There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Created a separate PR for map normalization. |
||
KnownFloatingPointNormalized(TransformValues(expr, | ||
LambdaFunction(function, Seq(keys, values)))) | ||
|
||
case _ => throw SparkException.internalError(s"fail to normalize $expr") | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -244,7 +244,9 @@ abstract class Optimizer(catalogManager: CatalogManager) | |
RemoveRedundantAliases, | ||
RemoveNoopOperators) :+ | ||
// This batch must be executed after the `RewriteSubquery` batch, which creates joins. | ||
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+ | ||
Batch("NormalizeFloatingNumbers", Once, | ||
InsertMapSortInGroupingExpressions, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we create a new batch for this rule? |
||
NormalizeFloatingNumbers) :+ | ||
Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression) | ||
|
||
// remove any batches with no rules. this may happen when subclasses do not add optional rules. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should also remove the error from
error-classes.json
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+ modified a test inside
AnalysisErrorSuite.scala
that uses it