Skip to content

Commit

Permalink
more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mapleFU committed Aug 29, 2024
1 parent cab4b25 commit 5616f8c
Showing 1 changed file with 44 additions and 23 deletions.
67 changes: 44 additions & 23 deletions cpp/src/arrow/acero/aggregate_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,29 +215,50 @@ TEST(ScalarAggregateNode, AnyAll) {
// when min_count != 0.
std::shared_ptr<Schema> in_schema = schema({field("not_used", int32())});
std::shared_ptr<Schema> out_schema = schema({field("agg_out", boolean())});
std::vector<ExecBatch> batches{
ExecBatchFromJSON({int32()}, "[[42], [42], [42], [42]]")};
for (auto& func_name : {"any", "all"}) {
SCOPED_TRACE(func_name);
std::vector<Aggregate> aggregates = {
Aggregate(func_name,
std::make_shared<compute::ScalarAggregateOptions>(/*skip_nulls=*/false,
/*min_count=*/2),
FieldRef("literal_true"))};

// And a projection to make the input including a Scalar Boolean
Declaration plan = Declaration::Sequence(
{{"exec_batch_source", ExecBatchSourceNodeOptions(in_schema, batches)},
{"project", ProjectNodeOptions({literal(true)}, {"literal_true"})},
{"aggregate", AggregateNodeOptions(aggregates)}});

ASSERT_OK_AND_ASSIGN(BatchesWithCommonSchema out_batches,
DeclarationToExecBatches(plan));

ExecBatch expected_batch = ExecBatchFromJSON({boolean()}, "[[true]]");

AssertExecBatchesEqualIgnoringOrder(out_schema, {expected_batch},
out_batches.batches);
struct AnyAllCase {
std::string batches_json;
Expression literal;
std::string expected_json;
bool skip_nulls = false;
uint32_t min_count = 2;
};
std::vector<AnyAllCase> cases{
{"[[42], [42], [42], [42]]", literal(true), "[[true]]"},
{"[[42], [42], [42], [42]]", literal(false), "[[false]]"},
{"[[42], [42], [42], [42]]", literal(BooleanScalar{}), "[[null]]"},
{"[[42]]", literal(true), "[[null]]"},
{"[[42], [42], [42]]", literal(true), "[[true]]"},
{"[[42], [42], [42]]", literal(true), "[[null]]", /*skip_nulls=*/false,
/*min_count=*/4},
{"[[42], [42], [42], [42]]", literal(BooleanScalar{}), "[[null]]",
/*skip_nulls=*/true},
};
for (const AnyAllCase& any_all_case : cases) {
for (const std::string& func_name : {"any", "all"}) {
std::vector<ExecBatch> batches{
ExecBatchFromJSON({int32()}, any_all_case.batches_json)};
std::vector<Aggregate> aggregates = {
Aggregate("any",
std::make_shared<compute::ScalarAggregateOptions>(
/*skip_nulls=*/any_all_case.skip_nulls,
/*min_count=*/any_all_case.min_count),
FieldRef("literal"))};

// And a projection to make the input including a Scalar Boolean
Declaration plan = Declaration::Sequence(
{{"exec_batch_source", ExecBatchSourceNodeOptions(in_schema, batches)},
{"project", ProjectNodeOptions({any_all_case.literal}, {"literal"})},
{"aggregate", AggregateNodeOptions(aggregates)}});

ASSERT_OK_AND_ASSIGN(BatchesWithCommonSchema out_batches,
DeclarationToExecBatches(plan));

ExecBatch expected_batch =
ExecBatchFromJSON({boolean()}, any_all_case.expected_json);

AssertExecBatchesEqualIgnoringOrder(out_schema, {expected_batch},
out_batches.batches);
}
}
}

Expand Down

0 comments on commit 5616f8c

Please sign in to comment.