Skip to content

Commit

Permalink
SPARK-47320. Modified the code to ensure that for unambiguous attribu…
Browse files Browse the repository at this point in the history
…tes resolved using datasetId for top level join, the behaviour remains unchanged independent of the flag spark.sql.analyzer.failAmbiguousSelfJoin value
  • Loading branch information
ashahid committed Mar 29, 2024
1 parent 3619857 commit 03149d5
Show file tree
Hide file tree
Showing 4 changed files with 388 additions and 309 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier}
import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.MetadataBuilder

trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {

Expand Down Expand Up @@ -518,6 +519,15 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
case _ => e
}

private def stripColumnReferenceMetadata(a: AttributeReference): AttributeReference = {
val metadataWithoutId = new MetadataBuilder()
.withMetadata(a.metadata)
.remove(LogicalPlan.DATASET_ID_KEY)
.remove(LogicalPlan.COL_POS_KEY)
.build()
a.withMetadata(metadataWithoutId)
}

private def resolveUsingDatasetId(
ua: UnresolvedAttribute,
left: LogicalPlan,
Expand All @@ -527,7 +537,8 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
var currentLp = lp
var depth = 0
while (true) {
if (currentLp.getTagValue(LogicalPlan.DATASET_ID_TAG).exists(_.contains(datasetId))) {
if (currentLp.getTagValue(LogicalPlan.DATASET_RESOLUTION_TAG).exists(
_.contains(datasetId))) {
return Option(currentLp, depth)
} else {
if (currentLp.children.size == 1) {
Expand All @@ -550,39 +561,61 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {

case (Some((lp, _)), None) => lp.output

case (Some((lp1, depth1)), Some((lp2, depth2))) =>
if (depth1 == depth2) {
lp1.output
} else if (depth1 < depth2) {
lp1.output
} else {
lp2.output
}
case (Some((lp1, depth1)), Some((lp2, depth2))) => if (depth1 == depth2) {
Seq.empty
} else if (depth1 < depth2) {
lp1.output
} else {
lp2.output
}

case _ => Seq.empty
}

AttributeSeq.fromNormalOutput(resolveOnAttribs).resolve(Seq(ua.name), conf.resolver)
if (resolveOnAttribs.isEmpty) {
None
} else {
AttributeSeq.fromNormalOutput(resolveOnAttribs).resolve(Seq(ua.name), conf.resolver)
}
}

private def resolveDataFrameColumn(
u: UnresolvedAttribute,
q: Seq[LogicalPlan]): Option[NamedExpression] = {

val attrWithDatasetIdOpt = u.getTagValue(LogicalPlan.ATTRIBUTE_DATASET_ID_TAG)
val resolvedOpt = if (attrWithDatasetIdOpt.isDefined) {
val did = attrWithDatasetIdOpt.get
if (q.size == 1) {
val binaryNodeOpt = q.head.collectFirst {
case bn: BinaryNode => bn
val origAttrOpt = u.getTagValue(LogicalPlan.UNRESOLVED_ATTRIBUTE_MD_TAG)
val resolvedOptWithDatasetId = if (origAttrOpt.isDefined) {
val md = origAttrOpt.get.metadata
if (md.contains(LogicalPlan.DATASET_ID_KEY)) {
val did = md.getLong(LogicalPlan.DATASET_ID_KEY)
val resolved = if (q.size == 1) {
val binaryNodeOpt = q.head.collectFirst {
case bn: BinaryNode => bn
}
binaryNodeOpt.flatMap(bn => resolveUsingDatasetId(u, bn.left, bn.right, did))
} else if (q.size == 2) {
resolveUsingDatasetId(u, q(0), q(1), did)
} else {
None
}
if (resolved.isEmpty) {
if (conf.getConf(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED)) {
origAttrOpt
} else {
origAttrOpt.map(stripColumnReferenceMetadata)
}
} else {
resolved
}
binaryNodeOpt.flatMap(bn => resolveUsingDatasetId(u, bn.left, bn.right, did))
} else if (q.size == 2) {
resolveUsingDatasetId(u, q(0), q(1), did)
} else {
None
origAttrOpt
}
} else {
None
}
val resolvedOpt = if (resolvedOptWithDatasetId.isDefined) {
resolvedOptWithDatasetId
}
else {
val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG)
if (planIdOpt.isEmpty) {
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,10 @@ object LogicalPlan {
// to the old code path.
private[spark] val PLAN_ID_TAG = TreeNodeTag[Long]("plan_id")
private[spark] val IS_METADATA_COL = TreeNodeTag[Unit]("is_metadata_col")
private[spark] val DATASET_ID_TAG = TreeNodeTag[mutable.HashSet[Long]]("dataset_id")
private[spark] val ATTRIBUTE_DATASET_ID_TAG = TreeNodeTag[Long]("dataset_id")
private[spark] val DATASET_RESOLUTION_TAG = TreeNodeTag[mutable.HashSet[Long]]("dataset_id")
private[spark] val UNRESOLVED_ATTRIBUTE_MD_TAG = TreeNodeTag[AttributeReference]("orig-attr")
private[spark] val DATASET_ID_KEY = "__dataset_id"
private[spark] val COL_POS_KEY = "__col_position"
}

/**
Expand Down
30 changes: 16 additions & 14 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql
import java.io.{ByteArrayOutputStream, CharArrayWriter, DataOutputStream}

import scala.annotation.varargs
import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashSet}
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
Expand Down Expand Up @@ -48,7 +49,7 @@ import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.TreePattern
import org.apache.spark.sql.catalyst.trees.{TreeNodeTag, TreePattern}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils}
import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
Expand All @@ -72,9 +73,9 @@ import org.apache.spark.util.Utils

private[sql] object Dataset {
val curId = new java.util.concurrent.atomic.AtomicLong()
val DATASET_ID_KEY = "__dataset_id"
val COL_POS_KEY = "__col_position"
val DATASET_ID_TAG = LogicalPlan.DATASET_ID_TAG
val DATASET_ID_KEY = LogicalPlan.DATASET_ID_KEY
val COL_POS_KEY = LogicalPlan.COL_POS_KEY
val DATASET_ID_TAG = TreeNodeTag[mutable.HashSet[Long]]("dataset_id")

def apply[T: Encoder](sparkSession: SparkSession, logicalPlan: LogicalPlan): Dataset[T] = {
val dataset = new Dataset(sparkSession, logicalPlan, implicitly[Encoder[T]])
Expand Down Expand Up @@ -228,6 +229,9 @@ class Dataset[T] private[sql](
dsIds.add(id)
plan.setTagValue(Dataset.DATASET_ID_TAG, dsIds)
}
val dsIds = plan.getTagValue(Dataset.DATASET_ID_TAG).getOrElse(new HashSet[Long])
dsIds.add(id)
plan.setTagValue(LogicalPlan.DATASET_RESOLUTION_TAG, dsIds)
plan
}

Expand Down Expand Up @@ -1177,8 +1181,8 @@ class Dataset[T] private[sql](
Join(logicalPlan, right.logicalPlan,
JoinType(joinType), None, JoinHint.NONE)).queryExecution.analyzed.asInstanceOf[Join]

val leftTagIdMap = planPart1.left.getTagValue(LogicalPlan.DATASET_ID_TAG)
val rightTagIdMap = planPart1.right.getTagValue(LogicalPlan.DATASET_ID_TAG)
val leftTagIdMap = planPart1.left.getTagValue(LogicalPlan.DATASET_RESOLUTION_TAG)
val rightTagIdMap = planPart1.right.getTagValue(LogicalPlan.DATASET_RESOLUTION_TAG)

val joinExprsRectified = joinExprs.map(_.expr transformUp {
case attr: AttributeReference if attr.metadata.contains(DATASET_ID_KEY) =>
Expand All @@ -1190,8 +1194,7 @@ class Dataset[T] private[sql](
if (!planPart1.outputSet.contains(attr) || leftLegWrong || rightLegWrong) {
val ua = UnresolvedAttribute(Seq(attr.name))
ua.copyTagsFrom(attr)
ua.setTagValue(LogicalPlan.ATTRIBUTE_DATASET_ID_TAG,
attr.metadata.getLong(DATASET_ID_KEY))
ua.setTagValue(LogicalPlan.UNRESOLVED_ATTRIBUTE_MD_TAG, attr)
ua
} else {
attr
Expand Down Expand Up @@ -1340,7 +1343,7 @@ class Dataset[T] private[sql](
case a: AttributeReference if a.metadata.contains(Dataset.DATASET_ID_KEY) =>
val ua = UnresolvedAttribute(Seq(a.name))
ua.copyTagsFrom(a)
ua.setTagValue(LogicalPlan.ATTRIBUTE_DATASET_ID_TAG, a.metadata.getLong(DATASET_ID_KEY))
ua.setTagValue(LogicalPlan.UNRESOLVED_ATTRIBUTE_MD_TAG, a)
ua
}
val rightAsOfExpr = rightAsOf.expr.transformUp {
Expand All @@ -1351,7 +1354,7 @@ class Dataset[T] private[sql](
case a: AttributeReference if a.metadata.contains(Dataset.DATASET_ID_KEY) =>
val ua = UnresolvedAttribute(Seq(a.name))
ua.copyTagsFrom(a)
ua.setTagValue(LogicalPlan.ATTRIBUTE_DATASET_ID_TAG, a.metadata.getLong(DATASET_ID_KEY))
ua.setTagValue(LogicalPlan.UNRESOLVED_ATTRIBUTE_MD_TAG, a)
ua
}
withPlan {
Expand Down Expand Up @@ -1525,8 +1528,8 @@ class Dataset[T] private[sql](
// `DetectAmbiguousSelfJoin` will remove it.
private def addDataFrameIdToCol(expr: NamedExpression): NamedExpression = {
val newExpr = expr transform {
case a: AttributeReference
if sparkSession.conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED) =>
case a: AttributeReference =>
// if sparkSession.conf.get(SQLConf.FAIL_AMBIGUOUS_SELF_JOIN_ENABLED) =>
val metadata = new MetadataBuilder()
.withMetadata(a.metadata)
.putLong(Dataset.DATASET_ID_KEY, id)
Expand Down Expand Up @@ -1623,9 +1626,8 @@ class Dataset[T] private[sql](
isIncorrectlyResolved(attr, inputForProj, HashSet(id))) =>
val ua = UnresolvedAttribute(Seq(attr.name))
ua.copyTagsFrom(attr)
ua.setTagValue(LogicalPlan.ATTRIBUTE_DATASET_ID_TAG, attr.metadata.getLong(DATASET_ID_KEY))
ua.setTagValue(LogicalPlan.UNRESOLVED_ATTRIBUTE_MD_TAG, attr)
ua

}).asInstanceOf[NamedExpression])
Project(namedExprs, logicalPlan)
}
Expand Down
Loading

0 comments on commit 03149d5

Please sign in to comment.