diff --git a/cpp/src/arrow/acero/hash_join_node.cc b/cpp/src/arrow/acero/hash_join_node.cc index 254dad361ff87..c0179fd160e4e 100644 --- a/cpp/src/arrow/acero/hash_join_node.cc +++ b/cpp/src/arrow/acero/hash_join_node.cc @@ -740,13 +740,11 @@ class HashJoinNode : public ExecNode, public TracedNode { // Create hash join implementation object // SwissJoin does not support: // a) 64-bit string offsets - // b) residual predicates - // c) dictionaries + // b) dictionaries // bool use_swiss_join; #if ARROW_LITTLE_ENDIAN - use_swiss_join = (filter == literal(true)) && !schema_mgr->HasDictionaries() && - !schema_mgr->HasLargeBinary(); + use_swiss_join = !schema_mgr->HasDictionaries() && !schema_mgr->HasLargeBinary(); #else use_swiss_join = false; #endif diff --git a/cpp/src/arrow/acero/swiss_join.cc b/cpp/src/arrow/acero/swiss_join.cc index c730767bb0c62..ad8e44c585e7c 100644 --- a/cpp/src/arrow/acero/swiss_join.cc +++ b/cpp/src/arrow/acero/swiss_join.cc @@ -1855,6 +1855,68 @@ bool JoinMatchIterator::GetNextBatch(int num_rows_max, int* out_num_rows, return (*out_num_rows) > 0; } +void JoinResidualFilter::Init(Expression filter, int minibatch_size, QueryContext* ctx, + MemoryPool* pool, int64_t hardware_flags, + const HashJoinProjectionMaps* probe_schemas, + const HashJoinProjectionMaps* build_schemas) { + filter_ = std::move(filter); + if (filter_ == literal(true)) { + return; + } + + minibatch_size_ = minibatch_size; + ctx_ = ctx; + pool_ = pool; + hardware_flags_ = hardware_flags; + probe_schemas_ = probe_schemas; + build_schemas_ = build_schemas; + + { + probe_filter_to_key_and_payload_.resize( + probe_schemas_->num_cols(HashJoinProjection::FILTER)); + int num_key_cols = probe_schemas_->num_cols(HashJoinProjection::KEY); + auto to_key = + probe_schemas_->map(HashJoinProjection::FILTER, HashJoinProjection::KEY); + auto to_payload = + probe_schemas_->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD); + for (int i = 0; static_cast(i) < probe_filter_to_key_and_payload_.size(); + ++i) { + if (auto idx = to_key.get(i); idx != SchemaProjectionMap::kMissingField) { + probe_filter_to_key_and_payload_[i] = idx; + } else if (idx = to_payload.get(i); idx != SchemaProjectionMap::kMissingField) { + probe_filter_to_key_and_payload_[i] = idx + num_key_cols; + } else { + DCHECK(false); + } + } + } + + { + int num_columns = build_schemas_->num_cols(HashJoinProjection::FILTER); + auto to_key = + build_schemas_->map(HashJoinProjection::FILTER, HashJoinProjection::KEY); + auto to_payload = + build_schemas_->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD); + for (int i = 0; i < num_columns; ++i) { + if (to_key.get(i) != SchemaProjectionMap::kMissingField) { + num_build_keys_referred_++; + } else if (to_payload.get(i) != SchemaProjectionMap::kMissingField) { + num_build_payloads_referred_++; + } else { + DCHECK(false); + } + } + } +} + +void JoinResidualFilter::SetBuildSide(const RowArray* build_keys, + const RowArray* build_payloads, + const uint32_t* key_to_payload) { + build_keys_ = build_keys; + build_payloads_ = build_payloads; + key_to_payload_ = key_to_payload; +} + Status JoinResidualFilter::FilterMatchBitVector( const ExecBatch& keypayload_batch, int batch_start_row, int num_batch_rows, int bit_match, const uint8_t* match_bitvector, const uint32_t* key_ids, @@ -1867,7 +1929,10 @@ Status JoinResidualFilter::FilterMatchBitVector( return Status::OK(); } - if (build_filter_to_key_.empty() && build_filter_to_payload_.empty()) { + if (num_build_keys_referred_ == 0 && num_build_payloads_referred_ == 0) { + // If filter refers no column in right table, + // TODO + // arrow::util::bit_util::bits_to_indexes(bit_match, hardware_flags_, num_batch_rows, match_bitvector, num_passing_ids, passing_batch_row_ids); @@ -2001,8 +2066,8 @@ Result JoinResidualFilter::MaterializeFilterInput( const uint32_t* key_ids_maybe_null, const uint32_t* payload_ids_maybe_null) const { ExecBatch out; out.length = num_batch_rows; - out.values.resize(probe_filter_to_key_and_payload_.size() + - build_filter_to_key_.size() + build_filter_to_payload_.size()); + out.values.resize(probe_filter_to_key_and_payload_.size() + num_build_keys_referred_ + + num_build_payloads_referred_); if (probe_filter_to_key_and_payload_.size() > 0) { ExecBatchBuilder probe_batch_builder; @@ -2017,31 +2082,29 @@ Result JoinResidualFilter::MaterializeFilterInput( } } - if (build_filter_to_key_.size() > 0) { - ARROW_DCHECK(key_ids_maybe_null); - for (size_t i = 0; i < build_filter_to_key_.size(); ++i) { - int key_idx = build_filter_to_key_[i]; - ResizableArrayData build_key; - build_key.Init(build_schemas_->data_type(HashJoinProjection::KEY, key_idx), pool_, - bit_util::Log2(num_batch_rows)); - RETURN_NOT_OK(build_keys_->DecodeSelected(&build_key, key_idx, num_batch_rows, - key_ids_maybe_null, pool_)); - out.values[probe_filter_to_key_and_payload_.size() + i] = build_key.array_data(); - } - } - - if (build_filter_to_payload_.size() > 0) { - ARROW_DCHECK(payload_ids_maybe_null); - for (size_t i = 0; i < build_filter_to_payload_.size(); ++i) { - int payload_idx = build_filter_to_payload_[i]; - ResizableArrayData build_payload; - build_payload.Init( - build_schemas_->data_type(HashJoinProjection::PAYLOAD, payload_idx), pool_, - bit_util::Log2(num_batch_rows)); - RETURN_NOT_OK(build_payloads_->DecodeSelected( - &build_payload, payload_idx, num_batch_rows, payload_ids_maybe_null, pool_)); - out.values[probe_filter_to_key_and_payload_.size() + build_filter_to_key_.size() + - i] = build_payload.array_data(); + if (num_build_keys_referred_ > 0 || num_build_payloads_referred_ > 0) { + ARROW_DCHECK(num_build_keys_referred_ == 0 || key_ids_maybe_null); + ARROW_DCHECK(num_build_payloads_referred_ == 0 || payload_ids_maybe_null); + + int num_build_cols = build_schemas_->num_cols(HashJoinProjection::FILTER); + auto to_key = + build_schemas_->map(HashJoinProjection::FILTER, HashJoinProjection::KEY); + auto to_payload = + build_schemas_->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD); + for (int i = 0; i < num_build_cols; ++i) { + ResizableArrayData column_data; + column_data.Init(build_schemas_->data_type(HashJoinProjection::FILTER, i), pool_, + bit_util::Log2(num_batch_rows)); + if (auto idx = to_key.get(i); idx != SchemaProjectionMap::kMissingField) { + RETURN_NOT_OK(build_keys_->DecodeSelected(&column_data, idx, num_batch_rows, + key_ids_maybe_null, pool_)); + } else if (idx = to_payload.get(i); idx != SchemaProjectionMap::kMissingField) { + RETURN_NOT_OK(build_payloads_->DecodeSelected(&column_data, idx, num_batch_rows, + payload_ids_maybe_null, pool_)); + } else { + ARROW_DCHECK(false); + } + out.values[probe_filter_to_key_and_payload_.size() + i] = column_data.array_data(); } } @@ -2050,12 +2113,14 @@ Result JoinResidualFilter::MaterializeFilterInput( void JoinProbeProcessor::Init(int num_key_columns, JoinType join_type, SwissTableForJoin* hash_table, + JoinResidualFilter* residual_filter, std::vector materialize, const std::vector* cmp, OutputBatchFn output_batch_fn) { num_key_columns_ = num_key_columns; join_type_ = join_type; hash_table_ = hash_table; + residual_filter_ = residual_filter; materialize_.resize(materialize.size()); for (size_t i = 0; i < materialize.size(); ++i) { materialize_[i] = materialize[i]; @@ -2321,8 +2386,13 @@ class SwissJoin : public HashJoinImpl { materialize[i] = &local_states_[i].materialize; } + int minibatch_size = hash_table_.keys()->swiss_table()->minibatch_size(); + residual_filter_.Init(std::move(filter), minibatch_size, ctx_, pool_, hardware_flags_, + proj_map_left, proj_map_right); + probe_processor_.Init(proj_map_left->num_cols(HashJoinProjection::KEY), join_type_, - &hash_table_, materialize, &key_cmp_, output_batch_callback_); + &hash_table_, &residual_filter_, materialize, &key_cmp_, + output_batch_callback_); InitTaskGroups(); @@ -2525,6 +2595,9 @@ class SwissJoin : public HashJoinImpl { } hash_table_ready_.store(true); + residual_filter_.SetBuildSide(hash_table_.keys()->keys(), hash_table_.payloads(), + hash_table_.key_to_payload()); + return build_finished_callback_(thread_id); } diff --git a/cpp/src/arrow/acero/swiss_join_internal.h b/cpp/src/arrow/acero/swiss_join_internal.h index 4d16b6310700e..6bcddfca9d828 100644 --- a/cpp/src/arrow/acero/swiss_join_internal.h +++ b/cpp/src/arrow/acero/swiss_join_internal.h @@ -738,6 +738,13 @@ class JoinMatchIterator { class JoinResidualFilter { public: + void Init(Expression filter, int minibatch_size, QueryContext* ctx, MemoryPool* pool, + int64_t hardware_flags, const HashJoinProjectionMaps* probe_schemas, + const HashJoinProjectionMaps* build_schemas); + + void SetBuildSide(const RowArray* build_keys, const RowArray* build_payloads, + const uint32_t* key_to_payload); + bool IsTrivial() const { return filter_ == literal(true); } Status FilterMatchBitVector(const ExecBatch& keypayload_batch, int batch_start_row, @@ -768,21 +775,23 @@ class JoinResidualFilter { const uint32_t* payload_ids_maybe_null) const; private: + Expression filter_; + int minibatch_size_; + QueryContext* ctx_; - int64_t hardware_flags_; MemoryPool* pool_; - int minibatch_size_; - // const HashJoinProjectionMaps* probe_schemas_; + int64_t hardware_flags_; + + const HashJoinProjectionMaps* probe_schemas_; const HashJoinProjectionMaps* build_schemas_; - Expression filter_; std::vector probe_filter_to_key_and_payload_; - std::vector build_filter_to_key_; - std::vector build_filter_to_payload_; + int num_build_keys_referred_ = 0; + int num_build_payloads_referred_ = 0; - const uint32_t* key_to_payload_; const RowArray* build_keys_; const RowArray* build_payloads_; + const uint32_t* key_to_payload_; }; // Implements entire processing of a probe side exec batch, @@ -793,6 +802,7 @@ class JoinProbeProcessor { using OutputBatchFn = std::function; void Init(int num_key_columns, JoinType join_type, SwissTableForJoin* hash_table, + JoinResidualFilter* residual_filter, std::vector materialize, const std::vector* cmp, OutputBatchFn output_batch_fn); Status OnNextBatch(int64_t thread_id, const ExecBatch& keypayload_batch, @@ -809,7 +819,7 @@ class JoinProbeProcessor { JoinType join_type_; SwissTableForJoin* hash_table_; - const JoinResidualFilter* residual_filter_; + JoinResidualFilter* residual_filter_; // One element per thread // std::vector materialize_;