From 65389802776fb9cfbf4a5855ed7b5f149155e1a3 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Thu, 13 Oct 2022 22:09:22 +0800 Subject: [PATCH] [SPARK-40382][SQL] Group distinct aggregate expressions by semantically equivalent children in `RewriteDistinctAggregates` ### What changes were proposed in this pull request? In `RewriteDistinctAggregates`, when grouping aggregate expressions by function children, treat children that are semantically equivalent as the same. ### Why are the changes needed? This PR will reduce the number of projections in the Expand operator when there are multiple distinct aggregations with superficially different children. In some cases, it will eliminate the need for an Expand operator. Example: In the following query, the Expand operator creates 3\*n rows (where n is the number of incoming rows) because it has a projection for each of function children `b + 1`, `1 + b` and `c`. ``` create or replace temp view v1 as select * from values (1, 2, 3.0), (1, 3, 4.0), (2, 4, 2.5), (2, 3, 1.0) v1(a, b, c); select a, count(distinct b + 1), avg(distinct 1 + b) filter (where c > 0), sum(c) from v1 group by a; ``` The Expand operator has three projections (each producing a row for each incoming row): ``` [a#87, null, null, 0, null, UnscaledValue(c#89)], <== projection #1 (for regular aggregation) [a#87, (b#88 + 1), null, 1, null, null], <== projection #2 (for distinct aggregation of b + 1) [a#87, null, (1 + b#88), 2, (c#89 > 0.0), null]], <== projection #3 (for distinct aggregation of 1 + b) ``` In reality, the Expand only needs one projection for `1 + b` and `b + 1`, because they are semantically equivalent. With the proposed change, the Expand operator's projections look like this: ``` [a#67, null, 0, null, UnscaledValue(c#69)], <== projection #1 (for regular aggregations) [a#67, (b#68 + 1), 1, (c#69 > 0.0), null]], <== projection #2 (for distinct aggregation on b + 1 and 1 + b) ``` With one less projection, Expand produces 2\*n rows instead of 3\*n rows, but still produces the correct result. In the case where all distinct aggregates have semantically equivalent children, the Expand operator is not needed at all. Benchmark code in the JIRA (SPARK-40382). Before the PR: ``` distinct aggregates: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ all semantically equivalent 14721 14859 195 5.7 175.5 1.0X some semantically equivalent 14569 14572 5 5.8 173.7 1.0X none semantically equivalent 14408 14488 113 5.8 171.8 1.0X ``` After the PR: ``` distinct aggregates: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ all semantically equivalent 3658 3692 49 22.9 43.6 1.0X some semantically equivalent 9124 9214 127 9.2 108.8 0.4X none semantically equivalent 14601 14777 250 5.7 174.1 0.3X ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New unit tests. Closes #37825 from bersprockets/rewritedistinct_issue. Authored-by: Bruce Robbins Signed-off-by: Wenchen Fan --- .../optimizer/RewriteDistinctAggregates.scala | 10 +++--- .../RewriteDistinctAggregatesSuite.scala | 33 ++++++++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 6 ++-- .../sql/execution/aggregate/AggUtils.scala | 9 +++-- .../spark/sql/DataFrameAggregateSuite.scala | 34 +++++++++++++++++++ .../spark/sql/execution/PlannerSuite.scala | 4 +++ 6 files changed, 87 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 1a58f45c07a29..3a35c08d594a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -220,7 +220,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Extract distinct aggregate expressions. val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e => - val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet + val unfoldableChildren = ExpressionSet(e.aggregateFunction.children.filter(!_.foldable)) if (unfoldableChildren.nonEmpty) { // Only expand the unfoldable children unfoldableChildren @@ -231,7 +231,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // count(distinct 1) will be explained to count(1) after the rewrite function. // Generally, the distinct aggregateFunction should not run // foldable TypeCheck for the first child. - e.aggregateFunction.children.take(1).toSet + ExpressionSet(e.aggregateFunction.children.take(1)) } } @@ -254,7 +254,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Setup unique distinct aggregate children. val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct - val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) + val distinctAggChildAttrMap = distinctAggChildren.map { e => + e.canonicalized -> AttributeReference(e.sql, e.dataType, nullable = true)() + } val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) // Setup all the filters in distinct aggregate. val (distinctAggFilters, distinctAggFilterAttrs, maxConds) = distinctAggs.collect { @@ -292,7 +294,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { af } else { patchAggregateFunctionChildren(af) { x => - distinctAggChildAttrLookup.get(x) + distinctAggChildAttrLookup.get(x.canonicalized) } } val newCondition = if (condition.isDefined) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index 6e66c91b8a89a..cb4771dd92f80 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -75,4 +75,37 @@ class RewriteDistinctAggregatesSuite extends PlanTest { .analyze checkRewrite(RewriteDistinctAggregates(input)) } + + test("SPARK-40382: eliminate multiple distinct groups due to superficial differences") { + val input = testRelation + .groupBy($"a")( + countDistinct($"b" + $"c").as("agg1"), + countDistinct($"c" + $"b").as("agg2"), + max($"c").as("agg3")) + .analyze + + val rewrite = RewriteDistinctAggregates(input) + rewrite match { + case Aggregate(_, _, LocalRelation(_, _, _)) => + case _ => fail(s"Plan is not as expected:\n$rewrite") + } + } + + test("SPARK-40382: reduce multiple distinct groups due to superficial differences") { + val input = testRelation + .groupBy($"a")( + countDistinct($"b" + $"c" + $"d").as("agg1"), + countDistinct($"d" + $"c" + $"b").as("agg2"), + countDistinct($"b" + $"c").as("agg3"), + countDistinct($"c" + $"b").as("agg4"), + max($"c").as("agg5")) + .analyze + + val rewrite = RewriteDistinctAggregates(input) + rewrite match { + case Aggregate(_, _, Aggregate(_, _, e: Expand)) => + assert(e.projections.size == 3) + case _ => fail(s"Plan is not rewritten:\n$rewrite") + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c64a123e3a78c..03e722a86fb21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -527,8 +527,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) - if (functionsWithDistinct.map( - _.aggregateFunction.children.filterNot(_.foldable).toSet).distinct.length > 1) { + val distinctAggChildSets = functionsWithDistinct.map { ae => + ExpressionSet(ae.aggregateFunction.children.filterNot(_.foldable)) + }.distinct + if (distinctAggChildSets.length > 1) { // This is a sanity check. We should not reach here when we have multiple distinct // column sets. Our `RewriteDistinctAggregates` should take care this case. throw new IllegalStateException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 0849ab59f64d0..579a00c7996f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -219,14 +219,17 @@ object AggUtils { } // 3. Create an Aggregate operator for partial aggregation (for distinct) - val distinctColumnAttributeLookup = CUtils.toMap(distinctExpressions, distinctAttributes) + val distinctColumnAttributeLookup = CUtils.toMap(distinctExpressions.map(_.canonicalized), + distinctAttributes) val rewrittenDistinctFunctions = functionsWithDistinct.map { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children // to AttributeReferences. case agg @ AggregateExpression(aggregateFunction, mode, true, _, _) => - aggregateFunction.transformDown(distinctColumnAttributeLookup) - .asInstanceOf[AggregateFunction] + aggregateFunction.transformDown { + case e: Expression if distinctColumnAttributeLookup.contains(e.canonicalized) => + distinctColumnAttributeLookup(e.canonicalized) + }.asInstanceOf[AggregateFunction] case agg => throw new IllegalArgumentException( "Non-distinct aggregate is found in functionsWithDistinct " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 90e2acfe5d688..54911d2a6fb61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1485,6 +1485,40 @@ class DataFrameAggregateSuite extends QueryTest val df = Seq(1).toDF("id").groupBy(Stream($"id" + 1, $"id" + 2): _*).sum("id") checkAnswer(df, Row(2, 3, 1)) } + + test("SPARK-40382: Distinct aggregation expression grouping by semantic equivalence") { + Seq( + (1, 1, 3), + (1, 2, 3), + (1, 2, 3), + (2, 1, 1), + (2, 2, 5) + ).toDF("k", "c1", "c2").createOrReplaceTempView("df") + + // all distinct aggregation children are semantically equivalent + val res1 = sql( + """select k, sum(distinct c1 + 1), avg(distinct 1 + c1), count(distinct 1 + C1) + |from df + |group by k + |""".stripMargin) + checkAnswer(res1, Row(1, 5, 2.5, 2) :: Row(2, 5, 2.5, 2) :: Nil) + + // some distinct aggregation children are semantically equivalent + val res2 = sql( + """select k, sum(distinct c1 + 2), avg(distinct 2 + c1), count(distinct c2) + |from df + |group by k + |""".stripMargin) + checkAnswer(res2, Row(1, 7, 3.5, 1) :: Row(2, 7, 3.5, 2) :: Nil) + + // no distinct aggregation children are semantically equivalent + val res3 = sql( + """select k, sum(distinct c1 + 2), avg(distinct 3 + c1), count(distinct c2) + |from df + |group by k + |""".stripMargin) + checkAnswer(res3, Row(1, 7, 4.5, 1) :: Row(2, 7, 4.5, 2) :: Nil) + } } case class B(c: Option[Double]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index c7bd12c86a4d1..8c5a09cb1890d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -95,6 +95,10 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { // 2 distinct columns with different order val query3 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT k, j) FROM v GROUP BY i") assertNoExpand(query3.queryExecution.executedPlan) + + // SPARK-40382: 1 distinct expression with cosmetic differences + val query4 = sql("SELECT sum(DISTINCT j), max(DISTINCT J) FROM v GROUP BY i") + assertNoExpand(query4.queryExecution.executedPlan) } }