Skip to content

Commit

Permalink
Done
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Jun 17, 2024
1 parent 3c62e1f commit 16d9e86
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 27 deletions.
36 changes: 23 additions & 13 deletions cpp/src/arrow/compute/row/compare_internal_avx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,24 @@ uint32_t KeyCompare::CompareBinaryColumnToRowHelper_avx2(
}
}

namespace {

inline __m256i UnsignedOffsetSafeGather32(int const* base, __m256i offset,
const int scale = 1) {
auto normalized_base = base + 0x80000000ull / sizeof(int);
__m256i normalized_offset = _mm256_sub_epi32(offset, _mm256_set1_epi32(0x80000000));
return _mm256_i32gather_epi32(normalized_base, normalized_offset, 1);
}

inline __m256i UnsignedOffsetSafeGather64(long long const* base, __m128i offset,
const int scale = 1) {
auto normalized_base = base + 0x80000000ull / sizeof(long long);
__m128i normalized_offset = _mm_sub_epi32(offset, _mm_set1_epi32(0x80000000));
return _mm256_i32gather_epi64(normalized_base, normalized_offset, 1);
}

} // namespace

template <int column_width>
inline uint64_t CompareSelected8_avx2(const uint8_t* left_base, const uint8_t* right_base,
__m256i irow_left, __m256i offset_right,
Expand Down Expand Up @@ -282,11 +300,7 @@ inline uint64_t CompareSelected8_avx2(const uint8_t* left_base, const uint8_t* r
ARROW_DCHECK(false);
}

// const int* normalized_right_base = (const int*)(right_base + 0x80000000ull);
// __m256i normalized_offset_right =
// _mm256_sub_epi32(offset_right, _mm256_set1_epi32(0x80000000));
// __m256i right = _mm256_i32gather_epi32(normalized_right_base, normalized_offset_right, 1);
__m256i right = _mm256_i32gather_epi32(right_base, offset_right, 1);
__m256i right = UnsignedOffsetSafeGather32((int const*)right_base, offset_right, 1);
if (column_width != sizeof(uint32_t)) {
constexpr uint32_t mask = column_width == 0 || column_width == 1 ? 0xff : 0xffff;
right = _mm256_and_si256(right, _mm256_set1_epi32(mask));
Expand Down Expand Up @@ -335,11 +349,7 @@ inline uint64_t Compare8_avx2(const uint8_t* left_base, const uint8_t* right_bas
ARROW_DCHECK(false);
}

// const int* normalized_right_base = (const int*)(right_base + 0x80000000ull);
// __m256i normalized_offset_right =
// _mm256_sub_epi32(offset_right, _mm256_set1_epi32(0x80000000));
// __m256i right = _mm256_i32gather_epi32(normalized_right_base, normalized_offset_right, 1);
__m256i right = _mm256_i32gather_epi32(right_base, offset_right, 1);
__m256i right = UnsignedOffsetSafeGather32((int const*)right_base, offset_right, 1);
if (column_width != sizeof(uint32_t)) {
constexpr uint32_t mask = column_width == 0 || column_width == 1 ? 0xff : 0xffff;
right = _mm256_and_si256(right, _mm256_set1_epi32(mask));
Expand Down Expand Up @@ -376,9 +386,9 @@ inline uint64_t Compare8_64bit_avx2(const uint8_t* left_base, const uint8_t* rig
auto right_base_i64 =
reinterpret_cast<const arrow::util::int64_for_gather_t*>(right_base);
__m256i right_lo =
_mm256_i32gather_epi64(right_base_i64, _mm256_castsi256_si128(offset_right), 1);
__m256i right_hi = _mm256_i32gather_epi64(right_base_i64,
_mm256_extracti128_si256(offset_right, 1), 1);
UnsignedOffsetSafeGather64(right_base_i64, _mm256_castsi256_si128(offset_right), 1);
__m256i right_hi = UnsignedOffsetSafeGather64(
right_base_i64, _mm256_extracti128_si256(offset_right, 1), 1);
uint32_t result_lo = _mm256_movemask_epi8(_mm256_cmpeq_epi64(left_lo, right_lo));
uint32_t result_hi = _mm256_movemask_epi8(_mm256_cmpeq_epi64(left_hi, right_hi));
return result_lo | (static_cast<uint64_t>(result_hi) << 32);
Expand Down
66 changes: 52 additions & 14 deletions cpp/src/arrow/compute/row/compare_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,33 +173,44 @@ TEST(KeyCompare, CompareColumnsToRowsLarge) {
GTEST_SKIP() << "Test only works on 64-bit platforms";
}

// The idea of this case is to create a row table using one fixed length column and one
// The idea of this case is to create a row table using several fixed length columns one
// var length column (so the row is hence var length and has offset buffer), with the
// overall data size exceeding 2GB. Then compare each row with itself.
constexpr int64_t two_gb = 2ll * 1024ll * 1024ll * 1024ll;
// The compare function requires the row id of the left column to be uint16_t, hence the
// number of rows.
constexpr int64_t num_rows = std::numeric_limits<uint16_t>::max() + 1;
const std::vector<std::shared_ptr<DataType>> fixed_length_types{uint64(), uint32()};
// The var length column should be a little smaller than 2GB to WAR the capacity
// limitation in the var length builder.
constexpr int32_t var_length = two_gb / num_rows - 1;
const int32_t fixed_length = uint32()->byte_width();
auto row_size =
std::accumulate(fixed_length_types.begin(), fixed_length_types.end(), var_length,
[](int64_t acc, const std::shared_ptr<DataType>& type) {
return acc + type->byte_width();
});
// The overall size should be larger than 2GB.
ASSERT_GT((var_length + fixed_length) * num_rows, two_gb);
ASSERT_GT(row_size * num_rows, two_gb);

MemoryPool* pool = default_memory_pool();
TempVectorStack stack;
ASSERT_OK(stack.Init(pool, KeyCompare::CompareColumnsToRowsTempStackUsage(num_rows)));

// A fixed length array containing random numbers.
ASSERT_OK_AND_ASSIGN(auto column_fixed_length,
::arrow::gen::Random(uint32())->Generate(num_rows));
// A var length array containing 'X' repeated var_length times.
ASSERT_OK_AND_ASSIGN(
auto column_var_length,
::arrow::gen::Constant(std::make_shared<BinaryScalar>(std::string(var_length, 'X')))
->Generate(num_rows));
ExecBatch batch({column_fixed_length, column_var_length}, num_rows);
std::vector<Datum> columns;
{
// Several fixed length arrays containing random content.
for (const auto& type : fixed_length_types) {
ASSERT_OK_AND_ASSIGN(auto column, ::arrow::gen::Random(type)->Generate(num_rows));
columns.push_back(std::move(column));
}
// A var length array containing 'X' repeated var_length times.
ASSERT_OK_AND_ASSIGN(auto column_var_length,
::arrow::gen::Constant(
std::make_shared<BinaryScalar>(std::string(var_length, 'X')))
->Generate(num_rows));
columns.push_back(std::move(column_var_length));
}
ExecBatch batch(std::move(columns), num_rows);

std::vector<KeyColumnMetadata> column_metadatas;
ASSERT_OK(ColumnMetadatasFromExecBatch(batch, &column_metadatas));
Expand Down Expand Up @@ -231,16 +242,43 @@ TEST(KeyCompare, CompareColumnsToRowsLarge) {
LightContext ctx{CpuInfo::GetInstance()->hardware_flags(), &stack};

{
// No selection, output no match row ids.
uint32_t num_rows_no_match;
std::vector<uint16_t> row_ids_out(num_rows);
KeyCompare::CompareColumnsToRows(num_rows, NULLPTR, row_ids_left.data(), &ctx,
&num_rows_no_match, row_ids_out.data(),
KeyCompare::CompareColumnsToRows(
num_rows, /*sel_left_maybe_null=*/NULLPTR, row_ids_left.data(), &ctx,
&num_rows_no_match, row_ids_out.data(), column_arrays, row_table, true, NULLPTR);
ASSERT_EQ(num_rows_no_match, 0);
}

{
// With selection, output no match row ids.
uint32_t num_rows_no_match;
std::vector<uint16_t> row_ids_out(num_rows);
std::vector<uint16_t> selection_left(num_rows);
std::iota(selection_left.begin(), selection_left.end(), 0);
KeyCompare::CompareColumnsToRows(num_rows, selection_left.data(), row_ids_left.data(),
&ctx, &num_rows_no_match, row_ids_out.data(),
column_arrays, row_table, true, NULLPTR);
ASSERT_EQ(num_rows_no_match, 0);
}

{
// No selection, output match bit vector.
std::vector<uint8_t> match_bitvector(BytesForBits(num_rows));
KeyCompare::CompareColumnsToRows(
num_rows, /*sel_left_maybe_null=*/NULLPTR, row_ids_left.data(), &ctx, NULLPTR,
NULLPTR, column_arrays, row_table, true, match_bitvector.data());

ASSERT_EQ(arrow::internal::CountSetBits(match_bitvector.data(), 0, num_rows),
num_rows);
}

{
// With selection, output match bit vector.
std::vector<uint8_t> match_bitvector(BytesForBits(num_rows));
std::vector<uint16_t> selection_left(num_rows);
std::iota(selection_left.begin(), selection_left.end(), 0);
KeyCompare::CompareColumnsToRows(num_rows, NULLPTR, row_ids_left.data(), &ctx,
NULLPTR, NULLPTR, column_arrays, row_table, true,
match_bitvector.data());
Expand Down

0 comments on commit 16d9e86

Please sign in to comment.