From 9520087409d5bd7e6a2651dacf2c295d564d5559 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Mon, 11 Sep 2023 11:18:46 -0700 Subject: [PATCH] [SPARK-44647][SQL] Support SPJ where join keys are less than cluster keys ### What changes were proposed in this pull request? - Add new conf spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled - Change key compatibility checks in EnsureRequirements. Remove checks where all partition keys must be in join keys to allow isKeyCompatible = true in this case (if this flag is enabled) - Change BatchScanExec/DataSourceV2Relation to group splits by join keys if they differ from partition keys (previously grouped only by partition values). Do same for all auxiliary data structure, like commonPartValues. - Implement partiallyClustered skew-handling. - Group only the replicate side (now by join key as well), replicate by the total size of other-side partitions that share the join key. - add an additional sort for partitions based on join key, as when we group the replicate side, partition ordering becomes out of order from the non-replicate side. ### Why are the changes needed? - Support Storage Partition Join in cases where the join condition does not contain all the partition keys, but just some of them ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? -Added tests in KeyGroupedPartitioningSuite -Found two existing problems, will address in separate PR: - Because of https://github.com/apache/spark/pull/37886 we have to select all join keys to trigger SPJ in this case, otherwise DSV2 scan does not report KeyGroupedPartitioning and SPJ does not get triggered. Need to see how to relax this. - https://issues.apache.org/jira/browse/SPARK-44641 was found when testing this change. This pr refactors some of those code to add group-by-join-key, but doesnt change the underlying logic, so issue continues to exist. Hopefully this will also get fixed in another way. Closes #42306 from szehon-ho/spj_attempt_master. Authored-by: Szehon Ho Signed-off-by: Dongjoon Hyun --- .../plans/physical/partitioning.scala | 59 +++- .../apache/spark/sql/internal/SQLConf.scala | 15 + .../datasources/v2/BatchScanExec.scala | 56 ++-- .../exchange/EnsureRequirements.scala | 15 +- .../KeyGroupedPartitioningSuite.scala | 265 +++++++++++++++++- 5 files changed, 378 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 0be4a61f27587..a61bd3b7324be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -355,7 +355,14 @@ case class KeyGroupedPartitioning( } else { // We'll need to find leaf attributes from the partition expressions first. val attributes = expressions.flatMap(_.collectLeaves()) - attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) + + if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { + // check that all join keys (required clustering keys) contained in partitioning + requiredClustering.forall(x => attributes.exists(_.semanticEquals(x))) && + expressions.forall(_.collectLeaves().size == 1) + } else { + attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) + } } case _ => @@ -364,8 +371,20 @@ case class KeyGroupedPartitioning( } } - override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = - KeyGroupedShuffleSpec(this, distribution) + override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = { + val result = KeyGroupedShuffleSpec(this, distribution) + if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { + // If allowing join keys to be subset of clustering keys, we should create a new + // `KeyGroupedPartitioning` here that is grouped on the join keys instead, and use that as + // the returned shuffle spec. + val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2) + val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions, + partitionValues, originalPartitionValues) + result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions)) + } else { + result + } + } lazy val uniquePartitionValues: Seq[InternalRow] = { partitionValues @@ -378,8 +397,25 @@ case class KeyGroupedPartitioning( object KeyGroupedPartitioning { def apply( expressions: Seq[Expression], - partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { - KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues, partitionValues) + projectionPositions: Seq[Int], + partitionValues: Seq[InternalRow], + originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = { + val projectedExpressions = projectionPositions.map(expressions(_)) + val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _)) + val projectedOriginalPartitionValues = + originalPartitionValues.map(project(expressions, projectionPositions, _)) + + KeyGroupedPartitioning(projectedExpressions, projectedPartitionValues.length, + projectedPartitionValues, projectedOriginalPartitionValues) + } + + def project( + expressions: Seq[Expression], + positions: Seq[Int], + input: InternalRow): InternalRow = { + val projectedValues: Array[Any] = positions.map(i => input.get(i, expressions(i).dataType)) + .toArray + new GenericInternalRow(projectedValues) } def supportsExpressions(expressions: Seq[Expression]): Boolean = { @@ -672,9 +708,18 @@ case class HashShuffleSpec( override def numPartitions: Int = partitioning.numPartitions } +/** + * [[ShuffleSpec]] created by [[KeyGroupedPartitioning]]. + * + * @param partitioning key grouped partitioning + * @param distribution distribution + * @param joinKeyPosition position of join keys among cluster keys. + * This is set if joining on a subset of cluster keys is allowed. + */ case class KeyGroupedShuffleSpec( partitioning: KeyGroupedPartitioning, - distribution: ClusteredDistribution) extends ShuffleSpec { + distribution: ClusteredDistribution, + joinKeyPositions: Option[Seq[Int]] = None) extends ShuffleSpec { /** * A sequence where each element is a set of positions of the partition expression to the cluster @@ -709,7 +754,7 @@ case class KeyGroupedShuffleSpec( // 3.3 each pair of partition expressions at the same index must share compatible // transform functions. // 4. the partition values from both sides are following the same order. - case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution) => + case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) => distribution.clustering.length == otherDistribution.clustering.length && numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8c8b33921e321..49a4b0bf98bb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1530,6 +1530,18 @@ object SQLConf { .booleanConf .createWithDefault(false) + val V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS = + buildConf("spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled") + .doc("Whether to allow storage-partition join in the case where join keys are" + + "a subset of the partition keys of the source tables. At planning time, " + + "Spark will group the partitions by only those keys that are in the join keys." + + s"This is currently enabled only if ${REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key} " + + "is false." + ) + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") .doc("The maximum number of buckets allowed.") .version("2.4.0") @@ -4936,6 +4948,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def v2BucketingShuffleEnabled: Boolean = getConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED) + def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean = + getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS) + def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 932ac0f5a1b15..094a7b20808ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -120,7 +120,12 @@ case class BatchScanExec( val newPartValues = spjParams.commonPartitionValues.get.flatMap { case (partValue, numSplits) => Seq.fill(numSplits)(partValue) } - k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues) + val expressions = spjParams.joinKeyPositions match { + case Some(projectionPositions) => projectionPositions.map(i => k.expressions(i)) + case _ => k.expressions + } + k.copy(expressions = expressions, numPartitions = newPartValues.length, + partitionValues = newPartValues) case p => p } } @@ -132,14 +137,29 @@ case class BatchScanExec( // return an empty RDD with 1 partition if dynamic filtering removed the only split sparkContext.parallelize(Array.empty[InternalRow], 1) } else { - var finalPartitions = filteredPartitions - - outputPartitioning match { + val finalPartitions = outputPartitioning match { case p: KeyGroupedPartitioning => - val groupedPartitions = filteredPartitions.map(splits => { - assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey]) - (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits) - }) + assert(spjParams.keyGroupedPartitioning.isDefined) + val expressions = spjParams.keyGroupedPartitioning.get + + // Re-group the input partitions if we are projecting on a subset of join keys + val (groupedPartitions, partExpressions) = spjParams.joinKeyPositions match { + case Some(projectPositions) => + val projectedExpressions = projectPositions.map(i => expressions(i)) + val parts = filteredPartitions.flatten.groupBy(part => { + val row = part.asInstanceOf[HasPartitionKey].partitionKey() + val projectedRow = KeyGroupedPartitioning.project( + expressions, projectPositions, row) + InternalRowComparableWrapper(projectedRow, projectedExpressions) + }).map { case (wrapper, splits) => (wrapper.row, splits) }.toSeq + (parts, projectedExpressions) + case _ => + val groupedParts = filteredPartitions.map(splits => { + assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey]) + (splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits) + }) + (groupedParts, expressions) + } // When partially clustered, the input partitions are not grouped by partition // values. Here we'll need to check `commonPartitionValues` and decide how to group @@ -149,12 +169,12 @@ case class BatchScanExec( // should contain. val commonPartValuesMap = spjParams.commonPartitionValues .get - .map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2)) + .map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2)) .toMap val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) => // `commonPartValuesMap` should contain the part value since it's the super set. val numSplits = commonPartValuesMap - .get(InternalRowComparableWrapper(partValue, p.expressions)) + .get(InternalRowComparableWrapper(partValue, partExpressions)) assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + "common partition values from Spark plan") @@ -169,37 +189,37 @@ case class BatchScanExec( // sides of a join will have the same number of partitions & splits. splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) } - (InternalRowComparableWrapper(partValue, p.expressions), newSplits) + (InternalRowComparableWrapper(partValue, partExpressions), newSplits) } // Now fill missing partition keys with empty partitions val partitionMapping = nestGroupedPartitions.toMap - finalPartitions = spjParams.commonPartitionValues.get.flatMap { + spjParams.commonPartitionValues.get.flatMap { case (partValue, numSplits) => // Use empty partition for those partition values that are not present. partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), + InternalRowComparableWrapper(partValue, partExpressions), Seq.fill(numSplits)(Seq.empty)) } } else { // either `commonPartitionValues` is not defined, or it is defined but // `applyPartialClustering` is false. val partitionMapping = groupedPartitions.map { case (partValue, splits) => - InternalRowComparableWrapper(partValue, p.expressions) -> splits + InternalRowComparableWrapper(partValue, partExpressions) -> splits }.toMap // In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there // could exist duplicated partition values, as partition grouping is not done // at the beginning and postponed to this method. It is important to use unique // partition values here so that grouped partitions won't get duplicated. - finalPartitions = p.uniquePartitionValues.map { partValue => + p.uniquePartitionValues.map { partValue => // Use empty partition for those partition values that are not present partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), Seq.empty) + InternalRowComparableWrapper(partValue, partExpressions), Seq.empty) } } - case _ => + case _ => filteredPartitions } new DataSourceRDD( @@ -234,6 +254,7 @@ case class BatchScanExec( case class StoragePartitionJoinParams( keyGroupedPartitioning: Option[Seq[Expression]] = None, + joinKeyPositions: Option[Seq[Int]] = None, commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, applyPartialClustering: Boolean = false, replicatePartitions: Boolean = false) { @@ -247,6 +268,7 @@ case class StoragePartitionJoinParams( } override def hashCode(): Int = Objects.hashCode( + joinKeyPositions: Option[Seq[Int]], commonPartitionValues: Option[Seq[(InternalRow, Int)]], applyPartialClustering: java.lang.Boolean, replicatePartitions: java.lang.Boolean) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index f8e6fd1d0167f..8552c950f6776 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -380,7 +380,8 @@ case class EnsureRequirements( val rightSpec = specs(1) var isCompatible = false - if (!conf.v2BucketingPushPartValuesEnabled) { + if (!conf.v2BucketingPushPartValuesEnabled && + !conf.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { isCompatible = leftSpec.isCompatibleWith(rightSpec) } else { logInfo("Pushing common partition values for storage-partitioned join") @@ -505,10 +506,10 @@ case class EnsureRequirements( } // Now we need to push-down the common partition key to the scan in each child - newLeft = populatePartitionValues( - left, mergedPartValues, applyPartialClustering, replicateLeftSide) - newRight = populatePartitionValues( - right, mergedPartValues, applyPartialClustering, replicateRightSide) + newLeft = populatePartitionValues(left, mergedPartValues, leftSpec.joinKeyPositions, + applyPartialClustering, replicateLeftSide) + newRight = populatePartitionValues(right, mergedPartValues, rightSpec.joinKeyPositions, + applyPartialClustering, replicateRightSide) } } @@ -530,19 +531,21 @@ case class EnsureRequirements( private def populatePartitionValues( plan: SparkPlan, values: Seq[(InternalRow, Int)], + joinKeyPositions: Option[Seq[Int]], applyPartialClustering: Boolean, replicatePartitions: Boolean): SparkPlan = plan match { case scan: BatchScanExec => scan.copy( spjParams = scan.spjParams.copy( commonPartitionValues = Some(values), + joinKeyPositions = joinKeyPositions, applyPartialClustering = applyPartialClustering, replicatePartitions = replicatePartitions ) ) case node => node.mapChildren(child => populatePartitionValues( - child, values, applyPartialClustering, replicatePartitions)) + child, values, joinKeyPositions, applyPartialClustering, replicatePartitions)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index b22aba61aabd8..ffd1c8e31e919 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -98,14 +98,17 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val catalystDistribution = physical.ClusteredDistribution( Seq(TransformExpression(YearsFunction, Seq(attr("ts"))))) val partitionValues = Seq(50, 51, 52).map(v => InternalRow.fromSeq(Seq(v))) + val projectedPositions = catalystDistribution.clustering.indices checkQueryPlan(df, catalystDistribution, - physical.KeyGroupedPartitioning(catalystDistribution.clustering, partitionValues)) + physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, + partitionValues, partitionValues)) // multiple group keys should work too as long as partition keys are subset of them df = sql(s"SELECT count(*) FROM testcat.ns.$table GROUP BY id, ts") checkQueryPlan(df, catalystDistribution, - physical.KeyGroupedPartitioning(catalystDistribution.clustering, partitionValues)) + physical.KeyGroupedPartitioning(catalystDistribution.clustering, projectedPositions, + partitionValues, partitionValues)) } test("non-clustered distribution: no partition") { @@ -1276,4 +1279,262 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } } + + test("SPARK-44647: test join key is subset of cluster key " + + "with push values and partially-clustered") { + val table1 = "tab1e1" + val table2 = "table2" + val partition = Array(identity("id"), identity("data")) + createTable(table1, schema, partition) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(1, 'aa', cast('2020-01-01' as timestamp)), " + + "(2, 'bb', cast('2020-01-01' as timestamp)), " + + "(2, 'cc', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'dd', cast('2020-01-01' as timestamp)), " + + "(3, 'ee', cast('2020-01-01' as timestamp)), " + + "(3, 'ee', cast('2020-01-01' as timestamp))") + + createTable(table2, schema, partition) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(4, 'zz', cast('2020-01-01' as timestamp)), " + + "(4, 'zz', cast('2020-01-01' as timestamp)), " + + "(3, 'yy', cast('2020-01-01' as timestamp)), " + + "(3, 'yy', cast('2020-01-01' as timestamp)), " + + "(3, 'xx', cast('2020-01-01' as timestamp)), " + + "(3, 'xx', cast('2020-01-01' as timestamp)), " + + "(2, 'ww', cast('2020-01-01' as timestamp))") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClustered => + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString) { + + val df = sql("SELECT t1.id AS id, t1.data AS t1data, t2.data AS t2data " + + s"FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 " + + "ON t1.id = t2.id ORDER BY t1.id, t1data, t2data") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (allowJoinKeysSubsetOfPartitionKeys) { + assert(shuffles.isEmpty, "SPJ should be triggered") + } else { + assert(shuffles.nonEmpty, "SPJ should not be triggered") + } + + val scans = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + + (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { + // SPJ and partially-clustered + case (true, true) => assert(scans == Seq(8, 8)) + // SPJ and not partially-clustered + case (true, false) => assert(scans == Seq(4, 4)) + // No SPJ + case _ => assert(scans == Seq(5, 4)) + } + + checkAnswer(df, Seq( + Row(2, "bb", "ww"), + Row(2, "cc", "ww"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "xx"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "dd", "yy"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "xx"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy"), + Row(3, "ee", "yy") + )) + } + } + } + } + } + + test("SPARK-44647: test join key is the second cluster key") { + val table1 = "tab1e1" + val table2 = "table2" + val partition = Array(identity("id"), identity("data")) + createTable(table1, schema, partition) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(1, 'aa', cast('2020-01-01' as timestamp)), " + + "(2, 'bb', cast('2020-01-02' as timestamp)), " + + "(3, 'cc', cast('2020-01-03' as timestamp))") + + createTable(table2, schema, partition) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(4, 'aa', cast('2020-01-01' as timestamp)), " + + "(5, 'bb', cast('2020-01-02' as timestamp)), " + + "(6, 'cc', cast('2020-01-03' as timestamp))") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClustered => + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> + pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString) { + + val df = sql("SELECT t1.id AS t1id, t2.id as t2id, t1.data AS data " + + s"FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 " + + "ON t1.data = t2.data ORDER BY t1id, t1id, data") + + checkAnswer(df, Seq(Row(1, 4, "aa"), Row(2, 5, "bb"), Row(3, 6, "cc"))) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (allowJoinKeysSubsetOfPartitionKeys) { + assert(shuffles.isEmpty, "SPJ should be triggered") + } else { + assert(shuffles.nonEmpty, "SPJ should not be triggered") + } + + val scans = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + (pushDownValues, allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { + // SPJ and partially-clustered + case (true, true, true) => assert(scans == Seq(3, 3)) + // non-SPJ or SPJ/partially-clustered + case _ => assert(scans == Seq(3, 3)) + } + } + } + } + } + } + + test("SPARK-44647: test join key is the second partition key and a transform") { + val items_partitions = Array(bucket(8, "id"), days("arrive_time")) + createTable(items, items_schema, items_partitions) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + s"(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + s"(1, 'aa', 41.0, cast('2020-01-15' as timestamp)), " + + s"(2, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + s"(2, 'bb', 10.5, cast('2020-01-01' as timestamp)), " + + s"(3, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + val purchases_partitions = Array(bucket(8, "item_id"), days("time")) + createTable(purchases, purchases_schema, purchases_partitions) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + s"(1, 42.0, cast('2020-01-01' as timestamp)), " + + s"(1, 44.0, cast('2020-01-15' as timestamp)), " + + s"(1, 45.0, cast('2020-01-15' as timestamp)), " + + s"(2, 11.0, cast('2020-01-01' as timestamp)), " + + s"(3, 19.5, cast('2020-02-01' as timestamp))") + + Seq(true, false).foreach { pushDownValues => + Seq(true, false).foreach { partiallyClustered => + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString) { + val df = sql("SELECT id, name, i.price as purchase_price, " + + "p.item_id, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.arrive_time = p.time " + + "ORDER BY id, purchase_price, p.item_id, sale_price") + + // Currently SPJ for case where join key not same as partition key + // only supported when push-part-values enabled + val shuffles = collectShuffles(df.queryExecution.executedPlan) + if (allowJoinKeysSubsetOfPartitionKeys) { + assert(shuffles.isEmpty, "SPJ should be triggered") + } else { + assert(shuffles.nonEmpty, "SPJ should not be triggered") + } + + val scans = collectScans(df.queryExecution.executedPlan) + .map(_.inputRDD.partitions.length) + (allowJoinKeysSubsetOfPartitionKeys, partiallyClustered) match { + // SPJ and partially-clustered + case (true, true) => assert(scans == Seq(5, 5)) + // SPJ and not partially-clustered + case (true, false) => assert(scans == Seq(3, 3)) + // No SPJ + case _ => assert(scans == Seq(4, 4)) + } + + checkAnswer(df, + Seq( + Row(1, "aa", 40.0, 1, 42.0), + Row(1, "aa", 40.0, 2, 11.0), + Row(1, "aa", 41.0, 1, 44.0), + Row(1, "aa", 41.0, 1, 45.0), + Row(2, "bb", 10.0, 1, 42.0), + Row(2, "bb", 10.0, 2, 11.0), + Row(2, "bb", 10.5, 1, 42.0), + Row(2, "bb", 10.5, 2, 11.0), + Row(3, "cc", 15.5, 3, 19.5) + ) + ) + } + } + } + } + } + + test("SPARK-44647: shuffle one side and join keys are less than partition keys") { + val items_partitions = Array(identity("id"), identity("name")) + createTable(items, items_schema, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 30.0, cast('2020-01-02' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + createTable(purchases, purchases_schema, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 89.0, cast('2020-01-03' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp)), " + + "(5, 26.0, cast('2023-01-01' as timestamp)), " + + "(6, 50.0, cast('2023-02-01' as timestamp))") + + Seq(true, false).foreach { pushdownValues => + Seq(true, false).foreach { partiallyClustered => + withSQLConf( + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key + -> partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { + val df = sql("SELECT id, name, i.price as purchase_price, p.price as sale_price " + + s"FROM testcat.ns.$items i JOIN testcat.ns.$purchases p " + + "ON i.id = p.item_id ORDER BY id, purchase_price, sale_price") + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 1, "SPJ should be triggered") + checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0), + Row(1, "aa", 30.0, 89.0), + Row(1, "aa", 40.0, 42.0), + Row(1, "aa", 40.0, 89.0), + Row(3, "bb", 10.0, 19.5))) + } + } + } + } }