From 16d9e86d97f7ccec95dfea7a3b8e46471b319c3b Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Tue, 18 Jun 2024 01:10:44 +0800 Subject: [PATCH] Done --- .../compute/row/compare_internal_avx2.cc | 36 ++++++---- cpp/src/arrow/compute/row/compare_test.cc | 66 +++++++++++++++---- 2 files changed, 75 insertions(+), 27 deletions(-) diff --git a/cpp/src/arrow/compute/row/compare_internal_avx2.cc b/cpp/src/arrow/compute/row/compare_internal_avx2.cc index 82991a8f1d162..cbe7ce98cae7f 100644 --- a/cpp/src/arrow/compute/row/compare_internal_avx2.cc +++ b/cpp/src/arrow/compute/row/compare_internal_avx2.cc @@ -252,6 +252,24 @@ uint32_t KeyCompare::CompareBinaryColumnToRowHelper_avx2( } } +namespace { + +inline __m256i UnsignedOffsetSafeGather32(int const* base, __m256i offset, + const int scale = 1) { + auto normalized_base = base + 0x80000000ull / sizeof(int); + __m256i normalized_offset = _mm256_sub_epi32(offset, _mm256_set1_epi32(0x80000000)); + return _mm256_i32gather_epi32(normalized_base, normalized_offset, 1); +} + +inline __m256i UnsignedOffsetSafeGather64(long long const* base, __m128i offset, + const int scale = 1) { + auto normalized_base = base + 0x80000000ull / sizeof(long long); + __m128i normalized_offset = _mm_sub_epi32(offset, _mm_set1_epi32(0x80000000)); + return _mm256_i32gather_epi64(normalized_base, normalized_offset, 1); +} + +} // namespace + template inline uint64_t CompareSelected8_avx2(const uint8_t* left_base, const uint8_t* right_base, __m256i irow_left, __m256i offset_right, @@ -282,11 +300,7 @@ inline uint64_t CompareSelected8_avx2(const uint8_t* left_base, const uint8_t* r ARROW_DCHECK(false); } - // const int* normalized_right_base = (const int*)(right_base + 0x80000000ull); - // __m256i normalized_offset_right = - // _mm256_sub_epi32(offset_right, _mm256_set1_epi32(0x80000000)); - // __m256i right = _mm256_i32gather_epi32(normalized_right_base, normalized_offset_right, 1); - __m256i right = _mm256_i32gather_epi32(right_base, offset_right, 1); + __m256i right = UnsignedOffsetSafeGather32((int const*)right_base, offset_right, 1); if (column_width != sizeof(uint32_t)) { constexpr uint32_t mask = column_width == 0 || column_width == 1 ? 0xff : 0xffff; right = _mm256_and_si256(right, _mm256_set1_epi32(mask)); @@ -335,11 +349,7 @@ inline uint64_t Compare8_avx2(const uint8_t* left_base, const uint8_t* right_bas ARROW_DCHECK(false); } - // const int* normalized_right_base = (const int*)(right_base + 0x80000000ull); - // __m256i normalized_offset_right = - // _mm256_sub_epi32(offset_right, _mm256_set1_epi32(0x80000000)); - // __m256i right = _mm256_i32gather_epi32(normalized_right_base, normalized_offset_right, 1); - __m256i right = _mm256_i32gather_epi32(right_base, offset_right, 1); + __m256i right = UnsignedOffsetSafeGather32((int const*)right_base, offset_right, 1); if (column_width != sizeof(uint32_t)) { constexpr uint32_t mask = column_width == 0 || column_width == 1 ? 0xff : 0xffff; right = _mm256_and_si256(right, _mm256_set1_epi32(mask)); @@ -376,9 +386,9 @@ inline uint64_t Compare8_64bit_avx2(const uint8_t* left_base, const uint8_t* rig auto right_base_i64 = reinterpret_cast(right_base); __m256i right_lo = - _mm256_i32gather_epi64(right_base_i64, _mm256_castsi256_si128(offset_right), 1); - __m256i right_hi = _mm256_i32gather_epi64(right_base_i64, - _mm256_extracti128_si256(offset_right, 1), 1); + UnsignedOffsetSafeGather64(right_base_i64, _mm256_castsi256_si128(offset_right), 1); + __m256i right_hi = UnsignedOffsetSafeGather64( + right_base_i64, _mm256_extracti128_si256(offset_right, 1), 1); uint32_t result_lo = _mm256_movemask_epi8(_mm256_cmpeq_epi64(left_lo, right_lo)); uint32_t result_hi = _mm256_movemask_epi8(_mm256_cmpeq_epi64(left_hi, right_hi)); return result_lo | (static_cast(result_hi) << 32); diff --git a/cpp/src/arrow/compute/row/compare_test.cc b/cpp/src/arrow/compute/row/compare_test.cc index f43b7a7b477ff..ca6768eb37145 100644 --- a/cpp/src/arrow/compute/row/compare_test.cc +++ b/cpp/src/arrow/compute/row/compare_test.cc @@ -173,33 +173,44 @@ TEST(KeyCompare, CompareColumnsToRowsLarge) { GTEST_SKIP() << "Test only works on 64-bit platforms"; } - // The idea of this case is to create a row table using one fixed length column and one + // The idea of this case is to create a row table using several fixed length columns one // var length column (so the row is hence var length and has offset buffer), with the // overall data size exceeding 2GB. Then compare each row with itself. constexpr int64_t two_gb = 2ll * 1024ll * 1024ll * 1024ll; // The compare function requires the row id of the left column to be uint16_t, hence the // number of rows. constexpr int64_t num_rows = std::numeric_limits::max() + 1; + const std::vector> fixed_length_types{uint64(), uint32()}; // The var length column should be a little smaller than 2GB to WAR the capacity // limitation in the var length builder. constexpr int32_t var_length = two_gb / num_rows - 1; - const int32_t fixed_length = uint32()->byte_width(); + auto row_size = + std::accumulate(fixed_length_types.begin(), fixed_length_types.end(), var_length, + [](int64_t acc, const std::shared_ptr& type) { + return acc + type->byte_width(); + }); // The overall size should be larger than 2GB. - ASSERT_GT((var_length + fixed_length) * num_rows, two_gb); + ASSERT_GT(row_size * num_rows, two_gb); MemoryPool* pool = default_memory_pool(); TempVectorStack stack; ASSERT_OK(stack.Init(pool, KeyCompare::CompareColumnsToRowsTempStackUsage(num_rows))); - // A fixed length array containing random numbers. - ASSERT_OK_AND_ASSIGN(auto column_fixed_length, - ::arrow::gen::Random(uint32())->Generate(num_rows)); - // A var length array containing 'X' repeated var_length times. - ASSERT_OK_AND_ASSIGN( - auto column_var_length, - ::arrow::gen::Constant(std::make_shared(std::string(var_length, 'X'))) - ->Generate(num_rows)); - ExecBatch batch({column_fixed_length, column_var_length}, num_rows); + std::vector columns; + { + // Several fixed length arrays containing random content. + for (const auto& type : fixed_length_types) { + ASSERT_OK_AND_ASSIGN(auto column, ::arrow::gen::Random(type)->Generate(num_rows)); + columns.push_back(std::move(column)); + } + // A var length array containing 'X' repeated var_length times. + ASSERT_OK_AND_ASSIGN(auto column_var_length, + ::arrow::gen::Constant( + std::make_shared(std::string(var_length, 'X'))) + ->Generate(num_rows)); + columns.push_back(std::move(column_var_length)); + } + ExecBatch batch(std::move(columns), num_rows); std::vector column_metadatas; ASSERT_OK(ColumnMetadatasFromExecBatch(batch, &column_metadatas)); @@ -231,16 +242,43 @@ TEST(KeyCompare, CompareColumnsToRowsLarge) { LightContext ctx{CpuInfo::GetInstance()->hardware_flags(), &stack}; { + // No selection, output no match row ids. uint32_t num_rows_no_match; std::vector row_ids_out(num_rows); - KeyCompare::CompareColumnsToRows(num_rows, NULLPTR, row_ids_left.data(), &ctx, - &num_rows_no_match, row_ids_out.data(), + KeyCompare::CompareColumnsToRows( + num_rows, /*sel_left_maybe_null=*/NULLPTR, row_ids_left.data(), &ctx, + &num_rows_no_match, row_ids_out.data(), column_arrays, row_table, true, NULLPTR); + ASSERT_EQ(num_rows_no_match, 0); + } + + { + // With selection, output no match row ids. + uint32_t num_rows_no_match; + std::vector row_ids_out(num_rows); + std::vector selection_left(num_rows); + std::iota(selection_left.begin(), selection_left.end(), 0); + KeyCompare::CompareColumnsToRows(num_rows, selection_left.data(), row_ids_left.data(), + &ctx, &num_rows_no_match, row_ids_out.data(), column_arrays, row_table, true, NULLPTR); ASSERT_EQ(num_rows_no_match, 0); } { + // No selection, output match bit vector. + std::vector match_bitvector(BytesForBits(num_rows)); + KeyCompare::CompareColumnsToRows( + num_rows, /*sel_left_maybe_null=*/NULLPTR, row_ids_left.data(), &ctx, NULLPTR, + NULLPTR, column_arrays, row_table, true, match_bitvector.data()); + + ASSERT_EQ(arrow::internal::CountSetBits(match_bitvector.data(), 0, num_rows), + num_rows); + } + + { + // With selection, output match bit vector. std::vector match_bitvector(BytesForBits(num_rows)); + std::vector selection_left(num_rows); + std::iota(selection_left.begin(), selection_left.end(), 0); KeyCompare::CompareColumnsToRows(num_rows, NULLPTR, row_ids_left.data(), &ctx, NULLPTR, NULLPTR, column_arrays, row_table, true, match_bitvector.data());