Skip to content

Commit

Permalink
[SPARK-48949][SQL] SPJ: Runtime partition filtering
Browse files Browse the repository at this point in the history
    ### What changes were proposed in this pull request?

    Introduce runtime partition filtering for SPJ.  In planning, we have the list of partition values on both sides to plan the tasks.  We can thus filter out partition values based on the join type.

    ### Why are the changes needed?

    In some common join types (INNER, LEFT, RIGHT), we have an opportunity to greatly reduce the data scanned in SPJ.  For example, a small table joining a larger table by partition key, can prune out most of the partitions of the large table.

    There is some similarity with the concept of DPP, but that uses heuristics and this is more exact as SPJ planning requires us anyway to list out both sides partitioning.

    ### Does this PR introduce _any_ user-facing change?

    No.

    ### How was this patch tested?

    New tests in KeyGroupedPartitioningSuite.
  • Loading branch information
szehon-ho committed Jul 19, 2024
1 parent ebf8da1 commit 7fd0d08
Show file tree
Hide file tree
Showing 6 changed files with 368 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, Murmur3HashFunction, RowOrdering}
import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning
import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition}
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.NonFateSharingCache
Expand Down Expand Up @@ -85,22 +84,25 @@ object InternalRowComparableWrapper {
}

def mergePartitions(
leftPartitioning: KeyGroupedPartitioning,
rightPartitioning: KeyGroupedPartitioning,
partitionExpression: Seq[Expression]): Seq[InternalRow] = {
leftPartitioning: Seq[InternalRow],
rightPartitioning: Seq[InternalRow],
partitionExpression: Seq[Expression],
intersect: Boolean = false): Seq[InternalRowComparableWrapper] = {
val partitionDataTypes = partitionExpression.map(_.dataType)
val partitionsSet = new mutable.HashSet[InternalRowComparableWrapper]
leftPartitioning.partitionValues
val leftPartitionSet = new mutable.HashSet[InternalRowComparableWrapper]
leftPartitioning
.map(new InternalRowComparableWrapper(_, partitionDataTypes))
.foreach(partition => partitionsSet.add(partition))
rightPartitioning.partitionValues
.foreach(partition => leftPartitionSet.add(partition))
val rightPartitionSet = new mutable.HashSet[InternalRowComparableWrapper]
rightPartitioning
.map(new InternalRowComparableWrapper(_, partitionDataTypes))
.foreach(partition => partitionsSet.add(partition))
// SPARK-41471: We keep to order of partitions to make sure the order of
// partitions is deterministic in different case.
val partitionOrdering: Ordering[InternalRow] = {
RowOrdering.createNaturalAscendingOrdering(partitionDataTypes)
.foreach(partition => rightPartitionSet.add(partition))

val result = if (intersect) {
leftPartitionSet.intersect(rightPartitionSet)
} else {
leftPartitionSet.union(rightPartitionSet)
}
partitionsSet.map(_.row).toSeq.sorted(partitionOrdering)
result.toSeq
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1635,6 +1635,17 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val V2_BUCKETING_PARTITION_FILTER_ENABLED =
buildConf("spark.sql.sources.v2.bucketing.partition.filter.enabled")
.doc(s"Whether to filter partitions when running storage-partition join. " +
s"When enabled, partitions without matches on the other side can be omitted for " +
s"scanning, if allowed by the join type. This config requires both " +
s"${V2_BUCKETING_ENABLED.key} and ${V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key} to be " +
s"enabled.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
.doc("The maximum number of buckets allowed.")
.version("2.4.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ object InternalRowComparableWrapperBenchmark extends BenchmarkBase {
val leftPartitioning = KeyGroupedPartitioning(expressions, bucketNum, partitions)
val rightPartitioning = KeyGroupedPartitioning(expressions, bucketNum, partitions)
val merged = InternalRowComparableWrapper.mergePartitions(
leftPartitioning, rightPartitioning, expressions)
leftPartitioning.partitionValues, rightPartitioning.partitionValues, expressions)
assert(merged.size == bucketNum)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,12 @@ case class BatchScanExec(
.get
.map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2))
.toMap
val nestGroupedPartitions = finalGroupedPartitions.map { case (partValue, splits) =>
val filteredGroupedPartitions = finalGroupedPartitions.filter {
case (partValues, _) =>
commonPartValuesMap.keySet.contains(
InternalRowComparableWrapper(partValues, partExpressions))
}
val nestGroupedPartitions = filteredGroupedPartitions.map { case (partValue, splits) =>
// `commonPartValuesMap` should contain the part value since it's the super set.
val numSplits = commonPartValuesMap
.get(InternalRowComparableWrapper(partValue, partExpressions))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,19 @@ case class EnsureRequirements(
// expressions
val partitionExprs = leftSpec.partitioning.expressions

var mergedPartValues = InternalRowComparableWrapper
.mergePartitions(leftSpec.partitioning, rightSpec.partitioning, partitionExprs)
// in case of compatible but not identical partition expressions, we apply 'reduce'
// transforms to group one side's partitions as well as the common partition values
val leftReducers = leftSpec.reducers(rightSpec)
val leftParts = reducePartValues(leftSpec.partitioning.partitionValues,
partitionExprs,
leftReducers)
val rightReducers = rightSpec.reducers(leftSpec)
val rightParts = reducePartValues(rightSpec.partitioning.partitionValues,
partitionExprs,
rightReducers)

// merge values on both sides
var mergedPartValues = mergePartitions(leftParts, rightParts, partitionExprs, joinType)
.map(v => (v, 1))

logInfo(log"After merging, there are " +
Expand Down Expand Up @@ -525,23 +536,6 @@ case class EnsureRequirements(
}
}

// in case of compatible but not identical partition expressions, we apply 'reduce'
// transforms to group one side's partitions as well as the common partition values
val leftReducers = leftSpec.reducers(rightSpec)
val rightReducers = rightSpec.reducers(leftSpec)

if (leftReducers.isDefined || rightReducers.isDefined) {
mergedPartValues = reduceCommonPartValues(mergedPartValues,
leftSpec.partitioning.expressions,
leftReducers)
mergedPartValues = reduceCommonPartValues(mergedPartValues,
rightSpec.partitioning.expressions,
rightReducers)
val rowOrdering = RowOrdering
.createNaturalAscendingOrdering(partitionExprs.map(_.dataType))
mergedPartValues = mergedPartValues.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1))
}

// Now we need to push-down the common partition information to the scan in each child
newLeft = populateCommonPartitionInfo(left, mergedPartValues, leftSpec.joinKeyPositions,
leftReducers, applyPartialClustering, replicateLeftSide)
Expand Down Expand Up @@ -602,15 +596,15 @@ case class EnsureRequirements(
child, joinKeyPositions))
}

private def reduceCommonPartValues(
commonPartValues: Seq[(InternalRow, Int)],
private def reducePartValues(
partValues: Seq[InternalRow],
expressions: Seq[Expression],
reducers: Option[Seq[Option[Reducer[_, _]]]]) = {
reducers match {
case Some(reducers) => commonPartValues.groupBy { case (row, _) =>
case Some(reducers) => partValues.map { row =>
KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers)
}.map{ case(wrapper, splits) => (wrapper.row, splits.map(_._2).sum) }.toSeq
case _ => commonPartValues
}.distinct.map(_.row)
case _ => partValues
}
}

Expand Down Expand Up @@ -651,6 +645,46 @@ case class EnsureRequirements(
}
}

/**
* Merge and sort partitions values for SPJ and optionally enable partition filtering.
* Both sides must have
* matching partition expressions.
* @param leftPartitioning left side partition values
* @param rightPartitioning right side partition values
* @param partitionExpression partition expressions
* @param joinType join type for optional partition filtering
* @return merged and sorted partition values
*/
private def mergePartitions(
leftPartitioning: Seq[InternalRow],
rightPartitioning: Seq[InternalRow],
partitionExpression: Seq[Expression],
joinType: JoinType): Seq[InternalRow] = {

val merged = if (SQLConf.get.getConf(SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED)) {
joinType match {
case Inner => InternalRowComparableWrapper.mergePartitions(
leftPartitioning, rightPartitioning, partitionExpression, intersect = true)
case LeftOuter => leftPartitioning.map(
InternalRowComparableWrapper(_, partitionExpression))
case RightOuter => rightPartitioning.map(
InternalRowComparableWrapper(_, partitionExpression))
case _ => InternalRowComparableWrapper.mergePartitions(leftPartitioning,
rightPartitioning, partitionExpression)
}
} else {
InternalRowComparableWrapper.mergePartitions(leftPartitioning, rightPartitioning,
partitionExpression)
}

// SPARK-41471: We keep to order of partitions to make sure the order of
// partitions is deterministic in different case.
val partitionOrdering: Ordering[InternalRow] = {
RowOrdering.createNaturalAscendingOrdering(partitionExpression.map(_.dataType))
}
merged.map(_.row).sorted(partitionOrdering)
}

def apply(plan: SparkPlan): SparkPlan = {
val newPlan = plan.transformUp {
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin, _)
Expand Down
Loading

0 comments on commit 7fd0d08

Please sign in to comment.