diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1d2e48301ea98..020102dbfa5e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -123,127 +123,6 @@ object AnalysisContext { } } -object Analyzer { - - /** - * Rewrites a given `plan` recursively based on rewrite mappings from old plans to new ones. - * This method also updates all the related references in the `plan` accordingly. - * - * @param plan to rewrite - * @param rewritePlanMap has mappings from old plans to new ones for the given `plan`. - * @return a rewritten plan and updated references related to a root node of - * the given `plan` for rewriting it. - */ - def rewritePlan(plan: LogicalPlan, rewritePlanMap: Map[LogicalPlan, LogicalPlan]) - : (LogicalPlan, Seq[(Attribute, Attribute)]) = { - if (plan.resolved) { - val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]() - val newChildren = plan.children.map { child => - // If not, we'd rewrite child plan recursively until we find the - // conflict node or reach the leaf node. - val (newChild, childAttrMapping) = rewritePlan(child, rewritePlanMap) - attrMapping ++= childAttrMapping.filter { case (oldAttr, _) => - // `attrMapping` is not only used to replace the attributes of the current `plan`, - // but also to be propagated to the parent plans of the current `plan`. Therefore, - // the `oldAttr` must be part of either `plan.references` (so that it can be used to - // replace attributes of the current `plan`) or `plan.outputSet` (so that it can be - // used by those parent plans). - (plan.outputSet ++ plan.references).contains(oldAttr) - } - newChild - } - - val newPlan = if (rewritePlanMap.contains(plan)) { - rewritePlanMap(plan).withNewChildren(newChildren) - } else { - plan.withNewChildren(newChildren) - } - - assert(!attrMapping.groupBy(_._1.exprId) - .exists(_._2.map(_._2.exprId).distinct.length > 1), - "Found duplicate rewrite attributes") - - val attributeRewrites = AttributeMap(attrMapping) - // Using attrMapping from the children plans to rewrite their parent node. - // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. - val p = newPlan.transformExpressions { - case a: Attribute => - updateAttr(a, attributeRewrites) - case s: SubqueryExpression => - s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attributeRewrites)) - } - attrMapping ++= plan.output.zip(p.output) - .filter { case (a1, a2) => a1.exprId != a2.exprId } - p -> attrMapping - } else { - // Just passes through unresolved nodes - plan.mapChildren { - rewritePlan(_, rewritePlanMap)._1 - } -> Nil - } - } - - private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { - val exprId = attrMap.getOrElse(attr, attr).exprId - attr.withExprId(exprId) - } - - /** - * The outer plan may have old references and the function below updates the - * outer references to refer to the new attributes. - * - * For example (SQL): - * {{{ - * SELECT * FROM t1 - * INTERSECT - * SELECT * FROM t1 - * WHERE EXISTS (SELECT 1 - * FROM t2 - * WHERE t1.c1 = t2.c1) - * }}} - * Plan before resolveReference rule. - * 'Intersect - * :- Project [c1#245, c2#246] - * : +- SubqueryAlias t1 - * : +- Relation[c1#245,c2#246] parquet - * +- 'Project [*] - * +- Filter exists#257 [c1#245] - * : +- Project [1 AS 1#258] - * : +- Filter (outer(c1#245) = c1#251) - * : +- SubqueryAlias t2 - * : +- Relation[c1#251,c2#252] parquet - * +- SubqueryAlias t1 - * +- Relation[c1#245,c2#246] parquet - * Plan after the resolveReference rule. - * Intersect - * :- Project [c1#245, c2#246] - * : +- SubqueryAlias t1 - * : +- Relation[c1#245,c2#246] parquet - * +- Project [c1#259, c2#260] - * +- Filter exists#257 [c1#259] - * : +- Project [1 AS 1#258] - * : +- Filter (outer(c1#259) = c1#251) => Updated - * : +- SubqueryAlias t2 - * : +- Relation[c1#251,c2#252] parquet - * +- SubqueryAlias t1 - * +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are rewritten. - */ - private def updateOuterReferencesInSubquery( - plan: LogicalPlan, - attrMap: AttributeMap[Attribute]): LogicalPlan = { - AnalysisHelper.allowInvokingTransformsInAnalyzer { - plan transformDown { case currentFragment => - currentFragment transformExpressions { - case OuterReference(a: Attribute) => - OuterReference(updateAttr(a, attrMap)) - case s: SubqueryExpression => - s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attrMap)) - } - } - } - } -} - /** * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and * [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]]. @@ -1376,7 +1255,7 @@ class Analyzer( if (conflictPlans.isEmpty) { right } else { - Analyzer.rewritePlan(right, conflictPlans.toMap)._1 + right.rewriteWithPlanMapping(conflictPlans.toMap) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 861eddedc0e1b..5a7d1a305b2d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -364,7 +364,7 @@ object TypeCoercion { if (rewritePlanMap.nonEmpty) { assert(!plan.fastEquals(newPlan)) - Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1 + newPlan.rewriteWithPlanMapping(rewritePlanMap.toMap, _.resolved) } else { plan } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 7133fb231d672..215c178f62729 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans +import scala.collection.mutable + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode, TreeNodeTag} @@ -168,6 +170,85 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT }.toSeq } + + /** + * Rewrites this plan tree based on the given plan mappings from old plan nodes to new nodes. + * This method also updates all the related references in this plan tree accordingly, in case + * the replaced node has different output expr ID than the old node. + */ + def rewriteWithPlanMapping( + planMapping: Map[PlanType, PlanType], + canGetOutput: PlanType => Boolean = _ => true): PlanType = { + def internalRewrite(plan: PlanType): (PlanType, Seq[(Attribute, Attribute)]) = { + if (planMapping.contains(plan)) { + val newPlan = planMapping(plan) + val attrMapping = if (canGetOutput(plan) && canGetOutput(newPlan)) { + plan.output.zip(newPlan.output).filter { + case (a1, a2) => a1.exprId != a2.exprId + } + } else { + Nil + } + newPlan -> attrMapping + } else { + val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]() + val newPlan = plan.mapChildren { child => + val (newChild, childAttrMapping) = internalRewrite(child) + attrMapping ++= childAttrMapping.filter { case (oldAttr, _) => + // `attrMapping` is not only used to replace the attributes of the current `plan`, + // but also to be propagated to the parent plans of the current `plan`. Therefore, + // the `oldAttr` must be part of either `plan.references` (so that it can be used to + // replace attributes of the current `plan`) or `plan.outputSet` (so that it can be + // used by those parent plans). + (plan.outputSet ++ plan.references).contains(oldAttr) + } + newChild + } + + if (attrMapping.isEmpty) { + newPlan -> Nil + } else { + assert(!attrMapping.groupBy(_._1.exprId) + .exists(_._2.map(_._2.exprId).distinct.length > 1), + "Found duplicate rewrite attributes") + + val attributeRewrites = AttributeMap(attrMapping) + // Using attrMapping from the children plans to rewrite their parent node. + // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. + newPlan.transformExpressions { + case a: AttributeReference => + updateAttr(a, attributeRewrites) + case pe: PlanExpression[PlanType] => + pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attributeRewrites)) + } -> attrMapping + } + } + } + internalRewrite(this)._1 + } + + private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { + val exprId = attrMap.getOrElse(attr, attr).exprId + attr.withExprId(exprId) + } + + /** + * The outer plan may have old references and the function below updates the + * outer references to refer to the new attributes. + */ + private def updateOuterReferencesInSubquery( + plan: PlanType, + attrMap: AttributeMap[Attribute]): PlanType = { + plan.transformDown { case currentFragment => + currentFragment.transformExpressions { + case OuterReference(a: AttributeReference) => + OuterReference(updateAttr(a, attrMap)) + case pe: PlanExpression[PlanType] => + pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap)) + } + } + } + lazy val schema: StructType = StructType.fromAttributes(output) /** Returns the output schema in the tree format. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 1af562fd1a061..7b80de908fa08 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1419,12 +1419,13 @@ class TypeCoercionSuite extends AnalysisTest { test("SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes") { val t1 = LocalRelation(AttributeReference("v", DecimalType(10, 0))()) val t2 = LocalRelation(AttributeReference("v", DecimalType(11, 0))()) - val p1 = t1.select(t1.output.head) - val p2 = t2.select(t2.output.head) + val p1 = t1.select(t1.output.head).as("p1") + val p2 = t2.select(t2.output.head).as("p2") val union = p1.union(p2) - val wp1 = widenSetOperationTypes(union.select(p1.output.head)) + val wp1 = widenSetOperationTypes(union.select(p1.output.head, $"p2.v")) assert(wp1.isInstanceOf[Project]) - assert(wp1.missingInput.isEmpty) + // The attribute `p1.output.head` should be replaced in the root `Project`. + assert(wp1.expressions.forall(_.find(_ == p1.output.head).isEmpty)) val wp2 = widenSetOperationTypes(Aggregate(Nil, sum(p1.output.head).as("v") :: Nil, union)) assert(wp2.isInstanceOf[Aggregate]) assert(wp2.missingInput.isEmpty)