-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-40382][SQL] Group distinct aggregate expressions by semantically equivalent children in RewriteDistinctAggregates
#37825
Conversation
#34953 take a look this pr? |
7943f07
to
84adc8b
Compare
@@ -1451,6 +1451,22 @@ class DataFrameAggregateSuite extends QueryTest | |||
val df = Seq(1).toDF("id").groupBy(Stream($"id" + 1, $"id" + 2): _*).sum("id") | |||
checkAnswer(df, Row(2, 3, 1)) | |||
} | |||
|
|||
test("SPARK-40382: All distinct aggregation children are semantically equivalent") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test succeeds without the changes to RewriteDistinctAggregates. It's just a sanity test to check that the grouping by semantic equivalence doesn't break this case.
If I understand that PR correctly (and I may not), it is related but orthogonal. One PR doesn't preclude the other. The PR you reference is making the |
84adc8b
to
cd5693f
Compare
// This is an n^2 operation, where n is the number of distinct aggregate children, but it | ||
// happens only once every time this rule is called. | ||
val funcChildrenLookup = funcChildren.map { e => | ||
(e, funcChildren.find(fc => e.semanticEquals(fc)).getOrElse(e)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems the code lead to find out itself.
...alyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
Outdated
Show resolved
Hide resolved
// This could potentially reduce the number of distinct | ||
// aggregate groups, and therefore reduce the number of | ||
// projections in Expand (or eliminate the need for Expand) | ||
val a = reduceDistinctAggregateGroups(aOrig) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we just canonicalize the function inputs when group by them? e.g. change e.aggregateFunction.children.filter(!_.foldable).toSet
to ExpressionSet(e.aggregateFunction.children.filter(!_.foldable))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I am working on it, just working through some small complications.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made the change to use ExpressionSet
and also commented on some of the issues.
I still prefer 'sanitizing' each original function child to use the first semantically equivalent child, in essence creating a new set of "original" children, as it bypasses some complexities (in particular the one where we may lose some of the original children as keys when we group by ExpressionSet
).
cd5693f
to
278d060
Compare
@@ -291,7 +298,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |||
val naf = if (af.children.forall(_.foldable)) { | |||
af | |||
} else { | |||
patchAggregateFunctionChildren(af) { x => | |||
patchAggregateFunctionChildren(af) { x1 => | |||
val x = funcChildrenLookup.getOrElse(x1, x1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's one of the complications, and my solution is somewhat brittle.
When grouping by ExpressionSet
, in the case where there are superficially different children, we don't get all of the original children in the keys of distinctAggGroups
. This is because multiple ExpressionSet
s may have the same baseSet but different originals, and groupBy
chooses only one ExpressionSet
to represent the group's key (which is what want: we want groupBy
to group by semantically equivalent children).
However, because distinctAggGroups
is missing some original children in its keys, distinctAggChildAttrLookup
is also missing some original children in its keys.
To bridge this gap, I used funcChildrenLookup
. This data structure maps each original function child to the first semantically equivalent original function child. funcChildrenLookup
will translate the original function child into the key (hopefully) expected by distinctAggChildAttrLookup
. The brittleness is this: this code depends, at the very least, on which ExpressionSet
is chosen by groupBy
as the winner.
In the first version of my PR, I modified the Aggregate (if needed) so there are no superfically different function children, thus there is no complexity when performing the groupings and the patching. I find it bit more straightforward to reason about.
// the original list of aggregate expressions had multiple distinct groups and, if so, | ||
// patch that list so we have only one distinct group. | ||
if (funcChildrenLookup.keySet.size > funcChildrenLookup.values.toSet.size) { | ||
val patchedAggExpressions = a.aggregateExpressions.map { e => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the second complexity.
It seems a little complex. |
...alyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala
Show resolved
Hide resolved
Thanks, I will take a look. This is reference to the fall-through case, where we discover there is really only a single distinct group, correct? |
Yes. |
@@ -254,7 +254,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |||
|
|||
// Setup unique distinct aggregate children. | |||
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if this is necessary, but it's better to use ExpressionSet(distinctAggGroups.keySet.flatten).toSeq
, instead of calling .distinct
on Seq[Expression]
// We may have one distinct group only because we grouped using ExpressionSet. | ||
// To prevent SparkStrategies from complaining during sanity check, we need to check whether | ||
// the original list of aggregate expressions had multiple distinct groups and, if so, | ||
// patch that list so we have only one distinct group. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite understand this. Shall we use ExpressionSet
to fix issues in SparkStrategies
as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we use ExpressionSet to fix issues in SparkStrategies as well?
Looking...
4713ba5
to
da23c38
Compare
RewriteDistinctAggregates
RewriteDistinctAggregates
@@ -560,7 +562,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { | |||
// [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is disallowed because those two distinct | |||
// aggregates have different column expressions. | |||
val distinctExpressions = | |||
functionsWithDistinct.head.aggregateFunction.children.filterNot(_.foldable) | |||
functionsWithDistinct.flatMap( | |||
_.aggregateFunction.children.filterNot(_.foldable)).distinct |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My implementation here has an odd effect in the case where all child sets are semantically equivalent but cosmetically different, e.g.:
explain select k, sum(distinct c1 + 1), avg(distinct 1 + c1), count(distinct 1 + C1)
from v1
group by k;
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- HashAggregate(keys=[k#87], functions=[sum(distinct (c1#88 + 1)#99), avg(distinct (1 + c1#88)#100), count(distinct (1 + C1#88)#101)])
+- Exchange hashpartitioning(k#87, 200), ENSURE_REQUIREMENTS, [plan_id=136]
+- HashAggregate(keys=[k#87], functions=[partial_sum(distinct (c1#88 + 1)#99), partial_avg(distinct (1 + c1#88)#100), partial_count(distinct (1 + C1#88)#101)])
+- HashAggregate(keys=[k#87, (c1#88 + 1)#99, (1 + c1#88)#100, (1 + C1#88)#101], functions=[])
+- Exchange hashpartitioning(k#87, (c1#88 + 1)#99, (1 + c1#88)#100, (1 + C1#88)#101, 200), ENSURE_REQUIREMENTS, [plan_id=132]
+- HashAggregate(keys=[k#87, (c1#88 + 1) AS (c1#88 + 1)#99, (1 + c1#88) AS (1 + c1#88)#100, (1 + C1#88) AS (1 + C1#88)#101], functions=[])
+- LocalTableScan [k#87, c1#88]
The grouping keys in the first aggregate should include the children of the distinct aggregations, and they do. But because I kept the children as cosmetically different (I no longer patch them in RewriteDistinctAggregates
when handling the fall-through case), the grouping keys now include each cosmetic variation (c1 + 1
, 1 + c1
, and 1 + C1
). If I remove one cosmetic variation, the final aggregate gets an error (because one of the aggregation expressions will refer to attributes that were not output in previous plan nodes).
My earlier implementation (where I patch the aggregate expressions in the fall-through case so there are no more cosmetic variations) doesn't have this oddity:
explain select k, sum(distinct c1 + 1), avg(distinct 1 + c1), count(distinct 1 + C1)
from v1
group by k;
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- HashAggregate(keys=[k#8], functions=[sum(distinct (c1#9 + 1)#20), avg(distinct (c1#9 + 1)#20), count(distinct (c1#9 + 1)#20)])
+- Exchange hashpartitioning(k#8, 200), ENSURE_REQUIREMENTS, [plan_id=30]
+- HashAggregate(keys=[k#8], functions=[partial_sum(distinct (c1#9 + 1)#20), partial_avg(distinct (c1#9 + 1)#20), partial_count(distinct (c1#9 + 1)#20)])
+- HashAggregate(keys=[k#8, (c1#9 + 1)#20], functions=[])
+- Exchange hashpartitioning(k#8, (c1#9 + 1)#20, 200), ENSURE_REQUIREMENTS, [plan_id=26]
+- HashAggregate(keys=[k#8, (c1#9 + 1) AS (c1#9 + 1)#20], functions=[])
+- LocalTableScan [k#8, c1#9]
Also my earlier implementation seems about 22% faster for the case where all child sets are semantically equivalent but cosmetically different. I assume because the rows output from the first physical aggregation are narrower (but I have not dug down too deep on this).
@@ -254,7 +254,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |||
|
|||
// Setup unique distinct aggregate children. | |||
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct | |||
val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) | |||
val distinctAggChildAttrMap = distinctAggChildren.map { e => | |||
ExpressionSet(Seq(e)) -> AttributeReference(e.sql, e.dataType, nullable = true)() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about e.canonicalized
instead of ExpressionSet(Seq(e))
?
@@ -560,7 +562,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { | |||
// [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is disallowed because those two distinct | |||
// aggregates have different column expressions. | |||
val distinctExpressions = | |||
functionsWithDistinct.head.aggregateFunction.children.filterNot(_.foldable) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I'm a little confused. Why do we change here? Since all children are semantically equivalent, we can just pick the first distinct function. If we need to look up the child later, we should make sure it uses ExpressionSet
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I'm a little confused.
Ah yes, possibly I was too. I had not read all of planAggregateWithOneDistinct
yet, and I see the creation of rewrittenDistinctFunctions
, where I can possibly take advantage of semantic equivalence.
867b5b3
to
e19b16f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Can you rebase to the latest master branch to retrigger github action jobs?
Will do. I still have this as WIP because I don't understand the impact, if any, of passing attributes created from the non-rewritten distinct aggregate expressions to the xxxAggregateExec constructor as |
e19b16f
to
a524eb9
Compare
RewriteDistinctAggregates
RewriteDistinctAggregates
@cloud-fan I rebased and hopefully the tests will now run. Also I put [WIP] in the correct place in the title, which I will remove once I finished looking at the impact of using those attributes (as I mentioned above). |
@@ -254,7 +254,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { | |||
|
|||
// Setup unique distinct aggregate children. | |||
val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct | |||
val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) | |||
val distinctAggChildAttrMap = distinctAggChildren.map { e => | |||
e.canonicalized -> AttributeReference(e.sql, e.dataType, nullable = true)() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems we can update expressionAttributePair
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
expressionAttributePair
is used in two other places, though, for regular aggregate children and filter expressions where the key does not need to be canonicalized.
RewriteDistinctAggregates
RewriteDistinctAggregates
…ction children to Spark strategies
a524eb9
to
f7d29df
Compare
I removed the WIP designation from the title. I don't see any negative impact of passing attributes created from the non-rewritten distinct aggregate expressions to the xxxAggregateExec constructor. |
thanks, merging to master! |
Thanks for all the help! |
…abled ### What changes were proposed in this pull request? This PR is a followup of #37825, that change the types in the test relation to make the tests pass with ANSI enalbed. ### Why are the changes needed? To recover the test coverage. Currently it fails with ANSI mode on: https://github.com/apache/spark/actions/runs/3246829492/jobs/5326051798#step:9:20487 ``` [info] - SPARK-40382: eliminate multiple distinct groups due to superficial differences *** FAILED *** (5 milliseconds) [info] org.apache.spark.sql.AnalysisException: [DATATYPE_MISMATCH.BINARY_OP_WRONG_TYPE] Cannot resolve "(b + c)" due to data type mismatch: the binary operator requires the input type ("NUMERIC" or "INTERVAL DAY TO SECOND" or "INTERVAL YEAR TO MONTH" or "INTERVAL"), not "STRING". [info] at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.dataTypeMismatch(package.scala:68) [info] at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$5(CheckAnalysis.scala:223) [info] at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$5$adapted(CheckAnalysis.scala:210) [info] at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:295) [info] at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$foreachUp$1(TreeNode.scala:294) [info] at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$foreachUp$1$adapted(TreeNode.scala:294) [info] at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62) [info] at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55) [info] at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49) [info] at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:294) [info] at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$foreachUp$1(TreeNode.scala:2 ``` ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? Manually ran the tests locally. Closes #38250 from HyukjinKwon/SPARK-40382-followup. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
…ly equivalent children in `RewriteDistinctAggregates` ### What changes were proposed in this pull request? In `RewriteDistinctAggregates`, when grouping aggregate expressions by function children, treat children that are semantically equivalent as the same. ### Why are the changes needed? This PR will reduce the number of projections in the Expand operator when there are multiple distinct aggregations with superficially different children. In some cases, it will eliminate the need for an Expand operator. Example: In the following query, the Expand operator creates 3\*n rows (where n is the number of incoming rows) because it has a projection for each of function children `b + 1`, `1 + b` and `c`. ``` create or replace temp view v1 as select * from values (1, 2, 3.0), (1, 3, 4.0), (2, 4, 2.5), (2, 3, 1.0) v1(a, b, c); select a, count(distinct b + 1), avg(distinct 1 + b) filter (where c > 0), sum(c) from v1 group by a; ``` The Expand operator has three projections (each producing a row for each incoming row): ``` [a#87, null, null, 0, null, UnscaledValue(c#89)], <== projection #1 (for regular aggregation) [a#87, (b#88 + 1), null, 1, null, null], <== projection #2 (for distinct aggregation of b + 1) [a#87, null, (1 + b#88), 2, (c#89 > 0.0), null]], <== projection #3 (for distinct aggregation of 1 + b) ``` In reality, the Expand only needs one projection for `1 + b` and `b + 1`, because they are semantically equivalent. With the proposed change, the Expand operator's projections look like this: ``` [a#67, null, 0, null, UnscaledValue(c#69)], <== projection #1 (for regular aggregations) [a#67, (b#68 + 1), 1, (c#69 > 0.0), null]], <== projection #2 (for distinct aggregation on b + 1 and 1 + b) ``` With one less projection, Expand produces 2\*n rows instead of 3\*n rows, but still produces the correct result. In the case where all distinct aggregates have semantically equivalent children, the Expand operator is not needed at all. Benchmark code in the JIRA (SPARK-40382). Before the PR: ``` distinct aggregates: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ all semantically equivalent 14721 14859 195 5.7 175.5 1.0X some semantically equivalent 14569 14572 5 5.8 173.7 1.0X none semantically equivalent 14408 14488 113 5.8 171.8 1.0X ``` After the PR: ``` distinct aggregates: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ all semantically equivalent 3658 3692 49 22.9 43.6 1.0X some semantically equivalent 9124 9214 127 9.2 108.8 0.4X none semantically equivalent 14601 14777 250 5.7 174.1 0.3X ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New unit tests. Closes apache#37825 from bersprockets/rewritedistinct_issue. Authored-by: Bruce Robbins <bersprockets@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
…abled ### What changes were proposed in this pull request? This PR is a followup of apache#37825, that change the types in the test relation to make the tests pass with ANSI enalbed. ### Why are the changes needed? To recover the test coverage. Currently it fails with ANSI mode on: https://github.com/apache/spark/actions/runs/3246829492/jobs/5326051798#step:9:20487 ``` [info] - SPARK-40382: eliminate multiple distinct groups due to superficial differences *** FAILED *** (5 milliseconds) [info] org.apache.spark.sql.AnalysisException: [DATATYPE_MISMATCH.BINARY_OP_WRONG_TYPE] Cannot resolve "(b + c)" due to data type mismatch: the binary operator requires the input type ("NUMERIC" or "INTERVAL DAY TO SECOND" or "INTERVAL YEAR TO MONTH" or "INTERVAL"), not "STRING". [info] at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.dataTypeMismatch(package.scala:68) [info] at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$5(CheckAnalysis.scala:223) [info] at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$5$adapted(CheckAnalysis.scala:210) [info] at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:295) [info] at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$foreachUp$1(TreeNode.scala:294) [info] at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$foreachUp$1$adapted(TreeNode.scala:294) [info] at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62) [info] at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55) [info] at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49) [info] at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:294) [info] at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$foreachUp$1(TreeNode.scala:2 ``` ### Does this PR introduce _any_ user-facing change? No, test-only. ### How was this patch tested? Manually ran the tests locally. Closes apache#38250 from HyukjinKwon/SPARK-40382-followup. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
What changes were proposed in this pull request?
In
RewriteDistinctAggregates
, when grouping aggregate expressions by function children, treat children that are semantically equivalent as the same.Why are the changes needed?
This PR will reduce the number of projections in the Expand operator when there are multiple distinct aggregations with superficially different children. In some cases, it will eliminate the need for an Expand operator.
Example: In the following query, the Expand operator creates 3*n rows (where n is the number of incoming rows) because it has a projection for each of function children
b + 1
,1 + b
andc
.The Expand operator has three projections (each producing a row for each incoming row):
In reality, the Expand only needs one projection for
1 + b
andb + 1
, because they are semantically equivalent.With the proposed change, the Expand operator's projections look like this:
With one less projection, Expand produces 2*n rows instead of 3*n rows, but still produces the correct result.
In the case where all distinct aggregates have semantically equivalent children, the Expand operator is not needed at all.
Benchmark code in the JIRA (SPARK-40382).
Before the PR:
After the PR:
Does this PR introduce any user-facing change?
No.
How was this patch tested?
New unit tests.