Skip to content

Commit

Permalink
test: test with the HalfFloatType from arrow (#503)
Browse files Browse the repository at this point in the history
Hi this PR adds an extra test case that tests `ArrowFloatToHalfFloat`
and `ArrowArrayViewGet{Double,Int,UInt}Unsafe` with the `HalfFloatType`
from arrow.


https://github.com/apache/arrow/blob/255dbf990c3d3e5fb1270a2a11efe0af2be195ab/cpp/src/arrow/type.h#L704-L713

And sorry that I didn't know there's a `HalfFloatType` in arrow-cpp when
I was doing #501 🥲.

~Sadly that we cannot simply append
`TestGetFromNumericArrayView<HalfFloatType>();` at the end of the
`ArrayViewTestGetNumeric` test suite because we have to convert floats
to half-floats using `ArrowFloatToHalfFloat` before calling
`builder.Append` (otherwise we'll get weird values back in the
subsequent `ArrowArrayViewGet{Double,Int,UInt}Unsafe` calls).~

Added some simple C++ magic and we can simply append
`TestGetFromNumericArrayView<HalfFloatType>();` at the end of the
`ArrayViewTestGetNumeric` test suite now.
  • Loading branch information
cocoa-xu authored Jun 4, 2024
1 parent 9410bd3 commit d917f29
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 6 deletions.
24 changes: 20 additions & 4 deletions src/nanoarrow/array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <gtest/gtest.h>
#include <cmath>
#include <cstdint>
#include <type_traits>

#include <arrow/array.h>
#include <arrow/array/builder_binary.h>
Expand Down Expand Up @@ -2423,6 +2424,20 @@ TEST(ArrayTest, ArrayViewTestSparseUnionGet) {
ArrowArrayRelease(&array);
}

template <
typename TypeClass, typename ValueType,
typename std::enable_if<std::is_same_v<TypeClass, HalfFloatType>, bool>::type = true>
auto transform_value(ValueType t) -> uint16_t {
return ArrowFloatToHalfFloat(t);
}

template <
typename TypeClass, typename ValueType,
typename std::enable_if<!std::is_same_v<TypeClass, HalfFloatType>, bool>::type = true>
auto transform_value(ValueType t) -> ValueType {
return t;
}

template <typename TypeClass>
void TestGetFromNumericArrayView() {
struct ArrowArray array;
Expand All @@ -2434,9 +2449,9 @@ void TestGetFromNumericArrayView() {

// Array with nulls
auto builder = NumericBuilder<TypeClass>();
ARROW_EXPECT_OK(builder.Append(1));
ARROW_EXPECT_OK(builder.Append(transform_value<TypeClass>(1)));
ARROW_EXPECT_OK(builder.AppendNulls(2));
ARROW_EXPECT_OK(builder.Append(4));
ARROW_EXPECT_OK(builder.Append(transform_value<TypeClass>(4)));
auto maybe_arrow_array = builder.Finish();
ARROW_EXPECT_OK(maybe_arrow_array);
auto arrow_array = maybe_arrow_array.ValueUnsafe();
Expand Down Expand Up @@ -2467,8 +2482,8 @@ void TestGetFromNumericArrayView() {

// Array without nulls (Arrow does not allocate the validity buffer)
builder = NumericBuilder<TypeClass>();
ARROW_EXPECT_OK(builder.Append(1));
ARROW_EXPECT_OK(builder.Append(2));
ARROW_EXPECT_OK(builder.Append(transform_value<TypeClass>(1)));
ARROW_EXPECT_OK(builder.Append(transform_value<TypeClass>(2)));
maybe_arrow_array = builder.Finish();
ARROW_EXPECT_OK(maybe_arrow_array);
arrow_array = maybe_arrow_array.ValueUnsafe();
Expand Down Expand Up @@ -2504,6 +2519,7 @@ TEST(ArrayViewTest, ArrayViewTestGetNumeric) {
TestGetFromNumericArrayView<UInt32Type>();
TestGetFromNumericArrayView<DoubleType>();
TestGetFromNumericArrayView<FloatType>();
TestGetFromNumericArrayView<HalfFloatType>();
}

TEST(ArrayViewTest, ArrayViewTestGetFloat16) {
Expand Down
3 changes: 3 additions & 0 deletions src/nanoarrow/nanoarrow_testing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,7 @@ class TestingJSONWriter {
}
break;

case NANOARROW_TYPE_HALF_FLOAT:
case NANOARROW_TYPE_FLOAT:
case NANOARROW_TYPE_DOUBLE: {
// JSON number to float_precision_ decimal places
Expand Down Expand Up @@ -2146,6 +2147,8 @@ class TestingJSONReader {
case NANOARROW_TYPE_UINT64:
return SetBufferInt<uint64_t, uint64_t>(data, buffer, error);

case NANOARROW_TYPE_HALF_FLOAT:
return SetBufferFloatingPoint<float>(data, buffer, error);
case NANOARROW_TYPE_FLOAT:
return SetBufferFloatingPoint<float>(data, buffer, error);
case NANOARROW_TYPE_DOUBLE:
Expand Down
4 changes: 2 additions & 2 deletions src/nanoarrow/utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -547,8 +547,8 @@ TEST(DecimalTest, DecimalRoundtripBitshiftTest) {
// https://github.com/apache/arrow/blob/main/go/arrow/float16/float16_test.go
TEST(HalfFloatTest, FloatAndHalfFloatRoundTrip) {
uint16_t cases_bits[] = {
0x8000, 0x7c00, 0xfc00, 0x3c00, 0x4000, 0xc000,
+0x0000, 0x5b8f, 0xdb8f, 0x48c8, 0xc8c8,
0x8000, 0x7c00, 0xfc00, 0x3c00, 0x4000, 0xc000,
0x0000, 0x5b8f, 0xdb8f, 0x48c8, 0xc8c8,
};
float cases_float[] = {
-0.0, INFINITY, -INFINITY, 1, 2, -2, 0, 241.875, -241.875, 9.5625, -9.5625,
Expand Down

0 comments on commit d917f29

Please sign in to comment.