Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-33036][SQL] Refactor RewriteCorrelatedScalarSubquery code to replace exprIds in a bottom-up manner #29913

Closed
Closed
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -338,20 +338,15 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
/**
* Extract all correlated scalar subqueries from an expression. The subqueries are collected using
* the given collector. To avoid the reuse of `exprId`s, this method generates new `exprId`
* for the subqueries and rewrite references in the given `expression`.
* This method returns extracted subqueries and the corresponding `exprId`s and these values
* will be used later in `constructLeftJoins` for building the child plan that
* returns subquery output with the `exprId`s.
* the given collector. The expression is rewritten and returned.
*/
private def extractCorrelatedScalarSubqueries[E <: Expression](
expression: E,
subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): E = {
subqueries: ArrayBuffer[ScalarSubquery]): E = {
val newExpression = expression transform {
case s: ScalarSubquery if s.children.nonEmpty =>
val newExprId = NamedExpression.newExprId
subqueries += s -> newExprId
s.plan.output.head.withExprId(newExprId)
subqueries += s
s.plan.output.head
}
newExpression.asInstanceOf[E]
}
Expand Down Expand Up @@ -512,19 +507,23 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {

/**
* Construct a new child plan by left joining the given subqueries to a base plan.
* This method returns the child plan and an attribute mapping
* for the updated `ExprId`s of subqueries. If the non-empty mapping returned,
* this rule will rewrite subquery references in a parent plan based on it.
*/
private def constructLeftJoins(
child: LogicalPlan,
subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): LogicalPlan = {
subqueries.foldLeft(child) {
case (currentChild, (ScalarSubquery(query, conditions, _), newExprId)) =>
subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, AttributeMap[Attribute]) = {
val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]()
val newChild = subqueries.foldLeft(child) {
case (currentChild, ScalarSubquery(query, conditions, _)) =>
val origOutput = query.output.head

val resultWithZeroTups = evalSubqueryOnZeroTups(query)
if (resultWithZeroTups.isEmpty) {
// CASE 1: Subquery guaranteed not to have the COUNT bug
Project(
currentChild.output :+ Alias(origOutput, origOutput.name)(exprId = newExprId),
currentChild.output :+ origOutput,
Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
} else {
// Subquery might have the COUNT bug. Add appropriate corrections.
Expand All @@ -544,12 +543,13 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {

if (havingNode.isEmpty) {
// CASE 2: Subquery with no HAVING clause
val subqueryResultExpr =
Alias(If(IsNull(alwaysTrueRef),
resultWithZeroTups.get,
aggValRef), origOutput.name)()
subqueryAttrMapping += ((origOutput, subqueryResultExpr.toAttribute))
Project(
currentChild.output :+
Alias(
If(IsNull(alwaysTrueRef),
resultWithZeroTups.get,
aggValRef), origOutput.name)(exprId = newExprId),
currentChild.output :+ subqueryResultExpr,
Join(currentChild,
Project(query.output :+ alwaysTrueExpr, query),
LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
Expand All @@ -576,7 +576,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
(IsNull(alwaysTrueRef), resultWithZeroTups.get),
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
aggValRef),
origOutput.name)(exprId = newExprId)
origOutput.name)()

subqueryAttrMapping += ((origOutput, caseExpr.toAttribute))

Project(
currentChild.output :+ caseExpr,
Expand All @@ -587,6 +589,20 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
}
}
}
(newChild, AttributeMap(subqueryAttrMapping.toSeq))
}

private def updateAttrs[E <: Expression](
exprs: Seq[E],
attrMap: AttributeMap[Attribute]): Seq[E] = {
if (attrMap.nonEmpty) {
val newExprs = exprs.map { _.transform {
case a: AttributeReference if attrMap.contains(a) => attrMap(a)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a big deal: it's more efficient to write case a: AttributeReference => attrMap.getOrElse(a, a)

}}
newExprs.asInstanceOf[Seq[E]]
} else {
exprs
}
}

/**
Expand All @@ -595,36 +611,42 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
*/
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
case a @ Aggregate(grouping, expressions, child) =>
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
// We currently only allow correlated subqueries in an aggregate if they are part of the
// grouping expressions. As a result we need to replace all the scalar subqueries in the
// grouping expressions by their result.
val newGrouping = grouping.map { e =>
subqueries.find(_._1.semanticEquals(e)).map(_._1.plan.output.head).getOrElse(e)
subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e)
}
val newAgg = Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries))
val (newChild, subqueryAttrMapping) = constructLeftJoins(child, subqueries)
val newExprs = updateAttrs(rewriteExprs, subqueryAttrMapping)
val newAgg = Aggregate(newGrouping, newExprs, newChild)
val attrMapping = a.output.zip(newAgg.output)
newAgg -> attrMapping
} else {
a -> Nil
}
case p @ Project(expressions, child) =>
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
val newProj = Project(newExpressions, constructLeftJoins(child, subqueries))
val (newChild, subqueryAttrMapping) = constructLeftJoins(child, subqueries)
val newExprs = updateAttrs(rewriteExprs, subqueryAttrMapping)
val newProj = Project(newExprs, newChild)
val attrMapping = p.output.zip(newProj.output)
newProj -> attrMapping
} else {
p -> Nil
}
case f @ Filter(condition, child) =>
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val rewriteCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
if (subqueries.nonEmpty) {
val newProj = Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries)))
val (newChild, subqueryAttrMapping) = constructLeftJoins(child, subqueries)
val newCondition = updateAttrs(Seq(rewriteCondition), subqueryAttrMapping).head
val newProj = Project(f.output, Filter(newCondition, newChild))
val attrMapping = f.output.zip(newProj.output)
newProj -> attrMapping
} else {
Expand Down