Skip to content

Commit

Permalink
Reorg code to make the avx2 path more clear
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Aug 21, 2024
1 parent 47072a8 commit c002367
Show file tree
Hide file tree
Showing 3 changed files with 282 additions and 120 deletions.
260 changes: 152 additions & 108 deletions cpp/src/arrow/acero/swiss_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,138 +277,73 @@ Status RowArray::DecodeSelected(ResizableArrayData* output, int column_id,
uint32_t fixed_length = column_metadata.fixed_length;

#ifdef ARROW_HAVE_RUNTIME_AVX2
// Process fixed length columns
//
if (hardware_flags_ & arrow::internal::CpuInfo::AVX2) {
num_rows_processed = DecodeFixedLength_avx2();
num_rows_processed =
DecodeFixedLength_avx2(output, num_rows_before, column_id, fixed_length,
num_rows_to_append / 2, row_ids + num_rows_processed);
}
#else
num_rows_processed =
DecodeFixedLength(output, num_rows_before, column_id, fixed_length,
num_rows_to_append / 2, row_ids + num_rows_processed);
#endif

int num_rows_to_append_next = num_rows_to_append - num_rows_processed;

switch (fixed_length) {
case 0:
RowArrayAccessor::Visit(
rows_, column_id, num_rows_to_append_next, row_ids + num_rows_processed,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
bit_util::SetBitTo(output->mutable_data(1),
num_rows_before + num_rows_processed + i, *ptr != 0);
});
break;
case 1:
RowArrayAccessor::Visit(
rows_, column_id, num_rows_to_append_next, row_ids + num_rows_processed,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
output->mutable_data(1)[num_rows_before + num_rows_processed + i] = *ptr;
});
break;
case 2:
RowArrayAccessor::Visit(
rows_, column_id, num_rows_to_append_next, row_ids + num_rows_processed,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
reinterpret_cast<uint16_t*>(
output->mutable_data(1))[num_rows_before + num_rows_processed + i] =
*reinterpret_cast<const uint16_t*>(ptr);
});
break;
case 4:
RowArrayAccessor::Visit(
rows_, column_id, num_rows_to_append_next, row_ids + num_rows_processed,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
reinterpret_cast<uint32_t*>(
output->mutable_data(1))[num_rows_before + num_rows_processed + i] =
*reinterpret_cast<const uint32_t*>(ptr);
});
break;
case 8:
RowArrayAccessor::Visit(
rows_, column_id, num_rows_to_append_next, row_ids + num_rows_processed,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
reinterpret_cast<uint64_t*>(
output->mutable_data(1))[num_rows_before + num_rows_processed + i] =
*reinterpret_cast<const uint64_t*>(ptr);
});
break;
default:
RowArrayAccessor::Visit(
rows_, column_id, num_rows_to_append_next, row_ids,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
uint64_t* dst = reinterpret_cast<uint64_t*>(
output->mutable_data(1) +
num_bytes * (num_rows_before + num_rows_processed + i));
const uint64_t* src = reinterpret_cast<const uint64_t*>(ptr);
for (uint32_t word_id = 0;
word_id < bit_util::CeilDiv(num_bytes, sizeof(uint64_t)); ++word_id) {
arrow::util::SafeStore<uint64_t>(dst + word_id,
arrow::util::SafeLoad(src + word_id));
}
});
break;
}
std::ignore = DecodeFixedLength(
output, num_rows_before + num_rows_processed, column_id, fixed_length,
num_rows_to_append - num_rows_processed, row_ids + num_rows_processed);
} else {
// Process offsets for varying length columns
//
#ifdef ARROW_HAVE_RUNTIME_AVX2
if (hardware_flags_ & arrow::internal::CpuInfo::AVX2) {
num_rows_processed = DecodeOffsets_avx2();
num_rows_processed = DecodeOffsets_avx2(output, num_rows_before, column_id,
num_rows_to_append / 2, row_ids);
}
#else
num_rows_processed = DecodeOffsets(output, num_rows_before, column_id,
num_rows_to_append / 2, row_ids);
#endif

int num_rows_to_append_next = num_rows_to_append - num_rows_processed;

uint32_t* offsets = reinterpret_cast<uint32_t*>(output->mutable_data(1)) +
num_rows_before + num_rows_processed;
uint32_t sum = (num_rows_before + num_rows_processed == 0) ? 0 : offsets[0];
RowArrayAccessor::Visit(
rows_, column_id, num_rows_to_append_next, row_ids,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) { offsets[i] = num_bytes; });
for (int i = 0; i < num_rows_to_append_next; ++i) {
uint32_t length = offsets[i];
offsets[i] = sum;
sum += length;
}
offsets[num_rows_to_append] = sum;
std::ignore = DecodeOffsets(output, num_rows_before + num_rows_processed, column_id,
num_rows_to_append - num_rows_processed,
row_ids + num_rows_processed);

RETURN_NOT_OK(output->ResizeVaryingLengthBuffer());

// Process data for varying length columns
//
#ifdef ARROW_HAVE_RUNTIME_AVX2
if (hardware_flags_ & arrow::internal::CpuInfo::AVX2) {
int num_var_length_rows_processed = DecodeVarLength_avx2();
int num_var_length_rows_processed = DecodeVarLength_avx2(
output, num_rows_before, column_id, num_rows_to_append / 2, row_ids);
DCHECK_EQ(num_var_length_rows_processed, num_rows_processed);
}
#else
int num_var_length_rows_processed = DecodeVarLength(
output, num_rows_before, column_id, num_rows_to_append / 2, row_ids);
DCHECK_EQ(num_var_length_rows_processed, num_rows_processed);
#endif

RowArrayAccessor::Visit(
rows_, column_id, num_rows_to_append_next, row_ids + num_rows_processed,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
uint64_t* dst = reinterpret_cast<uint64_t*>(
output->mutable_data(2) +
reinterpret_cast<const uint32_t*>(
output->mutable_data(1))[num_rows_before + num_rows_processed + i]);
const uint64_t* src = reinterpret_cast<const uint64_t*>(ptr);
for (uint32_t word_id = 0;
word_id < bit_util::CeilDiv(num_bytes, sizeof(uint64_t)); ++word_id) {
arrow::util::SafeStore<uint64_t>(dst + word_id,
arrow::util::SafeLoad(src + word_id));
}
});
std::ignore = DecodeVarLength(output, num_rows_before + num_rows_processed, column_id,
num_rows_to_append - num_rows_processed,
row_ids + num_rows_processed);
}

// Process nulls
//
#ifdef ARROW_HAVE_RUNTIME_AVX2
if (hardware_flags_ & arrow::internal::CpuInfo::AVX2) {
int num_null_rows_processed = DecodeNulls_avx2();
int num_null_rows_processed = DecodeNulls_avx2(output, num_rows_before, column_id,
num_rows_to_append / 2, row_ids);
DCHECK_EQ(num_null_rows_processed, num_rows_processed);
}
#else
int num_null_rows_processed =
DecodeNulls(output, num_rows_before, column_id, num_rows_to_append / 2, row_ids);
DCHECK_EQ(num_null_rows_processed, num_rows_processed);
#endif
num_rows_to_append -= num_rows_processed;
RowArrayAccessor::VisitNulls(
rows_, column_id, num_rows_to_append, row_ids + num_rows_processed,
[&](int i, uint8_t value) {
bit_util::SetBitTo(output->mutable_data(0),
num_rows_before + num_rows_processed + i, value == 0);
});
std::ignore =
DecodeNulls(output, num_rows_before + num_rows_processed, column_id,
num_rows_to_append - num_rows_processed, row_ids + num_rows_processed);

return Status::OK();
}
Expand Down Expand Up @@ -487,6 +422,115 @@ void RowArray::DebugPrintToFile(const char* filename, bool print_sorted) const {
}
}

int RowArray::DecodeFixedLength(ResizableArrayData* output, int output_start_row,
int column_id, uint32_t fixed_length,
int num_rows_to_append, const uint32_t* row_ids) const {
switch (fixed_length) {
case 0:
RowArrayAccessor::Visit(rows_, column_id, num_rows_to_append, row_ids,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
bit_util::SetBitTo(output->mutable_data(1),
output_start_row + i, *ptr != 0);
});
break;
case 1:
RowArrayAccessor::Visit(rows_, column_id, num_rows_to_append, row_ids,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
output->mutable_data(1)[output_start_row + i] = *ptr;
});
break;
case 2:
RowArrayAccessor::Visit(
rows_, column_id, num_rows_to_append, row_ids,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
reinterpret_cast<uint16_t*>(output->mutable_data(1))[output_start_row + i] =
*reinterpret_cast<const uint16_t*>(ptr);
});
break;
case 4:
RowArrayAccessor::Visit(
rows_, column_id, num_rows_to_append, row_ids,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
reinterpret_cast<uint32_t*>(output->mutable_data(1))[output_start_row + i] =
*reinterpret_cast<const uint32_t*>(ptr);
});
break;
case 8:
RowArrayAccessor::Visit(
rows_, column_id, num_rows_to_append, row_ids,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
reinterpret_cast<uint64_t*>(output->mutable_data(1))[output_start_row + i] =
*reinterpret_cast<const uint64_t*>(ptr);
});
break;
default:
RowArrayAccessor::Visit(
rows_, column_id, num_rows_to_append, row_ids,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
uint64_t* dst = reinterpret_cast<uint64_t*>(
output->mutable_data(1) + num_bytes * (output_start_row + i));
const uint64_t* src = reinterpret_cast<const uint64_t*>(ptr);
for (uint32_t word_id = 0;
word_id < bit_util::CeilDiv(num_bytes, sizeof(uint64_t)); ++word_id) {
arrow::util::SafeStore<uint64_t>(dst + word_id,
arrow::util::SafeLoad(src + word_id));
}
});
break;
}

return num_rows_to_append;
}

int RowArray::DecodeOffsets(ResizableArrayData* output, int output_start_row,
int column_id, int num_rows_to_append,
const uint32_t* row_ids) const {
uint32_t* offsets =
reinterpret_cast<uint32_t*>(output->mutable_data(1)) + output_start_row;
uint32_t sum = (output_start_row == 0) ? 0 : offsets[0];
RowArrayAccessor::Visit(
rows_, column_id, num_rows_to_append, row_ids,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) { offsets[i] = num_bytes; });
for (int i = 0; i < num_rows_to_append; ++i) {
uint32_t length = offsets[i];
offsets[i] = sum;
sum += length;
}
offsets[num_rows_to_append] = sum;

return num_rows_to_append;
}

int RowArray::DecodeVarLength(ResizableArrayData* output, int output_start_row,
int column_id, int num_rows_to_append,
const uint32_t* row_ids) const {
RowArrayAccessor::Visit(
rows_, column_id, num_rows_to_append, row_ids,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
uint64_t* dst = reinterpret_cast<uint64_t*>(
output->mutable_data(2) + reinterpret_cast<const uint32_t*>(
output->mutable_data(1))[output_start_row + i]);
const uint64_t* src = reinterpret_cast<const uint64_t*>(ptr);
for (uint32_t word_id = 0;
word_id < bit_util::CeilDiv(num_bytes, sizeof(uint64_t)); ++word_id) {
arrow::util::SafeStore<uint64_t>(dst + word_id,
arrow::util::SafeLoad(src + word_id));
}
});

return num_rows_to_append;
}

int RowArray::DecodeNulls(ResizableArrayData* output, int output_start_row, int column_id,
int num_rows_to_append, const uint32_t* row_ids) const {
RowArrayAccessor::VisitNulls(
rows_, column_id, num_rows_to_append, row_ids, [&](int i, uint8_t value) {
bit_util::SetBitTo(output->mutable_data(0), output_start_row + i, value == 0);
});

return num_rows_to_append;
}

Status RowArrayMerge::PrepareForMerge(RowArray* target,
const std::vector<RowArray*>& sources,
std::vector<int64_t>* first_target_row_id,
Expand Down Expand Up @@ -1917,9 +1961,9 @@ bool JoinMatchIterator::GetNextBatch(int num_rows_max, int* out_num_rows,

namespace {

// Given match_bitvector identifies that there is a match for row[batch_start_row + i] in
// given input batch if bit match_bitvector[i] == passing_bit. Collect all the passing row
// ids according to the given match_bitvector.
// Given match_bitvector identifies that there is a match for row[batch_start_row + i]
// in given input batch if bit match_bitvector[i] == passing_bit. Collect all the
// passing row ids according to the given match_bitvector.
//
void CollectPassingBatchIds(int passing_bit, int64_t hardware_flags, int batch_start_row,
int num_batch_rows, const uint8_t* match_bitvector,
Expand Down Expand Up @@ -2046,8 +2090,8 @@ Status JoinResidualFilter::FilterLeftSemi(const ExecBatch& keypayload_batch,
auto match_payload_ids_buf =
arrow::util::TempVectorHolder<uint32_t>(temp_stack, minibatch_size_);

// Inner matching is necessary for non-trivial filter. Only until evaluating filter for
// all matches of the same row can we be sure that it's not passing (it could pass
// Inner matching is necessary for non-trivial filter. Only until evaluating filter
// for all matches of the same row can we be sure that it's not passing (it could pass
// earlier though).
//
JoinMatchIterator match_iterator;
Expand Down
Loading

0 comments on commit c002367

Please sign in to comment.