Skip to content

Commit

Permalink
SPARK-45959. reworked and simplified the code. Instead of Collapsing …
Browse files Browse the repository at this point in the history
…before analyse, now collapsing after analyze
  • Loading branch information
ashahid committed Dec 15, 2023
1 parent 9a316c1 commit cfb2b04
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 60 deletions.
19 changes: 3 additions & 16 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation, FileTable}
import org.apache.spark.sql.execution.python.EvaluatePython
import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.internal.{EarlyCollapseProject, SQLConf}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
Expand Down Expand Up @@ -1573,14 +1573,7 @@ class Dataset[T] private[sql](

case other => other
}
val newProjList = untypedCols.map(_.named)
(logicalPlan, newProjList, id) match {
case EarlyCollapseProject(flattendPlan) if !this.isStreaming &&
!logicalPlan.getTagValue(LogicalPlan.SKIP_EARLY_PROJECT_COLLAPSE).getOrElse(false) =>
flattendPlan

case _ => Project(newProjList, logicalPlan)
}
Project(untypedCols.map(_.named), logicalPlan)
}
}

Expand Down Expand Up @@ -2957,13 +2950,7 @@ class Dataset[T] private[sql](
projectList.map(_.name),
sparkSession.sessionState.conf.caseSensitiveAnalysis)
withPlan(
(logicalPlan, projectList, id) match {
case EarlyCollapseProject(flattendPlan) if !this.isStreaming &&
!logicalPlan.getTagValue(LogicalPlan.SKIP_EARLY_PROJECT_COLLAPSE).getOrElse(false) =>
flattendPlan

case _ => Project(projectList, logicalPlan)
}
Project(projectList, logicalPlan)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, WatermarkPropagator}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.{EarlyCollapseProject, SQLConf}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -85,7 +85,11 @@ class QueryExecution(
lazy val analyzed: LogicalPlan = {
val plan = executePhase(QueryPlanningTracker.ANALYSIS) {
// We can't clone `logical` here, which will reset the `_analyzed` flag.
sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker)
val analyzedPlan = sparkSession.sessionState.analyzer.executeAndCheck(logical, tracker)
analyzedPlan match {
case EarlyCollapseProject(collapsedPlan) => collapsedPlan
case _ => analyzedPlan
}
}
tracker.setAnalyzed(plan)
plan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,23 @@ package org.apache.spark.sql.internal

import scala.util.{Failure, Success, Try}

import org.apache.spark.sql.Dataset
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, Expression, NamedExpression, UserDefinedExpression, WindowExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}


private[sql] object EarlyCollapseProject {
object OpType extends Enumeration {
type OpType = Value
val AddNewColumnsOnly, RemapOnly, Unknown = Value
}

def unapply(tuple: (LogicalPlan, Seq[NamedExpression], Long)): Option[LogicalPlan] = {
val (logicalPlan, newProjList, did) = tuple
def unapply(logicalPlan: LogicalPlan): Option[LogicalPlan] = {


logicalPlan match {
case p@Project(projList, child: LogicalPlan)
if newProjList.flatMap(_.collectLeaves()).forall {
case ar: AttributeReference if ar.metadata.contains(Dataset.DATASET_ID_KEY) &&
ar.metadata.getLong(Dataset.DATASET_ID_KEY) != did => false
case _ => true
} =>
case Project(newProjList, p @ Project(projList, child)) if !p.getTagValue(
LogicalPlan.SKIP_EARLY_PROJECT_COLLAPSE).getOrElse(false)
=>
val currentOutputAttribs = AttributeSet(p.output)

// In the new column list identify those Named Expressions which are just attributes and
Expand Down Expand Up @@ -80,9 +74,6 @@ private[sql] object EarlyCollapseProject {
case ex: AggregateExpression => ex
case ex: WindowExpression => ex
case ex: UserDefinedExpression => ex
case u: UnresolvedAttribute if u.nameParts.size != 1 => u
case u: UnresolvedFunction if u.nameParts.size == 1 & u.nameParts.head == "struct" =>
u
}.nonEmpty)) {
None
} else {
Expand All @@ -91,10 +82,6 @@ private[sql] object EarlyCollapseProject {
case attr: AttributeReference => projList.find(
_.toAttribute.canonicalized == attr.canonicalized).getOrElse(attr)

case ua: UnresolvedAttribute => projList.find(_.toAttribute.name.equals(ua.name)).
getOrElse(throw new UnsupportedOperationException("Not able to flatten" +
s" unresolved attribute $ua"))

case anyOtherExpr =>
(anyOtherExpr transformUp {
case attr: AttributeReference =>
Expand All @@ -103,24 +90,11 @@ private[sql] object EarlyCollapseProject {
case al: Alias => al.child
case x => x
}).getOrElse(attr)

case u: UnresolvedAttribute => attribsRemappedInProj.get(u.name).orElse(
projList.find(_.toAttribute.name.equals(u.name)).map {
case al: Alias => al.child

case u: UnresolvedAttribute =>
throw new UnsupportedOperationException("Not able to flatten" +
s" unresolved attribute $u")

case x => x
}).getOrElse(throw new UnsupportedOperationException("Not able to flatten" +
s" unresolved attribute $u"))
}).asInstanceOf[NamedExpression]
}
}
remappedNewProjListResult match {
case Success(remappedNewProjList) =>
Option(Project(remappedNewProjList, child))
case Success(remappedNewProjList) => Option(Project(remappedNewProjList, child))

case Failure(_) => None
}
Expand All @@ -133,10 +107,6 @@ private[sql] object EarlyCollapseProject {
case attr: AttributeReference => projList.find(
_.toAttribute.canonicalized == attr.canonicalized).get

case ua: UnresolvedAttribute if ua.nameParts.size == 1 =>
projList.find(_.toAttribute.name.equals(ua.name)).
getOrElse(throw new UnsupportedOperationException("Not able to flatten" +
s" unresolved attribute $ua"))

case al@Alias(ar: AttributeReference, name) =>
projList.find(_.toAttribute.canonicalized == ar.canonicalized).map {
Expand All @@ -146,14 +116,10 @@ private[sql] object EarlyCollapseProject {

case _: AttributeReference => al
}.get

case x => throw new UnsupportedOperationException("Not able to flatten" +
s" unresolved attribute $x")
}
}
remappedNewProjListResult match {
case Success(remappedNewProjList) =>
Option(Project(remappedNewProjList, child))
case Success(remappedNewProjList) => Option(Project(remappedNewProjList, child))

case Failure(_) => None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class EarlyCollapseProjectSuite extends QueryTest
(newDfOpt, newDfUnopt)
}

private def collectNodes(df: DataFrame): Seq[LogicalPlan] = df.queryExecution.logical.collect {
private def collectNodes(df: DataFrame): Seq[LogicalPlan] = df.logicalPlan.collect {
case l => l
}
}
Expand Down

0 comments on commit cfb2b04

Please sign in to comment.