Skip to content

Commit

Permalink
SPARK-47320: refactored the code to remove UnresolvedAttributeWithTag…
Browse files Browse the repository at this point in the history
…, instead marking the UnresolvedAttribute using a tag
  • Loading branch information
ashahid committed Mar 16, 2024
1 parent 4501ae5 commit 3b8383d
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
expr: Expression,
resolveColumnByName: Seq[String] => Option[Expression],
getAttrCandidates: () => Seq[Attribute],
resolveOnDatasetId: (Long, String) => Option[NamedExpression],
throws: Boolean,
includeLastResort: Boolean): Expression = {
def innerResolve(e: Expression, isTopLevel: Boolean): Expression = withOrigin(e.origin) {
Expand All @@ -157,9 +156,6 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
}
matched(ordinal)

case u @ UnresolvedAttributeWithTag(attr, id) =>
resolveOnDatasetId(id, attr.name).getOrElse(attr)

case u @ UnresolvedAttribute(nameParts) =>
val result = withPosition(u) {
resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map {
Expand Down Expand Up @@ -456,7 +452,6 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
plan.resolve(nameParts, conf.resolver)
},
getAttrCandidates = () => plan.output,
resolveOnDatasetId = (_, _) => None,
throws = throws,
includeLastResort = includeLastResort)
}
Expand All @@ -482,57 +477,6 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
assert(q.children.length == 1)
q.children.head.output
},

resolveOnDatasetId = (datasetid: Long, name: String) => {
def findUnaryNodeMatchingTagId(lp: LogicalPlan): Option[(LogicalPlan, Int)] = {
var currentLp = lp
var depth = 0
while(true) {
if (currentLp.getTagValue(LogicalPlan.DATASET_ID_TAG).exists(_.contains(datasetid))) {
return Option(currentLp, depth)
} else {
if (currentLp.children.size == 1) {
currentLp = currentLp.children.head
} else {
// leaf node or node is a binary node
return None
}
}
depth += 1
}
None
}

val binaryNodeOpt = q.collectFirst {
case bn: BinaryNode => bn
}

val resolveOnAttribs = binaryNodeOpt match {
case Some(bn) =>
val leftDefOpt = findUnaryNodeMatchingTagId(bn.left)
val rightDefOpt = findUnaryNodeMatchingTagId(bn.right)
(leftDefOpt, rightDefOpt) match {

case (None, Some((lp, _))) => lp.output

case (Some((lp, _)), None) => lp.output

case (Some((lp1, depth1)), Some((lp2, depth2))) =>
if (depth1 == depth2) {
q.children.head.output
} else if (depth1 < depth2) {
lp1.output
} else {
lp2.output
}

case _ => q.children.head.output
}

case _ => q.children.head.output
}
AttributeSeq.fromNormalOutput(resolveOnAttribs).resolve(Seq(name), conf.resolver)
},
throws = true,
includeLastResort = includeLastResort)
}
Expand Down Expand Up @@ -574,24 +518,90 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
case _ => e
}

private def resolveUsingDatasetId(
ua: UnresolvedAttribute,
left: LogicalPlan,
right: LogicalPlan,
datasetId: Long): Option[NamedExpression] = {
def findUnaryNodeMatchingTagId(lp: LogicalPlan): Option[(LogicalPlan, Int)] = {
var currentLp = lp
var depth = 0
while (true) {
if (currentLp.getTagValue(LogicalPlan.DATASET_ID_TAG).exists(_.contains(datasetId))) {
return Option(currentLp, depth)
} else {
if (currentLp.children.size == 1) {
currentLp = currentLp.children.head
} else {
// leaf node or node is a binary node
return None
}
}
depth += 1
}
None
}

val leftDefOpt = findUnaryNodeMatchingTagId(left)
val rightDefOpt = findUnaryNodeMatchingTagId(right)
val resolveOnAttribs = (leftDefOpt, rightDefOpt) match {

case (None, Some((lp, _))) => lp.output

case (Some((lp, _)), None) => lp.output

case (Some((lp1, depth1)), Some((lp2, depth2))) =>
if (depth1 == depth2) {
lp1.output
} else if (depth1 < depth2) {
lp1.output
} else {
lp2.output
}

case _ => Seq.empty
}

AttributeSeq.fromNormalOutput(resolveOnAttribs).resolve(Seq(ua.name), conf.resolver)
}

private def resolveDataFrameColumn(
u: UnresolvedAttribute,
q: Seq[LogicalPlan]): Option[NamedExpression] = {
val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG)
if (planIdOpt.isEmpty) return None
val planId = planIdOpt.get
logDebug(s"Extract plan_id $planId from $u")

val isMetadataAccess = u.getTagValue(LogicalPlan.IS_METADATA_COL).nonEmpty
val (resolved, matched) = resolveDataFrameColumnByPlanId(u, planId, isMetadataAccess, q)
if (!matched) {
// Can not find the target plan node with plan id, e.g.
// df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]])
// df2 = spark.createDataFrame([Row(a = 1, b = 2)]])
// df1.select(df2.a) <- illegal reference df2.a
throw QueryCompilationErrors.cannotResolveDataFrameColumn(u)

val attrWithDatasetIdOpt = u.getTagValue(LogicalPlan.ATTRIBUTE_DATASET_ID_TAG)
val resolvedOpt = if (attrWithDatasetIdOpt.isDefined) {
val did = attrWithDatasetIdOpt.get
if (q.size == 1) {
val binaryNodeOpt = q.head.collectFirst {
case bn: BinaryNode => bn
}
binaryNodeOpt.flatMap(bn => resolveUsingDatasetId(u, bn.left, bn.right, did))
} else if (q.size == 2) {
resolveUsingDatasetId(u, q(0), q(1), did)
} else {
None
}
} else {
val planIdOpt = u.getTagValue(LogicalPlan.PLAN_ID_TAG)
if (planIdOpt.isEmpty) {
None
} else {
val planId = planIdOpt.get
logDebug(s"Extract plan_id $planId from $u")
val isMetadataAccess = u.getTagValue(LogicalPlan.IS_METADATA_COL).nonEmpty
val (resolved, matched) = resolveDataFrameColumnByPlanId(u, planId, isMetadataAccess, q)
if (!matched) {
// Can not find the target plan node with plan id, e.g.
// df1 = spark.createDataFrame([Row(a = 1, b = 2, c = 3)]])
// df2 = spark.createDataFrame([Row(a = 1, b = 2)]])
// df1.select(df2.a) <- illegal reference df2.a
throw QueryCompilationErrors.cannotResolveDataFrameColumn(u)
}
resolved
}
}
resolved
resolvedOpt
}

private def resolveDataFrameColumnByPlanId(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ case class UnresolvedAttribute(nameParts: Seq[String]) extends Attribute with Un
nameParts.length == 1 && nameParts.head.equalsIgnoreCase(token)
}
}

/*
case class UnresolvedAttributeWithTag(attribute: Attribute, datasetId: Long) extends Attribute with
Unevaluable {
def name: String = attribute.name
Expand Down Expand Up @@ -309,6 +309,8 @@ case class UnresolvedAttributeWithTag(attribute: Attribute, datasetId: Long) ext
def equalsIgnoreCase(token: String): Boolean = token.equalsIgnoreCase(attribute.name)
}
*/

object UnresolvedAttribute extends AttributeNameParser {
/**
* Creates an [[UnresolvedAttribute]], parsing segments separated by dots ('.').
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ object LogicalPlan {
private[spark] val PLAN_ID_TAG = TreeNodeTag[Long]("plan_id")
private[spark] val IS_METADATA_COL = TreeNodeTag[Unit]("is_metadata_col")
private[spark] val DATASET_ID_TAG = TreeNodeTag[mutable.HashSet[Long]]("dataset_id")
private[spark] val ATTRIBUTE_DATASET_ID_TAG = TreeNodeTag[Long]("dataset_id")
}

/**
Expand Down
22 changes: 18 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,11 @@ class Dataset[T] private[sql](
val rightLegWrong = isIncorrectlyResolved(attr, planPart1.right.outputSet,
rightTagIdMap.getOrElse(HashSet.empty[Long]))
if (!planPart1.outputSet.contains(attr) || leftLegWrong || rightLegWrong) {
UnresolvedAttributeWithTag(attr, attr.metadata.getLong(DATASET_ID_KEY))
val ua = UnresolvedAttribute(attr.name)
ua.copyTagsFrom(attr)
ua.setTagValue(LogicalPlan.ATTRIBUTE_DATASET_ID_TAG,
attr.metadata.getLong(DATASET_ID_KEY))
ua
} else {
attr
}
Expand Down Expand Up @@ -1337,15 +1341,21 @@ class Dataset[T] private[sql](
joined.left.output(index)

case a: AttributeReference if a.metadata.contains(Dataset.DATASET_ID_KEY) =>
UnresolvedAttributeWithTag(a, a.metadata.getLong(Dataset.DATASET_ID_KEY))
val ua = UnresolvedAttribute(a.name)
ua.copyTagsFrom(a)
ua.setTagValue(LogicalPlan.ATTRIBUTE_DATASET_ID_TAG, a.metadata.getLong(DATASET_ID_KEY))
ua
}
val rightAsOfExpr = rightAsOf.expr.transformUp {
case a: AttributeReference if other.logicalPlan.outputSet.contains(a) =>
val index = other.logicalPlan.output.indexWhere(_.exprId == a.exprId)
joined.right.output(index)

case a: AttributeReference if a.metadata.contains(Dataset.DATASET_ID_KEY) =>
UnresolvedAttributeWithTag(a, a.metadata.getLong(Dataset.DATASET_ID_KEY))
val ua = UnresolvedAttribute(a.name)
ua.copyTagsFrom(a)
ua.setTagValue(LogicalPlan.ATTRIBUTE_DATASET_ID_TAG, a.metadata.getLong(DATASET_ID_KEY))
ua
}
withPlan {
AsOfJoin(
Expand Down Expand Up @@ -1614,7 +1624,11 @@ class Dataset[T] private[sql](
case attr: AttributeReference if attr.metadata.contains(DATASET_ID_KEY) &&
(!inputForProj.contains(attr) ||
isIncorrectlyResolved(attr, inputForProj, HashSet(id))) =>
UnresolvedAttributeWithTag(attr, attr.metadata.getLong(DATASET_ID_KEY))
val ua = UnresolvedAttribute(attr.name)
ua.copyTagsFrom(attr)
ua.setTagValue(LogicalPlan.ATTRIBUTE_DATASET_ID_TAG, attr.metadata.getLong(DATASET_ID_KEY))
ua

}).asInstanceOf[NamedExpression])
Project(namedExprs, logicalPlan)
}
Expand Down

0 comments on commit 3b8383d

Please sign in to comment.