diff --git a/cpp/src/arrow/compute/kernels/vector_placement_test.cc b/cpp/src/arrow/compute/kernels/vector_placement_test.cc index 4e30bcf4a963b..09433e5bd45c6 100644 --- a/cpp/src/arrow/compute/kernels/vector_placement_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_placement_test.cc @@ -17,6 +17,7 @@ #include +#include "arrow/array/concatenate.h" #include "arrow/chunked_array.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/kernels/test_util.h" @@ -25,6 +26,15 @@ namespace arrow::compute { +namespace { + +static const std::vector> kIntegerTypes = { + int8(), uint8(), int16(), uint16(), int32(), uint32(), int64(), uint64()}; + +using SmallOutputTypes = ::testing::Types; + +} // namespace + // ---------------------------------------------------------------------- // ReverseIndices tests @@ -36,6 +46,42 @@ Result ReverseIndices(const Datum& indices, int64_t output_length, return ReverseIndices(indices, options); } +template +void TestReverseIndices(const InputString& indices_str, int64_t output_length, + const std::string& expected_str, + InputShapeFunc&& input_shape_func, bool validity_must_be_null) { + for (const auto& input_type : kIntegerTypes) { + auto indices = input_shape_func(input_type, indices_str); + ARROW_SCOPED_TRACE("Input type: " + input_type->ToString()); + for (const auto& output_type : kIntegerTypes) { + ARROW_SCOPED_TRACE("Output type: " + output_type->ToString()); + auto expected = ArrayFromJSON(output_type, expected_str); + ASSERT_OK_AND_ASSIGN(Datum result, + ReverseIndices(indices, output_length, output_type)); + AssertDatumsEqual(expected, result); + if (validity_must_be_null) { + ASSERT_FALSE(result.array()->HasValidityBitmap()); + } + } + } +} + +void TestReverseIndices(const std::string& indices_str, + const std::vector& indices_chunked_str, + int64_t output_length, const std::string& expected_str, + bool validity_must_be_null = false) { + { + ARROW_SCOPED_TRACE("Array"); + TestReverseIndices(indices_str, output_length, expected_str, ArrayFromJSON, + validity_must_be_null); + } + { + ARROW_SCOPED_TRACE("Chunked"); + TestReverseIndices(indices_chunked_str, output_length, expected_str, + ChunkedArrayFromJSON, validity_must_be_null); + } +} + } // namespace TEST(ReverseIndices, InvalidOutputType) { @@ -44,24 +90,17 @@ TEST(ReverseIndices, InvalidOutputType) { auto indices = ArrayFromJSON(int32(), "[]"); ASSERT_RAISES_WITH_MESSAGE( Invalid, "Invalid: Output type of reverse_indices must be integer, got float", - ReverseIndices(indices, 0, float32())); + ReverseIndices(indices, /*output_length=*/0, /*output_type=*/float32())); } { ARROW_SCOPED_TRACE("Output type string"); auto indices = ArrayFromJSON(int32(), "[]"); ASSERT_RAISES_WITH_MESSAGE( Invalid, "Invalid: Output type of reverse_indices must be integer, got string", - ReverseIndices(indices, 0, utf8())); + ReverseIndices(indices, /*output_length=*/0, /*output_type=*/utf8())); } } -namespace { - -static const std::vector> kIntegerTypes = { - int8(), uint8(), int16(), uint16(), int32(), uint32(), int64(), uint64()}; - -} // namespace - TEST(ReverseIndices, DefaultOptions) { { ARROW_SCOPED_TRACE("Default options values"); @@ -91,14 +130,14 @@ class TestReverseIndicesSmallOutputType : public ::testing::Test { void JustEnoughOutputType() { auto output_type = type_singleton(); - ReverseIndicesOptions options{1, output_type}; int64_t input_length = static_cast(std::numeric_limits::max()); auto expected = ConstantArrayGenerator::Numeric( - 1, static_cast(input_length - 1)); + /*size=*/1, /*value=*/static_cast(input_length - 1)); for (const auto& input_type : kIntegerTypes) { ARROW_SCOPED_TRACE("Input type: " + input_type->ToString()); auto indices = ConstantArrayGenerator::Zeroes(input_length, input_type); - ASSERT_OK_AND_ASSIGN(Datum result, ReverseIndices(indices, options)); + ASSERT_OK_AND_ASSIGN(Datum result, + ReverseIndices(indices, /*output_length=*/1, output_type)); AssertDatumsEqual(expected, result); } } @@ -109,18 +148,16 @@ class TestReverseIndicesSmallOutputType : public ::testing::Test { for (const auto& input_type : kIntegerTypes) { ARROW_SCOPED_TRACE("Input type: " + input_type->ToString()); auto indices = ConstantArrayGenerator::Zeroes(input_length, int64()); - ReverseIndicesOptions options{1, output_type}; ASSERT_RAISES_WITH_MESSAGE( Invalid, "Invalid: Output type " + output_type->ToString() + " of reverse_indices is insufficient to store indices of length " + std::to_string(input_length), - ReverseIndices(indices, options)); + ReverseIndices(indices, /*output_length=*/1, output_type)); } } }; -using SmallOutputTypes = ::testing::Types; TYPED_TEST_SUITE(TestReverseIndicesSmallOutputType, SmallOutputTypes); TYPED_TEST(TestReverseIndicesSmallOutputType, JustEnoughOutputType) { @@ -131,46 +168,6 @@ TYPED_TEST(TestReverseIndicesSmallOutputType, InsufficientOutputType) { this->InsufficientOutputType(); } -namespace { - -template -void TestReverseIndices(const InputString& indices_str, int64_t output_length, - const std::string& expected_str, - InputShapeFunc&& input_shape_func, bool validity_must_be_null) { - for (const auto& input_type : kIntegerTypes) { - auto indices = input_shape_func(input_type, indices_str); - ARROW_SCOPED_TRACE("Input type: " + input_type->ToString()); - for (const auto& output_type : kIntegerTypes) { - ARROW_SCOPED_TRACE("Output type: " + output_type->ToString()); - auto expected = ArrayFromJSON(output_type, expected_str); - ASSERT_OK_AND_ASSIGN(Datum result, - ReverseIndices(indices, output_length, output_type)); - AssertDatumsEqual(expected, result); - if (validity_must_be_null) { - ASSERT_FALSE(result.array()->HasValidityBitmap()); - } - } - } -} - -void TestReverseIndices(const std::string& indices_str, - const std::vector& indices_chunked_str, - int64_t output_length, const std::string& expected_str, - bool validity_must_be_null = false) { - { - ARROW_SCOPED_TRACE("Array"); - TestReverseIndices(indices_str, output_length, expected_str, ArrayFromJSON, - validity_must_be_null); - } - { - ARROW_SCOPED_TRACE("Chunked"); - TestReverseIndices(indices_chunked_str, output_length, expected_str, - ChunkedArrayFromJSON, validity_must_be_null); - } -} - -} // namespace - TEST(ReverseIndices, Basic) { { ARROW_SCOPED_TRACE("Basic"); @@ -246,6 +243,144 @@ TEST(ReverseIndices, Basic) { // ---------------------------------------------------------------------- // Permute tests +namespace { + +Result Permute(const Datum& values, const Datum& indices, int64_t output_length) { + PermuteOptions options{output_length}; + return Permute(values, indices, options); +} + +void TestPermuteAAA(const Array& values, const Array& indices, int64_t output_length, + const Array& expected) { + ASSERT_OK_AND_ASSIGN(Datum result, Permute(values, indices, output_length)); + AssertDatumsEqual(expected, result); +} + +// void TestPermuteAAA(const std::shared_ptr& values_type, +// const std::string& values_str, const std::string& indices_str, +// int64_t output_length, const std::string& expected_str) { +// auto values = ArrayFromJSON(values_type, values_str); +// auto expected = ArrayFromJSON(values_type, expected_str); +// for (const auto& indices_type : kIntegerTypes) { +// ARROW_SCOPED_TRACE("Indices type: " + indices_type->ToString()); +// auto indices = ArrayFromJSON(indices_type, indices_str); +// TestPermuteAAA(*values, *indices, output_length, *expected); +// } +// } + +void TestPermuteCAC(const std::shared_ptr& values, const Array& indices, + int64_t output_length, + const std::shared_ptr& expected) { + ASSERT_OK_AND_ASSIGN(Datum result, Permute(values, indices, output_length)); + AssertDatumsEqual(expected, result); +} + +void TestPermuteCACWithArray(const std::shared_ptr& values, + const std::shared_ptr& indices, int64_t output_length, + const std::shared_ptr& expected) { + // PermuteAAA(Concat(V, V, V), I') == Concat(PermuteCAC([V, V, V], I')) + // where + // V = values + // I = indices + // I' = Concat(I + 2 * output_length, I, I + output_length) + auto values3 = ArrayVector{values, values, values}; + ASSERT_OK_AND_ASSIGN(auto concat_values3, Concatenate(values3)); + auto chunked_values3 = std::make_shared(values3); + + std::shared_ptr concat_indices3; + { + auto double_length = + MakeScalar(indices->type(), 2 * values->length())); + auto zero = MakeScalar(indices->type(), 0); + auto length = MakeScalar(indices->type(), static_cast(values->length())); + ASSERT_OK_AND_ASSIGN(auto indices_prefix, Add(indices, *double_length)); + ASSERT_OK_AND_ASSIGN(auto indices_middle, Add(indices, *zero)); + ASSERT_OK_AND_ASSIGN(auto indices_suffix, Add(indices, *length)); + auto indices3 = ArrayVector{ + indices_prefix.make_array(), + indices_middle.make_array(), + indices_suffix.make_array(), + }; + ASSERT_OK_AND_ASSIGN(concat_indices3, Concatenate(indices3)); + } + + ASSERT_OK_AND_ASSIGN(Datum result, Permute(values, indices, output_length)); + AssertDatumsEqual(expected, result); +} + +// void TestPermuteACA(const Array& values, const std::shared_ptr& indices, +// int64_t output_length, const Array& expected) { +// ASSERT_OK_AND_ASSIGN(Datum result, Permute(values, indices, output_length)); +// AssertDatumsEqual(expected, result); +// } + +} // namespace + +TEST(Permute, Invalid) { + { + ARROW_SCOPED_TRACE("Length mismatch"); + auto values = ArrayFromJSON(int32(), "[0, 1]"); + auto indices = ArrayFromJSON(int32(), "[0]"); + ASSERT_RAISES_WITH_MESSAGE( + Invalid, + "Invalid: Input and indices of permute must have the same length, got 2 and 1", + Permute(values, indices)); + } + { + ARROW_SCOPED_TRACE("Invalid input type"); + auto values = ArrayFromJSON(int32(), "[0]"); + auto indices = ArrayFromJSON(utf8(), R"(["a"])"); + ASSERT_RAISES_WITH_MESSAGE( + Invalid, "Invalid: Indices of permute must be of integer type, got string", + Permute(values, indices)); + } +} + +TEST(Permute, DefaultOptions) { + { + ARROW_SCOPED_TRACE("Default options values"); + PermuteOptions options; + ASSERT_EQ(options.output_length, -1); + } + { + ARROW_SCOPED_TRACE("Default options semantics"); + auto values = ArrayFromJSON(utf8(), R"(["a"])"); + for (const auto& indices_type : kIntegerTypes) { + ARROW_SCOPED_TRACE("Indices type: " + indices_type->ToString()); + auto indices = ArrayFromJSON(indices_type, "[0]"); + ASSERT_OK_AND_ASSIGN(Datum result, Permute(values, indices)); + AssertDatumsEqual(result, values); + } + } +} + +template +class TestPermuteSmallIndicesTypes : public ::testing::Test { + protected: + using CType = typename TypeTraits::CType; + + std::shared_ptr type_singleton() { + return TypeTraits::type_singleton(); + } + + void MaxIntegerIndex() { + auto values = ArrayFromJSON(utf8(), R"(["a"])"); + auto indices_type = type_singleton(); + int64_t max_integer = static_cast(std::numeric_limits::max()); + auto indices = ConstantArrayGenerator::Numeric(1, max_integer - 1); + ASSERT_OK_AND_ASSIGN(auto expected_prefix_nulls, + MakeArrayOfNull(utf8(), max_integer - 1)); + auto expected_suffix_value = ConstantArrayGenerator::String(/*size=*/1, "a"); + ASSERT_OK_AND_ASSIGN(auto expected, + Concatenate({expected_prefix_nulls, expected_suffix_value})); + TestPermuteAAA(*values, *indices, /*output_length=*/max_integer, *expected); + } +}; + +TYPED_TEST_SUITE(TestPermuteSmallIndicesTypes, SmallOutputTypes); + +TYPED_TEST(TestPermuteSmallIndicesTypes, MaxIntegerIndex) { this->MaxIntegerIndex(); } + TEST(Permute, Basic) { { auto values = ArrayFromJSON(int64(), "[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]");