From c3f17a522897d99298d4034890bccae266f2b7c2 Mon Sep 17 00:00:00 2001 From: wakun Date: Fri, 10 Mar 2023 22:12:08 +0800 Subject: [PATCH] [CARMEL-6608] Increase bucket table scan partitions (#1269) --- .../apache/spark/sql/execution/DataSourceScanExec.scala | 7 +++++-- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 6 +----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index cfa8e52a91574..9478fc1e77ea0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -955,8 +955,11 @@ case class FileSourceScanExec( var partitionNumber = FilePartition. minPartitionNumberBySpecifiedSize(fsRelation.sparkSession, selectedPartitions, originSize) - if (partitionNumber > conf.autoBucketedScanMaxPartitions) { - partitionNumber = conf.autoBucketedScanMaxPartitions + val maxBucketedScanMaxPartitions = + Math.max(conf.autoBucketedScanMaxPartitions, relation.bucketSpec.get.numBuckets) + + if (partitionNumber > maxBucketedScanMaxPartitions) { + partitionNumber = maxBucketedScanMaxPartitions FilePartition.maxSplitBytesBySpecifiedNumber( fsRelation.sparkSession, selectedPartitions, partitionNumber) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index da1a20d988aad..6170bfbe7602a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -6694,11 +6694,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark assert(scan.get.optionalBucketSet.isEmpty) assert(scan.get.disableBucketedScan) - if (maxPartitions > 1) { - assert(scan.get.inputRDD.getNumPartitions > 1) - } else { - assert(scan.get.inputRDD.getNumPartitions == 1) - } + assert(scan.get.inputRDD.getNumPartitions == 2) } } }