Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
maropu committed Sep 27, 2020
1 parent 93cd9e6 commit c78e517
Show file tree
Hide file tree
Showing 12 changed files with 90 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,33 @@
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.SchemaUtils

/**
* Resolves different children of Union to a common set of columns.
*/
object ResolveUnion extends Rule[LogicalPlan] {

private[catalyst] def makeUnionOutput(children: Seq[LogicalPlan]): Seq[Attribute] = {
children.map(_.output).transpose.map { attrs =>
val firstAttr = attrs.head
val nullable = attrs.exists(_.nullable)
val newDt = attrs.map(_.dataType).reduce(StructType.merge)
if (firstAttr.dataType == newDt) {
firstAttr.withNullability(nullable)
} else {
AttributeReference(firstAttr.name, newDt, nullable, firstAttr.metadata)(
NamedExpression.newExprId, firstAttr.qualifier)
}
}
}

private def unionTwoSides(
left: LogicalPlan,
right: LogicalPlan,
Expand Down Expand Up @@ -68,7 +84,8 @@ object ResolveUnion extends Rule[LogicalPlan] {
} else {
left
}
Union(leftChild, rightChild)
val unionOutput = makeUnionOutput(Seq(leftChild, rightChild))
Union(leftChild, rightChild, unionOutput)
}

// Check column name duplication
Expand All @@ -88,13 +105,17 @@ object ResolveUnion extends Rule[LogicalPlan] {
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
case e if !e.childrenResolved => e
case p if !p.childrenResolved => p

case Union(children, byName, allowMissingCol) if byName =>
case Union(children, byName, allowMissingCol, _) if byName =>
val union = children.reduceLeft { (left, right) =>
checkColumnNames(left, right)
unionTwoSides(left, right, allowMissingCol)
}
CombineUnions(union)

case u @ Union(children, _, _, unionOutput)
if u.allChildrenCompatible && unionOutput.isEmpty =>
u.copy(unionOutput = makeUnionOutput(children))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,8 @@ object TypeCoercion {
s -> Nil
} else {
val attrMapping = s.children.head.output.zip(newChildren.head.output)
s.copy(children = newChildren) -> attrMapping
val newOutput = ResolveUnion.makeUnionOutput(newChildren)
s.copy(children = newChildren, unionOutput = newOutput) -> attrMapping
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,8 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper
val rewrites = buildRewrites(u.children.head, child)
Project(projectList.map(pushToRight(_, rewrites)), child)
}
u.copy(children = newFirstChild +: newOtherChildren)
val newChildren = newFirstChild +: newOtherChildren
u.copy(children = newChildren, unionOutput = ResolveUnion.makeUnionOutput(newChildren))
} else {
p
}
Expand Down Expand Up @@ -967,17 +968,20 @@ object CombineUnions extends Rule[LogicalPlan] {
// rules (by position and by name) could cause incorrect results.
while (stack.nonEmpty) {
stack.pop() match {
case Distinct(Union(children, byName, allowMissingCol))
case Distinct(Union(children, byName, allowMissingCol, _))
if flattenDistinct && byName == topByName && allowMissingCol == topAllowMissingCol =>
stack.pushAll(children.reverse)
case Union(children, byName, allowMissingCol)
case Union(children, byName, allowMissingCol, _)
if byName == topByName && allowMissingCol == topAllowMissingCol =>
stack.pushAll(children.reverse)
case child =>
flattened += child
}
}
union.copy(children = flattened.toSeq)
union.copy(
children = flattened,
unionOutput = ResolveUnion.makeUnionOutput(flattened)
)
}
}

Expand Down Expand Up @@ -1689,7 +1693,8 @@ object RewriteExceptAll extends Rule[LogicalPlan] {
val newColumnRight = Alias(Literal(-1L), "vcol")()
val modifiedLeftPlan = Project(Seq(newColumnLeft) ++ left.output, left)
val modifiedRightPlan = Project(Seq(newColumnRight) ++ right.output, right)
val unionPlan = Union(modifiedLeftPlan, modifiedRightPlan)
val unionOutput = ResolveUnion.makeUnionOutput(Seq(modifiedLeftPlan, modifiedRightPlan))
val unionPlan = Union(modifiedLeftPlan, modifiedRightPlan, output = unionOutput)
val aggSumCol =
Alias(AggregateExpression(Sum(unionPlan.output.head.toAttribute), Complete, false), "sum")()
val aggOutputColumns = left.output ++ Seq(aggSumCol)
Expand Down Expand Up @@ -1753,7 +1758,12 @@ object RewriteIntersectAll extends Rule[LogicalPlan] {
val leftPlanWithAddedVirtualCols = Project(Seq(trueVcol1, nullVcol2) ++ left.output, left)
val rightPlanWithAddedVirtualCols = Project(Seq(nullVcol1, trueVcol2) ++ right.output, right)

val unionPlan = Union(leftPlanWithAddedVirtualCols, rightPlanWithAddedVirtualCols)
val unionOutput = ResolveUnion.makeUnionOutput(
Seq(leftPlanWithAddedVirtualCols, rightPlanWithAddedVirtualCols))
val unionPlan = Union(
leftPlanWithAddedVirtualCols,
rightPlanWithAddedVirtualCols,
output = unionOutput)

// Expressions to compute count and minimum of both the counts.
val vCol1AggrExpr =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,10 +463,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
sys.error(s"Unexpected operator in scalar subquery: $lp")
}

val resultMap = evalPlan(plan).mapValues { _.transform {
case a: Alias => a.newInstance() // Assigns a new `ExprId`
}
}
val resultMap = evalPlan(plan)

// By convention, the scalar subquery result is the leftmost field.
resultMap.get(plan.output.head.exprId) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,19 +218,18 @@ object LogicalPlanIntegrity {

/**
* Since some logical plans (e.g., `Union`) can build `AttributeReference`s in their `output`,
* this method checks if the same `ExprId` refers to a semantically-equal attribute
* in a plan output.
* this method checks if the same `ExprId` refers to attributes having the same data type
* in plan output.
*/
def hasUniqueExprIdsForOutput(plan: LogicalPlan): Boolean = {
val allOutputAttrs = plan.collect { case p if canGetOutputAttrs(p) =>
val exprIds = plan.collect { case p if canGetOutputAttrs(p) =>
// NOTE: we still need to filter resolved expressions here because the output of
// some resolved logical plans can have unresolved references,
// e.g., outer references in `ExistenceJoin`.
p.output.filter(_.resolved).map(_.canonicalized.asInstanceOf[Attribute])
}
val groupedAttrsByExprId = allOutputAttrs
.flatten.groupBy(_.exprId).values.map(_.distinct)
groupedAttrsByExprId.forall(_.length == 1)
p.output.filter(_.resolved).map { a => (a.exprId, a.dataType) }
}.flatten
val groupedDataTypesByExprId = exprIds.groupBy(_._1).values.map(_.distinct)
groupedDataTypesByExprId.forall(_.length == 1)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,13 @@ case class Except(

/** Factory for constructing new `Union` nodes. */
object Union {

def apply(left: LogicalPlan, right: LogicalPlan, output: Seq[Attribute]): Union = {
Union(left :: right :: Nil, unionOutput = output)
}

def apply(left: LogicalPlan, right: LogicalPlan): Union = {
Union (left :: right :: Nil)
Union(left :: right :: Nil)
}
}

Expand All @@ -229,7 +234,8 @@ object Union {
case class Union(
children: Seq[LogicalPlan],
byName: Boolean = false,
allowMissingCol: Boolean = false) extends LogicalPlan {
allowMissingCol: Boolean = false,
unionOutput: Seq[Attribute] = Seq.empty) extends LogicalPlan {
assert(!allowMissingCol || byName, "`allowMissingCol` can be true only if `byName` is true.")

override def maxRows: Option[Long] = {
Expand All @@ -256,8 +262,7 @@ case class Union(
AttributeSet.fromAttributeSets(children.map(_.outputSet)).size
}

// updating nullability to make all the children consistent
override def output: Seq[Attribute] = {
private def makeUnionOutput(): Seq[Attribute] = {
children.map(_.output).transpose.map { attrs =>
val firstAttr = attrs.head
val nullable = attrs.exists(_.nullable)
Expand All @@ -271,17 +276,27 @@ case class Union(
}
}

override lazy val resolved: Boolean = {
// updating nullability to make all the children consistent
override def output: Seq[Attribute] = {
assert(unionOutput.nonEmpty, "Union should have at least a single column")
unionOutput
}

def allChildrenCompatible: Boolean = {
// allChildrenCompatible needs to be evaluated after childrenResolved
def allChildrenCompatible: Boolean =
children.tail.forall( child =>
// compare the attribute number with the first child
child.output.length == children.head.output.length &&
childrenResolved && children.tail.forall { child =>
// compare the attribute number with the first child
child.output.length == children.head.output.length &&
// compare the data types with the first child
child.output.zip(children.head.output).forall {
case (l, r) => l.dataType.sameType(r.dataType)
})
children.length > 1 && !(byName || allowMissingCol) && childrenResolved && allChildrenCompatible
}
}
}

override lazy val resolved: Boolean = {
children.length > 1 && !(byName || allowMissingCol) && allChildrenCompatible &&
unionOutput.nonEmpty
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class DecimalPrecisionSuite extends AnalysisTest with BeforeAndAfter {
Union(Project(Seq(Alias(left, "l")()), relation),
Project(Seq(Alias(right, "r")()), relation))
val (l, r) = analyzer.execute(plan).collect {
case Union(Seq(child1, child2), _, _) => (child1.output.head, child2.output.head)
case Union(Seq(child1, child2), _, _, _) => (child1.output.head, child2.output.head)
}.head
assert(l.dataType === expectedType)
assert(r.dataType === expectedType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
Expand Down Expand Up @@ -51,7 +53,8 @@ class ResolveUnionSuite extends AnalysisTest {
val analyzed1 = analyzer.execute(union1)
val projected1 =
Project(Seq(table2.output(3), table2.output(0), table2.output(1), table2.output(2)), table2)
val expected1 = Union(table1 :: projected1 :: Nil)
val expectedOutput = Seq('i.int, 'u.decimal, 'b.byte, 'd.double)
val expected1 = Union(table1 :: projected1 :: Nil, unionOutput = expectedOutput)
comparePlans(analyzed1, expected1)

// Allow missing column
Expand All @@ -60,7 +63,7 @@ class ResolveUnionSuite extends AnalysisTest {
val nullAttr1 = Alias(Literal(null, ByteType), "b")()
val projected2 =
Project(Seq(table2.output(3), table2.output(0), nullAttr1, table2.output(2)), table3)
val expected2 = Union(table1 :: projected2 :: Nil)
val expected2 = Union(table1 :: projected2 :: Nil, unionOutput = expectedOutput)
comparePlans(analyzed2, expected2)

// Allow missing column + Allow missing column
Expand All @@ -69,7 +72,7 @@ class ResolveUnionSuite extends AnalysisTest {
val nullAttr2 = Alias(Literal(null, DoubleType), "d")()
val projected3 =
Project(Seq(table2.output(3), table2.output(0), nullAttr1, nullAttr2), table4)
val expected3 = Union(table1 :: projected2 :: projected3 :: Nil)
val expected3 = Union(table1 :: projected2 :: projected3 :: Nil, unionOutput = expectedOutput)
comparePlans(analyzed3, expected3)
}
}
2 changes: 1 addition & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class Dataset[T] private[sql](
val plan = queryExecution.analyzed match {
case c: Command =>
LocalRelation(c.output, withAction("command", queryExecution)(_.executeCollect()))
case u @ Union(children, _, _) if children.forall(_.isInstanceOf[Command]) =>
case u @ Union(children, _, _, _) if children.forall(_.isInstanceOf[Command]) =>
LocalRelation(u.output, withAction("command", queryExecution)(_.executeCollect()))
case _ =>
queryExecution.analyzed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.GlobalLimit(IntegerLiteral(limit), child) =>
execution.GlobalLimitExec(limit, planLater(child)) :: Nil
case union: logical.Union =>
execution.UnionExec(union.children.map(planLater)) :: Nil
execution.UnionExec(union.children.map(planLater), union.unionOutput) :: Nil
case g @ logical.Generate(generator, _, outer, _, _, child) =>
execution.GenerateExec(
generator, g.requiredChildOutput, outer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -634,21 +634,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
* If we change how this is implemented physically, we'd need to update
* [[org.apache.spark.sql.catalyst.plans.logical.Union.maxRowsPerPartition]].
*/
case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan {
// updating nullability to make all the children consistent
override def output: Seq[Attribute] = {
children.map(_.output).transpose.map { attrs =>
val firstAttr = attrs.head
val nullable = attrs.exists(_.nullable)
val newDt = attrs.map(_.dataType).reduce(StructType.merge)
if (firstAttr.dataType == newDt) {
firstAttr.withNullability(nullable)
} else {
AttributeReference(firstAttr.name, newDt, nullable, firstAttr.metadata)(
firstAttr.exprId, firstAttr.qualifier)
}
}
}
case class UnionExec(children: Seq[SparkPlan], unionOutput: Seq[Attribute]) extends SparkPlan {

override def output: Seq[Attribute] = unionOutput

protected override def doExecute(): RDD[InternalRow] =
sparkContext.union(children.map(_.execute()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class SparkPlannerSuite extends SharedSparkSession {
planLater(child) :: planLater(NeverPlanned) :: Nil
case u: Union =>
planned += 1
UnionExec(u.children.map(planLater)) :: planLater(NeverPlanned) :: Nil
UnionExec(u.children.map(planLater), u.unionOutput) :: planLater(NeverPlanned) :: Nil
case LocalRelation(output, data, _) =>
planned += 1
LocalTableScanExec(output, data) :: planLater(NeverPlanned) :: Nil
Expand Down

0 comments on commit c78e517

Please sign in to comment.