Skip to content

Commit

Permalink
review
Browse files Browse the repository at this point in the history
  • Loading branch information
maropu committed Sep 7, 2020
1 parent 0ca08ca commit a797d66
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,56 +137,59 @@ object Analyzer {
*/
def rewritePlan(plan: LogicalPlan, rewritePlanMap: Map[LogicalPlan, LogicalPlan])
: (LogicalPlan, Seq[(Attribute, Attribute)]) = {
if (plan.resolved) {
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
val newChildren = plan.children.map { child =>
// If not, we'd rewrite child plan recursively until we find the
// conflict node or reach the leaf node.
val (newChild, childAttrMapping) = rewritePlan(child, rewritePlanMap)
attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
// `attrMapping` is not only used to replace the attributes of the current `plan`,
// but also to be propagated to the parent plans of the current `plan`. Therefore,
// the `oldAttr` must be part of either `plan.references` (so that it can be used to
// replace attributes of the current `plan`) or `plan.outputSet` (so that it can be
// used by those parent plans).
val attrMapping = new mutable.ArrayBuffer[(Attribute, Attribute)]()
val newChildren = plan.children.map { child =>
// If not, we'd rewrite child plan recursively until we find the
// conflict node or reach the leaf node.
val (newChild, childAttrMapping) = rewritePlan(child, rewritePlanMap)
attrMapping ++= childAttrMapping.filter { case (oldAttr, _) =>
// `attrMapping` is not only used to replace the attributes of the current `plan`,
// but also to be propagated to the parent plans of the current `plan`. Therefore,
// the `oldAttr` must be part of either `plan.references` (so that it can be used to
// replace attributes of the current `plan`) or `plan.outputSet` (so that it can be
// used by those parent plans).
if (plan.resolved) {
(plan.outputSet ++ plan.references).contains(oldAttr)
} else {
plan.references.filter(_.resolved).contains(oldAttr)
}
newChild
}
newChild
}

val newPlan = if (rewritePlanMap.contains(plan)) {
rewritePlanMap(plan).withNewChildren(newChildren)
} else {
plan.withNewChildren(newChildren)
}
val newPlan = if (rewritePlanMap.contains(plan)) {
rewritePlanMap(plan).withNewChildren(newChildren)
} else {
plan.withNewChildren(newChildren)
}

assert(!attrMapping.groupBy(_._1.exprId)
.exists(_._2.map(_._2.exprId).distinct.length > 1),
"Found duplicate rewrite attributes")
assert(!attrMapping.groupBy(_._1.exprId)
.exists(_._2.map(_._2.exprId).distinct.length > 1),
"Found duplicate rewrite attributes")

val attributeRewrites = AttributeMap(attrMapping)
// Using attrMapping from the children plans to rewrite their parent node.
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
val p = newPlan.transformExpressions {
case a: Attribute =>
updateAttr(a, attributeRewrites)
case s: SubqueryExpression =>
s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attributeRewrites))
}
val attributeRewrites = AttributeMap(attrMapping)
// Using attrMapping from the children plans to rewrite their parent node.
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
val p = newPlan.transformExpressions {
case a: Attribute =>
updateAttr(a, attributeRewrites)
case s: SubqueryExpression =>
s.withNewPlan(updateOuterReferencesInSubquery(s.plan, attributeRewrites))
}
if (plan.resolved) {
attrMapping ++= plan.output.zip(p.output)
.filter { case (a1, a2) => a1.exprId != a2.exprId }
p -> attrMapping
} else {
// Just passes through unresolved nodes
plan.mapChildren {
rewritePlan(_, rewritePlanMap)._1
} -> Nil
}
p -> attrMapping
}

private def updateAttr(attr: Attribute, attrMap: AttributeMap[Attribute]): Attribute = {
val exprId = attrMap.getOrElse(attr, attr).exprId
attr.withExprId(exprId)
if (attr.resolved) {
val exprId = attrMap.getOrElse(attr, attr).exprId
attr.withExprId(exprId)
} else {
attr
}
}

/**
Expand Down Expand Up @@ -2699,8 +2702,7 @@ class Analyzer(
if (missingExpr.nonEmpty) {
extractedExprBuffer += ne
}
// alias will be cleaned in the rule CleanupAliases
ne
ne.toAttribute
case e: Expression if e.foldable =>
e // No need to create an attribute reference if it will be evaluated as a Literal.
case e: Expression =>
Expand Down Expand Up @@ -2831,7 +2833,7 @@ class Analyzer(
val windowOps =
groupedWindowExpressions.foldLeft(child) {
case (last, ((partitionSpec, orderSpec, _), windowExpressions)) =>
Window(windowExpressions.toSeq, partitionSpec, orderSpec, last)
Window(windowExpressions, partitionSpec, orderSpec, last)
}

// Finally, we create a Project to output windowOps's output
Expand All @@ -2853,8 +2855,8 @@ class Analyzer(
// a resolved Aggregate will not have Window Functions.
case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
if child.resolved &&
hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
// Create an Aggregate operator to evaluate aggregation functions.
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
Expand All @@ -2871,7 +2873,7 @@ class Analyzer(
// Aggregate without Having clause.
case a @ Aggregate(groupingExprs, aggregateExprs, child)
if hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
a.expressions.forall(_.resolved) =>
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
// Create an Aggregate operator to evaluate aggregation functions.
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,10 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
sys.error(s"Unexpected operator in scalar subquery: $lp")
}

val resultMap = evalPlan(plan)
val resultMap = evalPlan(plan).mapValues { _.transform {
case a: Alias => a.newInstance() // Assigns a new `ExprId`
}
}

// By convention, the scalar subquery result is the leftmost field.
resultMap.get(plan.output.head.exprId) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,15 +236,16 @@ object LogicalPlanIntegrity {
* with one of reference attributes, e.g., `a#1 + 1 AS a#1`.
*/
def checkIfSameExprIdNotReused(plan: LogicalPlan): Boolean = {
plan.map { p =>
p.expressions.filter(_.resolved).forall { e =>
val namedExprs = e.collect {
case ne: NamedExpression if !ne.isInstanceOf[LeafExpression] => ne
plan.collect { case p if p.resolved =>
val inputExprIds = p.inputSet.filter(_.resolved).map(_.exprId).toSet
val newExprIds = p.expressions.filter(_.resolved).flatMap { e =>
e.collect {
// Only accepts the case of aliases renaming foldable expressions, e.g.,
// `FoldablePropagation` generates this renaming pattern.
case a: Alias if !a.child.foldable => a.exprId
}
namedExprs.forall { ne =>
!ne.references.filter(_.resolved).map(_.exprId).exists(_ == ne.exprId)
}
}
}.toSet
inputExprIds.intersect(newExprIds).isEmpty
}.forall(identity)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ class FoldablePropagationSuite extends PlanTest {
val query = expand.where(a1.isNotNull).select(a1, a2).analyze
val optimized = Optimize.execute(query)
val correctExpand = expand.copy(projections = Seq(
Seq(Literal(null), c2),
Seq(c1, Literal(null))))
Seq(Literal(null), Literal(2)),
Seq(Literal(1), Literal(null))))
val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze
comparePlans(optimized, correctAnswer)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class LogicalPlanIntegritySuite extends PlanTest {
val Seq(a, b) = t.output
assert(checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")())))
assert(!checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")(exprId = a.exprId))))
assert(checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")(exprId = b.exprId))))
assert(!checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")(exprId = b.exprId))))
assert(checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")())))
assert(!checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")(exprId = a.exprId))))
assert(!checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")(exprId = b.exprId))))
Expand Down

0 comments on commit a797d66

Please sign in to comment.