diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/CollectRewriteRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/CollectRewriteRule.scala index d7299c511e15..5fbb7c4e235b 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/CollectRewriteRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/CollectRewriteRule.scala @@ -21,10 +21,11 @@ import org.apache.gluten.expression.aggregate.{VeloxCollectList, VeloxCollectSet import org.apache.gluten.utils.LogicalPlanSelector import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.expressions.{And, Coalesce, Expression, IsNotNull, Literal, WindowExpression} +import org.apache.spark.sql.catalyst.expressions.{And, Coalesce, Expression, IsNotNull, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Window} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE_EXPRESSION, AVERAGE, WINDOW, WINDOW_EXPRESSION} import org.apache.spark.sql.types.ArrayType import scala.reflect.{classTag, ClassTag} @@ -37,44 +38,21 @@ import scala.reflect.{classTag, ClassTag} case class CollectRewriteRule(spark: SparkSession) extends Rule[LogicalPlan] { import CollectRewriteRule._ override def apply(plan: LogicalPlan): LogicalPlan = LogicalPlanSelector.maybe(spark, plan) { - val out = plan.transformUp { - case node => - val out = replaceCollectSet(replaceCollectList(node)) - out - } - if (out.fastEquals(plan)) { + if (!has[VeloxCollectSet] && has[VeloxCollectList]) { return plan } - out - } - - private def replaceCollectList(node: LogicalPlan): LogicalPlan = { - node.transformExpressions { - case func @ AggregateExpression(l: CollectList, _, _, _, _) if has[VeloxCollectList] => - func.copy(VeloxCollectList(l.child)) - } - } - - private def replaceCollectSet(node: LogicalPlan): LogicalPlan = { - // 1. Replace null result from VeloxCollectSet with empty array to align with - // vanilla Spark. - // 2. Filter out null inputs from VeloxCollectSet to align with vanilla Spark. - // - // Since https://github.com/apache/incubator-gluten/pull/4805 - node match { + plan.transformWithPruning(_.containsAnyPattern(AVERAGE, WINDOW)) { case agg: Aggregate => - agg.transformExpressions { - case ToVeloxCollectSet(newAggFunc) => - val out = ensureNonNull(newAggFunc) - out + agg.transformExpressionsWithPruning(_.containsPattern(AGGREGATE_EXPRESSION)) { + case ToVeloxCollect(newAggFunc) => + newAggFunc } case w: Window => - w.transformExpressions { - case func @ WindowExpression(ToVeloxCollectSet(newAggFunc), _) => - val out = ensureNonNull(func.copy(newAggFunc)) - out + w.transformExpressionsWithPruning(_.containsPattern(WINDOW_EXPRESSION)) { + case ToVeloxCollect(newAggFunc) => + newAggFunc } - case other => other + } } } @@ -88,13 +66,18 @@ object CollectRewriteRule { out } - private object ToVeloxCollectSet { + private object ToVeloxCollect { def unapply(expr: Expression): Option[Expression] = expr match { - case aggFunc @ AggregateExpression(s: CollectSet, _, _, filter, _) if has[VeloxCollectSet] => + case ae @ AggregateExpression(s: CollectList, _, _, filter, _) if has[VeloxCollectList] => + val newFilter = (filter ++ Some(IsNotNull(s.child))).reduceOption(And) + val newAggExpr = + ae.copy(aggregateFunction = VeloxCollectList(s.child), filter = newFilter) + Some(newAggExpr) + case ae @ AggregateExpression(s: CollectSet, _, _, filter, _) if has[VeloxCollectSet] => val newFilter = (filter ++ Some(IsNotNull(s.child))).reduceOption(And) - val newAggFunc = - aggFunc.copy(aggregateFunction = VeloxCollectSet(s.child), filter = newFilter) - Some(newAggFunc) + val newAggExpr = + ae.copy(aggregateFunction = VeloxCollectSet(s.child), filter = newFilter) + Some(ensureNonNull(newAggExpr)) case _ => None } }