-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-32110][SQL] normalize special floating numbers in HyperLogLog++ #30673
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -143,6 +143,28 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { | |
|
||
case _ => throw new IllegalStateException(s"fail to normalize $expr") | ||
} | ||
|
||
val FLOAT_NORMALIZER: Any => Any = (input: Any) => { | ||
val f = input.asInstanceOf[Float] | ||
if (f.isNaN) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this check isn't necessary, as NaN won't equal -0.0f, so it will be returned on line 154 anyway. Or am I missing that there are different NaNs and this is normalizing them too? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is copied from the existing code. NaN is not a single value and we need to make sure all NaN values have the same binary representation in Spark unsafe row. |
||
Float.NaN | ||
} else if (f == -0.0f) { | ||
0.0f | ||
} else { | ||
f | ||
} | ||
} | ||
|
||
val DOUBLE_NORMALIZER: Any => Any = (input: Any) => { | ||
val d = input.asInstanceOf[Double] | ||
if (d.isNaN) { | ||
Double.NaN | ||
} else if (d == -0.0d) { | ||
0.0d | ||
} else { | ||
d | ||
} | ||
} | ||
} | ||
|
||
case class NormalizeNaNAndZero(child: Expression) extends UnaryExpression with ExpectsInputTypes { | ||
|
@@ -152,27 +174,8 @@ case class NormalizeNaNAndZero(child: Expression) extends UnaryExpression with E | |
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(FloatType, DoubleType)) | ||
|
||
private lazy val normalizer: Any => Any = child.dataType match { | ||
case FloatType => (input: Any) => { | ||
val f = input.asInstanceOf[Float] | ||
if (f.isNaN) { | ||
Float.NaN | ||
} else if (f == -0.0f) { | ||
0.0f | ||
} else { | ||
f | ||
} | ||
} | ||
|
||
case DoubleType => (input: Any) => { | ||
val d = input.asInstanceOf[Double] | ||
if (d.isNaN) { | ||
Double.NaN | ||
} else if (d == -0.0d) { | ||
0.0d | ||
} else { | ||
d | ||
} | ||
} | ||
case FloatType => NormalizeFloatingNumbers.FLOAT_NORMALIZER | ||
case DoubleType => NormalizeFloatingNumbers.DOUBLE_NORMALIZER | ||
} | ||
|
||
override def nullSafeEval(input: Any): Any = { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -554,4 +554,94 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { | |
checkEvaluation(GreaterThan(Literal(Float.NaN), Literal(Float.NaN)), false) | ||
checkEvaluation(GreaterThan(Literal(0.0F), Literal(-0.0F)), false) | ||
} | ||
|
||
test("SPARK-32110: compare special double/float values in array") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new tests here pass before this PR. I'm adding them to prove that nested 0.0/-0.0 is fine. |
||
def createUnsafeDoubleArray(d: Double): Literal = { | ||
Literal(UnsafeArrayData.fromPrimitiveArray(Array(d)), ArrayType(DoubleType)) | ||
} | ||
def createSafeDoubleArray(d: Double): Literal = { | ||
Literal(new GenericArrayData(Array(d)), ArrayType(DoubleType)) | ||
} | ||
def createUnsafeFloatArray(d: Double): Literal = { | ||
Literal(UnsafeArrayData.fromPrimitiveArray(Array(d.toFloat)), ArrayType(FloatType)) | ||
} | ||
def createSafeFloatArray(d: Double): Literal = { | ||
Literal(new GenericArrayData(Array(d.toFloat)), ArrayType(FloatType)) | ||
} | ||
def checkExpr( | ||
exprBuilder: (Expression, Expression) => Expression, | ||
left: Double, | ||
right: Double, | ||
expected: Any): Unit = { | ||
// test double | ||
checkEvaluation( | ||
exprBuilder(createUnsafeDoubleArray(left), createUnsafeDoubleArray(right)), expected) | ||
checkEvaluation( | ||
exprBuilder(createUnsafeDoubleArray(left), createSafeDoubleArray(right)), expected) | ||
checkEvaluation( | ||
exprBuilder(createSafeDoubleArray(left), createSafeDoubleArray(right)), expected) | ||
// test float | ||
checkEvaluation( | ||
exprBuilder(createUnsafeFloatArray(left), createUnsafeFloatArray(right)), expected) | ||
checkEvaluation( | ||
exprBuilder(createUnsafeFloatArray(left), createSafeFloatArray(right)), expected) | ||
checkEvaluation( | ||
exprBuilder(createSafeFloatArray(left), createSafeFloatArray(right)), expected) | ||
} | ||
|
||
checkExpr(EqualTo, Double.NaN, Double.NaN, true) | ||
checkExpr(EqualTo, Double.NaN, Double.PositiveInfinity, false) | ||
checkExpr(EqualTo, 0.0, -0.0, true) | ||
checkExpr(GreaterThan, Double.NaN, Double.PositiveInfinity, true) | ||
checkExpr(GreaterThan, Double.NaN, Double.NaN, false) | ||
checkExpr(GreaterThan, 0.0, -0.0, false) | ||
} | ||
|
||
test("SPARK-32110: compare special double/float values in struct") { | ||
def createUnsafeDoubleRow(d: Double): Literal = { | ||
val dt = new StructType().add("d", "double") | ||
val converter = UnsafeProjection.create(dt) | ||
val unsafeRow = converter.apply(InternalRow(d)) | ||
Literal(unsafeRow, dt) | ||
} | ||
def createSafeDoubleRow(d: Double): Literal = { | ||
Literal(InternalRow(d), new StructType().add("d", "double")) | ||
} | ||
def createUnsafeFloatRow(d: Double): Literal = { | ||
val dt = new StructType().add("f", "float") | ||
val converter = UnsafeProjection.create(dt) | ||
val unsafeRow = converter.apply(InternalRow(d.toFloat)) | ||
Literal(unsafeRow, dt) | ||
} | ||
def createSafeFloatRow(d: Double): Literal = { | ||
Literal(InternalRow(d.toFloat), new StructType().add("f", "float")) | ||
} | ||
def checkExpr( | ||
exprBuilder: (Expression, Expression) => Expression, | ||
left: Double, | ||
right: Double, | ||
expected: Any): Unit = { | ||
// test double | ||
checkEvaluation( | ||
exprBuilder(createUnsafeDoubleRow(left), createUnsafeDoubleRow(right)), expected) | ||
checkEvaluation( | ||
exprBuilder(createUnsafeDoubleRow(left), createSafeDoubleRow(right)), expected) | ||
checkEvaluation( | ||
exprBuilder(createSafeDoubleRow(left), createSafeDoubleRow(right)), expected) | ||
// test float | ||
checkEvaluation( | ||
exprBuilder(createUnsafeFloatRow(left), createUnsafeFloatRow(right)), expected) | ||
checkEvaluation( | ||
exprBuilder(createUnsafeFloatRow(left), createSafeFloatRow(right)), expected) | ||
checkEvaluation( | ||
exprBuilder(createSafeFloatRow(left), createSafeFloatRow(right)), expected) | ||
} | ||
|
||
checkExpr(EqualTo, Double.NaN, Double.NaN, true) | ||
checkExpr(EqualTo, Double.NaN, Double.PositiveInfinity, false) | ||
checkExpr(EqualTo, 0.0, -0.0, true) | ||
checkExpr(GreaterThan, Double.NaN, Double.PositiveInfinity, true) | ||
checkExpr(GreaterThan, Double.NaN, Double.NaN, false) | ||
checkExpr(GreaterThan, 0.0, -0.0, false) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,14 +17,15 @@ | |
|
||
package org.apache.spark.sql.catalyst.expressions.aggregate | ||
|
||
import java.lang.{Double => JDouble} | ||
import java.util.Random | ||
|
||
import scala.collection.mutable | ||
|
||
import org.apache.spark.SparkFunSuite | ||
import org.apache.spark.sql.catalyst.InternalRow | ||
import org.apache.spark.sql.catalyst.expressions.{BoundReference, SpecificInternalRow} | ||
import org.apache.spark.sql.types.{DataType, IntegerType} | ||
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType} | ||
|
||
class HyperLogLogPlusPlusSuite extends SparkFunSuite { | ||
|
||
|
@@ -153,4 +154,25 @@ class HyperLogLogPlusPlusSuite extends SparkFunSuite { | |
// Check if the buffers are equal. | ||
assert(buffer2 == buffer1a, "Buffers should be equal") | ||
} | ||
|
||
test("SPARK-32110: add 0.0 and -0.0") { | ||
val (hll, input, buffer) = createEstimator(0.05, DoubleType) | ||
input.setDouble(0, 0.0) | ||
hll.update(buffer, input) | ||
input.setDouble(0, -0.0) | ||
hll.update(buffer, input) | ||
evaluateEstimate(hll, buffer, 1); | ||
} | ||
|
||
test("SPARK-32110: add NaN") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test passes before this PR, as our hash implementation returns the same hash code for all NaN values. I'm adding it just to make the test cases completed. |
||
val (hll, input, buffer) = createEstimator(0.05, DoubleType) | ||
input.setDouble(0, Double.NaN) | ||
hll.update(buffer, input) | ||
val specialNaN = JDouble.longBitsToDouble(0x7ff1234512345678L) | ||
assert(JDouble.isNaN(specialNaN)) | ||
assert(JDouble.doubleToRawLongBits(Double.NaN) != JDouble.doubleToRawLongBits(specialNaN)) | ||
input.setDouble(0, specialNaN) | ||
hll.update(buffer, input) | ||
evaluateEstimate(hll, buffer, 1); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure, can this just be a
def
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is stateless and being a val is more efficient.