Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vectorize replace 🎭 #4554

Merged
merged 18 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions benchmarks/src/replace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ const char src[] =
"euismod eros, ut posuere ligula ullamcorper id. Nullam aliquam malesuada est at dignissim. Pellentesque finibus "
"sagittis libero nec bibendum. Phasellus dolor ipsum, finibus quis turpis quis, mollis interdum felis.";

template <class T>
void r(benchmark::State& state) {
const std::vector<T> a(std::begin(src), std::end(src));
std::vector<T> b(std::size(src));

for (auto _ : state) {
b = a;
std::replace(std::begin(b), std::end(b), T{'m'}, T{'w'});
}
}

template <class T>
void rc(benchmark::State& state) {
const std::vector<T> a(std::begin(src), std::end(src));
Expand All @@ -58,6 +69,9 @@ void rc_if(benchmark::State& state) {
}
}

BENCHMARK(r<std::uint32_t>);
BENCHMARK(r<std::uint64_t>);

BENCHMARK(rc<std::uint8_t>);
BENCHMARK(rc<std::uint16_t>);
BENCHMARK(rc<std::uint32_t>);
Expand Down
68 changes: 68 additions & 0 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ __declspec(noalias) _Min_max_8i __stdcall __std_minmax_8i(const void* _First, co
__declspec(noalias) _Min_max_8u __stdcall __std_minmax_8u(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_f __stdcall __std_minmax_f(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_d __stdcall __std_minmax_d(const void* _First, const void* _Last) noexcept;

// TRANSITION, DevCom-10610477
__declspec(noalias) void __stdcall __std_replace_4(
void* _First, void* _Last, uint32_t _Old_val, uint32_t _New_val) noexcept;
__declspec(noalias) void __stdcall __std_replace_8(
void* _First, void* _Last, uint64_t _Old_val, uint64_t _New_val) noexcept;
} // extern "C"

_STD_BEGIN
Expand Down Expand Up @@ -180,6 +186,24 @@ _Ty1* __std_find_first_of_trivial(
}
}

template <class _Ty, class _TVal1, class _TVal2>
__declspec(noalias) void _Replace_vectorized(
_Ty* const _First, _Ty* const _Last, const _TVal1 _Old_val, const _TVal2 _New_val) noexcept {
if constexpr (is_pointer_v<_Ty>) {
#ifdef _WIN64
::__std_replace_8(_First, _Last, reinterpret_cast<uint64_t>(_Old_val), reinterpret_cast<uint64_t>(_New_val));
#else // ^^^ defined(_WIN64) / !defined(_WIN64) vvv
::__std_replace_4(_First, _Last, reinterpret_cast<uint32_t>(_Old_val), reinterpret_cast<uint32_t>(_New_val));
#endif // ^^^ !defined(_WIN64) ^^^
} else if constexpr (sizeof(_Ty) == 4) {
::__std_replace_4(_First, _Last, static_cast<uint32_t>(_Old_val), static_cast<uint32_t>(_New_val));
} else if constexpr (sizeof(_Ty) == 8) {
::__std_replace_8(_First, _Last, static_cast<uint64_t>(_Old_val), static_cast<uint64_t>(_New_val));
} else {
static_assert(_Always_false<_Ty>, "Unexpected size");
}
}

// find_first_of vectorization is likely to be a win after this size (in elements)
_INLINE_VAR constexpr ptrdiff_t _Threshold_find_first_of = 16;

Expand All @@ -188,6 +212,17 @@ template <class _It1, class _It2, class _Pr>
_INLINE_VAR constexpr bool _Vector_alg_in_find_first_of_is_safe =
_Equal_memcmp_is_safe<_It1, _It2, _Pr> // can replace value comparison with bitwise comparison
&& sizeof(_Iter_value_t<_It1>) <= 2; // pcmpestri compatible size

// Can we activate the vector algorithms for replace?
template <class _Iter, class _Ty1>
constexpr bool _Vector_alg_in_replace_is_safe = _Vector_alg_in_find_is_safe<_Iter, _Ty1> // can search for the value
&& sizeof(_Iter_value_t<_Iter>) >= 4; // avx masked op compatible size

// Can we activate the vector algorithms for replace?
template <class _Iter, class _Ty1, class _Ty2>
constexpr bool _Vector_alg_in_replace_with_maybe_other_type_is_safe =
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
_Vector_alg_in_replace_is_safe<_Iter, _Ty1> // can search and replace
&& _Vector_alg_in_find_is_safe_elem<_Ty2, _Iter_value_t<_Iter>>; // replacement fits
_STD_END
#endif // _USE_STD_VECTOR_ALGORITHMS

Expand Down Expand Up @@ -3807,6 +3842,22 @@ _CONSTEXPR20 void replace(const _FwdIt _First, const _FwdIt _Last, const _Ty& _O
_STD _Adl_verify_range(_First, _Last);
auto _UFirst = _STD _Get_unwrapped(_First);
const auto _ULast = _STD _Get_unwrapped(_Last);

#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Vector_alg_in_replace_is_safe<_FwdIt, _Ty>) {
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
#if _HAS_CXX20
if (!_STD is_constant_evaluated())
#endif // _HAS_CXX20
{
if (_STD _Could_compare_equal_to_value_type<_FwdIt>(_Oldval)) {
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
_STD _Replace_vectorized(_STD _To_address(_UFirst), _STD _To_address(_ULast), _Oldval, _Newval);
}

return;
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS

for (; _UFirst != _ULast; ++_UFirst) {
if (*_UFirst == _Oldval) {
*_UFirst = _Newval;
Expand Down Expand Up @@ -3860,6 +3911,23 @@ namespace ranges {
_STL_INTERNAL_STATIC_ASSERT(indirectly_writable<_It, const _Ty2&>);
_STL_INTERNAL_STATIC_ASSERT(indirect_binary_predicate<ranges::equal_to, projected<_It, _Pj>, const _Ty1*>);

#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (is_same_v<_Pj, identity> && sized_sentinel_for<_Se, _It>
&& _Vector_alg_in_replace_with_maybe_other_type_is_safe<_It, _Ty1, _Ty2>) {
if (!_STD is_constant_evaluated()) {
const auto _Count = _Last - _First;

if (_STD _Could_compare_equal_to_value_type<_It>(_Oldval)) {
const auto _First_ptr = _STD to_address(_First);
const auto _Last_ptr = _First_ptr + _Count;
_STD _Replace_vectorized(_First_ptr, _Last_ptr, _Oldval, _Newval);
}

return _First + _Count;
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS

for (; _First != _Last; ++_First) {
if (_STD invoke(_Proj, *_First) == _Oldval) {
*_First = _Newval;
Expand Down
40 changes: 22 additions & 18 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -5830,30 +5830,34 @@ struct _Vector_alg_in_find_is_safe_object_pointers<_Ty1*, _Ty2*>
// either _Ty1 is the same as _Ty2 (ignoring cv-qualifiers), or one of the two is void
disjunction<is_same<remove_cv_t<_Ty1>, remove_cv_t<_Ty2>>, is_void<_Ty1>, is_void<_Ty2>>> {};

// Can we activate the vector algorithms for a value and container elements
template <class _Ty, class _Elem>
_INLINE_VAR constexpr bool _Vector_alg_in_find_is_safe_elem = disjunction_v<
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
#ifdef __cpp_lib_byte
// We're finding a std::byte in a range of std::byte.
conjunction<is_same<_Ty, byte>, is_same<_Elem, byte>>,
#endif // defined(__cpp_lib_byte)
// We're finding an integer in a range of integers.
// This case is the one that requires careful runtime handling in _Could_compare_equal_to_value_type.
conjunction<is_integral<_Ty>, is_integral<_Elem>>,
// We're finding an (object or function) pointer in a range of pointers of the same type.
conjunction<is_pointer<_Ty>, is_same<_Ty, _Elem>>,
// We're finding a nullptr in a range of (object or function) pointers.
conjunction<is_same<_Ty, nullptr_t>, is_pointer<_Elem>>,
// We're finding an object pointer in a range of object pointers, and:
// - One of the pointer types is a cv void*.
// - One of the pointer types is a cv1 U* and the other is a cv2 U*.
_Vector_alg_in_find_is_safe_object_pointers<_Ty, _Elem>>;

// Can we activate the vector algorithms for find/count?
template <class _Iter, class _Ty, class _Elem = _Iter_value_t<_Iter>>
template <class _Iter, class _Ty>
_INLINE_VAR constexpr bool _Vector_alg_in_find_is_safe =
// The iterator must be contiguous so we can get raw pointers.
_Iterator_is_contiguous<_Iter>
// The iterator must not be volatile.
&& !_Iterator_is_volatile<_Iter>
// And one of the following conditions must be met:
&& disjunction_v<
#ifdef __cpp_lib_byte
// We're finding a std::byte in a range of std::byte.
conjunction<is_same<_Ty, byte>, is_same<_Elem, byte>>,
#endif // defined(__cpp_lib_byte)
// We're finding an integer in a range of integers.
// This case is the one that requires careful runtime handling in _Could_compare_equal_to_value_type.
conjunction<is_integral<_Ty>, is_integral<_Elem>>,
// We're finding an (object or function) pointer in a range of pointers of the same type.
conjunction<is_pointer<_Ty>, is_same<_Ty, _Elem>>,
// We're finding a nullptr in a range of (object or function) pointers.
conjunction<is_same<_Ty, nullptr_t>, is_pointer<_Elem>>,
// We're finding an object pointer in a range of object pointers, and:
// - One of the pointer types is a cv void*.
// - One of the pointer types is a cv1 U* and the other is a cv2 U*.
_Vector_alg_in_find_is_safe_object_pointers<_Ty, _Elem>>;
// The elements and the value of a certain matching types.
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
&& _Vector_alg_in_find_is_safe_elem<_Ty, _Iter_value_t<_Iter>>;

template <class _InIt, class _Ty>
_NODISCARD constexpr bool _Could_compare_equal_to_value_type(const _Ty& _Val) {
Expand Down
71 changes: 71 additions & 0 deletions stl/src/vector_algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2322,6 +2322,77 @@ namespace {

extern "C" {

__declspec(noalias) void __stdcall __std_replace_4(
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
void* _First, void* const _Last, const uint32_t _Old_val, const uint32_t _New_val) noexcept {
if (_Use_avx2()) {
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
const __m256i _Comparand = _mm256_broadcastd_epi32(_mm_cvtsi32_si128(_Old_val));
const __m256i _Replacement = _mm256_broadcastd_epi32(_mm_cvtsi32_si128(_New_val));
const size_t _Full_length = _Byte_length(_First, _Last);

void* _Stop_at = _First;
_Advance_bytes(_Stop_at, _Full_length & ~size_t{0x1F});

while (_First != _Stop_at) {
const __m256i _Data = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(_First));
const __m256i _Mask = _mm256_cmpeq_epi32(_Comparand, _Data);
_mm256_maskstore_epi32(reinterpret_cast<int*>(_First), _Mask, _Replacement);

_Advance_bytes(_First, 32);
}

if (const size_t _Tail_length = _Full_length & 0x1C; _Tail_length != 0) {
const __m256i _Tail_mask = _Avx2_tail_mask_32(_Tail_length >> 2);
const __m256i _Data = _mm256_maskload_epi32(reinterpret_cast<const int*>(_First), _Tail_mask);
const __m256i _Mask = _mm256_and_si256(_mm256_cmpeq_epi32(_Comparand, _Data), _Tail_mask);
_mm256_maskstore_epi32(reinterpret_cast<int*>(_First), _Mask, _Replacement);
}
} else {
for (auto _Cur = reinterpret_cast<uint32_t*>(_First); _Cur != _Last; ++_Cur) {
if (*_Cur == _Old_val) {
*_Cur = _New_val;
}
}
}
}

__declspec(noalias) void __stdcall __std_replace_8(
void* _First, void* const _Last, const uint64_t _Old_val, const uint64_t _New_val) noexcept {
if (_Use_avx2()) {
#ifdef _WIN64
const __m256i _Comparand = _mm256_broadcastq_epi64(_mm_cvtsi64_si128(_Old_val));
const __m256i _Replacement = _mm256_broadcastq_epi64(_mm_cvtsi64_si128(_New_val));
#else // ^^^ defined(_WIN64) / !defined(_WIN64), workaround, _mm_cvtsi64_si128 does not compile vvv
const __m256i _Comparand = _mm256_set1_epi64x(_Old_val);
const __m256i _Replacement = _mm256_set1_epi64x(_New_val);
#endif // ^^^ !defined(_WIN64) ^^^
const size_t _Full_length = _Byte_length(_First, _Last);

void* _Stop_at = _First;
_Advance_bytes(_Stop_at, _Full_length & ~size_t{0x1F});

while (_First != _Stop_at) {
const __m256i _Data = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(_First));
const __m256i _Mask = _mm256_cmpeq_epi64(_Comparand, _Data);
_mm256_maskstore_epi64(reinterpret_cast<long long*>(_First), _Mask, _Replacement);

_Advance_bytes(_First, 32);
}

if (const size_t _Tail_length = _Full_length & 0x18; _Tail_length != 0) {
const __m256i _Tail_mask = _Avx2_tail_mask_32(_Tail_length >> 2);
const __m256i _Data = _mm256_maskload_epi64(reinterpret_cast<const long long*>(_First), _Tail_mask);
const __m256i _Mask = _mm256_and_si256(_mm256_cmpeq_epi64(_Comparand, _Data), _Tail_mask);
_mm256_maskstore_epi64(reinterpret_cast<long long*>(_First), _Mask, _Replacement);
}
} else {
for (auto _Cur = reinterpret_cast<uint64_t*>(_First); _Cur != _Last; ++_Cur) {
if (*_Cur == _Old_val) {
*_Cur = _New_val;
}
}
}
}

__declspec(noalias) void __stdcall __std_bitset_to_string_1(
char* const _Dest, const void* _Src, size_t _Size_bits, const char _Elem0, const char _Elem1) noexcept {
#ifndef _M_ARM64EC
Expand Down
49 changes: 49 additions & 0 deletions tests/std/tests/VSO_0000000_vector_algorithms/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,45 @@ namespace test_mismatch_sizes_and_alignments {
}
} // namespace test_mismatch_sizes_and_alignments

template <class FwdIt, class T>
void last_known_good_replace(FwdIt first, FwdIt last, const T old_val, const T new_val) {
for (; first != last; ++first) {
if (*first == old_val) {
*first = new_val;
}
}
}

template <class T>
void test_case_replace(const vector<T>& input, T old_val, T new_val) {
vector<T> replaced_actual(input);
vector<T> replaced_expected(input);
replace(replaced_actual.begin(), replaced_actual.end(), old_val, new_val);
last_known_good_replace(replaced_expected.begin(), replaced_expected.end(), old_val, new_val);
assert(replaced_expected == replaced_actual);

#if _HAS_CXX20
vector<T> replaced_actual_r(input);
ranges::replace(replaced_actual_r, old_val, new_val);
assert(replaced_expected == replaced_actual_r);
#endif // _HAS_CXX20
}

template <class T>
void test_replace(mt19937_64& gen) {
using TD = conditional_t<sizeof(T) == 1, int, T>;
uniform_int_distribution<TD> dis(0, 9);
vector<T> input;

input.reserve(dataCount);

test_case_replace(input, static_cast<T>(dis(gen)), static_cast<T>(dis(gen)));
for (size_t i = 0; i != dataCount; ++i) {
input.push_back(static_cast<T>(dis(gen)));
test_case_replace(input, static_cast<T>(dis(gen)), static_cast<T>(dis(gen)));
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
}
}

template <class BidIt>
void last_known_good_reverse(BidIt first, BidIt last) {
for (; first != last && first != --last; ++first) {
Expand Down Expand Up @@ -728,6 +767,16 @@ void test_vector_algorithms(mt19937_64& gen) {
test_mismatch_sizes_and_alignments::test<int>();
test_mismatch_sizes_and_alignments::test<long long>();

test_replace<char>(gen);
test_replace<signed char>(gen);
test_replace<unsigned char>(gen);
test_replace<short>(gen);
test_replace<unsigned short>(gen);
test_replace<int>(gen);
test_replace<unsigned int>(gen);
test_replace<long long>(gen);
test_replace<unsigned long long>(gen);
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved

test_reverse<char>(gen);
test_reverse<signed char>(gen);
test_reverse<unsigned char>(gen);
Expand Down