From 7fd0d081539369a21128517ce4c995d472b4b876 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Thu, 18 Jul 2024 18:27:52 -0700 Subject: [PATCH] [SPARK-48949][SQL] SPJ: Runtime partition filtering ### What changes were proposed in this pull request? Introduce runtime partition filtering for SPJ. In planning, we have the list of partition values on both sides to plan the tasks. We can thus filter out partition values based on the join type. ### Why are the changes needed? In some common join types (INNER, LEFT, RIGHT), we have an opportunity to greatly reduce the data scanned in SPJ. For example, a small table joining a larger table by partition key, can prune out most of the partitions of the large table. There is some similarity with the concept of DPP, but that uses heuristics and this is more exact as SPJ planning requires us anyway to list out both sides partitioning. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New tests in KeyGroupedPartitioningSuite. --- .../util/InternalRowComparableWrapper.scala | 30 +- .../apache/spark/sql/internal/SQLConf.scala | 11 + ...nternalRowComparableWrapperBenchmark.scala | 2 +- .../datasources/v2/BatchScanExec.scala | 7 +- .../exchange/EnsureRequirements.scala | 82 +++-- .../KeyGroupedPartitioningSuite.scala | 345 ++++++++++++++---- 6 files changed, 368 insertions(+), 109 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala index 90e3bdcd082cd..d2bdad2d880de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapper.scala @@ -21,7 +21,6 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, Murmur3HashFunction, RowOrdering} -import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition} import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.NonFateSharingCache @@ -85,22 +84,25 @@ object InternalRowComparableWrapper { } def mergePartitions( - leftPartitioning: KeyGroupedPartitioning, - rightPartitioning: KeyGroupedPartitioning, - partitionExpression: Seq[Expression]): Seq[InternalRow] = { + leftPartitioning: Seq[InternalRow], + rightPartitioning: Seq[InternalRow], + partitionExpression: Seq[Expression], + intersect: Boolean = false): Seq[InternalRowComparableWrapper] = { val partitionDataTypes = partitionExpression.map(_.dataType) - val partitionsSet = new mutable.HashSet[InternalRowComparableWrapper] - leftPartitioning.partitionValues + val leftPartitionSet = new mutable.HashSet[InternalRowComparableWrapper] + leftPartitioning .map(new InternalRowComparableWrapper(_, partitionDataTypes)) - .foreach(partition => partitionsSet.add(partition)) - rightPartitioning.partitionValues + .foreach(partition => leftPartitionSet.add(partition)) + val rightPartitionSet = new mutable.HashSet[InternalRowComparableWrapper] + rightPartitioning .map(new InternalRowComparableWrapper(_, partitionDataTypes)) - .foreach(partition => partitionsSet.add(partition)) - // SPARK-41471: We keep to order of partitions to make sure the order of - // partitions is deterministic in different case. - val partitionOrdering: Ordering[InternalRow] = { - RowOrdering.createNaturalAscendingOrdering(partitionDataTypes) + .foreach(partition => rightPartitionSet.add(partition)) + + val result = if (intersect) { + leftPartitionSet.intersect(rightPartitionSet) + } else { + leftPartitionSet.union(rightPartitionSet) } - partitionsSet.map(_.row).toSeq.sorted(partitionOrdering) + result.toSeq } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f50eb9b121589..ac5aa0f6bbdc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1635,6 +1635,17 @@ object SQLConf { .booleanConf .createWithDefault(false) + val V2_BUCKETING_PARTITION_FILTER_ENABLED = + buildConf("spark.sql.sources.v2.bucketing.partition.filter.enabled") + .doc(s"Whether to filter partitions when running storage-partition join. " + + s"When enabled, partitions without matches on the other side can be omitted for " + + s"scanning, if allowed by the join type. This config requires both " + + s"${V2_BUCKETING_ENABLED.key} and ${V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key} to be " + + s"enabled.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") .doc("The maximum number of buckets allowed.") .version("2.4.0") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala index cc28e85525162..f3dd232129e8b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/InternalRowComparableWrapperBenchmark.scala @@ -61,7 +61,7 @@ object InternalRowComparableWrapperBenchmark extends BenchmarkBase { val leftPartitioning = KeyGroupedPartitioning(expressions, bucketNum, partitions) val rightPartitioning = KeyGroupedPartitioning(expressions, bucketNum, partitions) val merged = InternalRowComparableWrapper.mergePartitions( - leftPartitioning, rightPartitioning, expressions) + leftPartitioning.partitionValues, rightPartitioning.partitionValues, expressions) assert(merged.size == bucketNum) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 997576a396d20..6a502a44fad58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -200,7 +200,12 @@ case class BatchScanExec( .get .map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2)) .toMap - val nestGroupedPartitions = finalGroupedPartitions.map { case (partValue, splits) => + val filteredGroupedPartitions = finalGroupedPartitions.filter { + case (partValues, _) => + commonPartValuesMap.keySet.contains( + InternalRowComparableWrapper(partValues, partExpressions)) + } + val nestGroupedPartitions = filteredGroupedPartitions.map { case (partValue, splits) => // `commonPartValuesMap` should contain the part value since it's the super set. val numSplits = commonPartValuesMap .get(InternalRowComparableWrapper(partValue, partExpressions)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 0470aacd4f823..90287c2028467 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -429,8 +429,19 @@ case class EnsureRequirements( // expressions val partitionExprs = leftSpec.partitioning.expressions - var mergedPartValues = InternalRowComparableWrapper - .mergePartitions(leftSpec.partitioning, rightSpec.partitioning, partitionExprs) + // in case of compatible but not identical partition expressions, we apply 'reduce' + // transforms to group one side's partitions as well as the common partition values + val leftReducers = leftSpec.reducers(rightSpec) + val leftParts = reducePartValues(leftSpec.partitioning.partitionValues, + partitionExprs, + leftReducers) + val rightReducers = rightSpec.reducers(leftSpec) + val rightParts = reducePartValues(rightSpec.partitioning.partitionValues, + partitionExprs, + rightReducers) + + // merge values on both sides + var mergedPartValues = mergePartitions(leftParts, rightParts, partitionExprs, joinType) .map(v => (v, 1)) logInfo(log"After merging, there are " + @@ -525,23 +536,6 @@ case class EnsureRequirements( } } - // in case of compatible but not identical partition expressions, we apply 'reduce' - // transforms to group one side's partitions as well as the common partition values - val leftReducers = leftSpec.reducers(rightSpec) - val rightReducers = rightSpec.reducers(leftSpec) - - if (leftReducers.isDefined || rightReducers.isDefined) { - mergedPartValues = reduceCommonPartValues(mergedPartValues, - leftSpec.partitioning.expressions, - leftReducers) - mergedPartValues = reduceCommonPartValues(mergedPartValues, - rightSpec.partitioning.expressions, - rightReducers) - val rowOrdering = RowOrdering - .createNaturalAscendingOrdering(partitionExprs.map(_.dataType)) - mergedPartValues = mergedPartValues.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) - } - // Now we need to push-down the common partition information to the scan in each child newLeft = populateCommonPartitionInfo(left, mergedPartValues, leftSpec.joinKeyPositions, leftReducers, applyPartialClustering, replicateLeftSide) @@ -602,15 +596,15 @@ case class EnsureRequirements( child, joinKeyPositions)) } - private def reduceCommonPartValues( - commonPartValues: Seq[(InternalRow, Int)], + private def reducePartValues( + partValues: Seq[InternalRow], expressions: Seq[Expression], reducers: Option[Seq[Option[Reducer[_, _]]]]) = { reducers match { - case Some(reducers) => commonPartValues.groupBy { case (row, _) => + case Some(reducers) => partValues.map { row => KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers) - }.map{ case(wrapper, splits) => (wrapper.row, splits.map(_._2).sum) }.toSeq - case _ => commonPartValues + }.distinct.map(_.row) + case _ => partValues } } @@ -651,6 +645,46 @@ case class EnsureRequirements( } } + /** + * Merge and sort partitions values for SPJ and optionally enable partition filtering. + * Both sides must have + * matching partition expressions. + * @param leftPartitioning left side partition values + * @param rightPartitioning right side partition values + * @param partitionExpression partition expressions + * @param joinType join type for optional partition filtering + * @return merged and sorted partition values + */ + private def mergePartitions( + leftPartitioning: Seq[InternalRow], + rightPartitioning: Seq[InternalRow], + partitionExpression: Seq[Expression], + joinType: JoinType): Seq[InternalRow] = { + + val merged = if (SQLConf.get.getConf(SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED)) { + joinType match { + case Inner => InternalRowComparableWrapper.mergePartitions( + leftPartitioning, rightPartitioning, partitionExpression, intersect = true) + case LeftOuter => leftPartitioning.map( + InternalRowComparableWrapper(_, partitionExpression)) + case RightOuter => rightPartitioning.map( + InternalRowComparableWrapper(_, partitionExpression)) + case _ => InternalRowComparableWrapper.mergePartitions(leftPartitioning, + rightPartitioning, partitionExpression) + } + } else { + InternalRowComparableWrapper.mergePartitions(leftPartitioning, rightPartitioning, + partitionExpression) + } + + // SPARK-41471: We keep to order of partitions to make sure the order of + // partitions is deterministic in different case. + val partitionOrdering: Ordering[InternalRow] = { + RowOrdering.createNaturalAscendingOrdering(partitionExpression.map(_.dataType)) + } + merged.map(_.row).sorted(partitionOrdering) + } + def apply(plan: SparkPlan): SparkPlan = { val newPlan = plan.transformUp { case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin, _) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 5e5453b4cd500..99d99fede8485 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -667,11 +667,12 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { s"(5, 30.0, cast('2023-01-01' as timestamp))") Seq(true, false).foreach { pushDownValues => - Seq(("true", 10), ("false", 5)).foreach { - case (enable, expected) => + Seq((true, true, 8), (false, true, 3), (true, false, 10), (false, false, 5)).foreach { + case (partial, filter, expected) => withSQLConf( - SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, - SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) { + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> filter.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> partial.toString) { val df = createJoinTestDF(Seq("id" -> "item_id")) val shuffles = collectShuffles(df.queryExecution.executedPlan) if (pushDownValues) { @@ -692,6 +693,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("SPARK-42038: partially clustered: with different partition keys and missing keys on " + "left-hand side") { val items_partitions = Array(identity("id")) @@ -715,11 +717,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { s"(5, 30.0, cast('2023-01-01' as timestamp))") Seq(true, false).foreach { pushDownValues => - Seq(("true", 9), ("false", 5)).foreach { - case (enable, expected) => + Seq((true, true, 3), (false, true, 2), (true, false, 9), (false, false, 5)).foreach { + case(partial, filter, expected) => withSQLConf( SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, - SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) { + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> filter.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partial.toString) { val df = createJoinTestDF(Seq("id" -> "item_id")) val shuffles = collectShuffles(df.queryExecution.executedPlan) if (pushDownValues) { @@ -759,11 +763,13 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { s"(5, 30.0, cast('2023-01-01' as timestamp))") Seq(true, false).foreach { pushDownValues => - Seq(("true", 6), ("false", 5)).foreach { - case (enable, expected) => + Seq((true, true, 2), (false, true, 2), (true, false, 6), (false, false, 5)).foreach { + case (partial, filter, expected) => withSQLConf( SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, - SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) { + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> filter.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partial.toString) { val df = createJoinTestDF(Seq("id" -> "item_id")) val shuffles = collectShuffles(df.queryExecution.executedPlan) if (pushDownValues) { @@ -802,12 +808,14 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { // In a left-outer join, and when the left side has larger stats, partially clustered // distribution should kick in and pick the right hand side to replicate partitions. Seq(true, false).foreach { pushDownValues => - Seq(("true", 7), ("false", 5)).foreach { - case (enable, expected) => + Seq((true, true, 5), (false, true, 3), (true, false, 7), (false, false, 5)).foreach { + case (partial, filter, expected) => withSQLConf( SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> false.toString, SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, - SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> enable) { + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> filter.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partial.toString) { val df = createJoinTestDF( Seq("id" -> "item_id", "arrive_time" -> "time"), joinType = "LEFT") val shuffles = collectShuffles(df.queryExecution.executedPlan) @@ -815,7 +823,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { assert(shuffles.isEmpty, "should not contain any shuffle") val scans = collectScans(df.queryExecution.executedPlan) assert(scans.forall(_.inputRDD.partitions.length == expected), - s"Expected $expected but got ${scans.head.inputRDD.partitions.length}") + s"Expected $expected but got ${scans.head.inputRDD.partitions.length}") } else { assert(shuffles.nonEmpty, "should contain shuffle when not pushing down partition values") @@ -1336,62 +1344,71 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { "(2, 'ww', cast('2020-01-01' as timestamp))") Seq(true, false).foreach { pushDownValues => - Seq(true, false).foreach { partiallyClustered => - Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => - - withSQLConf( - SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", - SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, - SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> - partiallyClustered.toString, - SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> - allowJoinKeysSubsetOfPartitionKeys.toString) { - val df = sql( - s""" - |${selectWithMergeJoinHint("t1", "t2")} - |t1.id AS id, t1.data AS t1data, t2.data AS t2data - |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 - |ON t1.id = t2.id ORDER BY t1.id, t1data, t2data - |""".stripMargin) - val shuffles = collectShuffles(df.queryExecution.executedPlan) - if (allowJoinKeysSubsetOfPartitionKeys) { - assert(shuffles.isEmpty, "SPJ should be triggered") - } else { - assert(shuffles.nonEmpty, "SPJ should not be triggered") - } - - val scans = collectScans(df.queryExecution.executedPlan) - .map(_.inputRDD.partitions.length) - - (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { - // SPJ and partially-clustered - case (true, true) => assert(scans == Seq(8, 8)) - // SPJ and not partially-clustered - case (true, false) => assert(scans == Seq(4, 4)) - // No SPJ - case _ => assert(scans == Seq(5, 4)) + Seq(true, false).foreach { filter => + Seq(true, false).foreach { partiallyClustered => + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> filter.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString) { + val df = sql( + s""" + |${selectWithMergeJoinHint("t1", "t2")} + |t1.id AS id, t1.data AS t1data, t2.data AS t2data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.id = t2.id ORDER BY t1.id, t1data, t2data + |""".stripMargin) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (allowJoinKeysSubsetOfPartitionKeys) { + assert(shuffles.isEmpty, "SPJ should be triggered") + } else { + assert(shuffles.nonEmpty, "SPJ should not be triggered") + } + + val scannedPartitions = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered, filter) match { + // SPJ, partially-clustered, with filter + case (true, true, true) => assert(scannedPartitions == Seq(6, 6)) + + // SPJ, partially-clustered, no filter + case (true, true, false) => assert(scannedPartitions == Seq(8, 8)) + + // SPJ and not partially-clustered, with filter + case (true, false, true) => assert(scannedPartitions == Seq(2, 2)) + + // SPJ and not partially-clustered, no filter + case (true, false, false) => assert(scannedPartitions == Seq(4, 4)) + + // No SPJ + case _ => assert(scannedPartitions == Seq(5, 4)) + } + + checkAnswer(df, Seq( + Row(2, "bb", "ww"), + Row(2, "cc", "ww"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy") + )) } - - checkAnswer(df, Seq( - Row(2, "bb", "ww"), - Row(2, "cc", "ww"), - Row(3, "dd", "xx"), - Row(3, "dd", "xx"), - Row(3, "dd", "xx"), - Row(3, "dd", "xx"), - Row(3, "dd", "yy"), - Row(3, "dd", "yy"), - Row(3, "dd", "yy"), - Row(3, "dd", "yy"), - Row(3, "ee", "xx"), - Row(3, "ee", "xx"), - Row(3, "ee", "xx"), - Row(3, "ee", "xx"), - Row(3, "ee", "yy"), - Row(3, "ee", "yy"), - Row(3, "ee", "yy"), - Row(3, "ee", "yy") - )) } } } @@ -2176,4 +2193,194 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Row(3, "bb", 10.0, 19.5))) } } + + test("SPARK-48949: test partition filters inner join") { + val items_partitions = Array(bucket(8, "id"), days("arrive_time")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(0, 'aa', 39.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 41.0, cast('2020-01-03' as timestamp)), " + + s"(3, 'bb', 42.0, cast('2020-01-04' as timestamp)), " + + s"(4, 'cc', 43.5, cast('2020-01-05' as timestamp)), " + + s"(5, 'cc', 44.5, cast('2020-01-15' as timestamp)), " + + s"(6, 'dd', 45.5, cast('2020-02-07' as timestamp))") + + val purchases_partitions = Array(bucket(8, "item_id"), days("time")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 42.0, cast('2020-01-01' as timestamp)), " + + s"(5, 44.0, cast('2020-01-15' as timestamp)), " + + s"(7, 46.5, cast('2020-02-08' as timestamp))") + + withSQLConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> "true") { + + val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time")) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + checkAnswer(df, + Seq(Row(1, "aa", 40.0, 42.0), Row(5, "cc", 44.5, 44.0)) + ) + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.forall(_.inputRDD.partitions.length == 2)) + } + } + + test("SPARK-48949: test partition filters with no matches") { + val items_partitions = Array(bucket(8, "id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(0, 'aa', 39.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 40.0, cast('2020-01-02' as timestamp))") + + val purchases_partitions = Array(bucket(8, "item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(4, 42.0, cast('2020-01-01' as timestamp)), " + + s"(5, 44.0, cast('2020-01-15' as timestamp))") + + withSQLConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> "true") { + + val df = createJoinTestDF(Seq("id" -> "item_id")) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + assert(df.collect().isEmpty, "should return no results") + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.forall(_.inputRDD.partitions.length == 0)) + } + } + + test("SPARK-48949: test partition filters with right outer") { + val items_partitions = Array(bucket(8, "id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(0, 'aa', 39.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 40.0, cast('2020-01-02' as timestamp))") + + val purchases_partitions = Array(bucket(8, "item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 40.0, cast('2020-01-01' as timestamp)), " + + s"(4, 42.0, cast('2020-01-02' as timestamp)), " + + s"(5, 44.0, cast('2020-01-15' as timestamp))") + + withSQLConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> "true") { + + val df = createJoinTestDF(Seq("id" -> "item_id"), joinType = "RIGHT OUTER") + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + + checkAnswer(df, + Seq(Row(null, null, null, 42.0), + Row(null, null, null, 44.0), + Row(1, "aa", 40.0, 40.0)) + ) + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.forall(_.inputRDD.partitions.length == 3)) + } + } + + test("SPARK-48949: test partition filters with full outer") { + val items_partitions = Array(bucket(8, "id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(0, 'aa', 39.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 40.0, cast('2020-01-02' as timestamp))") + + val purchases_partitions = Array(bucket(8, "item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 40.0, cast('2020-01-01' as timestamp)), " + + s"(4, 42.0, cast('2020-01-02' as timestamp)), " + + s"(5, 44.0, cast('2020-01-15' as timestamp))") + + withSQLConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> "true") { + + val df = createJoinTestDF(Seq("id" -> "item_id"), joinType = "FULL OUTER") + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + + checkAnswer(df, + Seq(Row(null, null, null, 42.0), + Row(null, null, null, 44.0), + Row(0, "aa", 39.0, null), + Row(1, "aa", 40.0, 40.0)) + ) + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.forall(_.inputRDD.partitions.length == 4)) + } + } + + test("SPARK-48949: test partition filters with left outer") { + val items_partitions = Array(bucket(8, "id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(0, 'aa', 38.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 39.0, cast('2020-01-02' as timestamp)), " + + s"(4, 'aa', 40.0, cast('2020-01-02' as timestamp))") + + val purchases_partitions = Array(bucket(8, "item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(4, 42.0, cast('2020-01-01' as timestamp)), " + + s"(5, 44.0, cast('2020-01-15' as timestamp))") + + withSQLConf(SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> "true") { + + val df = createJoinTestDF(Seq("id" -> "item_id"), joinType = "LEFT OUTER") + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + + checkAnswer(df, + Seq(Row(0, "aa", 38.0, null), + Row(1, "aa", 39.0, null), + Row(4, "aa", 40.0, 42.0)) + ) + + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.forall(_.inputRDD.partitions.length == 3)) + } + } + + test("SPARK-48949: test partition filters with compatible transforms") { + val items_partitions = Array(bucket(8, "id")) + createTable(items, itemsColumns, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(0, 'aa', 39.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 41.0, cast('2020-01-03' as timestamp)), " + + s"(3, 'bb', 42.0, cast('2020-01-04' as timestamp)), " + + s"(4, 'cc', 43.5, cast('2020-01-05' as timestamp)), " + + s"(5, 'cc', 44.5, cast('2020-01-15' as timestamp)), " + + s"(6, 'dd', 45.5, cast('2020-02-07' as timestamp))") + + val purchases_partitions = Array(bucket(4, "item_id")) + createTable(purchases, purchasesColumns, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 42.0, cast('2020-01-01' as timestamp)), " + + s"(5, 44.0, cast('2020-01-15' as timestamp)), " + + s"(7, 46.5, cast('2020-02-08' as timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + + val df = createJoinTestDF(Seq("id" -> "item_id")) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "should not add shuffle for both sides of the join") + checkAnswer(df, + Seq(Row(1, "aa", 40.0, 42.0), Row(5, "cc", 44.5, 44.0)) + ) + val scans = collectScans(df.queryExecution.executedPlan) + assert(scans.forall(_.inputRDD.partitions.length == 2)) + } + } }