From a687bb0a9365f83ff0a693b067fe8c7aca4ce38f Mon Sep 17 00:00:00 2001 From: zanmato1984 Date: Fri, 12 Jan 2024 00:27:26 +0800 Subject: [PATCH] Some tests --- cpp/src/arrow/acero/hash_join_node_test.cc | 314 +++++++++++++++------ 1 file changed, 230 insertions(+), 84 deletions(-) diff --git a/cpp/src/arrow/acero/hash_join_node_test.cc b/cpp/src/arrow/acero/hash_join_node_test.cc index 87e06eeb4c6b5..8fb7ec0d88719 100644 --- a/cpp/src/arrow/acero/hash_join_node_test.cc +++ b/cpp/src/arrow/acero/hash_join_node_test.cc @@ -1939,7 +1939,7 @@ class ResidualFilterCaseRunner { ASSERT_OK_AND_ASSIGN(auto result, DeclarationToExecBatches(std::move(join), parallel)); - AssertExecBatchesEqualIgnoringOrder(result.schema, result.batches, expected); + AssertExecBatchesEqualIgnoringOrder(result.schema, expected, result.batches); } } @@ -1974,7 +1974,7 @@ class ResidualFilterCaseRunner { if (i != 0) { ss << ", "; } - ss << left_output[i].ToString(); + ss << both_output[i].ToString(); } ss << ")"; return ss.str(); @@ -1984,19 +1984,17 @@ class ResidualFilterCaseRunner { TEST(HashJoin, ResidualFilter) { BatchesWithSchema input_left; input_left.batches = {ExecBatchFromJSON({int32(), int32(), utf8()}, R"([ - [1, 6, "alpha"], - [2, 5, "beta"], - [3, 4, "alpha"] - ])")}; + [1, 6, "alpha"], + [2, 5, "beta"], + [3, 4, "alpha"]])")}; input_left.schema = schema({field("l1", int32()), field("l2", int32()), field("l_str", utf8())}); BatchesWithSchema input_right; input_right.batches = {ExecBatchFromJSON({int32(), int32(), utf8()}, R"([ - [5, 11, "alpha"], - [2, 12, "beta"], - [4, 16, "alpha"] - ])")}; + [5, 11, "alpha"], + [2, 12, "beta"], + [4, 16, "alpha"]])")}; input_right.schema = schema({field("r1", int32()), field("r2", int32()), field("r_str", utf8())}); @@ -2008,10 +2006,10 @@ TEST(HashJoin, ResidualFilter) { runner.Run(JoinType::FULL_OUTER, {"l_str"}, {"r_str"}, std::move(filter), {ExecBatchFromJSON({int32(), int32(), utf8(), int32(), int32(), utf8()}, R"([ - [1, 6, "alpha", 4, 16, "alpha"], - [1, 6, "alpha", 5, 11, "alpha"], - [2, 5, "beta", 2, 12, "beta"], - [3, 4, "alpha", 4, 16, "alpha"]])")}); + [1, 6, "alpha", 4, 16, "alpha"], + [1, 6, "alpha", 5, 11, "alpha"], + [2, 5, "beta", 2, 12, "beta"], + [3, 4, "alpha", 4, 16, "alpha"]])")}); } TEST(HashJoin, TrivialResidualFilter) { @@ -2028,14 +2026,12 @@ TEST(HashJoin, TrivialResidualFilter) { BatchesWithSchema input_left; input_left.batches = {ExecBatchFromJSON({int32(), utf8()}, R"([ - [1, "alpha"] - ])")}; + [1, "alpha"]])")}; input_left.schema = schema({field("l1", int32()), field("l_str", utf8())}); BatchesWithSchema input_right; input_right.batches = {ExecBatchFromJSON({int32(), utf8()}, R"([ - [1, "alpha"] - ])")}; + [1, "alpha"]])")}; input_right.schema = schema({field("r1", int32()), field("r_str", utf8())}); ResidualFilterCaseRunner runner{std::move(input_left), std::move(input_right)}; @@ -2055,29 +2051,41 @@ TEST(HashJoin, FineGrainedResidualFilter) { std::shared_ptr left, right; std::vector left_output, right_output; - std::vector LeftOutput() const { - std::vector output; - for (int i : left_output) { - output.push_back(FieldRef(i)); - }; + std::vector LeftOutput(JoinType join_type) const { + if (join_type == JoinType::RIGHT_SEMI || join_type == JoinType::RIGHT_ANTI) { + return {}; + } + std::vector output(left_output.size()); + std::transform(left_output.begin(), left_output.end(), output.begin(), + [](int i) { return i; }); return output; } - std::vector RightOutput() const { - std::vector output; - for (int i : right_output) { - output.push_back(FieldRef(i)); - }; + std::vector RightOutput(JoinType join_type) const { + if (join_type == JoinType::LEFT_SEMI || join_type == JoinType::LEFT_ANTI) { + return {}; + } + std::vector output(right_output.size()); + std::transform(right_output.begin(), right_output.end(), output.begin(), + [](int i) { return i; }); return output; } - ExecBatch Project(const ExecBatch& batch) const { + ExecBatch Project(JoinType join_type, const ExecBatch& batch) const { std::vector values; - for (int i : left_output) { - values.push_back(batch[i]); + if (join_type != JoinType::RIGHT_SEMI && join_type != JoinType::RIGHT_ANTI) { + for (int i : left_output) { + values.push_back(batch[i]); + } } - for (int i : right_output) { - values.push_back(batch[left_output.size() + i]); + if (join_type != JoinType::LEFT_SEMI && join_type != JoinType::LEFT_ANTI) { + int left_size = + join_type == JoinType::RIGHT_SEMI || join_type == JoinType::RIGHT_ANTI + ? 0 + : left->num_fields(); + for (int i : right_output) { + values.push_back(batch[left_size + i]); + } } return {std::move(values), batch.length}; } @@ -2090,74 +2098,212 @@ TEST(HashJoin, FineGrainedResidualFilter) { BatchesWithSchema left; left.batches = {ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ - [null, null, "payload"], - [null, 0, "payload"], - [null, 42, "payload"], - ["left_only", null, "payload"], - ["left_only", 0, "payload"], - ["left_only", 42, "payload"], - ["both1", null, "payload"], - ["both1", 0, "payload"], - ["both1", 42, "payload"], - ["both2", null, "payload"], - ["both2", 0, "payload"], - ["both2", 42, "payload"]])")}; + [null, null, "payload"], + [null, 0, "payload"], + [null, 42, "payload"], + ["left_only", null, "payload"], + ["left_only", 0, "payload"], + ["left_only", 42, "payload"], + ["both1", null, "payload"], + ["both1", 0, "payload"], + ["both1", 42, "payload"], + ["both2", null, "payload"], + ["both2", 0, "payload"], + ["both2", 42, "payload"]])")}; left.schema = schema( {field("l_key", utf8()), field("l_filter", int32()), field("l_payload", utf8())}); BatchesWithSchema right; right.batches = {ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ - [null, null, "payload"], - [null, 0, "payload"], - [null, 42, "payload"], - ["both1", null, "payload"], - ["both1", 0, "payload"], - ["both1", 42, "payload"], - ["both2", null, "payload"], - ["both2", 0, "payload"], - ["both2", 42, "payload"], - ["right_only", null, "payload"], - ["right_only", 0, "payload"], - ["right_only", 42, "payload"]])")}; + [null, null, "payload"], + [null, 0, "payload"], + [null, 42, "payload"], + ["both1", null, "payload"], + ["both1", 0, "payload"], + ["both1", 42, "payload"], + ["both2", null, "payload"], + ["both2", 0, "payload"], + ["both2", 42, "payload"], + ["right_only", null, "payload"], + ["right_only", 0, "payload"], + ["right_only", 42, "payload"]])")}; right.schema = schema( {field("r_key", utf8()), field("r_filter", int32()), field("r_payload", utf8())}); - const ResidualFilterCaseRunner runner{std::move(left), std::move(right)}; JoinSchema join_schema{left.schema, right.schema}; std::vector projectors{ - join_schema.GetProjector({0, 1, 2}, {0, 1, 2}), // Output all. - join_schema.GetProjector({0, 1}, {0, 1}), // Output key columns only. - join_schema.GetProjector({0, 1, 2}, {0, 1}), // Output left payload only. - join_schema.GetProjector({0, 1}, {0, 1, 2}), // Output right payload only. - join_schema.GetProjector({0, 1, 2}, {0, 1, 2})}; // Output all.}; + join_schema.GetProjector({0, 1, 2}, {0, 1, 2}), // Output all. + join_schema.GetProjector({0}, {0}), // Output key columns only. + join_schema.GetProjector({1}, {1}), // Output filter columns only. + join_schema.GetProjector({2}, {2})}; // Output payload columns only. + + const ResidualFilterCaseRunner runner{std::move(left), std::move(right)}; { // Literal true. Expression filter = literal(true); + std::vector left_keys{"l_key", "l_filter"}, right_keys{"r_key", "r_filter"}; { // Inner join. + JoinType join_type = JoinType::INNER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + ["both1", 0, "payload", "both1", 0, "payload"], + ["both1", 42, "payload", "both1", 42, "payload"], + ["both2", 0, "payload", "both2", 0, "payload"], + ["both2", 42, "payload", "both2", 42, "payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Left outer join. + JoinType join_type = JoinType::LEFT_OUTER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + [null, null, "payload", null, null, null], + [null, 0, "payload", null, null, null], + [null, 42, "payload", null, null, null], + ["left_only", null, "payload", null, null, null], + ["left_only", 0, "payload", null, null, null], + ["left_only", 42, "payload", null, null, null], + ["both1", null, "payload", null, null, null], + ["both2", null, "payload", null, null, null], + ["both1", 0, "payload", "both1", 0, "payload"], + ["both1", 42, "payload", "both1", 42, "payload"], + ["both2", 0, "payload", "both2", 0, "payload"], + ["both2", 42, "payload", "both2", 42, "payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Right outer join. + JoinType join_type = JoinType::RIGHT_OUTER; auto expected = - ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, - R"([ - ["both1", 0, "payload", "both1", 0, "payload"], - ["both1", 42, "payload", "both1", 42, "payload"], - ["both2", 0, "payload", "both2", 0, "payload"], - ["both2", 42, "payload", "both2", 42, "payload"]])"); - { - // for (const auto& projector : projectors) { - // runner.Run(JoinType::INNER, {"l_key", "l_filter"}, {"r_key", "r_filter"}, - // projector.LeftOutput(), projector.RightOutput(), filter, - // {projector.Project(expected)}); - // } - // Output all. - // Output all. - runner.Run(JoinType::INNER, {"l_key", "l_filter"}, {"r_key", "r_filter"}, filter, - {ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, - R"([ - ["both1", 0, "payload", "both1", 0, "payload"], - ["both1", 42, "payload", "both1", 42, "payload"], - ["both2", 0, "payload", "both2", 0, "payload"], - ["both2", 42, "payload", "both2", 42, "payload"]])")}); + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + ["both1", 0, "payload", "both1", 0, "payload"], + ["both1", 42, "payload", "both1", 42, "payload"], + ["both2", 0, "payload", "both2", 0, "payload"], + ["both2", 42, "payload", "both2", 42, "payload"], + [null, null, null, null, null, "payload"], + [null, null, null, null, 0, "payload"], + [null, null, null, null, 42, "payload"], + [null, null, null, "both1", null, "payload"], + [null, null, null, "both2", null, "payload"], + [null, null, null, "right_only", null, "payload"], + [null, null, null, "right_only", 0, "payload"], + [null, null, null, "right_only", 42, "payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Full outer join. + JoinType join_type = JoinType::FULL_OUTER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + [null, null, "payload", null, null, null], + [null, 0, "payload", null, null, null], + [null, 42, "payload", null, null, null], + ["left_only", null, "payload", null, null, null], + ["left_only", 0, "payload", null, null, null], + ["left_only", 42, "payload", null, null, null], + ["both1", null, "payload", null, null, null], + ["both2", null, "payload", null, null, null], + ["both1", 0, "payload", "both1", 0, "payload"], + ["both1", 42, "payload", "both1", 42, "payload"], + ["both2", 0, "payload", "both2", 0, "payload"], + ["both2", 42, "payload", "both2", 42, "payload"], + [null, null, null, null, null, "payload"], + [null, null, null, null, 0, "payload"], + [null, null, null, null, 42, "payload"], + [null, null, null, "both1", null, "payload"], + [null, null, null, "both2", null, "payload"], + [null, null, null, "right_only", null, "payload"], + [null, null, null, "right_only", 0, "payload"], + [null, null, null, "right_only", 42, "payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Left semi join. + JoinType join_type = JoinType::LEFT_SEMI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + ["both1", 0, "payload"], + ["both1", 42, "payload"], + ["both2", 0, "payload"], + ["both2", 42, "payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Left anti join. + JoinType join_type = JoinType::LEFT_ANTI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + [null, null, "payload"], + [null, 0, "payload"], + [null, 42, "payload"], + ["left_only", null, "payload"], + ["left_only", 0, "payload"], + ["left_only", 42, "payload"], + ["both1", null, "payload"], + ["both2", null, "payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Right semi join. + JoinType join_type = JoinType::RIGHT_SEMI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + ["both1", 0, "payload"], + ["both1", 42, "payload"], + ["both2", 0, "payload"], + ["both2", 42, "payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Right anti join. + JoinType join_type = JoinType::RIGHT_ANTI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + [null, null, "payload"], + [null, 0, "payload"], + [null, 42, "payload"], + ["both1", null, "payload"], + ["both2", null, "payload"], + ["right_only", null, "payload"], + ["right_only", 0, "payload"], + ["right_only", 42, "payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); } } }