Skip to content

Commit

Permalink
Bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
maropu committed Sep 28, 2020
1 parent 8b86bcf commit 4c8a81e
Show file tree
Hide file tree
Showing 192 changed files with 7,449 additions and 7,201 deletions.
14 changes: 7 additions & 7 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -2684,21 +2684,21 @@ test_that("union(), unionByName(), rbind(), except(), and intersect() on a DataF
writeLines(lines, jsonPath2)
df2 <- read.df(jsonPath2, "json")

unioned <- arrange(union(df, df2), df$age)
unioned <- arrange(union(df, df2), "age")
expect_is(unioned, "SparkDataFrame")
expect_equal(count(unioned), 6)
expect_equal(first(unioned)$name, "Michael")
expect_equal(count(arrange(suppressWarnings(union(df, df2)), df$age)), 6)
expect_equal(count(arrange(suppressWarnings(unionAll(df, df2)), df$age)), 6)
expect_equal(count(arrange(suppressWarnings(union(df, df2)), "age")), 6)
expect_equal(count(arrange(suppressWarnings(unionAll(df, df2)), "age")), 6)

df1 <- select(df2, "age", "name")
unioned1 <- arrange(unionByName(df1, df), df1$age)
unioned1 <- arrange(unionByName(df1, df), "age")
expect_is(unioned, "SparkDataFrame")
expect_equal(count(unioned), 6)
# Here, we test if 'Michael' in df is correctly mapped to the same name.
expect_equal(first(unioned)$name, "Michael")

unioned2 <- arrange(rbind(unioned, df, df2), df$age)
unioned2 <- arrange(rbind(unioned, df, df2), "age")
expect_is(unioned2, "SparkDataFrame")
expect_equal(count(unioned2), 12)
expect_equal(first(unioned2)$name, "Michael")
Expand All @@ -2723,12 +2723,12 @@ test_that("union(), unionByName(), rbind(), except(), and intersect() on a DataF
testthat::expect_error(unionByName(df2, select(df2, "age"), FALSE))
testthat::expect_error(unionByName(df2, select(df2, "age")))

excepted <- arrange(except(df, df2), desc(df$age))
excepted <- arrange(except(df, df2), desc(column("age")))
expect_is(unioned, "SparkDataFrame")
expect_equal(count(excepted), 2)
expect_equal(first(excepted)$name, "Justin")

intersected <- arrange(intersect(df, df2), df$age)
intersected <- arrange(intersect(df, df2), "age")
expect_is(unioned, "SparkDataFrame")
expect_equal(count(intersected), 1)
expect_equal(first(intersected)$name, "Andy")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1272,6 +1272,11 @@ class Analyzer(
val newOutput = oldVersion.generatorOutput.map(_.newInstance())
Seq((oldVersion, oldVersion.copy(generatorOutput = newOutput)))

case oldVersion: Union
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
val newOutput = oldVersion.unionOutput.get.map(_.newInstance())
Seq((oldVersion, oldVersion.copy(unionOutput = Some(newOutput))))

case oldVersion: Expand
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
val producedAttributes = oldVersion.producedAttributes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package org.apache.spark.sql.catalyst.analysis

import scala.collection.mutable

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, ExprId, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
import org.apache.spark.sql.catalyst.rules.Rule
Expand All @@ -31,12 +33,15 @@ import org.apache.spark.sql.util.SchemaUtils
*/
object ResolveUnion extends Rule[LogicalPlan] {

private[catalyst] def makeUnionOutput(children: Seq[LogicalPlan]): Seq[Attribute] = {
def makeUnionOutput(children: Seq[LogicalPlan]): Seq[Attribute] = {
val seenExprIdSet = mutable.Map[ExprId, ExprId]()
children.map(_.output).transpose.map { attrs =>
val firstAttr = attrs.head
val nullable = attrs.exists(_.nullable)
val newDt = attrs.map(_.dataType).reduce(StructType.merge)
val newExprId = NamedExpression.newExprId
// If child's output has attributes having the same `exprId`, we needs to
// assign a unique `exprId` for them.
val newExprId = seenExprIdSet.getOrElseUpdate(firstAttr.exprId, NamedExpression.newExprId)
if (firstAttr.dataType == newDt) {
firstAttr.withExprId(newExprId).withNullability(nullable)
} else {
Expand Down Expand Up @@ -85,8 +90,13 @@ object ResolveUnion extends Rule[LogicalPlan] {
} else {
left
}
val unionOutput = makeUnionOutput(Seq(leftChild, rightChild))
Union(leftChild, rightChild, unionOutput)
val newUnion = Union(leftChild, rightChild)
if (newUnion.allChildrenCompatible) {
val unionOutput = makeUnionOutput(Seq(leftChild, rightChild))
newUnion.copy(unionOutput = Some(unionOutput))
} else {
newUnion
}
}

// Check column name duplication
Expand Down Expand Up @@ -117,6 +127,6 @@ object ResolveUnion extends Rule[LogicalPlan] {

case u @ Union(children, _, _, unionOutput)
if u.allChildrenCompatible && unionOutput.isEmpty =>
u.copy(unionOutput = makeUnionOutput(children))
u.copy(unionOutput = Some(makeUnionOutput(children)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -360,13 +360,14 @@ object TypeCoercion {
} else {
val attrMapping = s.children.head.output.zip(newChildren.head.output)
val newOutput = ResolveUnion.makeUnionOutput(newChildren)
s.copy(children = newChildren, unionOutput = newOutput) -> attrMapping
s.copy(children = newChildren, unionOutput = Some(newOutput)) -> attrMapping
}
}
}

/** Build new children with the widest types for each attribute among all the children */
private def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan]): Seq[LogicalPlan] = {
private[analysis] def buildNewChildrenWithWiderTypes(children: Seq[LogicalPlan])
: Seq[LogicalPlan] = {
require(children.forall(_.output.length == children.head.output.length))

// Get a sequence of data types, each of which is the widest type of this specific attribute
Expand Down Expand Up @@ -1120,6 +1121,8 @@ object TypeCoercion {
}

trait TypeCoercionRule extends Rule[LogicalPlan] with Logging {
import TypeCoercion.WidenSetOperationTypes

/**
* Applies any changes to [[AttributeReference]] data types that are made by the transform method
* to instances higher in the query tree.
Expand All @@ -1142,6 +1145,24 @@ trait TypeCoercionRule extends Rule[LogicalPlan] with Logging {
// Don't propagate types from unresolved children.
case q: LogicalPlan if !q.childrenResolved => q

case u: Union =>
if (u.unionOutput.isDefined) {
// If this type coercion rule changes the input types of `Union`, we need to
// update data types in `unionOutput` accordingly.
val newChildren = WidenSetOperationTypes.buildNewChildrenWithWiderTypes(u.children)
val newOutputTypes = newChildren.head.output.map(_.dataType)
if (!u.output.zip(newOutputTypes).forall { case (a, dt) => a.dataType.sameType(dt)}) {
val newOutput = u.output.map(_.asInstanceOf[AttributeReference]).zip(newOutputTypes).map {
case (a, dt) => a.copy(dataType = dt)(exprId = a.exprId, qualifier = a.qualifier)
}
u.copy(children = newChildren, unionOutput = Some(newOutput))
} else {
u
}
} else {
u
}

case q: LogicalPlan =>
val inputMap = q.inputSet.toSeq.map(a => (a.exprId, a)).toMap
q transformExpressions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] with PredicateHelper
}
val newChildren = newFirstChild +: newOtherChildren
val newOutput = ResolveUnion.makeUnionOutput(newChildren)
val newPlan = u.copy(children = newChildren, unionOutput = newOutput)
val newPlan = u.copy(children = newChildren, unionOutput = Some(newOutput))
val attrMapping = p.output.zip(newPlan.output).filter {
case (a1, a2) => a1.exprId != a2.exprId
}
Expand Down Expand Up @@ -683,7 +683,7 @@ object ColumnPruning extends Rule[LogicalPlan] {
Project(selected, p)
}
val prunedUnionOutput = u.output.filter(p.references.contains)
p.copy(child = u.copy(children = newChildren, unionOutput = prunedUnionOutput))
p.copy(child = u.copy(children = newChildren, unionOutput = Some(prunedUnionOutput)))
} else {
p
}
Expand Down Expand Up @@ -1804,9 +1804,7 @@ object RewriteIntersectAll extends Rule[LogicalPlan] {
projectMinPlan
)
val newPlan = Project(newLeftOutput, genRowPlan)
val attrMapping = i.output.zip(newPlan.output).filter {
case (a1, a2) => a1.exprId != a2.exprId
}
val attrMapping = i.output.zip(newPlan.output)
newPlan -> attrMapping
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,15 @@ object PropagateEmptyRelation extends Rule[LogicalPlan] with PredicateHelper wit
if (newChildren.isEmpty) {
empty(p)
} else {
val newPlan = if (newChildren.size > 1) Union(newChildren) else newChildren.head
val outputs = newPlan.output.zip(p.output)
// the original Union may produce different output attributes than the new one so we alias
// them if needed
if (outputs.forall { case (newAttr, oldAttr) => newAttr.exprId == oldAttr.exprId }) {
newPlan
if (newChildren.size > 1) {
p.copy(children = newChildren)
} else {
val outputAliases = outputs.map { case (newAttr, oldAttr) =>
val newPlan = newChildren.head
val outputAliases = p.output.zip(newPlan.output).map { case (oldAttr, newAttr) =>
val newExplicitMetadata =
if (oldAttr.metadata != newAttr.metadata) Some(oldAttr.metadata) else None
Alias(newAttr, oldAttr.name)(oldAttr.exprId, explicitMetadata = newExplicitMetadata)
Alias(newAttr, oldAttr.name)(
exprId = oldAttr.exprId, explicitMetadata = newExplicitMetadata)
}
Project(outputAliases, newPlan)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
if (inserts.length == 1) {
inserts.head
} else {
Union(inserts.toSeq)
Union(inserts.toSeq, unionOutput = Some(Nil))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,23 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) {
rule.applyOrElse(newPlan, (plan: PlanType) => plan -> Nil)
}
newPlan = planAfterRule

attrMapping ++= newAttrMapping.filter {
val newValidAttrMapping = newAttrMapping.filter {
case (a1, a2) => a1.exprId != a2.exprId
}
newPlan -> attrMapping.toSeq
// Updates the `attrMapping` entries that are obsoleted by generated entries in `rule`.
// For example, `attrMapping` has a mapping entry 'id#1 -> id#2' and `rule`
// generates a new entry 'id#2 -> id#3'. In this case, we need to update
// the corresponding old entry from 'id#1 -> id#2' to '#id#1 -> #id#3'.
val updatedAttrMap = AttributeMap(newValidAttrMapping)
val transferAttrMapping = attrMapping.map {
case (a1, a2) => (a1, updatedAttrMap.getOrElse(a2, a2))
}
val newOtherAttrMapping = {
val existingAttrMappingSet = transferAttrMapping.map(_._2).toSet
newValidAttrMapping.filterNot { case (_, a) => existingAttrMappingSet.contains(a) }
}
planAfterRule -> (transferAttrMapping ++ newOtherAttrMapping).toSeq
}
}
rewrite(this)._1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ case class Except(
object Union {

def apply(left: LogicalPlan, right: LogicalPlan, output: Seq[Attribute]): Union = {
Union(left :: right :: Nil, unionOutput = output)
Union(left :: right :: Nil, unionOutput = Some(output))
}

def apply(left: LogicalPlan, right: LogicalPlan): Union = {
Expand All @@ -235,7 +235,7 @@ case class Union(
children: Seq[LogicalPlan],
byName: Boolean = false,
allowMissingCol: Boolean = false,
unionOutput: Seq[Attribute] = Seq.empty) extends LogicalPlan {
unionOutput: Option[Seq[Attribute]] = None) extends LogicalPlan {
assert(!allowMissingCol || byName, "`allowMissingCol` can be true only if `byName` is true.")

override def maxRows: Option[Long] = {
Expand All @@ -262,12 +262,13 @@ case class Union(
AttributeSet.fromAttributeSets(children.map(_.outputSet)).size
}

override def producedAttributes: AttributeSet = AttributeSet(unionOutput)
override def producedAttributes: AttributeSet =
if (unionOutput.isDefined) AttributeSet(unionOutput.get) else AttributeSet.empty

// updating nullability to make all the children consistent
override def output: Seq[Attribute] = {
assert(unionOutput.nonEmpty, "Union should have at least a single column")
unionOutput
assert(unionOutput.isDefined, "Union should have at least a single column")
unionOutput.get
}

lazy val allChildrenCompatible: Boolean = {
Expand All @@ -284,7 +285,7 @@ case class Union(

override lazy val resolved: Boolean = {
children.length > 1 && !(byName || allowMissingCol) && allChildrenCompatible &&
unionOutput.nonEmpty
unionOutput.isDefined
}

/**
Expand Down Expand Up @@ -319,7 +320,7 @@ case class Union(

override protected lazy val validConstraints: ExpressionSet = {
children
.map(child => rewriteConstraints(children.head.output, child.output, child.constraints))
.map(child => rewriteConstraints(output, child.output, child.constraints))
.reduce(merge(_, _))
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
a.select(UnresolvedStar(None)).select($"a").union(b.select(UnresolvedStar(None)))
}

assertAnalysisSuccess(plan)
assertAnalysisSuccess(plan, maxIterations = Some(150))
}

test("check project's resolved") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,14 @@ trait AnalysisTest extends PlanTest {

protected def extendedAnalysisRules: Seq[Rule[LogicalPlan]] = Nil

private def makeAnalyzer(caseSensitive: Boolean): Analyzer = {
val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive)
private def makeAnalyzer(caseSensitive: Boolean, maxIterations: Option[Int] = None): Analyzer = {
val conf = {
val sqlConf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive)
if (maxIterations.isDefined) {
sqlConf.setConf(SQLConf.ANALYZER_MAX_ITERATIONS, maxIterations.get)
}
sqlConf
}
val catalog = new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin, conf)
catalog.createDatabase(
CatalogDatabase("default", "", new URI("loc"), Map.empty),
Expand All @@ -52,15 +58,20 @@ trait AnalysisTest extends PlanTest {
}
}

protected def getAnalyzer(caseSensitive: Boolean) = {
if (caseSensitive) caseSensitiveAnalyzer else caseInsensitiveAnalyzer
protected def getAnalyzer(caseSensitive: Boolean, maxIterations: Option[Int] = None) = {
if (maxIterations.isEmpty) {
if (caseSensitive) caseSensitiveAnalyzer else caseInsensitiveAnalyzer
} else {
makeAnalyzer(caseSensitive, maxIterations)
}
}

protected def checkAnalysis(
inputPlan: LogicalPlan,
expectedPlan: LogicalPlan,
caseSensitive: Boolean = true): Unit = {
val analyzer = getAnalyzer(caseSensitive)
caseSensitive: Boolean = true,
maxIterations: Option[Int] = None): Unit = {
val analyzer = getAnalyzer(caseSensitive, maxIterations)
val actualPlan = analyzer.executeAndCheck(inputPlan, new QueryPlanningTracker)
comparePlans(actualPlan, expectedPlan)
}
Expand All @@ -75,8 +86,9 @@ trait AnalysisTest extends PlanTest {

protected def assertAnalysisSuccess(
inputPlan: LogicalPlan,
caseSensitive: Boolean = true): Unit = {
val analyzer = getAnalyzer(caseSensitive)
caseSensitive: Boolean = true,
maxIterations: Option[Int] = None): Unit = {
val analyzer = getAnalyzer(caseSensitive, maxIterations)
val analysisAttempt = analyzer.execute(inputPlan)
try analyzer.checkAnalysis(analysisAttempt) catch {
case a: AnalysisException =>
Expand All @@ -94,8 +106,9 @@ trait AnalysisTest extends PlanTest {
protected def assertAnalysisError(
inputPlan: LogicalPlan,
expectedErrors: Seq[String],
caseSensitive: Boolean = true): Unit = {
val analyzer = getAnalyzer(caseSensitive)
caseSensitive: Boolean = true,
maxIterations: Option[Int] = None): Unit = {
val analyzer = getAnalyzer(caseSensitive, maxIterations)
val e = intercept[AnalysisException] {
analyzer.checkAnalysis(analyzer.execute(inputPlan))
}
Expand Down
Loading

0 comments on commit 4c8a81e

Please sign in to comment.