Skip to content

Commit

Permalink
Sketch basic filter logic
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Jan 1, 2024
1 parent bcaeaa8 commit 7d87de6
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 21 deletions.
164 changes: 143 additions & 21 deletions cpp/src/arrow/acero/swiss_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1855,6 +1855,99 @@ bool JoinMatchIterator::GetNextBatch(int num_rows_max, int* out_num_rows,
return (*out_num_rows) > 0;
}

Status JoinResidualFilter::FilterMatchBitVector(
const ExecBatch& keypayload_batch, int batch_start_row, int num_batch_rows,
int bit_match, const uint8_t* match_bitvector, const uint32_t* key_ids,
bool no_duplicate_keys, arrow::util::TempVectorStack* temp_stack,
int* num_passing_ids, uint16_t* passing_batch_row_ids,
uint32_t* passing_key_ids_maybe_null) {
ARROW_DCHECK(filter_ != literal(true));
*num_passing_ids = 0;
if (filter_.IsNullLiteral() || filter_ == literal(false)) {
return Status::OK();
}
auto materialize_batch_ids_buf =
arrow::util::TempVectorHolder<uint16_t>(temp_stack, minibatch_size_);
auto materialize_key_ids_buf =
arrow::util::TempVectorHolder<uint32_t>(temp_stack, minibatch_size_);
auto materialize_payload_ids_buf =
arrow::util::TempVectorHolder<uint32_t>(temp_stack, minibatch_size_);

JoinMatchIterator match_iterator;
match_iterator.SetLookupResult(num_batch_rows, batch_start_row, match_bitvector,
key_ids, no_duplicate_keys, key_to_payload_);
int num_matches_next = 0;
uint32_t row_id_last = std::numeric_limits<uint16_t>::max() + 1;
while (match_iterator.GetNextBatch(minibatch_size_, &num_matches_next,
materialize_batch_ids_buf.mutable_data(),
materialize_key_ids_buf.mutable_data(),
materialize_payload_ids_buf.mutable_data())) {
int num_filtered = 0;
RETURN_NOT_OK(FilterMatchRowIds(
keypayload_batch, num_matches_next, materialize_batch_ids_buf.mutable_data(),
materialize_key_ids_buf.mutable_data(),
materialize_payload_ids_buf.mutable_data(), passing_key_ids_maybe_null, false,
temp_stack, &num_filtered));
for (int ifiltered = 0; ifiltered < num_filtered; ++ifiltered) {
if (materialize_batch_ids_buf.mutable_data()[ifiltered] == row_id_last) {
continue;
}
passing_batch_row_ids[*num_passing_ids] =
materialize_batch_ids_buf.mutable_data()[ifiltered];
if (passing_key_ids_maybe_null) {
passing_key_ids_maybe_null[*num_passing_ids] =
materialize_key_ids_buf.mutable_data()[ifiltered];
}
row_id_last = materialize_batch_ids_buf.mutable_data()[ifiltered];
++(*num_passing_ids);
}
}
return Status::OK();
}

Status JoinResidualFilter::FilterMatchRowIds(const ExecBatch& keypayload_batch,
int num_batch_rows, uint16_t* batch_row_ids,
uint32_t* key_ids, uint32_t* payload_ids,
bool output_key_ids, bool output_payload_ids,
arrow::util::TempVectorStack* temp_stack,
int* num_passing_rows) {
ARROW_DCHECK(filter_ != literal(true));
*num_passing_rows = 0;
if (filter_.IsNullLiteral() || filter_ == literal(false)) {
return Status::OK();
}
ARROW_ASSIGN_OR_RAISE(Datum mask, EvalFilter());
if (mask.is_scalar()) {
const auto& mask_scalar = mask.scalar_as<BooleanScalar>();
if (mask_scalar.is_valid && mask_scalar.value) {
*num_passing_rows = num_batch_rows;
return Status::OK();
} else {
return Status::OK();
}
}
ARROW_DCHECK_EQ(mask.array()->offset, 0);
ARROW_DCHECK_EQ(mask.array()->length, static_cast<int64_t>(num_batch_rows));
const uint8_t* validity =
mask.array()->buffers[0] ? mask.array()->buffers[0]->data() : nullptr;
const uint8_t* comparisons = mask.array()->buffers[1]->data();
for (int irow = 0; irow < num_batch_rows; ++irow) {
bool is_valid = !validity || bit_util::GetBit(validity, irow);
bool is_cmp_true = bit_util::GetBit(comparisons, irow);
if (is_valid && is_cmp_true) {
batch_row_ids[*num_passing_rows] = batch_row_ids[irow];
if (output_key_ids) {
key_ids[*num_passing_rows] = key_ids[irow];
}
if (output_payload_ids) {
payload_ids[*num_passing_rows] = payload_ids[irow];
}
++(*num_passing_rows);
}
}
return Status::OK();
}

void JoinProbeProcessor::Init(int num_key_columns, JoinType join_type,
SwissTableForJoin* hash_table,
std::vector<JoinResultMaterialize*> materialize,
Expand Down Expand Up @@ -1893,6 +1986,8 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id,
auto hashes_buf = arrow::util::TempVectorHolder<uint32_t>(temp_stack, minibatch_size);
auto match_bitvector_buf = arrow::util::TempVectorHolder<uint8_t>(
temp_stack, static_cast<uint32_t>(bit_util::BytesForBits(minibatch_size)));
auto filtered_bitvector_buf = arrow::util::TempVectorHolder<uint8_t>(
temp_stack, static_cast<uint32_t>(bit_util::BytesForBits(minibatch_size)));
auto key_ids_buf = arrow::util::TempVectorHolder<uint32_t>(temp_stack, minibatch_size);
auto materialize_batch_ids_buf =
arrow::util::TempVectorHolder<uint16_t>(temp_stack, minibatch_size);
Expand Down Expand Up @@ -1923,33 +2018,48 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id,
if (join_type_ == JoinType::LEFT_SEMI || join_type_ == JoinType::LEFT_ANTI ||
join_type_ == JoinType::RIGHT_SEMI || join_type_ == JoinType::RIGHT_ANTI) {
int num_passing_ids = 0;
arrow::util::bit_util::bits_to_indexes(
(join_type_ == JoinType::LEFT_ANTI) ? 0 : 1, hardware_flags,
minibatch_size_next, match_bitvector_buf.mutable_data(), &num_passing_ids,
materialize_batch_ids_buf.mutable_data());
int bit_match = join_type_ == JoinType::LEFT_ANTI ? 0 : 1;
if (!residual_filter_) {
arrow::util::bit_util::bits_to_indexes(
bit_match, hardware_flags, minibatch_size_next,
match_bitvector_buf.mutable_data(), &num_passing_ids,
materialize_batch_ids_buf.mutable_data());
if (join_type_ == JoinType::RIGHT_SEMI || join_type_ == JoinType::RIGHT_ANTI) {
// For right-semi, right-anti joins: collect key ids of passing rows.
//
for (int i = 0; i < num_passing_ids; ++i) {
uint16_t id = materialize_batch_ids_buf.mutable_data()[i];
materialize_key_ids_buf.mutable_data()[i] = key_ids_buf.mutable_data()[id];
}
} else {
// For left-semi, left-anti joins: add base batch row index.
//
for (int i = 0; i < num_passing_ids; ++i) {
materialize_batch_ids_buf.mutable_data()[i] +=
static_cast<uint16_t>(minibatch_start);
}
}
} else {
bool no_duplicate_keys = (hash_table_->key_to_payload() == nullptr);
RETURN_NOT_OK(residual_filter_->FilterMatchBitVector(
keypayload_batch, minibatch_start, minibatch_size_next, bit_match,
match_bitvector_buf.mutable_data(), key_ids_buf.mutable_data(),
no_duplicate_keys, temp_stack, &num_passing_ids,
materialize_batch_ids_buf.mutable_data(),
join_type_ == JoinType::RIGHT_SEMI || join_type_ == JoinType::RIGHT_ANTI
? materialize_key_ids_buf.mutable_data()
: NULLPTR));
}

// For right-semi, right-anti joins: update has-match flags for the rows
// in hash table.
//
if (join_type_ == JoinType::RIGHT_SEMI || join_type_ == JoinType::RIGHT_ANTI) {
for (int i = 0; i < num_passing_ids; ++i) {
uint16_t id = materialize_batch_ids_buf.mutable_data()[i];
key_ids_buf.mutable_data()[i] = key_ids_buf.mutable_data()[id];
}
// For right-semi, right-anti joins: update has-match flags for the rows
// in hash table.
hash_table_->UpdateHasMatchForKeys(thread_id, num_passing_ids,
key_ids_buf.mutable_data());
materialize_key_ids_buf.mutable_data());
} else {
// For left-semi, left-anti joins: call materialize using match
// bit-vector.
// row ids.
//

// Add base batch row index.
//
for (int i = 0; i < num_passing_ids; ++i) {
materialize_batch_ids_buf.mutable_data()[i] +=
static_cast<uint16_t>(minibatch_start);
}

RETURN_NOT_OK(materialize_[thread_id]->AppendProbeOnly(
keypayload_batch, num_passing_ids, materialize_batch_ids_buf.mutable_data(),
[&](ExecBatch batch) {
Expand All @@ -1972,6 +2082,15 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id,
materialize_batch_ids_buf.mutable_data(),
materialize_key_ids_buf.mutable_data(),
materialize_payload_ids_buf.mutable_data())) {
if (residual_filter_) {
RETURN_NOT_OK(residual_filter_->FilterMatchRowIds(
keypayload_batch, num_matches_next,
materialize_batch_ids_buf.mutable_data(),
materialize_key_ids_buf.mutable_data(),
materialize_payload_ids_buf.mutable_data(), true,
!(no_duplicate_keys || no_payload_columns), temp_stack, &num_matches_next));
// TODO: Index to bit vector.
}
const uint16_t* materialize_batch_ids = materialize_batch_ids_buf.mutable_data();
const uint32_t* materialize_key_ids = materialize_key_ids_buf.mutable_data();
const uint32_t* materialize_payload_ids =
Expand Down Expand Up @@ -2003,6 +2122,9 @@ Status JoinProbeProcessor::OnNextBatch(int64_t thread_id,
// the other side of the join.
//
if (join_type_ == JoinType::LEFT_OUTER || join_type_ == JoinType::FULL_OUTER) {
if (residual_filter_) {
// TODO: and match bit vector.
}
int num_passing_ids = 0;
arrow::util::bit_util::bits_to_indexes(
/*bit_to_search=*/0, hardware_flags, minibatch_size_next,
Expand Down
28 changes: 28 additions & 0 deletions cpp/src/arrow/acero/swiss_join_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,33 @@ class JoinMatchIterator {
int current_match_for_row_;
};

class JoinResidualFilter {
public:
Status FilterMatchBitVector(const ExecBatch& keypayload_batch, int batch_start_row,
int num_batch_rows, int bit_match,
const uint8_t* match_bitvector, const uint32_t* key_ids,
bool no_duplicate_keys,
arrow::util::TempVectorStack* temp_stack,
int* num_passing_ids, uint16_t* passing_batch_row_ids,
uint32_t* passing_key_ids_maybe_null);

Status FilterMatchRowIds(const ExecBatch& keypayload_batch, int num_batch_rows,
uint16_t* batch_row_ids, uint32_t* key_ids,
uint32_t* payload_ids, bool output_key_ids,
bool output_payload_ids,
arrow::util::TempVectorStack* temp_stack,
int* num_passing_rows);

private:
Result<Datum> EvalFilter() { return Datum(); }

private:
// int64_t hardware_flags_;
int minibatch_size_;
Expression filter_;
const uint32_t* key_to_payload_;
};

// Implements entire processing of a probe side exec batch,
// provided the join hash table is already built and available.
//
Expand All @@ -760,6 +787,7 @@ class JoinProbeProcessor {
JoinType join_type_;

SwissTableForJoin* hash_table_;
JoinResidualFilter* residual_filter_;
// One element per thread
//
std::vector<JoinResultMaterialize*> materialize_;
Expand Down

0 comments on commit 7d87de6

Please sign in to comment.