Skip to content

Commit

Permalink
[SQL][SPARK-44647] Support SPJ where join keys are less than cluster …
Browse files Browse the repository at this point in the history
…keys

  ### What changes were proposed in this pull request?
- Add new conf spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled
- Change key compatibility checks in EnsureRequirements.  Remove checks where all partition keys must be in join keys to allow isKeyCompatible = true in this case.
- Change BatchScanExec/DataSourceV2Relation to group splits by join keys (previously grouped only by partition values)
- Implement partiallyClustered skew-handling.
  - Group only the replicate side (now by join key as well)
  - add an additional sort in the end of partitions based on join key, as when we group the non-replicate side, partition ordering becomes out of order.

  ### Why are the changes needed?
- Support Storage Partition Join in cases where the join condition does not contain all the partition keys, but just some of them

  ### Does this PR introduce _any_ user-facing change?
No

  ### How was this patch tested?
-Added tests in KeyGroupedPartitioningSuite
-Found two problems, will address in separate PR:
- #37886  made another change so that we have to select all join keys, otherwise DSV2 scan does not report KeyGroupedPartitioning and SPJ does not get triggered.  Need to see how to relax this.
- https://issues.apache.org/jira/browse/SPARK-44641 was found when testing this change.  This pr refactors some of those code to add group-by-join-key, but doesnt change the underlying logic, so issue continues to exist.  Hopefully this will also get fixed in another way.
  • Loading branch information
szehon-ho committed Aug 24, 2023
1 parent ce6b5f3 commit ea88ab5
Show file tree
Hide file tree
Showing 6 changed files with 850 additions and 338 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,11 @@ case class KeyGroupedPartitioning(
} else {
// We'll need to find leaf attributes from the partition expressions first.
val attributes = expressions.flatMap(_.collectLeaves())
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))

// Support only when all cluster key have an associated partition expression key
requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) &&
// and if all partition expression contain only a single partition key.
expressions.forall(_.collectLeaves().size == 1)
}

case _ =>
Expand Down Expand Up @@ -701,46 +705,83 @@ case class KeyGroupedShuffleSpec(
case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution) =>
distribution.clustering.length == otherDistribution.clustering.length &&
numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {
case (left, right) =>
InternalRowComparableWrapper(left, partitioning.expressions)
.equals(InternalRowComparableWrapper(right, partitioning.expressions))
}
isPartitioningCompatible(otherPartitioning)
case ShuffleSpecCollection(specs) =>
specs.exists(isCompatibleWith)
case _ => false
}

def isPartitioningCompatible(otherPartitioning: KeyGroupedPartitioning): Boolean = {
val clusterKeySize = keyPositions.size
partitioning.partitionValues.zip(otherPartitioning.partitionValues)
.forall {
case (left, right) =>
val leftTypes = partitioning.expressions.map(_.dataType)
val leftVals = left.toSeq(leftTypes).take(clusterKeySize).toArray
val newLeft = new GenericInternalRow(leftVals)

val rightTypes = partitioning.expressions.map(_.dataType)
val rightVals = right.toSeq(rightTypes).take(clusterKeySize).toArray
val newRight = new GenericInternalRow(rightVals)

InternalRowComparableWrapper(newLeft, partitioning.expressions.take(clusterKeySize))
.equals(InternalRowComparableWrapper(
newRight, partitioning.expressions.take(clusterKeySize)))
}
}

// Whether the partition keys (i.e., partition expressions) are compatible between this and the
// `other` spec.
def areKeysCompatible(other: KeyGroupedShuffleSpec): Boolean = {
val expressions = partitioning.expressions
val otherExpressions = other.partitioning.expressions

expressions.length == otherExpressions.length && {
val otherKeyPositions = other.keyPositions
keyPositions.zip(otherKeyPositions).forall { case (left, right) =>
left.intersect(right).nonEmpty
partitionExpressionsCompatible(other) &&
KeyGroupedShuffleSpec.keyPositionsCompatible(
keyPositions, other.keyPositions
)
}

// Whether the partition keys (i.e., partition expressions) that also are in the set of
// cluster keys are compatible between this and the 'other' spec.
def areClusterPartitionKeysCompatible(other: KeyGroupedShuffleSpec): Boolean = {
partitionExpressionsCompatible(other) &&
KeyGroupedShuffleSpec.keyPositionsCompatible(
keyPositions.filter(_.nonEmpty),
other.keyPositions.filter(_.nonEmpty)
)
}

private def partitionExpressionsCompatible(other: KeyGroupedShuffleSpec): Boolean = {
val left = partitioning.expressions
val right = other.partitioning.expressions
left.length == right.length &&
left.zip(right).forall {
case (l, r) => KeyGroupedShuffleSpec.isExpressionCompatible(l, r)
}
} && expressions.zip(otherExpressions).forall {
case (l, r) => isExpressionCompatible(l, r)
}
}

private def isExpressionCompatible(left: Expression, right: Expression): Boolean =
override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled &&
// Only support partition expressions are AttributeReference for now
partitioning.expressions.forall(_.isInstanceOf[AttributeReference])

override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues)
}
}

object KeyGroupedShuffleSpec {

def isExpressionCompatible(left: Expression, right: Expression): Boolean =
(left, right) match {
case (_: LeafExpression, _: LeafExpression) => true
case (left: TransformExpression, right: TransformExpression) =>
left.isSameFunction(right)
case _ => false
}

override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled &&
// Only support partition expressions are AttributeReference for now
partitioning.expressions.forall(_.isInstanceOf[AttributeReference])

override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues)
def keyPositionsCompatible(left: Seq[mutable.BitSet], right: Seq[mutable.BitSet]): Boolean = {
left.length == right.length &&
left.zip(right).forall { case (left, right) =>
left.intersect(right).nonEmpty
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1500,7 +1500,7 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val V2_BUCKETING_SHUFFLE_ENABLED =
val V2_BUCKETING_SHUFFLE_ENABLED =
buildConf("spark.sql.sources.v2.bucketing.shuffle.enabled")
.doc("During a storage-partitioned join, whether to allow to shuffle only one side." +
"When only one side is KeyGroupedPartitioning, if the conditions are met, spark will " +
Expand All @@ -1510,6 +1510,18 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS =
buildConf("spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled")
.doc("Whether to allow storage-partition join in the case where join keys are" +
"a subset of the partition keys of the source tables. At planning time, " +
"Spark will group the partitions by only those keys that are in the join keys." +
"This is currently enabled only if spark.sql.sources.v2.bucketing.pushPartValues.enabled " +
"is also enabled."
)
.version("4.0.0")
.booleanConf
.createWithDefault(true)

val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
.doc("The maximum number of buckets allowed.")
.version("2.4.0")
Expand Down Expand Up @@ -4909,8 +4921,13 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def v2BucketingPartiallyClusteredDistributionEnabled: Boolean =
getConf(SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED)

<<<<<<< HEAD
def v2BucketingShuffleEnabled: Boolean =
getConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED)
=======
def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean =
getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS)
>>>>>>> c431abb6dd5 ([SQL][SPARK-44647] Support SPJ where join keys are less than cluster keys)

def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,22 @@ case class BatchScanExec(

@transient override lazy val inputPartitions: Seq[InputPartition] = batch.planInputPartitions()

@transient override lazy val groupedPartitions: Option[Seq[(InternalRow, Seq[InputPartition])]] =
// Early check if we actually need to materialize the input partitions.
keyGroupedPartitioning match {
case Some(_) => groupPartitions(inputPartitions, spjParams.partitionGroupByPositions)
case _ => None
}

@transient lazy val groupedCommonPartValues: Option[Seq[(InternalRow, Int)]] = {
if (spjParams.replicatePartitions) {
groupCommonPartitionsByJoinKey(
spjParams.commonPartitionValues, spjParams.partitionGroupByPositions)
} else {
spjParams.commonPartitionValues
}
}

@transient private lazy val filteredPartitions: Seq[Seq[InputPartition]] = {
val dataSourceFilters = runtimeFilters.flatMap {
case DynamicPruningExpression(e) => DataSourceV2Strategy.translateRuntimeFilterV2(e)
Expand Down Expand Up @@ -101,7 +117,7 @@ case class BatchScanExec(
"partition values that are not present in the original partitioning.")
}

groupPartitions(newPartitions).get.map(_._2)
groupPartitions(newPartitions, spjParams.partitionGroupByPositions).get.map(_._2)

case _ =>
// no validation is needed as the data source did not report any specific partitioning
Expand Down Expand Up @@ -133,9 +149,8 @@ case class BatchScanExec(
// return an empty RDD with 1 partition if dynamic filtering removed the only split
sparkContext.parallelize(Array.empty[InternalRow], 1)
} else {
var finalPartitions = filteredPartitions

outputPartitioning match {
val finalPartitions = outputPartitioning match {
case p: KeyGroupedPartitioning =>
if (conf.v2BucketingPushPartValuesEnabled &&
conf.v2BucketingPartiallyClusteredDistributionEnabled) {
Expand All @@ -144,25 +159,42 @@ case class BatchScanExec(
s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " +
"is enabled")

val groupedPartitions = groupPartitions(finalPartitions.map(_.head),
groupSplits = true).get
// In the case where we replicate partitions, we have grouped
// the partitions by the join key if they differ
val groupByExpressions =
if (spjParams.replicatePartitions && spjParams.partitionGroupByPositions.isDefined) {
project(p.expressions, spjParams.partitionGroupByPositions.get,
"partition expressions")
} else {
p.expressions
}

// In the case where we replicate partitions, we need to also group
// the partitions by the join key if they differ
val groupedPartitions = if (spjParams.replicatePartitions) {
groupPartitions(filteredPartitions.map(_.head),
spjParams.partitionGroupByPositions,
groupSplits = true).get
} else {
groupPartitions(filteredPartitions.map(_.head), groupSplits = true).get
}

// This means the input partitions are not grouped by partition values. We'll need to
// check `groupByPartitionValues` and decide whether to group and replicate splits
// within a partition.
if (spjParams.commonPartitionValues.isDefined &&
spjParams.applyPartialClustering) {
// A mapping from the common partition values to how many splits the partition
// should contain.
val commonPartValuesMap = spjParams.commonPartitionValues
// should contain. Note this no longer maintain the partition key ordering.
val commonPartValuesMap = groupedCommonPartValues
.get
.map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2))
.map(t => (InternalRowComparableWrapper(t._1, groupByExpressions), t._2))
.toMap
val nestGroupedPartitions = groupedPartitions.map {
case (partValue, splits) =>
// `commonPartValuesMap` should contain the part value since it's the super set.
val numSplits = commonPartValuesMap
.get(InternalRowComparableWrapper(partValue, p.expressions))
.get(InternalRowComparableWrapper(partValue, groupByExpressions))
assert(numSplits.isDefined, s"Partition value $partValue does not exist in " +
"common partition values from Spark plan")

Expand All @@ -177,48 +209,50 @@ case class BatchScanExec(
// sides of a join will have the same number of partitions & splits.
splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
}
(InternalRowComparableWrapper(partValue, p.expressions), newSplits)
(InternalRowComparableWrapper(partValue, groupByExpressions), newSplits)
}

// In the case that join keys are less than partition keys, we need to sort
// the common partition values by join key
val sortedCommonPositionByJoinKey =
if (spjParams.partitionGroupByPositions.isDefined) {
sortCommonPartitionsByJoinKey(
groupedCommonPartValues,
p.expressions,
spjParams.partitionGroupByPositions.get,
spjParams.replicatePartitions)
} else {
groupedCommonPartValues
}

// Now fill missing partition keys with empty partitions
val partitionMapping = nestGroupedPartitions.toMap
finalPartitions = spjParams.commonPartitionValues.get.flatMap {
sortedCommonPositionByJoinKey.get.flatMap {
case (partValue, numSplits) =>
// Use empty partition for those partition values that are not present.
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, p.expressions),
InternalRowComparableWrapper(partValue, groupByExpressions),
Seq.fill(numSplits)(Seq.empty))
}
} else {
// either `commonPartitionValues` is not defined, or it is defined but
// `applyPartialClustering` is false.
val partitionMapping = groupedPartitions.map { case (row, parts) =>
InternalRowComparableWrapper(row, p.expressions) -> parts
}.toMap

// In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there
// could exist duplicated partition values, as partition grouping is not done
// at the beginning and postponed to this method. It is important to use unique
// partition values here so that grouped partitions won't get duplicated.
finalPartitions = p.uniquePartitionValues.map { partValue =>
// Use empty partition for those partition values that are not present
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, p.expressions), Seq.empty)
}
fillEmptyInputPartitions(
groupedPartitions, p.uniquePartitionValues, p.expressions)
}
} else {
val partitionMapping = finalPartitions.map { parts =>
val row = parts.head.asInstanceOf[HasPartitionKey].partitionKey()
InternalRowComparableWrapper(row, p.expressions) -> parts
}.toMap
finalPartitions = p.partitionValues.map { partValue =>
// Use empty partition for those partition values that are not present
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, p.expressions), Seq.empty)
}
val partitionMapping = filteredPartitions.map(p =>
(p.head.asInstanceOf[HasPartitionKey].partitionKey(), p))
fillEmptyInputPartitions(partitionMapping,
p.partitionValues, p.expressions)
}

case _ =>
case _ => filteredPartitions
}

new DataSourceRDD(
Expand All @@ -228,6 +262,45 @@ case class BatchScanExec(
rdd
}

/**
* Fill empty expected partition values with partitions
* to match the other side of join.
*
* @param partitions mapping of partition values to partitions so far
* @param expectedPartValues expected partition values
* @param partExpressions partition expression
* @return mapping of partition values to partitions, of which values that exist
* in partValues but not original mapping are filled with empty seqs.
*/
def fillEmptyInputPartitions(partitions: Seq[(InternalRow, Seq[InputPartition])],
expectedPartValues: Seq[InternalRow],
partExpressions: Seq[Expression]): Seq[Seq[InputPartition]] = {

// Handle case where we have fewer join keys than partition keys
// by grouping partition values by join key
val groupedExpectedPartValues = groupPartitionsByJoinKey(
expectedPartValues, partExpressions, spjParams.partitionGroupByPositions)

val projectedExpression = if (spjParams.partitionGroupByPositions.isDefined) {
project(partExpressions, spjParams.partitionGroupByPositions.get, "partition expressions")
} else {
partExpressions
}

val partitionMapping = partitions.map { case (row, parts) =>
projectGroupingKeyIfNecessary(row,
partExpressions,
spjParams.partitionGroupByPositions) -> parts
}.toMap

groupedExpectedPartValues.map { partValue =>
// Use empty partition for those partition values that are not present
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, projectedExpression),
Seq.empty)
}
}

override def keyGroupedPartitioning: Option[Seq[Expression]] =
spjParams.keyGroupedPartitioning

Expand All @@ -254,11 +327,13 @@ case class BatchScanExec(
case class StoragePartitionJoinParams(
keyGroupedPartitioning: Option[Seq[Expression]] = None,
commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
partitionGroupByPositions: Option[Seq[Boolean]] = None,
applyPartialClustering: Boolean = false,
replicatePartitions: Boolean = false) {
override def equals(other: Any): Boolean = other match {
case other: StoragePartitionJoinParams =>
this.commonPartitionValues == other.commonPartitionValues &&
this.partitionGroupByPositions == other.partitionGroupByPositions &&
this.replicatePartitions == other.replicatePartitions &&
this.applyPartialClustering == other.applyPartialClustering
case _ =>
Expand All @@ -267,6 +342,7 @@ case class StoragePartitionJoinParams(

override def hashCode(): Int = Objects.hashCode(
commonPartitionValues: Option[Seq[(InternalRow, Int)]],
partitionGroupByPositions: Option[Seq[Boolean]],
applyPartialClustering: java.lang.Boolean,
replicatePartitions: java.lang.Boolean)
}
Expand Down
Loading

0 comments on commit ea88ab5

Please sign in to comment.