Skip to content

Commit

Permalink
[SPARK-26021][SQL] replace minus zero with zero in Platform.putDouble…
Browse files Browse the repository at this point in the history
…/Float

GROUP BY treats -0.0 and 0.0 as different values which is unlike hive's behavior.
In addition current behavior with codegen is unpredictable (see example in JIRA ticket).

## What changes were proposed in this pull request?

In Platform.putDouble/Float() checking if the value is -0.0, and if so replacing with 0.0.
This is used by UnsafeRow so it won't have -0.0 values.

## How was this patch tested?

Added tests

Closes #23043 from adoron/adoron-spark-26021-replace-minus-zero-with-zero.

Authored-by: Alon Doron <adoron@palantir.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 0ec7b99)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
Alon Doron authored and cloud-fan committed Nov 23, 2018
1 parent 8705a9d commit d63ab5a
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 13 deletions.
10 changes: 10 additions & 0 deletions common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ public static float getFloat(Object object, long offset) {
}

public static void putFloat(Object object, long offset, float value) {
if (Float.isNaN(value)) {
value = Float.NaN;
} else if (value == -0.0f) {
value = 0.0f;
}
_UNSAFE.putFloat(object, offset, value);
}

Expand All @@ -128,6 +133,11 @@ public static double getDouble(Object object, long offset) {
}

public static void putDouble(Object object, long offset, double value) {
if (Double.isNaN(value)) {
value = Double.NaN;
} else if (value == -0.0d) {
value = 0.0d;
}
_UNSAFE.putDouble(object, offset, value);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,18 @@ public void heapMemoryReuse() {
Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7);
Assert.assertEquals(obj3, onheap4.getBaseObject());
}

@Test
// SPARK-26021
public void writeMinusZeroIsReplacedWithZero() {
byte[] doubleBytes = new byte[Double.BYTES];
byte[] floatBytes = new byte[Float.BYTES];
Platform.putDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET, -0.0d);
Platform.putFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET, -0.0f);
double doubleFromPlatform = Platform.getDouble(doubleBytes, Platform.BYTE_ARRAY_OFFSET);
float floatFromPlatform = Platform.getFloat(floatBytes, Platform.BYTE_ARRAY_OFFSET);

Assert.assertEquals(Double.doubleToLongBits(0.0d), Double.doubleToLongBits(doubleFromPlatform));
Assert.assertEquals(Float.floatToIntBits(0.0f), Float.floatToIntBits(floatFromPlatform));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,6 @@ public void setLong(int ordinal, long value) {
public void setDouble(int ordinal, double value) {
assertIndexIsValid(ordinal);
setNotNullAt(ordinal);
if (Double.isNaN(value)) {
value = Double.NaN;
}
Platform.putDouble(baseObject, getFieldOffset(ordinal), value);
}

Expand Down Expand Up @@ -255,9 +252,6 @@ public void setByte(int ordinal, byte value) {
public void setFloat(int ordinal, float value) {
assertIndexIsValid(ordinal);
setNotNullAt(ordinal);
if (Float.isNaN(value)) {
value = Float.NaN;
}
Platform.putFloat(baseObject, getFieldOffset(ordinal), value);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,10 @@ protected final void writeLong(long offset, long value) {
}

protected final void writeFloat(long offset, float value) {
if (Float.isNaN(value)) {
value = Float.NaN;
}
Platform.putFloat(getBuffer(), offset, value);
}

protected final void writeDouble(long offset, double value) {
if (Double.isNaN(value)) {
value = Double.NaN;
}
Platform.putDouble(getBuffer(), offset, value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -727,4 +727,18 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
"grouping expressions: [current_date(None)], value: [key: int, value: string], " +
"type: GroupBy]"))
}

test("SPARK-26021: Double and Float 0.0/-0.0 should be equal when grouping") {
val colName = "i"
val doubles = Seq(0.0d, -0.0d, 0.0d).toDF(colName).groupBy(colName).count().collect()
val floats = Seq(0.0f, -0.0f, 0.0f).toDF(colName).groupBy(colName).count().collect()

assert(doubles.length == 1)
assert(floats.length == 1)
// using compare since 0.0 == -0.0 is true
assert(java.lang.Double.compare(doubles(0).getDouble(0), 0.0d) == 0)
assert(java.lang.Float.compare(floats(0).getFloat(0), 0.0f) == 0)
assert(doubles(0).getLong(1) == 3)
assert(floats(0).getLong(1) == 3)
}
}
5 changes: 4 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ object QueryTest {
def prepareRow(row: Row): Row = {
Row.fromSeq(row.toSeq.map {
case null => null
case d: java.math.BigDecimal => BigDecimal(d)
case bd: java.math.BigDecimal => BigDecimal(bd)
// Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+
case seq: Seq[_] => seq.map {
case b: java.lang.Byte => b.byteValue
Expand All @@ -303,6 +303,9 @@ object QueryTest {
// Convert array to Seq for easy equality check.
case b: Array[_] => b.toSeq
case r: Row => prepareRow(r)
// spark treats -0.0 as 0.0
case d: Double if d == -0.0d => 0.0d
case f: Float if f == -0.0f => 0.0f
case o => o
})
}
Expand Down

0 comments on commit d63ab5a

Please sign in to comment.