-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-32741][SQL] Check if the same ExprId refers to the unique attribute in logical plans #29585
Changes from all commits
6b73170
cf1e86f
8da2af5
2d0d265
8fece42
73ca795
93cd9e6
c78e517
8b86bcf
4c8a81e
454821a
fe44f6f
241d697
2c6ba20
d98b915
e3c8742
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -338,15 +338,20 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper | |
object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | ||
/** | ||
* Extract all correlated scalar subqueries from an expression. The subqueries are collected using | ||
* the given collector. The expression is rewritten and returned. | ||
* the given collector. To avoid the reuse of `exprId`s, this method generates new `exprId` | ||
* for the subqueries and rewrite references in the given `expression`. | ||
* This method returns extracted subqueries and the corresponding `exprId`s and these values | ||
* will be used later in `constructLeftJoins` for building the child plan that | ||
* returns subquery output with the `exprId`s. | ||
*/ | ||
private def extractCorrelatedScalarSubqueries[E <: Expression]( | ||
expression: E, | ||
subqueries: ArrayBuffer[ScalarSubquery]): E = { | ||
subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): E = { | ||
val newExpression = expression transform { | ||
case s: ScalarSubquery if s.children.nonEmpty => | ||
subqueries += s | ||
s.plan.output.head | ||
val newExprId = NamedExpression.newExprId | ||
subqueries += s -> newExprId | ||
s.plan.output.head.withExprId(newExprId) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need more comments to explain what's going on here: we will rewrite the subqueries later and reference their results here by a new attribute. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW I feel it's more natural to replace attributes after the subquery rewriting, instead of generating the expr id here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yea, that refactoring looks reasonable and the current logic (top-down rewriting) looks a bit weird. But, I wanna keep the current logic in this PR in order to avoid making new bugs. Is it okay to do it in a separate PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added comments there. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's top-down now, let's defer to change it and fix it in a new PR. |
||
} | ||
newExpression.asInstanceOf[E] | ||
} | ||
|
@@ -510,16 +515,16 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
*/ | ||
private def constructLeftJoins( | ||
child: LogicalPlan, | ||
subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = { | ||
subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): LogicalPlan = { | ||
subqueries.foldLeft(child) { | ||
case (currentChild, ScalarSubquery(query, conditions, _)) => | ||
case (currentChild, (ScalarSubquery(query, conditions, _), newExprId)) => | ||
val origOutput = query.output.head | ||
|
||
val resultWithZeroTups = evalSubqueryOnZeroTups(query) | ||
if (resultWithZeroTups.isEmpty) { | ||
// CASE 1: Subquery guaranteed not to have the COUNT bug | ||
Project( | ||
currentChild.output :+ origOutput, | ||
currentChild.output :+ Alias(origOutput, origOutput.name)(exprId = newExprId), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need a new expr id here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rewritten expressions in a parent node, |
||
Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) | ||
} else { | ||
// Subquery might have the COUNT bug. Add appropriate corrections. | ||
|
@@ -544,7 +549,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
Alias( | ||
If(IsNull(alwaysTrueRef), | ||
resultWithZeroTups.get, | ||
aggValRef), origOutput.name)(exprId = origOutput.exprId), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NOTE: Found by the integrity check: https://github.com/apache/spark/pull/29585/files#diff-27c76f96a7b2733ecfd6f46a1716e153R238 |
||
aggValRef), origOutput.name)(exprId = newExprId), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we just let There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We might be able to do so, but we need more logic so that a parent node, In the current master, the parent node just refer to an attribute with the expr Id of the original output ( If necessary, I will try this approach tomorrow, so please let me know. |
||
Join(currentChild, | ||
Project(query.output :+ alwaysTrueExpr, query), | ||
LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) | ||
|
@@ -571,7 +576,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
(IsNull(alwaysTrueRef), resultWithZeroTups.get), | ||
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))), | ||
aggValRef), | ||
origOutput.name)(exprId = origOutput.exprId) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NOTE: Found by the integrity check: https://github.com/apache/spark/pull/29585/files#diff-27c76f96a7b2733ecfd6f46a1716e153R238 |
||
origOutput.name)(exprId = newExprId) | ||
|
||
Project( | ||
currentChild.output :+ caseExpr, | ||
|
@@ -588,36 +593,42 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { | |
* Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing correlated scalar | ||
* subqueries. | ||
*/ | ||
def apply(plan: LogicalPlan): LogicalPlan = plan transform { | ||
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput { | ||
case a @ Aggregate(grouping, expressions, child) => | ||
val subqueries = ArrayBuffer.empty[ScalarSubquery] | ||
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)] | ||
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) | ||
if (subqueries.nonEmpty) { | ||
// We currently only allow correlated subqueries in an aggregate if they are part of the | ||
// grouping expressions. As a result we need to replace all the scalar subqueries in the | ||
// grouping expressions by their result. | ||
val newGrouping = grouping.map { e => | ||
subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e) | ||
subqueries.find(_._1.semanticEquals(e)).map(_._1.plan.output.head).getOrElse(e) | ||
} | ||
Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries)) | ||
val newAgg = Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries)) | ||
val attrMapping = a.output.zip(newAgg.output) | ||
newAgg -> attrMapping | ||
} else { | ||
a | ||
a -> Nil | ||
} | ||
case p @ Project(expressions, child) => | ||
val subqueries = ArrayBuffer.empty[ScalarSubquery] | ||
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)] | ||
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) | ||
if (subqueries.nonEmpty) { | ||
Project(newExpressions, constructLeftJoins(child, subqueries)) | ||
val newProj = Project(newExpressions, constructLeftJoins(child, subqueries)) | ||
val attrMapping = p.output.zip(newProj.output) | ||
newProj -> attrMapping | ||
} else { | ||
p | ||
p -> Nil | ||
} | ||
case f @ Filter(condition, child) => | ||
val subqueries = ArrayBuffer.empty[ScalarSubquery] | ||
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)] | ||
val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries) | ||
if (subqueries.nonEmpty) { | ||
Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries))) | ||
val newProj = Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries))) | ||
val attrMapping = f.output.zip(newProj.output) | ||
newProj -> attrMapping | ||
} else { | ||
f | ||
f -> Nil | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -203,3 +203,73 @@ abstract class BinaryNode extends LogicalPlan { | |||||||||||
abstract class OrderPreservingUnaryNode extends UnaryNode { | ||||||||||||
override final def outputOrdering: Seq[SortOrder] = child.outputOrdering | ||||||||||||
} | ||||||||||||
|
||||||||||||
object LogicalPlanIntegrity { | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we check the physical plan as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, spark/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala Lines 363 to 367 in a1e459e
Is it okay to replace the code above with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we in general won't change output in physical plan? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see. Nice suggestion. Looked around the related code, I think it is not easy to catch all the ill-formed case as you said, but IMO most common patterns to cause duplicate |
||||||||||||
|
||||||||||||
private def canGetOutputAttrs(p: LogicalPlan): Boolean = { | ||||||||||||
p.resolved && !p.expressions.exists { e => | ||||||||||||
e.collectFirst { | ||||||||||||
// We cannot call `output` in plans with a `ScalarSubquery` expr having no column, | ||||||||||||
// so, we filter out them in advance. | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what will happen if we can There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some existing tests failed with the assertion below; spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala Line 229 in e7d9a24
|
||||||||||||
case s: ScalarSubquery if s.plan.schema.fields.isEmpty => true | ||||||||||||
}.isDefined | ||||||||||||
} | ||||||||||||
} | ||||||||||||
|
||||||||||||
/** | ||||||||||||
* Since some logical plans (e.g., `Union`) can build `AttributeReference`s in their `output`, | ||||||||||||
* this method checks if the same `ExprId` refers to attributes having the same data type | ||||||||||||
* in plan output. | ||||||||||||
*/ | ||||||||||||
def hasUniqueExprIdsForOutput(plan: LogicalPlan): Boolean = { | ||||||||||||
val exprIds = plan.collect { case p if canGetOutputAttrs(p) => | ||||||||||||
// NOTE: we still need to filter resolved expressions here because the output of | ||||||||||||
// some resolved logical plans can have unresolved references, | ||||||||||||
// e.g., outer references in `ExistenceJoin`. | ||||||||||||
p.output.filter(_.resolved).map { a => (a.exprId, a.dataType) } | ||||||||||||
}.flatten | ||||||||||||
|
||||||||||||
val ignoredExprIds = plan.collect { | ||||||||||||
// NOTE: `Union` currently reuses input `ExprId`s for output references, but we cannot | ||||||||||||
// simply modify the code for assigning new `ExprId`s in `Union#output` because | ||||||||||||
// the modification will make breaking changes (See SPARK-32741(#29585)). | ||||||||||||
// So, this check just ignores the `exprId`s of `Union` output. | ||||||||||||
case u: Union if u.resolved => u.output.map(_.exprId) | ||||||||||||
}.flatten.toSet | ||||||||||||
|
||||||||||||
val groupedDataTypesByExprId = exprIds.filterNot { case (exprId, _) => | ||||||||||||
ignoredExprIds.contains(exprId) | ||||||||||||
}.groupBy(_._1).values.map(_.distinct) | ||||||||||||
|
||||||||||||
groupedDataTypesByExprId.forall(_.length == 1) | ||||||||||||
} | ||||||||||||
|
||||||||||||
/** | ||||||||||||
* This method checks if reference `ExprId`s are not reused when assigning a new `ExprId`. | ||||||||||||
* For example, it returns false if plan transformers create an alias having the same `ExprId` | ||||||||||||
* with one of reference attributes, e.g., `a#1 + 1 AS a#1`. | ||||||||||||
*/ | ||||||||||||
def checkIfSameExprIdNotReused(plan: LogicalPlan): Boolean = { | ||||||||||||
plan.collect { case p if p.resolved => | ||||||||||||
p.expressions.forall { | ||||||||||||
case a: Alias => | ||||||||||||
// Even if a plan is resolved, `a.references` can return unresolved references, | ||||||||||||
// e.g., in `Grouping`/`GroupingID`, so we need to filter out them and | ||||||||||||
// check if the same `exprId` in `Alias` does not exist | ||||||||||||
// among reference `exprId`s. | ||||||||||||
!a.references.filter(_.resolved).map(_.exprId).exists(_ == a.exprId) | ||||||||||||
case _ => | ||||||||||||
true | ||||||||||||
} | ||||||||||||
}.forall(identity) | ||||||||||||
} | ||||||||||||
|
||||||||||||
/** | ||||||||||||
* This method checks if the same `ExprId` refers to an unique attribute in a plan tree. | ||||||||||||
* Some plan transformers (e.g., `RemoveNoopOperators`) rewrite logical | ||||||||||||
* plans based on this assumption. | ||||||||||||
*/ | ||||||||||||
def checkIfExprIdsAreGloballyUnique(plan: LogicalPlan): Boolean = { | ||||||||||||
checkIfSameExprIdNotReused(plan) && hasUniqueExprIdsForOutput(plan) | ||||||||||||
} | ||||||||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.catalyst.plans.logical | ||
|
||
import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
import org.apache.spark.sql.catalyst.dsl.plans._ | ||
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference} | ||
import org.apache.spark.sql.catalyst.plans.PlanTest | ||
import org.apache.spark.sql.types.LongType | ||
|
||
class LogicalPlanIntegritySuite extends PlanTest { | ||
import LogicalPlanIntegrity._ | ||
|
||
case class OutputTestPlan(child: LogicalPlan, output: Seq[Attribute]) extends UnaryNode { | ||
override val analyzed = true | ||
} | ||
|
||
test("Checks if the same `ExprId` refers to a semantically-equal attribute in a plan output") { | ||
val t = LocalRelation('a.int, 'b.int) | ||
assert(hasUniqueExprIdsForOutput(OutputTestPlan(t, t.output))) | ||
assert(!hasUniqueExprIdsForOutput(OutputTestPlan(t, t.output.zipWithIndex.map { | ||
case (a, i) => AttributeReference(s"c$i", LongType)(a.exprId) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does it break the check? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's because the the same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the same case with
Then, the integrity check fails as follows;
In this case, we need to reassign these IDs like this;
|
||
}))) | ||
} | ||
|
||
test("Checks if reference ExprIds are not reused when assigning a new ExprId") { | ||
val t = LocalRelation('a.int, 'b.int) | ||
val Seq(a, b) = t.output | ||
assert(checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")()))) | ||
assert(!checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")(exprId = a.exprId)))) | ||
assert(checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")(exprId = b.exprId)))) | ||
assert(checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")()))) | ||
assert(!checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")(exprId = a.exprId)))) | ||
assert(!checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")(exprId = b.exprId)))) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,7 @@ | |
|
||
package org.apache.spark.sql.execution.adaptive | ||
|
||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan | ||
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LogicalPlanIntegrity, PlanHelper} | ||
import org.apache.spark.sql.catalyst.rules.RuleExecutor | ||
import org.apache.spark.sql.internal.SQLConf | ||
import org.apache.spark.util.Utils | ||
|
@@ -54,4 +54,10 @@ class AQEOptimizer(conf: SQLConf) extends RuleExecutor[LogicalPlan] { | |
} | ||
} | ||
} | ||
|
||
override protected def isPlanIntegral(plan: LogicalPlan): Boolean = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NOTE: I noticed that we forgot to impmement it here. |
||
!Utils.isTesting || (plan.resolved && | ||
plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty && | ||
LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan)) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,7 +36,6 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} | |
import org.apache.spark.sql._ | ||
import org.apache.spark.sql.catalyst.plans.logical.Range | ||
import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2} | ||
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS | ||
import org.apache.spark.sql.catalyst.util.DateTimeUtils | ||
import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan} | ||
import org.apache.spark.sql.execution.command.ExplainCommand | ||
|
@@ -47,7 +46,7 @@ import org.apache.spark.sql.functions._ | |
import org.apache.spark.sql.internal.SQLConf | ||
import org.apache.spark.sql.sources.StreamSourceProvider | ||
import org.apache.spark.sql.streaming.util.{BlockOnStopSourceProvider, StreamManualClock} | ||
import org.apache.spark.sql.types.{IntegerType, StructField, StructType} | ||
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType} | ||
import org.apache.spark.util.Utils | ||
|
||
class StreamSuite extends StreamTest { | ||
|
@@ -1268,7 +1267,7 @@ class StreamSuite extends StreamTest { | |
} | ||
|
||
abstract class FakeSource extends StreamSourceProvider { | ||
private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) | ||
private val fakeSchema = StructType(StructField("a", LongType) :: Nil) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NOTE: This ill-formed definition was found by the integrity check: https://github.com/apache/spark/pull/29585/files#diff-27c76f96a7b2733ecfd6f46a1716e153R224 |
||
|
||
override def sourceSchema( | ||
spark: SQLContext, | ||
|
@@ -1290,7 +1289,7 @@ class FakeDefaultSource extends FakeSource { | |
new Source { | ||
private var offset = -1L | ||
|
||
override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil) | ||
override def schema: StructType = StructType(StructField("a", LongType) :: Nil) | ||
|
||
override def getOffset: Option[Offset] = { | ||
if (offset >= 10) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NOTE: Found by the integrity check: https://github.com/apache/spark/pull/29585/files#diff-27c76f96a7b2733ecfd6f46a1716e153R238