Skip to content

Commit

Permalink
AVX2 vectorization for very large bitsets (#4422)
Browse files Browse the repository at this point in the history
Co-authored-by: Stephan T. Lavavej <stl@nuwen.net>
  • Loading branch information
AlexGuteniev and StephanTLavavej authored Feb 29, 2024
1 parent a6c2a72 commit 8b081e2
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 24 deletions.
2 changes: 2 additions & 0 deletions benchmarks/src/bitset_to_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ namespace {

BENCHMARK(BM_bitset_to_string<15, char>);
BENCHMARK(BM_bitset_to_string<64, char>);
BENCHMARK(BM_bitset_to_string<512, char>);
BENCHMARK(BM_bitset_to_string_large_single<char>);
BENCHMARK(BM_bitset_to_string<7, wchar_t>);
BENCHMARK(BM_bitset_to_string<64, wchar_t>);
BENCHMARK(BM_bitset_to_string<512, wchar_t>);
BENCHMARK(BM_bitset_to_string_large_single<wchar_t>);

BENCHMARK_MAIN();
92 changes: 92 additions & 0 deletions stl/src/vector_algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2169,6 +2169,17 @@ __declspec(noalias) size_t

#ifndef _M_ARM64EC
namespace {
__m256i __forceinline _Bitset_to_string_1_step_avx(const uint32_t _Val, const __m256i _Px0, const __m256i _Px1) {
const __m128i _Vx0 = _mm_cvtsi32_si128(_Val);
const __m128i _Vx1 = _mm_shuffle_epi8(_Vx0, _mm_set_epi32(0x00000000, 0x01010101, 0x02020202, 0x03030303));
const __m256i _Vx2 = _mm256_castsi128_si256(_Vx1);
const __m256i _Vx3 = _mm256_permutevar8x32_epi32(_Vx2, _mm256_set_epi32(3, 3, 2, 2, 1, 1, 0, 0));
const __m256i _Msk = _mm256_and_si256(_Vx3, _mm256_set1_epi64x(0x0102040810204080));
const __m256i _Ex0 = _mm256_cmpeq_epi8(_Msk, _mm256_setzero_si256());
const __m256i _Ex1 = _mm256_blendv_epi8(_Px1, _Px0, _Ex0);
return _Ex1;
}

__m128i __forceinline _Bitset_to_string_1_step(const uint16_t _Val, const __m128i _Px0, const __m128i _Px1) {
const __m128i _Vx0 = _mm_cvtsi32_si128(_Val);
const __m128i _Vx1 = _mm_unpacklo_epi8(_Vx0, _Vx0);
Expand All @@ -2180,6 +2191,18 @@ namespace {
return _Ex1;
}

__m256i __forceinline _Bitset_to_string_2_step_avx(const uint16_t _Val, const __m256i _Px0, const __m256i _Px1) {
const __m128i _Vx0 = _mm_cvtsi32_si128(_Val);
const __m128i _Vx1 = _mm_shuffle_epi8(_Vx0, _mm_set_epi32(0x00000000, 0x00000000, 0x01010101, 0x01010101));
const __m256i _Vx2 = _mm256_castsi128_si256(_Vx1);
const __m256i _Vx3 = _mm256_permute4x64_epi64(_Vx2, _MM_SHUFFLE(1, 1, 0, 0));
const __m256i _Msk = _mm256_and_si256(
_Vx3, _mm256_set_epi64x(0x0001000200040008, 0x0010002000400080, 0x0001000200040008, 0x0010002000400080));
const __m256i _Ex0 = _mm256_cmpeq_epi16(_Msk, _mm256_setzero_si256());
const __m256i _Ex1 = _mm256_blendv_epi8(_Px1, _Px0, _Ex0);
return _Ex1;
}

__m128i __forceinline _Bitset_to_string_2_step(const uint8_t _Val, const __m128i _Px0, const __m128i _Px1) {
const __m128i _Vx = _mm_set1_epi16(_Val);
const __m128i _Msk = _mm_and_si128(_Vx, _mm_set_epi64x(0x0001000200040008, 0x0010002000400080));
Expand All @@ -2195,6 +2218,38 @@ extern "C" {
__declspec(noalias) void __stdcall __std_bitset_to_string_1(
char* const _Dest, const void* _Src, size_t _Size_bits, const char _Elem0, const char _Elem1) noexcept {
#ifndef _M_ARM64EC
if (_Use_avx2() && _Size_bits >= 256) {
const __m256i _Px0 = _mm256_broadcastb_epi8(_mm_cvtsi32_si128(_Elem0));
const __m256i _Px1 = _mm256_broadcastb_epi8(_mm_cvtsi32_si128(_Elem1));
if (_Size_bits >= 32) {
char* _Pos = _Dest + _Size_bits;
_Size_bits &= 0x1F;
char* const _Stop_at = _Dest + _Size_bits;
do {
uint32_t _Val;
memcpy(&_Val, _Src, 4);
const __m256i _Elems = _Bitset_to_string_1_step_avx(_Val, _Px0, _Px1);
_Pos -= 32;
_mm256_storeu_si256(reinterpret_cast<__m256i*>(_Pos), _Elems);
_Advance_bytes(_Src, 4);
} while (_Pos != _Stop_at);
}

if (_Size_bits > 0) {
__assume(_Size_bits < 32);
uint32_t _Val = 0;
memcpy(&_Val, _Src, (_Size_bits + 7) / 8);
const __m256i _Elems = _Bitset_to_string_1_step_avx(_Val, _Px0, _Px1);
char _Tmp[32];
_mm256_storeu_si256(reinterpret_cast<__m256i*>(_Tmp), _Elems);
const char* const _Tmpd = _Tmp + (32 - _Size_bits);
memcpy(_Dest, _Tmpd, _Size_bits);
}

_mm256_zeroupper(); // TRANSITION, DevCom-10331414
return;
}

if (_Use_sse2()) {
const __m128i _Px0 = _mm_set1_epi8(_Elem0 ^ _Elem1);
const __m128i _Px1 = _mm_set1_epi8(_Elem1);
Expand Down Expand Up @@ -2241,6 +2296,43 @@ __declspec(noalias) void __stdcall __std_bitset_to_string_1(
__declspec(noalias) void __stdcall __std_bitset_to_string_2(
wchar_t* const _Dest, const void* _Src, size_t _Size_bits, const wchar_t _Elem0, const wchar_t _Elem1) noexcept {
#ifndef _M_ARM64EC
if (_Use_avx2() && _Size_bits >= 256) {
const __m256i _Px0 = _mm256_broadcastw_epi16(_mm_cvtsi32_si128(_Elem0));
const __m256i _Px1 = _mm256_broadcastw_epi16(_mm_cvtsi32_si128(_Elem1));

if (_Size_bits >= 16) {
wchar_t* _Pos = _Dest + _Size_bits;
_Size_bits &= 0xF;
wchar_t* const _Stop_at = _Dest + _Size_bits;
do {
uint16_t _Val;
memcpy(&_Val, _Src, 2);
const __m256i _Elems = _Bitset_to_string_2_step_avx(_Val, _Px0, _Px1);
_Pos -= 16;
_mm256_storeu_si256(reinterpret_cast<__m256i*>(_Pos), _Elems);
_Advance_bytes(_Src, 2);
} while (_Pos != _Stop_at);
}

if (_Size_bits > 0) {
__assume(_Size_bits < 16);
uint16_t _Val;
if (_Size_bits > 8) {
memcpy(&_Val, _Src, 2);
} else {
_Val = *reinterpret_cast<const uint8_t*>(_Src);
}
const __m256i _Elems = _Bitset_to_string_2_step_avx(_Val, _Px0, _Px1);
wchar_t _Tmp[16];
_mm256_storeu_si256(reinterpret_cast<__m256i*>(_Tmp), _Elems);
const wchar_t* const _Tmpd = _Tmp + (16 - _Size_bits);
memcpy(_Dest, _Tmpd, _Size_bits * 2);
}

_mm256_zeroupper(); // TRANSITION, DevCom-10331414
return;
}

if (_Use_sse2()) {
const __m128i _Px0 = _mm_set1_epi16(_Elem0 ^ _Elem1);
const __m128i _Px1 = _mm_set1_epi16(_Elem1);
Expand Down
63 changes: 39 additions & 24 deletions tests/std/tests/VSO_0000000_vector_algorithms/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <random>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>

#if _HAS_CXX20
Expand Down Expand Up @@ -474,6 +475,43 @@ void test_one_container() {
test_two_containers<Container, list<int>>();
}

template <size_t N>
bool test_randomized_bitset(mt19937_64& gen) {
string str;
wstring wstr;
str.reserve(N);
wstr.reserve(N);

while (str.size() != N) {
uint64_t random_value = gen();

for (int bits = 0; bits < 64 && str.size() != N; ++bits) {
const auto character = '0' + (random_value & 1);
str.push_back(static_cast<char>(character));
wstr.push_back(static_cast<wchar_t>(character));
random_value >>= 1;
}
}

const bitset<N> b(str);

assert(b.to_string() == str);
assert(b.template to_string<wchar_t>() == wstr);

return true;
}

template <size_t Base, size_t... Vals>
void test_randomized_bitset_base(index_sequence<Vals...>, mt19937_64& gen) {
bool ignored[] = {test_randomized_bitset<Base + Vals>(gen)...};
(void) ignored;
}

template <size_t Base, size_t Count>
void test_randomized_bitset_base_count(mt19937_64& gen) {
test_randomized_bitset_base<Base>(make_index_sequence<Count>{}, gen);
}

void test_bitset(mt19937_64& gen) {
assert(bitset<0>(0x0ULL).to_string() == "");
assert(bitset<0>(0xFEDCBA9876543210ULL).to_string() == "");
Expand Down Expand Up @@ -515,30 +553,7 @@ void test_bitset(mt19937_64& gen) {
assert(bitset<75>(0xFEDCBA9876543210ULL).to_string<char32_t>()
== U"000000000001111111011011100101110101001100001110110010101000011001000010000"); // not vectorized

{
constexpr size_t N = 2048;

string str;
wstring wstr;
str.reserve(N);
wstr.reserve(N);

while (str.size() != N) {
uint64_t random_value = gen();

for (int bits = 0; bits < 64; ++bits) {
const auto character = '0' + (random_value & 1);
str.push_back(static_cast<char>(character));
wstr.push_back(static_cast<wchar_t>(character));
random_value >>= 1;
}
}

const bitset<N> b(str);

assert(b.to_string() == str);
assert(b.to_string<wchar_t>() == wstr);
}
test_randomized_bitset_base_count<512 - 5, 32 + 10>(gen);
}

void test_various_containers() {
Expand Down

0 comments on commit 8b081e2

Please sign in to comment.