Skip to content

Commit

Permalink
[SPARK-32741][SQL] Check if the same ExprId refers to the unique attr…
Browse files Browse the repository at this point in the history
…ibute in logical plans

### What changes were proposed in this pull request?

Some plan transformations (e.g., `RemoveNoopOperators`) implicitly assume the same `ExprId` refers to the unique attribute. But, `RuleExecutor` does not check this integrity between logical plan transformations. So, this PR intends to add this check in `isPlanIntegral` of `Analyzer`/`Optimizer`.

This PR comes from the talk with cloud-fan viirya in #29485 (comment)

### Why are the changes needed?

For better logical plan integrity checking.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing tests.

Closes #29585 from maropu/PlanIntegrityTest.

Authored-by: Takeshi Yamamuro <yamamuro@apache.org>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
  • Loading branch information
maropu committed Sep 30, 2020
1 parent cc06266 commit 3a299aa
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssignmentPolicy}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.Utils

/**
* A trivial [[Analyzer]] with a dummy [[SessionCatalog]] and [[EmptyFunctionRegistry]].
Expand Down Expand Up @@ -136,6 +137,10 @@ class Analyzer(

private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog

override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
!Utils.isTesting || LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan)
}

override def isView(nameParts: Seq[String]): Boolean = v1SessionCatalog.isView(nameParts)

// Only for tests.
Expand Down Expand Up @@ -2777,8 +2782,8 @@ class Analyzer(
// a resolved Aggregate will not have Window Functions.
case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
if child.resolved &&
hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
// Create an Aggregate operator to evaluate aggregation functions.
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
Expand All @@ -2795,7 +2800,7 @@ class Analyzer(
// Aggregate without Having clause.
case a @ Aggregate(groupingExprs, aggregateExprs, child)
if hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
a.expressions.forall(_.resolved) =>
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
// Create an Aggregate operator to evaluate aggregation functions.
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -44,9 +43,11 @@ abstract class Optimizer(catalogManager: CatalogManager)
// Currently we check after the execution of each rule if a plan:
// - is still resolved
// - only host special expressions in supported operators
// - has globally-unique attribute IDs
override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
!Utils.isTesting || (plan.resolved &&
plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty)
plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty &&
LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan))
}

override protected val excludedOnceBatches: Set[String] =
Expand Down Expand Up @@ -1585,14 +1586,14 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
* Replaces logical [[Deduplicate]] operator with an [[Aggregate]] operator.
*/
object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Deduplicate(keys, child) if !child.isStreaming =>
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
case d @ Deduplicate(keys, child) if !child.isStreaming =>
val keyExprIds = keys.map(_.exprId)
val aggCols = child.output.map { attr =>
if (keyExprIds.contains(attr.exprId)) {
attr
} else {
Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId)
Alias(new First(attr).toAggregateExpression(), attr.name)()
}
}
// SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping
Expand All @@ -1601,7 +1602,9 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
// we append a literal when the grouping key list is empty so that the result aggregate
// operator is properly treated as a grouping aggregation.
val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys
Aggregate(nonemptyKeys, aggCols, child)
val newAgg = Aggregate(nonemptyKeys, aggCols, child)
val attrMapping = d.output.zip(newAgg.output)
newAgg -> attrMapping
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,15 +338,20 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
/**
* Extract all correlated scalar subqueries from an expression. The subqueries are collected using
* the given collector. The expression is rewritten and returned.
* the given collector. To avoid the reuse of `exprId`s, this method generates new `exprId`
* for the subqueries and rewrite references in the given `expression`.
* This method returns extracted subqueries and the corresponding `exprId`s and these values
* will be used later in `constructLeftJoins` for building the child plan that
* returns subquery output with the `exprId`s.
*/
private def extractCorrelatedScalarSubqueries[E <: Expression](
expression: E,
subqueries: ArrayBuffer[ScalarSubquery]): E = {
subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): E = {
val newExpression = expression transform {
case s: ScalarSubquery if s.children.nonEmpty =>
subqueries += s
s.plan.output.head
val newExprId = NamedExpression.newExprId
subqueries += s -> newExprId
s.plan.output.head.withExprId(newExprId)
}
newExpression.asInstanceOf[E]
}
Expand Down Expand Up @@ -510,16 +515,16 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
*/
private def constructLeftJoins(
child: LogicalPlan,
subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = {
subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): LogicalPlan = {
subqueries.foldLeft(child) {
case (currentChild, ScalarSubquery(query, conditions, _)) =>
case (currentChild, (ScalarSubquery(query, conditions, _), newExprId)) =>
val origOutput = query.output.head

val resultWithZeroTups = evalSubqueryOnZeroTups(query)
if (resultWithZeroTups.isEmpty) {
// CASE 1: Subquery guaranteed not to have the COUNT bug
Project(
currentChild.output :+ origOutput,
currentChild.output :+ Alias(origOutput, origOutput.name)(exprId = newExprId),
Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
} else {
// Subquery might have the COUNT bug. Add appropriate corrections.
Expand All @@ -544,7 +549,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
Alias(
If(IsNull(alwaysTrueRef),
resultWithZeroTups.get,
aggValRef), origOutput.name)(exprId = origOutput.exprId),
aggValRef), origOutput.name)(exprId = newExprId),
Join(currentChild,
Project(query.output :+ alwaysTrueExpr, query),
LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
Expand All @@ -571,7 +576,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
(IsNull(alwaysTrueRef), resultWithZeroTups.get),
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
aggValRef),
origOutput.name)(exprId = origOutput.exprId)
origOutput.name)(exprId = newExprId)

Project(
currentChild.output :+ caseExpr,
Expand All @@ -588,36 +593,42 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
* Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing correlated scalar
* subqueries.
*/
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
case a @ Aggregate(grouping, expressions, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
// We currently only allow correlated subqueries in an aggregate if they are part of the
// grouping expressions. As a result we need to replace all the scalar subqueries in the
// grouping expressions by their result.
val newGrouping = grouping.map { e =>
subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e)
subqueries.find(_._1.semanticEquals(e)).map(_._1.plan.output.head).getOrElse(e)
}
Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries))
val newAgg = Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries))
val attrMapping = a.output.zip(newAgg.output)
newAgg -> attrMapping
} else {
a
a -> Nil
}
case p @ Project(expressions, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
Project(newExpressions, constructLeftJoins(child, subqueries))
val newProj = Project(newExpressions, constructLeftJoins(child, subqueries))
val attrMapping = p.output.zip(newProj.output)
newProj -> attrMapping
} else {
p
p -> Nil
}
case f @ Filter(condition, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
if (subqueries.nonEmpty) {
Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries)))
val newProj = Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries)))
val attrMapping = f.output.zip(newProj.output)
newProj -> attrMapping
} else {
f
f -> Nil
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,73 @@ abstract class BinaryNode extends LogicalPlan {
abstract class OrderPreservingUnaryNode extends UnaryNode {
override final def outputOrdering: Seq[SortOrder] = child.outputOrdering
}

object LogicalPlanIntegrity {

private def canGetOutputAttrs(p: LogicalPlan): Boolean = {
p.resolved && !p.expressions.exists { e =>
e.collectFirst {
// We cannot call `output` in plans with a `ScalarSubquery` expr having no column,
// so, we filter out them in advance.
case s: ScalarSubquery if s.plan.schema.fields.isEmpty => true
}.isDefined
}
}

/**
* Since some logical plans (e.g., `Union`) can build `AttributeReference`s in their `output`,
* this method checks if the same `ExprId` refers to attributes having the same data type
* in plan output.
*/
def hasUniqueExprIdsForOutput(plan: LogicalPlan): Boolean = {
val exprIds = plan.collect { case p if canGetOutputAttrs(p) =>
// NOTE: we still need to filter resolved expressions here because the output of
// some resolved logical plans can have unresolved references,
// e.g., outer references in `ExistenceJoin`.
p.output.filter(_.resolved).map { a => (a.exprId, a.dataType) }
}.flatten

val ignoredExprIds = plan.collect {
// NOTE: `Union` currently reuses input `ExprId`s for output references, but we cannot
// simply modify the code for assigning new `ExprId`s in `Union#output` because
// the modification will make breaking changes (See SPARK-32741(#29585)).
// So, this check just ignores the `exprId`s of `Union` output.
case u: Union if u.resolved => u.output.map(_.exprId)
}.flatten.toSet

val groupedDataTypesByExprId = exprIds.filterNot { case (exprId, _) =>
ignoredExprIds.contains(exprId)
}.groupBy(_._1).values.map(_.distinct)

groupedDataTypesByExprId.forall(_.length == 1)
}

/**
* This method checks if reference `ExprId`s are not reused when assigning a new `ExprId`.
* For example, it returns false if plan transformers create an alias having the same `ExprId`
* with one of reference attributes, e.g., `a#1 + 1 AS a#1`.
*/
def checkIfSameExprIdNotReused(plan: LogicalPlan): Boolean = {
plan.collect { case p if p.resolved =>
p.expressions.forall {
case a: Alias =>
// Even if a plan is resolved, `a.references` can return unresolved references,
// e.g., in `Grouping`/`GroupingID`, so we need to filter out them and
// check if the same `exprId` in `Alias` does not exist
// among reference `exprId`s.
!a.references.filter(_.resolved).map(_.exprId).exists(_ == a.exprId)
case _ =>
true
}
}.forall(identity)
}

/**
* This method checks if the same `ExprId` refers to an unique attribute in a plan tree.
* Some plan transformers (e.g., `RemoveNoopOperators`) rewrite logical
* plans based on this assumption.
*/
def checkIfExprIdsAreGloballyUnique(plan: LogicalPlan): Boolean = {
checkIfSameExprIdNotReused(plan) && hasUniqueExprIdsForOutput(plan)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ class FoldablePropagationSuite extends PlanTest {
val query = expand.where(a1.isNotNull).select(a1, a2).analyze
val optimized = Optimize.execute(query)
val correctExpand = expand.copy(projections = Seq(
Seq(Literal(null), c2),
Seq(c1, Literal(null))))
Seq(Literal(null), Literal(2)),
Seq(Literal(1), Literal(null))))
val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze
comparePlans(optimized, correctAnswer)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.types.LongType

class LogicalPlanIntegritySuite extends PlanTest {
import LogicalPlanIntegrity._

case class OutputTestPlan(child: LogicalPlan, output: Seq[Attribute]) extends UnaryNode {
override val analyzed = true
}

test("Checks if the same `ExprId` refers to a semantically-equal attribute in a plan output") {
val t = LocalRelation('a.int, 'b.int)
assert(hasUniqueExprIdsForOutput(OutputTestPlan(t, t.output)))
assert(!hasUniqueExprIdsForOutput(OutputTestPlan(t, t.output.zipWithIndex.map {
case (a, i) => AttributeReference(s"c$i", LongType)(a.exprId)
})))
}

test("Checks if reference ExprIds are not reused when assigning a new ExprId") {
val t = LocalRelation('a.int, 'b.int)
val Seq(a, b) = t.output
assert(checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")())))
assert(!checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")(exprId = a.exprId))))
assert(checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")(exprId = b.exprId))))
assert(checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")())))
assert(!checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")(exprId = a.exprId))))
assert(!checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")(exprId = b.exprId))))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.execution.adaptive

import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LogicalPlanIntegrity, PlanHelper}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -54,4 +54,10 @@ class AQEOptimizer(conf: SQLConf) extends RuleExecutor[LogicalPlan] {
}
}
}

override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
!Utils.isTesting || (plan.resolved &&
plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty &&
LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.Range
import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan}
import org.apache.spark.sql.execution.command.ExplainCommand
Expand All @@ -47,7 +46,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.streaming.util.{BlockOnStopSourceProvider, StreamManualClock}
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}
import org.apache.spark.util.Utils

class StreamSuite extends StreamTest {
Expand Down Expand Up @@ -1268,7 +1267,7 @@ class StreamSuite extends StreamTest {
}

abstract class FakeSource extends StreamSourceProvider {
private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)
private val fakeSchema = StructType(StructField("a", LongType) :: Nil)

override def sourceSchema(
spark: SQLContext,
Expand All @@ -1290,7 +1289,7 @@ class FakeDefaultSource extends FakeSource {
new Source {
private var offset = -1L

override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil)
override def schema: StructType = StructType(StructField("a", LongType) :: Nil)

override def getOffset: Option[Offset] = {
if (offset >= 10) {
Expand Down

0 comments on commit 3a299aa

Please sign in to comment.