Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize basic_string::rfind (the string needle overload) #5057

Merged
merged 6 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 41 additions & 15 deletions benchmarks/src/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
#include <vector>

#include "lorem.hpp"
using namespace std::string_view_literals;
#include "skewed_allocator.hpp"

using namespace std::string_view_literals;

template <size_t Size, bool Last_is_different>
constexpr auto make_fill_pattern_array() {
Expand Down Expand Up @@ -48,12 +49,18 @@ constexpr data_and_pattern patterns[] = {
/* 5. Large, evil */ {fill_pattern_view<3000, false>, fill_pattern_view<20, true>},
};

template <class T>
using not_highly_aligned_basic_string = std::basic_string<T, std::char_traits<T>, not_highly_aligned_allocator<T>>;

using not_highly_aligned_string = not_highly_aligned_basic_string<char>;
using not_highly_aligned_wstring = not_highly_aligned_basic_string<wchar_t>;

void c_strstr(benchmark::State& state) {
const auto& src_haystack = patterns[static_cast<size_t>(state.range())].data;
const auto& src_needle = patterns[static_cast<size_t>(state.range())].pattern;

const std::string haystack(src_haystack);
const std::string needle(src_needle);
const not_highly_aligned_string haystack(src_haystack);
const not_highly_aligned_string needle(src_needle);

for (auto _ : state) {
benchmark::DoNotOptimize(haystack);
Expand All @@ -68,8 +75,8 @@ void classic_search(benchmark::State& state) {
const auto& src_haystack = patterns[static_cast<size_t>(state.range())].data;
const auto& src_needle = patterns[static_cast<size_t>(state.range())].pattern;

const std::vector<T> haystack(src_haystack.begin(), src_haystack.end());
const std::vector<T> needle(src_needle.begin(), src_needle.end());
const std::vector<T, not_highly_aligned_allocator<T>> haystack(src_haystack.begin(), src_haystack.end());
const std::vector<T, not_highly_aligned_allocator<T>> needle(src_needle.begin(), src_needle.end());

for (auto _ : state) {
benchmark::DoNotOptimize(haystack);
Expand All @@ -84,8 +91,8 @@ void ranges_search(benchmark::State& state) {
const auto& src_haystack = patterns[static_cast<size_t>(state.range())].data;
const auto& src_needle = patterns[static_cast<size_t>(state.range())].pattern;

const std::vector<T> haystack(src_haystack.begin(), src_haystack.end());
const std::vector<T> needle(src_needle.begin(), src_needle.end());
const std::vector<T, not_highly_aligned_allocator<T>> haystack(src_haystack.begin(), src_haystack.end());
const std::vector<T, not_highly_aligned_allocator<T>> needle(src_needle.begin(), src_needle.end());

for (auto _ : state) {
benchmark::DoNotOptimize(haystack);
Expand All @@ -100,8 +107,8 @@ void search_default_searcher(benchmark::State& state) {
const auto& src_haystack = patterns[static_cast<size_t>(state.range())].data;
const auto& src_needle = patterns[static_cast<size_t>(state.range())].pattern;

const std::vector<T> haystack(src_haystack.begin(), src_haystack.end());
const std::vector<T> needle(src_needle.begin(), src_needle.end());
const std::vector<T, not_highly_aligned_allocator<T>> haystack(src_haystack.begin(), src_haystack.end());
const std::vector<T, not_highly_aligned_allocator<T>> needle(src_needle.begin(), src_needle.end());

for (auto _ : state) {
benchmark::DoNotOptimize(haystack);
Expand Down Expand Up @@ -132,8 +139,8 @@ void classic_find_end(benchmark::State& state) {
const auto& src_haystack = patterns[static_cast<size_t>(state.range())].data;
const auto& src_needle = patterns[static_cast<size_t>(state.range())].pattern;

const std::vector<T> haystack(src_haystack.begin(), src_haystack.end());
const std::vector<T> needle(src_needle.begin(), src_needle.end());
const std::vector<T, not_highly_aligned_allocator<T>> haystack(src_haystack.begin(), src_haystack.end());
const std::vector<T, not_highly_aligned_allocator<T>> needle(src_needle.begin(), src_needle.end());

for (auto _ : state) {
benchmark::DoNotOptimize(haystack);
Expand All @@ -148,8 +155,8 @@ void ranges_find_end(benchmark::State& state) {
const auto& src_haystack = patterns[static_cast<size_t>(state.range())].data;
const auto& src_needle = patterns[static_cast<size_t>(state.range())].pattern;

const std::vector<T> haystack(src_haystack.begin(), src_haystack.end());
const std::vector<T> needle(src_needle.begin(), src_needle.end());
const std::vector<T, not_highly_aligned_allocator<T>> haystack(src_haystack.begin(), src_haystack.end());
const std::vector<T, not_highly_aligned_allocator<T>> needle(src_needle.begin(), src_needle.end());

for (auto _ : state) {
benchmark::DoNotOptimize(haystack);
Expand All @@ -159,6 +166,22 @@ void ranges_find_end(benchmark::State& state) {
}
}

template <class T>
void member_rfind(benchmark::State& state) {
const auto& src_haystack = patterns[static_cast<size_t>(state.range())].data;
const auto& src_needle = patterns[static_cast<size_t>(state.range())].pattern;

const T haystack(src_haystack.begin(), src_haystack.end());
const T needle(src_needle.begin(), src_needle.end());

for (auto _ : state) {
benchmark::DoNotOptimize(haystack);
benchmark::DoNotOptimize(needle);
auto res = haystack.rfind(needle);
benchmark::DoNotOptimize(res);
}
}

void common_args(auto bm) {
bm->DenseRange(0, std::size(patterns) - 1, 1);
}
Expand All @@ -174,13 +197,16 @@ BENCHMARK(ranges_search<std::uint16_t>)->Apply(common_args);
BENCHMARK(search_default_searcher<std::uint8_t>)->Apply(common_args);
BENCHMARK(search_default_searcher<std::uint16_t>)->Apply(common_args);

BENCHMARK(member_find<std::string>)->Apply(common_args);
BENCHMARK(member_find<std::wstring>)->Apply(common_args);
BENCHMARK(member_find<not_highly_aligned_string>)->Apply(common_args);
BENCHMARK(member_find<not_highly_aligned_wstring>)->Apply(common_args);

BENCHMARK(classic_find_end<std::uint8_t>)->Apply(common_args);
BENCHMARK(classic_find_end<std::uint16_t>)->Apply(common_args);

BENCHMARK(ranges_find_end<std::uint8_t>)->Apply(common_args);
BENCHMARK(ranges_find_end<std::uint16_t>)->Apply(common_args);

BENCHMARK(member_rfind<not_highly_aligned_string>)->Apply(common_args);
BENCHMARK(member_rfind<not_highly_aligned_wstring>)->Apply(common_args);

BENCHMARK_MAIN();
23 changes: 22 additions & 1 deletion stl/inc/__msvc_string_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,28 @@ constexpr size_t _Traits_rfind(_In_reads_(_Hay_size) const _Traits_ptr_t<_Traits
return static_cast<size_t>(-1);
}

for (auto _Match_try = _Haystack + (_STD min)(_Start_at, _Hay_size - _Needle_size);; --_Match_try) {
const size_t _Actual_start_at = (_STD min)(_Start_at, _Hay_size - _Needle_size);

#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Is_implementation_handled_char_traits<_Traits> && sizeof(typename _Traits::char_type) <= 2) {
if (!_STD _Is_constant_evaluated()) {
// _Find_end_vectorized takes into accout the needle length when locates search start.
// As a potentially eearlier start position can be specified, need to take it into account,
// and pick between the maximum possible start position, and the specified one,
// and then add _Needle_size, so that it is subtracted back in _Find_end_vectorized.
StephanTLavavej marked this conversation as resolved.
Show resolved Hide resolved
const auto _End = _Haystack + _Actual_start_at + _Needle_size;
const auto _Ptr = _STD _Find_end_vectorized(_Haystack, _End, _Needle, _Needle_size);

if (_Ptr != _End) {
return static_cast<size_t>(_Ptr - _Haystack);
} else {
return static_cast<size_t>(-1);
}
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS

for (auto _Match_try = _Haystack + _Actual_start_at;; --_Match_try) {
if (_Traits::eq(*_Match_try, *_Needle) && _Traits::compare(_Match_try, _Needle, _Needle_size) == 0) {
return static_cast<size_t>(_Match_try - _Haystack); // found a match
}
Expand Down
18 changes: 0 additions & 18 deletions stl/inc/algorithm
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,6 @@ const void* __stdcall __std_find_last_trivial_2(const void* _First, const void*
const void* __stdcall __std_find_last_trivial_4(const void* _First, const void* _Last, uint32_t _Val) noexcept;
const void* __stdcall __std_find_last_trivial_8(const void* _First, const void* _Last, uint64_t _Val) noexcept;

const void* __stdcall __std_find_end_1(
const void* _First1, const void* _Last1, const void* _First2, size_t _Count2) noexcept;
const void* __stdcall __std_find_end_2(
const void* _First1, const void* _Last1, const void* _First2, size_t _Count2) noexcept;

__declspec(noalias) _Min_max_1i __stdcall __std_minmax_1i(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_1u __stdcall __std_minmax_1u(const void* _First, const void* _Last) noexcept;
__declspec(noalias) _Min_max_2i __stdcall __std_minmax_2i(const void* _First, const void* _Last) noexcept;
Expand Down Expand Up @@ -194,19 +189,6 @@ _Ty* _Find_last_vectorized(_Ty* const _First, _Ty* const _Last, const _TVal _Val
}
}

template <class _Ty1, class _Ty2>
_Ty1* _Find_end_vectorized(
_Ty1* const _First1, _Ty1* const _Last1, _Ty2* const _First2, const size_t _Count2) noexcept {
_STL_INTERNAL_STATIC_ASSERT(sizeof(_Ty1) == sizeof(_Ty2));
if constexpr (sizeof(_Ty1) == 1) {
return const_cast<_Ty1*>(static_cast<const _Ty1*>(::__std_find_end_1(_First1, _Last1, _First2, _Count2)));
} else if constexpr (sizeof(_Ty1) == 2) {
return const_cast<_Ty1*>(static_cast<const _Ty1*>(::__std_find_end_2(_First1, _Last1, _First2, _Count2)));
} else {
_STL_INTERNAL_STATIC_ASSERT(false); // unexpected size
}
}

template <class _Ty, class _TVal1, class _TVal2>
__declspec(noalias) void _Replace_vectorized(
_Ty* const _First, _Ty* const _Last, const _TVal1 _Old_val, const _TVal2 _New_val) noexcept {
Expand Down
18 changes: 18 additions & 0 deletions stl/inc/xutility
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ const void* __stdcall __std_search_1(
const void* __stdcall __std_search_2(
const void* _First1, const void* _Last1, const void* _First2, size_t _Count2) noexcept;

const void* __stdcall __std_find_end_1(
const void* _First1, const void* _Last1, const void* _First2, size_t _Count2) noexcept;
const void* __stdcall __std_find_end_2(
const void* _First1, const void* _Last1, const void* _First2, size_t _Count2) noexcept;

const void* __stdcall __std_min_element_1(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_min_element_2(const void* _First, const void* _Last, bool _Signed) noexcept;
const void* __stdcall __std_min_element_4(const void* _First, const void* _Last, bool _Signed) noexcept;
Expand Down Expand Up @@ -248,6 +253,19 @@ _Ty1* _Search_vectorized(_Ty1* const _First1, _Ty1* const _Last1, _Ty2* const _F
}
}

template <class _Ty1, class _Ty2>
_Ty1* _Find_end_vectorized(
_Ty1* const _First1, _Ty1* const _Last1, _Ty2* const _First2, const size_t _Count2) noexcept {
_STL_INTERNAL_STATIC_ASSERT(sizeof(_Ty1) == sizeof(_Ty2));
if constexpr (sizeof(_Ty1) == 1) {
return const_cast<_Ty1*>(static_cast<const _Ty1*>(::__std_find_end_1(_First1, _Last1, _First2, _Count2)));
} else if constexpr (sizeof(_Ty1) == 2) {
return const_cast<_Ty1*>(static_cast<const _Ty1*>(::__std_find_end_2(_First1, _Last1, _First2, _Count2)));
} else {
_STL_INTERNAL_STATIC_ASSERT(false); // unexpected size
}
}

template <class _Ty>
_Ty* _Min_element_vectorized(_Ty* const _First, _Ty* const _Last) noexcept {
constexpr bool _Signed = is_signed_v<_Ty>;
Expand Down
22 changes: 22 additions & 0 deletions tests/std/tests/VSO_0000000_vector_algorithms/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,25 @@ void test_case_string_find_str(const basic_string<T>& input_haystack, const basi
assert(expected == actual);
}

template <class T>
void test_case_string_rfind_str(const basic_string<T>& input_haystack, const basic_string<T>& input_needle) {
ptrdiff_t expected;
if (input_needle.empty()) {
expected = static_cast<ptrdiff_t>(input_haystack.size());
} else {
const auto expected_iter = last_known_good_find_end(
input_haystack.begin(), input_haystack.end(), input_needle.begin(), input_needle.end());

if (expected_iter != input_haystack.end()) {
expected = expected_iter - input_haystack.begin();
} else {
expected = -1;
}
}
const auto actual = static_cast<ptrdiff_t>(input_haystack.rfind(input_needle));
assert(expected == actual);
}

template <class T, class D>
void test_basic_string_dis(mt19937_64& gen, D& dis) {
basic_string<T> input_haystack;
Expand All @@ -1362,12 +1381,14 @@ void test_basic_string_dis(mt19937_64& gen, D& dis) {
test_case_string_find_first_of(input_haystack, input_needle);
test_case_string_find_last_of(input_haystack, input_needle);
test_case_string_find_str(input_haystack, input_needle);
test_case_string_rfind_str(input_haystack, input_needle);

for (size_t attempts = 0; attempts < needleDataCount; ++attempts) {
input_needle.push_back(static_cast<T>(dis(gen)));
test_case_string_find_first_of(input_haystack, input_needle);
test_case_string_find_last_of(input_haystack, input_needle);
test_case_string_find_str(input_haystack, input_needle);
test_case_string_rfind_str(input_haystack, input_needle);

// For large needles the chance of a match is low, so test a guaranteed match
if (input_haystack.size() > input_needle.size() * 2) {
Expand All @@ -1377,6 +1398,7 @@ void test_basic_string_dis(mt19937_64& gen, D& dis) {
temp.assign(overwritten_first, overwritten_first + static_cast<ptrdiff_t>(input_needle.size()));
copy(input_needle.begin(), input_needle.end(), overwritten_first);
test_case_string_find_str(input_haystack, input_needle);
test_case_string_rfind_str(input_haystack, input_needle);
copy(temp.begin(), temp.end(), overwritten_first);
}
}
Expand Down