Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
pitrou committed Feb 13, 2024
1 parent bbe59b3 commit 75e906f
Show file tree
Hide file tree
Showing 5 changed files with 462 additions and 231 deletions.
74 changes: 60 additions & 14 deletions cpp/src/arrow/util/byte_stream_split_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@

#include "arrow/util/endian.h"
#include "arrow/util/simd.h"
#include "arrow/util/small_vector.h"
#include "arrow/util/ubsan.h"

#include <algorithm>
#include <array>
#include <cassert>
#include <cstdint>

Expand All @@ -35,6 +37,9 @@ namespace arrow::util::internal {
// SIMD implementations
//

// TODO have all decode and encode routines take an explicit width? This would simplify
// testing and benchmarking quite a bit...

#if defined(ARROW_HAVE_SSE4_2)
template <int kNumStreams>
void ByteStreamSplitDecodeSse2(const uint8_t* data, int64_t num_values, int64_t stride,
Expand Down Expand Up @@ -672,14 +677,24 @@ inline void DoMergeStreams(const uint8_t** src_streams, int width, int64_t nvalu

template <int kNumStreams>
void ByteStreamSplitEncodeScalar(const uint8_t* raw_values, const int64_t num_values,
uint8_t* output_buffer_raw) {
uint8_t* out) {
std::array<uint8_t*, kNumStreams> dest_streams;
for (int stream = 0; stream < kNumStreams; ++stream) {
dest_streams[stream] = &output_buffer_raw[stream * num_values];
dest_streams[stream] = &out[stream * num_values];
}
DoSplitStreams(raw_values, kNumStreams, num_values, dest_streams.data());
}

inline void ByteStreamSplitEncodeScalarDynamic(const uint8_t* raw_values, int width,
const int64_t num_values, uint8_t* out) {
::arrow::internal::SmallVector<uint8_t*, 16> dest_streams;
dest_streams.resize(width);
for (int stream = 0; stream < width; ++stream) {
dest_streams[stream] = &out[stream * num_values];
}
DoSplitStreams(raw_values, width, num_values, dest_streams.data());
}

template <int kNumStreams>
void ByteStreamSplitDecodeScalar(const uint8_t* data, int64_t num_values, int64_t stride,
uint8_t* out) {
Expand All @@ -690,26 +705,57 @@ void ByteStreamSplitDecodeScalar(const uint8_t* data, int64_t num_values, int64_
DoMergeStreams(src_streams.data(), kNumStreams, num_values, out);
}

template <int kNumStreams>
void inline ByteStreamSplitEncode(const uint8_t* raw_values, const int64_t num_values,
uint8_t* output_buffer_raw) {
inline void ByteStreamSplitDecodeScalarDynamic(const uint8_t* data, int width,
int64_t num_values, int64_t stride,
uint8_t* out) {
::arrow::internal::SmallVector<const uint8_t*, 16> src_streams;
src_streams.resize(width);
for (int stream = 0; stream < width; ++stream) {
src_streams[stream] = &data[stream * stride];
}
DoMergeStreams(src_streams.data(), width, num_values, out);
}

inline void ByteStreamSplitEncode(const uint8_t* raw_values, int width,
const int64_t num_values, uint8_t* out) {
#if defined(ARROW_HAVE_SIMD_SPLIT)
return ByteStreamSplitEncodeSimd<kNumStreams>(raw_values, num_values,
output_buffer_raw);
#define ByteStreamSplitEncodePerhapsSimd ByteStreamSplitEncodeSimd
#else
return ByteStreamSplitEncodeScalar<kNumStreams>(raw_values, num_values,
output_buffer_raw);
#define ByteStreamSplitEncodePerhapsSimd ByteStreamSplitEncodeScalar
#endif
switch (width) {
case 2:
return ByteStreamSplitEncodeScalar<2>(raw_values, num_values, out);
case 4:
return ByteStreamSplitEncodePerhapsSimd<4>(raw_values, num_values, out);
case 8:
return ByteStreamSplitEncodePerhapsSimd<8>(raw_values, num_values, out);
case 16:
return ByteStreamSplitEncodeScalar<16>(raw_values, num_values, out);
}
return ByteStreamSplitEncodeScalarDynamic(raw_values, width, num_values, out);
#undef ByteStreamSplitEncodePerhapsSimd
}

template <int kNumStreams>
void inline ByteStreamSplitDecode(const uint8_t* data, int64_t num_values, int64_t stride,
uint8_t* out) {
inline void ByteStreamSplitDecode(const uint8_t* data, int width, int64_t num_values,
int64_t stride, uint8_t* out) {
#if defined(ARROW_HAVE_SIMD_SPLIT)
return ByteStreamSplitDecodeSimd<kNumStreams>(data, num_values, stride, out);
#define ByteStreamSplitDecodePerhapsSimd ByteStreamSplitDecodeSimd
#else
return ByteStreamSplitDecodeScalar<kNumStreams>(data, num_values, stride, out);
#define ByteStreamSplitDecodePerhapsSimd ByteStreamSplitDecodeScalar
#endif
switch (width) {
case 2:
return ByteStreamSplitDecodeScalar<2>(data, num_values, stride, out);
case 4:
return ByteStreamSplitDecodePerhapsSimd<4>(data, num_values, stride, out);
case 8:
return ByteStreamSplitDecodePerhapsSimd<8>(data, num_values, stride, out);
case 16:
return ByteStreamSplitDecodeScalar<16>(data, num_values, stride, out);
}
return ByteStreamSplitDecodeScalarDynamic(data, width, num_values, stride, out);
#undef ByteStreamSplitDecodePerhapsSimd
}

} // namespace arrow::util::internal
100 changes: 76 additions & 24 deletions cpp/src/arrow/util/byte_stream_split_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,29 +63,12 @@ class TestByteStreamSplitSpecialized : public ::testing::Test {
public:
static constexpr int kWidth = static_cast<int>(sizeof(T));

using EncodeFunc = NamedFunc<std::function<decltype(ByteStreamSplitEncode<kWidth>)>>;
using DecodeFunc = NamedFunc<std::function<decltype(ByteStreamSplitDecode<kWidth>)>>;
using EncodeFunc = NamedFunc<std::function<decltype(ByteStreamSplitEncode)>>;
using DecodeFunc = NamedFunc<std::function<decltype(ByteStreamSplitDecode)>>;

void SetUp() override {
encode_funcs_.push_back({"reference", &ReferenceEncode});
encode_funcs_.push_back({"scalar", &ByteStreamSplitEncodeScalar<kWidth>});
decode_funcs_.push_back({"scalar", &ByteStreamSplitDecodeScalar<kWidth>});
#if defined(ARROW_HAVE_SIMD_SPLIT)
encode_funcs_.push_back({"simd", &ByteStreamSplitEncodeSimd<kWidth>});
decode_funcs_.push_back({"simd", &ByteStreamSplitDecodeSimd<kWidth>});
#endif
#if defined(ARROW_HAVE_SSE4_2)
encode_funcs_.push_back({"sse2", &ByteStreamSplitEncodeSse2<kWidth>});
decode_funcs_.push_back({"sse2", &ByteStreamSplitDecodeSse2<kWidth>});
#endif
#if defined(ARROW_HAVE_AVX2)
encode_funcs_.push_back({"avx2", &ByteStreamSplitEncodeAvx2<kWidth>});
decode_funcs_.push_back({"avx2", &ByteStreamSplitDecodeAvx2<kWidth>});
#endif
#if defined(ARROW_HAVE_AVX512)
encode_funcs_.push_back({"avx512", &ByteStreamSplitEncodeAvx512<kWidth>});
decode_funcs_.push_back({"avx512", &ByteStreamSplitDecodeAvx512<kWidth>});
#endif
decode_funcs_ = MakeDecodeFuncs();
encode_funcs_ = MakeEncodeFuncs();
}

void TestRoundtrip(int64_t num_values) {
Expand All @@ -98,12 +81,12 @@ class TestByteStreamSplitSpecialized : public ::testing::Test {
for (const auto& encode_func : encode_funcs_) {
ARROW_SCOPED_TRACE("encode_func = ", encode_func);
encoded.assign(encoded.size(), 0);
encode_func.func(reinterpret_cast<const uint8_t*>(input.data()), num_values,
encode_func.func(reinterpret_cast<const uint8_t*>(input.data()), kWidth, num_values,
encoded.data());
for (const auto& decode_func : decode_funcs_) {
ARROW_SCOPED_TRACE("decode_func = ", decode_func);
decoded.assign(decoded.size(), T{});
decode_func.func(encoded.data(), num_values, /*stride=*/num_values,
decode_func.func(encoded.data(), kWidth, num_values, /*stride=*/num_values,
reinterpret_cast<uint8_t*>(decoded.data()));
ASSERT_EQ(decoded, input);
}
Expand All @@ -129,7 +112,8 @@ class TestByteStreamSplitSpecialized : public ::testing::Test {
int64_t offset = 0;
while (offset < num_values) {
auto chunk_size = std::min<int64_t>(num_values - offset, chunk_size_dist(gen));
decode_func.func(encoded.data() + offset, chunk_size, /*stride=*/num_values,
decode_func.func(encoded.data() + offset, kWidth, chunk_size,
/*stride=*/num_values,
reinterpret_cast<uint8_t*>(decoded.data() + offset));
offset += chunk_size;
}
Expand All @@ -156,6 +140,74 @@ class TestByteStreamSplitSpecialized : public ::testing::Test {
return input;
}

template <bool kSimdImplemented = (kWidth == 4 || kWidth == 8)>
static std::vector<DecodeFunc> MakeDecodeFuncs() {
std::vector<DecodeFunc> funcs;
funcs.push_back({"scalar", &ByteStreamSplitDecodeScalarDynamic});
funcs.push_back(
{"scalar", DynamicWidthDecodeFromStatic(&ByteStreamSplitDecodeScalar<kWidth>)});
#if defined(ARROW_HAVE_SIMD_SPLIT)
if constexpr (kSimdImplemented) {
funcs.push_back(
{"simd", DynamicWidthDecodeFromStatic(&ByteStreamSplitDecodeSimd<kWidth>)});
#if defined(ARROW_HAVE_SSE4_2)
funcs.push_back(
{"sse2", DynamicWidthDecodeFromStatic(&ByteStreamSplitDecodeSse2<kWidth>)});
#endif
#if defined(ARROW_HAVE_AVX2)
funcs.push_back(
{"avx2", DynamicWidthDecodeFromStatic(&ByteStreamSplitDecodeAvx2<kWidth>)});
#endif
#if defined(ARROW_HAVE_AVX512)
funcs.push_back(
{"avx512", DynamicWidthDecodeFromStatic(&ByteStreamSplitDecodeAvx512<kWidth>)});
#endif
}
#endif // defined(ARROW_HAVE_SIMD_SPLIT)
return funcs;
}

template <bool kSimdImplemented = (kWidth == 4 || kWidth == 8)>
static std::vector<EncodeFunc> MakeEncodeFuncs() {
std::vector<EncodeFunc> funcs;
funcs.push_back({"reference", &ReferenceByteStreamSplitEncode});
funcs.push_back({"reference", &ByteStreamSplitEncodeScalarDynamic});
funcs.push_back(
{"scalar", DynamicWidthEncodeFromStatic(&ByteStreamSplitEncodeScalar<kWidth>)});
#if defined(ARROW_HAVE_SIMD_SPLIT)
if constexpr (kSimdImplemented) {
funcs.push_back(
{"simd", DynamicWidthEncodeFromStatic(&ByteStreamSplitEncodeSimd<kWidth>)});
#if defined(ARROW_HAVE_SSE4_2)
funcs.push_back(
{"sse2", DynamicWidthEncodeFromStatic(&ByteStreamSplitEncodeSse2<kWidth>)});
#endif
#if defined(ARROW_HAVE_AVX2)
funcs.push_back(
{"avx2", DynamicWidthEncodeFromStatic(&ByteStreamSplitEncodeAvx2<kWidth>)});
#endif
#if defined(ARROW_HAVE_AVX512)
funcs.push_back(
{"avx512", DynamicWidthEncodeFromStatic(&ByteStreamSplitEncodeAvx512<kWidth>)});
#endif
}
#endif // defined(ARROW_HAVE_SIMD_SPLIT)
return funcs;
}

static std::function<decltype(ByteStreamSplitDecode)> DynamicWidthDecodeFromStatic(
std::function<decltype(ByteStreamSplitDecodeScalar<1>)> wrapped) {
return [wrapped](const uint8_t* data, int width, int64_t num_values, int64_t stride,
uint8_t* out) { wrapped(data, num_values, stride, out); };
}

static std::function<decltype(ByteStreamSplitEncode)> DynamicWidthEncodeFromStatic(
std::function<decltype(ByteStreamSplitEncodeScalar<1>)> wrapped) {
return [wrapped](const uint8_t* data, int width, int64_t num_values, uint8_t* out) {
wrapped(data, num_values, out);
};
}

std::vector<EncodeFunc> encode_funcs_;
std::vector<DecodeFunc> decode_funcs_;

Expand Down
Loading

0 comments on commit 75e906f

Please sign in to comment.