From a797d66b764d4ffef28f31349d3ec125adfd1402 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 4 Sep 2020 21:14:44 +0900 Subject: [PATCH] review --- .../sql/catalyst/analysis/Analyzer.scala | 90 ++++++++++--------- .../sql/catalyst/optimizer/subquery.scala | 5 +- .../catalyst/plans/logical/LogicalPlan.scala | 17 ++-- .../optimizer/FoldablePropagationSuite.scala | 4 +- .../logical/LogicalPlanIntegritySuite.scala | 2 +- 5 files changed, 62 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1f2975d432078..0cc414381520a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -137,56 +137,59 @@ object Analyzer { */ def rewritePlan(plan: LogicalPlan, rewritePlanMap: Map[LogicalPlan, LogicalPlan]) : (LogicalPlan, Seq[(Attribute, Attribute)]) = { - if (plan.resolved) { - val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]() - val newChildren = plan.children.map { child => - // If not, we'd rewrite child plan recursively until we find the - // conflict node or reach the leaf node. - val (newChild, childAttrMapping) = rewritePlan(child, rewritePlanMap) - attrMapping ++= childAttrMapping.filter { case (oldAttr, _) => - // `attrMapping` is not only used to replace the attributes of the current `plan`, - // but also to be propagated to the parent plans of the current `plan`. Therefore, - // the `oldAttr` must be part of either `plan.references` (so that it can be used to - // replace attributes of the current `plan`) or `plan.outputSet` (so that it can be - // used by those parent plans). + val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]() + val newChildren = plan.children.map { child => + // If not, we'd rewrite child plan recursively until we find the + // conflict node or reach the leaf node. + val (newChild, childAttrMapping) = rewritePlan(child, rewritePlanMap) + attrMapping ++= childAttrMapping.filter { case (oldAttr, _) => + // `attrMapping` is not only used to replace the attributes of the current `plan`, + // but also to be propagated to the parent plans of the current `plan`. Therefore, + // the `oldAttr` must be part of either `plan.references` (so that it can be used to + // replace attributes of the current `plan`) or `plan.outputSet` (so that it can be + // used by those parent plans). + if (plan.resolved) { (plan.outputSet ++ plan.references).contains(oldAttr) + } else { + plan.references.filter(_.resolved).contains(oldAttr) } - newChild } + newChild + } - val newPlan = if (rewritePlanMap.contains(plan)) { - rewritePlanMap(plan).withNewChildren(newChildren) - } else { - plan.withNewChildren(newChildren) - } + val newPlan = if (rewritePlanMap.contains(plan)) { + rewritePlanMap(plan).withNewChildren(newChildren) + } else { + plan.withNewChildren(newChildren) + } - assert(!attrMapping.groupBy(_._1.exprId) - .exists(_._2.map(_._2.exprId).distinct.length > 1), - "Found duplicate rewrite attributes") + assert(!attrMapping.groupBy(_._1.exprId) + .exists(_._2.map(_._2.exprId).distinct.length > 1), + "Found duplicate rewrite attributes") - val attributeRewrites = AttributeMap(attrMapping) - // Using attrMapping from the children plans to rewrite their parent node. - // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. - val p = newPlan.transformExpressions { - case a: Attribute => - updateAttr(a, attributeRewrites) - case s: SubqueryExpression => - s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attributeRewrites)) - } + val attributeRewrites = AttributeMap(attrMapping) + // Using attrMapping from the children plans to rewrite their parent node. + // Note that we shouldn't rewrite a node using attrMapping from its sibling nodes. + val p = newPlan.transformExpressions { + case a: Attribute => + updateAttr(a, attributeRewrites) + case s: SubqueryExpression => + s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attributeRewrites)) + } + if (plan.resolved) { attrMapping ++= plan.output.zip(p.output) .filter { case (a1, a2) => a1.exprId != a2.exprId } - p -> attrMapping - } else { - // Just passes through unresolved nodes - plan.mapChildren { - rewritePlan(_, rewritePlanMap)._1 - } -> Nil } + p -> attrMapping } private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = { - val exprId = attrMap.getOrElse(attr, attr).exprId - attr.withExprId(exprId) + if (attr.resolved) { + val exprId = attrMap.getOrElse(attr, attr).exprId + attr.withExprId(exprId) + } else { + attr + } } /** @@ -2699,8 +2702,7 @@ class Analyzer( if (missingExpr.nonEmpty) { extractedExprBuffer += ne } - // alias will be cleaned in the rule CleanupAliases - ne + ne.toAttribute case e: Expression if e.foldable => e // No need to create an attribute reference if it will be evaluated as a Literal. case e: Expression => @@ -2831,7 +2833,7 @@ class Analyzer( val windowOps = groupedWindowExpressions.foldLeft(child) { case (last, ((partitionSpec, orderSpec, _), windowExpressions)) => - Window(windowExpressions.toSeq, partitionSpec, orderSpec, last) + Window(windowExpressions, partitionSpec, orderSpec, last) } // Finally, we create a Project to output windowOps's output @@ -2853,8 +2855,8 @@ class Analyzer( // a resolved Aggregate will not have Window Functions. case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) if child.resolved && - hasWindowFunction(aggregateExprs) && - a.expressions.forall(_.resolved) => + hasWindowFunction(aggregateExprs) && + a.expressions.forall(_.resolved) => val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) // Create an Aggregate operator to evaluate aggregation functions. val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) @@ -2871,7 +2873,7 @@ class Analyzer( // Aggregate without Having clause. case a @ Aggregate(groupingExprs, aggregateExprs, child) if hasWindowFunction(aggregateExprs) && - a.expressions.forall(_.resolved) => + a.expressions.forall(_.resolved) => val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) // Create an Aggregate operator to evaluate aggregation functions. val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 642989928b7df..aaef7a49a5472 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -458,7 +458,10 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { sys.error(s"Unexpected operator in scalar subquery: $lp") } - val resultMap = evalPlan(plan) + val resultMap = evalPlan(plan).mapValues { _.transform { + case a: Alias => a.newInstance() // Assigns a new `ExprId` + } + } // By convention, the scalar subquery result is the leftmost field. resultMap.get(plan.output.head.exprId) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index e209e98e5ab53..9f9d028d88366 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -236,15 +236,16 @@ object LogicalPlanIntegrity { * with one of reference attributes, e.g., `a#1 + 1 AS a#1`. */ def checkIfSameExprIdNotReused(plan: LogicalPlan): Boolean = { - plan.map { p => - p.expressions.filter(_.resolved).forall { e => - val namedExprs = e.collect { - case ne: NamedExpression if !ne.isInstanceOf[LeafExpression] => ne + plan.collect { case p if p.resolved => + val inputExprIds = p.inputSet.filter(_.resolved).map(_.exprId).toSet + val newExprIds = p.expressions.filter(_.resolved).flatMap { e => + e.collect { + // Only accepts the case of aliases renaming foldable expressions, e.g., + // `FoldablePropagation` generates this renaming pattern. + case a: Alias if !a.child.foldable => a.exprId } - namedExprs.forall { ne => - !ne.references.filter(_.resolved).map(_.exprId).exists(_ == ne.exprId) - } - } + }.toSet + inputExprIds.intersect(newExprIds).isEmpty }.forall(identity) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala index 0d48ecb31cfa4..99411819cc182 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala @@ -156,8 +156,8 @@ class FoldablePropagationSuite extends PlanTest { val query = expand.where(a1.isNotNull).select(a1, a2).analyze val optimized = Optimize.execute(query) val correctExpand = expand.copy(projections = Seq( - Seq(Literal(null), c2), - Seq(c1, Literal(null)))) + Seq(Literal(null), Literal(2)), + Seq(Literal(1), Literal(null)))) val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala index 6f342b8d94379..87d487dbe1ac8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala @@ -43,7 +43,7 @@ class LogicalPlanIntegritySuite extends PlanTest { val Seq(a, b) = t.output assert(checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")()))) assert(!checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")(exprId = a.exprId)))) - assert(checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")(exprId = b.exprId)))) + assert(!checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")(exprId = b.exprId)))) assert(checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")()))) assert(!checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")(exprId = a.exprId)))) assert(!checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")(exprId = b.exprId))))