From 0ce72675f4dd755b2696eb6597850b21df647bb8 Mon Sep 17 00:00:00 2001 From: Rossi Sun Date: Tue, 12 Mar 2024 21:57:35 +0800 Subject: [PATCH] GH-20339: [C++] Add residual filter support to swiss join (#39487) ### Rationale for this change Add residual filter support to swiss join. ### What changes are included in this PR? 1. Added class `JoinResidualFilter` as a centralized structure to evaluate residual filter in swiss join. It has various flavors of filtering for various join types. Zero-overhead is guaranteed for trivial filters (literal true and sometimes literal false/null). More detailed explanation in code comments. 2. Tuned the structure of swiss join main body (`JoinProbeProcessor::OnNextBatch`) to better cope with `JoinResidualFilter` calls. ### Are these changes tested? Legacy UTs (`HashJoin.Random`, `HashJoin.ResidualFilter` and `HashJoin.TrivialResidualFilter`) cover part of this change. New fine-grained residual filter cases added as well. ### Are there any user-facing changes? No. * Closes: #20339 Lead-authored-by: zanmato Co-authored-by: zanmato1984 Co-authored-by: Ruoxi Sun Co-authored-by: Rossi Sun Signed-off-by: Weston Pace --- cpp/src/arrow/acero/hash_join_benchmark.cc | 197 +++- cpp/src/arrow/acero/hash_join_node.cc | 6 +- cpp/src/arrow/acero/hash_join_node_test.cc | 1187 ++++++++++++++++++-- cpp/src/arrow/acero/swiss_join.cc | 606 ++++++++-- cpp/src/arrow/acero/swiss_join_internal.h | 157 ++- 5 files changed, 1975 insertions(+), 178 deletions(-) diff --git a/cpp/src/arrow/acero/hash_join_benchmark.cc b/cpp/src/arrow/acero/hash_join_benchmark.cc index 9be4bed606553..993c0b9a705b4 100644 --- a/cpp/src/arrow/acero/hash_join_benchmark.cc +++ b/cpp/src/arrow/acero/hash_join_benchmark.cc @@ -51,6 +51,10 @@ struct BenchmarkSettings { double null_percentage = 0.0; double cardinality = 1.0; // Proportion of distinct keys in build side double selectivity = 1.0; // Probability of a match for a given row + int var_length_min = 2; // Minimal length of any var length types + int var_length_max = 20; // Maximum length of any var length types + + Expression residual_filter = literal(true); }; class JoinBenchmark { @@ -79,8 +83,8 @@ class JoinBenchmark { build_metadata["null_probability"] = std::to_string(settings.null_percentage); build_metadata["min"] = std::to_string(min_build_value); build_metadata["max"] = std::to_string(max_build_value); - build_metadata["min_length"] = "2"; - build_metadata["max_length"] = "20"; + build_metadata["min_length"] = settings.var_length_min; + build_metadata["max_length"] = settings.var_length_max; std::unordered_map probe_metadata; probe_metadata["null_probability"] = std::to_string(settings.null_percentage); @@ -126,10 +130,9 @@ class JoinBenchmark { stats_.num_probe_rows = settings.num_probe_batches * settings.batch_size; schema_mgr_ = std::make_unique(); - Expression filter = literal(true); DCHECK_OK(schema_mgr_->Init(settings.join_type, *l_batches_with_schema.schema, left_keys, *r_batches_with_schema.schema, right_keys, - filter, "l_", "r_")); + settings.residual_filter, "l_", "r_")); if (settings.use_basic_implementation) { join_ = *HashJoinImpl::MakeBasic(); @@ -158,7 +161,7 @@ class JoinBenchmark { DCHECK_OK(join_->Init( &ctx_, settings.join_type, settings.num_threads, &(schema_mgr_->proj_maps[0]), - &(schema_mgr_->proj_maps[1]), std::move(key_cmp), std::move(filter), + &(schema_mgr_->proj_maps[1]), std::move(key_cmp), settings.residual_filter, std::move(register_task_group_callback), std::move(start_task_group_callback), [](int64_t, ExecBatch) { return Status::OK(); }, [](int64_t) { return Status::OK(); })); @@ -308,6 +311,60 @@ static void BM_HashJoinBasic_NullPercentage(benchmark::State& st) { HashJoinBasicBenchmarkImpl(st, settings); } + +template +static void BM_HashJoinBasic_TrivialResidualFilter(benchmark::State& st, + JoinType join_type, + Expression residual_filter, + Args&&...) { + BenchmarkSettings settings; + settings.join_type = join_type; + settings.build_payload_types = {binary()}; + settings.probe_payload_types = {binary()}; + + settings.use_basic_implementation = st.range(0); + + settings.num_build_batches = 1024; + settings.num_probe_batches = 1024; + + // Let payload column length from 1 to 100. + settings.var_length_min = 1; + settings.var_length_max = 100; + + settings.residual_filter = std::move(residual_filter); + + HashJoinBasicBenchmarkImpl(st, settings); +} + +template +static void BM_HashJoinBasic_ComplexResidualFilter(benchmark::State& st, + JoinType join_type, Args&&...) { + BenchmarkSettings settings; + settings.join_type = join_type; + settings.build_payload_types = {binary()}; + settings.probe_payload_types = {binary()}; + + settings.use_basic_implementation = st.range(0); + + settings.num_build_batches = 1024; + settings.num_probe_batches = 1024; + + // Let payload column length from 1 to 100. + settings.var_length_min = 1; + settings.var_length_max = 100; + + // Create filter referring payload columns from both sides. + // binary_length(probe_payload) + binary_length(build_payload) <= 2 * selectivity + settings.selectivity = static_cast(st.range(1)) / 100.0; + using arrow::compute::call; + using arrow::compute::field_ref; + settings.residual_filter = + call("less_equal", {call("plus", {call("binary_length", {field_ref("lp0")}), + call("binary_length", {field_ref("rp0")})}), + literal(2 * settings.selectivity)}); + + HashJoinBasicBenchmarkImpl(st, settings); +} #endif std::vector hashtable_krows = benchmark::CreateRange(1, 4096, 8); @@ -435,6 +492,136 @@ BENCHMARK(BM_HashJoinBasic_BuildParallelism) BENCHMARK(BM_HashJoinBasic_NullPercentage) ->ArgNames({"Null Percentage"}) ->DenseRange(0, 100, 10); + +const char* use_basic_argname = "Use basic"; +std::vector use_basic_arg = benchmark::CreateDenseRange(0, 1, 1); + +std::vector trivial_residual_filter_argnames = {use_basic_argname}; +std::vector> trivial_residual_filter_args = {use_basic_arg}; + +BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Inner/Literal(true)", + JoinType::INNER, literal(true)) + ->ArgNames(trivial_residual_filter_argnames) + ->ArgsProduct(trivial_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Left Semi/Literal(true)", + JoinType::LEFT_SEMI, literal(true)) + ->ArgNames(trivial_residual_filter_argnames) + ->ArgsProduct(trivial_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Right Semi/Literal(true)", + JoinType::RIGHT_SEMI, literal(true)) + ->ArgNames(trivial_residual_filter_argnames) + ->ArgsProduct(trivial_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Left Anti/Literal(true)", + JoinType::LEFT_ANTI, literal(true)) + ->ArgNames(trivial_residual_filter_argnames) + ->ArgsProduct(trivial_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Right Anti/Literal(true)", + JoinType::RIGHT_ANTI, literal(true)) + ->ArgNames(trivial_residual_filter_argnames) + ->ArgsProduct(trivial_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Left Outer/Literal(true)", + JoinType::LEFT_OUTER, literal(true)) + ->ArgNames(trivial_residual_filter_argnames) + ->ArgsProduct(trivial_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Right Outer/Literal(true)", + JoinType::RIGHT_OUTER, literal(true)) + ->ArgNames(trivial_residual_filter_argnames) + ->ArgsProduct(trivial_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Full Outer/Literal(true)", + JoinType::FULL_OUTER, literal(true)) + ->ArgNames(trivial_residual_filter_argnames) + ->ArgsProduct(trivial_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Inner/Literal(false)", + JoinType::INNER, literal(false)) + ->ArgNames(trivial_residual_filter_argnames) + ->ArgsProduct(trivial_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Left Semi/Literal(false)", + JoinType::LEFT_SEMI, literal(false)) + ->ArgNames(trivial_residual_filter_argnames) + ->ArgsProduct(trivial_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Right Semi/Literal(false)", + JoinType::RIGHT_SEMI, literal(false)) + ->ArgNames(trivial_residual_filter_argnames) + ->ArgsProduct(trivial_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Left Anti/Literal(false)", + JoinType::LEFT_ANTI, literal(false)) + ->ArgNames(trivial_residual_filter_argnames) + ->ArgsProduct(trivial_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Right Anti/Literal(false)", + JoinType::RIGHT_ANTI, literal(false)) + ->ArgNames(trivial_residual_filter_argnames) + ->ArgsProduct(trivial_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Left Outer/Literal(false)", + JoinType::LEFT_OUTER, literal(false)) + ->ArgNames(trivial_residual_filter_argnames) + ->ArgsProduct(trivial_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Right Outer/Literal(false)", + JoinType::RIGHT_OUTER, literal(false)) + ->ArgNames(trivial_residual_filter_argnames) + ->ArgsProduct(trivial_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_TrivialResidualFilter, "Full Outer/Literal(false)", + JoinType::FULL_OUTER, literal(false)) + ->ArgNames(trivial_residual_filter_argnames) + ->ArgsProduct(trivial_residual_filter_args); + +std::vector complex_residual_filter_argnames = {use_basic_argname, + "Selectivity"}; +std::vector> complex_residual_filter_args = { + use_basic_arg, benchmark::CreateDenseRange(0, 100, 20)}; + +BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Inner", JoinType::INNER) + ->ArgNames(complex_residual_filter_argnames) + ->ArgsProduct(complex_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Left Semi", + JoinType::LEFT_SEMI) + ->ArgNames(complex_residual_filter_argnames) + ->ArgsProduct(complex_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Right Semi", + JoinType::RIGHT_SEMI) + ->ArgNames(complex_residual_filter_argnames) + ->ArgsProduct(complex_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Left Anti", + JoinType::LEFT_ANTI) + ->ArgNames(complex_residual_filter_argnames) + ->ArgsProduct(complex_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Right Anti", + JoinType::RIGHT_ANTI) + ->ArgNames(complex_residual_filter_argnames) + ->ArgsProduct(complex_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Left Outer", + JoinType::LEFT_OUTER) + ->ArgNames(complex_residual_filter_argnames) + ->ArgsProduct(complex_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Right Outer", + JoinType::RIGHT_OUTER) + ->ArgNames(complex_residual_filter_argnames) + ->ArgsProduct(complex_residual_filter_args); + +BENCHMARK_CAPTURE(BM_HashJoinBasic_ComplexResidualFilter, "Full Outer", + JoinType::FULL_OUTER) + ->ArgNames(complex_residual_filter_argnames) + ->ArgsProduct(complex_residual_filter_args); #else BENCHMARK_CAPTURE(BM_HashJoinBasic_KeyTypes, "{int32}", {int32()}) diff --git a/cpp/src/arrow/acero/hash_join_node.cc b/cpp/src/arrow/acero/hash_join_node.cc index 254dad361ff87..c0179fd160e4e 100644 --- a/cpp/src/arrow/acero/hash_join_node.cc +++ b/cpp/src/arrow/acero/hash_join_node.cc @@ -740,13 +740,11 @@ class HashJoinNode : public ExecNode, public TracedNode { // Create hash join implementation object // SwissJoin does not support: // a) 64-bit string offsets - // b) residual predicates - // c) dictionaries + // b) dictionaries // bool use_swiss_join; #if ARROW_LITTLE_ENDIAN - use_swiss_join = (filter == literal(true)) && !schema_mgr->HasDictionaries() && - !schema_mgr->HasLargeBinary(); + use_swiss_join = !schema_mgr->HasDictionaries() && !schema_mgr->HasLargeBinary(); #else use_swiss_join = false; #endif diff --git a/cpp/src/arrow/acero/hash_join_node_test.cc b/cpp/src/arrow/acero/hash_join_node_test.cc index 58551f4eca00a..63969d9a3ed4b 100644 --- a/cpp/src/arrow/acero/hash_join_node_test.cc +++ b/cpp/src/arrow/acero/hash_join_node_test.cc @@ -1893,58 +1893,147 @@ TEST(HashJoin, CheckHashJoinNodeOptionsValidation) { } } -TEST(HashJoin, ResidualFilter) { - for (bool parallel : {false, true}) { - SCOPED_TRACE(parallel ? "parallel/merged" : "serial"); - - BatchesWithSchema input_left; - input_left.batches = {ExecBatchFromJSON({int32(), int32(), utf8()}, R"([ - [1, 6, "alpha"], - [2, 5, "beta"], - [3, 4, "alpha"] - ])")}; - input_left.schema = - schema({field("l1", int32()), field("l2", int32()), field("l_str", utf8())}); - - BatchesWithSchema input_right; - input_right.batches = {ExecBatchFromJSON({int32(), int32(), utf8()}, R"([ - [5, 11, "alpha"], - [2, 12, "beta"], - [4, 16, "alpha"] - ])")}; - input_right.schema = - schema({field("r1", int32()), field("r2", int32()), field("r_str", utf8())}); +class ResidualFilterCaseRunner { + public: + ResidualFilterCaseRunner(BatchesWithSchema left_input, BatchesWithSchema right_input) + : left_input_(std::move(left_input)), right_input_(std::move(right_input)) {} + + void Run(JoinType join_type, std::vector left_keys, + std::vector right_keys, Expression filter, + const std::vector& expected) const { + RunInternal(HashJoinNodeOptions{join_type, std::move(left_keys), + std::move(right_keys), std::move(filter)}, + expected); + } + + void Run(JoinType join_type, std::vector left_keys, + std::vector right_keys, std::vector left_output, + std::vector right_output, Expression filter, + const std::vector& expected) const { + RunInternal(HashJoinNodeOptions{join_type, std::move(left_keys), + std::move(right_keys), std::move(left_output), + std::move(right_output), std::move(filter)}, + expected); + } + + private: + void RunInternal(const HashJoinNodeOptions& options, + const std::vector& expected) const { + auto join_type_str = JoinTypeString(options.join_type); + auto join_cond_str = + JoinConditionString(options.left_keys, options.right_keys, options.filter); + auto output_str = OutputString(options.left_output, options.right_output); + for (bool parallel : {false, true}) { + auto parallel_str = parallel ? "parallel" : "serial"; + ARROW_SCOPED_TRACE(join_type_str + " " + join_cond_str + " " + output_str + " " + + parallel_str); - Declaration left{ - "source", - SourceNodeOptions{input_left.schema, input_left.gen(parallel, /*slow=*/false)}}; - Declaration right{ - "source", - SourceNodeOptions{input_right.schema, input_right.gen(parallel, /*slow=*/false)}}; + Declaration left{"source", + SourceNodeOptions{left_input_.schema, + left_input_.gen(parallel, /*slow=*/false)}}; + Declaration right{"source", + SourceNodeOptions{right_input_.schema, + right_input_.gen(parallel, /*slow=*/false)}}; - Expression mul = call("multiply", {field_ref("l1"), field_ref("l2")}); - Expression combination = call("add", {mul, field_ref("r1")}); - Expression residual_filter = less_equal(combination, field_ref("r2")); + Declaration join{"hashjoin", {std::move(left), std::move(right)}, options}; - HashJoinNodeOptions join_opts{ - JoinType::FULL_OUTER, - /*left_keys=*/{"l_str"}, - /*right_keys=*/{"r_str"}, std::move(residual_filter), "l_", "r_"}; + ASSERT_OK_AND_ASSIGN(auto result, + DeclarationToExecBatches(std::move(join), parallel)); + AssertExecBatchesEqualIgnoringOrder(result.schema, expected, result.batches); + } + } - Declaration join{"hashjoin", {std::move(left), std::move(right)}, join_opts}; + private: + BatchesWithSchema left_input_; + BatchesWithSchema right_input_; - ASSERT_OK_AND_ASSIGN(auto result, - DeclarationToExecBatches(std::move(join), parallel)); + private: + static std::string JoinTypeString(JoinType t) { + switch (t) { + case JoinType::LEFT_SEMI: + return "LEFT_SEMI"; + case JoinType::RIGHT_SEMI: + return "RIGHT_SEMI"; + case JoinType::LEFT_ANTI: + return "LEFT_ANTI"; + case JoinType::RIGHT_ANTI: + return "RIGHT_ANTI"; + case JoinType::INNER: + return "INNER"; + case JoinType::LEFT_OUTER: + return "LEFT_OUTER"; + case JoinType::RIGHT_OUTER: + return "RIGHT_OUTER"; + case JoinType::FULL_OUTER: + return "FULL_OUTER"; + } + ARROW_DCHECK(false); + return ""; + } + + static std::string JoinConditionString(const std::vector& left_keys, + const std::vector& right_keys, + const Expression& filter) { + ARROW_DCHECK(left_keys.size() > 0); + ARROW_DCHECK(left_keys.size() == right_keys.size()); + std::stringstream ss; + ss << "on ("; + for (size_t i = 0; i < left_keys.size(); ++i) { + ss << left_keys[i].ToString() << " = " << right_keys[i].ToString() << " and "; + } + ss << filter.ToString(); + ss << ")"; + return ss.str(); + } + + static std::string OutputString(const std::vector& left_output, + const std::vector& right_output) { + std::vector both_output; + both_output.reserve(left_output.size() + right_output.size()); + both_output.insert(both_output.end(), left_output.begin(), left_output.end()); + both_output.insert(both_output.end(), right_output.begin(), right_output.end()); + std::stringstream ss; + ss << "output ("; + for (size_t i = 0; i < both_output.size(); ++i) { + if (i != 0) { + ss << ", "; + } + ss << both_output[i].ToString(); + } + ss << ")"; + return ss.str(); + } +}; - std::vector expected = { - ExecBatchFromJSON({int32(), int32(), utf8(), int32(), int32(), utf8()}, R"([ - [1, 6, "alpha", 4, 16, "alpha"], - [1, 6, "alpha", 5, 11, "alpha"], - [2, 5, "beta", 2, 12, "beta"], - [3, 4, "alpha", 4, 16, "alpha"]])")}; +TEST(HashJoin, ResidualFilter) { + BatchesWithSchema input_left; + input_left.batches = {ExecBatchFromJSON({int32(), int32(), utf8()}, R"([ + [1, 6, "alpha"], + [2, 5, "beta"], + [3, 4, "alpha"]])")}; + input_left.schema = + schema({field("l1", int32()), field("l2", int32()), field("l_str", utf8())}); - AssertExecBatchesEqualIgnoringOrder(result.schema, result.batches, expected); - } + BatchesWithSchema input_right; + input_right.batches = {ExecBatchFromJSON({int32(), int32(), utf8()}, R"([ + [5, 11, "alpha"], + [2, 12, "beta"], + [4, 16, "alpha"]])")}; + input_right.schema = + schema({field("r1", int32()), field("r2", int32()), field("r_str", utf8())}); + + const ResidualFilterCaseRunner runner{std::move(input_left), std::move(input_right)}; + + Expression mul = call("multiply", {field_ref("l1"), field_ref("l2")}); + Expression combination = call("add", {mul, field_ref("r1")}); + Expression filter = less_equal(combination, field_ref("r2")); + + runner.Run(JoinType::FULL_OUTER, {"l_str"}, {"r_str"}, std::move(filter), + {ExecBatchFromJSON({int32(), int32(), utf8(), int32(), int32(), utf8()}, R"([ + [1, 6, "alpha", 4, 16, "alpha"], + [1, 6, "alpha", 5, 11, "alpha"], + [2, 5, "beta", 2, 12, "beta"], + [3, 4, "alpha", 4, 16, "alpha"]])")}); } TEST(HashJoin, TrivialResidualFilter) { @@ -1959,47 +2048,993 @@ TEST(HashJoin, TrivialResidualFilter) { std::vector expected_strings = {expected_true, expected_false}; std::vector filters = {always_true, always_false}; + BatchesWithSchema input_left; + input_left.batches = {ExecBatchFromJSON({int32(), utf8()}, R"([ + [1, "alpha"]])")}; + input_left.schema = schema({field("l1", int32()), field("l_str", utf8())}); + + BatchesWithSchema input_right; + input_right.batches = {ExecBatchFromJSON({int32(), utf8()}, R"([ + [1, "alpha"]])")}; + input_right.schema = schema({field("r1", int32()), field("r_str", utf8())}); + + ResidualFilterCaseRunner runner{std::move(input_left), std::move(input_right)}; + for (size_t test_id = 0; test_id < 2; test_id++) { - for (bool parallel : {false, true}) { - SCOPED_TRACE(parallel ? "parallel/merged" : "serial"); + runner.Run(JoinType::INNER, {"l_str"}, {"r_str"}, filters[test_id], + {ExecBatchFromJSON({int32(), utf8(), int32(), utf8()}, + expected_strings[test_id])}); + } +} - BatchesWithSchema input_left; - input_left.batches = {ExecBatchFromJSON({int32(), utf8()}, R"([ - [1, "alpha"] - ])")}; - input_left.schema = schema({field("l1", int32()), field("l_str", utf8())}); +TEST(HashJoin, FineGrainedResidualFilter) { + struct JoinSchema { + std::shared_ptr left, right; - BatchesWithSchema input_right; - input_right.batches = {ExecBatchFromJSON({int32(), utf8()}, R"([ - [1, "alpha"] - ])")}; - input_right.schema = schema({field("r1", int32()), field("r_str", utf8())}); + struct Projector { + std::shared_ptr left, right; + std::vector left_output, right_output; - auto exec_ctx = std::make_unique( - default_memory_pool(), - parallel ? arrow::internal::GetCpuThreadPool() : nullptr); + std::vector LeftOutput(JoinType join_type) const { + if (join_type == JoinType::RIGHT_SEMI || join_type == JoinType::RIGHT_ANTI) { + return {}; + } + std::vector output(left_output.size()); + std::transform(left_output.begin(), left_output.end(), output.begin(), + [](int i) { return i; }); + return output; + } - Declaration left{ - "source", - SourceNodeOptions{input_left.schema, input_left.gen(parallel, /*slow=*/false)}}; - Declaration right{"source", - SourceNodeOptions{input_right.schema, - input_right.gen(parallel, /*slow=*/false)}}; + std::vector RightOutput(JoinType join_type) const { + if (join_type == JoinType::LEFT_SEMI || join_type == JoinType::LEFT_ANTI) { + return {}; + } + std::vector output(right_output.size()); + std::transform(right_output.begin(), right_output.end(), output.begin(), + [](int i) { return i; }); + return output; + } - HashJoinNodeOptions join_opts{ - JoinType::INNER, - /*left_keys=*/{"l_str"}, - /*right_keys=*/{"r_str"}, filters[test_id], "l_", "r_"}; + ExecBatch Project(JoinType join_type, const ExecBatch& batch) const { + std::vector values; + if (join_type != JoinType::RIGHT_SEMI && join_type != JoinType::RIGHT_ANTI) { + for (int i : left_output) { + values.push_back(batch[i]); + } + } + if (join_type != JoinType::LEFT_SEMI && join_type != JoinType::LEFT_ANTI) { + int left_size = + join_type == JoinType::RIGHT_SEMI || join_type == JoinType::RIGHT_ANTI + ? 0 + : left->num_fields(); + for (int i : right_output) { + values.push_back(batch[left_size + i]); + } + } + return {std::move(values), batch.length}; + } + }; - Declaration join{"hashjoin", {std::move(left), std::move(right)}, join_opts}; + Projector GetProjector(std::vector left_output, std::vector right_output) { + return Projector{left, right, std::move(left_output), std::move(right_output)}; + } + }; + + BatchesWithSchema left; + left.batches = {ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + [null, null, "l_payload"], + [null, 0, "l_payload"], + [null, 42, "l_payload"], + ["left_only", null, "l_payload"], + ["left_only", 0, "l_payload"], + ["left_only", 42, "l_payload"], + ["both1", null, "l_payload"], + ["both1", 0, "l_payload"], + ["both1", 42, "l_payload"], + ["both2", null, "l_payload"], + ["both2", 0, "l_payload"], + ["both2", 42, "l_payload"]])")}; + left.schema = schema( + {field("l_key", utf8()), field("l_filter", int32()), field("l_payload", utf8())}); + + BatchesWithSchema right; + right.batches = {ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + [null, null, "r_payload"], + [null, 0, "r_payload"], + [null, 42, "r_payload"], + ["both1", null, "r_payload"], + ["both1", 0, "r_payload"], + ["both1", 42, "r_payload"], + ["both2", null, "r_payload"], + ["both2", 0, "r_payload"], + ["both2", 42, "r_payload"], + ["right_only", null, "r_payload"], + ["right_only", 0, "r_payload"], + ["right_only", 42, "r_payload"]])")}; + right.schema = schema( + {field("r_key", utf8()), field("r_filter", int32()), field("r_payload", utf8())}); + + JoinSchema join_schema{left.schema, right.schema}; + std::vector projectors{ + join_schema.GetProjector({0, 1, 2}, {0, 1, 2}), // Output all. + join_schema.GetProjector({0}, {0}), // Output key columns only. + join_schema.GetProjector({1}, {1}), // Output filter columns only. + join_schema.GetProjector({2}, {2})}; // Output payload columns only. + + const ResidualFilterCaseRunner runner{std::move(left), std::move(right)}; - ASSERT_OK_AND_ASSIGN(auto result, - DeclarationToExecBatches(std::move(join), parallel)); + { + // Literal true and scalar true. + for (Expression filter : {literal(true), equal(literal(1), literal(1))}) { + std::vector left_keys{"l_key", "l_filter"}, + right_keys{"r_key", "r_filter"}; + { + // Inner join. + JoinType join_type = JoinType::INNER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + ["both1", 0, "l_payload", "both1", 0, "r_payload"], + ["both1", 42, "l_payload", "both1", 42, "r_payload"], + ["both2", 0, "l_payload", "both2", 0, "r_payload"], + ["both2", 42, "l_payload", "both2", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } - std::vector expected = {ExecBatchFromJSON( - {int32(), utf8(), int32(), utf8()}, expected_strings[test_id])}; + { + // Left outer join. + JoinType join_type = JoinType::LEFT_OUTER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + [null, null, "l_payload", null, null, null], + [null, 0, "l_payload", null, null, null], + [null, 42, "l_payload", null, null, null], + ["left_only", null, "l_payload", null, null, null], + ["left_only", 0, "l_payload", null, null, null], + ["left_only", 42, "l_payload", null, null, null], + ["both1", null, "l_payload", null, null, null], + ["both2", null, "l_payload", null, null, null], + ["both1", 0, "l_payload", "both1", 0, "r_payload"], + ["both1", 42, "l_payload", "both1", 42, "r_payload"], + ["both2", 0, "l_payload", "both2", 0, "r_payload"], + ["both2", 42, "l_payload", "both2", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } - AssertExecBatchesEqualIgnoringOrder(result.schema, result.batches, expected); + { + // Right outer join. + JoinType join_type = JoinType::RIGHT_OUTER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + ["both1", 0, "l_payload", "both1", 0, "r_payload"], + ["both1", 42, "l_payload", "both1", 42, "r_payload"], + ["both2", 0, "l_payload", "both2", 0, "r_payload"], + ["both2", 42, "l_payload", "both2", 42, "r_payload"], + [null, null, null, null, null, "r_payload"], + [null, null, null, null, 0, "r_payload"], + [null, null, null, null, 42, "r_payload"], + [null, null, null, "both1", null, "r_payload"], + [null, null, null, "both2", null, "r_payload"], + [null, null, null, "right_only", null, "r_payload"], + [null, null, null, "right_only", 0, "r_payload"], + [null, null, null, "right_only", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Full outer join. + JoinType join_type = JoinType::FULL_OUTER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + [null, null, "l_payload", null, null, null], + [null, 0, "l_payload", null, null, null], + [null, 42, "l_payload", null, null, null], + ["left_only", null, "l_payload", null, null, null], + ["left_only", 0, "l_payload", null, null, null], + ["left_only", 42, "l_payload", null, null, null], + ["both1", null, "l_payload", null, null, null], + ["both2", null, "l_payload", null, null, null], + ["both1", 0, "l_payload", "both1", 0, "r_payload"], + ["both1", 42, "l_payload", "both1", 42, "r_payload"], + ["both2", 0, "l_payload", "both2", 0, "r_payload"], + ["both2", 42, "l_payload", "both2", 42, "r_payload"], + [null, null, null, null, null, "r_payload"], + [null, null, null, null, 0, "r_payload"], + [null, null, null, null, 42, "r_payload"], + [null, null, null, "both1", null, "r_payload"], + [null, null, null, "both2", null, "r_payload"], + [null, null, null, "right_only", null, "r_payload"], + [null, null, null, "right_only", 0, "r_payload"], + [null, null, null, "right_only", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Left semi join. + JoinType join_type = JoinType::LEFT_SEMI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + ["both1", 0, "l_payload"], + ["both1", 42, "l_payload"], + ["both2", 0, "l_payload"], + ["both2", 42, "l_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Left anti join. + JoinType join_type = JoinType::LEFT_ANTI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + [null, null, "l_payload"], + [null, 0, "l_payload"], + [null, 42, "l_payload"], + ["left_only", null, "l_payload"], + ["left_only", 0, "l_payload"], + ["left_only", 42, "l_payload"], + ["both1", null, "l_payload"], + ["both2", null, "l_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Right semi join. + JoinType join_type = JoinType::RIGHT_SEMI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + ["both1", 0, "r_payload"], + ["both1", 42, "r_payload"], + ["both2", 0, "r_payload"], + ["both2", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Right anti join. + JoinType join_type = JoinType::RIGHT_ANTI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + [null, null, "r_payload"], + [null, 0, "r_payload"], + [null, 42, "r_payload"], + ["both1", null, "r_payload"], + ["both2", null, "r_payload"], + ["right_only", null, "r_payload"], + ["right_only", 0, "r_payload"], + ["right_only", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + } + } + + { + // Literal false, null, and scalar false, null. + for (Expression filter : + {literal(false), literal(NullScalar()), equal(literal(0), literal(1)), + equal(literal(1), literal(NullScalar()))}) { + std::vector left_keys{"l_key", "l_filter"}, + right_keys{"r_key", "r_filter"}; + { + // Inner join. + JoinType join_type = JoinType::INNER; + auto expected = ExecBatchFromJSON( + {utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Left outer join. + JoinType join_type = JoinType::LEFT_OUTER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + [null, null, "l_payload", null, null, null], + [null, 0, "l_payload", null, null, null], + [null, 42, "l_payload", null, null, null], + ["left_only", null, "l_payload", null, null, null], + ["left_only", 0, "l_payload", null, null, null], + ["left_only", 42, "l_payload", null, null, null], + ["both1", null, "l_payload", null, null, null], + ["both1", 0, "l_payload", null, null, null], + ["both1", 42, "l_payload", null, null, null], + ["both2", null, "l_payload", null, null, null], + ["both2", 0, "l_payload", null, null, null], + ["both2", 42, "l_payload", null, null, null]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Right outer join. + JoinType join_type = JoinType::RIGHT_OUTER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + [null, null, null, null, null, "r_payload"], + [null, null, null, null, 0, "r_payload"], + [null, null, null, null, 42, "r_payload"], + [null, null, null, "both1", null, "r_payload"], + [null, null, null, "both1", 0, "r_payload"], + [null, null, null, "both1", 42, "r_payload"], + [null, null, null, "both2", null, "r_payload"], + [null, null, null, "both2", 0, "r_payload"], + [null, null, null, "both2", 42, "r_payload"], + [null, null, null, "right_only", null, "r_payload"], + [null, null, null, "right_only", 0, "r_payload"], + [null, null, null, "right_only", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Full outer join. + JoinType join_type = JoinType::FULL_OUTER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + [null, null, "l_payload", null, null, null], + [null, 0, "l_payload", null, null, null], + [null, 42, "l_payload", null, null, null], + ["left_only", null, "l_payload", null, null, null], + ["left_only", 0, "l_payload", null, null, null], + ["left_only", 42, "l_payload", null, null, null], + ["both1", null, "l_payload", null, null, null], + ["both1", 0, "l_payload", null, null, null], + ["both1", 42, "l_payload", null, null, null], + ["both2", null, "l_payload", null, null, null], + ["both2", 0, "l_payload", null, null, null], + ["both2", 42, "l_payload", null, null, null], + [null, null, null, null, null, "r_payload"], + [null, null, null, null, 0, "r_payload"], + [null, null, null, null, 42, "r_payload"], + [null, null, null, "both1", null, "r_payload"], + [null, null, null, "both1", 0, "r_payload"], + [null, null, null, "both1", 42, "r_payload"], + [null, null, null, "both2", null, "r_payload"], + [null, null, null, "both2", 0, "r_payload"], + [null, null, null, "both2", 42, "r_payload"], + [null, null, null, "right_only", null, "r_payload"], + [null, null, null, "right_only", 0, "r_payload"], + [null, null, null, "right_only", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Left semi join. + JoinType join_type = JoinType::LEFT_SEMI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Left anti join. + JoinType join_type = JoinType::LEFT_ANTI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + [null, null, "l_payload"], + [null, 0, "l_payload"], + [null, 42, "l_payload"], + ["left_only", null, "l_payload"], + ["left_only", 0, "l_payload"], + ["left_only", 42, "l_payload"], + ["both1", null, "l_payload"], + ["both1", 0, "l_payload"], + ["both1", 42, "l_payload"], + ["both2", null, "l_payload"], + ["both2", 0, "l_payload"], + ["both2", 42, "l_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Right semi join. + JoinType join_type = JoinType::RIGHT_SEMI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Right anti join. + JoinType join_type = JoinType::RIGHT_ANTI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + [null, null, "r_payload"], + [null, 0, "r_payload"], + [null, 42, "r_payload"], + ["both1", null, "r_payload"], + ["both1", 0, "r_payload"], + ["both1", 42, "r_payload"], + ["both2", null, "r_payload"], + ["both2", 0, "r_payload"], + ["both2", 42, "r_payload"], + ["right_only", null, "r_payload"], + ["right_only", 0, "r_payload"], + ["right_only", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + } + } + + { + // Non-trivial filters referring left columns only. + for (Expression filter : {equal(field_ref("l_filter"), literal(42)), + not_equal(literal(0), field_ref("l_filter"))}) { + std::vector left_keys{"l_key"}, right_keys{"r_key"}; + { + // Inner join. + JoinType join_type = JoinType::INNER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + ["both1", 42, "l_payload", "both1", null, "r_payload"], + ["both1", 42, "l_payload", "both1", 0, "r_payload"], + ["both1", 42, "l_payload", "both1", 42, "r_payload"], + ["both2", 42, "l_payload", "both2", null, "r_payload"], + ["both2", 42, "l_payload", "both2", 0, "r_payload"], + ["both2", 42, "l_payload", "both2", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Left outer join. + JoinType join_type = JoinType::LEFT_OUTER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + [null, null, "l_payload", null, null, null], + [null, 0, "l_payload", null, null, null], + [null, 42, "l_payload", null, null, null], + ["left_only", null, "l_payload", null, null, null], + ["left_only", 0, "l_payload", null, null, null], + ["left_only", 42, "l_payload", null, null, null], + ["both1", null, "l_payload", null, null, null], + ["both1", 0, "l_payload", null, null, null], + ["both2", null, "l_payload", null, null, null], + ["both2", 0, "l_payload", null, null, null], + ["both1", 42, "l_payload", "both1", null, "r_payload"], + ["both1", 42, "l_payload", "both1", 0, "r_payload"], + ["both1", 42, "l_payload", "both1", 42, "r_payload"], + ["both2", 42, "l_payload", "both2", null, "r_payload"], + ["both2", 42, "l_payload", "both2", 0, "r_payload"], + ["both2", 42, "l_payload", "both2", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Right outer join. + JoinType join_type = JoinType::RIGHT_OUTER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + ["both1", 42, "l_payload", "both1", null, "r_payload"], + ["both1", 42, "l_payload", "both1", 0, "r_payload"], + ["both1", 42, "l_payload", "both1", 42, "r_payload"], + ["both2", 42, "l_payload", "both2", null, "r_payload"], + ["both2", 42, "l_payload", "both2", 0, "r_payload"], + ["both2", 42, "l_payload", "both2", 42, "r_payload"], + [null, null, null, null, null, "r_payload"], + [null, null, null, null, 0, "r_payload"], + [null, null, null, null, 42, "r_payload"], + [null, null, null, "right_only", null, "r_payload"], + [null, null, null, "right_only", 0, "r_payload"], + [null, null, null, "right_only", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Full outer join. + JoinType join_type = JoinType::FULL_OUTER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + [null, null, "l_payload", null, null, null], + [null, 0, "l_payload", null, null, null], + [null, 42, "l_payload", null, null, null], + ["left_only", null, "l_payload", null, null, null], + ["left_only", 0, "l_payload", null, null, null], + ["left_only", 42, "l_payload", null, null, null], + ["both1", null, "l_payload", null, null, null], + ["both1", 0, "l_payload", null, null, null], + ["both2", null, "l_payload", null, null, null], + ["both2", 0, "l_payload", null, null, null], + ["both1", 42, "l_payload", "both1", null, "r_payload"], + ["both1", 42, "l_payload", "both1", 0, "r_payload"], + ["both1", 42, "l_payload", "both1", 42, "r_payload"], + ["both2", 42, "l_payload", "both2", null, "r_payload"], + ["both2", 42, "l_payload", "both2", 0, "r_payload"], + ["both2", 42, "l_payload", "both2", 42, "r_payload"], + [null, null, null, null, null, "r_payload"], + [null, null, null, null, 0, "r_payload"], + [null, null, null, null, 42, "r_payload"], + [null, null, null, "right_only", null, "r_payload"], + [null, null, null, "right_only", 0, "r_payload"], + [null, null, null, "right_only", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Left semi join. + JoinType join_type = JoinType::LEFT_SEMI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + ["both1", 42, "l_payload"], + ["both2", 42, "l_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Left anti join. + JoinType join_type = JoinType::LEFT_ANTI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + [null, null, "l_payload"], + [null, 0, "l_payload"], + [null, 42, "l_payload"], + ["left_only", null, "l_payload"], + ["left_only", 0, "l_payload"], + ["left_only", 42, "l_payload"], + ["both1", null, "l_payload"], + ["both1", 0, "l_payload"], + ["both2", null, "l_payload"], + ["both2", 0, "l_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Right semi join. + JoinType join_type = JoinType::RIGHT_SEMI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + ["both1", null, "r_payload"], + ["both1", 0, "r_payload"], + ["both1", 42, "r_payload"], + ["both2", null, "r_payload"], + ["both2", 0, "r_payload"], + ["both2", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Right anti join. + JoinType join_type = JoinType::RIGHT_ANTI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + [null, null, "r_payload"], + [null, 0, "r_payload"], + [null, 42, "r_payload"], + ["right_only", null, "r_payload"], + ["right_only", 0, "r_payload"], + ["right_only", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + } + } + + { + // Non-trivial filters referring right columns only. + for (Expression filter : {equal(field_ref("r_filter"), literal(42)), + not_equal(literal(0), field_ref("r_filter"))}) { + std::vector left_keys{"l_key"}, right_keys{"r_key"}; + { + // Inner join. + JoinType join_type = JoinType::INNER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + ["both1", null, "l_payload", "both1", 42, "r_payload"], + ["both1", 0, "l_payload", "both1", 42, "r_payload"], + ["both1", 42, "l_payload", "both1", 42, "r_payload"], + ["both2", null, "l_payload", "both2", 42, "r_payload"], + ["both2", 0, "l_payload", "both2", 42, "r_payload"], + ["both2", 42, "l_payload", "both2", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Left outer join. + JoinType join_type = JoinType::LEFT_OUTER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + [null, null, "l_payload", null, null, null], + [null, 0, "l_payload", null, null, null], + [null, 42, "l_payload", null, null, null], + ["left_only", null, "l_payload", null, null, null], + ["left_only", 0, "l_payload", null, null, null], + ["left_only", 42, "l_payload", null, null, null], + ["both1", null, "l_payload", "both1", 42, "r_payload"], + ["both1", 0, "l_payload", "both1", 42, "r_payload"], + ["both1", 42, "l_payload", "both1", 42, "r_payload"], + ["both2", null, "l_payload", "both2", 42, "r_payload"], + ["both2", 0, "l_payload", "both2", 42, "r_payload"], + ["both2", 42, "l_payload", "both2", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Right outer join. + JoinType join_type = JoinType::RIGHT_OUTER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + ["both1", null, "l_payload", "both1", 42, "r_payload"], + ["both1", 0, "l_payload", "both1", 42, "r_payload"], + ["both1", 42, "l_payload", "both1", 42, "r_payload"], + ["both2", null, "l_payload", "both2", 42, "r_payload"], + ["both2", 0, "l_payload", "both2", 42, "r_payload"], + ["both2", 42, "l_payload", "both2", 42, "r_payload"], + [null, null, null, "both1", null, "r_payload"], + [null, null, null, "both1", 0, "r_payload"], + [null, null, null, "both2", null, "r_payload"], + [null, null, null, "both2", 0, "r_payload"], + [null, null, null, null, null, "r_payload"], + [null, null, null, null, 0, "r_payload"], + [null, null, null, null, 42, "r_payload"], + [null, null, null, "right_only", null, "r_payload"], + [null, null, null, "right_only", 0, "r_payload"], + [null, null, null, "right_only", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Full outer join. + JoinType join_type = JoinType::FULL_OUTER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + [null, null, "l_payload", null, null, null], + [null, 0, "l_payload", null, null, null], + [null, 42, "l_payload", null, null, null], + ["left_only", null, "l_payload", null, null, null], + ["left_only", 0, "l_payload", null, null, null], + ["left_only", 42, "l_payload", null, null, null], + ["both1", null, "l_payload", "both1", 42, "r_payload"], + ["both1", 0, "l_payload", "both1", 42, "r_payload"], + ["both1", 42, "l_payload", "both1", 42, "r_payload"], + ["both2", null, "l_payload", "both2", 42, "r_payload"], + ["both2", 0, "l_payload", "both2", 42, "r_payload"], + ["both2", 42, "l_payload", "both2", 42, "r_payload"], + [null, null, null, "both1", null, "r_payload"], + [null, null, null, "both1", 0, "r_payload"], + [null, null, null, "both2", null, "r_payload"], + [null, null, null, "both2", 0, "r_payload"], + [null, null, null, null, null, "r_payload"], + [null, null, null, null, 0, "r_payload"], + [null, null, null, null, 42, "r_payload"], + [null, null, null, "right_only", null, "r_payload"], + [null, null, null, "right_only", 0, "r_payload"], + [null, null, null, "right_only", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Left semi join. + JoinType join_type = JoinType::LEFT_SEMI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + ["both1", null, "l_payload"], + ["both1", 0, "l_payload"], + ["both1", 42, "l_payload"], + ["both2", null, "l_payload"], + ["both2", 0, "l_payload"], + ["both2", 42, "l_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Left anti join. + JoinType join_type = JoinType::LEFT_ANTI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + [null, null, "l_payload"], + [null, 0, "l_payload"], + [null, 42, "l_payload"], + ["left_only", null, "l_payload"], + ["left_only", 0, "l_payload"], + ["left_only", 42, "l_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Right semi join. + JoinType join_type = JoinType::RIGHT_SEMI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + ["both1", 42, "r_payload"], + ["both2", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Right anti join. + JoinType join_type = JoinType::RIGHT_ANTI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + [null, null, "r_payload"], + [null, 0, "r_payload"], + [null, 42, "r_payload"], + ["both1", null, "r_payload"], + ["both1", 0, "r_payload"], + ["both2", null, "r_payload"], + ["both2", 0, "r_payload"], + ["right_only", null, "r_payload"], + ["right_only", 0, "r_payload"], + ["right_only", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + } + } + + { + // Non-trivial filters referring both left and right columns. + for (Expression filter : + {equal(field_ref("l_filter"), field_ref("r_filter")), + equal(call("subtract", {field_ref("l_filter"), field_ref("r_filter")}), + literal(0))}) { + std::vector left_keys{"l_key"}, right_keys{"r_key"}; + { + // Inner join. + JoinType join_type = JoinType::INNER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + ["both1", 0, "l_payload", "both1", 0, "r_payload"], + ["both1", 42, "l_payload", "both1", 42, "r_payload"], + ["both2", 0, "l_payload", "both2", 0, "r_payload"], + ["both2", 42, "l_payload", "both2", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Left outer join. + JoinType join_type = JoinType::LEFT_OUTER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + [null, null, "l_payload", null, null, null], + [null, 0, "l_payload", null, null, null], + [null, 42, "l_payload", null, null, null], + ["left_only", null, "l_payload", null, null, null], + ["left_only", 0, "l_payload", null, null, null], + ["left_only", 42, "l_payload", null, null, null], + ["both1", null, "l_payload", null, null, null], + ["both2", null, "l_payload", null, null, null], + ["both1", 0, "l_payload", "both1", 0, "r_payload"], + ["both1", 42, "l_payload", "both1", 42, "r_payload"], + ["both2", 0, "l_payload", "both2", 0, "r_payload"], + ["both2", 42, "l_payload", "both2", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Right outer join. + JoinType join_type = JoinType::RIGHT_OUTER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + ["both1", 0, "l_payload", "both1", 0, "r_payload"], + ["both1", 42, "l_payload", "both1", 42, "r_payload"], + ["both2", 0, "l_payload", "both2", 0, "r_payload"], + ["both2", 42, "l_payload", "both2", 42, "r_payload"], + [null, null, null, null, null, "r_payload"], + [null, null, null, null, 0, "r_payload"], + [null, null, null, null, 42, "r_payload"], + [null, null, null, "both1", null, "r_payload"], + [null, null, null, "both2", null, "r_payload"], + [null, null, null, "right_only", null, "r_payload"], + [null, null, null, "right_only", 0, "r_payload"], + [null, null, null, "right_only", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Full outer join. + JoinType join_type = JoinType::FULL_OUTER; + auto expected = + ExecBatchFromJSON({utf8(), int32(), utf8(), utf8(), int32(), utf8()}, R"([ + [null, null, "l_payload", null, null, null], + [null, 0, "l_payload", null, null, null], + [null, 42, "l_payload", null, null, null], + ["left_only", null, "l_payload", null, null, null], + ["left_only", 0, "l_payload", null, null, null], + ["left_only", 42, "l_payload", null, null, null], + ["both1", null, "l_payload", null, null, null], + ["both2", null, "l_payload", null, null, null], + ["both1", 0, "l_payload", "both1", 0, "r_payload"], + ["both1", 42, "l_payload", "both1", 42, "r_payload"], + ["both2", 0, "l_payload", "both2", 0, "r_payload"], + ["both2", 42, "l_payload", "both2", 42, "r_payload"], + [null, null, null, null, null, "r_payload"], + [null, null, null, null, 0, "r_payload"], + [null, null, null, null, 42, "r_payload"], + [null, null, null, "both1", null, "r_payload"], + [null, null, null, "both2", null, "r_payload"], + [null, null, null, "right_only", null, "r_payload"], + [null, null, null, "right_only", 0, "r_payload"], + [null, null, null, "right_only", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Left semi join. + JoinType join_type = JoinType::LEFT_SEMI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + ["both1", 0, "l_payload"], + ["both1", 42, "l_payload"], + ["both2", 0, "l_payload"], + ["both2", 42, "l_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Left anti join. + JoinType join_type = JoinType::LEFT_ANTI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + [null, null, "l_payload"], + [null, 0, "l_payload"], + [null, 42, "l_payload"], + ["left_only", null, "l_payload"], + ["left_only", 0, "l_payload"], + ["left_only", 42, "l_payload"], + ["both1", null, "l_payload"], + ["both2", null, "l_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Right semi join. + JoinType join_type = JoinType::RIGHT_SEMI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + ["both1", 0, "r_payload"], + ["both1", 42, "r_payload"], + ["both2", 0, "r_payload"], + ["both2", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } + + { + // Right anti join. + JoinType join_type = JoinType::RIGHT_ANTI; + auto expected = ExecBatchFromJSON({utf8(), int32(), utf8()}, R"([ + [null, null, "r_payload"], + [null, 0, "r_payload"], + [null, 42, "r_payload"], + ["both1", null, "r_payload"], + ["both2", null, "r_payload"], + ["right_only", null, "r_payload"], + ["right_only", 0, "r_payload"], + ["right_only", 42, "r_payload"]])"); + for (const auto& projector : projectors) { + runner.Run(join_type, left_keys, right_keys, projector.LeftOutput(join_type), + projector.RightOutput(join_type), filter, + {projector.Project(join_type, expected)}); + } + } } } } diff --git a/cpp/src/arrow/acero/swiss_join.cc b/cpp/src/arrow/acero/swiss_join.cc index 2f79ed299bb70..68b0e37b01aa9 100644 --- a/cpp/src/arrow/acero/swiss_join.cc +++ b/cpp/src/arrow/acero/swiss_join.cc @@ -1085,10 +1085,30 @@ void SwissTableForJoin::UpdateHasMatchForKeys(int64_t thread_id, int num_ids, if (num_ids == 0 || !bit_vector) { return; } + for (int ikey = 0; ikey < num_ids; ++ikey) { + // Mark payloads corresponding to this key in hash table as having a match. + // + uint32_t key_id = key_ids[ikey]; + uint32_t first_payload_for_key = key_to_payload() ? key_to_payload()[key_id] : key_id; + uint32_t last_payload_for_key = + key_to_payload() ? key_to_payload()[key_id + 1] - 1 : key_id; + for (uint32_t ipayload = first_payload_for_key; ipayload <= last_payload_for_key; + ++ipayload) { + bit_util::SetBit(bit_vector, ipayload); + } + } +} + +void SwissTableForJoin::UpdateHasMatchForPayloads(int64_t thread_id, int num_ids, + const uint32_t* payload_ids) { + uint8_t* bit_vector = local_has_match(thread_id); + if (num_ids == 0 || !bit_vector) { + return; + } for (int i = 0; i < num_ids; ++i) { - // Mark row in hash table as having a match + // Mark payload in hash table as having a match. // - bit_util::SetBit(bit_vector, key_ids[i]); + bit_util::SetBit(bit_vector, payload_ids[i]); } } @@ -1123,29 +1143,6 @@ uint32_t SwissTableForJoin::payload_id_to_key_id(uint32_t payload_id) const { return static_cast(first_greater - entries) - 1; } -void SwissTableForJoin::payload_ids_to_key_ids(int num_rows, const uint32_t* payload_ids, - uint32_t* key_ids) const { - if (num_rows == 0) { - return; - } - if (no_duplicate_keys_) { - memcpy(key_ids, payload_ids, num_rows * sizeof(uint32_t)); - return; - } - - const uint32_t* entries = key_to_payload(); - uint32_t key_id = payload_id_to_key_id(payload_ids[0]); - key_ids[0] = key_id; - for (int i = 1; i < num_rows; ++i) { - ARROW_DCHECK(payload_ids[i] > payload_ids[i - 1]); - while (entries[key_id + 1] <= payload_ids[i]) { - ++key_id; - ARROW_DCHECK(key_id < num_keys()); - } - key_ids[i] = key_id; - } -} - Status SwissTableForJoinBuild::Init(SwissTableForJoin* target, int dop, int64_t num_rows, bool reject_duplicate_keys, bool no_payload, const std::vector& key_types, @@ -1581,6 +1578,10 @@ Status JoinResultMaterialize::AppendProbeOnly(const ExecBatch& key_and_payload, int num_rows_to_append, const uint16_t* row_ids, int* num_rows_appended) { + if (num_rows_to_append == 0) { + *num_rows_appended = 0; + return Status::OK(); + } num_rows_to_append = std::min(ExecBatchBuilder::num_rows_max() - num_rows_, num_rows_to_append); if (HasProbeOutput()) { @@ -1607,6 +1608,10 @@ Status JoinResultMaterialize::AppendBuildOnly(int num_rows_to_append, const uint32_t* key_ids, const uint32_t* payload_ids, int* num_rows_appended) { + if (num_rows_to_append == 0) { + *num_rows_appended = 0; + return Status::OK(); + } num_rows_to_append = std::min(ExecBatchBuilder::num_rows_max() - num_rows_, num_rows_to_append); if (HasProbeOutput()) { @@ -1634,6 +1639,10 @@ Status JoinResultMaterialize::Append(const ExecBatch& key_and_payload, int num_rows_to_append, const uint16_t* row_ids, const uint32_t* key_ids, const uint32_t* payload_ids, int* num_rows_appended) { + if (num_rows_to_append == 0) { + *num_rows_appended = 0; + return Status::OK(); + } num_rows_to_append = std::min(ExecBatchBuilder::num_rows_max() - num_rows_, num_rows_to_append); if (HasProbeOutput()) { @@ -1791,7 +1800,7 @@ void JoinMatchIterator::SetLookupResult(int num_batch_rows, int start_batch_row, bool JoinMatchIterator::GetNextBatch(int num_rows_max, int* out_num_rows, uint16_t* batch_row_ids, uint32_t* key_ids, - uint32_t* payload_ids) { + uint32_t* payload_ids, int row_id_to_skip) { *out_num_rows = 0; if (no_duplicate_keys_) { @@ -1816,7 +1825,8 @@ bool JoinMatchIterator::GetNextBatch(int num_rows_max, int* out_num_rows, // matches to output. // while (current_row_ < num_batch_rows_ && *out_num_rows < num_rows_max) { - if (!bit_util::GetBit(batch_has_match_, current_row_)) { + if (!bit_util::GetBit(batch_has_match_, current_row_) || + current_row_ == row_id_to_skip) { ++current_row_; current_match_for_row_ = 0; continue; @@ -1855,14 +1865,415 @@ bool JoinMatchIterator::GetNextBatch(int num_rows_max, int* out_num_rows, return (*out_num_rows) > 0; } +namespace { + +// Given match_bitvector identifies that there is a match for row[batch_start_row + i] in +// given input batch if bit match_bitvector[i] == passing_bit. Collect all the passing row +// ids according to the given match_bitvector. +// +void CollectPassingBatchIds(int passing_bit, int64_t hardware_flags, int batch_start_row, + int num_batch_rows, const uint8_t* match_bitvector, + int* num_passing_ids, uint16_t* passing_batch_row_ids) { + arrow::util::bit_util::bits_to_indexes(passing_bit, hardware_flags, num_batch_rows, + match_bitvector, num_passing_ids, + passing_batch_row_ids); + // Add base batch row index. + // + for (int i = 0; i < *num_passing_ids; ++i) { + passing_batch_row_ids[i] += static_cast(batch_start_row); + } +} + +} // namespace + +void JoinResidualFilter::Init(Expression filter, QueryContext* ctx, MemoryPool* pool, + int64_t hardware_flags, + const HashJoinProjectionMaps* probe_schemas, + const HashJoinProjectionMaps* build_schemas, + SwissTableForJoin* hash_table) { + filter_ = std::move(filter); + ctx_ = ctx; + pool_ = pool; + hardware_flags_ = hardware_flags; + probe_schemas_ = probe_schemas; + build_schemas_ = build_schemas; + hash_table_ = hash_table; + + { + probe_filter_to_key_and_payload_.resize( + probe_schemas_->num_cols(HashJoinProjection::FILTER)); + int num_key_cols = probe_schemas_->num_cols(HashJoinProjection::KEY); + auto to_key = + probe_schemas_->map(HashJoinProjection::FILTER, HashJoinProjection::KEY); + auto to_payload = + probe_schemas_->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD); + for (int i = 0; static_cast(i) < probe_filter_to_key_and_payload_.size(); + ++i) { + if (auto idx = to_key.get(i); idx != SchemaProjectionMap::kMissingField) { + probe_filter_to_key_and_payload_[i] = idx; + } else if (idx = to_payload.get(i); idx != SchemaProjectionMap::kMissingField) { + probe_filter_to_key_and_payload_[i] = idx + num_key_cols; + } else { + ARROW_DCHECK(false); + } + } + } + + { + int num_columns = build_schemas_->num_cols(HashJoinProjection::FILTER); + auto to_key = + build_schemas_->map(HashJoinProjection::FILTER, HashJoinProjection::KEY); + auto to_payload = + build_schemas_->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD); + for (int i = 0; i < num_columns; ++i) { + if (to_key.get(i) != SchemaProjectionMap::kMissingField) { + num_build_keys_referred_++; + } else if (to_payload.get(i) != SchemaProjectionMap::kMissingField) { + num_build_payloads_referred_++; + } else { + ARROW_DCHECK(false); + } + } + } +} + +void JoinResidualFilter::OnBuildFinished() { + minibatch_size_ = hash_table_->keys()->swiss_table()->minibatch_size(); + build_keys_ = hash_table_->keys()->keys(); + build_payloads_ = hash_table_->payloads(); + key_to_payload_ = hash_table_->key_to_payload(); +} + +void JoinResidualFilter::InitFilterBitVector(int num_batch_rows, + uint8_t* filter_bitvector) { + std::memset(filter_bitvector, 0, bit_util::BytesForBits(num_batch_rows)); +} + +void JoinResidualFilter::UpdateFilterBitVector(int batch_start_row, int num_batch_rows, + const uint16_t* batch_row_ids, + uint8_t* filter_bitvector) { + for (int i = 0; i < num_batch_rows; ++i) { + int bit_idx = batch_row_ids[i] - batch_start_row; + bit_util::SetBitTo(filter_bitvector, bit_idx, 1); + } +} + +Status JoinResidualFilter::FilterLeftSemi(const ExecBatch& keypayload_batch, + int batch_start_row, int num_batch_rows, + const uint8_t* match_bitvector, + const uint32_t* key_ids, bool no_duplicate_keys, + arrow::util::TempVectorStack* temp_stack, + int* num_passing_ids, + uint16_t* passing_batch_row_ids) const { + if (filter_ == literal(true)) { + CollectPassingBatchIds(1, hardware_flags_, batch_start_row, num_batch_rows, + match_bitvector, num_passing_ids, passing_batch_row_ids); + return Status::OK(); + } + + *num_passing_ids = 0; + if (filter_.IsNullLiteral() || filter_ == literal(false)) { + return Status::OK(); + } + + if (num_build_keys_referred_ == 0 && num_build_payloads_referred_ == 0) { + // If filter refers no column in the right table, then we can directly filter on the + // left rows without inner matching and materializing the right rows. + // + CollectPassingBatchIds(1, hardware_flags_, batch_start_row, num_batch_rows, + match_bitvector, num_passing_ids, passing_batch_row_ids); + return FilterOneBatch(keypayload_batch, *num_passing_ids, passing_batch_row_ids, + /*key_ids_maybe_null=*/NULLPTR, + /*payload_ids_maybe_null=*/NULLPTR, + /*output_key_ids=*/false, + /*output_payload_ids=*/false, temp_stack, num_passing_ids); + } + + auto match_batch_row_ids_buf = + arrow::util::TempVectorHolder(temp_stack, minibatch_size_); + auto match_key_ids_buf = + arrow::util::TempVectorHolder(temp_stack, minibatch_size_); + auto match_payload_ids_buf = + arrow::util::TempVectorHolder(temp_stack, minibatch_size_); + + // Inner matching is necessary for non-trivial filter. Only until evaluating filter for + // all matches of the same row can we be sure that it's not passing (it could pass + // earlier though). + // + JoinMatchIterator match_iterator; + match_iterator.SetLookupResult(num_batch_rows, batch_start_row, match_bitvector, + key_ids, no_duplicate_keys, key_to_payload_); + int num_matches_next = 0; + // Used to not only collect distinct row ids, but also skip unecessary matches in the + // next batch. + // + int row_id_last = JoinMatchIterator::kInvalidRowId; + while (match_iterator.GetNextBatch(minibatch_size_, &num_matches_next, + match_batch_row_ids_buf.mutable_data(), + match_key_ids_buf.mutable_data(), + match_payload_ids_buf.mutable_data(), row_id_last)) { + int num_passing = 0; + RETURN_NOT_OK(FilterOneBatch( + keypayload_batch, num_matches_next, match_batch_row_ids_buf.mutable_data(), + match_key_ids_buf.mutable_data(), match_payload_ids_buf.mutable_data(), + /*output_key_ids=*/false, + /*output_payload_ids=*/false, temp_stack, &num_passing)); + // There may be multiple passing of a row in batch. Collect distinct row ids. + // + for (int ipassing = 0; ipassing < num_passing; ++ipassing) { + if (match_batch_row_ids_buf.mutable_data()[ipassing] == row_id_last) { + continue; + } + row_id_last = passing_batch_row_ids[*num_passing_ids] = + match_batch_row_ids_buf.mutable_data()[ipassing]; + ++(*num_passing_ids); + } + } + + return Status::OK(); +} + +Status JoinResidualFilter::FilterLeftAnti(const ExecBatch& keypayload_batch, + int batch_start_row, int num_batch_rows, + const uint8_t* match_bitvector, + const uint32_t* key_ids, bool no_duplicate_keys, + arrow::util::TempVectorStack* temp_stack, + int* num_passing_ids, + uint16_t* passing_batch_row_ids) const { + if (filter_ == literal(true)) { + CollectPassingBatchIds(0, hardware_flags_, batch_start_row, num_batch_rows, + match_bitvector, num_passing_ids, passing_batch_row_ids); + return Status::OK(); + } + + // Do FilterLeftSemi first. + // + *num_passing_ids = 0; + int num_semi_passing_ids = 0; + auto semi_passing_batch_row_ids = + arrow::util::TempVectorHolder(temp_stack, num_batch_rows); + RETURN_NOT_OK(FilterLeftSemi(keypayload_batch, batch_start_row, num_batch_rows, + match_bitvector, key_ids, no_duplicate_keys, temp_stack, + &num_semi_passing_ids, + semi_passing_batch_row_ids.mutable_data())); + + // Then collect non-passing row ids of FilterLeftSemi. + // + int isemi = 0; + for (int irow = batch_start_row; irow < batch_start_row + num_batch_rows; ++irow) { + while (isemi < num_semi_passing_ids && + semi_passing_batch_row_ids.mutable_data()[isemi] < irow) { + ++isemi; + } + if (isemi == num_semi_passing_ids || + semi_passing_batch_row_ids.mutable_data()[isemi] != irow) { + passing_batch_row_ids[*num_passing_ids] = static_cast(irow); + ++(*num_passing_ids); + } + } + + return Status::OK(); +} + +Status JoinResidualFilter::FilterRightSemiAnti( + int64_t thread_id, const ExecBatch& keypayload_batch, int batch_start_row, + int num_batch_rows, const uint8_t* match_bitvector, const uint32_t* key_ids, + bool no_duplicate_keys, arrow::util::TempVectorStack* temp_stack) const { + if (filter_.IsNullLiteral() || filter_ == literal(false)) { + return Status::OK(); + } + + int num_matching_ids = 0; + if (filter_ == literal(true)) { + auto match_relative_batch_ids_buf = + arrow::util::TempVectorHolder(temp_stack, num_batch_rows); + auto match_key_ids_buf = + arrow::util::TempVectorHolder(temp_stack, num_batch_rows); + + arrow::util::bit_util::bits_to_indexes(1, hardware_flags_, num_batch_rows, + match_bitvector, &num_matching_ids, + match_relative_batch_ids_buf.mutable_data()); + // Collect key ids of passing rows. + // + for (int i = 0; i < num_matching_ids; ++i) { + uint16_t id = match_relative_batch_ids_buf.mutable_data()[i]; + match_key_ids_buf.mutable_data()[i] = key_ids[id]; + } + + hash_table_->UpdateHasMatchForKeys(thread_id, num_matching_ids, + match_key_ids_buf.mutable_data()); + return Status::OK(); + } + + auto match_batch_row_ids_buf = + arrow::util::TempVectorHolder(temp_stack, minibatch_size_); + auto match_key_ids_buf = + arrow::util::TempVectorHolder(temp_stack, minibatch_size_); + auto match_payload_ids_buf = + arrow::util::TempVectorHolder(temp_stack, minibatch_size_); + + // Inner matching is necessary for non-trivial filter. Because even for the same row + // with same matching key, the filter results could vary for different payloads. + // + JoinMatchIterator match_iterator; + match_iterator.SetLookupResult(num_batch_rows, batch_start_row, match_bitvector, + key_ids, no_duplicate_keys, key_to_payload_); + while (match_iterator.GetNextBatch( + minibatch_size_, &num_matching_ids, match_batch_row_ids_buf.mutable_data(), + match_key_ids_buf.mutable_data(), match_payload_ids_buf.mutable_data())) { + int num_filtered = 0; + RETURN_NOT_OK(FilterOneBatch( + keypayload_batch, num_matching_ids, match_batch_row_ids_buf.mutable_data(), + match_key_ids_buf.mutable_data(), match_payload_ids_buf.mutable_data(), + /*output_key_ids=*/false, + /*output_payload_ids=*/true, temp_stack, &num_filtered)); + hash_table_->UpdateHasMatchForPayloads(thread_id, num_filtered, + match_payload_ids_buf.mutable_data()); + } + + return Status::OK(); +} + +Status JoinResidualFilter::FilterInner( + const ExecBatch& keypayload_batch, int num_batch_rows, uint16_t* batch_row_ids, + uint32_t* key_ids, uint32_t* payload_ids_maybe_null, bool output_payload_ids, + arrow::util::TempVectorStack* temp_stack, int* num_passing_rows) const { + if (filter_ == literal(true)) { + *num_passing_rows = num_batch_rows; + return Status::OK(); + } + + *num_passing_rows = 0; + if (filter_.IsNullLiteral() || filter_ == literal(false)) { + return Status::OK(); + } + + return FilterOneBatch( + keypayload_batch, num_batch_rows, batch_row_ids, key_ids, payload_ids_maybe_null, + /*output_key_ids=*/true, output_payload_ids, temp_stack, num_passing_rows); +} + +Status JoinResidualFilter::FilterOneBatch(const ExecBatch& keypayload_batch, + int num_batch_rows, uint16_t* batch_row_ids, + uint32_t* key_ids_maybe_null, + uint32_t* payload_ids_maybe_null, + bool output_key_ids, bool output_payload_ids, + arrow::util::TempVectorStack* temp_stack, + int* num_passing_rows) const { + // Caller must do shortcuts for trivial filter. + ARROW_DCHECK(!filter_.IsNullLiteral() && filter_ != literal(true) && + filter_ != literal(false)); + ARROW_DCHECK(!output_key_ids || key_ids_maybe_null); + ARROW_DCHECK(!output_payload_ids || payload_ids_maybe_null); + + *num_passing_rows = 0; + ARROW_ASSIGN_OR_RAISE(Datum mask, + EvalFilter(keypayload_batch, num_batch_rows, batch_row_ids, + key_ids_maybe_null, payload_ids_maybe_null)); + if (mask.is_scalar()) { + const auto& mask_scalar = mask.scalar_as(); + if (mask_scalar.is_valid && mask_scalar.value) { + *num_passing_rows = num_batch_rows; + } + return Status::OK(); + } + + ARROW_DCHECK_EQ(mask.array()->offset, 0); + ARROW_DCHECK_EQ(mask.array()->length, static_cast(num_batch_rows)); + const uint8_t* validity = + mask.array()->buffers[0] ? mask.array()->buffers[0]->data() : nullptr; + const uint8_t* comparisons = mask.array()->buffers[1]->data(); + for (int irow = 0; irow < num_batch_rows; ++irow) { + bool is_valid = !validity || bit_util::GetBit(validity, irow); + bool is_cmp_true = bit_util::GetBit(comparisons, irow); + if (is_valid && is_cmp_true) { + batch_row_ids[*num_passing_rows] = batch_row_ids[irow]; + if (output_key_ids) { + key_ids_maybe_null[*num_passing_rows] = key_ids_maybe_null[irow]; + } + if (output_payload_ids) { + payload_ids_maybe_null[*num_passing_rows] = payload_ids_maybe_null[irow]; + } + ++(*num_passing_rows); + } + } + + return Status::OK(); +} + +Result JoinResidualFilter::EvalFilter( + const ExecBatch& keypayload_batch, int num_batch_rows, const uint16_t* batch_row_ids, + const uint32_t* key_ids_maybe_null, const uint32_t* payload_ids_maybe_null) const { + ARROW_DCHECK(!filter_.IsNullLiteral() && filter_ != literal(true) && + filter_ != literal(false)); + + ARROW_ASSIGN_OR_RAISE( + ExecBatch input, + MaterializeFilterInput(keypayload_batch, num_batch_rows, batch_row_ids, + key_ids_maybe_null, payload_ids_maybe_null)); + return ExecuteScalarExpression(filter_, input, ctx_->exec_context()); +} + +Result JoinResidualFilter::MaterializeFilterInput( + const ExecBatch& keypayload_batch, int num_batch_rows, const uint16_t* batch_row_ids, + const uint32_t* key_ids_maybe_null, const uint32_t* payload_ids_maybe_null) const { + ExecBatch out; + out.length = num_batch_rows; + out.values.resize(probe_filter_to_key_and_payload_.size() + num_build_keys_referred_ + + num_build_payloads_referred_); + + if (probe_filter_to_key_and_payload_.size() > 0) { + ExecBatchBuilder probe_batch_builder; + RETURN_NOT_OK(probe_batch_builder.AppendSelected( + pool_, keypayload_batch, num_batch_rows, batch_row_ids, + static_cast(probe_filter_to_key_and_payload_.size()), + probe_filter_to_key_and_payload_.data())); + ExecBatch probe_batch = probe_batch_builder.Flush(); + ARROW_DCHECK(probe_batch.values.size() == probe_filter_to_key_and_payload_.size()); + for (size_t i = 0; i < probe_batch.values.size(); ++i) { + out.values[i] = std::move(probe_batch.values[i]); + } + } + + if (num_build_keys_referred_ > 0 || num_build_payloads_referred_ > 0) { + ARROW_DCHECK(num_build_keys_referred_ == 0 || key_ids_maybe_null); + ARROW_DCHECK(num_build_payloads_referred_ == 0 || payload_ids_maybe_null); + + int num_build_cols = build_schemas_->num_cols(HashJoinProjection::FILTER); + auto to_key = + build_schemas_->map(HashJoinProjection::FILTER, HashJoinProjection::KEY); + auto to_payload = + build_schemas_->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD); + for (int i = 0; i < num_build_cols; ++i) { + ResizableArrayData column_data; + column_data.Init(build_schemas_->data_type(HashJoinProjection::FILTER, i), pool_, + bit_util::Log2(num_batch_rows)); + if (auto idx = to_key.get(i); idx != SchemaProjectionMap::kMissingField) { + RETURN_NOT_OK(build_keys_->DecodeSelected(&column_data, idx, num_batch_rows, + key_ids_maybe_null, pool_)); + } else if (idx = to_payload.get(i); idx != SchemaProjectionMap::kMissingField) { + RETURN_NOT_OK(build_payloads_->DecodeSelected(&column_data, idx, num_batch_rows, + payload_ids_maybe_null, pool_)); + } else { + ARROW_DCHECK(false); + } + out.values[probe_filter_to_key_and_payload_.size() + i] = column_data.array_data(); + } + } + + return out; +} + void JoinProbeProcessor::Init(int num_key_columns, JoinType join_type, SwissTableForJoin* hash_table, + JoinResidualFilter* residual_filter, std::vector materialize, const std::vector* cmp, OutputBatchFn output_batch_fn) { num_key_columns_ = num_key_columns; join_type_ = join_type; hash_table_ = hash_table; + residual_filter_ = residual_filter; materialize_.resize(materialize.size()); for (size_t i = 0; i < materialize.size(); ++i) { materialize_[i] = materialize[i]; @@ -1875,6 +2286,7 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id, const ExecBatch& keypayload_batch, arrow::util::TempVectorStack* temp_stack, std::vector* temp_column_arrays) { + bool no_duplicate_keys = (hash_table_->key_to_payload() == nullptr); const SwissTable* swiss_table = hash_table_->keys()->swiss_table(); int64_t hardware_flags = swiss_table->hardware_flags(); int minibatch_size = swiss_table->minibatch_size(); @@ -1900,6 +2312,8 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id, arrow::util::TempVectorHolder(temp_stack, minibatch_size); auto materialize_payload_ids_buf = arrow::util::TempVectorHolder(temp_stack, minibatch_size); + auto filter_bitvector_buf = arrow::util::TempVectorHolder( + temp_stack, static_cast(bit_util::BytesForBits(minibatch_size))); for (int minibatch_start = 0; minibatch_start < num_rows;) { uint32_t minibatch_size_next = std::min(minibatch_size, num_rows - minibatch_start); @@ -1923,33 +2337,29 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id, if (join_type_ == JoinType::LEFT_SEMI || join_type_ == JoinType::LEFT_ANTI || join_type_ == JoinType::RIGHT_SEMI || join_type_ == JoinType::RIGHT_ANTI) { int num_passing_ids = 0; - arrow::util::bit_util::bits_to_indexes( - (join_type_ == JoinType::LEFT_ANTI) ? 0 : 1, hardware_flags, - minibatch_size_next, match_bitvector_buf.mutable_data(), &num_passing_ids, - materialize_batch_ids_buf.mutable_data()); - - // For right-semi, right-anti joins: update has-match flags for the rows - // in hash table. - // - if (join_type_ == JoinType::RIGHT_SEMI || join_type_ == JoinType::RIGHT_ANTI) { - for (int i = 0; i < num_passing_ids; ++i) { - uint16_t id = materialize_batch_ids_buf.mutable_data()[i]; - key_ids_buf.mutable_data()[i] = key_ids_buf.mutable_data()[id]; - } - hash_table_->UpdateHasMatchForKeys(thread_id, num_passing_ids, - key_ids_buf.mutable_data()); + if (join_type_ == JoinType::LEFT_SEMI) { + RETURN_NOT_OK(residual_filter_->FilterLeftSemi( + keypayload_batch, minibatch_start, minibatch_size_next, + match_bitvector_buf.mutable_data(), key_ids_buf.mutable_data(), + no_duplicate_keys, temp_stack, &num_passing_ids, + materialize_batch_ids_buf.mutable_data())); + } else if (join_type_ == JoinType::LEFT_ANTI) { + RETURN_NOT_OK(residual_filter_->FilterLeftAnti( + keypayload_batch, minibatch_start, minibatch_size_next, + match_bitvector_buf.mutable_data(), key_ids_buf.mutable_data(), + no_duplicate_keys, temp_stack, &num_passing_ids, + materialize_batch_ids_buf.mutable_data())); } else { - // For left-semi, left-anti joins: call materialize using match - // bit-vector. - // + RETURN_NOT_OK(residual_filter_->FilterRightSemiAnti( + thread_id, keypayload_batch, minibatch_start, minibatch_size_next, + match_bitvector_buf.mutable_data(), key_ids_buf.mutable_data(), + no_duplicate_keys, temp_stack)); + } - // Add base batch row index. + if (join_type_ == JoinType::LEFT_SEMI || join_type_ == JoinType::LEFT_ANTI) { + // For left-semi, left-anti joins: call materialize using match + // row ids. // - for (int i = 0; i < num_passing_ids; ++i) { - materialize_batch_ids_buf.mutable_data()[i] += - static_cast(minibatch_start); - } - RETURN_NOT_OK(materialize_[thread_id]->AppendProbeOnly( keypayload_batch, num_passing_ids, materialize_batch_ids_buf.mutable_data(), [&](ExecBatch batch) { @@ -1961,30 +2371,46 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id, // Since every hash table lookup for an input row might have multiple // matches we use a helper class that implements enumerating all of them. // - bool no_duplicate_keys = (hash_table_->key_to_payload() == nullptr); - bool no_payload_columns = (hash_table_->payloads() == nullptr); JoinMatchIterator match_iterator; match_iterator.SetLookupResult( minibatch_size_next, minibatch_start, match_bitvector_buf.mutable_data(), key_ids_buf.mutable_data(), no_duplicate_keys, hash_table_->key_to_payload()); int num_matches_next; + bool use_filter_bitvector = residual_filter_->NeedFilterBitVector(join_type_); + if (use_filter_bitvector) { + residual_filter_->InitFilterBitVector(minibatch_size_next, + filter_bitvector_buf.mutable_data()); + } while (match_iterator.GetNextBatch(minibatch_size, &num_matches_next, materialize_batch_ids_buf.mutable_data(), materialize_key_ids_buf.mutable_data(), materialize_payload_ids_buf.mutable_data())) { + RETURN_NOT_OK(residual_filter_->FilterInner( + keypayload_batch, num_matches_next, materialize_batch_ids_buf.mutable_data(), + materialize_key_ids_buf.mutable_data(), + materialize_payload_ids_buf.mutable_data(), !no_duplicate_keys, temp_stack, + &num_matches_next)); + const uint16_t* materialize_batch_ids = materialize_batch_ids_buf.mutable_data(); const uint32_t* materialize_key_ids = materialize_key_ids_buf.mutable_data(); const uint32_t* materialize_payload_ids = - no_duplicate_keys || no_payload_columns - ? materialize_key_ids_buf.mutable_data() - : materialize_payload_ids_buf.mutable_data(); + no_duplicate_keys ? materialize_key_ids_buf.mutable_data() + : materialize_payload_ids_buf.mutable_data(); + + // For filtered result, update filter bit-vector. + // + if (use_filter_bitvector) { + residual_filter_->UpdateFilterBitVector(minibatch_start, num_matches_next, + materialize_batch_ids, + filter_bitvector_buf.mutable_data()); + } // For right-outer, full-outer joins we need to update has-match flags // for the rows in hash table. // if (join_type_ == JoinType::RIGHT_OUTER || join_type_ == JoinType::FULL_OUTER) { - hash_table_->UpdateHasMatchForKeys(thread_id, num_matches_next, - materialize_key_ids); + hash_table_->UpdateHasMatchForPayloads(thread_id, num_matches_next, + materialize_payload_ids); } // Call materialize for resulting id tuples pointing to matching pairs @@ -2004,17 +2430,11 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id, // if (join_type_ == JoinType::LEFT_OUTER || join_type_ == JoinType::FULL_OUTER) { int num_passing_ids = 0; - arrow::util::bit_util::bits_to_indexes( - /*bit_to_search=*/0, hardware_flags, minibatch_size_next, - match_bitvector_buf.mutable_data(), &num_passing_ids, - materialize_batch_ids_buf.mutable_data()); - - // Add base batch row index. - // - for (int i = 0; i < num_passing_ids; ++i) { - materialize_batch_ids_buf.mutable_data()[i] += - static_cast(minibatch_start); - } + CollectPassingBatchIds(0, hardware_flags, minibatch_start, minibatch_size_next, + use_filter_bitvector ? filter_bitvector_buf.mutable_data() + : match_bitvector_buf.mutable_data(), + &num_passing_ids, + materialize_batch_ids_buf.mutable_data()); RETURN_NOT_OK(materialize_[thread_id]->AppendProbeOnly( keypayload_batch, num_passing_ids, materialize_batch_ids_buf.mutable_data(), @@ -2099,8 +2519,12 @@ class SwissJoin : public HashJoinImpl { materialize[i] = &local_states_[i].materialize; } + residual_filter_.Init(std::move(filter), ctx_, pool_, hardware_flags_, proj_map_left, + proj_map_right, &hash_table_); + probe_processor_.Init(proj_map_left->num_cols(HashJoinProjection::KEY), join_type_, - &hash_table_, materialize, &key_cmp_, output_batch_callback_); + &hash_table_, &residual_filter_, materialize, &key_cmp_, + output_batch_callback_); InitTaskGroups(); @@ -2180,9 +2604,11 @@ class SwissJoin : public HashJoinImpl { // const HashJoinProjectionMaps* schema = schema_[1]; bool reject_duplicate_keys = - join_type_ == JoinType::LEFT_SEMI || join_type_ == JoinType::LEFT_ANTI; + (join_type_ == JoinType::LEFT_SEMI || join_type_ == JoinType::LEFT_ANTI) && + residual_filter_.NumBuildPayloadsReferred() == 0; bool no_payload = - reject_duplicate_keys || schema->num_cols(HashJoinProjection::PAYLOAD) == 0; + reject_duplicate_keys || (schema->num_cols(HashJoinProjection::PAYLOAD) == 0 && + residual_filter_.NumBuildPayloadsReferred() == 0); std::vector key_types; for (int i = 0; i < schema->num_cols(HashJoinProjection::KEY); ++i) { @@ -2302,6 +2728,8 @@ class SwissJoin : public HashJoinImpl { } hash_table_ready_.store(true); + residual_filter_.OnBuildFinished(); + return build_finished_callback_(thread_id); } @@ -2364,24 +2792,25 @@ class SwissJoin : public HashJoinImpl { static_cast(mini_batch_start + mini_batch_size_next - 1)); int num_output_rows = 0; for (uint32_t key_id = first_key_id; key_id <= last_key_id; ++key_id) { - if (bit_util::GetBit(hash_table_.has_match(), key_id) == bit_to_output) { - uint32_t first_payload_for_key = - std::max(static_cast(mini_batch_start), - hash_table_.key_to_payload() ? hash_table_.key_to_payload()[key_id] - : key_id); - uint32_t last_payload_for_key = std::min( - static_cast(mini_batch_start + mini_batch_size_next - 1), - hash_table_.key_to_payload() ? hash_table_.key_to_payload()[key_id + 1] - 1 - : key_id); - uint32_t num_payloads_for_key = - last_payload_for_key - first_payload_for_key + 1; - for (uint32_t i = 0; i < num_payloads_for_key; ++i) { - key_ids_buf.mutable_data()[num_output_rows + i] = key_id; - payload_ids_buf.mutable_data()[num_output_rows + i] = - first_payload_for_key + i; + uint32_t first_payload_for_key = std::max( + static_cast(mini_batch_start), + hash_table_.key_to_payload() ? hash_table_.key_to_payload()[key_id] : key_id); + uint32_t last_payload_for_key = std::min( + static_cast(mini_batch_start + mini_batch_size_next - 1), + hash_table_.key_to_payload() ? hash_table_.key_to_payload()[key_id + 1] - 1 + : key_id); + uint32_t num_payloads_for_key = last_payload_for_key - first_payload_for_key + 1; + uint32_t num_payloads_match = 0; + for (uint32_t i = 0; i < num_payloads_for_key; ++i) { + uint32_t payload = first_payload_for_key + i; + if (bit_util::GetBit(hash_table_.has_match(), payload) == bit_to_output) { + key_ids_buf.mutable_data()[num_output_rows + num_payloads_match] = key_id; + payload_ids_buf.mutable_data()[num_output_rows + num_payloads_match] = + payload; + num_payloads_match++; } - num_output_rows += num_payloads_for_key; } + num_output_rows += num_payloads_match; } if (num_output_rows > 0) { @@ -2524,6 +2953,7 @@ class SwissJoin : public HashJoinImpl { SwissTableForJoin hash_table_; JoinProbeProcessor probe_processor_; + JoinResidualFilter residual_filter_; SwissTableForJoinBuild hash_table_build_; AccumulationQueue build_side_batches_; diff --git a/cpp/src/arrow/acero/swiss_join_internal.h b/cpp/src/arrow/acero/swiss_join_internal.h index 6403b7a655e96..aa36a61109274 100644 --- a/cpp/src/arrow/acero/swiss_join_internal.h +++ b/cpp/src/arrow/acero/swiss_join_internal.h @@ -367,7 +367,13 @@ class SwissTableForJoin { friend class SwissTableForJoinBuild; public: + // Update all payloads corresponding to the given keys as having a match. + // void UpdateHasMatchForKeys(int64_t thread_id, int num_rows, const uint32_t* key_ids); + // Update the given payloads as having a match. + // + void UpdateHasMatchForPayloads(int64_t thread_id, int num_rows, + const uint32_t* payload_ids); void MergeHasMatch(); const SwissTableWithKeys* keys() const { return &map_; } @@ -385,10 +391,6 @@ class SwissTableForJoin { } uint32_t payload_id_to_key_id(uint32_t payload_id) const; - // Input payload ids must form an increasing sequence. - // - void payload_ids_to_key_ids(int num_rows, const uint32_t* payload_ids, - uint32_t* key_ids) const; private: uint8_t* local_has_match(int64_t thread_id); @@ -397,8 +399,10 @@ class SwissTableForJoin { int dop_; struct ThreadLocalState { + // Bit-vector for keeping track of whether each payload in the hash table had a match. std::vector has_match; }; + // Bit-vector for keeping track of whether each payload in the hash table had a match. std::vector local_states_; std::vector has_match_; @@ -714,8 +718,20 @@ class JoinMatchIterator { void SetLookupResult(int num_batch_rows, int start_batch_row, const uint8_t* batch_has_match, const uint32_t* key_ids, bool no_duplicate_keys, const uint32_t* key_to_payload); + // Get the next batch of matching rows by outputting the batch row ids, key ids and + // payload ids. If the row_id_to_skip is not kInvalidRowId, then the row with that id + // will be skipped. This is useful for left-anti and left-semi joins, where we can + // safely skip the subsequent matchings of the row that already has a match in the + // previous batch. + // bool GetNextBatch(int num_rows_max, int* out_num_rows, uint16_t* batch_row_ids, - uint32_t* key_ids, uint32_t* payload_ids); + uint32_t* key_ids, uint32_t* payload_ids, + int row_id_to_skip = kInvalidRowId); + + // The row id that will never exist in an ExecBatch. Used to indicate that there is no + // row to skip. + // + static constexpr uint32_t kInvalidRowId = std::numeric_limits::max() + 1; private: int num_batch_rows_; @@ -736,6 +752,135 @@ class JoinMatchIterator { int current_match_for_row_; }; +// Implement the residual filter support used when processing the probe side exec batches. +// There are four filtering patterns, each with a corresponding public FilterXXX method: +// - LeftSemi and LeftAnti, each for its co-naming join type, opposite to each other. +// - RightSemiAnti for both right-semi and right-anti joins: they have the same filtering +// logic and differ only in the scanning phase. +// - Inner for inner joins and the inner part of outer joins: caller should take care of +// filtering the outer part. +// All the public Filter* methods have zero-cost shortcut for trivial filter. +// +class JoinResidualFilter { + public: + void Init(Expression filter, QueryContext* ctx, MemoryPool* pool, + int64_t hardware_flags, const HashJoinProjectionMaps* probe_schemas, + const HashJoinProjectionMaps* build_schemas, SwissTableForJoin* hash_table); + + void OnBuildFinished(); + + int NumBuildKeysReferred() const { return num_build_keys_referred_; } + int NumBuildPayloadsReferred() const { return num_build_payloads_referred_; } + + // Left-outer and full-outer joins can result in a different bit-vector than the one of + // probing the hash table if the residual filter is not a literal true. If so, caller + // should setup a bit-vector for filtering properly and call `UpdateFilterBitVector` + // accordingly. + // + bool NeedFilterBitVector(JoinType join_type) const { + return (join_type == JoinType::LEFT_OUTER || join_type == JoinType::FULL_OUTER) && + filter_ != literal(true); + } + + // Init the bit-vector for filtering. Caller should make sure the bit-vector has enough + // size for a particular probe side batch. + // + void InitFilterBitVector(int num_batch_rows, uint8_t* filter_bitvector); + + // Update the bit-vector for filtering according to the given batch row ids. + // + void UpdateFilterBitVector(int batch_start_row, int num_batch_rows, + const uint16_t* batch_row_ids, uint8_t* filter_bitvector); + + // Left row is passing if filter evaluates true. Output all the passing row ids in + // the input batch. Like the left-semi join semantic, each passing row is output only + // once. + // Zero-overhead shortcut guarantee for trivial filter. + // + Status FilterLeftSemi(const ExecBatch& keypayload_batch, int batch_start_row, + int num_batch_rows, const uint8_t* match_bitvector, + const uint32_t* key_ids, bool no_duplicate_keys, + arrow::util::TempVectorStack* temp_stack, int* num_passing_ids, + uint16_t* passing_batch_row_ids) const; + + // Logically the opposite of FilterLeftSemi. Output all the passing row ids in the input + // batch. Like the left-anti join semantic, each passing row is output only once. + // Zero-overhead shortcut guarantee for trivial filter. + // + Status FilterLeftAnti(const ExecBatch& keypayload_batch, int batch_start_row, + int num_batch_rows, const uint8_t* match_bitvector, + const uint32_t* key_ids, bool no_duplicate_keys, + arrow::util::TempVectorStack* temp_stack, int* num_passing_ids, + uint16_t* passing_batch_row_ids) const; + + // Right row is passing if filter evaluates true. Mark a match for all the passing + // payload ids in the hash table. This applies for both right-semi and right-anti joins: + // they differ in scanning phase. + // Zero-overhead shortcut guarantee for trivial filter. + // + Status FilterRightSemiAnti(int64_t thread_id, const ExecBatch& keypayload_batch, + int batch_start_row, int num_batch_rows, + const uint8_t* match_bitvector, const uint32_t* key_ids, + bool no_duplicate_keys, + arrow::util::TempVectorStack* temp_stack) const; + + // For a given batch of an inner match (an inner-join or the inner part of an + // outer-join), row is passing if filter evaluates true. Does not do any outer filtering + // because this method is usually called within a inner match loop, which doesn't have + // the full scope of outer join. This requires caller to handle the outer part properly. + // All batch_row_ids, key_ids and payload_ids_maybe_null are input and output, this is + // for efficient shortcut. + // Zero-overhead shortcut guarantee for trivial filter. + // + Status FilterInner(const ExecBatch& keypayload_batch, int num_batch_rows, + uint16_t* batch_row_ids, uint32_t* key_ids, + uint32_t* payload_ids_maybe_null, bool output_payload_ids, + arrow::util::TempVectorStack* temp_stack, + int* num_passing_rows) const; + + private: + // Evaluates the filter for a given batch of matching rows, and outputs the passing + // rows. Always introduces overhead of materialization and evaluation, so caller must do + // shortcut properly for trivial filters. + // + Status FilterOneBatch(const ExecBatch& keypayload_batch, int num_batch_rows, + uint16_t* batch_row_ids, uint32_t* key_ids_maybe_null, + uint32_t* payload_ids_maybe_null, bool output_key_ids, + bool output_payload_ids, arrow::util::TempVectorStack* temp_stack, + int* num_passing_rows) const; + + Result EvalFilter(const ExecBatch& keypayload_batch, int num_batch_rows, + const uint16_t* batch_row_ids, + const uint32_t* key_ids_maybe_null, + const uint32_t* payload_ids_maybe_null) const; + + Result MaterializeFilterInput(const ExecBatch& keypayload_batch, + int num_batch_rows, + const uint16_t* batch_row_ids, + const uint32_t* key_ids_maybe_null, + const uint32_t* payload_ids_maybe_null) const; + + private: + Expression filter_; + + QueryContext* ctx_; + MemoryPool* pool_; + int64_t hardware_flags_; + + const HashJoinProjectionMaps* probe_schemas_; + const HashJoinProjectionMaps* build_schemas_; + + SwissTableForJoin* hash_table_; + std::vector probe_filter_to_key_and_payload_; + int num_build_keys_referred_ = 0; + int num_build_payloads_referred_ = 0; + + int minibatch_size_; + const RowArray* build_keys_; + const RowArray* build_payloads_; + const uint32_t* key_to_payload_; +}; + // Implements entire processing of a probe side exec batch, // provided the join hash table is already built and available. // @@ -744,6 +889,7 @@ class JoinProbeProcessor { using OutputBatchFn = std::function; void Init(int num_key_columns, JoinType join_type, SwissTableForJoin* hash_table, + JoinResidualFilter* residual_filter, std::vector materialize, const std::vector* cmp, OutputBatchFn output_batch_fn); Status OnNextBatch(int64_t thread_id, const ExecBatch& keypayload_batch, @@ -760,6 +906,7 @@ class JoinProbeProcessor { JoinType join_type_; SwissTableForJoin* hash_table_; + JoinResidualFilter* residual_filter_; // One element per thread // std::vector materialize_;