Skip to content

Commit

Permalink
WIP permute tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Oct 11, 2024
1 parent 034d3b7 commit 9f93e5c
Showing 1 changed file with 190 additions and 55 deletions.
245 changes: 190 additions & 55 deletions cpp/src/arrow/compute/kernels/vector_placement_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <gtest/gtest.h>

#include "arrow/array/concatenate.h"
#include "arrow/chunked_array.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/kernels/test_util.h"
Expand All @@ -25,6 +26,15 @@

namespace arrow::compute {

namespace {

static const std::vector<std::shared_ptr<DataType>> kIntegerTypes = {
int8(), uint8(), int16(), uint16(), int32(), uint32(), int64(), uint64()};

using SmallOutputTypes = ::testing::Types<UInt8Type, UInt16Type, Int8Type, Int16Type>;

} // namespace

// ----------------------------------------------------------------------
// ReverseIndices tests

Expand All @@ -36,6 +46,42 @@ Result<Datum> ReverseIndices(const Datum& indices, int64_t output_length,
return ReverseIndices(indices, options);
}

template <typename InputString, typename InputShapeFunc>
void TestReverseIndices(const InputString& indices_str, int64_t output_length,
const std::string& expected_str,
InputShapeFunc&& input_shape_func, bool validity_must_be_null) {
for (const auto& input_type : kIntegerTypes) {
auto indices = input_shape_func(input_type, indices_str);
ARROW_SCOPED_TRACE("Input type: " + input_type->ToString());
for (const auto& output_type : kIntegerTypes) {
ARROW_SCOPED_TRACE("Output type: " + output_type->ToString());
auto expected = ArrayFromJSON(output_type, expected_str);
ASSERT_OK_AND_ASSIGN(Datum result,
ReverseIndices(indices, output_length, output_type));
AssertDatumsEqual(expected, result);
if (validity_must_be_null) {
ASSERT_FALSE(result.array()->HasValidityBitmap());
}
}
}
}

void TestReverseIndices(const std::string& indices_str,
const std::vector<std::string>& indices_chunked_str,
int64_t output_length, const std::string& expected_str,
bool validity_must_be_null = false) {
{
ARROW_SCOPED_TRACE("Array");
TestReverseIndices(indices_str, output_length, expected_str, ArrayFromJSON,
validity_must_be_null);
}
{
ARROW_SCOPED_TRACE("Chunked");
TestReverseIndices(indices_chunked_str, output_length, expected_str,
ChunkedArrayFromJSON, validity_must_be_null);
}
}

} // namespace

TEST(ReverseIndices, InvalidOutputType) {
Expand All @@ -44,24 +90,17 @@ TEST(ReverseIndices, InvalidOutputType) {
auto indices = ArrayFromJSON(int32(), "[]");
ASSERT_RAISES_WITH_MESSAGE(
Invalid, "Invalid: Output type of reverse_indices must be integer, got float",
ReverseIndices(indices, 0, float32()));
ReverseIndices(indices, /*output_length=*/0, /*output_type=*/float32()));
}
{
ARROW_SCOPED_TRACE("Output type string");
auto indices = ArrayFromJSON(int32(), "[]");
ASSERT_RAISES_WITH_MESSAGE(
Invalid, "Invalid: Output type of reverse_indices must be integer, got string",
ReverseIndices(indices, 0, utf8()));
ReverseIndices(indices, /*output_length=*/0, /*output_type=*/utf8()));
}
}

namespace {

static const std::vector<std::shared_ptr<DataType>> kIntegerTypes = {
int8(), uint8(), int16(), uint16(), int32(), uint32(), int64(), uint64()};

} // namespace

TEST(ReverseIndices, DefaultOptions) {
{
ARROW_SCOPED_TRACE("Default options values");
Expand Down Expand Up @@ -91,14 +130,14 @@ class TestReverseIndicesSmallOutputType : public ::testing::Test {

void JustEnoughOutputType() {
auto output_type = type_singleton();
ReverseIndicesOptions options{1, output_type};
int64_t input_length = static_cast<int64_t>(std::numeric_limits<CType>::max());
auto expected = ConstantArrayGenerator::Numeric<ArrowType>(
1, static_cast<CType>(input_length - 1));
/*size=*/1, /*value=*/static_cast<CType>(input_length - 1));
for (const auto& input_type : kIntegerTypes) {
ARROW_SCOPED_TRACE("Input type: " + input_type->ToString());
auto indices = ConstantArrayGenerator::Zeroes(input_length, input_type);
ASSERT_OK_AND_ASSIGN(Datum result, ReverseIndices(indices, options));
ASSERT_OK_AND_ASSIGN(Datum result,
ReverseIndices(indices, /*output_length=*/1, output_type));
AssertDatumsEqual(expected, result);
}
}
Expand All @@ -109,18 +148,16 @@ class TestReverseIndicesSmallOutputType : public ::testing::Test {
for (const auto& input_type : kIntegerTypes) {
ARROW_SCOPED_TRACE("Input type: " + input_type->ToString());
auto indices = ConstantArrayGenerator::Zeroes(input_length, int64());
ReverseIndicesOptions options{1, output_type};
ASSERT_RAISES_WITH_MESSAGE(
Invalid,
"Invalid: Output type " + output_type->ToString() +
" of reverse_indices is insufficient to store indices of length " +
std::to_string(input_length),
ReverseIndices(indices, options));
ReverseIndices(indices, /*output_length=*/1, output_type));
}
}
};

using SmallOutputTypes = ::testing::Types<UInt8Type, UInt16Type, Int8Type, Int16Type>;
TYPED_TEST_SUITE(TestReverseIndicesSmallOutputType, SmallOutputTypes);

TYPED_TEST(TestReverseIndicesSmallOutputType, JustEnoughOutputType) {
Expand All @@ -131,46 +168,6 @@ TYPED_TEST(TestReverseIndicesSmallOutputType, InsufficientOutputType) {
this->InsufficientOutputType();
}

namespace {

template <typename InputString, typename InputShapeFunc>
void TestReverseIndices(const InputString& indices_str, int64_t output_length,
const std::string& expected_str,
InputShapeFunc&& input_shape_func, bool validity_must_be_null) {
for (const auto& input_type : kIntegerTypes) {
auto indices = input_shape_func(input_type, indices_str);
ARROW_SCOPED_TRACE("Input type: " + input_type->ToString());
for (const auto& output_type : kIntegerTypes) {
ARROW_SCOPED_TRACE("Output type: " + output_type->ToString());
auto expected = ArrayFromJSON(output_type, expected_str);
ASSERT_OK_AND_ASSIGN(Datum result,
ReverseIndices(indices, output_length, output_type));
AssertDatumsEqual(expected, result);
if (validity_must_be_null) {
ASSERT_FALSE(result.array()->HasValidityBitmap());
}
}
}
}

void TestReverseIndices(const std::string& indices_str,
const std::vector<std::string>& indices_chunked_str,
int64_t output_length, const std::string& expected_str,
bool validity_must_be_null = false) {
{
ARROW_SCOPED_TRACE("Array");
TestReverseIndices(indices_str, output_length, expected_str, ArrayFromJSON,
validity_must_be_null);
}
{
ARROW_SCOPED_TRACE("Chunked");
TestReverseIndices(indices_chunked_str, output_length, expected_str,
ChunkedArrayFromJSON, validity_must_be_null);
}
}

} // namespace

TEST(ReverseIndices, Basic) {
{
ARROW_SCOPED_TRACE("Basic");
Expand Down Expand Up @@ -246,6 +243,144 @@ TEST(ReverseIndices, Basic) {
// ----------------------------------------------------------------------
// Permute tests

namespace {

Result<Datum> Permute(const Datum& values, const Datum& indices, int64_t output_length) {
PermuteOptions options{output_length};
return Permute(values, indices, options);
}

void TestPermuteAAA(const Array& values, const Array& indices, int64_t output_length,
const Array& expected) {
ASSERT_OK_AND_ASSIGN(Datum result, Permute(values, indices, output_length));
AssertDatumsEqual(expected, result);
}

// void TestPermuteAAA(const std::shared_ptr<DataType>& values_type,
// const std::string& values_str, const std::string& indices_str,
// int64_t output_length, const std::string& expected_str) {
// auto values = ArrayFromJSON(values_type, values_str);
// auto expected = ArrayFromJSON(values_type, expected_str);
// for (const auto& indices_type : kIntegerTypes) {
// ARROW_SCOPED_TRACE("Indices type: " + indices_type->ToString());
// auto indices = ArrayFromJSON(indices_type, indices_str);
// TestPermuteAAA(*values, *indices, output_length, *expected);
// }
// }

void TestPermuteCAC(const std::shared_ptr<ChunkedArray>& values, const Array& indices,
int64_t output_length,
const std::shared_ptr<ChunkedArray>& expected) {
ASSERT_OK_AND_ASSIGN(Datum result, Permute(values, indices, output_length));
AssertDatumsEqual(expected, result);
}

void TestPermuteCACWithArray(const std::shared_ptr<Array>& values,
const std::shared_ptr<Array>& indices, int64_t output_length,
const std::shared_ptr<Array>& expected) {
// PermuteAAA(Concat(V, V, V), I') == Concat(PermuteCAC([V, V, V], I'))
// where
// V = values
// I = indices
// I' = Concat(I + 2 * output_length, I, I + output_length)
auto values3 = ArrayVector{values, values, values};
ASSERT_OK_AND_ASSIGN(auto concat_values3, Concatenate(values3));
auto chunked_values3 = std::make_shared<ChunkedArray>(values3);

std::shared_ptr<Array> concat_indices3;
{
auto double_length =
MakeScalar(indices->type(), 2 * values->length()));
auto zero = MakeScalar(indices->type(), 0);
auto length = MakeScalar(indices->type(), static_cast<int64_t>(values->length()));
ASSERT_OK_AND_ASSIGN(auto indices_prefix, Add(indices, *double_length));
ASSERT_OK_AND_ASSIGN(auto indices_middle, Add(indices, *zero));
ASSERT_OK_AND_ASSIGN(auto indices_suffix, Add(indices, *length));
auto indices3 = ArrayVector{
indices_prefix.make_array(),
indices_middle.make_array(),
indices_suffix.make_array(),
};
ASSERT_OK_AND_ASSIGN(concat_indices3, Concatenate(indices3));
}

ASSERT_OK_AND_ASSIGN(Datum result, Permute(values, indices, output_length));
AssertDatumsEqual(expected, result);
}

// void TestPermuteACA(const Array& values, const std::shared_ptr<ChunkedArray>& indices,
// int64_t output_length, const Array& expected) {
// ASSERT_OK_AND_ASSIGN(Datum result, Permute(values, indices, output_length));
// AssertDatumsEqual(expected, result);
// }

} // namespace

TEST(Permute, Invalid) {
{
ARROW_SCOPED_TRACE("Length mismatch");
auto values = ArrayFromJSON(int32(), "[0, 1]");
auto indices = ArrayFromJSON(int32(), "[0]");
ASSERT_RAISES_WITH_MESSAGE(
Invalid,
"Invalid: Input and indices of permute must have the same length, got 2 and 1",
Permute(values, indices));
}
{
ARROW_SCOPED_TRACE("Invalid input type");
auto values = ArrayFromJSON(int32(), "[0]");
auto indices = ArrayFromJSON(utf8(), R"(["a"])");
ASSERT_RAISES_WITH_MESSAGE(
Invalid, "Invalid: Indices of permute must be of integer type, got string",
Permute(values, indices));
}
}

TEST(Permute, DefaultOptions) {
{
ARROW_SCOPED_TRACE("Default options values");
PermuteOptions options;
ASSERT_EQ(options.output_length, -1);
}
{
ARROW_SCOPED_TRACE("Default options semantics");
auto values = ArrayFromJSON(utf8(), R"(["a"])");
for (const auto& indices_type : kIntegerTypes) {
ARROW_SCOPED_TRACE("Indices type: " + indices_type->ToString());
auto indices = ArrayFromJSON(indices_type, "[0]");
ASSERT_OK_AND_ASSIGN(Datum result, Permute(values, indices));
AssertDatumsEqual(result, values);
}
}
}

template <typename ArrowType>
class TestPermuteSmallIndicesTypes : public ::testing::Test {
protected:
using CType = typename TypeTraits<ArrowType>::CType;

std::shared_ptr<DataType> type_singleton() {
return TypeTraits<ArrowType>::type_singleton();
}

void MaxIntegerIndex() {
auto values = ArrayFromJSON(utf8(), R"(["a"])");
auto indices_type = type_singleton();
int64_t max_integer = static_cast<int64_t>(std::numeric_limits<CType>::max());
auto indices = ConstantArrayGenerator::Numeric<ArrowType>(1, max_integer - 1);
ASSERT_OK_AND_ASSIGN(auto expected_prefix_nulls,
MakeArrayOfNull(utf8(), max_integer - 1));
auto expected_suffix_value = ConstantArrayGenerator::String(/*size=*/1, "a");
ASSERT_OK_AND_ASSIGN(auto expected,
Concatenate({expected_prefix_nulls, expected_suffix_value}));
TestPermuteAAA(*values, *indices, /*output_length=*/max_integer, *expected);
}
};

TYPED_TEST_SUITE(TestPermuteSmallIndicesTypes, SmallOutputTypes);

TYPED_TEST(TestPermuteSmallIndicesTypes, MaxIntegerIndex) { this->MaxIntegerIndex(); }

TEST(Permute, Basic) {
{
auto values = ArrayFromJSON(int64(), "[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]");
Expand Down

0 comments on commit 9f93e5c

Please sign in to comment.