From 3f2163554ff918d96527f92a6fd33c31dabc0d92 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Tue, 8 Sep 2020 10:22:08 +0900 Subject: [PATCH] Fix --- .../sql/catalyst/optimizer/Optimizer.scala | 49 ++---- .../sql/catalyst/optimizer/subquery.scala | 162 ++++++++---------- 2 files changed, 87 insertions(+), 124 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 86aa1f2cd61d9..d1aea17538859 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1580,36 +1580,25 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { * Replaces logical [[Deduplicate]] operator with an [[Aggregate]] operator. */ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - val rewritePlanMap = mutable.ArrayBuffer[(LogicalPlan, LogicalPlan)]() - val newPlan = plan transform { - case Deduplicate(keys, child) if !child.isStreaming => - val keyExprIds = keys.map(_.exprId) - val aggCols = child.output.map { attr => - if (keyExprIds.contains(attr.exprId)) { - attr -> attr - } else { - val alias = Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId) - alias -> alias.newInstance() - } - }.unzip - // SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping - // aggregations by checking the number of grouping keys. The key difference here is that a - // global aggregation always returns at least one row even if there are no input rows. Here - // we append a literal when the grouping key list is empty so that the result aggregate - // operator is properly treated as a grouping aggregation. - val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys - val newAgg = Aggregate(nonemptyKeys, aggCols._1, child) - rewritePlanMap += newAgg -> Aggregate(nonemptyKeys, aggCols._2, child) - newAgg - } - - if (rewritePlanMap.nonEmpty) { - assert(!plan.fastEquals(newPlan)) - Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1 - } else { - plan - } + def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput { + case d @ Deduplicate(keys, child) if !child.isStreaming => + val keyExprIds = keys.map(_.exprId) + val aggCols = child.output.map { attr => + if (keyExprIds.contains(attr.exprId)) { + attr + } else { + Alias(new First(attr).toAggregateExpression(), attr.name)() + } + } + // SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping + // aggregations by checking the number of grouping keys. The key difference here is that a + // global aggregation always returns at least one row even if there are no input rows. Here + // we append a literal when the grouping key list is empty so that the result aggregate + // operator is properly treated as a grouping aggregation. + val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys + val newAgg = Aggregate(nonemptyKeys, aggCols, child) + val attrMapping = d.output.zip(newAgg.output) + newAgg -> attrMapping } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index aaef7a49a5472..8476fce2bfd93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{Analyzer, CleanupAliases} +import org.apache.spark.sql.catalyst.analysis.CleanupAliases import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -342,11 +342,12 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { */ private def extractCorrelatedScalarSubqueries[E <: Expression]( expression: E, - subqueries: ArrayBuffer[ScalarSubquery]): E = { + subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): E = { val newExpression = expression transform { case s: ScalarSubquery if s.children.nonEmpty => - subqueries += s - s.plan.output.head + val newExprId = NamedExpression.newExprId + subqueries += s -> newExprId + s.plan.output.head.withExprId(newExprId) } newExpression.asInstanceOf[E] } @@ -513,17 +514,16 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { */ private def constructLeftJoins( child: LogicalPlan, - subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, Seq[(LogicalPlan, LogicalPlan)]) = { - val rewritePlanMap = ArrayBuffer[(LogicalPlan, LogicalPlan)]() - val newPlan = subqueries.foldLeft(child) { - case (currentChild, ScalarSubquery(query, conditions, _)) => + subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): LogicalPlan = { + subqueries.foldLeft(child) { + case (currentChild, (ScalarSubquery(query, conditions, _), newExprId)) => val origOutput = query.output.head val resultWithZeroTups = evalSubqueryOnZeroTups(query) if (resultWithZeroTups.isEmpty) { // CASE 1: Subquery guaranteed not to have the COUNT bug Project( - currentChild.output :+ origOutput, + currentChild.output :+ Alias(origOutput, origOutput.name)(exprId = newExprId), Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) } else { // Subquery might have the COUNT bug. Add appropriate corrections. @@ -543,23 +543,16 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { if (havingNode.isEmpty) { // CASE 2: Subquery with no HAVING clause - val joinPlan = Join(currentChild, - Project(query.output :+ alwaysTrueExpr, query), - LeftOuter, conditions.reduceOption(And), JoinHint.NONE) - - def buildPlan(exprId: ExprId): LogicalPlan = { - Project( - currentChild.output :+ - Alias( - If(IsNull(alwaysTrueRef), - resultWithZeroTups.get, - aggValRef), origOutput.name)(exprId), - joinPlan) - } + Project( + currentChild.output :+ + Alias( + If(IsNull(alwaysTrueRef), + resultWithZeroTups.get, + aggValRef), origOutput.name)(exprId = newExprId), + Join(currentChild, + Project(query.output :+ alwaysTrueExpr, query), + LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) - val newPlan = buildPlan(origOutput.exprId) - rewritePlanMap += newPlan -> buildPlan(NamedExpression.newExprId) - newPlan } else { // CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join. // Need to modify any operators below the join to pass through all columns @@ -575,85 +568,66 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { case op => sys.error(s"Unexpected operator $op in corelated subquery") } - val joinPlan = Join(currentChild, - Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot), - LeftOuter, conditions.reduceOption(And), JoinHint.NONE) - - def buildPlan(exprId: ExprId): LogicalPlan = { - // CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups - // WHEN NOT (original HAVING clause expr) THEN CAST(null AS ) - // ELSE (aggregate value) END AS (original column name) - val caseExpr = Alias(CaseWhen(Seq( - (IsNull(alwaysTrueRef), resultWithZeroTups.get), - (Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))), - aggValRef), - origOutput.name)(exprId) - - Project( - currentChild.output :+ caseExpr, - joinPlan) - } + // CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups + // WHEN NOT (original HAVING clause expr) THEN CAST(null AS ) + // ELSE (aggregate value) END AS (original column name) + val caseExpr = Alias(CaseWhen(Seq( + (IsNull(alwaysTrueRef), resultWithZeroTups.get), + (Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))), + aggValRef), + origOutput.name)(exprId = newExprId) + + Project( + currentChild.output :+ caseExpr, + Join(currentChild, + Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot), + LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) - val newPlan = buildPlan(origOutput.exprId) - rewritePlanMap += newPlan -> buildPlan(NamedExpression.newExprId) - newPlan } } } - - (newPlan, rewritePlanMap) } /** * Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing correlated scalar * subqueries. */ - def apply(plan: LogicalPlan): LogicalPlan = { - val rewritePlanMap = ArrayBuffer[(LogicalPlan, LogicalPlan)]() - val newPlan = plan transform { - case a @ Aggregate(grouping, expressions, child) => - val subqueries = ArrayBuffer.empty[ScalarSubquery] - val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) - if (subqueries.nonEmpty) { - // We currently only allow correlated subqueries in an aggregate if they are part of the - // grouping expressions. As a result we need to replace all the scalar subqueries in the - // grouping expressions by their result. - val newGrouping = grouping.map { e => - subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e) - } - val (newChild, rewriteMap) = constructLeftJoins(child, subqueries) - rewritePlanMap ++= rewriteMap - Aggregate(newGrouping, newExpressions, newChild) - } else { - a - } - case p @ Project(expressions, child) => - val subqueries = ArrayBuffer.empty[ScalarSubquery] - val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) - if (subqueries.nonEmpty) { - val (newChild, rewriteMap) = constructLeftJoins(child, subqueries) - rewritePlanMap ++= rewriteMap - Project(newExpressions, newChild) - } else { - p + def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput { + case a @ Aggregate(grouping, expressions, child) => + val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)] + val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) + if (subqueries.nonEmpty) { + // We currently only allow correlated subqueries in an aggregate if they are part of the + // grouping expressions. As a result we need to replace all the scalar subqueries in the + // grouping expressions by their result. + val newGrouping = grouping.map { e => + subqueries.find(_._1.semanticEquals(e)).map(_._1.plan.output.head).getOrElse(e) } - case f @ Filter(condition, child) => - val subqueries = ArrayBuffer.empty[ScalarSubquery] - val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries) - if (subqueries.nonEmpty) { - val (newChild, rewriteMap) = constructLeftJoins(child, subqueries) - rewritePlanMap ++= rewriteMap - Project(f.output, Filter(newCondition, newChild)) - } else { - f - } - } - - if (rewritePlanMap.nonEmpty) { - assert(!plan.fastEquals(newPlan)) - Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1 - } else { - newPlan - } + val newAgg = Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries)) + val attrMapping = a.output.zip(newAgg.output) + newAgg -> attrMapping + } else { + a -> Nil + } + case p @ Project(expressions, child) => + val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)] + val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) + if (subqueries.nonEmpty) { + val newProj = Project(newExpressions, constructLeftJoins(child, subqueries)) + val attrMapping = p.output.zip(newProj.output) + newProj -> attrMapping + } else { + p -> Nil + } + case f @ Filter(condition, child) => + val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)] + val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries) + if (subqueries.nonEmpty) { + val newProj = Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries))) + val attrMapping = f.output.zip(newProj.output) + newProj -> attrMapping + } else { + f -> Nil + } } }