Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
maropu committed Sep 8, 2020
1 parent b3524b6 commit 3f21635
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1580,36 +1580,25 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
* Replaces logical [[Deduplicate]] operator with an [[Aggregate]] operator.
*/
object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
val rewritePlanMap = mutable.ArrayBuffer[(LogicalPlan, LogicalPlan)]()
val newPlan = plan transform {
case Deduplicate(keys, child) if !child.isStreaming =>
val keyExprIds = keys.map(_.exprId)
val aggCols = child.output.map { attr =>
if (keyExprIds.contains(attr.exprId)) {
attr -> attr
} else {
val alias = Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId)
alias -> alias.newInstance()
}
}.unzip
// SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping
// aggregations by checking the number of grouping keys. The key difference here is that a
// global aggregation always returns at least one row even if there are no input rows. Here
// we append a literal when the grouping key list is empty so that the result aggregate
// operator is properly treated as a grouping aggregation.
val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys
val newAgg = Aggregate(nonemptyKeys, aggCols._1, child)
rewritePlanMap += newAgg -> Aggregate(nonemptyKeys, aggCols._2, child)
newAgg
}

if (rewritePlanMap.nonEmpty) {
assert(!plan.fastEquals(newPlan))
Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1
} else {
plan
}
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
case d @ Deduplicate(keys, child) if !child.isStreaming =>
val keyExprIds = keys.map(_.exprId)
val aggCols = child.output.map { attr =>
if (keyExprIds.contains(attr.exprId)) {
attr
} else {
Alias(new First(attr).toAggregateExpression(), attr.name)()
}
}
// SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping
// aggregations by checking the number of grouping keys. The key difference here is that a
// global aggregation always returns at least one row even if there are no input rows. Here
// we append a literal when the grouping key list is empty so that the result aggregate
// operator is properly treated as a grouping aggregation.
val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys
val newAgg = Aggregate(nonemptyKeys, aggCols, child)
val attrMapping = d.output.zip(newAgg.output)
newAgg -> attrMapping
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{Analyzer, CleanupAliases}
import org.apache.spark.sql.catalyst.analysis.CleanupAliases
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand Down Expand Up @@ -342,11 +342,12 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
*/
private def extractCorrelatedScalarSubqueries[E <: Expression](
expression: E,
subqueries: ArrayBuffer[ScalarSubquery]): E = {
subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): E = {
val newExpression = expression transform {
case s: ScalarSubquery if s.children.nonEmpty =>
subqueries += s
s.plan.output.head
val newExprId = NamedExpression.newExprId
subqueries += s -> newExprId
s.plan.output.head.withExprId(newExprId)
}
newExpression.asInstanceOf[E]
}
Expand Down Expand Up @@ -513,17 +514,16 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
*/
private def constructLeftJoins(
child: LogicalPlan,
subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, Seq[(LogicalPlan, LogicalPlan)]) = {
val rewritePlanMap = ArrayBuffer[(LogicalPlan, LogicalPlan)]()
val newPlan = subqueries.foldLeft(child) {
case (currentChild, ScalarSubquery(query, conditions, _)) =>
subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): LogicalPlan = {
subqueries.foldLeft(child) {
case (currentChild, (ScalarSubquery(query, conditions, _), newExprId)) =>
val origOutput = query.output.head

val resultWithZeroTups = evalSubqueryOnZeroTups(query)
if (resultWithZeroTups.isEmpty) {
// CASE 1: Subquery guaranteed not to have the COUNT bug
Project(
currentChild.output :+ origOutput,
currentChild.output :+ Alias(origOutput, origOutput.name)(exprId = newExprId),
Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
} else {
// Subquery might have the COUNT bug. Add appropriate corrections.
Expand All @@ -543,23 +543,16 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {

if (havingNode.isEmpty) {
// CASE 2: Subquery with no HAVING clause
val joinPlan = Join(currentChild,
Project(query.output :+ alwaysTrueExpr, query),
LeftOuter, conditions.reduceOption(And), JoinHint.NONE)

def buildPlan(exprId: ExprId): LogicalPlan = {
Project(
currentChild.output :+
Alias(
If(IsNull(alwaysTrueRef),
resultWithZeroTups.get,
aggValRef), origOutput.name)(exprId),
joinPlan)
}
Project(
currentChild.output :+
Alias(
If(IsNull(alwaysTrueRef),
resultWithZeroTups.get,
aggValRef), origOutput.name)(exprId = newExprId),
Join(currentChild,
Project(query.output :+ alwaysTrueExpr, query),
LeftOuter, conditions.reduceOption(And), JoinHint.NONE))

val newPlan = buildPlan(origOutput.exprId)
rewritePlanMap += newPlan -> buildPlan(NamedExpression.newExprId)
newPlan
} else {
// CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join.
// Need to modify any operators below the join to pass through all columns
Expand All @@ -575,85 +568,66 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
case op => sys.error(s"Unexpected operator $op in corelated subquery")
}

val joinPlan = Join(currentChild,
Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
LeftOuter, conditions.reduceOption(And), JoinHint.NONE)

def buildPlan(exprId: ExprId): LogicalPlan = {
// CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups
// WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>)
// ELSE (aggregate value) END AS (original column name)
val caseExpr = Alias(CaseWhen(Seq(
(IsNull(alwaysTrueRef), resultWithZeroTups.get),
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
aggValRef),
origOutput.name)(exprId)

Project(
currentChild.output :+ caseExpr,
joinPlan)
}
// CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups
// WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>)
// ELSE (aggregate value) END AS (original column name)
val caseExpr = Alias(CaseWhen(Seq(
(IsNull(alwaysTrueRef), resultWithZeroTups.get),
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
aggValRef),
origOutput.name)(exprId = newExprId)

Project(
currentChild.output :+ caseExpr,
Join(currentChild,
Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
LeftOuter, conditions.reduceOption(And), JoinHint.NONE))

val newPlan = buildPlan(origOutput.exprId)
rewritePlanMap += newPlan -> buildPlan(NamedExpression.newExprId)
newPlan
}
}
}

(newPlan, rewritePlanMap)
}

/**
* Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing correlated scalar
* subqueries.
*/
def apply(plan: LogicalPlan): LogicalPlan = {
val rewritePlanMap = ArrayBuffer[(LogicalPlan, LogicalPlan)]()
val newPlan = plan transform {
case a @ Aggregate(grouping, expressions, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
// We currently only allow correlated subqueries in an aggregate if they are part of the
// grouping expressions. As a result we need to replace all the scalar subqueries in the
// grouping expressions by their result.
val newGrouping = grouping.map { e =>
subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e)
}
val (newChild, rewriteMap) = constructLeftJoins(child, subqueries)
rewritePlanMap ++= rewriteMap
Aggregate(newGrouping, newExpressions, newChild)
} else {
a
}
case p @ Project(expressions, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
val (newChild, rewriteMap) = constructLeftJoins(child, subqueries)
rewritePlanMap ++= rewriteMap
Project(newExpressions, newChild)
} else {
p
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
case a @ Aggregate(grouping, expressions, child) =>
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
// We currently only allow correlated subqueries in an aggregate if they are part of the
// grouping expressions. As a result we need to replace all the scalar subqueries in the
// grouping expressions by their result.
val newGrouping = grouping.map { e =>
subqueries.find(_._1.semanticEquals(e)).map(_._1.plan.output.head).getOrElse(e)
}
case f @ Filter(condition, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
if (subqueries.nonEmpty) {
val (newChild, rewriteMap) = constructLeftJoins(child, subqueries)
rewritePlanMap ++= rewriteMap
Project(f.output, Filter(newCondition, newChild))
} else {
f
}
}

if (rewritePlanMap.nonEmpty) {
assert(!plan.fastEquals(newPlan))
Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1
} else {
newPlan
}
val newAgg = Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries))
val attrMapping = a.output.zip(newAgg.output)
newAgg -> attrMapping
} else {
a -> Nil
}
case p @ Project(expressions, child) =>
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
val newProj = Project(newExpressions, constructLeftJoins(child, subqueries))
val attrMapping = p.output.zip(newProj.output)
newProj -> attrMapping
} else {
p -> Nil
}
case f @ Filter(condition, child) =>
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
if (subqueries.nonEmpty) {
val newProj = Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries)))
val attrMapping = f.output.zip(newProj.output)
newProj -> attrMapping
} else {
f -> Nil
}
}
}

0 comments on commit 3f21635

Please sign in to comment.