diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1d2e48301ea98..4516c71bbc514 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -123,127 +123,6 @@ object AnalysisContext { } } -object Analyzer { - - /** - * Rewrites a given `plan` recursively based on rewrite mappings from old plans to new ones. - * This method also updates all the related references in the `plan` accordingly. - * - * @param plan to rewrite - * @param rewritePlanMap has mappings from old plans to new ones for the given `plan`. - * @return a rewritten plan and updated references related to a root node of - * the given `plan` for rewriting it. - */ - def rewritePlan(plan: LogicalPlan, rewritePlanMap: Map[LogicalPlan, LogicalPlan]) - : (LogicalPlan, Seq[(Attribute, Attribute)]) = { - if (plan.resolved) { - val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]() - val newChildren = plan.children.map { child => - // If not, we'd rewrite child plan recursively until we find the - // conflict node or reach the leaf node. - val (newChild, childAttrMapping) = rewritePlan(child, rewritePlanMap) - attrMapping ++= childAttrMapping.filter { case (oldAttr, _) => - // `attrMapping` is not only used to replace the attributes of the current `plan`, - // but also to be propagated to the parent plans of the current `plan`. Therefore, - // the `oldAttr` must be part of either `plan.references` (so that it can be used to - // replace attributes of the current `plan`) or `plan.outputSet` (so that it can be - // used by those parent plans). - (plan.outputSet ++ plan.references).contains(oldAttr) - } - newChild - } - - val newPlan = if (rewritePlanMap.contains(plan)) { - rewritePlanMap(plan).withNewChildren(newChildren) - } else { - plan.withNewChildren(newChildren) - } - - assert(!attrMapping.groupBy(_._1.exprId) - .exists(_._2.map(_._2.exprId).distinct.length > 1), - "Found duplicate rewrite attributes") - - val attributeRewrites = AttributeMap(attrMapping) - // Using attrMapping from the children plans to rewrite their parent node. - // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. - val p = newPlan.transformExpressions { - case a: Attribute => - updateAttr(a, attributeRewrites) - case s: SubqueryExpression => - s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attributeRewrites)) - } - attrMapping ++= plan.output.zip(p.output) - .filter { case (a1, a2) => a1.exprId != a2.exprId } - p -> attrMapping - } else { - // Just passes through unresolved nodes - plan.mapChildren { - rewritePlan(_, rewritePlanMap)._1 - } -> Nil - } - } - - private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { - val exprId = attrMap.getOrElse(attr, attr).exprId - attr.withExprId(exprId) - } - - /** - * The outer plan may have old references and the function below updates the - * outer references to refer to the new attributes. - * - * For example (SQL): - * {{{ - * SELECT * FROM t1 - * INTERSECT - * SELECT * FROM t1 - * WHERE EXISTS (SELECT 1 - * FROM t2 - * WHERE t1.c1 = t2.c1) - * }}} - * Plan before resolveReference rule. - * 'Intersect - * :- Project [c1#245, c2#246] - * : +- SubqueryAlias t1 - * : +- Relation[c1#245,c2#246] parquet - * +- 'Project [*] - * +- Filter exists#257 [c1#245] - * : +- Project [1 AS 1#258] - * : +- Filter (outer(c1#245) = c1#251) - * : +- SubqueryAlias t2 - * : +- Relation[c1#251,c2#252] parquet - * +- SubqueryAlias t1 - * +- Relation[c1#245,c2#246] parquet - * Plan after the resolveReference rule. - * Intersect - * :- Project [c1#245, c2#246] - * : +- SubqueryAlias t1 - * : +- Relation[c1#245,c2#246] parquet - * +- Project [c1#259, c2#260] - * +- Filter exists#257 [c1#259] - * : +- Project [1 AS 1#258] - * : +- Filter (outer(c1#259) = c1#251) => Updated - * : +- SubqueryAlias t2 - * : +- Relation[c1#251,c2#252] parquet - * +- SubqueryAlias t1 - * +- Relation[c1#259,c2#260] parquet => Outer plan's attributes are rewritten. - */ - private def updateOuterReferencesInSubquery( - plan: LogicalPlan, - attrMap: AttributeMap[Attribute]): LogicalPlan = { - AnalysisHelper.allowInvokingTransformsInAnalyzer { - plan transformDown { case currentFragment => - currentFragment transformExpressions { - case OuterReference(a: Attribute) => - OuterReference(updateAttr(a, attrMap)) - case s: SubqueryExpression => - s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attrMap)) - } - } - } - } -} - /** * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and * [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]]. @@ -1376,7 +1255,14 @@ class Analyzer( if (conflictPlans.isEmpty) { right } else { - Analyzer.rewritePlan(right, conflictPlans.toMap)._1 + val planMapping = conflictPlans.toMap + right.transformUpWithNewOutput { + case oldPlan => + val newPlanOpt = planMapping.get(oldPlan) + newPlanOpt.map { newPlan => + newPlan -> oldPlan.output.zip(newPlan.output) + }.getOrElse(oldPlan -> Nil) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 861eddedc0e1b..deaa49bf423b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -329,50 +329,43 @@ object TypeCoercion { object WidenSetOperationTypes extends TypeCoercionRule { override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = { - val rewritePlanMap = mutable.ArrayBuffer[(LogicalPlan, LogicalPlan)]() - val newPlan = plan resolveOperatorsUp { + plan resolveOperatorsUpWithNewOutput { case s @ Except(left, right, isAll) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => - val newChildren = buildNewChildrenWithWiderTypes(left :: right :: Nil) - if (newChildren.nonEmpty) { - rewritePlanMap ++= newChildren - Except(newChildren.head._1, newChildren.last._1, isAll) + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) + if (newChildren.isEmpty) { + s -> Nil } else { - s + assert(newChildren.length == 2) + val attrMapping = left.output.zip(newChildren.head.output) + Except(newChildren.head, newChildren.last, isAll) -> attrMapping } case s @ Intersect(left, right, isAll) if s.childrenResolved && left.output.length == right.output.length && !s.resolved => - val newChildren = buildNewChildrenWithWiderTypes(left :: right :: Nil) - if (newChildren.nonEmpty) { - rewritePlanMap ++= newChildren - Intersect(newChildren.head._1, newChildren.last._1, isAll) + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(left :: right :: Nil) + if (newChildren.isEmpty) { + s -> Nil } else { - s + assert(newChildren.length == 2) + val attrMapping = left.output.zip(newChildren.head.output) + Intersect(newChildren.head, newChildren.last, isAll) -> attrMapping } case s: Union if s.childrenResolved && !s.byName && s.children.forall(_.output.length == s.children.head.output.length) && !s.resolved => - val newChildren = buildNewChildrenWithWiderTypes(s.children) - if (newChildren.nonEmpty) { - rewritePlanMap ++= newChildren - s.copy(children = newChildren.map(_._1)) + val newChildren: Seq[LogicalPlan] = buildNewChildrenWithWiderTypes(s.children) + if (newChildren.isEmpty) { + s -> Nil } else { - s + val attrMapping = s.children.head.output.zip(newChildren.head.output) + s.copy(children = newChildren) -> attrMapping } } - - if (rewritePlanMap.nonEmpty) { - assert(!plan.fastEquals(newPlan)) - Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1 - } else { - plan - } } /** Build new children with the widest types for each attribute among all the children */ - private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]) - : Seq[(LogicalPlan, LogicalPlan)] = { + private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = { require(children.forall(_.output.length == children.head.output.length)) // Get a sequence of data types, each of which is the widest type of this specific attribute @@ -408,16 +401,13 @@ object TypeCoercion { } /** Given a plan, add an extra project on top to widen some columns' data types. */ - private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]) - : (LogicalPlan, LogicalPlan) = { + private def widenTypes(plan: LogicalPlan, targetTypes: Seq[DataType]): LogicalPlan = { val casted = plan.output.zip(targetTypes).map { case (e, dt) if e.dataType != dt => - val alias = Alias(Cast(e, dt), e.name)(exprId = e.exprId) - alias -> alias.newInstance() - case (e, _) => - e -> e - }.unzip - Project(casted._1, plan) -> Project(casted._2, plan) + Alias(Cast(e, dt, Some(SQLConf.get.sessionLocalTimeZone)), e.name)() + case (e, _) => e + } + Project(casted, plan) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 7133fb231d672..fed5df69580ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans +import scala.collection.mutable + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode, TreeNodeTag} @@ -168,6 +170,89 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT }.toSeq } + /** + * A variant of `transformUp`, which takes care of the case that the rule replaces a plan node + * with a new one that has different output expr IDs, by updating the attribute references in + * the parent nodes accordingly. + * + * @param rule the function to transform plan nodes, and return new nodes with attributes mapping + * from old attributes to new attributes. The attribute mapping will be used to + * rewrite attribute references in the parent nodes. + * @param skipCond a boolean condition to indicate if we can skip transforming a plan node to save + * time. + */ + def transformUpWithNewOutput( + rule: PartialFunction[PlanType, (PlanType, Seq[(Attribute, Attribute)])], + skipCond: PlanType => Boolean = _ => false): PlanType = { + def rewrite(plan: PlanType): (PlanType, Seq[(Attribute, Attribute)]) = { + if (skipCond(plan)) { + plan -> Nil + } else { + val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]() + var newPlan = plan.mapChildren { child => + val (newChild, childAttrMapping) = rewrite(child) + attrMapping ++= childAttrMapping + newChild + } + + val attrMappingForCurrentPlan = attrMapping.filter { + // The `attrMappingForCurrentPlan` is used to replace the attributes of the + // current `plan`, so the `oldAttr` must be part of `plan.references`. + case (oldAttr, _) => plan.references.contains(oldAttr) + } + + val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) { + rule.applyOrElse(newPlan, (plan: PlanType) => plan -> Nil) + } + newPlan = planAfterRule + + if (attrMappingForCurrentPlan.nonEmpty) { + assert(!attrMappingForCurrentPlan.groupBy(_._1.exprId) + .exists(_._2.map(_._2.exprId).distinct.length > 1), + "Found duplicate rewrite attributes") + + val attributeRewrites = AttributeMap(attrMappingForCurrentPlan) + // Using attrMapping from the children plans to rewrite their parent node. + // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. + newPlan = newPlan.transformExpressions { + case a: AttributeReference => + updateAttr(a, attributeRewrites) + case pe: PlanExpression[PlanType] => + pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attributeRewrites)) + } + } + + attrMapping ++= newAttrMapping.filter { + case (a1, a2) => a1.exprId != a2.exprId + } + newPlan -> attrMapping + } + } + rewrite(this)._1 + } + + private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { + val exprId = attrMap.getOrElse(attr, attr).exprId + attr.withExprId(exprId) + } + + /** + * The outer plan may have old references and the function below updates the + * outer references to refer to the new attributes. + */ + private def updateOuterReferencesInSubquery( + plan: PlanType, + attrMap: AttributeMap[Attribute]): PlanType = { + plan.transformDown { case currentFragment => + currentFragment.transformExpressions { + case OuterReference(a: AttributeReference) => + OuterReference(updateAttr(a, attrMap)) + case pe: PlanExpression[PlanType] => + pe.withNewPlan(updateOuterReferencesInSubquery(pe.plan, attrMap)) + } + } + } + lazy val schema: StructType = StructType.fromAttributes(output) /** Returns the output schema in the tree format. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala index 9404a809b453c..30447db1acc04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/AnalysisHelper.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.analysis.CheckAnalysis -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, TreeNode} import org.apache.spark.util.Utils @@ -120,6 +120,19 @@ trait AnalysisHelper extends QueryPlan[LogicalPlan] { self: LogicalPlan => } } + /** + * A variant of `transformUpWithNewOutput`, which skips touching already analyzed plan. + */ + def resolveOperatorsUpWithNewOutput( + rule: PartialFunction[LogicalPlan, (LogicalPlan, Seq[(Attribute, Attribute)])]) + : LogicalPlan = { + if (!analyzed) { + transformUpWithNewOutput(rule, skipCond = _.analyzed) + } else { + self + } + } + /** * Recursively transforms the expressions of a tree, skipping nodes that have already * been analyzed. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 1af562fd1a061..7b80de908fa08 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1419,12 +1419,13 @@ class TypeCoercionSuite extends AnalysisTest { test("SPARK-32638: corrects references when adding aliases in WidenSetOperationTypes") { val t1 = LocalRelation(AttributeReference("v", DecimalType(10, 0))()) val t2 = LocalRelation(AttributeReference("v", DecimalType(11, 0))()) - val p1 = t1.select(t1.output.head) - val p2 = t2.select(t2.output.head) + val p1 = t1.select(t1.output.head).as("p1") + val p2 = t2.select(t2.output.head).as("p2") val union = p1.union(p2) - val wp1 = widenSetOperationTypes(union.select(p1.output.head)) + val wp1 = widenSetOperationTypes(union.select(p1.output.head, $"p2.v")) assert(wp1.isInstanceOf[Project]) - assert(wp1.missingInput.isEmpty) + // The attribute `p1.output.head` should be replaced in the root `Project`. + assert(wp1.expressions.forall(_.find(_ == p1.output.head).isEmpty)) val wp2 = widenSetOperationTypes(Aggregate(Nil, sum(p1.output.head).as("v") :: Nil, union)) assert(wp2.isInstanceOf[Aggregate]) assert(wp2.missingInput.isEmpty)