From a5a6fc0582b90b619f5ec732ca87165c83b519ee Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Thu, 10 Mar 2022 17:41:09 -0800 Subject: [PATCH 01/13] First attempt --- .../optimizer/RewriteDistinctAggregates.scala | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 1a58f45c07a29..29cb0fb7e6e11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -218,9 +218,27 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val aggExpressions = collectAggregateExprs(a) val distinctAggs = aggExpressions.filter(_.isDistinct) + val funcChildren = distinctAggs.flatMap { e => + e.aggregateFunction.children.filter(!_.foldable) + } + + // For each function child, find the first instance that is semantically equivalent. + // E.g., assume funcChildren is the following three expressions: + // [('a + 1), (1 + 'a), 'b] + // then we want the map to be: + // Map(('a + 1) -> ('a + 1), (1 + 'a) -> ('a + 1), 'b -> 'b) + // That is, both ('a + 1) and (1 + 'a) map to ('a + 1). + // This is an n^2 operation, where n is the number of distinct aggregate children, but it + // happens only once every time this rule is called. + val funcChildrenLookup = funcChildren.map { e => + (e, funcChildren.find(fc => e.semanticEquals(fc)).getOrElse(e)) + }.toMap + // Extract distinct aggregate expressions. val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e => - val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet + val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).map { fc => + funcChildrenLookup.getOrElse(fc, fc) + }.toSet if (unfoldableChildren.nonEmpty) { // Only expand the unfoldable children unfoldableChildren @@ -292,7 +310,12 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { af } else { patchAggregateFunctionChildren(af) { x => - distinctAggChildAttrLookup.get(x) + // x might not exactly match any key of distinctAggChildAttrLookup + // only because `distinctAggChildAttrLookup`'s keys have been de-duped + // based on semantic equivalence. So we need to translate x to the + // semantic equivalent that we are actually using. + val x2 = funcChildrenLookup(x) + distinctAggChildAttrLookup.get(x2) } } val newCondition = if (condition.isDefined) { From 0a109d98b246080ebedf764fbd2b9d511bd8f032 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 11 Mar 2022 12:26:13 -0800 Subject: [PATCH 02/13] Fix for foldables --- .../sql/catalyst/optimizer/RewriteDistinctAggregates.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 29cb0fb7e6e11..6bf694a50727a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -314,7 +314,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // only because `distinctAggChildAttrLookup`'s keys have been de-duped // based on semantic equivalence. So we need to translate x to the // semantic equivalent that we are actually using. - val x2 = funcChildrenLookup(x) + val x2 = funcChildrenLookup.getOrElse(x, x) distinctAggChildAttrLookup.get(x2) } } From 38f1f6a8aaa7137a6b4baac27bf7939111769ed2 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Thu, 17 Mar 2022 17:28:41 -0700 Subject: [PATCH 03/13] Update --- .../optimizer/RewriteDistinctAggregates.scala | 43 +++++++++++++++---- .../RewriteDistinctAggregatesSuite.scala | 37 ++++++++++++++++ 2 files changed, 72 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 6bf694a50727a..8ee3df27aa734 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -253,8 +253,42 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } } + // get the count of aggregation groups that takes into account + // even superficial differences in the function children + val distictAggGroupsCount = aggExpressions.filter(_.isDistinct).map { e => + val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet + if (unfoldableChildren.nonEmpty) { + unfoldableChildren + } else { + e.aggregateFunction.children.take(1).toSet + } + }.toSet.size + + def patchAggregateFunctionChildren( + af: AggregateFunction)( + attrs: Expression => Option[Expression]): AggregateFunction = { + val newChildren = af.children.map(c => attrs(c).getOrElse(c)) + af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] + } + // Aggregation strategy can handle queries with a single distinct group without filter clause. - if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) { + if (distinctAggGroups.size == 1 && distictAggGroupsCount > 1 + && !distinctAggs.exists(_.filter.isDefined)) { + // we have multiple groups only because of + // superficial differences. Make them the same so that SparkStrategies + // doesn't complain during sanity check. That is, if we have an aggList of: + // [count(distinct b + 1), sum(distinct 1 + b), sum(c)] + // Change it to: + // [count(distinct b + 1), sum(distinct b + 1), sum(c)] + // therefore we have distinct aggregations over only one expression + val patchedAggExpressions = a.aggregateExpressions.map { e => + e.transformDown { + case e: Expression => + funcChildrenLookup.getOrElse(e, e) + }.asInstanceOf[NamedExpression] + } + a.copy(aggregateExpressions = patchedAggExpressions) + } else if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) { // Create the attributes for the grouping id and the group by clause. val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { @@ -263,13 +297,6 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } val groupByAttrs = groupByMap.map(_._2) - def patchAggregateFunctionChildren( - af: AggregateFunction)( - attrs: Expression => Option[Expression]): AggregateFunction = { - val newChildren = af.children.map(c => attrs(c).getOrElse(c)) - af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] - } - // Setup unique distinct aggregate children. val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index 6e66c91b8a89a..7d1975f1d3b86 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -75,4 +75,41 @@ class RewriteDistinctAggregatesSuite extends PlanTest { .analyze checkRewrite(RewriteDistinctAggregates(input)) } + + test("eliminate multiple distinct groups due to superficial differences") { + val input = testRelation + .groupBy('a)( + countDistinct('b + 'c).as('agg1), + countDistinct('c + 'b).as('agg2), + max('c).as('agg3)) + .analyze + + val expected = testRelation + .groupBy('a)( + countDistinct('b + 'c).as('agg1), + countDistinct('b + 'c).as('agg2), + max('c).as('agg3)) + .analyze + + val rewrite = RewriteDistinctAggregates(input) + comparePlans(expected, rewrite) + } + + test("reduce multiple distinct groups due to superficial differences") { + val input = testRelation + .groupBy('a)( + countDistinct('b + 'c + 'd).as('agg1), + countDistinct('d + 'c + 'b).as('agg2), + countDistinct('b + 'c).as('agg4), + countDistinct('c + 'b).as('agg4), + max('c).as('agg5)) + .analyze + + val rewrite = RewriteDistinctAggregates(input) + rewrite match { + case Aggregate(_, _, Aggregate(_, _, e: Expand)) => + assert(e.projections.size == 3) + case _ => fail(s"Plan is not rewritten:\n$rewrite") + } + } } From 3fa3588a1c52ad5012437777fd1f88043078a466 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 18 Mar 2022 13:13:10 -0700 Subject: [PATCH 04/13] Update test --- .../optimizer/RewriteDistinctAggregatesSuite.scala | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index 7d1975f1d3b86..7b9056f83e887 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -84,15 +84,11 @@ class RewriteDistinctAggregatesSuite extends PlanTest { max('c).as('agg3)) .analyze - val expected = testRelation - .groupBy('a)( - countDistinct('b + 'c).as('agg1), - countDistinct('b + 'c).as('agg2), - max('c).as('agg3)) - .analyze - val rewrite = RewriteDistinctAggregates(input) - comparePlans(expected, rewrite) + rewrite match { + case Aggregate(_, _, LocalRelation(_, _, _)) => + case _ => fail(s"Plan is not as expected:\n$rewrite") + } } test("reduce multiple distinct groups due to superficial differences") { From 4a40f910f4fd44a526d5959fc255e5702aa29151 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 2 Sep 2022 17:19:59 -0700 Subject: [PATCH 05/13] Simplify --- .../optimizer/RewriteDistinctAggregates.scala | 104 ++++++++---------- 1 file changed, 43 insertions(+), 61 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 8ee3df27aa734..de660339d0fb2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -213,32 +213,17 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { case a: Aggregate if mayNeedtoRewrite(a) => rewrite(a) } - def rewrite(a: Aggregate): Aggregate = { + def rewrite(aRaw: Aggregate): Aggregate = { + // make children of distinct aggregations the same if they are different + // only because of superficial differences. + val a = getSanitizedAggregate(aRaw) val aggExpressions = collectAggregateExprs(a) val distinctAggs = aggExpressions.filter(_.isDistinct) - val funcChildren = distinctAggs.flatMap { e => - e.aggregateFunction.children.filter(!_.foldable) - } - - // For each function child, find the first instance that is semantically equivalent. - // E.g., assume funcChildren is the following three expressions: - // [('a + 1), (1 + 'a), 'b] - // then we want the map to be: - // Map(('a + 1) -> ('a + 1), (1 + 'a) -> ('a + 1), 'b -> 'b) - // That is, both ('a + 1) and (1 + 'a) map to ('a + 1). - // This is an n^2 operation, where n is the number of distinct aggregate children, but it - // happens only once every time this rule is called. - val funcChildrenLookup = funcChildren.map { e => - (e, funcChildren.find(fc => e.semanticEquals(fc)).getOrElse(e)) - }.toMap - // Extract distinct aggregate expressions. val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e => - val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).map { fc => - funcChildrenLookup.getOrElse(fc, fc) - }.toSet + val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet if (unfoldableChildren.nonEmpty) { // Only expand the unfoldable children unfoldableChildren @@ -253,42 +238,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } } - // get the count of aggregation groups that takes into account - // even superficial differences in the function children - val distictAggGroupsCount = aggExpressions.filter(_.isDistinct).map { e => - val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet - if (unfoldableChildren.nonEmpty) { - unfoldableChildren - } else { - e.aggregateFunction.children.take(1).toSet - } - }.toSet.size - - def patchAggregateFunctionChildren( - af: AggregateFunction)( - attrs: Expression => Option[Expression]): AggregateFunction = { - val newChildren = af.children.map(c => attrs(c).getOrElse(c)) - af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] - } - // Aggregation strategy can handle queries with a single distinct group without filter clause. - if (distinctAggGroups.size == 1 && distictAggGroupsCount > 1 - && !distinctAggs.exists(_.filter.isDefined)) { - // we have multiple groups only because of - // superficial differences. Make them the same so that SparkStrategies - // doesn't complain during sanity check. That is, if we have an aggList of: - // [count(distinct b + 1), sum(distinct 1 + b), sum(c)] - // Change it to: - // [count(distinct b + 1), sum(distinct b + 1), sum(c)] - // therefore we have distinct aggregations over only one expression - val patchedAggExpressions = a.aggregateExpressions.map { e => - e.transformDown { - case e: Expression => - funcChildrenLookup.getOrElse(e, e) - }.asInstanceOf[NamedExpression] - } - a.copy(aggregateExpressions = patchedAggExpressions) - } else if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) { + if (distinctAggGroups.size > 1 || distinctAggs.exists(_.filter.isDefined)) { // Create the attributes for the grouping id and the group by clause. val gid = AttributeReference("gid", IntegerType, nullable = false)() val groupByMap = a.groupingExpressions.collect { @@ -337,12 +288,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { af } else { patchAggregateFunctionChildren(af) { x => - // x might not exactly match any key of distinctAggChildAttrLookup - // only because `distinctAggChildAttrLookup`'s keys have been de-duped - // based on semantic equivalence. So we need to translate x to the - // semantic equivalent that we are actually using. - val x2 = funcChildrenLookup.getOrElse(x, x) - distinctAggChildAttrLookup.get(x2) + distinctAggChildAttrLookup.get(x) } } val newCondition = if (condition.isDefined) { @@ -463,6 +409,42 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { }} } + private def patchAggregateFunctionChildren( + af: AggregateFunction)( + attrs: Expression => Option[Expression]): AggregateFunction = { + val newChildren = af.children.map(c => attrs(c).getOrElse(c)) + af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] + } + + private def getSanitizedAggregate(a: Aggregate): Aggregate = { + val aggExpressions = collectAggregateExprs(a) + val distinctAggs = aggExpressions.filter(_.isDistinct) + + val funcChildren = distinctAggs.flatMap { e => + e.aggregateFunction.children.filter(!_.foldable) + } + + // For each function child, find the first instance that is semantically equivalent. + // E.g., assume funcChildren is the following three expressions: + // [('a + 1), (1 + 'a), 'b] + // then we want the map to be: + // Map(('a + 1) -> ('a + 1), (1 + 'a) -> ('a + 1), 'b -> 'b) + // That is, both ('a + 1) and (1 + 'a) map to ('a + 1). + // This is an n^2 operation, where n is the number of distinct aggregate children, but it + // happens only once every time this rule is called. + val funcChildrenLookup = funcChildren.map { e => + (e, funcChildren.find(fc => e.semanticEquals(fc)).getOrElse(e)) + }.toMap + + val patchedAggExpressions = a.aggregateExpressions.map { e => + e.transformDown { + case e: Expression => + funcChildrenLookup.getOrElse(e, e) + }.asInstanceOf[NamedExpression] + } + a.copy(aggregateExpressions = patchedAggExpressions) + } + private def nullify(e: Expression) = Literal.create(null, e.dataType) private def expressionAttributePair(e: Expression) = From 165f558590dc46ffe50ebcbc9fcc9cd81f801c96 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 4 Sep 2022 16:42:15 -0700 Subject: [PATCH 06/13] Rename --- .../optimizer/RewriteDistinctAggregates.scala | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index de660339d0fb2..feef5056230ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -213,10 +213,13 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { case a: Aggregate if mayNeedtoRewrite(a) => rewrite(a) } - def rewrite(aRaw: Aggregate): Aggregate = { + def rewrite(aOrig: Aggregate): Aggregate = { // make children of distinct aggregations the same if they are different - // only because of superficial differences. - val a = getSanitizedAggregate(aRaw) + // only because of superficial reasons, e.g.: + // "1 + col1" vs "col1 + 1", both become "1 + col1" + // or + // "col1" vs "Col1", both become "col1" + val a = reduceDistinctAggregateGroups(aOrig) val aggExpressions = collectAggregateExprs(a) val distinctAggs = aggExpressions.filter(_.isDistinct) @@ -248,6 +251,13 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } val groupByAttrs = groupByMap.map(_._2) + def patchAggregateFunctionChildren( + af: AggregateFunction)( + attrs: Expression => Option[Expression]): AggregateFunction = { + val newChildren = af.children.map(c => attrs(c).getOrElse(c)) + af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] + } + // Setup unique distinct aggregate children. val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) @@ -409,14 +419,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { }} } - private def patchAggregateFunctionChildren( - af: AggregateFunction)( - attrs: Expression => Option[Expression]): AggregateFunction = { - val newChildren = af.children.map(c => attrs(c).getOrElse(c)) - af.withNewChildren(newChildren).asInstanceOf[AggregateFunction] - } - - private def getSanitizedAggregate(a: Aggregate): Aggregate = { + private def reduceDistinctAggregateGroups(a: Aggregate): Aggregate = { val aggExpressions = collectAggregateExprs(a) val distinctAggs = aggExpressions.filter(_.isDistinct) @@ -436,6 +439,14 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { (e, funcChildren.find(fc => e.semanticEquals(fc)).getOrElse(e)) }.toMap + val funcChildrenPatched = funcChildren.map { e => + funcChildrenLookup.getOrElse(e, e) + } + + if (funcChildren.distinct.size == funcChildrenPatched.distinct.size) { + return a; + } + val patchedAggExpressions = a.aggregateExpressions.map { e => e.transformDown { case e: Expression => From 27dcffe22f4f18003fdd8e493f677f5324a11b55 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Wed, 7 Sep 2022 12:07:20 -0700 Subject: [PATCH 07/13] Update comments --- .../optimizer/RewriteDistinctAggregates.scala | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index feef5056230ae..d969587b015bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -214,11 +214,14 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } def rewrite(aOrig: Aggregate): Aggregate = { - // make children of distinct aggregations the same if they are different - // only because of superficial reasons, e.g.: - // "1 + col1" vs "col1 + 1", both become "1 + col1" + // Make children of distinct aggregations the same if they are only + // different due to superficial reasons, e.g.: + // "1 + col1" vs "col1 + 1", both should become "1 + col1" // or - // "col1" vs "Col1", both become "col1" + // "col1" vs "Col1", both should become "col1" + // This could potentially reduce the number of distinct + // aggregate groups, and therefore reduce the number of + // projections in Expand (or eliminate the need for Expand) val a = reduceDistinctAggregateGroups(aOrig) val aggExpressions = collectAggregateExprs(a) @@ -408,6 +411,10 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) } else { + // It's possible we avoided rewriting the plan to use Expand only because + // reduceDistinctAggregateGroups reduced the number of distinct aggregate groups + // from > 1 to 1. To prevent SparkStrategies from complaining during sanity check, + // we use the potentially patched Aggregate returned by reduceDistinctAggregateGroups. a } } From 484ca8ed206775b4e8bc7ea016c95d2e386d4a9c Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Wed, 7 Sep 2022 15:00:12 -0700 Subject: [PATCH 08/13] Replace Symbol usage with $"" in new unit tests --- .../RewriteDistinctAggregatesSuite.scala | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index 7b9056f83e887..a523b27eca337 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -78,10 +78,10 @@ class RewriteDistinctAggregatesSuite extends PlanTest { test("eliminate multiple distinct groups due to superficial differences") { val input = testRelation - .groupBy('a)( - countDistinct('b + 'c).as('agg1), - countDistinct('c + 'b).as('agg2), - max('c).as('agg3)) + .groupBy($"a")( + countDistinct($"b" + $"c").as("agg1"), + countDistinct($"c" + $"b").as("agg2"), + max($"c").as("agg3")) .analyze val rewrite = RewriteDistinctAggregates(input) @@ -93,12 +93,12 @@ class RewriteDistinctAggregatesSuite extends PlanTest { test("reduce multiple distinct groups due to superficial differences") { val input = testRelation - .groupBy('a)( - countDistinct('b + 'c + 'd).as('agg1), - countDistinct('d + 'c + 'b).as('agg2), - countDistinct('b + 'c).as('agg4), - countDistinct('c + 'b).as('agg4), - max('c).as('agg5)) + .groupBy($"a")( + countDistinct($"b" + $"c" + $"d").as("agg1"), + countDistinct($"d" + $"c" + $"b").as("agg2"), + countDistinct($"b" + $"c").as("agg3"), + countDistinct($"c" + $"b").as("agg4"), + max($"c").as("agg5")) .analyze val rewrite = RewriteDistinctAggregates(input) From 208fe82c9f643f0572e9f7ff10e80f16f568f8a9 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Thu, 8 Sep 2022 11:41:32 -0700 Subject: [PATCH 09/13] Update tests --- .../RewriteDistinctAggregatesSuite.scala | 4 ++-- .../spark/sql/DataFrameAggregateSuite.scala | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala index a523b27eca337..cb4771dd92f80 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregatesSuite.scala @@ -76,7 +76,7 @@ class RewriteDistinctAggregatesSuite extends PlanTest { checkRewrite(RewriteDistinctAggregates(input)) } - test("eliminate multiple distinct groups due to superficial differences") { + test("SPARK-40382: eliminate multiple distinct groups due to superficial differences") { val input = testRelation .groupBy($"a")( countDistinct($"b" + $"c").as("agg1"), @@ -91,7 +91,7 @@ class RewriteDistinctAggregatesSuite extends PlanTest { } } - test("reduce multiple distinct groups due to superficial differences") { + test("SPARK-40382: reduce multiple distinct groups due to superficial differences") { val input = testRelation .groupBy($"a")( countDistinct($"b" + $"c" + $"d").as("agg1"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 90e2acfe5d688..ff2efe790396a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1485,6 +1485,22 @@ class DataFrameAggregateSuite extends QueryTest val df = Seq(1).toDF("id").groupBy(Stream($"id" + 1, $"id" + 2): _*).sum("id") checkAnswer(df, Row(2, 3, 1)) } + + test("SPARK-40382: All distinct aggregation children are semantically equivalent") { + val df = Seq( + (1, 1, 1), + (1, 2, 3), + (1, 2, 3), + (2, 1, 1), + (2, 2, 5) + ).toDF("k", "c1", "c2") + val res1 = df.groupBy("k") + .agg(sum("c1"), countDistinct($"c2" + 1), sum_distinct(lit(1) + $"c2")) + checkAnswer(res1, Row(1, 5, 2, 6) :: Row(2, 3, 2, 8) :: Nil) + + val res2 = df.selectExpr("count(distinct C2)", "count(distinct c2)") + checkAnswer(res2, Row(3, 3) :: Nil) + } } case class B(c: Option[Double]) From 882cdaab19763571bb5851e51aeab50aa189308b Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 25 Sep 2022 17:22:59 -0700 Subject: [PATCH 10/13] Update --- .../optimizer/RewriteDistinctAggregates.scala | 80 ++++++------------- .../spark/sql/DataFrameAggregateSuite.scala | 38 ++++++--- 2 files changed, 54 insertions(+), 64 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index d969587b015bc..e722d50d8a5d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -213,23 +213,21 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { case a: Aggregate if mayNeedtoRewrite(a) => rewrite(a) } - def rewrite(aOrig: Aggregate): Aggregate = { - // Make children of distinct aggregations the same if they are only - // different due to superficial reasons, e.g.: - // "1 + col1" vs "col1 + 1", both should become "1 + col1" - // or - // "col1" vs "Col1", both should become "col1" - // This could potentially reduce the number of distinct - // aggregate groups, and therefore reduce the number of - // projections in Expand (or eliminate the need for Expand) - val a = reduceDistinctAggregateGroups(aOrig) + def rewrite(a: Aggregate): Aggregate = { val aggExpressions = collectAggregateExprs(a) val distinctAggs = aggExpressions.filter(_.isDistinct) + val funcChildren = distinctAggs.flatMap { e => + e.aggregateFunction.children.filter(!_.foldable) + } + val funcChildrenLookup = funcChildren.map { e => + (e, funcChildren.find(fc => e.semanticEquals(fc)).getOrElse(e)) + }.toMap + // Extract distinct aggregate expressions. val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e => - val unfoldableChildren = e.aggregateFunction.children.filter(!_.foldable).toSet + val unfoldableChildren = ExpressionSet(e.aggregateFunction.children.filter(!_.foldable)) if (unfoldableChildren.nonEmpty) { // Only expand the unfoldable children unfoldableChildren @@ -300,7 +298,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val naf = if (af.children.forall(_.foldable)) { af } else { - patchAggregateFunctionChildren(af) { x => + patchAggregateFunctionChildren(af) { x1 => + val x = funcChildrenLookup.getOrElse(x1, x1) distinctAggChildAttrLookup.get(x) } } @@ -411,11 +410,21 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) } else { - // It's possible we avoided rewriting the plan to use Expand only because - // reduceDistinctAggregateGroups reduced the number of distinct aggregate groups - // from > 1 to 1. To prevent SparkStrategies from complaining during sanity check, - // we use the potentially patched Aggregate returned by reduceDistinctAggregateGroups. - a + // We may have one distinct group only because we grouped using ExpressionSet. + // To prevent SparkStrategies from complaining during sanity check, we need to check whether + // the original list of aggregate expressions had multiple distinct groups and, if so, + // patch that list so we have only one distinct group. + if (funcChildrenLookup.keySet.size > funcChildrenLookup.values.toSet.size) { + val patchedAggExpressions = a.aggregateExpressions.map { e => + e.transformDown { + case e: Expression => + funcChildrenLookup.getOrElse(e, e) + }.asInstanceOf[NamedExpression] + } + a.copy(aggregateExpressions = patchedAggExpressions) + } else { + a + } } } @@ -426,43 +435,6 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { }} } - private def reduceDistinctAggregateGroups(a: Aggregate): Aggregate = { - val aggExpressions = collectAggregateExprs(a) - val distinctAggs = aggExpressions.filter(_.isDistinct) - - val funcChildren = distinctAggs.flatMap { e => - e.aggregateFunction.children.filter(!_.foldable) - } - - // For each function child, find the first instance that is semantically equivalent. - // E.g., assume funcChildren is the following three expressions: - // [('a + 1), (1 + 'a), 'b] - // then we want the map to be: - // Map(('a + 1) -> ('a + 1), (1 + 'a) -> ('a + 1), 'b -> 'b) - // That is, both ('a + 1) and (1 + 'a) map to ('a + 1). - // This is an n^2 operation, where n is the number of distinct aggregate children, but it - // happens only once every time this rule is called. - val funcChildrenLookup = funcChildren.map { e => - (e, funcChildren.find(fc => e.semanticEquals(fc)).getOrElse(e)) - }.toMap - - val funcChildrenPatched = funcChildren.map { e => - funcChildrenLookup.getOrElse(e, e) - } - - if (funcChildren.distinct.size == funcChildrenPatched.distinct.size) { - return a; - } - - val patchedAggExpressions = a.aggregateExpressions.map { e => - e.transformDown { - case e: Expression => - funcChildrenLookup.getOrElse(e, e) - }.asInstanceOf[NamedExpression] - } - a.copy(aggregateExpressions = patchedAggExpressions) - } - private def nullify(e: Expression) = Literal.create(null, e.dataType) private def expressionAttributePair(e: Expression) = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index ff2efe790396a..54911d2a6fb61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1486,20 +1486,38 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer(df, Row(2, 3, 1)) } - test("SPARK-40382: All distinct aggregation children are semantically equivalent") { - val df = Seq( - (1, 1, 1), + test("SPARK-40382: Distinct aggregation expression grouping by semantic equivalence") { + Seq( + (1, 1, 3), (1, 2, 3), (1, 2, 3), (2, 1, 1), (2, 2, 5) - ).toDF("k", "c1", "c2") - val res1 = df.groupBy("k") - .agg(sum("c1"), countDistinct($"c2" + 1), sum_distinct(lit(1) + $"c2")) - checkAnswer(res1, Row(1, 5, 2, 6) :: Row(2, 3, 2, 8) :: Nil) - - val res2 = df.selectExpr("count(distinct C2)", "count(distinct c2)") - checkAnswer(res2, Row(3, 3) :: Nil) + ).toDF("k", "c1", "c2").createOrReplaceTempView("df") + + // all distinct aggregation children are semantically equivalent + val res1 = sql( + """select k, sum(distinct c1 + 1), avg(distinct 1 + c1), count(distinct 1 + C1) + |from df + |group by k + |""".stripMargin) + checkAnswer(res1, Row(1, 5, 2.5, 2) :: Row(2, 5, 2.5, 2) :: Nil) + + // some distinct aggregation children are semantically equivalent + val res2 = sql( + """select k, sum(distinct c1 + 2), avg(distinct 2 + c1), count(distinct c2) + |from df + |group by k + |""".stripMargin) + checkAnswer(res2, Row(1, 7, 3.5, 1) :: Row(2, 7, 3.5, 2) :: Nil) + + // no distinct aggregation children are semantically equivalent + val res3 = sql( + """select k, sum(distinct c1 + 2), avg(distinct 3 + c1), count(distinct c2) + |from df + |group by k + |""".stripMargin) + checkAnswer(res3, Row(1, 7, 4.5, 1) :: Row(2, 7, 4.5, 2) :: Nil) } } From f53136d27bc39cbd7550feb160174dca3bfbd536 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Mon, 26 Sep 2022 17:59:25 -0700 Subject: [PATCH 11/13] Use ExpressionSet as key for various distinct aggregate child maps --- .../optimizer/RewriteDistinctAggregates.scala | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index e722d50d8a5d6..6e7e71bd4a541 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -218,13 +218,6 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val aggExpressions = collectAggregateExprs(a) val distinctAggs = aggExpressions.filter(_.isDistinct) - val funcChildren = distinctAggs.flatMap { e => - e.aggregateFunction.children.filter(!_.foldable) - } - val funcChildrenLookup = funcChildren.map { e => - (e, funcChildren.find(fc => e.semanticEquals(fc)).getOrElse(e)) - }.toMap - // Extract distinct aggregate expressions. val distinctAggGroups = aggExpressions.filter(_.isDistinct).groupBy { e => val unfoldableChildren = ExpressionSet(e.aggregateFunction.children.filter(!_.foldable)) @@ -238,7 +231,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // count(distinct 1) will be explained to count(1) after the rewrite function. // Generally, the distinct aggregateFunction should not run // foldable TypeCheck for the first child. - e.aggregateFunction.children.take(1).toSet + ExpressionSet(e.aggregateFunction.children.take(1)) } } @@ -261,7 +254,9 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Setup unique distinct aggregate children. val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct - val distinctAggChildAttrMap = distinctAggChildren.map(expressionAttributePair) + val distinctAggChildAttrMap = distinctAggChildren.map { e => + ExpressionSet(Seq(e)) -> AttributeReference(e.sql, e.dataType, nullable = true)() + } val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) // Setup all the filters in distinct aggregate. val (distinctAggFilters, distinctAggFilterAttrs, maxConds) = distinctAggs.collect { @@ -299,8 +294,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { af } else { patchAggregateFunctionChildren(af) { x1 => - val x = funcChildrenLookup.getOrElse(x1, x1) - distinctAggChildAttrLookup.get(x) + val es = ExpressionSet(Seq(x1)) + distinctAggChildAttrLookup.get(es) } } val newCondition = if (condition.isDefined) { @@ -414,6 +409,13 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // To prevent SparkStrategies from complaining during sanity check, we need to check whether // the original list of aggregate expressions had multiple distinct groups and, if so, // patch that list so we have only one distinct group. + val funcChildren = distinctAggs.flatMap { e => + e.aggregateFunction.children.filter(!_.foldable) + } + val funcChildrenLookup = funcChildren.map { e => + (e, funcChildren.find(fc => e.semanticEquals(fc)).getOrElse(e)) + }.toMap + if (funcChildrenLookup.keySet.size > funcChildrenLookup.values.toSet.size) { val patchedAggExpressions = a.aggregateExpressions.map { e => e.transformDown { From 9938252d65861651601cef2db24ea12fa5a1ce16 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Sun, 2 Oct 2022 18:37:26 -0700 Subject: [PATCH 12/13] Handle case of one distinct grouping with superficially different function children to Spark strategies --- .../optimizer/RewriteDistinctAggregates.scala | 23 +------------------ .../spark/sql/execution/SparkStrategies.scala | 9 +++++--- 2 files changed, 7 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 6e7e71bd4a541..0ddb9c0dff53f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -405,28 +405,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { } Aggregate(groupByAttrs, patchedAggExpressions, firstAggregate) } else { - // We may have one distinct group only because we grouped using ExpressionSet. - // To prevent SparkStrategies from complaining during sanity check, we need to check whether - // the original list of aggregate expressions had multiple distinct groups and, if so, - // patch that list so we have only one distinct group. - val funcChildren = distinctAggs.flatMap { e => - e.aggregateFunction.children.filter(!_.foldable) - } - val funcChildrenLookup = funcChildren.map { e => - (e, funcChildren.find(fc => e.semanticEquals(fc)).getOrElse(e)) - }.toMap - - if (funcChildrenLookup.keySet.size > funcChildrenLookup.values.toSet.size) { - val patchedAggExpressions = a.aggregateExpressions.map { e => - e.transformDown { - case e: Expression => - funcChildrenLookup.getOrElse(e, e) - }.asInstanceOf[NamedExpression] - } - a.copy(aggregateExpressions = patchedAggExpressions) - } else { - a - } + a } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c64a123e3a78c..7ea72961bd149 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -527,8 +527,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) - if (functionsWithDistinct.map( - _.aggregateFunction.children.filterNot(_.foldable).toSet).distinct.length > 1) { + val distinctAggChildSets = functionsWithDistinct.map { ae => + ExpressionSet(ae.aggregateFunction.children.filterNot(_.foldable)) + }.distinct + if (distinctAggChildSets.length > 1) { // This is a sanity check. We should not reach here when we have multiple distinct // column sets. Our `RewriteDistinctAggregates` should take care this case. throw new IllegalStateException( @@ -560,7 +562,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is disallowed because those two distinct // aggregates have different column expressions. val distinctExpressions = - functionsWithDistinct.head.aggregateFunction.children.filterNot(_.foldable) + functionsWithDistinct.flatMap( + _.aggregateFunction.children.filterNot(_.foldable)).distinct val normalizedNamedDistinctExpressions = distinctExpressions.map { e => // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here // because `distinctExpressions` is not extracted during logical phase. From f7d29df9ac7541c5fe727a6fa037fd9e6a3d9a07 Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Thu, 6 Oct 2022 16:48:05 -0700 Subject: [PATCH 13/13] Update --- .../catalyst/optimizer/RewriteDistinctAggregates.scala | 7 +++---- .../org/apache/spark/sql/execution/SparkStrategies.scala | 3 +-- .../apache/spark/sql/execution/aggregate/AggUtils.scala | 9 ++++++--- .../org/apache/spark/sql/execution/PlannerSuite.scala | 4 ++++ 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index 0ddb9c0dff53f..3a35c08d594a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -255,7 +255,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Setup unique distinct aggregate children. val distinctAggChildren = distinctAggGroups.keySet.flatten.toSeq.distinct val distinctAggChildAttrMap = distinctAggChildren.map { e => - ExpressionSet(Seq(e)) -> AttributeReference(e.sql, e.dataType, nullable = true)() + e.canonicalized -> AttributeReference(e.sql, e.dataType, nullable = true)() } val distinctAggChildAttrs = distinctAggChildAttrMap.map(_._2) // Setup all the filters in distinct aggregate. @@ -293,9 +293,8 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { val naf = if (af.children.forall(_.foldable)) { af } else { - patchAggregateFunctionChildren(af) { x1 => - val es = ExpressionSet(Seq(x1)) - distinctAggChildAttrLookup.get(es) + patchAggregateFunctionChildren(af) { x => + distinctAggChildAttrLookup.get(x.canonicalized) } } val newCondition = if (condition.isDefined) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7ea72961bd149..03e722a86fb21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -562,8 +562,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // [COUNT(DISTINCT bar), COUNT(DISTINCT foo)] is disallowed because those two distinct // aggregates have different column expressions. val distinctExpressions = - functionsWithDistinct.flatMap( - _.aggregateFunction.children.filterNot(_.foldable)).distinct + functionsWithDistinct.head.aggregateFunction.children.filterNot(_.foldable) val normalizedNamedDistinctExpressions = distinctExpressions.map { e => // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here // because `distinctExpressions` is not extracted during logical phase. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 0849ab59f64d0..579a00c7996f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -219,14 +219,17 @@ object AggUtils { } // 3. Create an Aggregate operator for partial aggregation (for distinct) - val distinctColumnAttributeLookup = CUtils.toMap(distinctExpressions, distinctAttributes) + val distinctColumnAttributeLookup = CUtils.toMap(distinctExpressions.map(_.canonicalized), + distinctAttributes) val rewrittenDistinctFunctions = functionsWithDistinct.map { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children // to AttributeReferences. case agg @ AggregateExpression(aggregateFunction, mode, true, _, _) => - aggregateFunction.transformDown(distinctColumnAttributeLookup) - .asInstanceOf[AggregateFunction] + aggregateFunction.transformDown { + case e: Expression if distinctColumnAttributeLookup.contains(e.canonicalized) => + distinctColumnAttributeLookup(e.canonicalized) + }.asInstanceOf[AggregateFunction] case agg => throw new IllegalArgumentException( "Non-distinct aggregate is found in functionsWithDistinct " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index c7bd12c86a4d1..8c5a09cb1890d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -95,6 +95,10 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper { // 2 distinct columns with different order val query3 = sql("SELECT corr(DISTINCT j, k), count(DISTINCT k, j) FROM v GROUP BY i") assertNoExpand(query3.queryExecution.executedPlan) + + // SPARK-40382: 1 distinct expression with cosmetic differences + val query4 = sql("SELECT sum(DISTINCT j), max(DISTINCT J) FROM v GROUP BY i") + assertNoExpand(query4.queryExecution.executedPlan) } }