Skip to content

Commit

Permalink
Reverse indices tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Oct 11, 2024
1 parent 4ea1465 commit cbdce2f
Showing 1 changed file with 178 additions and 68 deletions.
246 changes: 178 additions & 68 deletions cpp/src/arrow/compute/kernels/vector_placement_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,108 +28,218 @@ namespace arrow::compute {
// ----------------------------------------------------------------------
// ReverseIndices tests

TEST(ReverseIndices, Invalid) {
namespace {

Result<Datum> ReverseIndices(const Datum& indices, int64_t output_length,
std::shared_ptr<DataType> output_type) {
ReverseIndicesOptions options{output_length, std::move(output_type)};
return ReverseIndices(indices, options);
}

} // namespace

TEST(ReverseIndices, InvalidOutputType) {
{
ARROW_SCOPED_TRACE("Output type float");
auto indices = ArrayFromJSON(int32(), "[]");
ASSERT_RAISES_WITH_MESSAGE(
Invalid, "Invalid: Output type of reverse_indices must be integer, got float",
ReverseIndices(indices, 0, float32()));
}
{
ARROW_SCOPED_TRACE("Output type string");
auto indices = ArrayFromJSON(int32(), "[]");
ReverseIndicesOptions options{0, utf8()};
ASSERT_RAISES_WITH_MESSAGE(
Invalid, "Invalid: Output type of reverse_indices must be integer, got string",
CallFunction("reverse_indices", {indices}, &options));
ReverseIndices(indices, 0, utf8()));
}
}

namespace {

static const std::vector<std::shared_ptr<DataType>> kIntegerTypes = {
int8(), uint8(), int16(), uint16(), int32(), uint32(), int64(), uint64()};

} // namespace

TEST(ReverseIndices, DefaultOptions) {
{
ARROW_SCOPED_TRACE("Default options values");
ReverseIndicesOptions options;
ASSERT_EQ(options.output_length, -1);
ASSERT_EQ(options.output_type, nullptr);
}
{
auto indices = ArrayFromJSON(int32(), "[0]");
ReverseIndicesOptions options;
ASSERT_OK_AND_ASSIGN(Datum result,
CallFunction("reverse_indices", {indices}, &options));
ASSERT_EQ(result.length(), 1);
ASSERT_EQ(result.type()->id(), Type::INT32);
ARROW_SCOPED_TRACE("Default options semantics");
for (const auto& input_type : kIntegerTypes) {
ARROW_SCOPED_TRACE("Input type: " + input_type->ToString());
auto indices = ArrayFromJSON(input_type, "[0]");
ASSERT_OK_AND_ASSIGN(Datum result, ReverseIndices(indices));
AssertDatumsEqual(result, indices);
}
}
}

TEST(ReverseIndices, Basic) {
template <typename ArrowType>
class TestReverseIndicesSmallOutputType : public ::testing::Test {
protected:
using CType = typename TypeTraits<ArrowType>::CType;

std::shared_ptr<DataType> type_singleton() {
return TypeTraits<ArrowType>::type_singleton();
}

void JustEnoughOutputType() {
auto output_type = type_singleton();
ReverseIndicesOptions options{1, output_type};
int64_t input_length = static_cast<int64_t>(std::numeric_limits<CType>::max());
auto expected = ConstantArrayGenerator::Numeric<ArrowType>(
1, static_cast<CType>(input_length - 1));
for (const auto& input_type : kIntegerTypes) {
ARROW_SCOPED_TRACE("Input type: " + input_type->ToString());
auto indices = ConstantArrayGenerator::Zeroes(input_length, input_type);
ASSERT_OK_AND_ASSIGN(Datum result, ReverseIndices(indices, options));
AssertDatumsEqual(expected, result);
}
}

void InsufficientOutputType() {
auto output_type = type_singleton();
int64_t input_length = static_cast<int64_t>(std::numeric_limits<CType>::max()) + 1;
for (const auto& input_type : kIntegerTypes) {
ARROW_SCOPED_TRACE("Input type: " + input_type->ToString());
auto indices = ConstantArrayGenerator::Zeroes(input_length, int64());
ReverseIndicesOptions options{1, output_type};
ASSERT_RAISES_WITH_MESSAGE(
Invalid,
"Invalid: Output type " + output_type->ToString() +
" of reverse_indices is insufficient to store indices of length " +
std::to_string(input_length),
ReverseIndices(indices, options));
}
}
};

using SmallOutputTypes = ::testing::Types<UInt8Type, UInt16Type, Int8Type, Int16Type>;
TYPED_TEST_SUITE(TestReverseIndicesSmallOutputType, SmallOutputTypes);

TYPED_TEST(TestReverseIndicesSmallOutputType, JustEnoughOutputType) {
this->JustEnoughOutputType();
}

TYPED_TEST(TestReverseIndicesSmallOutputType, InsufficientOutputType) {
this->InsufficientOutputType();
}

namespace {

template <typename InputString, typename InputShapeFunc>
void TestReverseIndices(const InputString& indices_str, int64_t output_length,
const std::string& expected_str,
InputShapeFunc&& input_shape_func, bool validity_must_be_null) {
for (const auto& input_type : kIntegerTypes) {
auto indices = input_shape_func(input_type, indices_str);
ARROW_SCOPED_TRACE("Input type: " + input_type->ToString());
for (const auto& output_type : kIntegerTypes) {
ARROW_SCOPED_TRACE("Output type: " + output_type->ToString());
auto expected = ArrayFromJSON(output_type, expected_str);
ASSERT_OK_AND_ASSIGN(Datum result,
ReverseIndices(indices, output_length, output_type));
AssertDatumsEqual(expected, result);
if (validity_must_be_null) {
ASSERT_FALSE(result.array()->HasValidityBitmap());
}
}
}
}

void TestReverseIndices(const std::string& indices_str,
const std::vector<std::string>& indices_chunked_str,
int64_t output_length, const std::string& expected_str,
bool validity_must_be_null = false) {
{
auto indices = ArrayFromJSON(int32(), "[9, 7, 5, 3, 1, 0, 2, 4, 6, 8]");
auto expected = ArrayFromJSON(int8(), "[5, 4, 6, 3, 7, 2, 8, 1, 9, 0]");
ReverseIndicesOptions options{10, int8()};
ASSERT_OK_AND_ASSIGN(Datum result,
CallFunction("reverse_indices", {indices}, &options));
AssertDatumsEqual(expected, result);
ARROW_SCOPED_TRACE("Array");
TestReverseIndices(indices_str, output_length, expected_str, ArrayFromJSON,
validity_must_be_null);
}
{
auto indices = ArrayFromJSON(int32(), "[1, 2]");
auto expected = ArrayFromJSON(int8(), "[null, 0, 1, null, null, null, null]");
ReverseIndicesOptions options{7, int8()};
ASSERT_OK_AND_ASSIGN(Datum result,
CallFunction("reverse_indices", {indices}, &options));
AssertDatumsEqual(expected, result);
ARROW_SCOPED_TRACE("Chunked");
TestReverseIndices(indices_chunked_str, output_length, expected_str,
ChunkedArrayFromJSON, validity_must_be_null);
}
}

} // namespace

TEST(ReverseIndices, Basic) {
{
auto indices = ArrayFromJSON(int32(), "[1, 2]");
auto expected = ArrayFromJSON(int8(), "[]");
ReverseIndicesOptions options{0, int8()};
ASSERT_OK_AND_ASSIGN(Datum result,
CallFunction("reverse_indices", {indices}, &options));
AssertDatumsEqual(expected, result);
ARROW_SCOPED_TRACE("Basic");
auto indices = "[9, 7, 5, 3, 1, 0, 2, 4, 6, 8]";
std::vector<std::string> indices_chunked{
"[]", "[9, 7, 5, 3, 1]", "[0]", "[2, 4, 6]", "[8]", "[]"};
int64_t output_length = 10;
auto expected = "[5, 4, 6, 3, 7, 2, 8, 1, 9, 0]";
TestReverseIndices(indices, indices_chunked, output_length, expected,
/*validity_must_be_null=*/true);
}
{
auto indices = ArrayFromJSON(int32(), "[1, 0]");
auto expected = ArrayFromJSON(int8(), "[1]");
ReverseIndicesOptions options{1, int8()};
ASSERT_OK_AND_ASSIGN(Datum result,
CallFunction("reverse_indices", {indices}, &options));
AssertDatumsEqual(expected, result);
ARROW_SCOPED_TRACE("Empty output");
auto indices = "[1, 2]";
std::vector<std::string> indices_chunked{"[]", "[1]", "[]", "[]", "[2]", "[]"};
int64_t output_length = 0;
auto expected = "[]";
TestReverseIndices(indices, indices_chunked, output_length, expected,
/*validity_must_be_null=*/true);
}
{
auto indices = ArrayFromJSON(int32(), "[1, 2]");
auto expected = ArrayFromJSON(int8(), "[null]");
ReverseIndicesOptions options{1, int8()};
ASSERT_OK_AND_ASSIGN(Datum result,
CallFunction("reverse_indices", {indices}, &options));
AssertDatumsEqual(expected, result);
ARROW_SCOPED_TRACE("Output less than input");
auto indices = "[1, 0]";
std::vector<std::string> indices_chunked{"[]", "[]", "[]", "[1, 0]"};
int64_t output_length = 1;
auto expected = "[1]";
TestReverseIndices(indices, indices_chunked, output_length, expected,
/*validity_must_be_null=*/true);
}
{
auto indices = ArrayFromJSON(int32(), "[]");
auto expected = ArrayFromJSON(int8(), "[null, null, null, null, null, null, null]");
ReverseIndicesOptions options{7, int8()};
ASSERT_OK_AND_ASSIGN(Datum result,
CallFunction("reverse_indices", {indices}, &options));
AssertDatumsEqual(expected, result);
ARROW_SCOPED_TRACE("Output greater than input");
auto indices = "[1, 2]";
std::vector<std::string> indices_chunked{"[]", "[1]", "[]", "[2]"};
int64_t output_length = 7;
auto expected = "[null, 0, 1, null, null, null, null]";
TestReverseIndices(indices, indices_chunked, output_length, expected);
}
}

TEST(ReverseIndices, Overflow) {
{
auto indices = ConstantArrayGenerator::Zeroes(127, int8());
auto expected = ArrayFromJSON(int8(), "[126]");
ReverseIndicesOptions options{1, int8()};
ASSERT_OK_AND_ASSIGN(Datum result,
CallFunction("reverse_indices", {indices}, &options));
AssertDatumsEqual(expected, result);
ARROW_SCOPED_TRACE("Input all null");
auto indices = "[null, null]";
std::vector<std::string> indices_chunked{"[]", "[null]", "[]", "[null]"};
int64_t output_length = 1;
auto expected = "[null]";
TestReverseIndices(indices, indices_chunked, output_length, expected);
}
{
ARROW_SCOPED_TRACE("Output all null");
auto indices = "[1, 2]";
std::vector<std::string> indices_chunked{"[]", "[1]", "[]", "[2]"};
int64_t output_length = 1;
auto expected = "[null]";
TestReverseIndices(indices, indices_chunked, output_length, expected);
}
{
auto indices = ConstantArrayGenerator::Zeroes(128, int8());
ReverseIndicesOptions options{1, int8()};
ASSERT_RAISES_WITH_MESSAGE(Invalid,
"Invalid: Output type int8 of reverse_indices is "
"insufficient to store indices of length 128",
CallFunction("reverse_indices", {indices}, &options));
ARROW_SCOPED_TRACE("Empty input output null");
auto indices = "[]";
std::vector<std::string> indices_chunked{"[]", "[]", "[]", "[]"};
int64_t output_length = 7;
auto expected = "[null, null, null, null, null, null, null]";
TestReverseIndices(indices, indices_chunked, output_length, expected);
}
{
ASSERT_OK_AND_ASSIGN(auto indices, MakeArrayOfNull(int8(), 128));
auto expected = ArrayFromJSON(int8(), "[null]");
ReverseIndicesOptions options{1, int8()};
ASSERT_RAISES_WITH_MESSAGE(Invalid,
"Invalid: Output type int8 of reverse_indices is "
"insufficient to store indices of length 128",
CallFunction("reverse_indices", {indices}, &options));
ARROW_SCOPED_TRACE("Input duplicated indices");
auto indices = "[1, 2, 3, 1, 2, 3, 1, 2, 3]";
std::vector<std::string> indices_chunked{"[]", "[1, 2]", "[3, 1, 2, 3, 1]",
"[]", "[2]", "[3]"};
int64_t output_length = 5;
auto expected = "[null, 6, 7, 8, null]";
TestReverseIndices(indices, indices_chunked, output_length, expected);
}
}

Expand Down

0 comments on commit cbdce2f

Please sign in to comment.