From 9c59adc2e2024ccc7a02228487c9d278337ae007 Mon Sep 17 00:00:00 2001 From: ashahid Date: Fri, 15 Dec 2023 10:58:57 -0800 Subject: [PATCH] SPARK-45959. added new tests. Handled flattening of Project when done using dataFrame.select instead of withColumn api --- .../scala/org/apache/spark/sql/Dataset.scala | 4 +- .../sql/internal/EasilyFlattenable.scala | 142 +++++------------- 2 files changed, 39 insertions(+), 107 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ef7ea3ff74027..c4c30a55381b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1574,7 +1574,7 @@ class Dataset[T] private[sql]( case other => other } val newProjList = untypedCols.map(_.named) - (logicalPlan, newProjList, sparkSession.conf) match { + (logicalPlan, newProjList, id) match { case EasilyFlattenable(flattendPlan) if !this.isStreaming && !logicalPlan.getTagValue(LogicalPlan.SKIP_FLATTENING).getOrElse(false) => flattendPlan @@ -2956,7 +2956,7 @@ class Dataset[T] private[sql]( projectList.map(_.name), sparkSession.sessionState.conf.caseSensitiveAnalysis) withPlan( - (logicalPlan, projectList, sparkSession.conf) match { + (logicalPlan, projectList, id) match { case EasilyFlattenable(flattendPlan) if !this.isStreaming && !logicalPlan.getTagValue(LogicalPlan.SKIP_FLATTENING).getOrElse(false) => flattendPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/EasilyFlattenable.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/EasilyFlattenable.scala index c36ffd5af9aed..40a2279b60537 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/EasilyFlattenable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/EasilyFlattenable.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.internal -import scala.collection.mutable import scala.util.{Failure, Success, Try} -import org.apache.spark.sql.{Dataset, RuntimeConfig} +import org.apache.spark.sql.Dataset import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction} -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, NamedExpression, UserDefinedExpression, WindowExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, NamedExpression, UserDefinedExpression, WindowExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} -import org.apache.spark.sql.types.MetadataBuilder + private[sql] object EasilyFlattenable { @@ -34,17 +33,19 @@ private[sql] object EasilyFlattenable { val AddNewColumnsOnly, RemapOnly, Unknown = Value } - def unapply(tuple: (LogicalPlan, Seq[NamedExpression], RuntimeConfig)): Option[LogicalPlan] + def unapply(tuple: (LogicalPlan, Seq[NamedExpression], Long)): Option[LogicalPlan] = { - val (logicalPlan, newProjList, conf) = tuple + val (logicalPlan, newProjList, did) = tuple logicalPlan match { - case p @ Project(projList, child: LogicalPlan) if p.output.groupBy(_.name). - forall(_._2.size == 1) => + case p @ Project(projList, child: LogicalPlan) + if newProjList.flatMap(_.collectLeaves()).forall { + case ar: AttributeReference if ar.metadata.contains(Dataset.DATASET_ID_KEY) && + ar.metadata.getLong(Dataset.DATASET_ID_KEY) != did => false + case _ => true + } => val currentOutputAttribs = AttributeSet(p.output) - val currentDatasetIdOpt = p.getTagValue(Dataset.DATASET_ID_TAG).get.toSet.headOption - // In the new column list identify those Named Expressions which are just attributes and // hence pass thru val (passThruAttribs, tinkeredOrNewNamedExprs) = newProjList.partition { @@ -62,15 +63,15 @@ private[sql] object EasilyFlattenable { // case of new columns being added only val childOutput = child.output.map(_.name).toSet val attribsRemappedInProj = projList.flatMap(ne => ne match { - case _: AttributeReference => Seq.empty[(String, Alias)] + case _: AttributeReference => Seq.empty[(String, Expression)] - case al @ Alias(_, name) => if (childOutput.contains(name)) { - Seq(name -> al) + case Alias(expr, name) => if (childOutput.contains(name)) { + Seq(name -> expr) } else { - Seq.empty[(String, Alias)] + Seq.empty[(String, Expression)] } - case _ => Seq.empty[(String, Alias)] + case _ => Seq.empty[(String, Expression)] }).toMap if (tinkeredOrNewNamedExprs.exists(_.collectFirst { @@ -90,46 +91,30 @@ private[sql] object EasilyFlattenable { val remappedNewProjListResult = Try { newProjList.map { - case attr: AttributeReference => - val ne = projList.find( + case attr: AttributeReference => projList.find( _.toAttribute.canonicalized == attr.canonicalized).getOrElse(attr) - if (attr.metadata.contains(Dataset.DATASET_ID_KEY) && - currentDatasetIdOpt.contains(attr.metadata.getLong( - Dataset.DATASET_ID_KEY))) { - addDataFrameIdToCol(conf, ne, child, currentDatasetIdOpt) - } else { - ne - } case ua: UnresolvedAttribute => - projList.find(_.toAttribute.name.equalsIgnoreCase(ua.name)). + projList.find(_.toAttribute.name.equals(ua.name)). getOrElse(throw new UnsupportedOperationException("Not able to flatten" + s" unresolved attribute $ua")) case anyOtherExpr => (anyOtherExpr transformUp { - case attr: AttributeReference => val ne = - attribsRemappedInProj.get(attr.name).orElse( - projList.find( - _.toAttribute.canonicalized == attr.canonicalized).map { - case al: Alias => al - case x => x - }).getOrElse(attr) - if (attr.metadata.contains(Dataset.DATASET_ID_KEY) && - currentDatasetIdOpt.contains(attr.metadata.getLong( - Dataset.DATASET_ID_KEY))) { - addDataFrameIdToCol(conf, ne, child, currentDatasetIdOpt) - } else { - ne - } + case attr: AttributeReference => + attribsRemappedInProj.get(attr.name).orElse(projList.find( + _.toAttribute.canonicalized == attr.canonicalized).map { + case al: Alias => al.child + case x => x + }).getOrElse(attr) case u: UnresolvedAttribute => attribsRemappedInProj.get(u.name).orElse( - projList.find( _.toAttribute.name.equalsIgnoreCase(u.name)).map { - case al: Alias => al.child - case u: UnresolvedAttribute => - throw new UnsupportedOperationException("Not able to flatten" + - s" unresolved attribute $u") - case x => x + projList.find( _.toAttribute.name.equals(u.name)).map { + case al: Alias => al.child + case u: UnresolvedAttribute => + throw new UnsupportedOperationException("Not able to flatten" + + s" unresolved attribute $u") + case x => x }).getOrElse(throw new UnsupportedOperationException("Not able to flatten" + s" unresolved attribute $u")) }).asInstanceOf[NamedExpression] @@ -138,17 +123,7 @@ private[sql] object EasilyFlattenable { } remappedNewProjListResult match { case Success(remappedNewProjList) => - currentDatasetIdOpt.foreach(id => { - if (conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) { - val dsIds = child.getTagValue(Dataset.DATASET_ID_TAG).getOrElse( - new mutable.HashSet[Long]) - dsIds.add(id) - child.setTagValue(Dataset.DATASET_ID_TAG, dsIds) - } - }) - val newProj = Project(remappedNewProjList, child) - - Option(newProj) + Option(Project(remappedNewProjList, child)) case Failure(_) => None } @@ -158,53 +133,31 @@ private[sql] object EasilyFlattenable { // case of renaming of columns val remappedNewProjListResult = Try { newProjList.map { - case attr: AttributeReference => val ne = projList.find( + case attr: AttributeReference => projList.find( _.toAttribute.canonicalized == attr.canonicalized).get - if (attr.metadata.contains(Dataset.DATASET_ID_KEY) && - currentDatasetIdOpt.contains(attr.metadata.getLong( - Dataset.DATASET_ID_KEY))) { - addDataFrameIdToCol(conf, ne, child, currentDatasetIdOpt) - } else { - ne - } + case ua: UnresolvedAttribute if ua.nameParts.size == 1 => - projList.find( _.toAttribute.name.equalsIgnoreCase(ua.name)). + projList.find( _.toAttribute.name.equals(ua.name)). getOrElse(throw new UnsupportedOperationException("Not able to flatten" + s" unresolved attribute $ua")) case al@Alias(ar: AttributeReference, name) => - val ne = projList.find(_.toAttribute.canonicalized == ar.canonicalized).map { + projList.find(_.toAttribute.canonicalized == ar.canonicalized).map { case alx : Alias => Alias(alx.child, name)(al.exprId, al.qualifier, al.explicitMetadata, al.nonInheritableMetadataKeys) case _: AttributeReference => al }.get - if (ar.metadata.contains(Dataset.DATASET_ID_KEY) && - currentDatasetIdOpt.contains(ar.metadata.getLong( - Dataset.DATASET_ID_KEY))) { - addDataFrameIdToCol(conf, ne, child, currentDatasetIdOpt) - } else { - ne - } + case x => throw new UnsupportedOperationException("Not able to flatten" + s" unresolved attribute $x") } } remappedNewProjListResult match { case Success(remappedNewProjList) => - currentDatasetIdOpt.foreach(id => { - if (conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) { - val dsIds = child.getTagValue(Dataset.DATASET_ID_TAG).getOrElse( - new mutable.HashSet[Long]) - dsIds.add(id) - child.setTagValue(Dataset.DATASET_ID_TAG, dsIds) - } - }) - val newProj = Project(remappedNewProjList, child) - - Option(newProj) + Option(Project(remappedNewProjList, child)) case Failure(_) => None } @@ -236,25 +189,4 @@ private[sql] object EasilyFlattenable { OpType.Unknown } } - - private def addDataFrameIdToCol( - conf: RuntimeConfig, - expr: NamedExpression, - logicalPlan: LogicalPlan, - childDatasetId: Option[Long]): NamedExpression = - if (conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED) && childDatasetId.nonEmpty) { - val newExpr = expr transform { - case a: AttributeReference - => - val metadata = new MetadataBuilder() - .withMetadata(a.metadata) - .putLong(Dataset.DATASET_ID_KEY, childDatasetId.get) - .putLong(Dataset.COL_POS_KEY, logicalPlan.output.indexWhere(a.semanticEquals)) - .build() - a.withMetadata(metadata) - } - newExpr.asInstanceOf[NamedExpression] - } else { - expr - } }