Skip to content

Commit

Permalink
Review comment
Browse files Browse the repository at this point in the history
  • Loading branch information
szehon-ho committed Aug 24, 2023
1 parent ea88ab5 commit 8319e43
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 229 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,13 @@ case class KeyGroupedPartitioning(
// We'll need to find leaf attributes from the partition expressions first.
val attributes = expressions.flatMap(_.collectLeaves())

// Support only when all cluster key have an associated partition expression key
requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) &&
// and if all partition expression contain only a single partition key.
expressions.forall(_.collectLeaves().size == 1)
if (SQLConf.get.getConf(
SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS)) {
requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) &&
expressions.forall(_.collectLeaves().size == 1)
} else {
attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
}
}

case _ =>
Expand Down Expand Up @@ -705,33 +708,29 @@ case class KeyGroupedShuffleSpec(
case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution) =>
distribution.clustering.length == otherDistribution.clustering.length &&
numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
isPartitioningCompatible(otherPartitioning)
partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {
case (left, right) =>
InternalRowComparableWrapper(left, partitioning.expressions)
.equals(InternalRowComparableWrapper(right, partitioning.expressions))
}
case ShuffleSpecCollection(specs) =>
specs.exists(isCompatibleWith)
case _ => false
}

def isPartitioningCompatible(otherPartitioning: KeyGroupedPartitioning): Boolean = {
val clusterKeySize = keyPositions.size
val joinKeyPositions = keyPositions.map(_.nonEmpty)
partitioning.partitionValues.zip(otherPartitioning.partitionValues)
.forall {
case (left, right) =>
val leftTypes = partitioning.expressions.map(_.dataType)
val leftVals = left.toSeq(leftTypes).take(clusterKeySize).toArray
val newLeft = new GenericInternalRow(leftVals)

val rightTypes = partitioning.expressions.map(_.dataType)
val rightVals = right.toSeq(rightTypes).take(clusterKeySize).toArray
val newRight = new GenericInternalRow(rightVals)

InternalRowComparableWrapper(newLeft, partitioning.expressions.take(clusterKeySize))
.equals(InternalRowComparableWrapper(
newRight, partitioning.expressions.take(clusterKeySize)))
KeyGroupedShuffleSpec.project(left, partitioning.expressions, joinKeyPositions)
.equals(
KeyGroupedShuffleSpec.project(right, partitioning.expressions, joinKeyPositions))
}
}

// Whether the partition keys (i.e., partition expressions) are compatible between this and the
// `other` spec.
// other spec.
def areKeysCompatible(other: KeyGroupedShuffleSpec): Boolean = {
partitionExpressionsCompatible(other) &&
KeyGroupedShuffleSpec.keyPositionsCompatible(
Expand All @@ -740,15 +739,23 @@ case class KeyGroupedShuffleSpec(
}

// Whether the partition keys (i.e., partition expressions) that also are in the set of
// cluster keys are compatible between this and the 'other' spec.
def areClusterPartitionKeysCompatible(other: KeyGroupedShuffleSpec): Boolean = {
// join keys are compatible between this and the other spec.
def areJoinKeysCompatible(other: KeyGroupedShuffleSpec): Boolean = {
partitionExpressionsCompatible(other) &&
KeyGroupedShuffleSpec.keyPositionsCompatible(
keyPositions.filter(_.nonEmpty),
other.keyPositions.filter(_.nonEmpty)
)
}

override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled &&
// Only support partition expressions are AttributeReference for now
partitioning.expressions.forall(_.isInstanceOf[AttributeReference])

override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues)
}

private def partitionExpressionsCompatible(other: KeyGroupedShuffleSpec): Boolean = {
val left = partitioning.expressions
val right = other.partitioning.expressions
Expand All @@ -757,14 +764,6 @@ case class KeyGroupedShuffleSpec(
case (l, r) => KeyGroupedShuffleSpec.isExpressionCompatible(l, r)
}
}

override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled &&
// Only support partition expressions are AttributeReference for now
partitioning.expressions.forall(_.isInstanceOf[AttributeReference])

override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues)
}
}

object KeyGroupedShuffleSpec {
Expand All @@ -783,6 +782,19 @@ object KeyGroupedShuffleSpec {
left.intersect(right).nonEmpty
}
}

def project(row: InternalRow, expressions: Seq[Expression],
joinKeyPositions: Seq[Boolean]): InternalRowComparableWrapper = {
val projectedExprs = expressions.zip(joinKeyPositions)
.filter(_._2)
.map(_._1)
val projectedVals = row.toSeq(expressions.map(_.dataType))
.zip(joinKeyPositions)
.filter(_._2)
.map(_._1)
val newRow = InternalRow.fromSeq(projectedVals)
InternalRowComparableWrapper(newRow, projectedExprs)
}
}

case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1520,7 +1520,7 @@ object SQLConf {
)
.version("4.0.0")
.booleanConf
.createWithDefault(true)
.createWithDefault(false)

val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
.doc("The maximum number of buckets allowed.")
Expand Down Expand Up @@ -4921,13 +4921,11 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def v2BucketingPartiallyClusteredDistributionEnabled: Boolean =
getConf(SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED)

<<<<<<< HEAD
def v2BucketingShuffleEnabled: Boolean =
getConf(SQLConf.V2_BUCKETING_SHUFFLE_ENABLED)
=======

def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean =
getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS)
>>>>>>> c431abb6dd5 ([SQL][SPARK-44647] Support SPJ where join keys are less than cluster keys)

def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)
Expand Down
Loading

0 comments on commit 8319e43

Please sign in to comment.