From 8319e4301288251335567963308a0b137504ec58 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Mon, 21 Aug 2023 17:39:28 -0700 Subject: [PATCH] Review comment --- .../plans/physical/partitioning.scala | 68 ++++--- .../apache/spark/sql/internal/SQLConf.scala | 6 +- .../datasources/v2/BatchScanExec.scala | 177 +++++++++++------- .../v2/DataSourceV2ScanExecBase.scala | 164 ++++------------ .../exchange/EnsureRequirements.scala | 18 +- .../KeyGroupedPartitioningSuite.scala | 2 + 6 files changed, 206 insertions(+), 229 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 2ce01eb56cda9..2c993388ee9ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -345,10 +345,13 @@ case class KeyGroupedPartitioning( // We'll need to find leaf attributes from the partition expressions first. val attributes = expressions.flatMap(_.collectLeaves()) - // Support only when all cluster key have an associated partition expression key - requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && - // and if all partition expression contain only a single partition key. - expressions.forall(_.collectLeaves().size == 1) + if (SQLConf.get.getConf( + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS)) { + requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) && + expressions.forall(_.collectLeaves().size == 1) + } else { + attributes.forall(x => requiredClustering.exists(_.semanticEquals(x))) + } } case _ => @@ -705,33 +708,29 @@ case class KeyGroupedShuffleSpec( case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution) => distribution.clustering.length == otherDistribution.clustering.length && numPartitions == other.numPartitions && areKeysCompatible(otherSpec) && - isPartitioningCompatible(otherPartitioning) + partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall { + case (left, right) => + InternalRowComparableWrapper(left, partitioning.expressions) + .equals(InternalRowComparableWrapper(right, partitioning.expressions)) + } case ShuffleSpecCollection(specs) => specs.exists(isCompatibleWith) case _ => false } def isPartitioningCompatible(otherPartitioning: KeyGroupedPartitioning): Boolean = { - val clusterKeySize = keyPositions.size + val joinKeyPositions = keyPositions.map(_.nonEmpty) partitioning.partitionValues.zip(otherPartitioning.partitionValues) .forall { case (left, right) => - val leftTypes = partitioning.expressions.map(_.dataType) - val leftVals = left.toSeq(leftTypes).take(clusterKeySize).toArray - val newLeft = new GenericInternalRow(leftVals) - - val rightTypes = partitioning.expressions.map(_.dataType) - val rightVals = right.toSeq(rightTypes).take(clusterKeySize).toArray - val newRight = new GenericInternalRow(rightVals) - - InternalRowComparableWrapper(newLeft, partitioning.expressions.take(clusterKeySize)) - .equals(InternalRowComparableWrapper( - newRight, partitioning.expressions.take(clusterKeySize))) + KeyGroupedShuffleSpec.project(left, partitioning.expressions, joinKeyPositions) + .equals( + KeyGroupedShuffleSpec.project(right, partitioning.expressions, joinKeyPositions)) } } // Whether the partition keys (i.e., partition expressions) are compatible between this and the - // `other` spec. + // other spec. def areKeysCompatible(other: KeyGroupedShuffleSpec): Boolean = { partitionExpressionsCompatible(other) && KeyGroupedShuffleSpec.keyPositionsCompatible( @@ -740,8 +739,8 @@ case class KeyGroupedShuffleSpec( } // Whether the partition keys (i.e., partition expressions) that also are in the set of - // cluster keys are compatible between this and the 'other' spec. - def areClusterPartitionKeysCompatible(other: KeyGroupedShuffleSpec): Boolean = { + // join keys are compatible between this and the other spec. + def areJoinKeysCompatible(other: KeyGroupedShuffleSpec): Boolean = { partitionExpressionsCompatible(other) && KeyGroupedShuffleSpec.keyPositionsCompatible( keyPositions.filter(_.nonEmpty), @@ -749,6 +748,14 @@ case class KeyGroupedShuffleSpec( ) } + override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && + // Only support partition expressions are AttributeReference for now + partitioning.expressions.forall(_.isInstanceOf[AttributeReference]) + + override def createPartitioning(clustering: Seq[Expression]): Partitioning = { + KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues) + } + private def partitionExpressionsCompatible(other: KeyGroupedShuffleSpec): Boolean = { val left = partitioning.expressions val right = other.partitioning.expressions @@ -757,14 +764,6 @@ case class KeyGroupedShuffleSpec( case (l, r) => KeyGroupedShuffleSpec.isExpressionCompatible(l, r) } } - - override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && - // Only support partition expressions are AttributeReference for now - partitioning.expressions.forall(_.isInstanceOf[AttributeReference]) - - override def createPartitioning(clustering: Seq[Expression]): Partitioning = { - KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues) - } } object KeyGroupedShuffleSpec { @@ -783,6 +782,19 @@ object KeyGroupedShuffleSpec { left.intersect(right).nonEmpty } } + + def project(row: InternalRow, expressions: Seq[Expression], + joinKeyPositions: Seq[Boolean]): InternalRowComparableWrapper = { + val projectedExprs = expressions.zip(joinKeyPositions) + .filter(_._2) + .map(_._1) + val projectedVals = row.toSeq(expressions.map(_.dataType)) + .zip(joinKeyPositions) + .filter(_._2) + .map(_._1) + val newRow = InternalRow.fromSeq(projectedVals) + InternalRowComparableWrapper(newRow, projectedExprs) + } } case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c1962e1346c5e..4bcfdf7b57375 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1520,7 +1520,7 @@ object SQLConf { ) .version("4.0.0") .booleanConf - .createWithDefault(true) + .createWithDefault(false) val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") .doc("The maximum number of buckets allowed.") @@ -4921,13 +4921,11 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def v2BucketingPartiallyClusteredDistributionEnabled: Boolean = getConf(SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED) -<<<<<<< HEAD def v2BucketingShuffleEnabled: Boolean = getConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED) -======= + def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean = getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS) ->>>>>>> c431abb6dd5 ([SQL][SPARK-44647] Support SPJ where join keys are less than cluster keys) def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index c3a4c0ba27346..fe933aa3714b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.collection.mutable + import com.google.common.base.Objects import org.apache.spark.SparkException @@ -61,16 +63,60 @@ case class BatchScanExec( @transient override lazy val groupedPartitions: Option[Seq[(InternalRow, Seq[InputPartition])]] = // Early check if we actually need to materialize the input partitions. keyGroupedPartitioning match { - case Some(_) => groupPartitions(inputPartitions, spjParams.partitionGroupByPositions) + case Some(_) => groupPartitions(inputPartitions, spjParams.joinKeyPositions) case _ => None } @transient lazy val groupedCommonPartValues: Option[Seq[(InternalRow, Int)]] = { - if (spjParams.replicatePartitions) { - groupCommonPartitionsByJoinKey( - spjParams.commonPartitionValues, spjParams.partitionGroupByPositions) - } else { - spjParams.commonPartitionValues + keyGroupedPartitioning match { + case Some(_) if spjParams.joinKeyPositions.isDefined => + commonPartitionValuesByJoinKey(group = spjParams.replicatePartitions) + case _ => spjParams.commonPartitionValues + } + } + + /** + * Handle common partition value. This needs special handling in the case where we have fewer + * join keys than partition keys. + * + * @param group. Whether to group partitions or not. + * @return a sorted sequence of join key values projected on the common partition values, + * with aggregate numSplits of all common partition values with those join key values + */ + private def commonPartitionValuesByJoinKey(group: Boolean): Option[Seq[(InternalRow, Int)]] = { + for { + expressions <- keyGroupedPartitioning + commonValues <- spjParams.commonPartitionValues + } yield { + + if (group) { + // We need to group the common partition value and sort them by join keys + val groupedMap = new mutable.HashMap[InternalRowComparableWrapper, Int] + commonValues + .map(p => (partitionRow(p._1, expressions, spjParams.joinKeyPositions), p._2)) + .foreach(p => groupedMap.put(p._1, groupedMap.getOrElse(p._1, 0) + p._2) + ) + val groupedPartitions = groupedMap.map(r => (r._1.row, r._2)).toSeq + + + val projectedExpressions = projectIfNecessary( + expressions.map(_.dataType), + spjParams.joinKeyPositions, + "partition expressions" + ) + val ordering: Ordering[(InternalRow, Int)] = RowOrdering + .createNaturalAscendingOrdering(projectedExpressions).on(_._1) + groupedPartitions.sorted(ordering) + + } else { + // We still sort the other side by the join keys + val orders: Seq[SortOrder] = expressions.zipWithIndex.collect { + case (e, i) if spjParams.joinKeyPositions.get(i) => + SortOrder(BoundReference(i, e.dataType, nullable = true), Ascending) + } + val ordering: Ordering[(InternalRow, Int)] = RowOrdering.create(orders, Seq.empty).on(_._1) + commonValues.sorted(ordering) + } } } @@ -117,7 +163,7 @@ case class BatchScanExec( "partition values that are not present in the original partitioning.") } - groupPartitions(newPartitions, spjParams.partitionGroupByPositions).get.map(_._2) + groupPartitions(newPartitions).get.map(_._2) case _ => // no validation is needed as the data source did not report any specific partitioning @@ -132,12 +178,30 @@ case class BatchScanExec( override def outputPartitioning: Partitioning = { super.outputPartitioning match { case k: KeyGroupedPartitioning if spjParams.commonPartitionValues.isDefined => + // We allow duplicated partition values if // `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true val newPartValues = spjParams.commonPartitionValues.get.flatMap { case (partValue, numSplits) => Seq.fill(numSplits)(partValue) } - k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues) + + // We need to project join keys, if join keys are less than partition keys + // and if we on the replicate side of partially-clustered join + val newExpressions = if (groupByJoinKeys) { + projectIfNecessary(k.expressions, spjParams.joinKeyPositions, "partition exprs") + } else { + k.expressions + } + + val finalPartValues = if (groupByJoinKeys) { + newPartValues.map(r => projectRow(r, newExpressions, spjParams.joinKeyPositions)) + } else { + newPartValues + } + + k.copy(expressions = newExpressions, + numPartitions = newPartValues.length, + partitionValues = finalPartValues) case p => p } } @@ -149,7 +213,6 @@ case class BatchScanExec( // return an empty RDD with 1 partition if dynamic filtering removed the only split sparkContext.parallelize(Array.empty[InternalRow], 1) } else { - val finalPartitions = outputPartitioning match { case p: KeyGroupedPartitioning => if (conf.v2BucketingPushPartValuesEnabled && @@ -159,22 +222,11 @@ case class BatchScanExec( s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " + "is enabled") - // In the case where we replicate partitions, we have grouped - // the partitions by the join key if they differ - val groupByExpressions = - if (spjParams.replicatePartitions && spjParams.partitionGroupByPositions.isDefined) { - project(p.expressions, spjParams.partitionGroupByPositions.get, - "partition expressions") - } else { - p.expressions - } - - // In the case where we replicate partitions, we need to also group - // the partitions by the join key if they differ + // In the case where we replicate partitions, we group + // the partitions further by the join key if they differ val groupedPartitions = if (spjParams.replicatePartitions) { groupPartitions(filteredPartitions.map(_.head), - spjParams.partitionGroupByPositions, - groupSplits = true).get + spjParams.joinKeyPositions, groupSplits = true).get } else { groupPartitions(filteredPartitions.map(_.head), groupSplits = true).get } @@ -185,16 +237,16 @@ case class BatchScanExec( if (spjParams.commonPartitionValues.isDefined && spjParams.applyPartialClustering) { // A mapping from the common partition values to how many splits the partition - // should contain. Note this no longer maintain the partition key ordering. + // should contain. val commonPartValuesMap = groupedCommonPartValues .get - .map(t => (InternalRowComparableWrapper(t._1, groupByExpressions), t._2)) + .map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2)) .toMap val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) => // `commonPartValuesMap` should contain the part value since it's the super set. val numSplits = commonPartValuesMap - .get(InternalRowComparableWrapper(partValue, groupByExpressions)) + .get(InternalRowComparableWrapper(partValue, p.expressions)) assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + "common partition values from Spark plan") @@ -209,29 +261,16 @@ case class BatchScanExec( // sides of a join will have the same number of partitions & splits. splits.map(Seq(_)).padTo(numSplits.get, Seq.empty) } - (InternalRowComparableWrapper(partValue, groupByExpressions), newSplits) + (InternalRowComparableWrapper(partValue, p.expressions), newSplits) } - // In the case that join keys are less than partition keys, we need to sort - // the common partition values by join key - val sortedCommonPositionByJoinKey = - if (spjParams.partitionGroupByPositions.isDefined) { - sortCommonPartitionsByJoinKey( - groupedCommonPartValues, - p.expressions, - spjParams.partitionGroupByPositions.get, - spjParams.replicatePartitions) - } else { - groupedCommonPartValues - } - // Now fill missing partition keys with empty partitions val partitionMapping = nestGroupedPartitions.toMap - sortedCommonPositionByJoinKey.get.flatMap { + groupedCommonPartValues.get.flatMap { case (partValue, numSplits) => // Use empty partition for those partition values that are not present. partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, groupByExpressions), + InternalRowComparableWrapper(partValue, p.expressions), Seq.fill(numSplits)(Seq.empty)) } } else { @@ -248,8 +287,23 @@ case class BatchScanExec( } else { val partitionMapping = filteredPartitions.map(p => (p.head.asInstanceOf[HasPartitionKey].partitionKey(), p)) - fillEmptyInputPartitions(partitionMapping, - p.partitionValues, p.expressions) + val finalPartitionMapping = if (groupByJoinKeys) { + partitionMapping.map { + case (k, v) => (projectRow(k, p.expressions, spjParams.joinKeyPositions), v) + } + } else { + partitionMapping + } + + // If we project by join keys, we may get duplicate partition values. + // It is important to use unique partition values here so that + // grouped partitions won't get duplicated. + val expectedPartitions = if (groupByJoinKeys) { + p.uniquePartitionValues + } else { + p.partitionValues + } + fillEmptyInputPartitions(finalPartitionMapping, expectedPartitions, p.expressions) } case _ => filteredPartitions @@ -276,33 +330,26 @@ case class BatchScanExec( expectedPartValues: Seq[InternalRow], partExpressions: Seq[Expression]): Seq[Seq[InputPartition]] = { - // Handle case where we have fewer join keys than partition keys - // by grouping partition values by join key - val groupedExpectedPartValues = groupPartitionsByJoinKey( - expectedPartValues, partExpressions, spjParams.partitionGroupByPositions) - - val projectedExpression = if (spjParams.partitionGroupByPositions.isDefined) { - project(partExpressions, spjParams.partitionGroupByPositions.get, "partition expressions") - } else { - partExpressions - } - val partitionMapping = partitions.map { case (row, parts) => - projectGroupingKeyIfNecessary(row, - partExpressions, - spjParams.partitionGroupByPositions) -> parts + InternalRowComparableWrapper(row, partExpressions) -> parts }.toMap - groupedExpectedPartValues.map { partValue => + expectedPartValues.map { partValue => // Use empty partition for those partition values that are not present partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, projectedExpression), + InternalRowComparableWrapper(partValue, partExpressions), Seq.empty) } } - override def keyGroupedPartitioning: Option[Seq[Expression]] = - spjParams.keyGroupedPartitioning + override def keyGroupedPartitioning: Option[Seq[Expression]] = spjParams.keyGroupedPartitioning + + // If join keys are less than partition keys, we may have to group + // The only case we may not is if partially-clustered and we are on not-replicate partition side + def groupByJoinKeys: Boolean = spjParams.joinKeyPositions.isDefined && + (!conf.v2BucketingPartiallyClusteredDistributionEnabled || + spjParams.replicatePartitions) + override def doCanonicalize(): BatchScanExec = { this.copy( @@ -327,13 +374,13 @@ case class BatchScanExec( case class StoragePartitionJoinParams( keyGroupedPartitioning: Option[Seq[Expression]] = None, commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, - partitionGroupByPositions: Option[Seq[Boolean]] = None, + joinKeyPositions: Option[Seq[Boolean]] = None, applyPartialClustering: Boolean = false, replicatePartitions: Boolean = false) { override def equals(other: Any): Boolean = other match { case other: StoragePartitionJoinParams => this.commonPartitionValues == other.commonPartitionValues && - this.partitionGroupByPositions == other.partitionGroupByPositions && + this.joinKeyPositions == other.joinKeyPositions && this.replicatePartitions == other.replicatePartitions && this.applyPartialClustering == other.applyPartialClustering case _ => @@ -342,7 +389,7 @@ case class StoragePartitionJoinParams( override def hashCode(): Int = Objects.hashCode( commonPartitionValues: Option[Seq[(InternalRow, Int)]], - partitionGroupByPositions: Option[Seq[Boolean]], + joinKeyPositions: Option[Seq[Boolean]], applyPartialClustering: java.lang.Boolean, replicatePartitions: java.lang.Boolean) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala index f0199ef0e7955..97e92b95d9a3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExecBase.scala @@ -17,11 +17,9 @@ package org.apache.spark.sql.execution.datasources.v2 -import scala.collection.mutable - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Ascending, BoundReference, Expression, RowOrdering, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Expression, RowOrdering, SortOrder} import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} @@ -136,7 +134,7 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { if (!SQLConf.get.v2BucketingEnabled) return None keyGroupedPartitioning.flatMap { expressions => - val results: Seq[(InternalRow, InputPartition)] = inputPartitions.takeWhile { + val results = inputPartitions.takeWhile { case _: HasPartitionKey => true case _ => false }.map(p => (p.asInstanceOf[HasPartitionKey].partitionKey(), p)) @@ -147,20 +145,12 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { } else { // also sort the input partitions according to their partition key order. This ensures // a canonical order from both sides of a bucketed join, for example. - val partitionDataTypes = expressions.map(_.dataType) - val projectedDataTypes = if (groupSplits && partitionGroupByPositions.isDefined) { - project(partitionDataTypes, partitionGroupByPositions.get, "partition expressions") - } else { - partitionDataTypes - } - val partitionOrdering: Ordering[(InternalRow, Seq[InputPartition])] = { - RowOrdering.createNaturalAscendingOrdering(projectedDataTypes).on(_._1) - } - - val partitions = if (groupSplits) { + val (finalPartitions, finalExpressions) = if (groupSplits) { // Group the splits by their partition value - results - .map(t => (projectGroupingKeyIfNecessary(t._1, + val projected = projectIfNecessary(expressions, + partitionGroupByPositions, "partition expressions") + val parts = results + .map(t => (partitionRow(t._1, expressions, partitionGroupByPositions), t._2)) @@ -169,121 +159,21 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { .map { case (key, s) => (key.row, s.map(_._2)) } + (parts, projected) } else { // No splits grouping, each split will become a separate Spark partition - results.map(t => (t._1, Seq(t._2))) + val parts = results.map(t => (t._1, Seq(t._2))) + (parts, expressions) } - Some(partitions.sorted(partitionOrdering)) - } - } - } - - def groupPartitionsByJoinKey( - partitions: Seq[InternalRow], - partitionExpressions: Seq[Expression], - partitionGroupByPositions: Option[Seq[Boolean]]): Seq[InternalRow] = { - if (partitionGroupByPositions.isDefined) { - val groupedPartitions = new mutable.HashSet[InternalRowComparableWrapper] - partitions.foreach { p => - groupedPartitions.add( - projectGroupingKey(p, partitionExpressions, partitionGroupByPositions.get)) - } - groupedPartitions.map(_.row).toSeq - } else { - partitions - } - } - - - /** - * Group the common partition values. - * - * Similar to groupPartitions, we only do this when - * - all input partitions implement [[HasPartitionKey]] - * - `keyGroupedPartitioning` is set - * - [[SQLConf.V2_BUCKETING_ENABLED]] is turned on (checked before the call) - * - we do not replicate partitions (checked before the call) - * - * In these cases, when join keys are less than partition keys, we need to group the common - * partition values by the join keys, and aggregate their value of numSplits. - * - * @param commonValuesOption common partition values, if v2.bucketing.pushPartValues enabled - * @param partitionGroupByPositionsOption position of join keys among partition keys - * @return a sorted sequence of join key values projected on the common partition values, - * with aggregate numSplits of all common partition values with those join key values - */ - def groupCommonPartitionsByJoinKey( - commonValuesOption: Option[Seq[(InternalRow, Int)]], - partitionGroupByPositionsOption: Option[Seq[Boolean]]): Option[Seq[(InternalRow, Int)]] = { - for { - expressions <- keyGroupedPartitioning - commonValues <- commonValuesOption - partitionGroupByPositions <- partitionGroupByPositionsOption - } yield { - val grouped = new mutable.HashMap[InternalRowComparableWrapper, Int] - commonValues.map(p => { - val key = projectGroupingKey(p._1, expressions, partitionGroupByPositions) - (key, p._2) - }).foreach(p => - grouped.put(p._1, grouped.getOrElse(p._1, 0) + p._2) - ) - grouped.map(r => (r._1.row, r._2)).toSeq - } - } - - /** - * Return common partition values, ordered by the join keys. - * - * This is needed in cases where we have fewer join keys than partition keys. - * @param commonValuesOption - * @param expressions - * @param partitionGroupByPositions - * @param grouped - * @return - */ - def sortCommonPartitionsByJoinKey( - commonValuesOption: Option[Seq[(InternalRow, Int)]], - expressions: Seq[Expression], - partitionGroupByPositions: Seq[Boolean], - grouped: Boolean): Option[Seq[(InternalRow, Int)]] = { - commonValuesOption.map { commonValues => - val ordering: Ordering[(InternalRow, Int)] = if (grouped) { - val projectedExpressions = project( - expressions.map(_.dataType), - partitionGroupByPositions, - "partition expressions" - ) - RowOrdering.createNaturalAscendingOrdering(projectedExpressions).on(_._1) - } else { - val orders: Seq[SortOrder] = expressions.zipWithIndex.collect { - case (e, i) if partitionGroupByPositions(i) => - SortOrder(BoundReference(i, e.dataType, nullable = true), Ascending) + val partitionOrdering: Ordering[(InternalRow, Seq[InputPartition])] = { + RowOrdering.createNaturalAscendingOrdering(finalExpressions.map(_.dataType)).on(_._1) } - RowOrdering.create(orders, Seq.empty).on(_._1) + Some(finalPartitions.sorted(partitionOrdering)) } - commonValues.sorted(ordering) - } - } - - def projectGroupingKeyIfNecessary(row: InternalRow, partitionExpressions: Seq[Expression], - partitionGroupByPositions: Option[Seq[Boolean]]): InternalRowComparableWrapper = { - if (partitionGroupByPositions.isDefined) { - projectGroupingKey(row, partitionExpressions, partitionGroupByPositions.get) - } else { - InternalRowComparableWrapper(row, partitionExpressions) } } - def projectGroupingKey(row: InternalRow, partitionExpressions: Seq[Expression], - partitionGroupByPositions: Seq[Boolean]): InternalRowComparableWrapper = { - val values = row.toSeq(partitionExpressions.map(_.dataType)) - val filteredValues = project(values, partitionGroupByPositions, "partition values") - val filteredExpressions = project(partitionExpressions, - partitionGroupByPositions, "partition expressions") - InternalRowComparableWrapper(InternalRow.fromSeq(filteredValues), filteredExpressions) - } - def project[T](values: Seq[T], positions: Seq[Boolean], desc: String): Seq[T] = { assert(values.size == positions.size, s"partition group-by positions map does not match $desc") @@ -301,6 +191,34 @@ trait DataSourceV2ScanExecBase extends LeafExecNode { } } + def projectRow(row: InternalRow, + projectedExpressions: Seq[Expression], + joinKeyPositions: Option[Seq[Boolean]]): InternalRow = { + if (joinKeyPositions.isDefined) { + val values = row.toSeq(projectedExpressions.map(_.dataType)) + val filteredValues = projectIfNecessary(values, joinKeyPositions, + "partition values") + InternalRow.fromSeq(filteredValues) + } else { + row + } + } + + def partitionRow(row: InternalRow, + partitionExpressions: Seq[Expression], + joinKeyPositions: Option[Seq[Boolean]]): InternalRowComparableWrapper = { + if (joinKeyPositions.isDefined) { + val values = row.toSeq(partitionExpressions.map(_.dataType)) + val filteredValues = projectIfNecessary(values, joinKeyPositions, + "partition values") + val projectedExpressions = projectIfNecessary(partitionExpressions, joinKeyPositions, + "partition expressions") + InternalRowComparableWrapper(InternalRow.fromSeq(filteredValues), projectedExpressions) + } else { + InternalRowComparableWrapper(row, partitionExpressions) + } + } + override def outputOrdering: Seq[SortOrder] = { // when multiple partitions are grouped together, ordering inside partitions is not preserved val partitioningPreservesOrdering = groupedPartitions.forall(_.forall(_._2.length <= 1)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 6791749e0faa6..c7a46efa80f25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -386,7 +386,7 @@ case class EnsureRequirements( logInfo("Pushing common partition values for storage-partitioned join") // if (group by partition key..) , check if partition key overlap if (conf.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { - isCompatible = leftSpec.areClusterPartitionKeysCompatible(rightSpec) + isCompatible = leftSpec.areJoinKeysCompatible(rightSpec) } else { isCompatible = leftSpec.areKeysCompatible(rightSpec) } @@ -508,15 +508,15 @@ case class EnsureRequirements( // only need to check left spec spec, as it should be compatible // with right spec wrt key positions - val partitionGroupPositions = getPartitionGroupPositions(leftSpec.keyPositions, + val joinKeyPositions = getJoinKeyPositions(leftSpec.keyPositions, conf.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) // Now we need to push-down the common partition key to the scan in each child newLeft = populateStoragePartitionJoinParams( - left, mergedPartValues, partitionGroupPositions, applyPartialClustering, + left, mergedPartValues, joinKeyPositions, applyPartialClustering, replicateLeftSide) newRight = populateStoragePartitionJoinParams( - right, mergedPartValues, partitionGroupPositions, applyPartialClustering, + right, mergedPartValues, joinKeyPositions, applyPartialClustering, replicateRightSide) } } @@ -524,9 +524,9 @@ case class EnsureRequirements( if (isCompatible) Some(Seq(newLeft, newRight)) else None } - // Given keyPositions (join key positions), return a sequence of position of partition keys, + // Given keyPositions, return a sequence of position of partition keys that are join keys, // to group similar partition values before executing storage-partition join - private def getPartitionGroupPositions(keyPositions: Seq[mutable.BitSet], + private def getJoinKeyPositions(keyPositions: Seq[mutable.BitSet], allowJoinKeysSubsetOfPartitionKeys: Boolean): Option[Seq[Boolean]] = { if (allowJoinKeysSubsetOfPartitionKeys) { Some(keyPositions.map(_.nonEmpty)) @@ -550,21 +550,21 @@ case class EnsureRequirements( private def populateStoragePartitionJoinParams( plan: SparkPlan, values: Seq[(InternalRow, Int)], - partitionGroupByPositions: Option[Seq[Boolean]], + joinKeyPositions: Option[Seq[Boolean]], applyPartialClustering: Boolean, replicatePartitions: Boolean): SparkPlan = plan match { case scan: BatchScanExec => scan.copy( spjParams = scan.spjParams.copy( commonPartitionValues = Some(values), - partitionGroupByPositions = partitionGroupByPositions, + joinKeyPositions = joinKeyPositions, applyPartialClustering = applyPartialClustering, replicatePartitions = replicatePartitions ) ) case node => node.mapChildren(child => populateStoragePartitionJoinParams( - child, values, partitionGroupByPositions, applyPartialClustering, replicatePartitions)) + child, values, joinKeyPositions, applyPartialClustering, replicatePartitions)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 3fd5521376f11..88f78d1771a08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -1305,6 +1305,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Seq(true, false).foreach { pushDownValues => Seq(true, false).foreach { partiallyClustered => Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + withSQLConf( SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString, @@ -1443,6 +1444,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Seq(true, false).foreach { pushDownValues => Seq(true, false).foreach { partiallyClustered => Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + withSQLConf( SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushDownValues.toString,