Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-44647][SQL] Support SPJ where join keys are less than cluster keys #42306

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,14 @@ case class KeyGroupedPartitioning(
} else {
// We'll need to find leaf attributes from the partition expressions first.
val attributes = expressions.flatMap(_.collectLeaves())
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))

if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
// check that all join keys (required clustering keys) contained in partitioning
requiredClustering.forall(x => attributes.exists(_.semanticEquals(x))) &&
expressions.forall(_.collectLeaves().size == 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be guaranteed currently - it might be better to have this invariant check somewhere else like when constructing a KeyGroupedPartitioning, but OK to leave it here for now

} else {
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
}
}

case _ =>
Expand All @@ -364,8 +371,20 @@ case class KeyGroupedPartitioning(
}
}

override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec =
KeyGroupedShuffleSpec(this, distribution)
override def createShuffleSpec(distribution: ClusteredDistribution): ShuffleSpec = {
val result = KeyGroupedShuffleSpec(this, distribution)
if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
// If allowing join keys to be subset of clustering keys, we should create a new
// `KeyGroupedPartitioning` here that is grouped on the join keys instead, and use that as
// the returned shuffle spec.
val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2)
val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions,
partitionValues, originalPartitionValues)
result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions))
} else {
result
}
}

lazy val uniquePartitionValues: Seq[InternalRow] = {
partitionValues
Expand All @@ -378,8 +397,25 @@ case class KeyGroupedPartitioning(
object KeyGroupedPartitioning {
def apply(
expressions: Seq[Expression],
partitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
KeyGroupedPartitioning(expressions, partitionValues.size, partitionValues, partitionValues)
projectionPositions: Seq[Int],
partitionValues: Seq[InternalRow],
originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
val projectedExpressions = projectionPositions.map(expressions(_))
val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _))
val projectedOriginalPartitionValues =
originalPartitionValues.map(project(expressions, projectionPositions, _))

KeyGroupedPartitioning(projectedExpressions, projectedPartitionValues.length,
projectedPartitionValues, projectedOriginalPartitionValues)
}

def project(
expressions: Seq[Expression],
positions: Seq[Int],
input: InternalRow): InternalRow = {
val projectedValues: Array[Any] = positions.map(i => input.get(i, expressions(i).dataType))
.toArray
new GenericInternalRow(projectedValues)
}

def supportsExpressions(expressions: Seq[Expression]): Boolean = {
Expand Down Expand Up @@ -672,9 +708,18 @@ case class HashShuffleSpec(
override def numPartitions: Int = partitioning.numPartitions
}

/**
* [[ShuffleSpec]] created by [[KeyGroupedPartitioning]].
*
* @param partitioning key grouped partitioning
szehon-ho marked this conversation as resolved.
Show resolved Hide resolved
* @param distribution distribution
* @param joinKeyPosition position of join keys among cluster keys.
* This is set if joining on a subset of cluster keys is allowed.
*/
case class KeyGroupedShuffleSpec(
partitioning: KeyGroupedPartitioning,
distribution: ClusteredDistribution) extends ShuffleSpec {
distribution: ClusteredDistribution,
joinKeyPositions: Option[Seq[Int]] = None) extends ShuffleSpec {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add some comments for KeyGroupedShuffleSpec to explain what is this for, otherwise it's a bit hard to understand.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments, please check and suggest if it can be improved.


/**
* A sequence where each element is a set of positions of the partition expression to the cluster
Expand Down Expand Up @@ -709,7 +754,7 @@ case class KeyGroupedShuffleSpec(
// 3.3 each pair of partition expressions at the same index must share compatible
// transform functions.
// 4. the partition values from both sides are following the same order.
case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution) =>
case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) =>
distribution.clustering.length == otherDistribution.clustering.length &&
numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,18 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS =
buildConf("spark.sql.sources.v2.bucketing.allowJoinKeysSubsetOfPartitionKeys.enabled")
szehon-ho marked this conversation as resolved.
Show resolved Hide resolved
.doc("Whether to allow storage-partition join in the case where join keys are" +
"a subset of the partition keys of the source tables. At planning time, " +
"Spark will group the partitions by only those keys that are in the join keys." +
s"This is currently enabled only if ${REQUIRE_ALL_CLUSTER_KEYS_FOR_DISTRIBUTION.key} " +
"is false."
)
szehon-ho marked this conversation as resolved.
Show resolved Hide resolved
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
.doc("The maximum number of buckets allowed.")
.version("2.4.0")
Expand Down Expand Up @@ -4936,6 +4948,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def v2BucketingShuffleEnabled: Boolean =
getConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED)

def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean =
getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS)

def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,12 @@ case class BatchScanExec(
val newPartValues = spjParams.commonPartitionValues.get.flatMap {
case (partValue, numSplits) => Seq.fill(numSplits)(partValue)
}
k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues)
val expressions = spjParams.joinKeyPositions match {
case Some(projectionPositions) => projectionPositions.map(i => k.expressions(i))
case _ => k.expressions
}
k.copy(expressions = expressions, numPartitions = newPartValues.length,
partitionValues = newPartValues)
case p => p
}
}
Expand All @@ -132,14 +137,29 @@ case class BatchScanExec(
// return an empty RDD with 1 partition if dynamic filtering removed the only split
sparkContext.parallelize(Array.empty[InternalRow], 1)
} else {
var finalPartitions = filteredPartitions

outputPartitioning match {
val finalPartitions = outputPartitioning match {
case p: KeyGroupedPartitioning =>
val groupedPartitions = filteredPartitions.map(splits => {
assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey])
(splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits)
})
assert(spjParams.keyGroupedPartitioning.isDefined)
val expressions = spjParams.keyGroupedPartitioning.get

// Re-group the input partitions if we are projecting on a subset of join keys
val (groupedPartitions, partExpressions) = spjParams.joinKeyPositions match {
case Some(projectPositions) =>
val projectedExpressions = projectPositions.map(i => expressions(i))
val parts = filteredPartitions.flatten.groupBy(part => {
val row = part.asInstanceOf[HasPartitionKey].partitionKey()
val projectedRow = KeyGroupedPartitioning.project(
expressions, projectPositions, row)
InternalRowComparableWrapper(projectedRow, projectedExpressions)
}).map { case (wrapper, splits) => (wrapper.row, splits) }.toSeq
(parts, projectedExpressions)
case _ =>
val groupedParts = filteredPartitions.map(splits => {
assert(splits.nonEmpty && splits.head.isInstanceOf[HasPartitionKey])
(splits.head.asInstanceOf[HasPartitionKey].partitionKey(), splits)
})
(groupedParts, expressions)
}

// When partially clustered, the input partitions are not grouped by partition
// values. Here we'll need to check `commonPartitionValues` and decide how to group
Expand All @@ -149,12 +169,12 @@ case class BatchScanExec(
// should contain.
val commonPartValuesMap = spjParams.commonPartitionValues
.get
.map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2))
.map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2))
.toMap
val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) =>
// `commonPartValuesMap` should contain the part value since it's the super set.
val numSplits = commonPartValuesMap
.get(InternalRowComparableWrapper(partValue, p.expressions))
.get(InternalRowComparableWrapper(partValue, partExpressions))
assert(numSplits.isDefined, s"Partition value $partValue does not exist in " +
"common partition values from Spark plan")

Expand All @@ -169,37 +189,37 @@ case class BatchScanExec(
// sides of a join will have the same number of partitions & splits.
splits.map(Seq(_)).padTo(numSplits.get, Seq.empty)
}
(InternalRowComparableWrapper(partValue, p.expressions), newSplits)
(InternalRowComparableWrapper(partValue, partExpressions), newSplits)
}

// Now fill missing partition keys with empty partitions
val partitionMapping = nestGroupedPartitions.toMap
finalPartitions = spjParams.commonPartitionValues.get.flatMap {
spjParams.commonPartitionValues.get.flatMap {
case (partValue, numSplits) =>
// Use empty partition for those partition values that are not present.
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, p.expressions),
InternalRowComparableWrapper(partValue, partExpressions),
Seq.fill(numSplits)(Seq.empty))
}
} else {
// either `commonPartitionValues` is not defined, or it is defined but
// `applyPartialClustering` is false.
val partitionMapping = groupedPartitions.map { case (partValue, splits) =>
InternalRowComparableWrapper(partValue, p.expressions) -> splits
InternalRowComparableWrapper(partValue, partExpressions) -> splits
}.toMap

// In case `commonPartitionValues` is not defined (e.g., SPJ is not used), there
// could exist duplicated partition values, as partition grouping is not done
// at the beginning and postponed to this method. It is important to use unique
// partition values here so that grouped partitions won't get duplicated.
finalPartitions = p.uniquePartitionValues.map { partValue =>
p.uniquePartitionValues.map { partValue =>
// Use empty partition for those partition values that are not present
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, p.expressions), Seq.empty)
InternalRowComparableWrapper(partValue, partExpressions), Seq.empty)
}
}

case _ =>
case _ => filteredPartitions
}

new DataSourceRDD(
Expand Down Expand Up @@ -234,6 +254,7 @@ case class BatchScanExec(

case class StoragePartitionJoinParams(
keyGroupedPartitioning: Option[Seq[Expression]] = None,
joinKeyPositions: Option[Seq[Int]] = None,
commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
applyPartialClustering: Boolean = false,
replicatePartitions: Boolean = false) {
Expand All @@ -247,6 +268,7 @@ case class StoragePartitionJoinParams(
}

override def hashCode(): Int = Objects.hashCode(
joinKeyPositions: Option[Seq[Int]],
commonPartitionValues: Option[Seq[(InternalRow, Int)]],
applyPartialClustering: java.lang.Boolean,
replicatePartitions: java.lang.Boolean)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,8 @@ case class EnsureRequirements(
val rightSpec = specs(1)

var isCompatible = false
if (!conf.v2BucketingPushPartValuesEnabled) {
if (!conf.v2BucketingPushPartValuesEnabled &&
!conf.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
isCompatible = leftSpec.isCompatibleWith(rightSpec)
} else {
logInfo("Pushing common partition values for storage-partitioned join")
Expand Down Expand Up @@ -505,10 +506,10 @@ case class EnsureRequirements(
}

// Now we need to push-down the common partition key to the scan in each child
newLeft = populatePartitionValues(
left, mergedPartValues, applyPartialClustering, replicateLeftSide)
newRight = populatePartitionValues(
right, mergedPartValues, applyPartialClustering, replicateRightSide)
newLeft = populatePartitionValues(left, mergedPartValues, leftSpec.joinKeyPositions,
applyPartialClustering, replicateLeftSide)
newRight = populatePartitionValues(right, mergedPartValues, rightSpec.joinKeyPositions,
applyPartialClustering, replicateRightSide)
}
}

Expand All @@ -530,19 +531,21 @@ case class EnsureRequirements(
private def populatePartitionValues(
plan: SparkPlan,
values: Seq[(InternalRow, Int)],
joinKeyPositions: Option[Seq[Int]],
applyPartialClustering: Boolean,
replicatePartitions: Boolean): SparkPlan = plan match {
case scan: BatchScanExec =>
scan.copy(
spjParams = scan.spjParams.copy(
commonPartitionValues = Some(values),
joinKeyPositions = joinKeyPositions,
applyPartialClustering = applyPartialClustering,
replicatePartitions = replicatePartitions
)
)
case node =>
node.mapChildren(child => populatePartitionValues(
child, values, applyPartialClustering, replicatePartitions))
child, values, joinKeyPositions, applyPartialClustering, replicatePartitions))
}

/**
Expand Down
Loading