Skip to content

Commit

Permalink
[SPARK-32110][SQL] normalize special floating numbers in HyperLogLog++
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Currently, Spark treats 0.0 and -0.0 semantically equal, while it still retains the difference between them so that users can see -0.0 when displaying the data set.

The comparison expressions in Spark take care of the special floating numbers and implement the correct semantic. However, Spark doesn't always use these comparison expressions to compare values, and we need to normalize the special floating numbers before comparing them in these places:
1. GROUP BY
2. join keys
3. window partition keys

This PR fixes one more place that compares values without using comparison expressions: HyperLogLog++

### Why are the changes needed?

Fix the query result

### Does this PR introduce _any_ user-facing change?

Yes, the result of HyperLogLog++ becomes correct now.

### How was this patch tested?

a new test case, and a few more test cases that pass before this PR to improve test coverage.

Closes #30673 from cloud-fan/bug.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
(cherry picked from commit 6fd2345)
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
cloud-fan authored and dongjoon-hyun committed Dec 8, 2020
1 parent 54a73ab commit 1093c0f
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,28 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {

case _ => throw new IllegalStateException(s"fail to normalize $expr")
}

val FLOAT_NORMALIZER: Any => Any = (input: Any) => {
val f = input.asInstanceOf[Float]
if (f.isNaN) {
Float.NaN
} else if (f == -0.0f) {
0.0f
} else {
f
}
}

val DOUBLE_NORMALIZER: Any => Any = (input: Any) => {
val d = input.asInstanceOf[Double]
if (d.isNaN) {
Double.NaN
} else if (d == -0.0d) {
0.0d
} else {
d
}
}
}

case class NormalizeNaNAndZero(child: Expression) extends UnaryExpression with ExpectsInputTypes {
Expand All @@ -152,27 +174,8 @@ case class NormalizeNaNAndZero(child: Expression) extends UnaryExpression with E
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(FloatType, DoubleType))

private lazy val normalizer: Any => Any = child.dataType match {
case FloatType => (input: Any) => {
val f = input.asInstanceOf[Float]
if (f.isNaN) {
Float.NaN
} else if (f == -0.0f) {
0.0f
} else {
f
}
}

case DoubleType => (input: Any) => {
val d = input.asInstanceOf[Double]
if (d.isNaN) {
Double.NaN
} else if (d == -0.0d) {
0.0d
} else {
d
}
}
case FloatType => NormalizeFloatingNumbers.FLOAT_NORMALIZER
case DoubleType => NormalizeFloatingNumbers.DOUBLE_NORMALIZER
}

override def nullSafeEval(input: Any): Any = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.XxHash64Function
import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers.{DOUBLE_NORMALIZER, FLOAT_NORMALIZER}
import org.apache.spark.sql.types._

// A helper class for HyperLogLogPlusPlus.
Expand Down Expand Up @@ -88,7 +89,12 @@ class HyperLogLogPlusPlusHelper(relativeSD: Double) extends Serializable {
*
* Variable names in the HLL++ paper match variable names in the code.
*/
def update(buffer: InternalRow, bufferOffset: Int, value: Any, dataType: DataType): Unit = {
def update(buffer: InternalRow, bufferOffset: Int, _value: Any, dataType: DataType): Unit = {
val value = dataType match {
case FloatType => FLOAT_NORMALIZER.apply(_value)
case DoubleType => DOUBLE_NORMALIZER.apply(_value)
case _ => _value
}
// Create the hashed value 'x'.
val x = XxHash64Function.hash(value, dataType, 42L)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -554,4 +554,94 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(GreaterThan(Literal(Float.NaN), Literal(Float.NaN)), false)
checkEvaluation(GreaterThan(Literal(0.0F), Literal(-0.0F)), false)
}

test("SPARK-32110: compare special double/float values in array") {
def createUnsafeDoubleArray(d: Double): Literal = {
Literal(UnsafeArrayData.fromPrimitiveArray(Array(d)), ArrayType(DoubleType))
}
def createSafeDoubleArray(d: Double): Literal = {
Literal(new GenericArrayData(Array(d)), ArrayType(DoubleType))
}
def createUnsafeFloatArray(d: Double): Literal = {
Literal(UnsafeArrayData.fromPrimitiveArray(Array(d.toFloat)), ArrayType(FloatType))
}
def createSafeFloatArray(d: Double): Literal = {
Literal(new GenericArrayData(Array(d.toFloat)), ArrayType(FloatType))
}
def checkExpr(
exprBuilder: (Expression, Expression) => Expression,
left: Double,
right: Double,
expected: Any): Unit = {
// test double
checkEvaluation(
exprBuilder(createUnsafeDoubleArray(left), createUnsafeDoubleArray(right)), expected)
checkEvaluation(
exprBuilder(createUnsafeDoubleArray(left), createSafeDoubleArray(right)), expected)
checkEvaluation(
exprBuilder(createSafeDoubleArray(left), createSafeDoubleArray(right)), expected)
// test float
checkEvaluation(
exprBuilder(createUnsafeFloatArray(left), createUnsafeFloatArray(right)), expected)
checkEvaluation(
exprBuilder(createUnsafeFloatArray(left), createSafeFloatArray(right)), expected)
checkEvaluation(
exprBuilder(createSafeFloatArray(left), createSafeFloatArray(right)), expected)
}

checkExpr(EqualTo, Double.NaN, Double.NaN, true)
checkExpr(EqualTo, Double.NaN, Double.PositiveInfinity, false)
checkExpr(EqualTo, 0.0, -0.0, true)
checkExpr(GreaterThan, Double.NaN, Double.PositiveInfinity, true)
checkExpr(GreaterThan, Double.NaN, Double.NaN, false)
checkExpr(GreaterThan, 0.0, -0.0, false)
}

test("SPARK-32110: compare special double/float values in struct") {
def createUnsafeDoubleRow(d: Double): Literal = {
val dt = new StructType().add("d", "double")
val converter = UnsafeProjection.create(dt)
val unsafeRow = converter.apply(InternalRow(d))
Literal(unsafeRow, dt)
}
def createSafeDoubleRow(d: Double): Literal = {
Literal(InternalRow(d), new StructType().add("d", "double"))
}
def createUnsafeFloatRow(d: Double): Literal = {
val dt = new StructType().add("f", "float")
val converter = UnsafeProjection.create(dt)
val unsafeRow = converter.apply(InternalRow(d.toFloat))
Literal(unsafeRow, dt)
}
def createSafeFloatRow(d: Double): Literal = {
Literal(InternalRow(d.toFloat), new StructType().add("f", "float"))
}
def checkExpr(
exprBuilder: (Expression, Expression) => Expression,
left: Double,
right: Double,
expected: Any): Unit = {
// test double
checkEvaluation(
exprBuilder(createUnsafeDoubleRow(left), createUnsafeDoubleRow(right)), expected)
checkEvaluation(
exprBuilder(createUnsafeDoubleRow(left), createSafeDoubleRow(right)), expected)
checkEvaluation(
exprBuilder(createSafeDoubleRow(left), createSafeDoubleRow(right)), expected)
// test float
checkEvaluation(
exprBuilder(createUnsafeFloatRow(left), createUnsafeFloatRow(right)), expected)
checkEvaluation(
exprBuilder(createUnsafeFloatRow(left), createSafeFloatRow(right)), expected)
checkEvaluation(
exprBuilder(createSafeFloatRow(left), createSafeFloatRow(right)), expected)
}

checkExpr(EqualTo, Double.NaN, Double.NaN, true)
checkExpr(EqualTo, Double.NaN, Double.PositiveInfinity, false)
checkExpr(EqualTo, 0.0, -0.0, true)
checkExpr(GreaterThan, Double.NaN, Double.PositiveInfinity, true)
checkExpr(GreaterThan, Double.NaN, Double.NaN, false)
checkExpr(GreaterThan, 0.0, -0.0, false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import java.lang.{Double => JDouble}
import java.util.Random

import scala.collection.mutable

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BoundReference, SpecificInternalRow}
import org.apache.spark.sql.types.{DataType, IntegerType}
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType}

class HyperLogLogPlusPlusSuite extends SparkFunSuite {

Expand Down Expand Up @@ -153,4 +154,25 @@ class HyperLogLogPlusPlusSuite extends SparkFunSuite {
// Check if the buffers are equal.
assert(buffer2 == buffer1a, "Buffers should be equal")
}

test("SPARK-32110: add 0.0 and -0.0") {
val (hll, input, buffer) = createEstimator(0.05, DoubleType)
input.setDouble(0, 0.0)
hll.update(buffer, input)
input.setDouble(0, -0.0)
hll.update(buffer, input)
evaluateEstimate(hll, buffer, 1);
}

test("SPARK-32110: add NaN") {
val (hll, input, buffer) = createEstimator(0.05, DoubleType)
input.setDouble(0, Double.NaN)
hll.update(buffer, input)
val specialNaN = JDouble.longBitsToDouble(0x7ff1234512345678L)
assert(JDouble.isNaN(specialNaN))
assert(JDouble.doubleToRawLongBits(Double.NaN) != JDouble.doubleToRawLongBits(specialNaN))
input.setDouble(0, specialNaN)
hll.update(buffer, input)
evaluateEstimate(hll, buffer, 1);
}
}

0 comments on commit 1093c0f

Please sign in to comment.