From 61aecfcecd169fec4b94bc2ac152a4e8e0d03bf1 Mon Sep 17 00:00:00 2001 From: Stefan Kandic Date: Tue, 30 Jan 2024 15:34:30 +0100 Subject: [PATCH] add check for determinism in shouldPushFilter method --- .../datasources/DataSourceUtils.scala | 22 +++++++++++-------- .../datasources/FileSourceStrategy.scala | 5 ++--- .../PruneFileSourcePartitions.scala | 4 ++-- .../datasources/v2/FileScanBuilder.scala | 2 +- 4 files changed, 18 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index 7d250edb93f59..8edb8ba51d828 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -286,15 +286,19 @@ object DataSourceUtils extends PredicateHelper { * @param expression The filter expression to be evaluated. * @return A boolean indicating whether the filter should be pushed down or not. */ - def shouldPushFilter(expression: Expression): Boolean = expression match { - case attr: AttributeReference => - attr.dataType match { - // don't push down filters for string columns with non-default collation - // as it could lead to incorrect results - case st: StringType => st.isDefaultCollation - case _ => true - } + def shouldPushFilter(expression: Expression): Boolean = { + def shouldPushFilterRecursive(expression: Expression): Boolean = expression match { + case attr: AttributeReference => + attr.dataType match { + // don't push down filters for string columns with non-default collation + // as it could lead to incorrect results + case st: StringType => st.isDefaultCollation + case _ => true + } + + case _ => expression.children.forall(shouldPushFilterRecursive) + } - case _ => expression.children.forall(shouldPushFilter) + expression.deterministic && shouldPushFilterRecursive(expression) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index c7fa1494eb7b1..9ab3b08648f82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -160,11 +160,10 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { // - filters that need to be evaluated again after the scan val filterSet = ExpressionSet(filters) - val deterministicFiltersToPush = filters - .filter(_.deterministic) + val filtersToPush = filters .filter(f => DataSourceUtils.shouldPushFilter(f)) val normalizedFilters = DataSourceStrategy.normalizeExprs( - deterministicFiltersToPush, l.output) + filtersToPush, l.output) val partitionColumns = l.resolve( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 84f30e82180bf..1f2b0a5509a1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -63,8 +63,8 @@ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] { _)) if filters.nonEmpty && fsRelation.partitionSchema.nonEmpty => val normalizedFilters = DataSourceStrategy.normalizeExprs( - filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f) - && DataSourceUtils.shouldPushFilter(f)), + filters.filter(f => DataSourceUtils.shouldPushFilter(f) && + !SubqueryExpression.hasSubquery(f)), logicalRelation.output) val (partitionKeyFilters, _) = DataSourceUtils .getPartitionFiltersAndDataFilters(partitionSchema, normalizedFilters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala index cfd1c47ccce9c..569be4e387b07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala @@ -71,7 +71,7 @@ abstract class FileScanBuilder( override def pushFilters(filters: Seq[Expression]): Seq[Expression] = { val (filtersToPush, filtersToIgnore) = filters - .partition(f => f.deterministic && DataSourceUtils.shouldPushFilter(f)) + .partition(f => DataSourceUtils.shouldPushFilter(f)) val (partitionFilters, dataFilters) = DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, filtersToPush) this.partitionFilters = partitionFilters