Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32638][SQL][FOLLOWUP] Move the plan rewriting methods to QueryPlan #29643

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -123,127 +123,6 @@ object AnalysisContext {
}
}

object Analyzer {

/**
* Rewrites a given `plan` recursively based on rewrite mappings from old plans to new ones.
* This method also updates all the related references in the `plan` accordingly.
*
* @param plan to rewrite
* @param rewritePlanMap has mappings from old plans to new ones for the given `plan`.
* @return a rewritten plan and updated references related to a root node of
* the given `plan` for rewriting it.
*/
def rewritePlan(plan: LogicalPlan, rewritePlanMap: Map[LogicalPlan, LogicalPlan])
: (LogicalPlan, Seq[(Attribute, Attribute)]) = {
if (plan.resolved) {
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
val newChildren = plan.children.map { child =>
// If not, we'd rewrite child plan recursively until we find the
// conflict node or reach the leaf node.
val (newChild, childAttrMapping) = rewritePlan(child, rewritePlanMap)
attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
// `attrMapping` is not only used to replace the attributes of the current `plan`,
// but also to be propagated to the parent plans of the current `plan`. Therefore,
// the `oldAttr` must be part of either `plan.references` (so that it can be used to
// replace attributes of the current `plan`) or `plan.outputSet` (so that it can be
// used by those parent plans).
(plan.outputSet ++ plan.references).contains(oldAttr)
}
newChild
}

val newPlan = if (rewritePlanMap.contains(plan)) {
rewritePlanMap(plan).withNewChildren(newChildren)
} else {
plan.withNewChildren(newChildren)
}

assert(!attrMapping.groupBy(_._1.exprId)
.exists(_._2.map(_._2.exprId).distinct.length > 1),
"Found duplicate rewrite attributes")

val attributeRewrites = AttributeMap(attrMapping)
// Using attrMapping from the children plans to rewrite their parent node.
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
val p = newPlan.transformExpressions {
case a: Attribute =>
updateAttr(a, attributeRewrites)
case s: SubqueryExpression =>
s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attributeRewrites))
}
attrMapping ++= plan.output.zip(p.output)
.filter { case (a1, a2) => a1.exprId != a2.exprId }
p -> attrMapping
} else {
// Just passes through unresolved nodes
plan.mapChildren {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that we won't replace attributes in an unresolved plan, which is not sufficient. See the updated test: https://github.com/apache/spark/pull/29643/files#diff-01ecdd038c5c2f53f38118912210fef8R1425

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! In this unresolved plan, there might be other resolved and replaced attributes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. Nice catch.

rewritePlan(_, rewritePlanMap)._1
} -> Nil
}
}

private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
val exprId = attrMap.getOrElse(attr, attr).exprId
attr.withExprId(exprId)
}

/**
* The outer plan may have old references and the function below updates the
* outer references to refer to the new attributes.
*
* For example (SQL):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The example here is not useful at all. The first sentence already explains the reason very well, while the query plan example is hard to read.

* {{{
* SELECT * FROM t1
* INTERSECT
* SELECT * FROM t1
* WHERE EXISTS (SELECT 1
* FROM t2
* WHERE t1.c1 = t2.c1)
* }}}
* Plan before resolveReference rule.
* 'Intersect
* :- Project [c1#245, c2#246]
* : +- SubqueryAlias t1
* : +- Relation[c1#245,c2#246] parquet
* +- 'Project [*]
* +- Filter exists#257 [c1#245]
* : +- Project [1 AS 1#258]
* : +- Filter (outer(c1#245) = c1#251)
* : +- SubqueryAlias t2
* : +- Relation[c1#251,c2#252] parquet
* +- SubqueryAlias t1
* +- Relation[c1#245,c2#246] parquet
* Plan after the resolveReference rule.
* Intersect
* :- Project [c1#245, c2#246]
* : +- SubqueryAlias t1
* : +- Relation[c1#245,c2#246] parquet
* +- Project [c1#259, c2#260]
* +- Filter exists#257 [c1#259]
* : +- Project [1 AS 1#258]
* : +- Filter (outer(c1#259) = c1#251) => Updated
* : +- SubqueryAlias t2
* : +- Relation[c1#251,c2#252] parquet
* +- SubqueryAlias t1
* +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are rewritten.
*/
private def updateOuterReferencesInSubquery(
plan: LogicalPlan,
attrMap: AttributeMap[Attribute]): LogicalPlan = {
AnalysisHelper.allowInvokingTransformsInAnalyzer {
plan transformDown { case currentFragment =>
currentFragment transformExpressions {
case OuterReference(a: Attribute) =>
OuterReference(updateAttr(a, attrMap))
case s: SubqueryExpression =>
s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attrMap))
}
}
}
}
}

/**
* Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
* [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]].
Expand Down Expand Up @@ -1376,7 +1255,14 @@ class Analyzer(
if (conflictPlans.isEmpty) {
right
} else {
Analyzer.rewritePlan(right, conflictPlans.toMap)._1
val planMapping = conflictPlans.toMap
right.transformUpWithNewOutput {
case oldPlan =>
val newPlanOpt = planMapping.get(oldPlan)
newPlanOpt.map { newPlan =>
newPlan -> oldPlan.output.zip(newPlan.output)
}.getOrElse(oldPlan -> Nil)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -329,50 +329,43 @@ object TypeCoercion {
object WidenSetOperationTypes extends TypeCoercionRule {

override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = {
val rewritePlanMap = mutable.ArrayBuffer[(LogicalPlan, LogicalPlan)]()
val newPlan = plan resolveOperatorsUp {
plan resolveOperatorsUpWithNewOutput {
case s @ Except(left, right, isAll) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren = buildNewChildrenWithWiderTypes(left :: right :: Nil)
if (newChildren.nonEmpty) {
rewritePlanMap ++= newChildren
Except(newChildren.head._1, newChildren.last._1, isAll)
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
if (newChildren.isEmpty) {
s -> Nil
} else {
s
assert(newChildren.length == 2)
val attrMapping = left.output.zip(newChildren.head.output)
Except(newChildren.head, newChildren.last, isAll) -> attrMapping
}

case s @ Intersect(left, right, isAll) if s.childrenResolved &&
left.output.length == right.output.length && !s.resolved =>
val newChildren = buildNewChildrenWithWiderTypes(left :: right :: Nil)
if (newChildren.nonEmpty) {
rewritePlanMap ++= newChildren
Intersect(newChildren.head._1, newChildren.last._1, isAll)
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil)
if (newChildren.isEmpty) {
s -> Nil
} else {
s
assert(newChildren.length == 2)
val attrMapping = left.output.zip(newChildren.head.output)
Intersect(newChildren.head, newChildren.last, isAll) -> attrMapping
}

case s: Union if s.childrenResolved && !s.byName &&
s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved =>
val newChildren = buildNewChildrenWithWiderTypes(s.children)
if (newChildren.nonEmpty) {
rewritePlanMap ++= newChildren
s.copy(children = newChildren.map(_._1))
val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children)
if (newChildren.isEmpty) {
s -> Nil
} else {
s
val attrMapping = s.children.head.output.zip(newChildren.head.output)
s.copy(children = newChildren) -> attrMapping
}
}

if (rewritePlanMap.nonEmpty) {
assert(!plan.fastEquals(newPlan))
Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1
} else {
plan
}
}

/** Build new children with the widest types for each attribute among all the children */
private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan])
: Seq[(LogicalPlan, LogicalPlan)] = {
private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
require(children.forall(_.output.length == children.head.output.length))

// Get a sequence of data types, each of which is the widest type of this specific attribute
Expand Down Expand Up @@ -408,16 +401,13 @@ object TypeCoercion {
}

/** Given a plan, add an extra project on top to widen some columns' data types. */
private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType])
: (LogicalPlan, LogicalPlan) = {
private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = {
val casted = plan.output.zip(targetTypes).map {
case (e, dt) if e.dataType != dt =>
val alias = Alias(Cast(e, dt), e.name)(exprId = e.exprId)
alias -> alias.newInstance()
case (e, _) =>
e -> e
}.unzip
Project(casted._1, plan) -> Project(casted._2, plan)
Alias(Cast(e, dt, Some(SQLConf.get.sessionLocalTimeZone)), e.name)()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity; why we need to set timezone here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

otherwise WidenSetOperationTypes will return invalid attribute mapping (unresolved Alias with unresolved cast) when calling resolveOperatorsUpWithNewOutput

case (e, _) => e
}
Project(casted, plan)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.plans

import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode, TreeNodeTag}
Expand Down Expand Up @@ -168,6 +170,89 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
}.toSeq
}

/**
* A variant of `transformUp`, which takes care of the case that the rule replaces a plan node
* with a new one that has different output expr IDs, by updating the attribute references in
* the parent nodes accordingly.
*
* @param rule the function to transform plan nodes, and return new nodes with attributes mapping
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question. Why we need to return the attribute mapping from old to new? Can we just detect if the output of new plan is different to old plan, then create the mapping inside transformUpWithNewOutput?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because that's too hard. For example, WidenSetOperationTypes returns attribute mapping according to the replaced children, not itself, because itself may not be resolved yet. While for self-join dedup, we return attribute mapping according to the current node.

* from old attributes to new attributes. The attribute mapping will be used to
* rewrite attribute references in the parent nodes.
* @param skipCond a boolean condition to indicate if we can skip transforming a plan node to save
* time.
*/
def transformUpWithNewOutput(
rule: PartialFunction[PlanType, (PlanType, Seq[(Attribute, Attribute)])],
skipCond: PlanType => Boolean = _ => false): PlanType = {
def rewrite(plan: PlanType): (PlanType, Seq[(Attribute, Attribute)]) = {
if (skipCond(plan)) {
plan -> Nil
} else {
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
var newPlan = plan.mapChildren { child =>
val (newChild, childAttrMapping) = rewrite(child)
attrMapping ++= childAttrMapping
newChild
}

val attrMappingForCurrentPlan = attrMapping.filter {
// The `attrMappingForCurrentPlan` is used to replace the attributes of the
// current `plan`, so the `oldAttr` must be part of `plan.references`.
case (oldAttr, _) => plan.references.contains(oldAttr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we skip if child is not resolved ? Although, it would break the one shot rewrite idea. The reason is, call .references on an unresovled plan is dangerous that plan may use child.outpuSet as its references.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a base trait for plans that override references with child.outputSet? Then we can match this trait here and skip calling reference.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally a plan should determine its reference by its expressions, but not by its child output attributes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a new base trait sounds good

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I send a pr #40154 for it

}

val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(newPlan, (plan: PlanType) => plan -> Nil)
}
newPlan = planAfterRule

if (attrMappingForCurrentPlan.nonEmpty) {
assert(!attrMappingForCurrentPlan.groupBy(_._1.exprId)
.exists(_._2.map(_._2.exprId).distinct.length > 1),
"Found duplicate rewrite attributes")

val attributeRewrites = AttributeMap(attrMappingForCurrentPlan)
// Using attrMapping from the children plans to rewrite their parent node.
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
newPlan = newPlan.transformExpressions {
case a: AttributeReference =>
updateAttr(a, attributeRewrites)
case pe: PlanExpression[PlanType] =>
pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attributeRewrites))
}
}

attrMapping ++= newAttrMapping.filter {
case (a1, a2) => a1.exprId != a2.exprId
}
newPlan -> attrMapping
}
}
rewrite(this)._1
}

private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
val exprId = attrMap.getOrElse(attr, attr).exprId
attr.withExprId(exprId)
}

/**
* The outer plan may have old references and the function below updates the
* outer references to refer to the new attributes.
*/
private def updateOuterReferencesInSubquery(
plan: PlanType,
attrMap: AttributeMap[Attribute]): PlanType = {
plan.transformDown { case currentFragment =>
currentFragment.transformExpressions {
case OuterReference(a: AttributeReference) =>
OuterReference(updateAttr(a, attrMap))
case pe: PlanExpression[PlanType] =>
pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap))
}
}
}

lazy val schema: StructType = StructType.fromAttributes(output)

/** Returns the output schema in the tree format. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.analysis.CheckAnalysis
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -120,6 +120,19 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan =>
}
}

/**
* A variant of `transformUpWithNewOutput`, which skips touching already analyzed plan.
*/
def resolveOperatorsUpWithNewOutput(
rule: PartialFunction[LogicalPlan, (LogicalPlan, Seq[(Attribute, Attribute)])])
: LogicalPlan = {
if (!analyzed) {
transformUpWithNewOutput(rule, skipCond = _.analyzed)
} else {
self
}
}

/**
* Recursively transforms the expressions of a tree, skipping nodes that have already
* been analyzed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1419,12 +1419,13 @@ class TypeCoercionSuite extends AnalysisTest {
test("SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes") {
val t1 = LocalRelation(AttributeReference("v", DecimalType(10, 0))())
val t2 = LocalRelation(AttributeReference("v", DecimalType(11, 0))())
val p1 = t1.select(t1.output.head)
val p2 = t2.select(t2.output.head)
val p1 = t1.select(t1.output.head).as("p1")
val p2 = t2.select(t2.output.head).as("p2")
val union = p1.union(p2)
val wp1 = widenSetOperationTypes(union.select(p1.output.head))
val wp1 = widenSetOperationTypes(union.select(p1.output.head, $"p2.v"))
assert(wp1.isInstanceOf[Project])
assert(wp1.missingInput.isEmpty)
// The attribute `p1.output.head` should be replaced in the root `Project`.
assert(wp1.expressions.forall(_.find(_ == p1.output.head).isEmpty))
val wp2 = widenSetOperationTypes(Aggregate(Nil, sum(p1.output.head).as("v") :: Nil, union))
assert(wp2.isInstanceOf[Aggregate])
assert(wp2.missingInput.isEmpty)
Expand Down