From dda56e3380428d5f86271f1c909172d7fe07cc1a Mon Sep 17 00:00:00 2001 From: zanmato Date: Sun, 7 Jan 2024 01:53:52 -0800 Subject: [PATCH] Refine structure --- cpp/src/arrow/acero/swiss_join.cc | 379 +++++++++++++--------- cpp/src/arrow/acero/swiss_join_internal.h | 45 ++- 2 files changed, 260 insertions(+), 164 deletions(-) diff --git a/cpp/src/arrow/acero/swiss_join.cc b/cpp/src/arrow/acero/swiss_join.cc index 6f6c72de01af1..4a0ff7e67f3e5 100644 --- a/cpp/src/arrow/acero/swiss_join.cc +++ b/cpp/src/arrow/acero/swiss_join.cc @@ -1853,6 +1853,23 @@ bool JoinMatchIterator::GetNextBatch(int num_rows_max, int* out_num_rows, return (*out_num_rows) > 0; } +namespace { + +void CollectPassingBatchIds(int passing_bit, int64_t hardware_flags, int batch_start_row, + int num_batch_rows, const uint8_t* match_bitvector, + int* num_passing_ids, uint16_t* passing_batch_row_ids) { + arrow::util::bit_util::bits_to_indexes(passing_bit, hardware_flags, num_batch_rows, + match_bitvector, num_passing_ids, + passing_batch_row_ids); + // Add base batch row index. + // + for (int i = 0; i < *num_passing_ids; ++i) { + passing_batch_row_ids[i] += static_cast(batch_start_row); + } +} + +} // namespace + void JoinResidualFilter::Init(Expression filter, QueryContext* ctx, MemoryPool* pool, int64_t hardware_flags, const HashJoinProjectionMaps* probe_schemas, @@ -1922,7 +1939,11 @@ Status JoinResidualFilter::FilterLeftSemi(const ExecBatch& keypayload_batch, arrow::util::TempVectorStack* temp_stack, int* num_passing_ids, uint16_t* passing_batch_row_ids) const { - ARROW_DCHECK(filter_ != literal(true)); + if (filter_ == literal(true)) { + CollectPassingBatchIds(1, hardware_flags_, batch_start_row, num_batch_rows, + match_bitvector, num_passing_ids, passing_batch_row_ids); + return Status::OK(); + } *num_passing_ids = 0; if (filter_.IsNullLiteral() || filter_ == literal(false)) { @@ -1933,29 +1954,22 @@ Status JoinResidualFilter::FilterLeftSemi(const ExecBatch& keypayload_batch, // If filter refers no column in the right table, then we can directly filter on the // left rows without inner matching and materializing the right rows. // - arrow::util::bit_util::bits_to_indexes(1, hardware_flags_, num_batch_rows, - match_bitvector, num_passing_ids, - passing_batch_row_ids); - - // Add base batch row index. - // - for (int i = 0; i < *num_passing_ids; ++i) { - passing_batch_row_ids[i] += static_cast(batch_start_row); - } - - RETURN_NOT_OK(FilterInner(keypayload_batch, *num_passing_ids, passing_batch_row_ids, - /*payload_ids_maybe_null=*/NULLPTR, - /*payload_ids_maybe_null=*/NULLPTR, - /*output_payload_ids=*/false, /*output_payload_ids=*/false, - temp_stack, num_passing_ids)); + CollectPassingBatchIds(1, hardware_flags_, batch_start_row, num_batch_rows, + match_bitvector, num_passing_ids, passing_batch_row_ids); + RETURN_NOT_OK( + FilterOneBatch(keypayload_batch, *num_passing_ids, passing_batch_row_ids, + /*payload_ids_maybe_null=*/NULLPTR, + /*payload_ids_maybe_null=*/NULLPTR, + /*output_payload_ids=*/false, + /*output_payload_ids=*/false, temp_stack, num_passing_ids)); return Status::OK(); } - auto materialize_batch_ids_buf = + auto match_batch_ids_buf = arrow::util::TempVectorHolder(temp_stack, minibatch_size_); - auto materialize_key_ids_buf = + auto match_key_ids_buf = arrow::util::TempVectorHolder(temp_stack, minibatch_size_); - auto materialize_payload_ids_buf = + auto match_payload_ids_buf = arrow::util::TempVectorHolder(temp_stack, minibatch_size_); JoinMatchIterator match_iterator; @@ -1964,23 +1978,62 @@ Status JoinResidualFilter::FilterLeftSemi(const ExecBatch& keypayload_batch, int num_matches_next = 0; int row_id_to_skip = JoinMatchIterator::kInvalidRowId; while (match_iterator.GetNextBatch( - minibatch_size_, &num_matches_next, materialize_batch_ids_buf.mutable_data(), - materialize_key_ids_buf.mutable_data(), materialize_payload_ids_buf.mutable_data(), + minibatch_size_, &num_matches_next, match_batch_ids_buf.mutable_data(), + match_key_ids_buf.mutable_data(), match_payload_ids_buf.mutable_data(), row_id_to_skip)) { int num_filtered = 0; - RETURN_NOT_OK(FilterInner( - keypayload_batch, num_matches_next, materialize_batch_ids_buf.mutable_data(), - materialize_key_ids_buf.mutable_data(), - materialize_payload_ids_buf.mutable_data(), /*output_key_ids=*/false, + RETURN_NOT_OK(FilterOneBatch( + keypayload_batch, num_matches_next, match_batch_ids_buf.mutable_data(), + match_key_ids_buf.mutable_data(), match_payload_ids_buf.mutable_data(), + /*output_key_ids=*/false, /*output_payload_ids=*/false, temp_stack, &num_filtered)); // There may be multiple matches for a row in batch. Collect distinct row ids. // for (int ifiltered = 0; ifiltered < num_filtered; ++ifiltered) { - if (materialize_batch_ids_buf.mutable_data()[ifiltered] == row_id_to_skip) { + if (match_batch_ids_buf.mutable_data()[ifiltered] == row_id_to_skip) { continue; } row_id_to_skip = passing_batch_row_ids[*num_passing_ids] = - materialize_batch_ids_buf.mutable_data()[ifiltered]; + match_batch_ids_buf.mutable_data()[ifiltered]; + ++(*num_passing_ids); + } + } + + return Status::OK(); +} + +Status JoinResidualFilter::FilterLeftAnti(const ExecBatch& keypayload_batch, + int batch_start_row, int num_batch_rows, + const uint8_t* match_bitvector, + const uint32_t* key_ids, bool no_duplicate_keys, + arrow::util::TempVectorStack* temp_stack, + int* num_passing_ids, + uint16_t* passing_batch_row_ids) const { + if (filter_ == literal(true)) { + CollectPassingBatchIds(0, hardware_flags_, batch_start_row, num_batch_rows, + match_bitvector, num_passing_ids, passing_batch_row_ids); + return Status::OK(); + } + + *num_passing_ids = 0; + int num_matching_ids = 0; + auto matching_batch_row_ids = + arrow::util::TempVectorHolder(temp_stack, num_batch_rows); + RETURN_NOT_OK(FilterLeftSemi(keypayload_batch, batch_start_row, num_batch_rows, + match_bitvector, key_ids, no_duplicate_keys, temp_stack, + &num_matching_ids, matching_batch_row_ids.mutable_data())); + + // Collect no match row ids. + // + int imatch = 0; + for (int irow = batch_start_row; irow < batch_start_row + num_batch_rows; ++irow) { + while (imatch < num_matching_ids && + matching_batch_row_ids.mutable_data()[imatch] < irow) { + ++imatch; + } + if (imatch == num_matching_ids || + matching_batch_row_ids.mutable_data()[imatch] != irow) { + passing_batch_row_ids[*num_passing_ids] = static_cast(irow); ++(*num_passing_ids); } } @@ -1991,59 +2044,90 @@ Status JoinResidualFilter::FilterLeftSemi(const ExecBatch& keypayload_batch, Status JoinResidualFilter::FilterRightSemi( const ExecBatch& keypayload_batch, int batch_start_row, int num_batch_rows, const uint8_t* match_bitvector, const uint32_t* key_ids, bool no_duplicate_keys, - arrow::util::TempVectorStack* temp_stack, - OutputPayloadIdsCallback output_payload_ids) const { - ARROW_DCHECK(filter_ != literal(true)); + arrow::util::TempVectorStack* temp_stack, OnMatchBatch on_match_batch) const { + ARROW_DCHECK(on_match_batch); if (filter_.IsNullLiteral() || filter_ == literal(false)) { return Status::OK(); } - auto materialize_batch_ids_buf = + int num_matching_ids = 0; + if (filter_ == literal(true)) { + auto match_relative_batch_ids_buf = + arrow::util::TempVectorHolder(temp_stack, num_batch_rows); + auto match_key_ids_buf = + arrow::util::TempVectorHolder(temp_stack, num_batch_rows); + + arrow::util::bit_util::bits_to_indexes(1, hardware_flags_, num_batch_rows, + match_bitvector, &num_matching_ids, + match_relative_batch_ids_buf.mutable_data()); + // Collect key ids of passing rows. + // + for (int i = 0; i < num_matching_ids; ++i) { + uint16_t id = match_relative_batch_ids_buf.mutable_data()[i]; + match_key_ids_buf.mutable_data()[i] = key_ids[id]; + } + + on_match_batch(num_matching_ids, /*batch_row_ids=*/NULLPTR, + match_key_ids_buf.mutable_data(), + /*payload_ids=*/NULLPTR); + return Status::OK(); + } + + auto match_batch_ids_buf = arrow::util::TempVectorHolder(temp_stack, minibatch_size_); - auto materialize_key_ids_buf = + auto match_key_ids_buf = arrow::util::TempVectorHolder(temp_stack, minibatch_size_); - auto materialize_payload_ids_buf = + auto match_payload_ids_buf = arrow::util::TempVectorHolder(temp_stack, minibatch_size_); JoinMatchIterator match_iterator; match_iterator.SetLookupResult(num_batch_rows, batch_start_row, match_bitvector, key_ids, no_duplicate_keys, key_to_payload_); - int num_matches_next = 0; - while (match_iterator.GetNextBatch(minibatch_size_, &num_matches_next, - materialize_batch_ids_buf.mutable_data(), - materialize_key_ids_buf.mutable_data(), - materialize_payload_ids_buf.mutable_data())) { + while (match_iterator.GetNextBatch( + minibatch_size_, &num_matching_ids, match_batch_ids_buf.mutable_data(), + match_key_ids_buf.mutable_data(), match_payload_ids_buf.mutable_data())) { int num_filtered = 0; - RETURN_NOT_OK(FilterInner( - keypayload_batch, num_matches_next, materialize_batch_ids_buf.mutable_data(), - materialize_key_ids_buf.mutable_data(), - materialize_payload_ids_buf.mutable_data(), /*output_key_ids=*/false, - /*output_payload_ids=*/true, temp_stack, &num_filtered)); - // Output payload ids of passing rows. - // - output_payload_ids(num_filtered, materialize_payload_ids_buf.mutable_data()); + RETURN_NOT_OK(FilterOneBatch( + keypayload_batch, num_matching_ids, match_batch_ids_buf.mutable_data(), + match_key_ids_buf.mutable_data(), match_payload_ids_buf.mutable_data(), + /*output_key_ids=*/false, + /*output_payload_ids=*/true, temp_stack, &num_filtered, on_match_batch)); } return Status::OK(); } -Status JoinResidualFilter::FilterInner(const ExecBatch& keypayload_batch, - int num_batch_rows, uint16_t* batch_row_ids, - uint32_t* key_ids_maybe_null, - uint32_t* payload_ids_maybe_null, - bool output_key_ids, bool output_payload_ids, - arrow::util::TempVectorStack* temp_stack, - int* num_passing_rows) const { - ARROW_DCHECK(filter_ != literal(true)); - ARROW_DCHECK(!output_key_ids || key_ids_maybe_null); - ARROW_DCHECK(!output_payload_ids || payload_ids_maybe_null); +Status JoinResidualFilter::FilterInner( + const ExecBatch& keypayload_batch, int num_batch_rows, uint16_t* batch_row_ids, + uint32_t* key_ids, uint32_t* payload_ids_maybe_null, bool output_payload_ids, + arrow::util::TempVectorStack* temp_stack, int* num_passing_rows) const { + if (filter_ == literal(true)) { + *num_passing_rows = num_batch_rows; + return Status::OK(); + } *num_passing_rows = 0; if (filter_.IsNullLiteral() || filter_ == literal(false)) { return Status::OK(); } + return FilterOneBatch( + keypayload_batch, num_batch_rows, batch_row_ids, key_ids, payload_ids_maybe_null, + /*output_key_ids=*/true, output_payload_ids, temp_stack, num_passing_rows); +} + +Status JoinResidualFilter::FilterOneBatch( + const ExecBatch& keypayload_batch, int num_batch_rows, uint16_t* batch_row_ids, + uint32_t* key_ids_maybe_null, uint32_t* payload_ids_maybe_null, bool output_key_ids, + bool output_payload_ids, arrow::util::TempVectorStack* temp_stack, + int* num_passing_rows, OnMatchBatch on_match_batch) const { + ARROW_DCHECK(!filter_.IsNullLiteral() && filter_ != literal(true) && + filter_ != literal(false)); + ARROW_DCHECK(!output_key_ids || key_ids_maybe_null); + ARROW_DCHECK(!output_payload_ids || payload_ids_maybe_null); + + *num_passing_rows = 0; ARROW_ASSIGN_OR_RAISE(Datum mask, EvalFilter(keypayload_batch, num_batch_rows, batch_row_ids, key_ids_maybe_null, payload_ids_maybe_null)); @@ -2051,10 +2135,12 @@ Status JoinResidualFilter::FilterInner(const ExecBatch& keypayload_batch, const auto& mask_scalar = mask.scalar_as(); if (mask_scalar.is_valid && mask_scalar.value) { *num_passing_rows = num_batch_rows; - return Status::OK(); - } else { - return Status::OK(); } + if (on_match_batch) { + on_match_batch(*num_passing_rows, batch_row_ids, key_ids_maybe_null, + payload_ids_maybe_null); + } + return Status::OK(); } ARROW_DCHECK_EQ(mask.array()->offset, 0); @@ -2077,6 +2163,11 @@ Status JoinResidualFilter::FilterInner(const ExecBatch& keypayload_batch, } } + if (on_match_batch) { + on_match_batch(*num_passing_rows, batch_row_ids, key_ids_maybe_null, + payload_ids_maybe_null); + } + return Status::OK(); } @@ -2165,6 +2256,7 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id, const ExecBatch& keypayload_batch, arrow::util::TempVectorStack* temp_stack, std::vector* temp_column_arrays) { + bool no_duplicate_keys = (hash_table_->key_to_payload() == nullptr); const SwissTable* swiss_table = hash_table_->keys()->swiss_table(); int64_t hardware_flags = swiss_table->hardware_flags(); int minibatch_size = swiss_table->minibatch_size(); @@ -2190,8 +2282,6 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id, arrow::util::TempVectorHolder(temp_stack, minibatch_size); auto materialize_payload_ids_buf = arrow::util::TempVectorHolder(temp_stack, minibatch_size); - auto materialize_no_match_batch_ids_buf = - arrow::util::TempVectorHolder(temp_stack, minibatch_size); auto filtered_bitvector_buf = arrow::util::TempVectorHolder( temp_stack, static_cast(bit_util::BytesForBits(minibatch_size))); @@ -2217,71 +2307,29 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id, if (join_type_ == JoinType::LEFT_SEMI || join_type_ == JoinType::LEFT_ANTI || join_type_ == JoinType::RIGHT_SEMI || join_type_ == JoinType::RIGHT_ANTI) { int num_passing_ids = 0; - const uint16_t* materialize_batch_ids = materialize_batch_ids_buf.mutable_data(); - if (residual_filter_->IsTrivial()) { - int bit_match = (join_type_ == JoinType::LEFT_ANTI) ? 0 : 1; - arrow::util::bit_util::bits_to_indexes( - bit_match, hardware_flags, minibatch_size_next, - match_bitvector_buf.mutable_data(), &num_passing_ids, - materialize_batch_ids_buf.mutable_data()); - if (join_type_ == JoinType::RIGHT_SEMI || join_type_ == JoinType::RIGHT_ANTI) { - // For right-semi, right-anti joins: collect key ids of passing rows. - // - for (int i = 0; i < num_passing_ids; ++i) { - uint16_t id = materialize_batch_ids_buf.mutable_data()[i]; - materialize_key_ids_buf.mutable_data()[i] = key_ids_buf.mutable_data()[id]; - } - // For right-semi, right-anti joins: update has-match flags for the rows - // in hash table. - hash_table_->UpdateHasMatchForKeys(thread_id, num_passing_ids, - materialize_key_ids_buf.mutable_data()); - } else { - // For left-semi, left-anti joins: add base batch row index. - // - for (int i = 0; i < num_passing_ids; ++i) { - materialize_batch_ids_buf.mutable_data()[i] += - static_cast(minibatch_start); - } - } + if (join_type_ == JoinType::LEFT_SEMI) { + RETURN_NOT_OK(residual_filter_->FilterLeftSemi( + keypayload_batch, minibatch_start, minibatch_size_next, + match_bitvector_buf.mutable_data(), key_ids_buf.mutable_data(), + no_duplicate_keys, temp_stack, &num_passing_ids, + materialize_batch_ids_buf.mutable_data())); + } else if (join_type_ == JoinType::LEFT_ANTI) { + RETURN_NOT_OK(residual_filter_->FilterLeftAnti( + keypayload_batch, minibatch_start, minibatch_size_next, + match_bitvector_buf.mutable_data(), key_ids_buf.mutable_data(), + no_duplicate_keys, temp_stack, &num_passing_ids, + materialize_batch_ids_buf.mutable_data())); } else { - bool no_duplicate_keys = (hash_table_->key_to_payload() == nullptr); - if (join_type_ == JoinType::LEFT_SEMI || join_type_ == JoinType::LEFT_ANTI) { - RETURN_NOT_OK(residual_filter_->FilterLeftSemi( - keypayload_batch, minibatch_start, minibatch_size_next, - match_bitvector_buf.mutable_data(), key_ids_buf.mutable_data(), - no_duplicate_keys, temp_stack, &num_passing_ids, - materialize_batch_ids_buf.mutable_data())); - if (join_type_ == JoinType::LEFT_ANTI) { - // For left-anti join: collect no match row ids. - // - int num_no_passing_ids = 0; - int imatch = 0; - for (int irow = minibatch_start; - irow < minibatch_start + static_cast(minibatch_size_next); ++irow) { - while (imatch < num_passing_ids && - materialize_batch_ids_buf.mutable_data()[imatch] < irow) { - ++imatch; - } - if (imatch == num_passing_ids || - materialize_batch_ids_buf.mutable_data()[imatch] != irow) { - materialize_no_match_batch_ids_buf.mutable_data()[num_no_passing_ids++] = - static_cast(irow); - } - } - num_passing_ids = num_no_passing_ids; - materialize_batch_ids = materialize_no_match_batch_ids_buf.mutable_data(); - } - } else { - auto update_has_match = [thread_id, this](int num_passing_ids, - const uint32_t* payload_ids) { - hash_table_->UpdateHasMatchForPayloads(thread_id, num_passing_ids, - payload_ids); - }; - RETURN_NOT_OK(residual_filter_->FilterRightSemi( - keypayload_batch, minibatch_start, minibatch_size_next, - match_bitvector_buf.mutable_data(), key_ids_buf.mutable_data(), - no_duplicate_keys, temp_stack, update_has_match)); - } + RETURN_NOT_OK(residual_filter_->FilterRightSemi( + keypayload_batch, minibatch_start, minibatch_size_next, + match_bitvector_buf.mutable_data(), key_ids_buf.mutable_data(), + no_duplicate_keys, temp_stack, + [thread_id, this](int num_passing_ids, const uint16_t*, + const uint32_t* key_ids_maybe_null, + const uint32_t* payload_ids_maybe_null) { + UpdateHasMatch(thread_id, num_passing_ids, key_ids_maybe_null, + payload_ids_maybe_null); + })); } if (join_type_ == JoinType::LEFT_SEMI || join_type_ == JoinType::LEFT_ANTI) { @@ -2289,7 +2337,7 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id, // row ids. // RETURN_NOT_OK(materialize_[thread_id]->AppendProbeOnly( - keypayload_batch, num_passing_ids, materialize_batch_ids, + keypayload_batch, num_passing_ids, materialize_batch_ids_buf.mutable_data(), [&](ExecBatch batch) { return output_batch_fn_(thread_id, std::move(batch)); })); @@ -2299,47 +2347,49 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id, // Since every hash table lookup for an input row might have multiple // matches we use a helper class that implements enumerating all of them. // - bool no_duplicate_keys = (hash_table_->key_to_payload() == nullptr); JoinMatchIterator match_iterator; match_iterator.SetLookupResult( minibatch_size_next, minibatch_start, match_bitvector_buf.mutable_data(), key_ids_buf.mutable_data(), no_duplicate_keys, hash_table_->key_to_payload()); - if (!residual_filter_->IsTrivial() && - (join_type_ == JoinType::LEFT_OUTER || join_type_ == JoinType::FULL_OUTER)) { + int num_matches_next; + bool use_filtered_bitvector = + residual_filter_->NeedToUpdateMatchBitVector(join_type_); + // For filtered result, initialize match bit-vector to all zeros (no match). + // + if (use_filtered_bitvector) { std::memset(filtered_bitvector_buf.mutable_data(), 0, bit_util::BytesForBits(minibatch_size_next)); } - int num_matches_next; while (match_iterator.GetNextBatch(minibatch_size, &num_matches_next, materialize_batch_ids_buf.mutable_data(), materialize_key_ids_buf.mutable_data(), materialize_payload_ids_buf.mutable_data())) { - if (!residual_filter_->IsTrivial()) { - RETURN_NOT_OK(residual_filter_->FilterInner( - keypayload_batch, num_matches_next, - materialize_batch_ids_buf.mutable_data(), - materialize_key_ids_buf.mutable_data(), - materialize_payload_ids_buf.mutable_data(), /*output_key_ids=*/true, - !no_duplicate_keys, temp_stack, &num_matches_next)); - if (join_type_ == JoinType::LEFT_OUTER || join_type_ == JoinType::FULL_OUTER) { - for (int i = 0; i < num_matches_next; ++i) { - int bit_idx = materialize_batch_ids_buf.mutable_data()[i] - minibatch_start; - bit_util::SetBitTo(filtered_bitvector_buf.mutable_data(), bit_idx, 1); - } - } - } + RETURN_NOT_OK(residual_filter_->FilterInner( + keypayload_batch, num_matches_next, materialize_batch_ids_buf.mutable_data(), + materialize_key_ids_buf.mutable_data(), + materialize_payload_ids_buf.mutable_data(), !no_duplicate_keys, temp_stack, + &num_matches_next)); + const uint16_t* materialize_batch_ids = materialize_batch_ids_buf.mutable_data(); const uint32_t* materialize_key_ids = materialize_key_ids_buf.mutable_data(); const uint32_t* materialize_payload_ids = no_duplicate_keys ? materialize_key_ids_buf.mutable_data() : materialize_payload_ids_buf.mutable_data(); + // For filtered result, update match bit-vector. + // + if (use_filtered_bitvector) { + UpdateMatchBitVector(minibatch_start, num_matches_next, + filtered_bitvector_buf.mutable_data(), num_matches_next, + materialize_batch_ids); + } + // For right-outer, full-outer joins we need to update has-match flags // for the rows in hash table. // if (join_type_ == JoinType::RIGHT_OUTER || join_type_ == JoinType::FULL_OUTER) { - hash_table_->UpdateHasMatchForPayloads(thread_id, num_matches_next, - materialize_payload_ids); + UpdateHasMatch(thread_id, num_matches_next, /*key_ids_maybe_null=*/NULLPTR, + materialize_payload_ids); } // Call materialize for resulting id tuples pointing to matching pairs @@ -2359,20 +2409,12 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id, // if (join_type_ == JoinType::LEFT_OUTER || join_type_ == JoinType::FULL_OUTER) { int num_passing_ids = 0; - const uint8_t* match_bitvector = residual_filter_->IsTrivial() - ? match_bitvector_buf.mutable_data() - : filtered_bitvector_buf.mutable_data(); - arrow::util::bit_util::bits_to_indexes( - /*bit_to_search=*/0, hardware_flags, minibatch_size_next, match_bitvector, + CollectPassingBatchIds( + 0, hardware_flags, minibatch_start, minibatch_size_next, + use_filtered_bitvector ? filtered_bitvector_buf.mutable_data() + : match_bitvector_buf.mutable_data(), &num_passing_ids, materialize_batch_ids_buf.mutable_data()); - // Add base batch row index. - // - for (int i = 0; i < num_passing_ids; ++i) { - materialize_batch_ids_buf.mutable_data()[i] += - static_cast(minibatch_start); - } - RETURN_NOT_OK(materialize_[thread_id]->AppendProbeOnly( keypayload_batch, num_passing_ids, materialize_batch_ids_buf.mutable_data(), [&](ExecBatch batch) { @@ -2400,6 +2442,27 @@ Status JoinProbeProcessor::OnFinished() { return Status::OK(); } +void JoinProbeProcessor::UpdateHasMatch(int64_t thread_id, int num_rows, + const uint32_t* key_ids_maybe_null, + const uint32_t* payload_ids_maybe_null) { + ARROW_DCHECK(key_ids_maybe_null || payload_ids_maybe_null); + if (payload_ids_maybe_null) { + hash_table_->UpdateHasMatchForPayloads(thread_id, num_rows, payload_ids_maybe_null); + } else { + hash_table_->UpdateHasMatchForKeys(thread_id, num_rows, key_ids_maybe_null); + } +} + +void JoinProbeProcessor::UpdateMatchBitVector(int batch_start_row, int num_batch_rows, + uint8_t* match_bitvector, + int num_passing_rows, + const uint16_t* batch_ids) { + for (int i = 0; i < num_passing_rows; ++i) { + int bit_idx = batch_ids[i] - batch_start_row; + bit_util::SetBitTo(match_bitvector, bit_idx, 1); + } +} + class SwissJoin : public HashJoinImpl { public: Status Init(QueryContext* ctx, JoinType join_type, size_t num_threads, diff --git a/cpp/src/arrow/acero/swiss_join_internal.h b/cpp/src/arrow/acero/swiss_join_internal.h index 40b2c5a9867fb..f1da3c8b1a26b 100644 --- a/cpp/src/arrow/acero/swiss_join_internal.h +++ b/cpp/src/arrow/acero/swiss_join_internal.h @@ -761,31 +761,49 @@ class JoinResidualFilter { void SetBuildSide(int minibatch_size, const RowArray* build_keys, const RowArray* build_payloads, const uint32_t* key_to_payload); - bool IsTrivial() const { return filter_ == literal(true); } + bool NeedToUpdateMatchBitVector(JoinType join_type) const { + return (join_type == JoinType::LEFT_OUTER || join_type == JoinType::FULL_OUTER) && + filter_ != literal(true); + } int NumBuildKeysReferred() const { return num_build_keys_referred_; } int NumBuildPayloadsReferred() const { return num_build_payloads_referred_; } + using OnMatchBatch = + std::function; + Status FilterLeftSemi(const ExecBatch& keypayload_batch, int batch_start_row, int num_batch_rows, const uint8_t* match_bitvector, const uint32_t* key_ids, bool no_duplicate_keys, arrow::util::TempVectorStack* temp_stack, int* num_passing_ids, uint16_t* passing_batch_row_ids) const; - using OutputPayloadIdsCallback = std::function; + Status FilterLeftAnti(const ExecBatch& keypayload_batch, int batch_start_row, + int num_batch_rows, const uint8_t* match_bitvector, + const uint32_t* key_ids, bool no_duplicate_keys, + arrow::util::TempVectorStack* temp_stack, int* num_passing_ids, + uint16_t* passing_batch_row_ids) const; + Status FilterRightSemi(const ExecBatch& keypayload_batch, int batch_start_row, int num_batch_rows, const uint8_t* match_bitvector, const uint32_t* key_ids, bool no_duplicate_keys, arrow::util::TempVectorStack* temp_stack, - OutputPayloadIdsCallback output_payload_ids) const; + OnMatchBatch on_match_batch) const; Status FilterInner(const ExecBatch& keypayload_batch, int num_batch_rows, - uint16_t* batch_row_ids, uint32_t* key_ids_maybe_null, - uint32_t* payload_ids_maybe_null, bool output_key_ids, - bool output_payload_ids, arrow::util::TempVectorStack* temp_stack, + uint16_t* batch_row_ids, uint32_t* key_ids, + uint32_t* payload_ids_maybe_null, bool output_payload_ids, + arrow::util::TempVectorStack* temp_stack, int* num_passing_rows) const; private: + Status FilterOneBatch(const ExecBatch& keypayload_batch, int num_batch_rows, + uint16_t* batch_row_ids, uint32_t* key_ids_maybe_null, + uint32_t* payload_ids_maybe_null, bool output_key_ids, + bool output_payload_ids, arrow::util::TempVectorStack* temp_stack, + int* num_passing_rows, OnMatchBatch on_match_batch = {}) const; + Result EvalFilter(const ExecBatch& keypayload_batch, int num_batch_rows, const uint16_t* batch_row_ids, const uint32_t* key_ids_maybe_null, @@ -837,6 +855,21 @@ class JoinProbeProcessor { // Status OnFinished(); + private: + // For right-* and full-outer joins: we need to update has-match flags + // for the rows in hash table. + // + void UpdateHasMatch(int64_t thread_id, int num_passing_ids, + const uint32_t* key_ids_maybe_null, + const uint32_t* payload_ids_maybe_null); + + // For left-outer and full-outer joins: we need to update match bit-vector if + // the residual filter is not a literal true. + // + void UpdateMatchBitVector(int batch_start_row, int num_batch_rows, + uint8_t* match_bitvector, int num_passing_rows, + const uint16_t* batch_ids); + private: int num_key_columns_; JoinType join_type_;