Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32110][SQL] normalize special floating numbers in HyperLogLog++ #30673

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,28 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {

case _ => throw new IllegalStateException(s"fail to normalize $expr")
}

val FLOAT_NORMALIZER: Any => Any = (input: Any) => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, can this just be a def?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is stateless and being a val is more efficient.

val f = input.asInstanceOf[Float]
if (f.isNaN) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check isn't necessary, as NaN won't equal -0.0f, so it will be returned on line 154 anyway. Or am I missing that there are different NaNs and this is normalizing them too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is copied from the existing code. NaN is not a single value and we need to make sure all NaN values have the same binary representation in Spark unsafe row.

Float.NaN
} else if (f == -0.0f) {
0.0f
} else {
f
}
}

val DOUBLE_NORMALIZER: Any => Any = (input: Any) => {
val d = input.asInstanceOf[Double]
if (d.isNaN) {
Double.NaN
} else if (d == -0.0d) {
0.0d
} else {
d
}
}
}

case class NormalizeNaNAndZero(child: Expression) extends UnaryExpression with ExpectsInputTypes {
Expand All @@ -152,27 +174,8 @@ case class NormalizeNaNAndZero(child: Expression) extends UnaryExpression with E
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(FloatType, DoubleType))

private lazy val normalizer: Any => Any = child.dataType match {
case FloatType => (input: Any) => {
val f = input.asInstanceOf[Float]
if (f.isNaN) {
Float.NaN
} else if (f == -0.0f) {
0.0f
} else {
f
}
}

case DoubleType => (input: Any) => {
val d = input.asInstanceOf[Double]
if (d.isNaN) {
Double.NaN
} else if (d == -0.0d) {
0.0d
} else {
d
}
}
case FloatType => NormalizeFloatingNumbers.FLOAT_NORMALIZER
case DoubleType => NormalizeFloatingNumbers.DOUBLE_NORMALIZER
}

override def nullSafeEval(input: Any): Any = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.XxHash64Function
import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers.{DOUBLE_NORMALIZER, FLOAT_NORMALIZER}
import org.apache.spark.sql.types._

// A helper class for HyperLogLogPlusPlus.
Expand Down Expand Up @@ -88,7 +89,12 @@ class HyperLogLogPlusPlusHelper(relativeSD: Double) extends Serializable {
*
* Variable names in the HLL++ paper match variable names in the code.
*/
def update(buffer: InternalRow, bufferOffset: Int, value: Any, dataType: DataType): Unit = {
def update(buffer: InternalRow, bufferOffset: Int, _value: Any, dataType: DataType): Unit = {
val value = dataType match {
case FloatType => FLOAT_NORMALIZER.apply(_value)
case DoubleType => DOUBLE_NORMALIZER.apply(_value)
case _ => _value
}
// Create the hashed value 'x'.
val x = XxHash64Function.hash(value, dataType, 42L)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -554,4 +554,94 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(GreaterThan(Literal(Float.NaN), Literal(Float.NaN)), false)
checkEvaluation(GreaterThan(Literal(0.0F), Literal(-0.0F)), false)
}

test("SPARK-32110: compare special double/float values in array") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new tests here pass before this PR. I'm adding them to prove that nested 0.0/-0.0 is fine. CodegenContext.genComp is very conservative and only does the shortcut when both sides are unsafe format and they equal to each other in binary. for 0.0 and -0.0, they do not equal to each other in binary and will fallback to the element-by-element comparison.

def createUnsafeDoubleArray(d: Double): Literal = {
Literal(UnsafeArrayData.fromPrimitiveArray(Array(d)), ArrayType(DoubleType))
}
def createSafeDoubleArray(d: Double): Literal = {
Literal(new GenericArrayData(Array(d)), ArrayType(DoubleType))
}
def createUnsafeFloatArray(d: Double): Literal = {
Literal(UnsafeArrayData.fromPrimitiveArray(Array(d.toFloat)), ArrayType(FloatType))
}
def createSafeFloatArray(d: Double): Literal = {
Literal(new GenericArrayData(Array(d.toFloat)), ArrayType(FloatType))
}
def checkExpr(
exprBuilder: (Expression, Expression) => Expression,
left: Double,
right: Double,
expected: Any): Unit = {
// test double
checkEvaluation(
exprBuilder(createUnsafeDoubleArray(left), createUnsafeDoubleArray(right)), expected)
checkEvaluation(
exprBuilder(createUnsafeDoubleArray(left), createSafeDoubleArray(right)), expected)
checkEvaluation(
exprBuilder(createSafeDoubleArray(left), createSafeDoubleArray(right)), expected)
// test float
checkEvaluation(
exprBuilder(createUnsafeFloatArray(left), createUnsafeFloatArray(right)), expected)
checkEvaluation(
exprBuilder(createUnsafeFloatArray(left), createSafeFloatArray(right)), expected)
checkEvaluation(
exprBuilder(createSafeFloatArray(left), createSafeFloatArray(right)), expected)
}

checkExpr(EqualTo, Double.NaN, Double.NaN, true)
checkExpr(EqualTo, Double.NaN, Double.PositiveInfinity, false)
checkExpr(EqualTo, 0.0, -0.0, true)
checkExpr(GreaterThan, Double.NaN, Double.PositiveInfinity, true)
checkExpr(GreaterThan, Double.NaN, Double.NaN, false)
checkExpr(GreaterThan, 0.0, -0.0, false)
}

test("SPARK-32110: compare special double/float values in struct") {
def createUnsafeDoubleRow(d: Double): Literal = {
val dt = new StructType().add("d", "double")
val converter = UnsafeProjection.create(dt)
val unsafeRow = converter.apply(InternalRow(d))
Literal(unsafeRow, dt)
}
def createSafeDoubleRow(d: Double): Literal = {
Literal(InternalRow(d), new StructType().add("d", "double"))
}
def createUnsafeFloatRow(d: Double): Literal = {
val dt = new StructType().add("f", "float")
val converter = UnsafeProjection.create(dt)
val unsafeRow = converter.apply(InternalRow(d.toFloat))
Literal(unsafeRow, dt)
}
def createSafeFloatRow(d: Double): Literal = {
Literal(InternalRow(d.toFloat), new StructType().add("f", "float"))
}
def checkExpr(
exprBuilder: (Expression, Expression) => Expression,
left: Double,
right: Double,
expected: Any): Unit = {
// test double
checkEvaluation(
exprBuilder(createUnsafeDoubleRow(left), createUnsafeDoubleRow(right)), expected)
checkEvaluation(
exprBuilder(createUnsafeDoubleRow(left), createSafeDoubleRow(right)), expected)
checkEvaluation(
exprBuilder(createSafeDoubleRow(left), createSafeDoubleRow(right)), expected)
// test float
checkEvaluation(
exprBuilder(createUnsafeFloatRow(left), createUnsafeFloatRow(right)), expected)
checkEvaluation(
exprBuilder(createUnsafeFloatRow(left), createSafeFloatRow(right)), expected)
checkEvaluation(
exprBuilder(createSafeFloatRow(left), createSafeFloatRow(right)), expected)
}

checkExpr(EqualTo, Double.NaN, Double.NaN, true)
checkExpr(EqualTo, Double.NaN, Double.PositiveInfinity, false)
checkExpr(EqualTo, 0.0, -0.0, true)
checkExpr(GreaterThan, Double.NaN, Double.PositiveInfinity, true)
checkExpr(GreaterThan, Double.NaN, Double.NaN, false)
checkExpr(GreaterThan, 0.0, -0.0, false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import java.lang.{Double => JDouble}
import java.util.Random

import scala.collection.mutable

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BoundReference, SpecificInternalRow}
import org.apache.spark.sql.types.{DataType, IntegerType}
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType}

class HyperLogLogPlusPlusSuite extends SparkFunSuite {

Expand Down Expand Up @@ -153,4 +154,25 @@ class HyperLogLogPlusPlusSuite extends SparkFunSuite {
// Check if the buffers are equal.
assert(buffer2 == buffer1a, "Buffers should be equal")
}

test("SPARK-32110: add 0.0 and -0.0") {
val (hll, input, buffer) = createEstimator(0.05, DoubleType)
input.setDouble(0, 0.0)
hll.update(buffer, input)
input.setDouble(0, -0.0)
hll.update(buffer, input)
evaluateEstimate(hll, buffer, 1);
}

test("SPARK-32110: add NaN") {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test passes before this PR, as our hash implementation returns the same hash code for all NaN values. I'm adding it just to make the test cases completed.

val (hll, input, buffer) = createEstimator(0.05, DoubleType)
input.setDouble(0, Double.NaN)
hll.update(buffer, input)
val specialNaN = JDouble.longBitsToDouble(0x7ff1234512345678L)
assert(JDouble.isNaN(specialNaN))
assert(JDouble.doubleToRawLongBits(Double.NaN) != JDouble.doubleToRawLongBits(specialNaN))
input.setDouble(0, specialNaN)
hll.update(buffer, input)
evaluateEstimate(hll, buffer, 1);
}
}