Skip to content

Commit

Permalink
[SPARK-6451][SQL] supported code generation for CombineSum
Browse files Browse the repository at this point in the history
Author: Venkata Ramana Gollamudi <ramana.gollamudi@huawei.com>

Closes apache#5138 from gvramana/sum_fix_codegen and squashes the following commits:

95f5fe4 [Venkata Ramana Gollamudi] rebase merge changes
12f45a5 [Venkata Ramana Gollamudi] Combined and added code generations tests as per comment
d6a76ac [Venkata Ramana Gollamudi] added support for codegeneration for CombineSum and tests
  • Loading branch information
gvramana authored and marmbrus committed Apr 9, 2015
1 parent 9418280 commit 7d7384c
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ case class GeneratedAggregate(
// but really, common sub expression elimination would be better....
val zero = Cast(Literal(0), calcType)
val updateFunction = Coalesce(
Add(Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType)) :: currentSum :: Nil)
Add(
Coalesce(currentSum :: zero :: Nil),
Cast(expr, calcType)
) :: currentSum :: zero :: Nil)
val result =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
Expand All @@ -109,6 +112,45 @@ case class GeneratedAggregate(

AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)

case cs @ CombineSum(expr) =>
val calcType = expr.dataType
expr.dataType match {
case DecimalType.Fixed(_, _) =>
DecimalType.Unlimited
case _ =>
expr.dataType
}

val currentSum = AttributeReference("currentSum", calcType, nullable = true)()
val initialValue = Literal.create(null, calcType)

// Coalasce avoids double calculation...
// but really, common sub expression elimination would be better....
val zero = Cast(Literal(0), calcType)
// If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
// UnscaledValue will be null if and only if x is null; helps with Average on decimals
val actualExpr = expr match {
case UnscaledValue(e) => e
case _ => expr
}
// partial sum result can be null only when no input rows present
val updateFunction = If(
IsNotNull(actualExpr),
Coalesce(
Add(
Coalesce(currentSum :: zero :: Nil),
Cast(expr, calcType)) :: currentSum :: zero :: Nil),
currentSum)

val result =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
Cast(currentSum, cs.dataType)
case _ => currentSum
}

AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)

case a @ Average(expr) =>
val calcType =
expr.dataType match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}

def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists {
case _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false
case _: CombineSum | _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false
// The generated set implementation is pretty limited ATM.
case CollectHashSet(exprs) if exprs.size == 1 &&
Seq(IntegerType, LongType).contains(exprs.head.dataType) => false
Expand Down
92 changes: 89 additions & 3 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql
import org.apache.spark.sql.test.TestSQLContext
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
Expand Down Expand Up @@ -102,14 +103,99 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql("SELECT ABS(2.5)"),
Row(2.5))
}

test("aggregation with codegen") {
val originalValue = conf.codegenEnabled
setConf(SQLConf.CODEGEN_ENABLED, "true")
sql("SELECT key FROM testData GROUP BY key").collect()
// Prepare a table that we can group some rows.
table("testData")
.unionAll(table("testData"))
.unionAll(table("testData"))
.registerTempTable("testData3x")

def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = {
val df = sql(sqlText)
// First, check if we have GeneratedAggregate.
var hasGeneratedAgg = false
df.queryExecution.executedPlan.foreach {
case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true
case _ =>
}
if (!hasGeneratedAgg) {
fail(
s"""
|Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan.
|${df.queryExecution.simpleString}
""".stripMargin)
}
// Then, check results.
checkAnswer(df, expectedResults)
}

// Just to group rows.
testCodeGen(
"SELECT key FROM testData3x GROUP BY key",
(1 to 100).map(Row(_)))
// COUNT
testCodeGen(
"SELECT key, count(value) FROM testData3x GROUP BY key",
(1 to 100).map(i => Row(i, 3)))
testCodeGen(
"SELECT count(key) FROM testData3x",
Row(300) :: Nil)
// COUNT DISTINCT ON int
testCodeGen(
"SELECT value, count(distinct key) FROM testData3x GROUP BY value",
(1 to 100).map(i => Row(i.toString, 1)))
testCodeGen(
"SELECT count(distinct key) FROM testData3x",
Row(100) :: Nil)
// SUM
testCodeGen(
"SELECT value, sum(key) FROM testData3x GROUP BY value",
(1 to 100).map(i => Row(i.toString, 3 * i)))
testCodeGen(
"SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x",
Row(5050 * 3, 5050 * 3.0) :: Nil)
// AVERAGE
testCodeGen(
"SELECT value, avg(key) FROM testData3x GROUP BY value",
(1 to 100).map(i => Row(i.toString, i)))
testCodeGen(
"SELECT avg(key) FROM testData3x",
Row(50.5) :: Nil)
// MAX
testCodeGen(
"SELECT value, max(key) FROM testData3x GROUP BY value",
(1 to 100).map(i => Row(i.toString, i)))
testCodeGen(
"SELECT max(key) FROM testData3x",
Row(100) :: Nil)
// Some combinations.
testCodeGen(
"""
|SELECT
| value,
| sum(key),
| max(key),
| avg(key),
| count(key),
| count(distinct key)
|FROM testData3x
|GROUP BY value
""".stripMargin,
(1 to 100).map(i => Row(i.toString, i*3, i, i, 3, 1)))
testCodeGen(
"SELECT max(key), avg(key), count(key), count(distinct key) FROM testData3x",
Row(100, 50.5, 300, 100) :: Nil)
// Aggregate with Code generation handling all null values
testCodeGen(
"SELECT sum('a'), avg('a'), count(null) FROM testData",
Row(0, null, 0) :: Nil)

dropTempTable("testData3x")
setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
}

test("Add Parser of SQL COALESCE()") {
checkAnswer(
sql("""SELECT COALESCE(1, 2)"""),
Expand Down

0 comments on commit 7d7384c

Please sign in to comment.