diff --git a/cpp/src/arrow/chunked_array.cc b/cpp/src/arrow/chunked_array.cc index 12937406e7800..c36b736d5d5df 100644 --- a/cpp/src/arrow/chunked_array.cc +++ b/cpp/src/arrow/chunked_array.cc @@ -86,7 +86,7 @@ Result> ChunkedArray::MakeEmpty( return std::make_shared(std::move(new_chunks)); } -bool ChunkedArray::Equals(const ChunkedArray& other) const { +bool ChunkedArray::Equals(const ChunkedArray& other, const EqualOptions& opts) const { if (length_ != other.length()) { return false; } @@ -102,9 +102,9 @@ bool ChunkedArray::Equals(const ChunkedArray& other) const { // the underlying data independently of the chunk size. return internal::ApplyBinaryChunked( *this, other, - [](const Array& left_piece, const Array& right_piece, - int64_t ARROW_ARG_UNUSED(position)) { - if (!left_piece.Equals(right_piece)) { + [&](const Array& left_piece, const Array& right_piece, + int64_t ARROW_ARG_UNUSED(position)) { + if (!left_piece.Equals(right_piece, opts)) { return Status::Invalid("Unequal piece"); } return Status::OK(); @@ -129,14 +129,15 @@ bool mayHaveNaN(const arrow::DataType& type) { } // namespace -bool ChunkedArray::Equals(const std::shared_ptr& other) const { +bool ChunkedArray::Equals(const std::shared_ptr& other, + const EqualOptions& opts) const { if (!other) { return false; } if (this == other.get() && !mayHaveNaN(*type_)) { return true; } - return Equals(*other.get()); + return Equals(*other.get(), opts); } bool ChunkedArray::ApproxEquals(const ChunkedArray& other, diff --git a/cpp/src/arrow/chunked_array.h b/cpp/src/arrow/chunked_array.h index 6ec7d11ac839d..5d300861d85c2 100644 --- a/cpp/src/arrow/chunked_array.h +++ b/cpp/src/arrow/chunked_array.h @@ -152,9 +152,11 @@ class ARROW_EXPORT ChunkedArray { /// /// Two chunked arrays can be equal only if they have equal datatypes. /// However, they may be equal even if they have different chunkings. - bool Equals(const ChunkedArray& other) const; + bool Equals(const ChunkedArray& other, + const EqualOptions& opts = EqualOptions::Defaults()) const; /// \brief Determine if two chunked arrays are equal. - bool Equals(const std::shared_ptr& other) const; + bool Equals(const std::shared_ptr& other, + const EqualOptions& opts = EqualOptions::Defaults()) const; /// \brief Determine if two chunked arrays approximately equal bool ApproxEquals(const ChunkedArray& other, const EqualOptions& = EqualOptions::Defaults()) const; diff --git a/cpp/src/arrow/testing/CMakeLists.txt b/cpp/src/arrow/testing/CMakeLists.txt index d5332405964ba..59825f0bf227a 100644 --- a/cpp/src/arrow/testing/CMakeLists.txt +++ b/cpp/src/arrow/testing/CMakeLists.txt @@ -19,4 +19,5 @@ arrow_install_all_headers("arrow/testing") if(ARROW_BUILD_TESTS) add_arrow_test(random_test) + add_arrow_test(gtest_util_test) endif() diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index a6dc1d59c67a9..5ef1820d5b581 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -145,42 +145,46 @@ void AssertScalarsApproxEqual(const Scalar& expected, const Scalar& actual, bool } void AssertBatchesEqual(const RecordBatch& expected, const RecordBatch& actual, - bool check_metadata) { + bool check_metadata, const EqualOptions& options) { AssertTsSame(expected, actual, [&](const RecordBatch& expected, const RecordBatch& actual) { - return expected.Equals(actual, check_metadata); + return expected.Equals(actual, check_metadata, options); }); } -void AssertBatchesApproxEqual(const RecordBatch& expected, const RecordBatch& actual) { +void AssertBatchesApproxEqual(const RecordBatch& expected, const RecordBatch& actual, + const EqualOptions& options) { AssertTsSame(expected, actual, [&](const RecordBatch& expected, const RecordBatch& actual) { - return expected.ApproxEquals(actual); + return expected.ApproxEquals(actual, options); }); } -void AssertChunkedEqual(const ChunkedArray& expected, const ChunkedArray& actual) { +void AssertChunkedEqual(const ChunkedArray& expected, const ChunkedArray& actual, + const EqualOptions& options) { ASSERT_EQ(expected.num_chunks(), actual.num_chunks()) << "# chunks unequal"; - if (!actual.Equals(expected)) { + if (!actual.Equals(expected, options)) { std::stringstream diff; for (int i = 0; i < actual.num_chunks(); ++i) { auto c1 = actual.chunk(i); auto c2 = expected.chunk(i); diff << "# chunk " << i << std::endl; - ARROW_IGNORE_EXPR(c1->Equals(c2, EqualOptions().diff_sink(&diff))); + ARROW_IGNORE_EXPR(c1->Equals(c2, options.diff_sink(&diff))); } FAIL() << diff.str(); } } -void AssertChunkedEqual(const ChunkedArray& actual, const ArrayVector& expected) { - AssertChunkedEqual(ChunkedArray(expected, actual.type()), actual); +void AssertChunkedEqual(const ChunkedArray& actual, const ArrayVector& expected, + const EqualOptions& options) { + AssertChunkedEqual(ChunkedArray(expected, actual.type()), actual, options); } -void AssertChunkedEquivalent(const ChunkedArray& expected, const ChunkedArray& actual) { +void AssertChunkedEquivalent(const ChunkedArray& expected, const ChunkedArray& actual, + const EqualOptions& options) { // XXX: AssertChunkedEqual in gtest_util.h does not permit the chunk layouts // to be different - if (!actual.Equals(expected)) { + if (!actual.Equals(expected, options)) { std::stringstream pp_expected; std::stringstream pp_actual; ::arrow::PrettyPrintOptions options(/*indent=*/2); @@ -321,21 +325,23 @@ ASSERT_EQUAL_IMPL(Field, Field, "fields") ASSERT_EQUAL_IMPL(Schema, Schema, "schemas") #undef ASSERT_EQUAL_IMPL -void AssertDatumsEqual(const Datum& expected, const Datum& actual, bool verbose) { +void AssertDatumsEqual(const Datum& expected, const Datum& actual, bool verbose, + const EqualOptions& options) { ASSERT_EQ(expected.kind(), actual.kind()) << "expected:" << expected.ToString() << " got:" << actual.ToString(); switch (expected.kind()) { case Datum::SCALAR: - AssertScalarsEqual(*expected.scalar(), *actual.scalar(), verbose); + AssertScalarsEqual(*expected.scalar(), *actual.scalar(), verbose, options); break; case Datum::ARRAY: { auto expected_array = expected.make_array(); auto actual_array = actual.make_array(); - AssertArraysEqual(*expected_array, *actual_array, verbose); + AssertArraysEqual(*expected_array, *actual_array, verbose, options); } break; case Datum::CHUNKED_ARRAY: - AssertChunkedEquivalent(*expected.chunked_array(), *actual.chunked_array()); + AssertChunkedEquivalent(*expected.chunked_array(), *actual.chunked_array(), + options); break; default: // TODO: Implement better print @@ -479,7 +485,7 @@ Result> PrintArrayDiff(const ChunkedArray& expected, } void AssertTablesEqual(const Table& expected, const Table& actual, bool same_chunk_layout, - bool combine_chunks) { + bool combine_chunks, const EqualOptions& options) { ASSERT_EQ(expected.num_columns(), actual.num_columns()); if (combine_chunks) { @@ -487,13 +493,13 @@ void AssertTablesEqual(const Table& expected, const Table& actual, bool same_chu ASSERT_OK_AND_ASSIGN(auto new_expected, expected.CombineChunks(pool)); ASSERT_OK_AND_ASSIGN(auto new_actual, actual.CombineChunks(pool)); - AssertTablesEqual(*new_expected, *new_actual, false, false); + AssertTablesEqual(*new_expected, *new_actual, false, false, options); return; } if (same_chunk_layout) { for (int i = 0; i < actual.num_columns(); ++i) { - AssertChunkedEqual(*expected.column(i), *actual.column(i)); + AssertChunkedEqual(*expected.column(i), *actual.column(i), options); } } else { std::stringstream ss; @@ -533,17 +539,18 @@ void CompareBatchWith(const RecordBatch& left, const RecordBatch& right, } void CompareBatch(const RecordBatch& left, const RecordBatch& right, - bool compare_metadata) { + bool compare_metadata, const EqualOptions& options) { return CompareBatchWith( left, right, compare_metadata, - [](const Array& left, const Array& right) { return left.Equals(right); }); + [&](const Array& left, const Array& right) { return left.Equals(right, options); }); } void ApproxCompareBatch(const RecordBatch& left, const RecordBatch& right, - bool compare_metadata) { - return CompareBatchWith( - left, right, compare_metadata, - [](const Array& left, const Array& right) { return left.ApproxEquals(right); }); + bool compare_metadata, const EqualOptions& options) { + return CompareBatchWith(left, right, compare_metadata, + [&](const Array& left, const Array& right) { + return left.ApproxEquals(right, options); + }); } std::shared_ptr TweakValidityBit(const std::shared_ptr& array, diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 641aae5a5e2e4..916067d85b753 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -221,18 +221,22 @@ ARROW_TESTING_EXPORT void AssertScalarsEqual( ARROW_TESTING_EXPORT void AssertScalarsApproxEqual( const Scalar& expected, const Scalar& actual, bool verbose = false, const EqualOptions& options = TestingEqualOptions()); -ARROW_TESTING_EXPORT void AssertBatchesEqual(const RecordBatch& expected, - const RecordBatch& actual, - bool check_metadata = false); -ARROW_TESTING_EXPORT void AssertBatchesApproxEqual(const RecordBatch& expected, - const RecordBatch& actual); -ARROW_TESTING_EXPORT void AssertChunkedEqual(const ChunkedArray& expected, - const ChunkedArray& actual); -ARROW_TESTING_EXPORT void AssertChunkedEqual(const ChunkedArray& actual, - const ArrayVector& expected); +ARROW_TESTING_EXPORT void AssertBatchesEqual( + const RecordBatch& expected, const RecordBatch& actual, bool check_metadata = false, + const EqualOptions& options = TestingEqualOptions()); +ARROW_TESTING_EXPORT void AssertBatchesApproxEqual( + const RecordBatch& expected, const RecordBatch& actual, + const EqualOptions& options = TestingEqualOptions()); +ARROW_TESTING_EXPORT void AssertChunkedEqual( + const ChunkedArray& expected, const ChunkedArray& actual, + const EqualOptions& options = TestingEqualOptions()); +ARROW_TESTING_EXPORT void AssertChunkedEqual( + const ChunkedArray& actual, const ArrayVector& expected, + const EqualOptions& options = TestingEqualOptions()); // Like ChunkedEqual, but permits different chunk layout -ARROW_TESTING_EXPORT void AssertChunkedEquivalent(const ChunkedArray& expected, - const ChunkedArray& actual); +ARROW_TESTING_EXPORT void AssertChunkedEquivalent( + const ChunkedArray& expected, const ChunkedArray& actual, + const EqualOptions& options = TestingEqualOptions()); ARROW_TESTING_EXPORT void AssertChunkedApproxEquivalent( const ChunkedArray& expected, const ChunkedArray& actual, const EqualOptions& options = TestingEqualOptions()); @@ -277,12 +281,13 @@ ARROW_TESTING_EXPORT void AssertSchemaNotEqual(const std::shared_ptr& lh ARROW_TESTING_EXPORT Result> PrintArrayDiff( const ChunkedArray& expected, const ChunkedArray& actual); -ARROW_TESTING_EXPORT void AssertTablesEqual(const Table& expected, const Table& actual, - bool same_chunk_layout = true, - bool flatten = false); +ARROW_TESTING_EXPORT void AssertTablesEqual( + const Table& expected, const Table& actual, bool same_chunk_layout = true, + bool flatten = false, const EqualOptions& options = TestingEqualOptions()); -ARROW_TESTING_EXPORT void AssertDatumsEqual(const Datum& expected, const Datum& actual, - bool verbose = false); +ARROW_TESTING_EXPORT void AssertDatumsEqual( + const Datum& expected, const Datum& actual, bool verbose = false, + const EqualOptions& options = TestingEqualOptions()); ARROW_TESTING_EXPORT void AssertDatumsApproxEqual( const Datum& expected, const Datum& actual, bool verbose = false, const EqualOptions& options = TestingEqualOptions()); @@ -296,12 +301,13 @@ void AssertNumericDataEqual(const C_TYPE* raw_data, } } -ARROW_TESTING_EXPORT void CompareBatch(const RecordBatch& left, const RecordBatch& right, - bool compare_metadata = true); +ARROW_TESTING_EXPORT void CompareBatch( + const RecordBatch& left, const RecordBatch& right, bool compare_metadata = true, + const EqualOptions& options = TestingEqualOptions()); -ARROW_TESTING_EXPORT void ApproxCompareBatch(const RecordBatch& left, - const RecordBatch& right, - bool compare_metadata = true); +ARROW_TESTING_EXPORT void ApproxCompareBatch( + const RecordBatch& left, const RecordBatch& right, bool compare_metadata = true, + const EqualOptions& options = TestingEqualOptions()); // Check if the padding of the buffers of the array is zero. // Also cause valgrind warnings if the padding bytes are uninitialized. diff --git a/cpp/src/arrow/testing/gtest_util_test.cc b/cpp/src/arrow/testing/gtest_util_test.cc new file mode 100644 index 0000000000000..14c17a972aa06 --- /dev/null +++ b/cpp/src/arrow/testing/gtest_util_test.cc @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/array.h" +#include "arrow/array/builder_decimal.h" +#include "arrow/datum.h" +#include "arrow/record_batch.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" +#include "arrow/util/checked_cast.h" + +namespace arrow { + +// Test basic cases for contains NaN. +class TestAssertContainsNaN : public ::testing::Test {}; + +TEST_F(TestAssertContainsNaN, BatchesEqual) { + auto schema = ::arrow::schema({ + {field("a", float32())}, + {field("b", float64())}, + }); + + auto expected = RecordBatchFromJSON(schema, + R"([{"a": 3, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": 4}, + {"a": NaN, "b": 6}, + {"a": 2, "b": 5}, + {"a": 1, "b": NaN}, + {"a": 1, "b": 3} + ])"); + auto actual = RecordBatchFromJSON(schema, + R"([{"a": 3, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": 4}, + {"a": NaN, "b": 6}, + {"a": 2, "b": 5}, + {"a": 1, "b": NaN}, + {"a": 1, "b": 3} + ])"); + ASSERT_BATCHES_EQUAL(*expected, *actual); + AssertBatchesApproxEqual(*expected, *actual); +} + +TEST_F(TestAssertContainsNaN, TableEqual) { + auto schema = ::arrow::schema({ + {field("a", float32())}, + {field("b", float64())}, + }); + + auto expected = TableFromJSON(schema, {R"([{"a": null, "b": 5}, + {"a": NaN, "b": 3}, + {"a": 3, "b": null} + ])", + R"([{"a": null, "b": null}, + {"a": 2, "b": NaN}, + {"a": 1, "b": 5}, + {"a": 3, "b": 5} + ])"}); + auto actual = TableFromJSON(schema, {R"([{"a": null, "b": 5}, + {"a": NaN, "b": 3}, + {"a": 3, "b": null} + ])", + R"([{"a": null, "b": null}, + {"a": 2, "b": NaN}, + {"a": 1, "b": 5}, + {"a": 3, "b": 5} + ])"}); + ASSERT_TABLES_EQUAL(*expected, *actual); +} + +TEST_F(TestAssertContainsNaN, ArrayEqual) { + auto expected = ArrayFromJSON(float64(), "[0, 1, 2, NaN]"); + auto actual = ArrayFromJSON(float64(), "[0, 1, 2, NaN]"); + AssertArraysEqual(*expected, *actual); +} + +TEST_F(TestAssertContainsNaN, ChunkedEqual) { + auto expected = ChunkedArrayFromJSON(float64(), { + "[null, 1]", + "[3, NaN, 2]", + "[NaN]", + }); + + auto actual = ChunkedArrayFromJSON(float64(), { + "[null, 1]", + "[3, NaN, 2]", + "[NaN]", + }); + AssertChunkedEqual(*expected, *actual); +} + +TEST_F(TestAssertContainsNaN, DatumEqual) { + // scalar + auto expected_scalar = ScalarFromJSON(float64(), "NaN"); + auto actual_scalar = ScalarFromJSON(float64(), "NaN"); + AssertDatumsEqual(expected_scalar, actual_scalar); + + // array + auto expected_array = ArrayFromJSON(float64(), "[3, NaN, 2, 1, 5]"); + auto actual_array = ArrayFromJSON(float64(), "[3, NaN, 2, 1, 5]"); + AssertDatumsEqual(expected_array, actual_array); + + // chunked array + auto expected_chunked = ChunkedArrayFromJSON(float64(), { + "[null, 1]", + "[3, NaN, 2]", + "[NaN]", + }); + + auto actual_chunked = ChunkedArrayFromJSON(float64(), { + "[null, 1]", + "[3, NaN, 2]", + "[NaN]", + }); + AssertDatumsEqual(expected_chunked, actual_chunked); +} + +} // namespace arrow