Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Odd even test spj uneven buckets #3

Open
wants to merge 8 commits into
base: spj-uneven-buckets
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package org.apache.spark.sql.connector.catalog.functions;

import org.apache.spark.annotation.Evolving;

/**
* A 'reducer' for output of user-defined functions.
*
* A user_defined function f_source(x) is 'reducible' on another user_defined function f_target(x),
* if there exists a 'reducer' r(x) such that r(f_source(x)) = f_target(x) for all input x.
* @param <T> function output type
* @since 4.0.0
*/
@Evolving
public interface Reducer<T> {
T reduce(T arg1);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package org.apache.spark.sql.connector.catalog.functions;

import org.apache.spark.annotation.Evolving;
import scala.Option;

/**
* Base class for user-defined functions that can be 'reduced' on another function.
*
* A function f_source(x) is 'reducible' on another function f_target(x) if
* there exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for all input x.
*
* @since 4.0.0
*/
@Evolving
public interface ReducibleFunction<T, A> extends ScalarFunction<T> {

/**
* If this function is 'reducible' on another function, return the {@link Reducer} function.
* @param other other function
* @param thisArgument argument for this function instance
* @param otherArgument argument for other function instance
* @return a reduction function if it is reducible, none if not
*/
Option<Reducer<A>> reducer(ReducibleFunction<?, ?> other, Option<?> thisArgument, Option<?> otherArgument);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.connector.catalog.functions.BoundFunction
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ReducibleFunction}
import org.apache.spark.sql.types.DataType

/**
Expand Down Expand Up @@ -54,6 +54,31 @@ case class TransformExpression(
false
}

/**
* Whether this [[TransformExpression]]'s function is compatible with the `other`
* [[TransformExpression]]'s function.
*
* This is true if both are instances of [[ReducibleFunction]] and there exists a [[Reducer]] r(x)
* such that r(t1(x)) = t2(x), or r(t2(x)) = t1(x), for all input x.
*
* @param other the transform expression to compare to
* @return true if compatible, false if not
*/
def isCompatible(other: TransformExpression): Boolean = {
if (isSameFunction(other)) {
true
} else {
(function, other.function) match {
case (f: ReducibleFunction[Any, Any] @unchecked,
o: ReducibleFunction[Any, Any] @unchecked) =>
val reducer = f.reducer(o, numBucketsOpt, other.numBucketsOpt)
val otherReducer = o.reducer(f, other.numBucketsOpt, numBucketsOpt)
reducer.isDefined || otherReducer.isDefined
case _ => false
}
}
}

override def dataType: DataType = function.resultType()

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
import org.apache.spark.sql.connector.catalog.functions.{Reducer, ReducibleFunction}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, IntegerType}

Expand Down Expand Up @@ -635,6 +636,22 @@ trait ShuffleSpec {
*/
def createPartitioning(clustering: Seq[Expression]): Partitioning =
throw SparkUnsupportedOperationException()

/**
* Return a set of [[Reducer]] for the partition expressions of this shuffle spec,
* on the partition expressions of another shuffle spec.
* <p>
* A [[Reducer]] exists for a partition expression function of this shuffle spec if it is
* 'reducible' on the corresponding partition expression function of the other shuffle spec.
* <p>
* If a value is returned, there must be one Option[[Reducer]] per partition expression.
* A None value in the set indicates that the particular partition expression is not reducible
* on the corresponding expression on the other shuffle spec.
* <p>
* Returning none also indicates that none of the partition expressions can be reduced on the
* corresponding expression on the other shuffle spec.
*/
def reducers(spec: ShuffleSpec): Option[Seq[Option[Reducer[Any]]]] = None
}

case object SinglePartitionShuffleSpec extends ShuffleSpec {
Expand Down Expand Up @@ -829,20 +846,60 @@ case class KeyGroupedShuffleSpec(
}
}

override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled &&
// Only support partition expressions are AttributeReference for now
partitioning.expressions.forall(_.isInstanceOf[AttributeReference])

override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues)
}

override def reducers(other: ShuffleSpec): Option[Seq[Option[Reducer[Any]]]] = {
other match {
case otherSpec: KeyGroupedShuffleSpec =>
val results = partitioning.expressions.zip(otherSpec.partitioning.expressions).map {
case (e1: TransformExpression, e2: TransformExpression)
if e1.function.isInstanceOf[ReducibleFunction[Any, Any]@unchecked]
&& e2.function.isInstanceOf[ReducibleFunction[Any, Any]@unchecked] =>
e1.function.asInstanceOf[ReducibleFunction[Any, Any]].reducer(
e2.function.asInstanceOf[ReducibleFunction[Any, Any]],
e1.numBucketsOpt.map(a => a.asInstanceOf[Any]),
e2.numBucketsOpt.map(a => a.asInstanceOf[Any]))
case (_, _) => None
}

// optimize to not return a value, if none of the partition expressions need reducing
if (results.forall(p => p.isEmpty)) None else Some(results)
case _ => None
}
}

private def isExpressionCompatible(left: Expression, right: Expression): Boolean =
(left, right) match {
case (_: LeafExpression, _: LeafExpression) => true
case (left: TransformExpression, right: TransformExpression) =>
left.isSameFunction(right)
if (SQLConf.get.v2BucketingPushPartValuesEnabled &&
!SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled &&
SQLConf.get.v2BucketingAllowCompatibleTransforms) {
left.isCompatible(right)
} else {
left.isSameFunction(right)
}
case _ => false
}
}

override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled &&
// Only support partition expressions are AttributeReference for now
partitioning.expressions.forall(_.isInstanceOf[AttributeReference])

override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues)
object KeyGroupedShuffleSpec {
def reducePartitionValue(row: InternalRow,
expressions: Seq[Expression],
reducers: Seq[Option[Reducer[Any]]]):
InternalRowComparableWrapper = {
val partitionVals = row.toSeq(expressions.map(_.dataType))
val reducedRow = partitionVals.zip(reducers).map{
case (v, Some(reducer)) => reducer.reduce(v)
case (v, _) => v
}.toArray
InternalRowComparableWrapper(new GenericInternalRow(reducedRow), expressions)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1537,6 +1537,18 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS =
buildConf("spark.sql.sources.v2.bucketing.allowCompatibleTransforms.enabled")
.doc("Whether to allow storage-partition join in the case where the partition transforms" +
"are compatible but not identical. This config requires both " +
s"${V2_BUCKETING_ENABLED.key} and ${V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key} to be " +
s"enabled and ${V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " +
"to be disabled."
)
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
.doc("The maximum number of buckets allowed.")
.version("2.4.0")
Expand Down Expand Up @@ -5201,6 +5213,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean =
getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS)

def v2BucketingAllowCompatibleTransforms: Boolean =
getConf(SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS)

def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, KeyGroupedShuffleSpec, Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper}
import org.apache.spark.sql.connector.catalog.Table
import org.apache.spark.sql.connector.catalog.functions.Reducer
import org.apache.spark.sql.connector.read._
import org.apache.spark.util.ArrayImplicits._

Expand Down Expand Up @@ -164,6 +165,18 @@ case class BatchScanExec(
(groupedParts, expressions)
}

// Also re-group the partitions if we are reducing compatible partition expressions
val finalGroupedPartitions = spjParams.reducers match {
case Some(reducers) =>
val result = groupedPartitions.groupBy { case (row, _) =>
KeyGroupedShuffleSpec.reducePartitionValue(row, partExpressions, reducers)
}.map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) }.toSeq
val rowOrdering = RowOrdering.createNaturalAscendingOrdering(
partExpressions.map(_.dataType))
result.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1))
case _ => groupedPartitions
}

// When partially clustered, the input partitions are not grouped by partition
// values. Here we'll need to check `commonPartitionValues` and decide how to group
// and replicate splits within a partition.
Expand All @@ -174,7 +187,7 @@ case class BatchScanExec(
.get
.map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2))
.toMap
val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) =>
val nestGroupedPartitions = finalGroupedPartitions.map { case (partValue, splits) =>
// `commonPartValuesMap` should contain the part value since it's the super set.
val numSplits = commonPartValuesMap
.get(InternalRowComparableWrapper(partValue, partExpressions))
Expand Down Expand Up @@ -207,7 +220,7 @@ case class BatchScanExec(
} else {
// either `commonPartitionValues` is not defined, or it is defined but
// `applyPartialClustering` is false.
val partitionMapping = groupedPartitions.map { case (partValue, splits) =>
val partitionMapping = finalGroupedPartitions.map { case (partValue, splits) =>
InternalRowComparableWrapper(partValue, partExpressions) -> splits
}.toMap

Expand Down Expand Up @@ -259,6 +272,7 @@ case class StoragePartitionJoinParams(
keyGroupedPartitioning: Option[Seq[Expression]] = None,
joinKeyPositions: Option[Seq[Int]] = None,
commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
reducers: Option[Seq[Option[Reducer[Any]]]] = None,
applyPartialClustering: Boolean = false,
replicatePartitions: Boolean = false) {
override def equals(other: Any): Boolean = other match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
import org.apache.spark.sql.connector.catalog.functions.Reducer
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
Expand Down Expand Up @@ -505,11 +506,28 @@ case class EnsureRequirements(
}
}

// Now we need to push-down the common partition key to the scan in each child
newLeft = populatePartitionValues(left, mergedPartValues, leftSpec.joinKeyPositions,
applyPartialClustering, replicateLeftSide)
newRight = populatePartitionValues(right, mergedPartValues, rightSpec.joinKeyPositions,
applyPartialClustering, replicateRightSide)
// in case of compatible but not identical partition expressions, we apply 'reduce'
// transforms to group one side's partitions as well as the common partition values
val leftReducers = leftSpec.reducers(rightSpec)
val rightReducers = rightSpec.reducers(leftSpec)

if (leftReducers.isDefined || rightReducers.isDefined) {
mergedPartValues = reduceCommonPartValues(mergedPartValues,
leftSpec.partitioning.expressions,
leftReducers)
mergedPartValues = reduceCommonPartValues(mergedPartValues,
rightSpec.partitioning.expressions,
rightReducers)
val rowOrdering = RowOrdering
.createNaturalAscendingOrdering(partitionExprs.map(_.dataType))
mergedPartValues = mergedPartValues.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1))
}

// Now we need to push-down the common partition information to the scan in each child
newLeft = populateCommonPartitionInfo(left, mergedPartValues, leftSpec.joinKeyPositions,
leftReducers, applyPartialClustering, replicateLeftSide)
newRight = populateCommonPartitionInfo(right, mergedPartValues, rightSpec.joinKeyPositions,
rightReducers, applyPartialClustering, replicateRightSide)
}
}

Expand All @@ -527,25 +545,38 @@ case class EnsureRequirements(
joinType == LeftAnti || joinType == LeftOuter
}

// Populate the common partition values down to the scan nodes
private def populatePartitionValues(
// Populate the common partition information down to the scan nodes
private def populateCommonPartitionInfo(
plan: SparkPlan,
values: Seq[(InternalRow, Int)],
joinKeyPositions: Option[Seq[Int]],
reducers: Option[Seq[Option[Reducer[Any]]]],
applyPartialClustering: Boolean,
replicatePartitions: Boolean): SparkPlan = plan match {
case scan: BatchScanExec =>
scan.copy(
spjParams = scan.spjParams.copy(
commonPartitionValues = Some(values),
joinKeyPositions = joinKeyPositions,
reducers = reducers,
applyPartialClustering = applyPartialClustering,
replicatePartitions = replicatePartitions
)
)
case node =>
node.mapChildren(child => populatePartitionValues(
child, values, joinKeyPositions, applyPartialClustering, replicatePartitions))
node.mapChildren(child => populateCommonPartitionInfo(
child, values, joinKeyPositions, reducers, applyPartialClustering, replicatePartitions))
}

private def reduceCommonPartValues(commonPartValues: Seq[(InternalRow, Int)],
expressions: Seq[Expression],
reducers: Option[Seq[Option[Reducer[Any]]]]) = {
reducers match {
case Some(reducers) => commonPartValues.groupBy { case (row, _) =>
KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers)
}.map{ case(wrapper, splits) => (wrapper.row, splits.map(_._2).sum) }.toSeq
case _ => commonPartValues
}
}

/**
Expand Down
Loading