Skip to content

Commit

Permalink
[GLUTEN-7450][CORE] Improve CollectRewriteRule for Velox
Browse files Browse the repository at this point in the history
  • Loading branch information
beliefer committed Oct 9, 2024
1 parent dfbf226 commit 0dda3c3
Showing 1 changed file with 21 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ import org.apache.gluten.expression.aggregate.{VeloxCollectList, VeloxCollectSet
import org.apache.gluten.utils.LogicalPlanSelector

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{And, Coalesce, Expression, IsNotNull, Literal, WindowExpression}
import org.apache.spark.sql.catalyst.expressions.{And, Coalesce, Expression, IsNotNull, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Window}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE_EXPRESSION, AVERAGE, WINDOW, WINDOW_EXPRESSION}
import org.apache.spark.sql.types.ArrayType

import scala.reflect.{classTag, ClassTag}
Expand All @@ -37,44 +38,21 @@ import scala.reflect.{classTag, ClassTag}
case class CollectRewriteRule(spark: SparkSession) extends Rule[LogicalPlan] {
import CollectRewriteRule._
override def apply(plan: LogicalPlan): LogicalPlan = LogicalPlanSelector.maybe(spark, plan) {
val out = plan.transformUp {
case node =>
val out = replaceCollectSet(replaceCollectList(node))
out
}
if (out.fastEquals(plan)) {
if (!has[VeloxCollectSet] && has[VeloxCollectList]) {
return plan
}
out
}

private def replaceCollectList(node: LogicalPlan): LogicalPlan = {
node.transformExpressions {
case func @ AggregateExpression(l: CollectList, _, _, _, _) if has[VeloxCollectList] =>
func.copy(VeloxCollectList(l.child))
}
}

private def replaceCollectSet(node: LogicalPlan): LogicalPlan = {
// 1. Replace null result from VeloxCollectSet with empty array to align with
// vanilla Spark.
// 2. Filter out null inputs from VeloxCollectSet to align with vanilla Spark.
//
// Since https://github.com/apache/incubator-gluten/pull/4805
node match {
plan.transformWithPruning(_.containsAnyPattern(AVERAGE, WINDOW)) {
case agg: Aggregate =>
agg.transformExpressions {
case ToVeloxCollectSet(newAggFunc) =>
val out = ensureNonNull(newAggFunc)
out
agg.transformExpressionsWithPruning(_.containsPattern(AGGREGATE_EXPRESSION)) {
case ToVeloxCollect(newAggFunc) =>
newAggFunc
}
case w: Window =>
w.transformExpressions {
case func @ WindowExpression(ToVeloxCollectSet(newAggFunc), _) =>
val out = ensureNonNull(func.copy(newAggFunc))
out
w.transformExpressionsWithPruning(_.containsPattern(WINDOW_EXPRESSION)) {
case ToVeloxCollect(newAggFunc) =>
newAggFunc
}
case other => other

}
}
}
Expand All @@ -88,13 +66,18 @@ object CollectRewriteRule {
out
}

private object ToVeloxCollectSet {
private object ToVeloxCollect {
def unapply(expr: Expression): Option[Expression] = expr match {
case aggFunc @ AggregateExpression(s: CollectSet, _, _, filter, _) if has[VeloxCollectSet] =>
case ae @ AggregateExpression(s: CollectList, _, _, filter, _) if has[VeloxCollectList] =>
val newFilter = (filter ++ Some(IsNotNull(s.child))).reduceOption(And)
val newAggExpr =
ae.copy(aggregateFunction = VeloxCollectList(s.child), filter = newFilter)
Some(newAggExpr)
case ae @ AggregateExpression(s: CollectSet, _, _, filter, _) if has[VeloxCollectSet] =>
val newFilter = (filter ++ Some(IsNotNull(s.child))).reduceOption(And)
val newAggFunc =
aggFunc.copy(aggregateFunction = VeloxCollectSet(s.child), filter = newFilter)
Some(newAggFunc)
val newAggExpr =
ae.copy(aggregateFunction = VeloxCollectSet(s.child), filter = newFilter)
Some(ensureNonNull(newAggExpr))
case _ => None
}
}
Expand Down

0 comments on commit 0dda3c3

Please sign in to comment.