diff --git a/cpp/src/arrow/compute/kernels/vector_placement_test.cc b/cpp/src/arrow/compute/kernels/vector_placement_test.cc index f1496b16ade35..01a4ab2ad246f 100644 --- a/cpp/src/arrow/compute/kernels/vector_placement_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_placement_test.cc @@ -23,6 +23,8 @@ #include "arrow/compute/kernels/test_util.h" #include "arrow/testing/generator.h" #include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/util/logging.h" namespace arrow::compute { @@ -34,6 +36,14 @@ static const std::vector> kSignedIntegerTypes = { static const std::vector> kIntegerTypes = { int8(), uint8(), int16(), uint16(), int32(), uint32(), int64(), uint64()}; +static const std::vector> kNumericTypes = { + uint8(), int8(), uint16(), int16(), uint32(), + int32(), uint64(), int64(), float32(), float64()}; + +static const std::vector> kNumericAndBaseBinaryTypes = { + uint8(), int8(), uint16(), int16(), uint32(), int32(), uint64(), + int64(), float32(), float64(), binary(), utf8(), large_binary(), large_utf8()}; + using SmallOutputTypes = ::testing::Types; } // namespace @@ -490,33 +500,6 @@ void DoTestPermute(const std::shared_ptr& values, DoTestPermuteForIndicesTypes(kIntegerTypes, values, indices, output_length, expected); } -// void TestPermute(const std::shared_ptr& values_type, -// const std::string& values_str, const std::string& indices_str, -// int64_t output_length, const std::string& expected_str) { -// auto values = ArrayFromJSON(values_type, values_str); -// auto expected = ArrayFromJSON(values_type, expected_str); -// for (const auto& indices_type : kIntegerTypes) { -// ARROW_SCOPED_TRACE("Indices type: " + indices_type->ToString()); -// auto indices = ArrayFromJSON(indices_type, indices_str); -// { -// ARROW_SCOPED_TRACE("AAA"); -// DoTestPermuteAAA(values, indices, output_length, expected); -// } -// { -// ARROW_SCOPED_TRACE("CAA"); -// DoTestPermuteCACWithArrays(values, indices, output_length, expected); -// } -// { -// ARROW_SCOPED_TRACE("ACA"); -// DoTestPermuteACCWithArrays(values, indices, output_length, expected); -// } -// { -// ARROW_SCOPED_TRACE("CCA"); -// DoTestPermuteCCCWithArrays(values, indices, output_length, expected); -// } -// } -// } - } // namespace TEST(Permute, Invalid) { @@ -928,4 +911,190 @@ TYPED_TEST(TestPermuteString, Basic) { } } -}; // namespace arrow::compute +// ---------------------------------------------------------------------- +// Test Permute using a hypothetical if-else special form. +// Also demonstrate how Permute can serve as a building block of implementing special +// forms. + +namespace { + +/// Execute an if-else expression using regular expression evaluation, as a reference. +Result ExecuteIfElseByExpr(const Expression& cond, const Expression& if_true, + const Expression& if_false, + const std::shared_ptr& schema, + const ExecBatch& input) { + auto if_else = call("if_else", {cond, if_true, if_false}); + ARROW_ASSIGN_OR_RAISE(auto bound, if_else.Bind(*schema)); + return ExecuteScalarExpression(bound, input); +} + +/// Execute an if-else expression in a special form fashion, in which Permute is used as a +/// building block. +Result ExecuteIfElseByPermute(const Expression& cond, const Expression& if_true, + const Expression& if_false, + const std::shared_ptr& schema, + const ExecBatch& input) { + for (const auto& column : input.values) { + DCHECK(column.is_array()); + } + + ARROW_ASSIGN_OR_RAISE(auto input_rb, input.ToRecordBatch(schema)); + + // 1. Evaluate "cond", getting a boolean array as a mask to branches. + ARROW_ASSIGN_OR_RAISE(auto bound_cond, cond.Bind(*schema)); + ARROW_ASSIGN_OR_RAISE(auto cond_datum, ExecuteScalarExpression(bound_cond, input)); + + // 2. Get indices of "true"s from the mask as the selection vector. + ARROW_ASSIGN_OR_RAISE(auto sel_if_true_datum, + CallFunction("indices_nonzero", {cond_datum})); + DCHECK(sel_if_true_datum.is_array()); + auto sel_if_true_array = sel_if_true_datum.make_array(); + + // 3. Take the "true" rows from input. + ARROW_ASSIGN_OR_RAISE(auto if_true_input_datum, + CallFunction("take", {input_rb, sel_if_true_datum})); + + // 4. Get indices of "false"es form the mas as the selection vector - by first inverting + // the mask and then getting the non-zero's indices. + ARROW_ASSIGN_OR_RAISE(auto invert_cond_datum, CallFunction("invert", {cond_datum})); + ARROW_ASSIGN_OR_RAISE(auto sel_if_false_datum, + CallFunction("indices_nonzero", {invert_cond_datum})); + DCHECK(sel_if_false_datum.is_array()); + auto sel_if_false_array = sel_if_false_datum.make_array(); + + // 5. Take the "false" rows from input. + ARROW_ASSIGN_OR_RAISE(auto if_false_input_datum, + CallFunction("take", {input_rb, sel_if_false_datum})); + + DCHECK_EQ(if_true_input_datum.kind(), Datum::RECORD_BATCH); + auto if_true_input_batch = ExecBatch(*if_true_input_datum.record_batch()); + + DCHECK_EQ(if_false_input_datum.kind(), Datum::RECORD_BATCH); + auto if_false_input_batch = ExecBatch(*if_false_input_datum.record_batch()); + + // 6. Evaluate "true" branch on the "true" rows. + ARROW_ASSIGN_OR_RAISE(auto bound_if_true, if_true.Bind(*schema)); + ARROW_ASSIGN_OR_RAISE(auto if_true_result_datum, + ExecuteScalarExpression(bound_if_true, if_true_input_batch)); + DCHECK(if_true_result_datum.is_array()); + auto if_true_result_array = if_true_result_datum.make_array(); + + // 7. Evaluate "false" branch on the "false" rows. + ARROW_ASSIGN_OR_RAISE(auto bound_if_false, if_false.Bind(*schema)); + ARROW_ASSIGN_OR_RAISE(auto if_false_result_datum, + ExecuteScalarExpression(bound_if_false, if_false_input_batch)); + DCHECK(if_false_result_datum.is_array()); + auto if_false_result_array = if_false_result_datum.make_array(); + + // 8. Combine the "true"/"false" results/selection vectors into chunked arrays. + auto result_ca = std::make_shared( + ArrayVector{if_true_result_array, if_false_result_array}); + auto sel_ca = + std::make_shared(ArrayVector{sel_if_true_array, sel_if_false_array}); + + // 9. Finally, permute the "true"/"false" results to their original positions in the + // input (according to the selection vectors). Note we didn't handle the rows with nulls + // in the mask, because Permute will fill nulls for these rows and this is equal to the + // null handling policy of if-else, which is pretty nice. + return Permute(/*values=*/result_ca, /*indices=*/sel_ca, + /*output_length=*/input.length); +} + +void DoTestIfElse(const Expression& cond, const Expression& if_true, + const Expression& if_false, const std::shared_ptr& schema, + const ExecBatch& input) { + ASSERT_OK_AND_ASSIGN(Datum result_by_expr, + ExecuteIfElseByExpr(cond, if_true, if_false, schema, input)); + ASSERT_TRUE(result_by_expr.is_array()); + ASSERT_OK_AND_ASSIGN(Datum result_by_permute, + ExecuteIfElseByPermute(cond, if_true, if_false, schema, input)); + ASSERT_TRUE(result_by_permute.is_chunked_array()); + ASSERT_OK_AND_ASSIGN(auto result_by_permute_concat, + Concatenate(result_by_permute.chunked_array()->chunks())); + + AssertDatumsEqual(result_by_expr, result_by_permute_concat); +} + +void DoTestIfElse(const Expression& cond, const Expression& if_true, + const Expression& if_false, const std::shared_ptr& schema, + const ExecBatch& input, const std::shared_ptr& expected) { + ASSERT_OK_AND_ASSIGN(Datum result, + ExecuteIfElseByPermute(cond, if_true, if_false, schema, input)); + ASSERT_TRUE(result.is_chunked_array()); + ASSERT_OK_AND_ASSIGN(auto result_concat, Concatenate(result.chunked_array()->chunks())); + + AssertDatumsEqual(expected, result_concat); +} + +} // namespace + +TEST(Permute, IfElse) { + { + ARROW_SCOPED_TRACE("if (b != 0) then a / b else b"); + auto cond = call("not_equal", {field_ref("b"), literal(0)}); + auto if_true = call("divide", {field_ref("a"), field_ref("b")}); + auto if_false = field_ref("b"); + auto schema = arrow::schema({field("a", int32()), field("b", int32())}); + { + auto rb = RecordBatchFromJSON(schema, R"([ + [1, 1], + [2, 1], + [3, 0], + [4, 1], + [5, 1] + ])"); + auto input = ExecBatch(*rb); + + ASSERT_RAISES_WITH_MESSAGE( + Invalid, "Invalid: divide by zero", + ExecuteIfElseByExpr(cond, if_true, if_false, schema, input)); + + auto expected = ArrayFromJSON(int32(), "[1, 2, 0, 4, 5]"); + DoTestIfElse(cond, if_true, if_false, schema, input, expected); + } + } + { + ARROW_SCOPED_TRACE("if (a > b) then a else b"); + auto cond = call("greater", {field_ref("a"), field_ref("b")}); + auto if_true = field_ref("a"); + auto if_false = field_ref("b"); + constexpr int64_t length = 5; + for (const auto& type : kNumericTypes) { + ARROW_SCOPED_TRACE("Type " + type->ToString()); + auto schema = arrow::schema({field("a", type), field("b", type)}); + auto big = ArrayFromJSON(type, "[1, 2, 3, 4, 5]"); + auto small = ArrayFromJSON(type, "[0, 1, 2, 3, 4]"); + { + ARROW_SCOPED_TRACE("All true"); + auto input = + ExecBatch(*RecordBatch::Make(schema, length, {/*a=*/big, /*b=*/small})); + DoTestIfElse(cond, if_true, if_false, schema, input); + } + { + ARROW_SCOPED_TRACE("All false"); + auto input = + ExecBatch(*RecordBatch::Make(schema, length, {/*a=*/small, /*b=*/big})); + DoTestIfElse(cond, if_true, if_false, schema, input); + } + } + { + ARROW_SCOPED_TRACE("Random"); + auto rng = random::RandomArrayGenerator(42); + constexpr int64_t length = 1024; + constexpr int repeat = 10; + for (const auto& type : kNumericAndBaseBinaryTypes) { + ARROW_SCOPED_TRACE("Type " + type->ToString()); + auto schema = arrow::schema({field("a", type), field("b", type)}); + for (int i = 0; i < repeat; ++i) { + auto a = rng.ArrayOf(type, length, /*null_probability=*/0.2); + auto b = rng.ArrayOf(type, length, /*null_probability=*/0.2); + auto input = + ExecBatch(*RecordBatch::Make(schema, length, {std::move(a), std::move(b)})); + DoTestIfElse(cond, if_true, if_false, schema, input); + } + } + } + } +} + +} // namespace arrow::compute