Skip to content

Commit

Permalink
Fix conditional compilation logic for runtime-selected AVX2 functions
Browse files Browse the repository at this point in the history
  • Loading branch information
pitrou committed Jul 13, 2023
1 parent 5aef169 commit 241a5b4
Show file tree
Hide file tree
Showing 22 changed files with 52 additions and 78 deletions.
22 changes: 11 additions & 11 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ function(ADD_ARROW_BENCHMARK REL_TEST_NAME)
${ARG_UNPARSED_ARGUMENTS})
endfunction()

macro(append_avx2_src SRC)
macro(append_runtime_avx2_src SRC)
if(ARROW_HAVE_RUNTIME_AVX2)
list(APPEND ARROW_SRCS ${SRC})
set_source_files_properties(${SRC} PROPERTIES SKIP_PRECOMPILE_HEADERS ON)
set_source_files_properties(${SRC} PROPERTIES COMPILE_FLAGS ${ARROW_AVX2_FLAG})
endif()
endmacro()

macro(append_avx512_src SRC)
macro(append_runtime_avx512_src SRC)
if(ARROW_HAVE_RUNTIME_AVX512)
list(APPEND ARROW_SRCS ${SRC})
set_source_files_properties(${SRC} PROPERTIES SKIP_PRECOMPILE_HEADERS ON)
Expand Down Expand Up @@ -254,8 +254,8 @@ if(ARROW_JEMALLOC)
PROPERTIES SKIP_UNITY_BUILD_INCLUSION ON)
endif()

append_avx2_src(util/bpacking_avx2.cc)
append_avx512_src(util/bpacking_avx512.cc)
append_runtime_avx2_src(util/bpacking_avx2.cc)
append_runtime_avx512_src(util/bpacking_avx512.cc)

if(ARROW_HAVE_NEON)
list(APPEND ARROW_SRCS util/bpacking_neon.cc)
Expand Down Expand Up @@ -425,11 +425,11 @@ list(APPEND
compute/row/row_internal.cc
compute/util.cc)

append_avx2_src(compute/key_hash_avx2.cc)
append_avx2_src(compute/key_map_avx2.cc)
append_avx2_src(compute/row/compare_internal_avx2.cc)
append_avx2_src(compute/row/encode_internal_avx2.cc)
append_avx2_src(compute/util_avx2.cc)
append_runtime_avx2_src(compute/key_hash_avx2.cc)
append_runtime_avx2_src(compute/key_map_avx2.cc)
append_runtime_avx2_src(compute/row/compare_internal_avx2.cc)
append_runtime_avx2_src(compute/row/encode_internal_avx2.cc)
append_runtime_avx2_src(compute/util_avx2.cc)

if(ARROW_COMPUTE)
# Include the remaining kernels
Expand Down Expand Up @@ -464,8 +464,8 @@ if(ARROW_COMPUTE)
compute/kernels/vector_select_k.cc
compute/kernels/vector_sort.cc)

append_avx2_src(compute/kernels/aggregate_basic_avx2.cc)
append_avx512_src(compute/kernels/aggregate_basic_avx512.cc)
append_runtime_avx2_src(compute/kernels/aggregate_basic_avx2.cc)
append_runtime_avx512_src(compute/kernels/aggregate_basic_avx512.cc)
endif()

if(ARROW_FILESYSTEM)
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/arrow/acero/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ add_custom_target(arrow_acero)

arrow_install_all_headers("arrow/acero")

macro(append_acero_avx2_src SRC)
macro(append_acero_runtime_avx2_src SRC)
if(ARROW_HAVE_RUNTIME_AVX2)
list(APPEND ARROW_ACERO_SRCS ${SRC})
set_source_files_properties(${SRC} PROPERTIES SKIP_PRECOMPILE_HEADERS ON)
Expand Down Expand Up @@ -56,8 +56,8 @@ set(ARROW_ACERO_SRCS
union_node.cc
util.cc)

append_acero_avx2_src(bloom_filter_avx2.cc)
append_acero_avx2_src(swiss_join_avx2.cc)
append_acero_runtime_avx2_src(bloom_filter_avx2.cc)
append_acero_runtime_avx2_src(swiss_join_avx2.cc)

set(ARROW_ACERO_SHARED_LINK_LIBS)
set(ARROW_ACERO_STATIC_LINK_LIBS)
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/arrow/acero/bloom_filter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ void BlockedBloomFilter::InsertImp(int64_t num_rows, const T* hashes) {
void BlockedBloomFilter::Insert(int64_t hardware_flags, int64_t num_rows,
const uint32_t* hashes) {
int64_t num_processed = 0;
#if defined(ARROW_HAVE_AVX2)
#if defined(ARROW_HAVE_RUNTIME_AVX2)
if (hardware_flags & arrow::internal::CpuInfo::AVX2) {
num_processed = Insert_avx2(num_rows, hashes);
}
Expand All @@ -134,7 +134,7 @@ void BlockedBloomFilter::Insert(int64_t hardware_flags, int64_t num_rows,
void BlockedBloomFilter::Insert(int64_t hardware_flags, int64_t num_rows,
const uint64_t* hashes) {
int64_t num_processed = 0;
#if defined(ARROW_HAVE_AVX2)
#if defined(ARROW_HAVE_RUNTIME_AVX2)
if (hardware_flags & arrow::internal::CpuInfo::AVX2) {
num_processed = Insert_avx2(num_rows, hashes);
}
Expand Down Expand Up @@ -181,7 +181,7 @@ void BlockedBloomFilter::Find(int64_t hardware_flags, int64_t num_rows,
bool enable_prefetch) const {
int64_t num_processed = 0;

#if defined(ARROW_HAVE_AVX2)
#if defined(ARROW_HAVE_RUNTIME_AVX2)
if (!(enable_prefetch && UsePrefetch()) &&
(hardware_flags & arrow::internal::CpuInfo::AVX2)) {
num_processed = Find_avx2(num_rows, hashes, result_bit_vector);
Expand All @@ -202,7 +202,7 @@ void BlockedBloomFilter::Find(int64_t hardware_flags, int64_t num_rows,
bool enable_prefetch) const {
int64_t num_processed = 0;

#if defined(ARROW_HAVE_AVX2)
#if defined(ARROW_HAVE_RUNTIME_AVX2)
if (!(enable_prefetch && UsePrefetch()) &&
(hardware_flags & arrow::internal::CpuInfo::AVX2)) {
num_processed = Find_avx2(num_rows, hashes, result_bit_vector);
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/arrow/acero/bloom_filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

#pragma once

#if defined(ARROW_HAVE_AVX2)
#if defined(ARROW_HAVE_RUNTIME_AVX2)
#include <immintrin.h>
#endif

#include <atomic>
#include <cstdint>
#include <memory>

#include "arrow/acero/partition_util.h"
#include "arrow/acero/util.h"
#include "arrow/memory_pool.h"
Expand Down Expand Up @@ -203,7 +204,7 @@ class ARROW_ACERO_EXPORT BlockedBloomFilter {

void SingleFold(int num_folds);

#if defined(ARROW_HAVE_AVX2)
#if defined(ARROW_HAVE_RUNTIME_AVX2)
inline __m256i mask_avx2(__m256i hash) const;
inline __m256i block_id_avx2(__m256i hash) const;
int64_t Insert_avx2(int64_t num_rows, const uint32_t* hashes);
Expand Down
5 changes: 1 addition & 4 deletions cpp/src/arrow/acero/bloom_filter_avx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,13 @@
// under the License.

#include <immintrin.h>

#include "arrow/acero/bloom_filter.h"
#include "arrow/util/bit_util.h"

namespace arrow {
namespace acero {

#if defined(ARROW_HAVE_AVX2)

inline __m256i BlockedBloomFilter::mask_avx2(__m256i hash) const {
// AVX2 translation of mask() method
//
Expand Down Expand Up @@ -132,7 +131,5 @@ int64_t BlockedBloomFilter::Insert_avx2(int64_t num_rows, const uint64_t* hashes
return InsertImp_avx2(num_rows, hashes);
}

#endif

} // namespace acero
} // namespace arrow
4 changes: 0 additions & 4 deletions cpp/src/arrow/acero/swiss_join_avx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
namespace arrow {
namespace acero {

#if defined(ARROW_HAVE_AVX2)

template <class PROCESS_8_VALUES_FN>
int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows, int column_id, int num_rows,
const uint32_t* row_ids,
Expand Down Expand Up @@ -191,7 +189,5 @@ int RowArrayAccessor::VisitNulls_avx2(const RowTableImpl& rows, int column_id,
return num_rows - (num_rows % unroll);
}

#endif

} // namespace acero
} // namespace arrow
2 changes: 1 addition & 1 deletion cpp/src/arrow/acero/swiss_join_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class RowArrayAccessor {
const uint32_t* row_ids, PROCESS_VALUE_FN process_value_fn);

private:
#if defined(ARROW_HAVE_AVX2)
#if defined(ARROW_HAVE_RUNTIME_AVX2)
// This is equivalent to Visit method, but processing 8 rows at a time in a
// loop.
// Returns the number of processed rows, which may be less than requested (up
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/arrow/compute/key_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ void Hashing32::HashVarLen(int64_t hardware_flags, bool combine_hashes, uint32_t
const uint32_t* offsets, const uint8_t* concatenated_keys,
uint32_t* hashes, uint32_t* hashes_temp_for_combine) {
uint32_t num_processed = 0;
#if defined(ARROW_HAVE_AVX2)
#if defined(ARROW_HAVE_RUNTIME_AVX2)
if (hardware_flags & arrow::internal::CpuInfo::AVX2) {
num_processed = HashVarLen_avx2(combine_hashes, num_rows, offsets, concatenated_keys,
hashes, hashes_temp_for_combine);
Expand All @@ -255,7 +255,7 @@ void Hashing32::HashVarLen(int64_t hardware_flags, bool combine_hashes, uint32_t
const uint64_t* offsets, const uint8_t* concatenated_keys,
uint32_t* hashes, uint32_t* hashes_temp_for_combine) {
uint32_t num_processed = 0;
#if defined(ARROW_HAVE_AVX2)
#if defined(ARROW_HAVE_RUNTIME_AVX2)
if (hardware_flags & arrow::internal::CpuInfo::AVX2) {
num_processed = HashVarLen_avx2(combine_hashes, num_rows, offsets, concatenated_keys,
hashes, hashes_temp_for_combine);
Expand Down Expand Up @@ -361,7 +361,7 @@ void Hashing32::HashFixed(int64_t hardware_flags, bool combine_hashes, uint32_t
}

uint32_t num_processed = 0;
#if defined(ARROW_HAVE_AVX2)
#if defined(ARROW_HAVE_RUNTIME_AVX2)
if (hardware_flags & arrow::internal::CpuInfo::AVX2) {
num_processed = HashFixedLen_avx2(combine_hashes, num_rows, length, keys, hashes,
hashes_temp_for_combine);
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/compute/key_hash.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

#pragma once

#if defined(ARROW_HAVE_AVX2)
#if defined(ARROW_HAVE_RUNTIME_AVX2)
#include <immintrin.h>
#endif

Expand Down Expand Up @@ -115,7 +115,7 @@ class ARROW_EXPORT Hashing32 {
static void HashInt(bool combine_hashes, uint32_t num_keys, uint64_t length_key,
const uint8_t* keys, uint32_t* hashes);

#if defined(ARROW_HAVE_AVX2)
#if defined(ARROW_HAVE_RUNTIME_AVX2)
static inline __m256i Avalanche_avx2(__m256i hash);
static inline __m256i CombineHashesImp_avx2(__m256i previous_hash, __m256i hash);
template <bool T_COMBINE_HASHES>
Expand Down
4 changes: 0 additions & 4 deletions cpp/src/arrow/compute/key_hash_avx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
namespace arrow {
namespace compute {

#if defined(ARROW_HAVE_AVX2)

inline __m256i Hashing32::Avalanche_avx2(__m256i hash) {
hash = _mm256_xor_si256(hash, _mm256_srli_epi32(hash, 15));
hash = _mm256_mullo_epi32(hash, _mm256_set1_epi32(PRIME32_2));
Expand Down Expand Up @@ -315,7 +313,5 @@ uint32_t Hashing32::HashVarLen_avx2(bool combine_hashes, uint32_t num_rows,
}
}

#endif

} // namespace compute
} // namespace arrow
4 changes: 2 additions & 2 deletions cpp/src/arrow/compute/key_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ void SwissTable::extract_group_ids(const int num_keys, const uint16_t* optional_

// Optimistically use simplified lookup involving only a start block to find
// a single group id candidate for every input.
#if defined(ARROW_HAVE_AVX2)
#if defined(ARROW_HAVE_RUNTIME_AVX2)
int num_group_id_bytes = num_group_id_bits / 8;
if ((hardware_flags_ & arrow::internal::CpuInfo::AVX2) && !optional_selection) {
num_processed = extract_group_ids_avx2(num_keys, hashes, local_slots, out_group_ids,
Expand Down Expand Up @@ -301,7 +301,7 @@ void SwissTable::early_filter(const int num_keys, const uint32_t* hashes,
// Optimistically use simplified lookup involving only a start block to find
// a single group id candidate for every input.
int num_processed = 0;
#if defined(ARROW_HAVE_AVX2)
#if defined(ARROW_HAVE_RUNTIME_AVX2)
if (hardware_flags_ & arrow::internal::CpuInfo::AVX2) {
if (log_blocks_ <= 4) {
num_processed = early_filter_imp_avx2_x32(num_keys, hashes, out_match_bitvector,
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/key_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class ARROW_EXPORT SwissTable {
//
void early_filter_imp(const int num_keys, const uint32_t* hashes,
uint8_t* out_match_bitvector, uint8_t* out_local_slots) const;
#if defined(ARROW_HAVE_AVX2)
#if defined(ARROW_HAVE_RUNTIME_AVX2)
int early_filter_imp_avx2_x8(const int num_hashes, const uint32_t* hashes,
uint8_t* out_match_bitvector,
uint8_t* out_local_slots) const;
Expand Down
4 changes: 0 additions & 4 deletions cpp/src/arrow/compute/key_map_avx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
namespace arrow {
namespace compute {

#if defined(ARROW_HAVE_AVX2)

// This is more or less translation of equivalent scalar code, adjusted for a
// different instruction set (e.g. missing leading zero count instruction).
//
Expand Down Expand Up @@ -412,7 +410,5 @@ int SwissTable::extract_group_ids_avx2(const int num_keys, const uint32_t* hashe
return num_keys - (num_keys % unroll);
}

#endif

} // namespace compute
} // namespace arrow
8 changes: 4 additions & 4 deletions cpp/src/arrow/compute/row/compare_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void KeyCompare::NullUpdateColumnToRow(uint32_t id_col, uint32_t num_rows_to_com
return;
}
uint32_t num_processed = 0;
#if defined(ARROW_HAVE_AVX2)
#if defined(ARROW_HAVE_RUNTIME_AVX2)
if (ctx->has_avx2()) {
num_processed = NullUpdateColumnToRow_avx2(use_selection, id_col, num_rows_to_compare,
sel_left_maybe_null, left_to_right_map,
Expand Down Expand Up @@ -130,7 +130,7 @@ void KeyCompare::CompareBinaryColumnToRow(uint32_t offset_within_row,
const RowTableImpl& rows,
uint8_t* match_bytevector) {
uint32_t num_processed = 0;
#if defined(ARROW_HAVE_AVX2)
#if defined(ARROW_HAVE_RUNTIME_AVX2)
if (ctx->has_avx2()) {
num_processed = CompareBinaryColumnToRow_avx2(
use_selection, offset_within_row, num_rows_to_compare, sel_left_maybe_null,
Expand Down Expand Up @@ -297,7 +297,7 @@ void KeyCompare::CompareVarBinaryColumnToRow(uint32_t id_varbinary_col,
const RowTableImpl& rows,
uint8_t* match_bytevector) {
uint32_t num_processed = 0;
#if defined(ARROW_HAVE_AVX2)
#if defined(ARROW_HAVE_RUNTIME_AVX2)
if (ctx->has_avx2()) {
num_processed = CompareVarBinaryColumnToRow_avx2(
use_selection, is_first_varbinary_col, id_varbinary_col, num_rows_to_compare,
Expand All @@ -313,7 +313,7 @@ void KeyCompare::CompareVarBinaryColumnToRow(uint32_t id_varbinary_col,
void KeyCompare::AndByteVectors(LightContext* ctx, uint32_t num_elements,
uint8_t* bytevector_A, const uint8_t* bytevector_B) {
uint32_t num_processed = 0;
#if defined(ARROW_HAVE_AVX2)
#if defined(ARROW_HAVE_RUNTIME_AVX2)
if (ctx->has_avx2()) {
num_processed = AndByteVectors_avx2(num_elements, bytevector_A, bytevector_B);
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/row/compare_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class ARROW_EXPORT KeyCompare {
static void AndByteVectors(LightContext* ctx, uint32_t num_elements,
uint8_t* bytevector_A, const uint8_t* bytevector_B);

#if defined(ARROW_HAVE_AVX2)
#if defined(ARROW_HAVE_RUNTIME_AVX2)

template <bool use_selection>
static uint32_t NullUpdateColumnToRowImp_avx2(
Expand Down
4 changes: 0 additions & 4 deletions cpp/src/arrow/compute/row/compare_internal_avx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
namespace arrow {
namespace compute {

#if defined(ARROW_HAVE_AVX2)

inline __m256i set_first_n_bytes_avx2(int n) {
constexpr uint64_t kByteSequence0To7 = 0x0706050403020100ULL;
constexpr uint64_t kByteSequence8To15 = 0x0f0e0d0c0b0a0908ULL;
Expand Down Expand Up @@ -670,7 +668,5 @@ uint32_t KeyCompare::CompareVarBinaryColumnToRow_avx2(
return num_rows_to_compare;
}

#endif

} // namespace compute
} // namespace arrow
Loading

0 comments on commit 241a5b4

Please sign in to comment.