Skip to content

Commit

Permalink
[SPARK-41468][SQL] Fix PlanExpression handling in EquivalentExpressions
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
#36012 already added a check to avoid adding expressions containing `PlanExpression`s to `EquivalentExpressions` as those expressions might cause NPE on executors. But, for some reason, the check is still missing from `getExprState()` where we check the presence of an experssion in the equivalence map.

This PR:
- adds the check to `getExprState()`
- moves the check from `updateExprTree()` to `addExprTree()` so as to run it only once.

### Why are the changes needed?
To avoid exceptions like:
```
org.apache.spark.SparkException: Task failed while writing rows.
        at org.apache.spark.sql.errors.QueryExecutionErrors$.taskFailedWhileWritingRowsError(QueryExecutionErrors.scala:642)
        at org.apache.spark.sql.execution.datasources.FileFormatWriter$.executeTask(FileFormatWriter.scala:348)
        at org.apache.spark.sql.execution.datasources.FileFormatWriter$.$anonfun$write$21(FileFormatWriter.scala:256)
        at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
        at org.apache.spark.scheduler.Task.run(Task.scala:136)
        at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
        at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
        at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
        at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
        at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
        at java.base/java.lang.Thread.run(Thread.java:834)
Caused by: java.lang.NullPointerException
        at org.apache.spark.sql.execution.columnar.InMemoryTableScanExec.$anonfun$doCanonicalize$1(InMemoryTableScanExec.scala:51)
        at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
        at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
        at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
        at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
        at scala.collection.TraversableLike.map(TraversableLike.scala:286)
        at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
        at scala.collection.AbstractTraversable.map(Traversable.scala:108)
        at org.apache.spark.sql.execution.columnar.InMemoryTableScanExec.doCanonicalize(InMemoryTableScanExec.scala:51)
        at org.apache.spark.sql.execution.columnar.InMemoryTableScanExec.doCanonicalize(InMemoryTableScanExec.scala:30)
        ...
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:541)
        at org.apache.spark.sql.execution.SubqueryExec.doCanonicalize(basicPhysicalOperators.scala:850)
        at org.apache.spark.sql.execution.SubqueryExec.doCanonicalize(basicPhysicalOperators.scala:814)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:542)
        at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:541)
        at org.apache.spark.sql.execution.ScalarSubquery.preCanonicalized$lzycompute(subquery.scala:72)
        at org.apache.spark.sql.execution.ScalarSubquery.preCanonicalized(subquery.scala:71)
        ...
        at org.apache.spark.sql.catalyst.expressions.Expression.canonicalized(Expression.scala:261)
        at org.apache.spark.sql.catalyst.expressions.Expression.semanticHash(Expression.scala:278)
        at org.apache.spark.sql.catalyst.expressions.ExpressionEquals.hashCode(EquivalentExpressions.scala:226)
        at scala.runtime.Statics.anyHash(Statics.java:122)
        at scala.collection.mutable.HashTable$HashUtils.elemHashCode(HashTable.scala:416)
        at scala.collection.mutable.HashTable$HashUtils.elemHashCode$(HashTable.scala:416)
        at scala.collection.mutable.HashMap.elemHashCode(HashMap.scala:44)
        at scala.collection.mutable.HashTable.findEntry(HashTable.scala:136)
        at scala.collection.mutable.HashTable.findEntry$(HashTable.scala:135)
        at scala.collection.mutable.HashMap.findEntry(HashMap.scala:44)
        at scala.collection.mutable.HashMap.get(HashMap.scala:74)
        at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.getExprState(EquivalentExpressions.scala:180)
        at org.apache.spark.sql.catalyst.expressions.SubExprEvaluationRuntime.replaceWithProxy(SubExprEvaluationRuntime.scala:78)
        at org.apache.spark.sql.catalyst.expressions.SubExprEvaluationRuntime.$anonfun$proxyExpressions$3(SubExprEvaluationRuntime.scala:109)
        at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286)
        at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
        at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
        at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:38)
        at scala.collection.TraversableLike.map(TraversableLike.scala:286)
        at scala.collection.TraversableLike.map$(TraversableLike.scala:279)
        at scala.collection.AbstractTraversable.map(Traversable.scala:108)
        at org.apache.spark.sql.catalyst.expressions.SubExprEvaluationRuntime.proxyExpressions(SubExprEvaluationRuntime.scala:109)
        at org.apache.spark.sql.catalyst.expressions.InterpretedUnsafeProjection.<init>(InterpretedUnsafeProjection.scala:40)
        at org.apache.spark.sql.catalyst.expressions.InterpretedUnsafeProjection$.createProjection(InterpretedUnsafeProjection.scala:112)
        at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.createInterpretedObject(Projection.scala:127)
        at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.createInterpretedObject(Projection.scala:119)
        at org.apache.spark.sql.catalyst.expressions.CodeGeneratorWithInterpretedFallback.createObject(CodeGeneratorWithInterpretedFallback.scala:56)
        at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:150)
        at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:160)
        at org.apache.spark.sql.execution.ProjectExec.$anonfun$doExecute$1(basicPhysicalOperators.scala:95)
        at org.apache.spark.sql.execution.ProjectExec.$anonfun$doExecute$1$adapted(basicPhysicalOperators.scala:94)
        at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2(RDD.scala:877)
        at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2$adapted(RDD.scala:877)
        at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
        at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
        at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
        at org.apache.spark.rdd.UnionRDD.compute(UnionRDD.scala:106)
        at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
        at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
        at org.apache.spark.rdd.CoalescedRDD.$anonfun$compute$1(CoalescedRDD.scala:99)
        at scala.collection.Iterator$$anon$11.nextCur(Iterator.scala:486)
        at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:492)
        at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
        at org.apache.spark.sql.execution.datasources.FileFormatDataWriter.writeWithIterator(FileFormatDataWriter.scala:91)
        at org.apache.spark.sql.execution.datasources.FileFormatWriter$.$anonfun$executeTask$1(FileFormatWriter.scala:331)
        at org.apache.spark.util.Utils$.tryWithSafeFinallyAndFailureCallbacks(Utils.scala:1538)
        at org.apache.spark.sql.execution.datasources.FileFormatWriter$.executeTask(FileFormatWriter.scala:338)
```

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Existing UTs.

Closes #39010 from peter-toth/SPARK-41468-fix-planexpressions-in-equivalentexpressions.

Authored-by: Peter Toth <peter.toth@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
peter-toth authored and cloud-fan committed Dec 12, 2022
1 parent 63226ca commit 1b2d700
Showing 1 changed file with 23 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,28 +142,37 @@ class EquivalentExpressions {
case _ => Nil
}

private def supportedExpression(e: Expression) = {
!e.exists {
// `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
// loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
case _: LambdaVariable => true

// `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor,
// can cause error like NPE.
case _: PlanExpression[_] => Utils.isInRunningSparkTask

case _ => false
}
}

/**
* Adds the expression to this data structure recursively. Stops if a matching expression
* is found. That is, if `expr` has already been added, its children are not added.
*/
def addExprTree(
expr: Expression,
map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap): Unit = {
updateExprTree(expr, map)
if (supportedExpression(expr)) {
updateExprTree(expr, map)
}
}

private def updateExprTree(
expr: Expression,
map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap,
useCount: Int = 1): Unit = {
val skip = useCount == 0 ||
expr.isInstanceOf[LeafExpression] ||
// `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the
// loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning.
expr.exists(_.isInstanceOf[LambdaVariable]) ||
// `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor,
// can cause error like NPE.
(expr.exists(_.isInstanceOf[PlanExpression[_]]) && Utils.isInRunningSparkTask)
val skip = useCount == 0 || expr.isInstanceOf[LeafExpression]

if (!skip && !updateExprInMap(expr, map, useCount)) {
val uc = useCount.signum
Expand All @@ -177,7 +186,11 @@ class EquivalentExpressions {
* equivalent expressions.
*/
def getExprState(e: Expression): Option[ExpressionStats] = {
equivalenceMap.get(ExpressionEquals(e))
if (supportedExpression(e)) {
equivalenceMap.get(ExpressionEquals(e))
} else {
None
}
}

// Exposed for testing.
Expand Down

0 comments on commit 1b2d700

Please sign in to comment.