Skip to content

Commit

Permalink
[SPARK-42548][SQL] Add ReferenceAllColumns to skip rewriting attributes
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Add a new trait `ReferenceAllColumns ` that overrides `references` using children output. Then we can skip it during rewriting attributes in transformUpWithNewOutput.

### Why are the changes needed?

There are two reasons with this new trait:

1. it's dangerous to call `references` on an unresolved plan that all of references come from children
2. it's unnecessary to rewrite its attributes that all of references come from children

### Does this PR introduce _any_ user-facing change?

prevent potential bug

### How was this patch tested?

add test and pass CI

Closes apache#40154 from ulysses-you/references.

Authored-by: ulysses-you <ulyssesyou18@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
ulysses-you authored and cloud-fan committed Feb 28, 2023
1 parent 3320725 commit db0e822
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -297,21 +297,28 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
newChild
}

val attrMappingForCurrentPlan = attrMapping.filter {
// The `attrMappingForCurrentPlan` is used to replace the attributes of the
// current `plan`, so the `oldAttr` must be part of `plan.references`.
case (oldAttr, _) => plan.references.contains(oldAttr)
}

if (attrMappingForCurrentPlan.nonEmpty) {
assert(!attrMappingForCurrentPlan.groupBy(_._1.exprId)
.exists(_._2.map(_._2.exprId).distinct.length > 1),
"Found duplicate rewrite attributes")

val attributeRewrites = AttributeMap(attrMappingForCurrentPlan)
// Using attrMapping from the children plans to rewrite their parent node.
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
newPlan = newPlan.rewriteAttrs(attributeRewrites)
plan match {
case _: ReferenceAllColumns[_] =>
// It's dangerous to call `references` on an unresolved `ReferenceAllColumns`, and
// it's unnecessary to rewrite its attributes that all of references come from children

case _ =>
val attrMappingForCurrentPlan = attrMapping.filter {
// The `attrMappingForCurrentPlan` is used to replace the attributes of the
// current `plan`, so the `oldAttr` must be part of `plan.references`.
case (oldAttr, _) => plan.references.contains(oldAttr)
}

if (attrMappingForCurrentPlan.nonEmpty) {
assert(!attrMappingForCurrentPlan.groupBy(_._1.exprId)
.exists(_._2.map(_._2.exprId).distinct.length > 1),
"Found duplicate rewrite attributes")

val attributeRewrites = AttributeMap(attrMappingForCurrentPlan)
// Using attrMapping from the children plans to rewrite their parent node.
// Note that we shouldn't rewrite a node using attrMapping from its sibling nodes.
newPlan = newPlan.rewriteAttrs(attributeRewrites)
}
}

val (planAfterRule, newAttrMapping) = CurrentOrigin.withOrigin(origin) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.plans

import org.apache.spark.sql.catalyst.expressions.AttributeSet

/**
* A trait that overrides `references` using children output.
*
* It's unnecessary to rewrite attributes for `ReferenceAllColumns` since all of references
* come from it's children.
*
* Note, the only used place is at [[QueryPlan.transformUpWithNewOutput]].
*/
trait ReferenceAllColumns[PlanType <: QueryPlan[PlanType]] { self: QueryPlan[PlanType] =>

@transient
override final lazy val references: AttributeSet = AttributeSet(children.flatMap(_.outputSet))
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns

/**
* Transforms the input by forking and running the specified script.
Expand All @@ -30,10 +31,7 @@ case class ScriptTransformation(
script: String,
output: Seq[Attribute],
child: LogicalPlan,
ioschema: ScriptInputOutputSchema) extends UnaryNode {
@transient
override lazy val references: AttributeSet = AttributeSet(child.output)

ioschema: ScriptInputOutputSchema) extends UnaryNode with ReferenceAllColumns[LogicalPlan] {
override protected def withNewChildInternal(newChild: LogicalPlan): ScriptTransformation =
copy(child = newChild)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
Expand Down Expand Up @@ -64,13 +65,8 @@ trait ObjectProducer extends LogicalPlan {
* A trait for logical operators that consumes domain objects as input.
* The output of its child must be a single-field row containing the input object.
*/
trait ObjectConsumer extends UnaryNode {
trait ObjectConsumer extends UnaryNode with ReferenceAllColumns[LogicalPlan] {
assert(child.output.length == 1)

// This operator always need all columns of its child, even it doesn't reference to.
@transient
override lazy val references: AttributeSet = child.outputSet

def inputObjAttr: Attribute = child.output.head
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCoercion._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -1740,6 +1741,16 @@ class TypeCoercionSuite extends TypeCoercionSuiteBase {
}
}
}

test("SPARK-32638: Add ReferenceAllColumns to skip rewriting attributes") {
val t1 = LocalRelation(AttributeReference("c", DecimalType(1, 0))())
val t2 = LocalRelation(AttributeReference("c", DecimalType(2, 0))())
val unresolved = t1.union(t2).select(UnresolvedStar(None))
val referenceAllColumns = FakeReferenceAllColumns(unresolved)
val wp1 = widenSetOperationTypes(referenceAllColumns.select(t1.output.head))
assert(wp1.isInstanceOf[Project])
assert(wp1.expressions.forall(!_.exists(_ == t1.output.head)))
}
}


Expand Down Expand Up @@ -1798,3 +1809,10 @@ object TypeCoercionSuite {
copy(left = newLeft, right = newRight)
}
}

case class FakeReferenceAllColumns(child: LogicalPlan)
extends UnaryNode with ReferenceAllColumns[LogicalPlan] {
override def output: Seq[Attribute] = child.output
override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, FunctionUtils, LogicalGroupState}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.python.BatchIterator
Expand All @@ -58,13 +59,8 @@ trait ObjectProducerExec extends SparkPlan {
/**
* Physical version of `ObjectConsumer`.
*/
trait ObjectConsumerExec extends UnaryExecNode {
trait ObjectConsumerExec extends UnaryExecNode with ReferenceAllColumns[SparkPlan] {
assert(child.output.length == 1)

// This operator always need all columns of its child, even it doesn't reference to.
@transient
override lazy val references: AttributeSet = child.outputSet

def inputObjectType: DataType = child.output.head.dataType
}

Expand Down

0 comments on commit db0e822

Please sign in to comment.