Skip to content

Commit

Permalink
fix: Comet should not translate try_sum to native sum expression
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Apr 16, 2024
1 parent 421f0e0 commit ee349ad
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
inputs: Seq[Attribute],
binding: Boolean): Option[AggExpr] = {
aggExpr.aggregateFunction match {
case s @ Sum(child, _) if sumDataTypeSupported(s.dataType) =>
case s @ Sum(child, evalMode)
if sumDataTypeSupported(s.dataType) &&
evalMode == EvalMode.LEGACY =>
val childExpr = exprToProto(child, inputs, binding)
val dataType = serializeDataType(s.dataType)

Expand All @@ -220,7 +222,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
} else {
None
}
case s @ Average(child, _) if avgDataTypeSupported(s.dataType) =>
case s @ Average(child, evalMode)
if avgDataTypeSupported(s.dataType) &&
evalMode == EvalMode.LEGACY =>
val childExpr = exprToProto(child, inputs, binding)
val dataType = serializeDataType(s.dataType)

Expand Down
17 changes: 16 additions & 1 deletion spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

package org.apache.comet.exec

import java.time.{Duration, Period}

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.util.Random
Expand All @@ -38,7 +40,7 @@ import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, SQLExecuti
import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.functions.{date_add, expr, sum}
import org.apache.spark.sql.functions.{col, date_add, expr, sum}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
import org.apache.spark.unsafe.types.UTF8String
Expand All @@ -58,6 +60,19 @@ class CometExecSuite extends CometTestBase {
}
}

test("try_sum should return null if overflow happens before merging") {
val longDf = Seq(Long.MaxValue, Long.MaxValue, 2).toDF("v")
val yearMonthDf = Seq(Int.MaxValue, Int.MaxValue, 2)
.map(Period.ofMonths)
.toDF("v")
val dayTimeDf = Seq(106751991L, 106751991L, 2L)
.map(Duration.ofDays)
.toDF("v")
Seq(longDf, yearMonthDf, dayTimeDf).foreach { df =>
checkSparkAnswer(df.repartitionByRange(2, col("v")).selectExpr("try_sum(v)"))
}
}

test("Fix corrupted AggregateMode when transforming plan parameters") {
withParquetTable((0 until 5).map(i => (i, i + 1)), "table") {
val df = sql("SELECT * FROM table").groupBy($"_1").agg(sum("_2"))
Expand Down

0 comments on commit ee349ad

Please sign in to comment.