diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6e1f371b1a2b5..77a6631b250e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -48,6 +48,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssignmentPolicy} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.Utils /** * A trivial [[Analyzer]] with a dummy [[SessionCatalog]] and [[EmptyFunctionRegistry]]. @@ -136,6 +137,10 @@ class Analyzer( private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog + override protected def isPlanIntegral(plan: LogicalPlan): Boolean = { + !Utils.isTesting || LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan) + } + override def isView(nameParts: Seq[String]): Boolean = v1SessionCatalog.isView(nameParts) // Only for tests. @@ -2777,8 +2782,8 @@ class Analyzer( // a resolved Aggregate will not have Window Functions. case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child)) if child.resolved && - hasWindowFunction(aggregateExprs) && - a.expressions.forall(_.resolved) => + hasWindowFunction(aggregateExprs) && + a.expressions.forall(_.resolved) => val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) // Create an Aggregate operator to evaluate aggregation functions. val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) @@ -2795,7 +2800,7 @@ class Analyzer( // Aggregate without Having clause. case a @ Aggregate(groupingExprs, aggregateExprs, child) if hasWindowFunction(aggregateExprs) && - a.expressions.forall(_.resolved) => + a.expressions.forall(_.resolved) => val (windowExpressions, aggregateExpressions) = extract(aggregateExprs) // Create an Aggregate operator to evaluate aggregation functions. val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 94970740d8d91..f2360150e47b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions._ @@ -44,9 +43,11 @@ abstract class Optimizer(catalogManager: CatalogManager) // Currently we check after the execution of each rule if a plan: // - is still resolved // - only host special expressions in supported operators + // - has globally-unique attribute IDs override protected def isPlanIntegral(plan: LogicalPlan): Boolean = { !Utils.isTesting || (plan.resolved && - plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty) + plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty && + LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan)) } override protected val excludedOnceBatches: Set[String] = @@ -1585,14 +1586,14 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { * Replaces logical [[Deduplicate]] operator with an [[Aggregate]] operator. */ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Deduplicate(keys, child) if !child.isStreaming => + def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput { + case d @ Deduplicate(keys, child) if !child.isStreaming => val keyExprIds = keys.map(_.exprId) val aggCols = child.output.map { attr => if (keyExprIds.contains(attr.exprId)) { attr } else { - Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId) + Alias(new First(attr).toAggregateExpression(), attr.name)() } } // SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping @@ -1601,7 +1602,9 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] { // we append a literal when the grouping key list is empty so that the result aggregate // operator is properly treated as a grouping aggregation. val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys - Aggregate(nonemptyKeys, aggCols, child) + val newAgg = Aggregate(nonemptyKeys, aggCols, child) + val attrMapping = d.output.zip(newAgg.output) + newAgg -> attrMapping } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 7b696912aa465..a168dcd7a83f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -338,15 +338,20 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { /** * Extract all correlated scalar subqueries from an expression. The subqueries are collected using - * the given collector. The expression is rewritten and returned. + * the given collector. To avoid the reuse of `exprId`s, this method generates new `exprId` + * for the subqueries and rewrite references in the given `expression`. + * This method returns extracted subqueries and the corresponding `exprId`s and these values + * will be used later in `constructLeftJoins` for building the child plan that + * returns subquery output with the `exprId`s. */ private def extractCorrelatedScalarSubqueries[E <: Expression]( expression: E, - subqueries: ArrayBuffer[ScalarSubquery]): E = { + subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): E = { val newExpression = expression transform { case s: ScalarSubquery if s.children.nonEmpty => - subqueries += s - s.plan.output.head + val newExprId = NamedExpression.newExprId + subqueries += s -> newExprId + s.plan.output.head.withExprId(newExprId) } newExpression.asInstanceOf[E] } @@ -510,16 +515,16 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { */ private def constructLeftJoins( child: LogicalPlan, - subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = { + subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): LogicalPlan = { subqueries.foldLeft(child) { - case (currentChild, ScalarSubquery(query, conditions, _)) => + case (currentChild, (ScalarSubquery(query, conditions, _), newExprId)) => val origOutput = query.output.head val resultWithZeroTups = evalSubqueryOnZeroTups(query) if (resultWithZeroTups.isEmpty) { // CASE 1: Subquery guaranteed not to have the COUNT bug Project( - currentChild.output :+ origOutput, + currentChild.output :+ Alias(origOutput, origOutput.name)(exprId = newExprId), Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) } else { // Subquery might have the COUNT bug. Add appropriate corrections. @@ -544,7 +549,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { Alias( If(IsNull(alwaysTrueRef), resultWithZeroTups.get, - aggValRef), origOutput.name)(exprId = origOutput.exprId), + aggValRef), origOutput.name)(exprId = newExprId), Join(currentChild, Project(query.output :+ alwaysTrueExpr, query), LeftOuter, conditions.reduceOption(And), JoinHint.NONE)) @@ -571,7 +576,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { (IsNull(alwaysTrueRef), resultWithZeroTups.get), (Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))), aggValRef), - origOutput.name)(exprId = origOutput.exprId) + origOutput.name)(exprId = newExprId) Project( currentChild.output :+ caseExpr, @@ -588,36 +593,42 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { * Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing correlated scalar * subqueries. */ - def apply(plan: LogicalPlan): LogicalPlan = plan transform { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput { case a @ Aggregate(grouping, expressions, child) => - val subqueries = ArrayBuffer.empty[ScalarSubquery] + val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)] val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) if (subqueries.nonEmpty) { // We currently only allow correlated subqueries in an aggregate if they are part of the // grouping expressions. As a result we need to replace all the scalar subqueries in the // grouping expressions by their result. val newGrouping = grouping.map { e => - subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e) + subqueries.find(_._1.semanticEquals(e)).map(_._1.plan.output.head).getOrElse(e) } - Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries)) + val newAgg = Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries)) + val attrMapping = a.output.zip(newAgg.output) + newAgg -> attrMapping } else { - a + a -> Nil } case p @ Project(expressions, child) => - val subqueries = ArrayBuffer.empty[ScalarSubquery] + val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)] val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries)) if (subqueries.nonEmpty) { - Project(newExpressions, constructLeftJoins(child, subqueries)) + val newProj = Project(newExpressions, constructLeftJoins(child, subqueries)) + val attrMapping = p.output.zip(newProj.output) + newProj -> attrMapping } else { - p + p -> Nil } case f @ Filter(condition, child) => - val subqueries = ArrayBuffer.empty[ScalarSubquery] + val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)] val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries) if (subqueries.nonEmpty) { - Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries))) + val newProj = Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries))) + val attrMapping = f.output.zip(newProj.output) + newProj -> attrMapping } else { - f + f -> Nil } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 96c550616065a..48dfc5fd57e63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -203,3 +203,73 @@ abstract class BinaryNode extends LogicalPlan { abstract class OrderPreservingUnaryNode extends UnaryNode { override final def outputOrdering: Seq[SortOrder] = child.outputOrdering } + +object LogicalPlanIntegrity { + + private def canGetOutputAttrs(p: LogicalPlan): Boolean = { + p.resolved && !p.expressions.exists { e => + e.collectFirst { + // We cannot call `output` in plans with a `ScalarSubquery` expr having no column, + // so, we filter out them in advance. + case s: ScalarSubquery if s.plan.schema.fields.isEmpty => true + }.isDefined + } + } + + /** + * Since some logical plans (e.g., `Union`) can build `AttributeReference`s in their `output`, + * this method checks if the same `ExprId` refers to attributes having the same data type + * in plan output. + */ + def hasUniqueExprIdsForOutput(plan: LogicalPlan): Boolean = { + val exprIds = plan.collect { case p if canGetOutputAttrs(p) => + // NOTE: we still need to filter resolved expressions here because the output of + // some resolved logical plans can have unresolved references, + // e.g., outer references in `ExistenceJoin`. + p.output.filter(_.resolved).map { a => (a.exprId, a.dataType) } + }.flatten + + val ignoredExprIds = plan.collect { + // NOTE: `Union` currently reuses input `ExprId`s for output references, but we cannot + // simply modify the code for assigning new `ExprId`s in `Union#output` because + // the modification will make breaking changes (See SPARK-32741(#29585)). + // So, this check just ignores the `exprId`s of `Union` output. + case u: Union if u.resolved => u.output.map(_.exprId) + }.flatten.toSet + + val groupedDataTypesByExprId = exprIds.filterNot { case (exprId, _) => + ignoredExprIds.contains(exprId) + }.groupBy(_._1).values.map(_.distinct) + + groupedDataTypesByExprId.forall(_.length == 1) + } + + /** + * This method checks if reference `ExprId`s are not reused when assigning a new `ExprId`. + * For example, it returns false if plan transformers create an alias having the same `ExprId` + * with one of reference attributes, e.g., `a#1 + 1 AS a#1`. + */ + def checkIfSameExprIdNotReused(plan: LogicalPlan): Boolean = { + plan.collect { case p if p.resolved => + p.expressions.forall { + case a: Alias => + // Even if a plan is resolved, `a.references` can return unresolved references, + // e.g., in `Grouping`/`GroupingID`, so we need to filter out them and + // check if the same `exprId` in `Alias` does not exist + // among reference `exprId`s. + !a.references.filter(_.resolved).map(_.exprId).exists(_ == a.exprId) + case _ => + true + } + }.forall(identity) + } + + /** + * This method checks if the same `ExprId` refers to an unique attribute in a plan tree. + * Some plan transformers (e.g., `RemoveNoopOperators`) rewrite logical + * plans based on this assumption. + */ + def checkIfExprIdsAreGloballyUnique(plan: LogicalPlan): Boolean = { + checkIfSameExprIdNotReused(plan) && hasUniqueExprIdsForOutput(plan) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala index fe43e8e288673..92e4fa345e2ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala @@ -156,8 +156,8 @@ class FoldablePropagationSuite extends PlanTest { val query = expand.where(a1.isNotNull).select(a1, a2).analyze val optimized = Optimize.execute(query) val correctExpand = expand.copy(projections = Seq( - Seq(Literal(null), c2), - Seq(c1, Literal(null)))) + Seq(Literal(null), Literal(2)), + Seq(Literal(1), Literal(null)))) val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala new file mode 100644 index 0000000000000..6f342b8d94379 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanIntegritySuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.types.LongType + +class LogicalPlanIntegritySuite extends PlanTest { + import LogicalPlanIntegrity._ + + case class OutputTestPlan(child: LogicalPlan, output: Seq[Attribute]) extends UnaryNode { + override val analyzed = true + } + + test("Checks if the same `ExprId` refers to a semantically-equal attribute in a plan output") { + val t = LocalRelation('a.int, 'b.int) + assert(hasUniqueExprIdsForOutput(OutputTestPlan(t, t.output))) + assert(!hasUniqueExprIdsForOutput(OutputTestPlan(t, t.output.zipWithIndex.map { + case (a, i) => AttributeReference(s"c$i", LongType)(a.exprId) + }))) + } + + test("Checks if reference ExprIds are not reused when assigning a new ExprId") { + val t = LocalRelation('a.int, 'b.int) + val Seq(a, b) = t.output + assert(checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")()))) + assert(!checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")(exprId = a.exprId)))) + assert(checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")(exprId = b.exprId)))) + assert(checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")()))) + assert(!checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")(exprId = a.exprId)))) + assert(!checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")(exprId = b.exprId)))) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala index c82b264a600ef..0170f8b2f71c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.adaptive -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LogicalPlanIntegrity, PlanHelper} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils @@ -54,4 +54,10 @@ class AQEOptimizer(conf: SQLConf) extends RuleExecutor[LogicalPlan] { } } } + + override protected def isPlanIntegral(plan: LogicalPlan): Boolean = { + !Utils.isTesting || (plan.resolved && + plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty && + LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 9f3ff1a6708e4..8797e5ad64149 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -36,7 +36,6 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.Range import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2} -import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan} import org.apache.spark.sql.execution.command.ExplainCommand @@ -47,7 +46,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.streaming.util.{BlockOnStopSourceProvider, StreamManualClock} -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType} import org.apache.spark.util.Utils class StreamSuite extends StreamTest { @@ -1268,7 +1267,7 @@ class StreamSuite extends StreamTest { } abstract class FakeSource extends StreamSourceProvider { - private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) + private val fakeSchema = StructType(StructField("a", LongType) :: Nil) override def sourceSchema( spark: SQLContext, @@ -1290,7 +1289,7 @@ class FakeDefaultSource extends FakeSource { new Source { private var offset = -1L - override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil) + override def schema: StructType = StructType(StructField("a", LongType) :: Nil) override def getOffset: Option[Offset] = { if (offset >= 10) {