Skip to content

Commit

Permalink
move the plan rewrite methods to QueryPlan
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Sep 3, 2020
1 parent a6114d8 commit 76cf567
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,127 +123,6 @@ object AnalysisContext {
}
}

object Analyzer {

/**
* Rewrites a given `plan` recursively based on rewrite mappings from old plans to new ones.
* This method also updates all the related references in the `plan` accordingly.
*
* @param plan to rewrite
* @param rewritePlanMap has mappings from old plans to new ones for the given `plan`.
* @return a rewritten plan and updated references related to a root node of
* the given `plan` for rewriting it.
*/
def rewritePlan(plan: LogicalPlan, rewritePlanMap: Map[LogicalPlan, LogicalPlan])
: (LogicalPlan, Seq[(Attribute, Attribute)]) = {
if (plan.resolved) {
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
val newChildren = plan.children.map { child =>
// If not, we'd rewrite child plan recursively until we find the
// conflict node or reach the leaf node.
val (newChild, childAttrMapping) = rewritePlan(child, rewritePlanMap)
attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
// `attrMapping` is not only used to replace the attributes of the current `plan`,
// but also to be propagated to the parent plans of the current `plan`. Therefore,
// the `oldAttr` must be part of either `plan.references` (so that it can be used to
// replace attributes of the current `plan`) or `plan.outputSet` (so that it can be
// used by those parent plans).
(plan.outputSet ++ plan.references).contains(oldAttr)
}
newChild
}

val newPlan = if (rewritePlanMap.contains(plan)) {
rewritePlanMap(plan).withNewChildren(newChildren)
} else {
plan.withNewChildren(newChildren)
}

assert(!attrMapping.groupBy(_._1.exprId)
.exists(_._2.map(_._2.exprId).distinct.length > 1),
"Found duplicate rewrite attributes")

val attributeRewrites = AttributeMap(attrMapping)
// Using attrMapping from the children plans to rewrite their parent node.
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
val p = newPlan.transformExpressions {
case a: Attribute =>
updateAttr(a, attributeRewrites)
case s: SubqueryExpression =>
s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attributeRewrites))
}
attrMapping ++= plan.output.zip(p.output)
.filter { case (a1, a2) => a1.exprId != a2.exprId }
p -> attrMapping
} else {
// Just passes through unresolved nodes
plan.mapChildren {
rewritePlan(_, rewritePlanMap)._1
} -> Nil
}
}

private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
val exprId = attrMap.getOrElse(attr, attr).exprId
attr.withExprId(exprId)
}

/**
* The outer plan may have old references and the function below updates the
* outer references to refer to the new attributes.
*
* For example (SQL):
* {{{
* SELECT * FROM t1
* INTERSECT
* SELECT * FROM t1
* WHERE EXISTS (SELECT 1
* FROM t2
* WHERE t1.c1 = t2.c1)
* }}}
* Plan before resolveReference rule.
* 'Intersect
* :- Project [c1#245, c2#246]
* : +- SubqueryAlias t1
* : +- Relation[c1#245,c2#246] parquet
* +- 'Project [*]
* +- Filter exists#257 [c1#245]
* : +- Project [1 AS 1#258]
* : +- Filter (outer(c1#245) = c1#251)
* : +- SubqueryAlias t2
* : +- Relation[c1#251,c2#252] parquet
* +- SubqueryAlias t1
* +- Relation[c1#245,c2#246] parquet
* Plan after the resolveReference rule.
* Intersect
* :- Project [c1#245, c2#246]
* : +- SubqueryAlias t1
* : +- Relation[c1#245,c2#246] parquet
* +- Project [c1#259, c2#260]
* +- Filter exists#257 [c1#259]
* : +- Project [1 AS 1#258]
* : +- Filter (outer(c1#259) = c1#251) => Updated
* : +- SubqueryAlias t2
* : +- Relation[c1#251,c2#252] parquet
* +- SubqueryAlias t1
* +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are rewritten.
*/
private def updateOuterReferencesInSubquery(
plan: LogicalPlan,
attrMap: AttributeMap[Attribute]): LogicalPlan = {
AnalysisHelper.allowInvokingTransformsInAnalyzer {
plan transformDown { case currentFragment =>
currentFragment transformExpressions {
case OuterReference(a: Attribute) =>
OuterReference(updateAttr(a, attrMap))
case s: SubqueryExpression =>
s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attrMap))
}
}
}
}
}

/**
* Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
* [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]].
Expand Down Expand Up @@ -1376,7 +1255,7 @@ class Analyzer(
if (conflictPlans.isEmpty) {
right
} else {
Analyzer.rewritePlan(right, conflictPlans.toMap)._1
right.rewriteWithPlanMapping(conflictPlans.toMap)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ object TypeCoercion {

if (rewritePlanMap.nonEmpty) {
assert(!plan.fastEquals(newPlan))
Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1
newPlan.rewriteWithPlanMapping(rewritePlanMap.toMap, _.resolved)
} else {
plan
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.plans

import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode, TreeNodeTag}
Expand Down Expand Up @@ -168,6 +170,85 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
}.toSeq
}


/**
* Rewrites this plan tree based on the given plan mappings from old plan nodes to new nodes.
* This method also updates all the related references in this plan tree accordingly, in case
* the replaced node has different output expr ID than the old node.
*/
def rewriteWithPlanMapping(
planMapping: Map[PlanType, PlanType],
canGetOutput: PlanType => Boolean = _ => true): PlanType = {
def internalRewrite(plan: PlanType): (PlanType, Seq[(Attribute, Attribute)]) = {
if (planMapping.contains(plan)) {
val newPlan = planMapping(plan)
val attrMapping = if (canGetOutput(plan) && canGetOutput(newPlan)) {
plan.output.zip(newPlan.output).filter {
case (a1, a2) => a1.exprId != a2.exprId
}
} else {
Nil
}
newPlan -> attrMapping
} else {
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
val newPlan = plan.mapChildren { child =>
val (newChild, childAttrMapping) = internalRewrite(child)
attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
// `attrMapping` is not only used to replace the attributes of the current `plan`,
// but also to be propagated to the parent plans of the current `plan`. Therefore,
// the `oldAttr` must be part of either `plan.references` (so that it can be used to
// replace attributes of the current `plan`) or `plan.outputSet` (so that it can be
// used by those parent plans).
(plan.outputSet ++ plan.references).contains(oldAttr)
}
newChild
}

if (attrMapping.isEmpty) {
newPlan -> Nil
} else {
assert(!attrMapping.groupBy(_._1.exprId)
.exists(_._2.map(_._2.exprId).distinct.length > 1),
"Found duplicate rewrite attributes")

val attributeRewrites = AttributeMap(attrMapping)
// Using attrMapping from the children plans to rewrite their parent node.
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
newPlan.transformExpressions {
case a: AttributeReference =>
updateAttr(a, attributeRewrites)
case pe: PlanExpression[PlanType] =>
pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attributeRewrites))
} -> attrMapping
}
}
}
internalRewrite(this)._1
}

private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
val exprId = attrMap.getOrElse(attr, attr).exprId
attr.withExprId(exprId)
}

/**
* The outer plan may have old references and the function below updates the
* outer references to refer to the new attributes.
*/
private def updateOuterReferencesInSubquery(
plan: PlanType,
attrMap: AttributeMap[Attribute]): PlanType = {
plan.transformDown { case currentFragment =>
currentFragment.transformExpressions {
case OuterReference(a: AttributeReference) =>
OuterReference(updateAttr(a, attrMap))
case pe: PlanExpression[PlanType] =>
pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap))
}
}
}

lazy val schema: StructType = StructType.fromAttributes(output)

/** Returns the output schema in the tree format. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1419,12 +1419,13 @@ class TypeCoercionSuite extends AnalysisTest {
test("SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes") {
val t1 = LocalRelation(AttributeReference("v", DecimalType(10, 0))())
val t2 = LocalRelation(AttributeReference("v", DecimalType(11, 0))())
val p1 = t1.select(t1.output.head)
val p2 = t2.select(t2.output.head)
val p1 = t1.select(t1.output.head).as("p1")
val p2 = t2.select(t2.output.head).as("p2")
val union = p1.union(p2)
val wp1 = widenSetOperationTypes(union.select(p1.output.head))
val wp1 = widenSetOperationTypes(union.select(p1.output.head, $"p2.v"))
assert(wp1.isInstanceOf[Project])
assert(wp1.missingInput.isEmpty)
// The attribute `p1.output.head` should be replaced in the root `Project`.
assert(wp1.expressions.forall(_.find(_ == p1.output.head).isEmpty))
val wp2 = widenSetOperationTypes(Aggregate(Nil, sum(p1.output.head).as("v") :: Nil, union))
assert(wp2.isInstanceOf[Aggregate])
assert(wp2.missingInput.isEmpty)
Expand Down

0 comments on commit 76cf567

Please sign in to comment.