diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index 693a5a4e75443..efbd7101d6915 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -18,17 +18,33 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal, NamedExpression} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.SchemaUtils /** * Resolves different children of Union to a common set of columns. */ object ResolveUnion extends Rule[LogicalPlan] { + + private[catalyst] def makeUnionOutput(children: Seq[LogicalPlan]): Seq[Attribute] = { + children.map(_.output).transpose.map { attrs => + val firstAttr = attrs.head + val nullable = attrs.exists(_.nullable) + val newDt = attrs.map(_.dataType).reduce(StructType.merge) + if (firstAttr.dataType == newDt) { + firstAttr.withNullability(nullable) + } else { + AttributeReference(firstAttr.name, newDt, nullable, firstAttr.metadata)( + NamedExpression.newExprId, firstAttr.qualifier) + } + } + } + private def unionTwoSides( left: LogicalPlan, right: LogicalPlan, @@ -68,7 +84,8 @@ object ResolveUnion extends Rule[LogicalPlan] { } else { left } - Union(leftChild, rightChild) + val unionOutput = makeUnionOutput(Seq(leftChild, rightChild)) + Union(leftChild, rightChild, unionOutput) } // Check column name duplication @@ -88,13 +105,17 @@ object ResolveUnion extends Rule[LogicalPlan] { } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp { - case e if !e.childrenResolved => e + case p if !p.childrenResolved => p - case Union(children, byName, allowMissingCol) if byName => + case Union(children, byName, allowMissingCol, _) if byName => val union = children.reduceLeft { (left, right) => checkColumnNames(left, right) unionTwoSides(left, right, allowMissingCol) } CombineUnions(union) + + case u @ Union(children, _, _, unionOutput) + if u.allChildrenCompatible && unionOutput.isEmpty => + u.copy(unionOutput = makeUnionOutput(children)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index deaa49bf423b1..f71e2c5a3c505 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -359,7 +359,8 @@ object TypeCoercion { s -> Nil } else { val attrMapping = s.children.head.output.zip(newChildren.head.output) - s.copy(children = newChildren) -> attrMapping + val newOutput = ResolveUnion.makeUnionOutput(newChildren) + s.copy(children = newChildren, unionOutput = newOutput) -> attrMapping } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f27c5a26741d4..507381699642f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -592,7 +592,8 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper val rewrites = buildRewrites(u.children.head, child) Project(projectList.map(pushToRight(_, rewrites)), child) } - u.copy(children = newFirstChild +: newOtherChildren) + val newChildren = newFirstChild +: newOtherChildren + u.copy(children = newChildren, unionOutput = ResolveUnion.makeUnionOutput(newChildren)) } else { p } @@ -967,17 +968,20 @@ object CombineUnions extends Rule[LogicalPlan] { // rules (by position and by name) could cause incorrect results. while (stack.nonEmpty) { stack.pop() match { - case Distinct(Union(children, byName, allowMissingCol)) + case Distinct(Union(children, byName, allowMissingCol, _)) if flattenDistinct && byName == topByName && allowMissingCol == topAllowMissingCol => stack.pushAll(children.reverse) - case Union(children, byName, allowMissingCol) + case Union(children, byName, allowMissingCol, _) if byName == topByName && allowMissingCol == topAllowMissingCol => stack.pushAll(children.reverse) case child => flattened += child } } - union.copy(children = flattened.toSeq) + union.copy( + children = flattened, + unionOutput = ResolveUnion.makeUnionOutput(flattened) + ) } } @@ -1689,7 +1693,8 @@ object RewriteExceptAll extends Rule[LogicalPlan] { val newColumnRight = Alias(Literal(-1L), "vcol")() val modifiedLeftPlan = Project(Seq(newColumnLeft) ++ left.output, left) val modifiedRightPlan = Project(Seq(newColumnRight) ++ right.output, right) - val unionPlan = Union(modifiedLeftPlan, modifiedRightPlan) + val unionOutput = ResolveUnion.makeUnionOutput(Seq(modifiedLeftPlan, modifiedRightPlan)) + val unionPlan = Union(modifiedLeftPlan, modifiedRightPlan, output = unionOutput) val aggSumCol = Alias(AggregateExpression(Sum(unionPlan.output.head.toAttribute), Complete, false), "sum")() val aggOutputColumns = left.output ++ Seq(aggSumCol) @@ -1753,7 +1758,12 @@ object RewriteIntersectAll extends Rule[LogicalPlan] { val leftPlanWithAddedVirtualCols = Project(Seq(trueVcol1, nullVcol2) ++ left.output, left) val rightPlanWithAddedVirtualCols = Project(Seq(nullVcol1, trueVcol2) ++ right.output, right) - val unionPlan = Union(leftPlanWithAddedVirtualCols, rightPlanWithAddedVirtualCols) + val unionOutput = ResolveUnion.makeUnionOutput( + Seq(leftPlanWithAddedVirtualCols, rightPlanWithAddedVirtualCols)) + val unionPlan = Union( + leftPlanWithAddedVirtualCols, + rightPlanWithAddedVirtualCols, + output = unionOutput) // Expressions to compute count and minimum of both the counts. val vCol1AggrExpr = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 9173b2d8032b5..a168dcd7a83f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -463,10 +463,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { sys.error(s"Unexpected operator in scalar subquery: $lp") } - val resultMap = evalPlan(plan).mapValues { _.transform { - case a: Alias => a.newInstance() // Assigns a new `ExprId` - } - } + val resultMap = evalPlan(plan) // By convention, the scalar subquery result is the leftmost field. resultMap.get(plan.output.head.exprId) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index cefc97a95d223..eee9fb7cff5a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -218,19 +218,18 @@ object LogicalPlanIntegrity { /** * Since some logical plans (e.g., `Union`) can build `AttributeReference`s in their `output`, - * this method checks if the same `ExprId` refers to a semantically-equal attribute - * in a plan output. + * this method checks if the same `ExprId` refers to attributes having the same data type + * in plan output. */ def hasUniqueExprIdsForOutput(plan: LogicalPlan): Boolean = { - val allOutputAttrs = plan.collect { case p if canGetOutputAttrs(p) => + val exprIds = plan.collect { case p if canGetOutputAttrs(p) => // NOTE: we still need to filter resolved expressions here because the output of // some resolved logical plans can have unresolved references, // e.g., outer references in `ExistenceJoin`. - p.output.filter(_.resolved).map(_.canonicalized.asInstanceOf[Attribute]) - } - val groupedAttrsByExprId = allOutputAttrs - .flatten.groupBy(_.exprId).values.map(_.distinct) - groupedAttrsByExprId.forall(_.length == 1) + p.output.filter(_.resolved).map { a => (a.exprId, a.dataType) } + }.flatten + val groupedDataTypesByExprId = exprIds.groupBy(_._1).values.map(_.distinct) + groupedDataTypesByExprId.forall(_.length == 1) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 01f6f96437d39..b301b9e37b649 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -213,8 +213,13 @@ case class Except( /** Factory for constructing new `Union` nodes. */ object Union { + + def apply(left: LogicalPlan, right: LogicalPlan, output: Seq[Attribute]): Union = { + Union(left :: right :: Nil, unionOutput = output) + } + def apply(left: LogicalPlan, right: LogicalPlan): Union = { - Union (left :: right :: Nil) + Union(left :: right :: Nil) } } @@ -229,7 +234,8 @@ object Union { case class Union( children: Seq[LogicalPlan], byName: Boolean = false, - allowMissingCol: Boolean = false) extends LogicalPlan { + allowMissingCol: Boolean = false, + unionOutput: Seq[Attribute] = Seq.empty) extends LogicalPlan { assert(!allowMissingCol || byName, "`allowMissingCol` can be true only if `byName` is true.") override def maxRows: Option[Long] = { @@ -256,8 +262,7 @@ case class Union( AttributeSet.fromAttributeSets(children.map(_.outputSet)).size } - // updating nullability to make all the children consistent - override def output: Seq[Attribute] = { + private def makeUnionOutput(): Seq[Attribute] = { children.map(_.output).transpose.map { attrs => val firstAttr = attrs.head val nullable = attrs.exists(_.nullable) @@ -271,17 +276,27 @@ case class Union( } } - override lazy val resolved: Boolean = { + // updating nullability to make all the children consistent + override def output: Seq[Attribute] = { + assert(unionOutput.nonEmpty, "Union should have at least a single column") + unionOutput + } + + def allChildrenCompatible: Boolean = { // allChildrenCompatible needs to be evaluated after childrenResolved - def allChildrenCompatible: Boolean = - children.tail.forall( child => - // compare the attribute number with the first child - child.output.length == children.head.output.length && + childrenResolved && children.tail.forall { child => + // compare the attribute number with the first child + child.output.length == children.head.output.length && // compare the data types with the first child child.output.zip(children.head.output).forall { case (l, r) => l.dataType.sameType(r.dataType) - }) - children.length > 1 && !(byName || allowMissingCol) && childrenResolved && allChildrenCompatible + } + } + } + + override lazy val resolved: Boolean = { + children.length > 1 && !(byName || allowMissingCol) && allChildrenCompatible && + unionOutput.nonEmpty } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index d5991ff10ce6c..2351523f8f41c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -73,7 +73,7 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter { Union(Project(Seq(Alias(left, "l")()), relation), Project(Seq(Alias(right, "r")()), relation)) val (l, r) = analyzer.execute(plan).collect { - case Union(Seq(child1, child2), _, _) => (child1.output.head, child2.output.head) + case Union(Seq(child1, child2), _, _, _) => (child1.output.head, child2.output.head) }.head assert(l.dataType === expectedType) assert(r.dataType === expectedType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnionSuite.scala index 5c7ad0067a456..42bd4012b0a63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnionSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -51,7 +53,8 @@ class ResolveUnionSuite extends AnalysisTest { val analyzed1 = analyzer.execute(union1) val projected1 = Project(Seq(table2.output(3), table2.output(0), table2.output(1), table2.output(2)), table2) - val expected1 = Union(table1 :: projected1 :: Nil) + val expectedOutput = Seq('i.int, 'u.decimal, 'b.byte, 'd.double) + val expected1 = Union(table1 :: projected1 :: Nil, unionOutput = expectedOutput) comparePlans(analyzed1, expected1) // Allow missing column @@ -60,7 +63,7 @@ class ResolveUnionSuite extends AnalysisTest { val nullAttr1 = Alias(Literal(null, ByteType), "b")() val projected2 = Project(Seq(table2.output(3), table2.output(0), nullAttr1, table2.output(2)), table3) - val expected2 = Union(table1 :: projected2 :: Nil) + val expected2 = Union(table1 :: projected2 :: Nil, unionOutput = expectedOutput) comparePlans(analyzed2, expected2) // Allow missing column + Allow missing column @@ -69,7 +72,7 @@ class ResolveUnionSuite extends AnalysisTest { val nullAttr2 = Alias(Literal(null, DoubleType), "d")() val projected3 = Project(Seq(table2.output(3), table2.output(0), nullAttr1, nullAttr2), table4) - val expected3 = Union(table1 :: projected2 :: projected3 :: Nil) + val expected3 = Union(table1 :: projected2 :: projected3 :: Nil, unionOutput = expectedOutput) comparePlans(analyzed3, expected3) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 87b9aea80c823..0c4b5dc5a73cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -227,7 +227,7 @@ class Dataset[T] private[sql]( val plan = queryExecution.analyzed match { case c: Command => LocalRelation(c.output, withAction("command", queryExecution)(_.executeCollect())) - case u @ Union(children, _, _) if children.forall(_.isInstanceOf[Command]) => + case u @ Union(children, _, _, _) if children.forall(_.isInstanceOf[Command]) => LocalRelation(u.output, withAction("command", queryExecution)(_.executeCollect())) case _ => queryExecution.analyzed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ba3d83714c302..abea4a7a8eaff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -692,7 +692,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.GlobalLimit(IntegerLiteral(limit), child) => execution.GlobalLimitExec(limit, planLater(child)) :: Nil case union: logical.Union => - execution.UnionExec(union.children.map(planLater)) :: Nil + execution.UnionExec(union.children.map(planLater), union.unionOutput) :: Nil case g @ logical.Generate(generator, _, outer, _, _, child) => execution.GenerateExec( generator, g.requiredChildOutput, outer, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 1f70fde3f7654..740564f120e1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -634,21 +634,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) * If we change how this is implemented physically, we'd need to update * [[org.apache.spark.sql.catalyst.plans.logical.Union.maxRowsPerPartition]]. */ -case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan { - // updating nullability to make all the children consistent - override def output: Seq[Attribute] = { - children.map(_.output).transpose.map { attrs => - val firstAttr = attrs.head - val nullable = attrs.exists(_.nullable) - val newDt = attrs.map(_.dataType).reduce(StructType.merge) - if (firstAttr.dataType == newDt) { - firstAttr.withNullability(nullable) - } else { - AttributeReference(firstAttr.name, newDt, nullable, firstAttr.metadata)( - firstAttr.exprId, firstAttr.qualifier) - } - } - } +case class UnionExec(children: Seq[SparkPlan], unionOutput: Seq[Attribute]) extends SparkPlan { + + override def output: Seq[Attribute] = unionOutput protected override def doExecute(): RDD[InternalRow] = sparkContext.union(children.map(_.execute())) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala index b4cb7e3fce3cf..4ba25db374147 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlannerSuite.scala @@ -39,7 +39,7 @@ class SparkPlannerSuite extends SharedSparkSession { planLater(child) :: planLater(NeverPlanned) :: Nil case u: Union => planned += 1 - UnionExec(u.children.map(planLater)) :: planLater(NeverPlanned) :: Nil + UnionExec(u.children.map(planLater), u.unionOutput) :: planLater(NeverPlanned) :: Nil case LocalRelation(output, data, _) => planned += 1 LocalTableScanExec(output, data) :: planLater(NeverPlanned) :: Nil