diff --git a/include/oneapi/dpl/pstl/algorithm_ranges_impl.h b/include/oneapi/dpl/pstl/algorithm_ranges_impl.h index a2945584ccc..3bf92fe1eee 100644 --- a/include/oneapi/dpl/pstl/algorithm_ranges_impl.h +++ b/include/oneapi/dpl/pstl/algorithm_ranges_impl.h @@ -274,12 +274,12 @@ __pattern_search_n_impl(_Tag __tag, _ExecutionPolicy&& __exec, _R&& __r, { static_assert(__is_parallel_tag_v<_Tag> || typename _Tag::__is_vector{}); - auto __pred_2 = [__pred, __proj, __value](auto&& __val1, auto&& __val2) { return __pred(__proj(__val1), __val2);}; + auto __pred_2 = [__pred, __proj](auto&& __val1, auto&& __val2) { return __pred(__proj(__val1), __val2);}; - auto __res = oneapi::dpl::__internal::__pattern_search_n(std::forward<_ExecutionPolicy>(__exec), + auto __res = oneapi::dpl::__internal::__pattern_search_n(__tag, std::forward<_ExecutionPolicy>(__exec), std::ranges::begin(__r), std::ranges::begin(__r) + __r.size(), __count, __value, __pred_2); - return std::ranges::borrowed_subrange_t<_R>(__res, __res + __count); + return std::ranges::borrowed_subrange_t<_R>(__res, __res == std::ranges::end(__r) ? __res : __res + __count); } template diff --git a/include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h b/include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h index 5e8ffa3223e..0bd4dd3dffc 100644 --- a/include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h +++ b/include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h @@ -210,6 +210,8 @@ struct search_n_fn } }; //search_n_fn +inline constexpr search_n_fn search_n; + } //ranges #endif //_ONEDPL___cplusplus >= 202002L diff --git a/include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h b/include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h index 209a844f2ed..390b953e81d 100644 --- a/include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h +++ b/include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h @@ -356,7 +356,8 @@ __pattern_search_n(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _ auto __idx = oneapi::dpl::__internal::__ranges::__pattern_search_n(__tag, std::forward<_ExecutionPolicy>(__exec), oneapi::dpl::views::all_read(::std::forward<_R>(__r)), __count, __value, __pred_2); - return std::ranges::borrowed_subrange_t<_R>(__r.begin() + __idx, __r.begin() + __idx + __count); + auto __end = (__idx == __r.size() ? __r.begin() + __idx : __r.begin() + __idx + __count); + return std::ranges::borrowed_subrange_t<_R>(__r.begin() + __idx, __end); } #endif //_ONEDPL___cplusplus >= 202002L diff --git a/test/parallel_api/ranges/std_ranges.pass.cpp b/test/parallel_api/ranges/std_ranges.pass.cpp index f8b9093a553..8e8c6e3a334 100644 --- a/test/parallel_api/ranges/std_ranges.pass.cpp +++ b/test/parallel_api/ranges/std_ranges.pass.cpp @@ -50,6 +50,7 @@ main() test_range_algo{}(oneapi::dpl::ranges::adjacent_find, std::ranges::adjacent_find, pred_2, proj); test_range_algo{}(oneapi::dpl::ranges::search, std::ranges::search, pred_2, proj); + test_range_algo{}(oneapi::dpl::ranges::search_n, std::ranges::search_n, pred_2, proj); #endif //_ENABLE_STD_RANGES_TESTING diff --git a/test/parallel_api/ranges/std_ranges_test.h b/test/parallel_api/ranges/std_ranges_test.h index 1aa777ad831..d632ff65b39 100644 --- a/test/parallel_api/ranges/std_ranges_test.h +++ b/test/parallel_api/ranges/std_ranges_test.h @@ -19,7 +19,6 @@ #include "support/utils.h" -#define _ENABLE_STD_RANGES_TESTING (_ONEDPL___cplusplus >= 202002L) #if _ENABLE_STD_RANGES_TESTING #include @@ -49,7 +48,8 @@ enum TestDataMode data_in, data_in_out, data_in_in, - data_in_in_out + data_in_in_out, + data_in_val_n, }; template @@ -161,6 +161,37 @@ struct test } } + template + std::enable_if_t && Ranges == data_in_val_n> + operator()(Policy&& exec, Algo algo, Checker checker, FunctorOrVal f, Proj proj = {}, Transform tr = {}) + { + constexpr int max_n = 10; + int data[max_n] = {0, 1, 2, 5, 5, 5, 6, 7, 8, 9}; + int expected[max_n] = {0, 1, 2, 5, 5, 5, 6, 7, 8, 9}; + int val = 5, n = 3; + + auto expected_view = tr(std::ranges::subrange(expected, expected + max_n)); + auto expected_res = checker(expected_view, n, val, f, proj); + { + Container cont(exec, data, max_n); + typename Container::type& A = cont(); + + auto res = algo(exec, tr(A), n, val, f, proj); + + //check result + if constexpr(RetTypeCheck) + static_assert(std::is_same_v, "Wrong return type"); + + auto bres = ret_in_val(expected_res, expected_view.begin()) == ret_in_val(res, tr(A).begin()); + EXPECT_TRUE(bres, (std::string("wrong return value from algo with ranges: ") + typeid(Algo).name()).c_str()); + } + + //check result + EXPECT_EQ_N(expected, data, max_n, (std::string("wrong effect algo with ranges: ") + + typeid(Algo).name() + typeid(decltype(tr(std::declval()()))).name()).c_str()); + } + private: template diff --git a/test/support/test_config.h b/test/support/test_config.h index 37c6d433404..968bb065b26 100644 --- a/test/support/test_config.h +++ b/test/support/test_config.h @@ -113,6 +113,8 @@ #endif #endif //!defined(_ENABLE_RANGES_TESTING) +#define _ENABLE_STD_RANGES_TESTING (_ONEDPL___cplusplus >= 202002L) + #define TEST_HAS_NO_INT128 #define _PSTL_TEST_COMPLEX_NON_FLOAT_AVAILABLE (_MSVC_STL_VERSION < 143)