Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
maropu committed Sep 8, 2020
1 parent b3524b6 commit 9bcc4e0
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 119 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1580,36 +1580,25 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
* Replaces logical [[Deduplicate]] operator with an [[Aggregate]] operator.
*/
object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
val rewritePlanMap = mutable.ArrayBuffer[(LogicalPlan, LogicalPlan)]()
val newPlan = plan transform {
case Deduplicate(keys, child) if !child.isStreaming =>
val keyExprIds = keys.map(_.exprId)
val aggCols = child.output.map { attr =>
if (keyExprIds.contains(attr.exprId)) {
attr -> attr
} else {
val alias = Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId)
alias -> alias.newInstance()
}
}.unzip
// SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping
// aggregations by checking the number of grouping keys. The key difference here is that a
// global aggregation always returns at least one row even if there are no input rows. Here
// we append a literal when the grouping key list is empty so that the result aggregate
// operator is properly treated as a grouping aggregation.
val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys
val newAgg = Aggregate(nonemptyKeys, aggCols._1, child)
rewritePlanMap += newAgg -> Aggregate(nonemptyKeys, aggCols._2, child)
newAgg
}

if (rewritePlanMap.nonEmpty) {
assert(!plan.fastEquals(newPlan))
Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1
} else {
plan
}
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
case d @ Deduplicate(keys, child) if !child.isStreaming =>
val keyExprIds = keys.map(_.exprId)
val aggCols = child.output.map { attr =>
if (keyExprIds.contains(attr.exprId)) {
attr
} else {
Alias(new First(attr).toAggregateExpression(), attr.name)()
}
}
// SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping
// aggregations by checking the number of grouping keys. The key difference here is that a
// global aggregation always returns at least one row even if there are no input rows. Here
// we append a literal when the grouping key list is empty so that the result aggregate
// operator is properly treated as a grouping aggregation.
val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys
val newAgg = Aggregate(nonemptyKeys, aggCols, child)
val attrMapping = d.output.zip(newAgg.output)
newAgg -> attrMapping
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{Analyzer, CleanupAliases}
import org.apache.spark.sql.catalyst.analysis.CleanupAliases
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
Expand Down Expand Up @@ -513,9 +513,8 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
*/
private def constructLeftJoins(
child: LogicalPlan,
subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, Seq[(LogicalPlan, LogicalPlan)]) = {
val rewritePlanMap = ArrayBuffer[(LogicalPlan, LogicalPlan)]()
val newPlan = subqueries.foldLeft(child) {
subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = {
subqueries.foldLeft(child) {
case (currentChild, ScalarSubquery(query, conditions, _)) =>
val origOutput = query.output.head

Expand Down Expand Up @@ -543,23 +542,16 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {

if (havingNode.isEmpty) {
// CASE 2: Subquery with no HAVING clause
val joinPlan = Join(currentChild,
Project(query.output :+ alwaysTrueExpr, query),
LeftOuter, conditions.reduceOption(And), JoinHint.NONE)

def buildPlan(exprId: ExprId): LogicalPlan = {
Project(
currentChild.output :+
Alias(
If(IsNull(alwaysTrueRef),
resultWithZeroTups.get,
aggValRef), origOutput.name)(exprId),
joinPlan)
}
Project(
currentChild.output :+
Alias(
If(IsNull(alwaysTrueRef),
resultWithZeroTups.get,
aggValRef), origOutput.name)(exprId = origOutput.exprId),
Join(currentChild,
Project(query.output :+ alwaysTrueExpr, query),
LeftOuter, conditions.reduceOption(And), JoinHint.NONE))

val newPlan = buildPlan(origOutput.exprId)
rewritePlanMap += newPlan -> buildPlan(NamedExpression.newExprId)
newPlan
} else {
// CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join.
// Need to modify any operators below the join to pass through all columns
Expand All @@ -575,85 +567,66 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
case op => sys.error(s"Unexpected operator $op in corelated subquery")
}

val joinPlan = Join(currentChild,
Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
LeftOuter, conditions.reduceOption(And), JoinHint.NONE)

def buildPlan(exprId: ExprId): LogicalPlan = {
// CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups
// WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>)
// ELSE (aggregate value) END AS (original column name)
val caseExpr = Alias(CaseWhen(Seq(
(IsNull(alwaysTrueRef), resultWithZeroTups.get),
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
aggValRef),
origOutput.name)(exprId)

Project(
currentChild.output :+ caseExpr,
joinPlan)
}
// CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups
// WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>)
// ELSE (aggregate value) END AS (original column name)
val caseExpr = Alias(CaseWhen(Seq(
(IsNull(alwaysTrueRef), resultWithZeroTups.get),
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
aggValRef),
origOutput.name)(exprId = origOutput.exprId)

Project(
currentChild.output :+ caseExpr,
Join(currentChild,
Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
LeftOuter, conditions.reduceOption(And), JoinHint.NONE))

val newPlan = buildPlan(origOutput.exprId)
rewritePlanMap += newPlan -> buildPlan(NamedExpression.newExprId)
newPlan
}
}
}

(newPlan, rewritePlanMap)
}

/**
* Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing correlated scalar
* subqueries.
*/
def apply(plan: LogicalPlan): LogicalPlan = {
val rewritePlanMap = ArrayBuffer[(LogicalPlan, LogicalPlan)]()
val newPlan = plan transform {
case a @ Aggregate(grouping, expressions, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
// We currently only allow correlated subqueries in an aggregate if they are part of the
// grouping expressions. As a result we need to replace all the scalar subqueries in the
// grouping expressions by their result.
val newGrouping = grouping.map { e =>
subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e)
}
val (newChild, rewriteMap) = constructLeftJoins(child, subqueries)
rewritePlanMap ++= rewriteMap
Aggregate(newGrouping, newExpressions, newChild)
} else {
a
}
case p @ Project(expressions, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
val (newChild, rewriteMap) = constructLeftJoins(child, subqueries)
rewritePlanMap ++= rewriteMap
Project(newExpressions, newChild)
} else {
p
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
case a @ Aggregate(grouping, expressions, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
// We currently only allow correlated subqueries in an aggregate if they are part of the
// grouping expressions. As a result we need to replace all the scalar subqueries in the
// grouping expressions by their result.
val newGrouping = grouping.map { e =>
subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e)
}
case f @ Filter(condition, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
if (subqueries.nonEmpty) {
val (newChild, rewriteMap) = constructLeftJoins(child, subqueries)
rewritePlanMap ++= rewriteMap
Project(f.output, Filter(newCondition, newChild))
} else {
f
}
}

if (rewritePlanMap.nonEmpty) {
assert(!plan.fastEquals(newPlan))
Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1
} else {
newPlan
}
val newAgg = Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries))
val attrMapping = a.output.zip(newAgg.output)
newAgg -> attrMapping
} else {
a -> Nil
}
case p @ Project(expressions, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
val newProj = Project(newExpressions, constructLeftJoins(child, subqueries))
val attrMapping = p.output.zip(newProj.output)
newProj -> attrMapping
} else {
p -> Nil
}
case f @ Filter(condition, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
if (subqueries.nonEmpty) {
val newProj = Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries)))
val attrMapping = f.output.zip(newProj.output)
newProj -> attrMapping
} else {
f -> Nil
}
}
}

0 comments on commit 9bcc4e0

Please sign in to comment.