Skip to content

Commit

Permalink
Finish impl
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Jan 2, 2024
1 parent cccf9e1 commit 8bfb8d7
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 41 deletions.
6 changes: 2 additions & 4 deletions cpp/src/arrow/acero/hash_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -740,13 +740,11 @@ class HashJoinNode : public ExecNode, public TracedNode {
// Create hash join implementation object
// SwissJoin does not support:
// a) 64-bit string offsets
// b) residual predicates
// c) dictionaries
// b) dictionaries
//
bool use_swiss_join;
#if ARROW_LITTLE_ENDIAN
use_swiss_join = (filter == literal(true)) && !schema_mgr->HasDictionaries() &&
!schema_mgr->HasLargeBinary();
use_swiss_join = !schema_mgr->HasDictionaries() && !schema_mgr->HasLargeBinary();
#else
use_swiss_join = false;
#endif
Expand Down
131 changes: 102 additions & 29 deletions cpp/src/arrow/acero/swiss_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1855,6 +1855,68 @@ bool JoinMatchIterator::GetNextBatch(int num_rows_max, int* out_num_rows,
return (*out_num_rows) > 0;
}

void JoinResidualFilter::Init(Expression filter, int minibatch_size, QueryContext* ctx,
MemoryPool* pool, int64_t hardware_flags,
const HashJoinProjectionMaps* probe_schemas,
const HashJoinProjectionMaps* build_schemas) {
filter_ = std::move(filter);
if (filter_ == literal(true)) {
return;
}

minibatch_size_ = minibatch_size;
ctx_ = ctx;
pool_ = pool;
hardware_flags_ = hardware_flags;
probe_schemas_ = probe_schemas;
build_schemas_ = build_schemas;

{
probe_filter_to_key_and_payload_.resize(
probe_schemas_->num_cols(HashJoinProjection::FILTER));
int num_key_cols = probe_schemas_->num_cols(HashJoinProjection::KEY);
auto to_key =
probe_schemas_->map(HashJoinProjection::FILTER, HashJoinProjection::KEY);
auto to_payload =
probe_schemas_->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
for (int i = 0; static_cast<size_t>(i) < probe_filter_to_key_and_payload_.size();
++i) {
if (auto idx = to_key.get(i); idx != SchemaProjectionMap::kMissingField) {
probe_filter_to_key_and_payload_[i] = idx;
} else if (idx = to_payload.get(i); idx != SchemaProjectionMap::kMissingField) {
probe_filter_to_key_and_payload_[i] = idx + num_key_cols;
} else {
DCHECK(false);
}
}
}

{
int num_columns = build_schemas_->num_cols(HashJoinProjection::FILTER);
auto to_key =
build_schemas_->map(HashJoinProjection::FILTER, HashJoinProjection::KEY);
auto to_payload =
build_schemas_->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
for (int i = 0; i < num_columns; ++i) {
if (to_key.get(i) != SchemaProjectionMap::kMissingField) {
num_build_keys_referred_++;
} else if (to_payload.get(i) != SchemaProjectionMap::kMissingField) {
num_build_payloads_referred_++;
} else {
DCHECK(false);
}
}
}
}

void JoinResidualFilter::SetBuildSide(const RowArray* build_keys,
const RowArray* build_payloads,
const uint32_t* key_to_payload) {
build_keys_ = build_keys;
build_payloads_ = build_payloads;
key_to_payload_ = key_to_payload;
}

Status JoinResidualFilter::FilterMatchBitVector(
const ExecBatch& keypayload_batch, int batch_start_row, int num_batch_rows,
int bit_match, const uint8_t* match_bitvector, const uint32_t* key_ids,
Expand All @@ -1867,7 +1929,10 @@ Status JoinResidualFilter::FilterMatchBitVector(
return Status::OK();
}

if (build_filter_to_key_.empty() && build_filter_to_payload_.empty()) {
if (num_build_keys_referred_ == 0 && num_build_payloads_referred_ == 0) {
// If filter refers no column in right table,
// TODO
//
arrow::util::bit_util::bits_to_indexes(bit_match, hardware_flags_, num_batch_rows,
match_bitvector, num_passing_ids,
passing_batch_row_ids);
Expand Down Expand Up @@ -2001,8 +2066,8 @@ Result<ExecBatch> JoinResidualFilter::MaterializeFilterInput(
const uint32_t* key_ids_maybe_null, const uint32_t* payload_ids_maybe_null) const {
ExecBatch out;
out.length = num_batch_rows;
out.values.resize(probe_filter_to_key_and_payload_.size() +
build_filter_to_key_.size() + build_filter_to_payload_.size());
out.values.resize(probe_filter_to_key_and_payload_.size() + num_build_keys_referred_ +
num_build_payloads_referred_);

if (probe_filter_to_key_and_payload_.size() > 0) {
ExecBatchBuilder probe_batch_builder;
Expand All @@ -2017,31 +2082,29 @@ Result<ExecBatch> JoinResidualFilter::MaterializeFilterInput(
}
}

if (build_filter_to_key_.size() > 0) {
ARROW_DCHECK(key_ids_maybe_null);
for (size_t i = 0; i < build_filter_to_key_.size(); ++i) {
int key_idx = build_filter_to_key_[i];
ResizableArrayData build_key;
build_key.Init(build_schemas_->data_type(HashJoinProjection::KEY, key_idx), pool_,
bit_util::Log2(num_batch_rows));
RETURN_NOT_OK(build_keys_->DecodeSelected(&build_key, key_idx, num_batch_rows,
key_ids_maybe_null, pool_));
out.values[probe_filter_to_key_and_payload_.size() + i] = build_key.array_data();
}
}

if (build_filter_to_payload_.size() > 0) {
ARROW_DCHECK(payload_ids_maybe_null);
for (size_t i = 0; i < build_filter_to_payload_.size(); ++i) {
int payload_idx = build_filter_to_payload_[i];
ResizableArrayData build_payload;
build_payload.Init(
build_schemas_->data_type(HashJoinProjection::PAYLOAD, payload_idx), pool_,
bit_util::Log2(num_batch_rows));
RETURN_NOT_OK(build_payloads_->DecodeSelected(
&build_payload, payload_idx, num_batch_rows, payload_ids_maybe_null, pool_));
out.values[probe_filter_to_key_and_payload_.size() + build_filter_to_key_.size() +
i] = build_payload.array_data();
if (num_build_keys_referred_ > 0 || num_build_payloads_referred_ > 0) {
ARROW_DCHECK(num_build_keys_referred_ == 0 || key_ids_maybe_null);
ARROW_DCHECK(num_build_payloads_referred_ == 0 || payload_ids_maybe_null);

int num_build_cols = build_schemas_->num_cols(HashJoinProjection::FILTER);
auto to_key =
build_schemas_->map(HashJoinProjection::FILTER, HashJoinProjection::KEY);
auto to_payload =
build_schemas_->map(HashJoinProjection::FILTER, HashJoinProjection::PAYLOAD);
for (int i = 0; i < num_build_cols; ++i) {
ResizableArrayData column_data;
column_data.Init(build_schemas_->data_type(HashJoinProjection::FILTER, i), pool_,
bit_util::Log2(num_batch_rows));
if (auto idx = to_key.get(i); idx != SchemaProjectionMap::kMissingField) {
RETURN_NOT_OK(build_keys_->DecodeSelected(&column_data, idx, num_batch_rows,
key_ids_maybe_null, pool_));
} else if (idx = to_payload.get(i); idx != SchemaProjectionMap::kMissingField) {
RETURN_NOT_OK(build_payloads_->DecodeSelected(&column_data, idx, num_batch_rows,
payload_ids_maybe_null, pool_));
} else {
ARROW_DCHECK(false);
}
out.values[probe_filter_to_key_and_payload_.size() + i] = column_data.array_data();
}
}

Expand All @@ -2050,12 +2113,14 @@ Result<ExecBatch> JoinResidualFilter::MaterializeFilterInput(

void JoinProbeProcessor::Init(int num_key_columns, JoinType join_type,
SwissTableForJoin* hash_table,
JoinResidualFilter* residual_filter,
std::vector<JoinResultMaterialize*> materialize,
const std::vector<JoinKeyCmp>* cmp,
OutputBatchFn output_batch_fn) {
num_key_columns_ = num_key_columns;
join_type_ = join_type;
hash_table_ = hash_table;
residual_filter_ = residual_filter;
materialize_.resize(materialize.size());
for (size_t i = 0; i < materialize.size(); ++i) {
materialize_[i] = materialize[i];
Expand Down Expand Up @@ -2321,8 +2386,13 @@ class SwissJoin : public HashJoinImpl {
materialize[i] = &local_states_[i].materialize;
}

int minibatch_size = hash_table_.keys()->swiss_table()->minibatch_size();
residual_filter_.Init(std::move(filter), minibatch_size, ctx_, pool_, hardware_flags_,
proj_map_left, proj_map_right);

probe_processor_.Init(proj_map_left->num_cols(HashJoinProjection::KEY), join_type_,
&hash_table_, materialize, &key_cmp_, output_batch_callback_);
&hash_table_, &residual_filter_, materialize, &key_cmp_,
output_batch_callback_);

InitTaskGroups();

Expand Down Expand Up @@ -2525,6 +2595,9 @@ class SwissJoin : public HashJoinImpl {
}
hash_table_ready_.store(true);

residual_filter_.SetBuildSide(hash_table_.keys()->keys(), hash_table_.payloads(),
hash_table_.key_to_payload());

return build_finished_callback_(thread_id);
}

Expand Down
26 changes: 18 additions & 8 deletions cpp/src/arrow/acero/swiss_join_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,13 @@ class JoinMatchIterator {

class JoinResidualFilter {
public:
void Init(Expression filter, int minibatch_size, QueryContext* ctx, MemoryPool* pool,
int64_t hardware_flags, const HashJoinProjectionMaps* probe_schemas,
const HashJoinProjectionMaps* build_schemas);

void SetBuildSide(const RowArray* build_keys, const RowArray* build_payloads,
const uint32_t* key_to_payload);

bool IsTrivial() const { return filter_ == literal(true); }

Status FilterMatchBitVector(const ExecBatch& keypayload_batch, int batch_start_row,
Expand Down Expand Up @@ -768,21 +775,23 @@ class JoinResidualFilter {
const uint32_t* payload_ids_maybe_null) const;

private:
Expression filter_;
int minibatch_size_;

QueryContext* ctx_;
int64_t hardware_flags_;
MemoryPool* pool_;
int minibatch_size_;
// const HashJoinProjectionMaps* probe_schemas_;
int64_t hardware_flags_;

const HashJoinProjectionMaps* probe_schemas_;
const HashJoinProjectionMaps* build_schemas_;
Expression filter_;

std::vector<int> probe_filter_to_key_and_payload_;
std::vector<int> build_filter_to_key_;
std::vector<int> build_filter_to_payload_;
int num_build_keys_referred_ = 0;
int num_build_payloads_referred_ = 0;

const uint32_t* key_to_payload_;
const RowArray* build_keys_;
const RowArray* build_payloads_;
const uint32_t* key_to_payload_;
};

// Implements entire processing of a probe side exec batch,
Expand All @@ -793,6 +802,7 @@ class JoinProbeProcessor {
using OutputBatchFn = std::function<Status(int64_t, ExecBatch)>;

void Init(int num_key_columns, JoinType join_type, SwissTableForJoin* hash_table,
JoinResidualFilter* residual_filter,
std::vector<JoinResultMaterialize*> materialize,
const std::vector<JoinKeyCmp>* cmp, OutputBatchFn output_batch_fn);
Status OnNextBatch(int64_t thread_id, const ExecBatch& keypayload_batch,
Expand All @@ -809,7 +819,7 @@ class JoinProbeProcessor {
JoinType join_type_;

SwissTableForJoin* hash_table_;
const JoinResidualFilter* residual_filter_;
JoinResidualFilter* residual_filter_;
// One element per thread
//
std::vector<JoinResultMaterialize*> materialize_;
Expand Down

0 comments on commit 8bfb8d7

Please sign in to comment.