Skip to content

Commit

Permalink
do best to avoid overflowing in function avg().
Browse files Browse the repository at this point in the history
  • Loading branch information
egraldlo committed Jun 5, 2014
1 parent b77c19b commit 1153f75
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,10 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
override def toString = s"AVG($child)"

override def asPartial: SplitEvaluation = {
val partialSum = Alias(Sum(child), "PartialSum")()
val partialCount = Alias(Count(child), "PartialCount")()
val castedSum = Cast(Sum(partialSum.toAttribute), dataType)
val castedCount = Cast(Sum(partialCount.toAttribute), dataType)
val partialSum = Alias(Sum(Cast(child, dataType)), "PartialSum")()
val partialCount = Alias(Cast(Count(child), dataType), "PartialCount")()
val castedSum = Sum(partialSum.toAttribute)
val castedCount = Sum(partialCount.toAttribute)

SplitEvaluation(
Divide(castedSum, castedCount),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ class SQLQuerySuite extends QueryTest {
2.0)
}

test("average overflow test") {
checkAnswer(
sql("SELECT AVG(a),b FROM testData1 group by b"),
Seq((2147483645.0,1),(2.0,2)))
}

test("count") {
checkAnswer(
sql("SELECT COUNT(*) FROM testData2"),
Expand Down
11 changes: 11 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ object TestData {
(1 to 100).map(i => TestData(i, i.toString)))
testData.registerAsTable("testData")

case class TestData1(a: Int, b: Int)
val testData1: SchemaRDD =
TestSQLContext.sparkContext.parallelize(
TestData1(2147483644, 1) ::
TestData1(1, 2) ::
TestData1(2147483645, 1) ::
TestData1(2, 2) ::
TestData1(2147483646, 1) ::
TestData1(3, 2) :: Nil)
testData1.registerAsTable("testData1")

case class TestData2(a: Int, b: Int)
val testData2: SchemaRDD =
TestSQLContext.sparkContext.parallelize(
Expand Down

0 comments on commit 1153f75

Please sign in to comment.