Skip to content

Commit

Permalink
Add compatibility with Spark 3.2 and 3.3
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Apr 22, 2024
1 parent d1257ea commit b4c698c
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.shims

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.physical.Partitioning

trait ShimCometBroadcastHashJoinExec {

/**
* Returns the expressions that are used for hash partitioning including `HashPartitioning` and
* `CoalescedHashPartitioning`. They shares same trait `HashPartitioningLike` since Spark 3.4,
* but Spark 3.2/3.3 doesn't have `HashPartitioningLike` and `CoalescedHashPartitioning`.
*
* TODO: remove after dropping Spark 3.2 and 3.3 support.
*/
def getHashPartitioningLikeExpressions(partitioning: Partitioning): Seq[Expression] = {
partitioning.getClass.getDeclaredMethods
.filter(_.getName == "expressions")
.flatMap(_.invoke(partitioning).asInstanceOf[Seq[Expression]])
}
}
23 changes: 15 additions & 8 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expre
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateMode}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioningLike, Partitioning, PartitioningCollection, UnknownPartitioning}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning}
import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec}
import org.apache.spark.sql.comet.plans.PartitioningPreservingUnaryExecNode
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, PartitioningPreservingUnaryExecNode, ScalarSubquery, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
Expand All @@ -49,6 +50,7 @@ import com.google.common.base.Objects

import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException}
import org.apache.comet.serde.OperatorOuterClass.Operator
import org.apache.comet.shims.ShimCometBroadcastHashJoinExec

/**
* A Comet physical operator
Expand Down Expand Up @@ -705,7 +707,8 @@ case class CometBroadcastHashJoinExec(
override val left: SparkPlan,
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometBinaryExec {
extends CometBinaryExec
with ShimCometBroadcastHashJoinExec {

// The following logic of `outputPartitioning` is copied from Spark `BroadcastHashJoinExec`.
protected lazy val streamedPlan: SparkPlan = buildSide match {
Expand All @@ -717,7 +720,9 @@ case class CometBroadcastHashJoinExec(
joinType match {
case _: InnerLike if conf.broadcastHashJoinOutputPartitioningExpandLimit > 0 =>
streamedPlan.outputPartitioning match {
case h: HashPartitioningLike => expandOutputPartitioning(h)
case h: HashPartitioning => expandOutputPartitioning(h)
case h: Expression if h.getClass.getName.contains("CoalescedHashPartitioning") =>
expandOutputPartitioning(h)
case c: PartitioningCollection => expandOutputPartitioning(c)
case other => other
}
Expand Down Expand Up @@ -756,7 +761,9 @@ case class CometBroadcastHashJoinExec(
private def expandOutputPartitioning(
partitioning: PartitioningCollection): PartitioningCollection = {
PartitioningCollection(partitioning.partitionings.flatMap {
case h: HashPartitioningLike => expandOutputPartitioning(h).partitionings
case h: HashPartitioning => expandOutputPartitioning(h).partitionings
case h: Expression if h.getClass.getName.contains("CoalescedHashPartitioning") =>
expandOutputPartitioning(h).partitionings
case c: PartitioningCollection => Seq(expandOutputPartitioning(c))
case other => Seq(other)
})
Expand All @@ -769,7 +776,7 @@ case class CometBroadcastHashJoinExec(
// Seq("a", "b", "c"), Seq("a", "b", "y"), Seq("a", "x", "c"), Seq("a", "x", "y").
// The expanded expressions are returned as PartitioningCollection.
private def expandOutputPartitioning(
partitioning: HashPartitioningLike): PartitioningCollection = {
partitioning: Partitioning with Expression): PartitioningCollection = {
val maxNumCombinations = conf.broadcastHashJoinOutputPartitioningExpandLimit
var currentNumCombinations = 0

Expand All @@ -791,8 +798,8 @@ case class CometBroadcastHashJoinExec(
}

PartitioningCollection(
generateExprCombinations(partitioning.expressions, Nil)
.map(exprs => partitioning.withNewChildren(exprs).asInstanceOf[HashPartitioningLike]))
generateExprCombinations(getHashPartitioningLikeExpressions(partitioning), Nil)
.map(exprs => partitioning.withNewChildren(exprs).asInstanceOf[Partitioning]))
}

override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.comet.plans

import scala.collection.mutable

import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.internal.SQLConf

/**
* A trait that provides functionality to handle aliases in the `outputExpressions`.
*/
trait AliasAwareOutputExpression extends SQLConfHelper {
// `SQLConf.EXPRESSION_PROJECTION_CANDIDATE_LIMIT` is Spark 3.4+ only.
// Use a default value for now.
protected val aliasCandidateLimit = 100
protected def outputExpressions: Seq[NamedExpression]

/**
* This method can be used to strip expression which does not affect the result, for example:
* strip the expression which is ordering agnostic for output ordering.
*/
protected def strip(expr: Expression): Expression = expr

// Build an `Expression` -> `Attribute` alias map.
// There can be multiple alias defined for the same expressions but it doesn't make sense to store
// more than `aliasCandidateLimit` attributes for an expression. In those cases the old logic
// handled only the last alias so we need to make sure that we give precedence to that.
// If the `outputExpressions` contain simple attributes we need to add those too to the map.
@transient
private lazy val aliasMap = {
val aliases = mutable.Map[Expression, mutable.ArrayBuffer[Attribute]]()
outputExpressions.reverse.foreach {
case a @ Alias(child, _) =>
val buffer =
aliases.getOrElseUpdate(strip(child).canonicalized, mutable.ArrayBuffer.empty)
if (buffer.size < aliasCandidateLimit) {
buffer += a.toAttribute
}
case _ =>
}
outputExpressions.foreach {
case a: Attribute if aliases.contains(a.canonicalized) =>
val buffer = aliases(a.canonicalized)
if (buffer.size < aliasCandidateLimit) {
buffer += a
}
case _ =>
}
aliases
}

protected def hasAlias: Boolean = aliasMap.nonEmpty

/**
* Return a stream of expressions in which the original expression is projected with `aliasMap`.
*/
protected def projectExpression(expr: Expression): Stream[Expression] = {
val outputSet = AttributeSet(outputExpressions.map(_.toAttribute))
multiTransformDown(expr) {
// Mapping with aliases
case e: Expression if aliasMap.contains(e.canonicalized) =>
aliasMap(e.canonicalized).toSeq ++ (if (e.containsChild.nonEmpty) Seq(e) else Seq.empty)

// Prune if we encounter an attribute that we can't map and it is not in output set.
// This prune will go up to the closest `multiTransformDown()` call and returns `Stream.empty`
// there.
case a: Attribute if !outputSet.contains(a) => Seq.empty
}
}

// Copied from Spark 3.4+ to make it available in Spark 3.2+.
def multiTransformDown(expr: Expression)(
rule: PartialFunction[Expression, Seq[Expression]]): Stream[Expression] = {

// We could return `Seq(this)` if the `rule` doesn't apply and handle both
// - the doesn't apply
// - and the rule returns a one element `Seq(originalNode)`
// cases together. The returned `Seq` can be a `Stream` and unfortunately it doesn't seem like
// there is a way to match on a one element stream without eagerly computing the tail's head.
// This contradicts with the purpose of only taking the necessary elements from the
// alternatives. I.e. the "multiTransformDown is lazy" test case in `TreeNodeSuite` would fail.
// Please note that this behaviour has a downside as well that we can only mark the rule on the
// original node ineffective if the rule didn't match.
var ruleApplied = true
val afterRules = CurrentOrigin.withOrigin(expr.origin) {
rule.applyOrElse(
expr,
(_: Expression) => {
ruleApplied = false
Seq.empty
})
}

val afterRulesStream = if (afterRules.isEmpty) {
if (ruleApplied) {
// If the rule returned with empty alternatives then prune
Stream.empty
} else {
// If the rule was not applied then keep the original node
Stream(expr)
}
} else {
// If the rule was applied then use the returned alternatives
afterRules.toStream.map { afterRule =>
if (expr fastEquals afterRule) {
expr
} else {
afterRule.copyTagsFrom(expr)
afterRule
}
}
}

afterRulesStream.flatMap { afterRule =>
if (afterRule.containsChild.nonEmpty) {
generateCartesianProduct(afterRule.children.map(c => () => multiTransformDown(c)(rule)))
.map(afterRule.withNewChildren)
} else {
Stream(afterRule)
}
}
}

def generateCartesianProduct[T](elementSeqs: Seq[() => Seq[T]]): Stream[Seq[T]] = {
elementSeqs.foldRight(Stream(Seq.empty[T]))((elements, elementTails) =>
for {
elementTail <- elementTails
element <- elements()
} yield element +: elementTail)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.comet.plans

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression}
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection, UnknownPartitioning}
import org.apache.spark.sql.execution.UnaryExecNode

/**
* A trait that handles aliases in the `outputExpressions` to produce `outputPartitioning` that
* satisfies distribution requirements.
*
* This is copied from Spark's `PartitioningPreservingUnaryExecNode` because it is only available
* in Spark 3.4+. This is a workaround to make it available in Spark 3.2+.
*/
trait PartitioningPreservingUnaryExecNode extends UnaryExecNode with AliasAwareOutputExpression {
final override def outputPartitioning: Partitioning = {
val partitionings: Seq[Partitioning] = if (hasAlias) {
flattenPartitioning(child.outputPartitioning).flatMap {
case e: Expression =>
// We need unique partitionings but if the input partitioning is
// `HashPartitioning(Seq(id + id))` and we have `id -> a` and `id -> b` aliases then after
// the projection we have 4 partitionings:
// `HashPartitioning(Seq(a + a))`, `HashPartitioning(Seq(a + b))`,
// `HashPartitioning(Seq(b + a))`, `HashPartitioning(Seq(b + b))`, but
// `HashPartitioning(Seq(a + b))` is the same as `HashPartitioning(Seq(b + a))`.
val partitioningSet = mutable.Set.empty[Expression]
projectExpression(e)
.filter(e => partitioningSet.add(e.canonicalized))
.take(aliasCandidateLimit)
.asInstanceOf[Stream[Partitioning]]
case o => Seq(o)
}
} else {
// Filter valid partitiongs (only reference output attributes of the current plan node)
val outputSet = AttributeSet(outputExpressions.map(_.toAttribute))
flattenPartitioning(child.outputPartitioning).filter {
case e: Expression => e.references.subsetOf(outputSet)
case _ => true
}
}
partitionings match {
case Seq() => UnknownPartitioning(child.outputPartitioning.numPartitions)
case Seq(p) => p
case ps => PartitioningCollection(ps)
}
}

private def flattenPartitioning(partitioning: Partitioning): Seq[Partitioning] = {
partitioning match {
case PartitioningCollection(childPartitionings) =>
childPartitionings.flatMap(flattenPartitioning)
case rest =>
rest +: Nil
}
}
}

0 comments on commit b4c698c

Please sign in to comment.