Skip to content

Commit

Permalink
SPARK-45959. added new tests. Handled flattening of Project when done…
Browse files Browse the repository at this point in the history
… using dataFrame.select instead of withColumn api
  • Loading branch information
ashahid committed Dec 15, 2023
1 parent 8f7a9bf commit 9c59adc
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 107 deletions.
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1574,7 +1574,7 @@ class Dataset[T] private[sql](
case other => other
}
val newProjList = untypedCols.map(_.named)
(logicalPlan, newProjList, sparkSession.conf) match {
(logicalPlan, newProjList, id) match {
case EasilyFlattenable(flattendPlan) if !this.isStreaming &&
!logicalPlan.getTagValue(LogicalPlan.SKIP_FLATTENING).getOrElse(false) => flattendPlan

Expand Down Expand Up @@ -2956,7 +2956,7 @@ class Dataset[T] private[sql](
projectList.map(_.name),
sparkSession.sessionState.conf.caseSensitiveAnalysis)
withPlan(
(logicalPlan, projectList, sparkSession.conf) match {
(logicalPlan, projectList, id) match {
case EasilyFlattenable(flattendPlan) if !this.isStreaming &&
!logicalPlan.getTagValue(LogicalPlan.SKIP_FLATTENING).getOrElse(false) => flattendPlan

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@

package org.apache.spark.sql.internal

import scala.collection.mutable
import scala.util.{Failure, Success, Try}

import org.apache.spark.sql.{Dataset, RuntimeConfig}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, NamedExpression, UserDefinedExpression, WindowExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, NamedExpression, UserDefinedExpression, WindowExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.types.MetadataBuilder



private[sql] object EasilyFlattenable {
Expand All @@ -34,17 +33,19 @@ private[sql] object EasilyFlattenable {
val AddNewColumnsOnly, RemapOnly, Unknown = Value
}

def unapply(tuple: (LogicalPlan, Seq[NamedExpression], RuntimeConfig)): Option[LogicalPlan]
def unapply(tuple: (LogicalPlan, Seq[NamedExpression], Long)): Option[LogicalPlan]
= {
val (logicalPlan, newProjList, conf) = tuple
val (logicalPlan, newProjList, did) = tuple

logicalPlan match {
case p @ Project(projList, child: LogicalPlan) if p.output.groupBy(_.name).
forall(_._2.size == 1) =>
case p @ Project(projList, child: LogicalPlan)
if newProjList.flatMap(_.collectLeaves()).forall {
case ar: AttributeReference if ar.metadata.contains(Dataset.DATASET_ID_KEY) &&
ar.metadata.getLong(Dataset.DATASET_ID_KEY) != did => false
case _ => true
} =>
val currentOutputAttribs = AttributeSet(p.output)

val currentDatasetIdOpt = p.getTagValue(Dataset.DATASET_ID_TAG).get.toSet.headOption

// In the new column list identify those Named Expressions which are just attributes and
// hence pass thru
val (passThruAttribs, tinkeredOrNewNamedExprs) = newProjList.partition {
Expand All @@ -62,15 +63,15 @@ private[sql] object EasilyFlattenable {
// case of new columns being added only
val childOutput = child.output.map(_.name).toSet
val attribsRemappedInProj = projList.flatMap(ne => ne match {
case _: AttributeReference => Seq.empty[(String, Alias)]
case _: AttributeReference => Seq.empty[(String, Expression)]

case al @ Alias(_, name) => if (childOutput.contains(name)) {
Seq(name -> al)
case Alias(expr, name) => if (childOutput.contains(name)) {
Seq(name -> expr)
} else {
Seq.empty[(String, Alias)]
Seq.empty[(String, Expression)]
}

case _ => Seq.empty[(String, Alias)]
case _ => Seq.empty[(String, Expression)]
}).toMap

if (tinkeredOrNewNamedExprs.exists(_.collectFirst {
Expand All @@ -90,46 +91,30 @@ private[sql] object EasilyFlattenable {
val remappedNewProjListResult = Try {
newProjList.map {

case attr: AttributeReference =>
val ne = projList.find(
case attr: AttributeReference => projList.find(
_.toAttribute.canonicalized == attr.canonicalized).getOrElse(attr)
if (attr.metadata.contains(Dataset.DATASET_ID_KEY) &&
currentDatasetIdOpt.contains(attr.metadata.getLong(
Dataset.DATASET_ID_KEY))) {
addDataFrameIdToCol(conf, ne, child, currentDatasetIdOpt)
} else {
ne
}

case ua: UnresolvedAttribute =>
projList.find(_.toAttribute.name.equalsIgnoreCase(ua.name)).
projList.find(_.toAttribute.name.equals(ua.name)).
getOrElse(throw new UnsupportedOperationException("Not able to flatten" +
s" unresolved attribute $ua"))

case anyOtherExpr =>
(anyOtherExpr transformUp {
case attr: AttributeReference => val ne =
attribsRemappedInProj.get(attr.name).orElse(
projList.find(
_.toAttribute.canonicalized == attr.canonicalized).map {
case al: Alias => al
case x => x
}).getOrElse(attr)
if (attr.metadata.contains(Dataset.DATASET_ID_KEY) &&
currentDatasetIdOpt.contains(attr.metadata.getLong(
Dataset.DATASET_ID_KEY))) {
addDataFrameIdToCol(conf, ne, child, currentDatasetIdOpt)
} else {
ne
}
case attr: AttributeReference =>
attribsRemappedInProj.get(attr.name).orElse(projList.find(
_.toAttribute.canonicalized == attr.canonicalized).map {
case al: Alias => al.child
case x => x
}).getOrElse(attr)

case u: UnresolvedAttribute => attribsRemappedInProj.get(u.name).orElse(
projList.find( _.toAttribute.name.equalsIgnoreCase(u.name)).map {
case al: Alias => al.child
case u: UnresolvedAttribute =>
throw new UnsupportedOperationException("Not able to flatten" +
s" unresolved attribute $u")
case x => x
projList.find( _.toAttribute.name.equals(u.name)).map {
case al: Alias => al.child
case u: UnresolvedAttribute =>
throw new UnsupportedOperationException("Not able to flatten" +
s" unresolved attribute $u")
case x => x
}).getOrElse(throw new UnsupportedOperationException("Not able to flatten" +
s" unresolved attribute $u"))
}).asInstanceOf[NamedExpression]
Expand All @@ -138,17 +123,7 @@ private[sql] object EasilyFlattenable {
}
remappedNewProjListResult match {
case Success(remappedNewProjList) =>
currentDatasetIdOpt.foreach(id => {
if (conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) {
val dsIds = child.getTagValue(Dataset.DATASET_ID_TAG).getOrElse(
new mutable.HashSet[Long])
dsIds.add(id)
child.setTagValue(Dataset.DATASET_ID_TAG, dsIds)
}
})
val newProj = Project(remappedNewProjList, child)

Option(newProj)
Option(Project(remappedNewProjList, child))

case Failure(_) => None
}
Expand All @@ -158,53 +133,31 @@ private[sql] object EasilyFlattenable {
// case of renaming of columns
val remappedNewProjListResult = Try {
newProjList.map {
case attr: AttributeReference => val ne = projList.find(
case attr: AttributeReference => projList.find(
_.toAttribute.canonicalized == attr.canonicalized).get
if (attr.metadata.contains(Dataset.DATASET_ID_KEY) &&
currentDatasetIdOpt.contains(attr.metadata.getLong(
Dataset.DATASET_ID_KEY))) {
addDataFrameIdToCol(conf, ne, child, currentDatasetIdOpt)
} else {
ne
}


case ua: UnresolvedAttribute if ua.nameParts.size == 1 =>
projList.find( _.toAttribute.name.equalsIgnoreCase(ua.name)).
projList.find( _.toAttribute.name.equals(ua.name)).
getOrElse(throw new UnsupportedOperationException("Not able to flatten" +
s" unresolved attribute $ua"))

case al@Alias(ar: AttributeReference, name) =>
val ne = projList.find(_.toAttribute.canonicalized == ar.canonicalized).map {
projList.find(_.toAttribute.canonicalized == ar.canonicalized).map {

case alx : Alias => Alias(alx.child, name)(al.exprId, al.qualifier,
al.explicitMetadata, al.nonInheritableMetadataKeys)

case _: AttributeReference => al
}.get
if (ar.metadata.contains(Dataset.DATASET_ID_KEY) &&
currentDatasetIdOpt.contains(ar.metadata.getLong(
Dataset.DATASET_ID_KEY))) {
addDataFrameIdToCol(conf, ne, child, currentDatasetIdOpt)
} else {
ne
}

case x => throw new UnsupportedOperationException("Not able to flatten" +
s" unresolved attribute $x")
}
}
remappedNewProjListResult match {
case Success(remappedNewProjList) =>
currentDatasetIdOpt.foreach(id => {
if (conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) {
val dsIds = child.getTagValue(Dataset.DATASET_ID_TAG).getOrElse(
new mutable.HashSet[Long])
dsIds.add(id)
child.setTagValue(Dataset.DATASET_ID_TAG, dsIds)
}
})
val newProj = Project(remappedNewProjList, child)

Option(newProj)
Option(Project(remappedNewProjList, child))

case Failure(_) => None
}
Expand Down Expand Up @@ -236,25 +189,4 @@ private[sql] object EasilyFlattenable {
OpType.Unknown
}
}

private def addDataFrameIdToCol(
conf: RuntimeConfig,
expr: NamedExpression,
logicalPlan: LogicalPlan,
childDatasetId: Option[Long]): NamedExpression =
if (conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED) && childDatasetId.nonEmpty) {
val newExpr = expr transform {
case a: AttributeReference
=>
val metadata = new MetadataBuilder()
.withMetadata(a.metadata)
.putLong(Dataset.DATASET_ID_KEY, childDatasetId.get)
.putLong(Dataset.COL_POS_KEY, logicalPlan.output.indexWhere(a.semanticEquals))
.build()
a.withMetadata(metadata)
}
newExpr.asInstanceOf[NamedExpression]
} else {
expr
}
}

0 comments on commit 9c59adc

Please sign in to comment.