From 231c63a245192045d77cc74e1d085986fda056fd Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 12 Dec 2022 18:14:24 +0800 Subject: [PATCH] [SPARK-41468][SQL] Fix PlanExpression handling in EquivalentExpressions ### What changes were proposed in this pull request? https://github.com/apache/spark/pull/36012 already added a check to avoid adding expressions containing `PlanExpression`s to `EquivalentExpressions` as those expressions might cause NPE on executors. But, for some reason, the check is still missing from `getExprState()` where we check the presence of an experssion in the equivalence map. This PR: - adds the check to `getExprState()` - moves the check from `updateExprTree()` to `addExprTree()` so as to run it only once. ### Why are the changes needed? To avoid exceptions like: ``` org.apache.spark.SparkException: Task failed while writing rows. at org.apache.spark.sql.errors.QueryExecutionErrors$.taskFailedWhileWritingRowsError(QueryExecutionErrors.scala:642) at org.apache.spark.sql.execution.datasources.FileFormatWriter$.executeTask(FileFormatWriter.scala:348) at org.apache.spark.sql.execution.datasources.FileFormatWriter$.$anonfun$write$21(FileFormatWriter.scala:256) at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90) at org.apache.spark.scheduler.Task.run(Task.scala:136) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551) at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128) at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628) at java.base/java.lang.Thread.run(Thread.java:834) Caused by: java.lang.NullPointerException at org.apache.spark.sql.execution.columnar.InMemoryTableScanExec.$anonfun$doCanonicalize$1(InMemoryTableScanExec.scala:51) at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286) at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62) at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49) at scala.collection.TraversableLike.map(TraversableLike.scala:286) at scala.collection.TraversableLike.map$(TraversableLike.scala:279) at scala.collection.AbstractTraversable.map(Traversable.scala:108) at org.apache.spark.sql.execution.columnar.InMemoryTableScanExec.doCanonicalize(InMemoryTableScanExec.scala:51) at org.apache.spark.sql.execution.columnar.InMemoryTableScanExec.doCanonicalize(InMemoryTableScanExec.scala:30) ... at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:541) at org.apache.spark.sql.execution.SubqueryExec.doCanonicalize(basicPhysicalOperators.scala:850) at org.apache.spark.sql.execution.SubqueryExec.doCanonicalize(basicPhysicalOperators.scala:814) at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized$lzycompute(QueryPlan.scala:542) at org.apache.spark.sql.catalyst.plans.QueryPlan.canonicalized(QueryPlan.scala:541) at org.apache.spark.sql.execution.ScalarSubquery.preCanonicalized$lzycompute(subquery.scala:72) at org.apache.spark.sql.execution.ScalarSubquery.preCanonicalized(subquery.scala:71) ... at org.apache.spark.sql.catalyst.expressions.Expression.canonicalized(Expression.scala:261) at org.apache.spark.sql.catalyst.expressions.Expression.semanticHash(Expression.scala:278) at org.apache.spark.sql.catalyst.expressions.ExpressionEquals.hashCode(EquivalentExpressions.scala:226) at scala.runtime.Statics.anyHash(Statics.java:122) at scala.collection.mutable.HashTable$HashUtils.elemHashCode(HashTable.scala:416) at scala.collection.mutable.HashTable$HashUtils.elemHashCode$(HashTable.scala:416) at scala.collection.mutable.HashMap.elemHashCode(HashMap.scala:44) at scala.collection.mutable.HashTable.findEntry(HashTable.scala:136) at scala.collection.mutable.HashTable.findEntry$(HashTable.scala:135) at scala.collection.mutable.HashMap.findEntry(HashMap.scala:44) at scala.collection.mutable.HashMap.get(HashMap.scala:74) at org.apache.spark.sql.catalyst.expressions.EquivalentExpressions.getExprState(EquivalentExpressions.scala:180) at org.apache.spark.sql.catalyst.expressions.SubExprEvaluationRuntime.replaceWithProxy(SubExprEvaluationRuntime.scala:78) at org.apache.spark.sql.catalyst.expressions.SubExprEvaluationRuntime.$anonfun$proxyExpressions$3(SubExprEvaluationRuntime.scala:109) at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:286) at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36) at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33) at scala.collection.mutable.WrappedArray.foreach(WrappedArray.scala:38) at scala.collection.TraversableLike.map(TraversableLike.scala:286) at scala.collection.TraversableLike.map$(TraversableLike.scala:279) at scala.collection.AbstractTraversable.map(Traversable.scala:108) at org.apache.spark.sql.catalyst.expressions.SubExprEvaluationRuntime.proxyExpressions(SubExprEvaluationRuntime.scala:109) at org.apache.spark.sql.catalyst.expressions.InterpretedUnsafeProjection.(InterpretedUnsafeProjection.scala:40) at org.apache.spark.sql.catalyst.expressions.InterpretedUnsafeProjection$.createProjection(InterpretedUnsafeProjection.scala:112) at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.createInterpretedObject(Projection.scala:127) at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.createInterpretedObject(Projection.scala:119) at org.apache.spark.sql.catalyst.expressions.CodeGeneratorWithInterpretedFallback.createObject(CodeGeneratorWithInterpretedFallback.scala:56) at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:150) at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:160) at org.apache.spark.sql.execution.ProjectExec.$anonfun$doExecute$1(basicPhysicalOperators.scala:95) at org.apache.spark.sql.execution.ProjectExec.$anonfun$doExecute$1$adapted(basicPhysicalOperators.scala:94) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2(RDD.scala:877) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2$adapted(RDD.scala:877) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365) at org.apache.spark.rdd.RDD.iterator(RDD.scala:329) at org.apache.spark.rdd.UnionRDD.compute(UnionRDD.scala:106) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365) at org.apache.spark.rdd.RDD.iterator(RDD.scala:329) at org.apache.spark.rdd.CoalescedRDD.$anonfun$compute$1(CoalescedRDD.scala:99) at scala.collection.Iterator$$anon$11.nextCur(Iterator.scala:486) at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:492) at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460) at org.apache.spark.sql.execution.datasources.FileFormatDataWriter.writeWithIterator(FileFormatDataWriter.scala:91) at org.apache.spark.sql.execution.datasources.FileFormatWriter$.$anonfun$executeTask$1(FileFormatWriter.scala:331) at org.apache.spark.util.Utils$.tryWithSafeFinallyAndFailureCallbacks(Utils.scala:1538) at org.apache.spark.sql.execution.datasources.FileFormatWriter$.executeTask(FileFormatWriter.scala:338) ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing UTs. Closes #39010 from peter-toth/SPARK-41468-fix-planexpressions-in-equivalentexpressions. Authored-by: Peter Toth Signed-off-by: Wenchen Fan (cherry picked from commit 1b2d7001f2924738b61609a5399ebc152969b5c8) Signed-off-by: Wenchen Fan --- .../expressions/EquivalentExpressions.scala | 33 +++++++++++++------ 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 2bbde304c2815..3ffd9f9d88750 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -142,6 +142,20 @@ class EquivalentExpressions { case _ => Nil } + private def supportedExpression(e: Expression) = { + !e.exists { + // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the + // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. + case _: LambdaVariable => true + + // `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor, + // can cause error like NPE. + case _: PlanExpression[_] => Utils.isInRunningSparkTask + + case _ => false + } + } + /** * Adds the expression to this data structure recursively. Stops if a matching expression * is found. That is, if `expr` has already been added, its children are not added. @@ -149,21 +163,16 @@ class EquivalentExpressions { def addExprTree( expr: Expression, map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap): Unit = { - updateExprTree(expr, map) + if (supportedExpression(expr)) { + updateExprTree(expr, map) + } } private def updateExprTree( expr: Expression, map: mutable.HashMap[ExpressionEquals, ExpressionStats] = equivalenceMap, useCount: Int = 1): Unit = { - val skip = useCount == 0 || - expr.isInstanceOf[LeafExpression] || - // `LambdaVariable` is usually used as a loop variable, which can't be evaluated ahead of the - // loop. So we can't evaluate sub-expressions containing `LambdaVariable` at the beginning. - expr.exists(_.isInstanceOf[LambdaVariable]) || - // `PlanExpression` wraps query plan. To compare query plans of `PlanExpression` on executor, - // can cause error like NPE. - (expr.exists(_.isInstanceOf[PlanExpression[_]]) && Utils.isInRunningSparkTask) + val skip = useCount == 0 || expr.isInstanceOf[LeafExpression] if (!skip && !updateExprInMap(expr, map, useCount)) { val uc = useCount.signum @@ -177,7 +186,11 @@ class EquivalentExpressions { * equivalent expressions. */ def getExprState(e: Expression): Option[ExpressionStats] = { - equivalenceMap.get(ExpressionEquals(e)) + if (supportedExpression(e)) { + equivalenceMap.get(ExpressionEquals(e)) + } else { + None + } } // Exposed for testing.