diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 9098fe2e3..9b7fb3cb8 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -39,6 +39,43 @@ import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus */ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { + test("Ensure traversed operators during finding first partial aggregation are all native") { + withTable("lineitem", "part") { + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true") { + + sql( + "CREATE TABLE lineitem(l_extendedprice DOUBLE, l_quantity DOUBLE, l_partkey STRING) USING PARQUET") + sql("INSERT INTO TABLE lineitem VALUES (1.0, 1.0, '1')") + + sql( + "CREATE TABLE part(p_partkey STRING, p_brand STRING, p_container STRING) USING PARQUET") + sql("INSERT INTO TABLE part VALUES ('1', 'Brand#23', 'MED BOX')") + + val df = sql("""select + sum(l_extendedprice) / 7.0 as avg_yearly + from + lineitem, + part + where + p_partkey = l_partkey + and p_brand = 'Brand#23' + and p_container = 'MED BOX' + and l_quantity < ( + select + 0.2 * avg(l_quantity) + from + lineitem + where + l_partkey = p_partkey + )""") + checkAnswer(df, Row(null)) + } + } + } + test("SUM decimal supports emit.first") { withSQLConf( SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> EliminateSorts.ruleName,