Skip to content

Commit

Permalink
Some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Jan 11, 2024
1 parent 595f96f commit a687bb0
Showing 1 changed file with 230 additions and 84 deletions.
314 changes: 230 additions & 84 deletions cpp/src/arrow/acero/hash_join_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1939,7 +1939,7 @@ class ResidualFilterCaseRunner {

ASSERT_OK_AND_ASSIGN(auto result,
DeclarationToExecBatches(std::move(join), parallel));
AssertExecBatchesEqualIgnoringOrder(result.schema, result.batches, expected);
AssertExecBatchesEqualIgnoringOrder(result.schema, expected, result.batches);
}
}

Expand Down Expand Up @@ -1974,7 +1974,7 @@ class ResidualFilterCaseRunner {
if (i != 0) {
ss << ", ";
}
ss << left_output[i].ToString();
ss << both_output[i].ToString();
}
ss << ")";
return ss.str();
Expand All @@ -1984,19 +1984,17 @@ class ResidualFilterCaseRunner {
TEST(HashJoin, ResidualFilter) {
BatchesWithSchema input_left;
input_left.batches = {ExecBatchFromJSON({int32(), int32(), utf8()}, R"([
[1, 6, "alpha"],
[2, 5, "beta"],
[3, 4, "alpha"]
])")};
[1, 6, "alpha"],
[2, 5, "beta"],
[3, 4, "alpha"]])")};
input_left.schema =
schema({field("l1", int32()), field("l2", int32()), field("l_str", utf8())});

BatchesWithSchema input_right;
input_right.batches = {ExecBatchFromJSON({int32(), int32(), utf8()}, R"([
[5, 11, "alpha"],
[2, 12, "beta"],
[4, 16, "alpha"]
])")};
[5, 11, "alpha"],
[2, 12, "beta"],
[4, 16, "alpha"]])")};
input_right.schema =
schema({field("r1", int32()), field("r2", int32()), field("r_str", utf8())});

Expand All @@ -2008,10 +2006,10 @@ TEST(HashJoin, ResidualFilter) {

runner.Run(JoinType::FULL_OUTER, {"l_str"}, {"r_str"}, std::move(filter),
{ExecBatchFromJSON({int32(), int32(), utf8(), int32(), int32(), utf8()}, R"([
[1, 6, "alpha", 4, 16, "alpha"],
[1, 6, "alpha", 5, 11, "alpha"],
[2, 5, "beta", 2, 12, "beta"],
[3, 4, "alpha", 4, 16, "alpha"]])")});
[1, 6, "alpha", 4, 16, "alpha"],
[1, 6, "alpha", 5, 11, "alpha"],
[2, 5, "beta", 2, 12, "beta"],
[3, 4, "alpha", 4, 16, "alpha"]])")});
}

TEST(HashJoin, TrivialResidualFilter) {
Expand All @@ -2028,14 +2026,12 @@ TEST(HashJoin, TrivialResidualFilter) {

BatchesWithSchema input_left;
input_left.batches = {ExecBatchFromJSON({int32(), utf8()}, R"([
[1, "alpha"]
])")};
[1, "alpha"]])")};
input_left.schema = schema({field("l1", int32()), field("l_str", utf8())});

BatchesWithSchema input_right;
input_right.batches = {ExecBatchFromJSON({int32(), utf8()}, R"([
[1, "alpha"]
])")};
[1, "alpha"]])")};
input_right.schema = schema({field("r1", int32()), field("r_str", utf8())});

ResidualFilterCaseRunner runner{std::move(input_left), std::move(input_right)};
Expand All @@ -2055,29 +2051,41 @@ TEST(HashJoin, FineGrainedResidualFilter) {
std::shared_ptr<Schema> left, right;
std::vector<int> left_output, right_output;

std::vector<FieldRef> LeftOutput() const {
std::vector<FieldRef> output;
for (int i : left_output) {
output.push_back(FieldRef(i));
};
std::vector<FieldRef> LeftOutput(JoinType join_type) const {
if (join_type == JoinType::RIGHT_SEMI || join_type == JoinType::RIGHT_ANTI) {
return {};
}
std::vector<FieldRef> output(left_output.size());
std::transform(left_output.begin(), left_output.end(), output.begin(),
[](int i) { return i; });
return output;
}

std::vector<FieldRef> RightOutput() const {
std::vector<FieldRef> output;
for (int i : right_output) {
output.push_back(FieldRef(i));
};
std::vector<FieldRef> RightOutput(JoinType join_type) const {
if (join_type == JoinType::LEFT_SEMI || join_type == JoinType::LEFT_ANTI) {
return {};
}
std::vector<FieldRef> output(right_output.size());
std::transform(right_output.begin(), right_output.end(), output.begin(),
[](int i) { return i; });
return output;
}

ExecBatch Project(const ExecBatch& batch) const {
ExecBatch Project(JoinType join_type, const ExecBatch& batch) const {
std::vector<Datum> values;
for (int i : left_output) {
values.push_back(batch[i]);
if (join_type != JoinType::RIGHT_SEMI && join_type != JoinType::RIGHT_ANTI) {
for (int i : left_output) {
values.push_back(batch[i]);
}
}
for (int i : right_output) {
values.push_back(batch[left_output.size() + i]);
if (join_type != JoinType::LEFT_SEMI && join_type != JoinType::LEFT_ANTI) {
int left_size =
join_type == JoinType::RIGHT_SEMI || join_type == JoinType::RIGHT_ANTI
? 0
: left->num_fields();
for (int i : right_output) {
values.push_back(batch[left_size + i]);
}
}
return {std::move(values), batch.length};
}
Expand All @@ -2090,74 +2098,212 @@ TEST(HashJoin, FineGrainedResidualFilter) {

BatchesWithSchema left;
left.batches = {ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([
[null, null, "payload"],
[null, 0, "payload"],
[null, 42, "payload"],
["left_only", null, "payload"],
["left_only", 0, "payload"],
["left_only", 42, "payload"],
["both1", null, "payload"],
["both1", 0, "payload"],
["both1", 42, "payload"],
["both2", null, "payload"],
["both2", 0, "payload"],
["both2", 42, "payload"]])")};
[null, null, "payload"],
[null, 0, "payload"],
[null, 42, "payload"],
["left_only", null, "payload"],
["left_only", 0, "payload"],
["left_only", 42, "payload"],
["both1", null, "payload"],
["both1", 0, "payload"],
["both1", 42, "payload"],
["both2", null, "payload"],
["both2", 0, "payload"],
["both2", 42, "payload"]])")};
left.schema = schema(
{field("l_key", utf8()), field("l_filter", int32()), field("l_payload", utf8())});

BatchesWithSchema right;
right.batches = {ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([
[null, null, "payload"],
[null, 0, "payload"],
[null, 42, "payload"],
["both1", null, "payload"],
["both1", 0, "payload"],
["both1", 42, "payload"],
["both2", null, "payload"],
["both2", 0, "payload"],
["both2", 42, "payload"],
["right_only", null, "payload"],
["right_only", 0, "payload"],
["right_only", 42, "payload"]])")};
[null, null, "payload"],
[null, 0, "payload"],
[null, 42, "payload"],
["both1", null, "payload"],
["both1", 0, "payload"],
["both1", 42, "payload"],
["both2", null, "payload"],
["both2", 0, "payload"],
["both2", 42, "payload"],
["right_only", null, "payload"],
["right_only", 0, "payload"],
["right_only", 42, "payload"]])")};
right.schema = schema(
{field("r_key", utf8()), field("r_filter", int32()), field("r_payload", utf8())});

const ResidualFilterCaseRunner runner{std::move(left), std::move(right)};
JoinSchema join_schema{left.schema, right.schema};
std::vector<JoinSchema::Projector> projectors{
join_schema.GetProjector({0, 1, 2}, {0, 1, 2}), // Output all.
join_schema.GetProjector({0, 1}, {0, 1}), // Output key columns only.
join_schema.GetProjector({0, 1, 2}, {0, 1}), // Output left payload only.
join_schema.GetProjector({0, 1}, {0, 1, 2}), // Output right payload only.
join_schema.GetProjector({0, 1, 2}, {0, 1, 2})}; // Output all.};
join_schema.GetProjector({0, 1, 2}, {0, 1, 2}), // Output all.
join_schema.GetProjector({0}, {0}), // Output key columns only.
join_schema.GetProjector({1}, {1}), // Output filter columns only.
join_schema.GetProjector({2}, {2})}; // Output payload columns only.

const ResidualFilterCaseRunner runner{std::move(left), std::move(right)};

{
// Literal true.
Expression filter = literal(true);
std::vector<FieldRef> left_keys{"l_key", "l_filter"}, right_keys{"r_key", "r_filter"};
{
// Inner join.
JoinType join_type = JoinType::INNER;
auto expected =
ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([
["both1", 0, "payload", "both1", 0, "payload"],
["both1", 42, "payload", "both1", 42, "payload"],
["both2", 0, "payload", "both2", 0, "payload"],
["both2", 42, "payload", "both2", 42, "payload"]])");
for (const auto& projector : projectors) {
runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type),
projector.RightOutput(join_type), filter,
{projector.Project(join_type, expected)});
}
}

{
// Left outer join.
JoinType join_type = JoinType::LEFT_OUTER;
auto expected =
ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([
[null, null, "payload", null, null, null],
[null, 0, "payload", null, null, null],
[null, 42, "payload", null, null, null],
["left_only", null, "payload", null, null, null],
["left_only", 0, "payload", null, null, null],
["left_only", 42, "payload", null, null, null],
["both1", null, "payload", null, null, null],
["both2", null, "payload", null, null, null],
["both1", 0, "payload", "both1", 0, "payload"],
["both1", 42, "payload", "both1", 42, "payload"],
["both2", 0, "payload", "both2", 0, "payload"],
["both2", 42, "payload", "both2", 42, "payload"]])");
for (const auto& projector : projectors) {
runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type),
projector.RightOutput(join_type), filter,
{projector.Project(join_type, expected)});
}
}

{
// Right outer join.
JoinType join_type = JoinType::RIGHT_OUTER;
auto expected =
ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()},
R"([
["both1", 0, "payload", "both1", 0, "payload"],
["both1", 42, "payload", "both1", 42, "payload"],
["both2", 0, "payload", "both2", 0, "payload"],
["both2", 42, "payload", "both2", 42, "payload"]])");
{
// for (const auto& projector : projectors) {
// runner.Run(JoinType::INNER, {"l_key", "l_filter"}, {"r_key", "r_filter"},
// projector.LeftOutput(), projector.RightOutput(), filter,
// {projector.Project(expected)});
// }
// Output all.
// Output all.
runner.Run(JoinType::INNER, {"l_key", "l_filter"}, {"r_key", "r_filter"}, filter,
{ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()},
R"([
["both1", 0, "payload", "both1", 0, "payload"],
["both1", 42, "payload", "both1", 42, "payload"],
["both2", 0, "payload", "both2", 0, "payload"],
["both2", 42, "payload", "both2", 42, "payload"]])")});
ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([
["both1", 0, "payload", "both1", 0, "payload"],
["both1", 42, "payload", "both1", 42, "payload"],
["both2", 0, "payload", "both2", 0, "payload"],
["both2", 42, "payload", "both2", 42, "payload"],
[null, null, null, null, null, "payload"],
[null, null, null, null, 0, "payload"],
[null, null, null, null, 42, "payload"],
[null, null, null, "both1", null, "payload"],
[null, null, null, "both2", null, "payload"],
[null, null, null, "right_only", null, "payload"],
[null, null, null, "right_only", 0, "payload"],
[null, null, null, "right_only", 42, "payload"]])");
for (const auto& projector : projectors) {
runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type),
projector.RightOutput(join_type), filter,
{projector.Project(join_type, expected)});
}
}

{
// Full outer join.
JoinType join_type = JoinType::FULL_OUTER;
auto expected =
ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([
[null, null, "payload", null, null, null],
[null, 0, "payload", null, null, null],
[null, 42, "payload", null, null, null],
["left_only", null, "payload", null, null, null],
["left_only", 0, "payload", null, null, null],
["left_only", 42, "payload", null, null, null],
["both1", null, "payload", null, null, null],
["both2", null, "payload", null, null, null],
["both1", 0, "payload", "both1", 0, "payload"],
["both1", 42, "payload", "both1", 42, "payload"],
["both2", 0, "payload", "both2", 0, "payload"],
["both2", 42, "payload", "both2", 42, "payload"],
[null, null, null, null, null, "payload"],
[null, null, null, null, 0, "payload"],
[null, null, null, null, 42, "payload"],
[null, null, null, "both1", null, "payload"],
[null, null, null, "both2", null, "payload"],
[null, null, null, "right_only", null, "payload"],
[null, null, null, "right_only", 0, "payload"],
[null, null, null, "right_only", 42, "payload"]])");
for (const auto& projector : projectors) {
runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type),
projector.RightOutput(join_type), filter,
{projector.Project(join_type, expected)});
}
}

{
// Left semi join.
JoinType join_type = JoinType::LEFT_SEMI;
auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([
["both1", 0, "payload"],
["both1", 42, "payload"],
["both2", 0, "payload"],
["both2", 42, "payload"]])");
for (const auto& projector : projectors) {
runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type),
projector.RightOutput(join_type), filter,
{projector.Project(join_type, expected)});
}
}

{
// Left anti join.
JoinType join_type = JoinType::LEFT_ANTI;
auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([
[null, null, "payload"],
[null, 0, "payload"],
[null, 42, "payload"],
["left_only", null, "payload"],
["left_only", 0, "payload"],
["left_only", 42, "payload"],
["both1", null, "payload"],
["both2", null, "payload"]])");
for (const auto& projector : projectors) {
runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type),
projector.RightOutput(join_type), filter,
{projector.Project(join_type, expected)});
}
}

{
// Right semi join.
JoinType join_type = JoinType::RIGHT_SEMI;
auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([
["both1", 0, "payload"],
["both1", 42, "payload"],
["both2", 0, "payload"],
["both2", 42, "payload"]])");
for (const auto& projector : projectors) {
runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type),
projector.RightOutput(join_type), filter,
{projector.Project(join_type, expected)});
}
}

{
// Right anti join.
JoinType join_type = JoinType::RIGHT_ANTI;
auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([
[null, null, "payload"],
[null, 0, "payload"],
[null, 42, "payload"],
["both1", null, "payload"],
["both2", null, "payload"],
["right_only", null, "payload"],
["right_only", 0, "payload"],
["right_only", 42, "payload"]])");
for (const auto& projector : projectors) {
runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type),
projector.RightOutput(join_type), filter,
{projector.Project(join_type, expected)});
}
}
}
Expand Down

0 comments on commit a687bb0

Please sign in to comment.