Skip to content

Commit

Permalink
[oneDPL][ranges] + fixes in ranges::search_n implementation; + a test…
Browse files Browse the repository at this point in the history
… for ranges::search_n
  • Loading branch information
MikeDvorskiy committed Mar 27, 2024
1 parent a5bf788 commit dac7c2a
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 6 deletions.
6 changes: 3 additions & 3 deletions include/oneapi/dpl/pstl/algorithm_ranges_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,12 @@ __pattern_search_n_impl(_Tag __tag, _ExecutionPolicy&& __exec, _R&& __r,
{
static_assert(__is_parallel_tag_v<_Tag> || typename _Tag::__is_vector{});

auto __pred_2 = [__pred, __proj, __value](auto&& __val1, auto&& __val2) { return __pred(__proj(__val1), __val2);};
auto __pred_2 = [__pred, __proj](auto&& __val1, auto&& __val2) { return __pred(__proj(__val1), __val2);};

auto __res = oneapi::dpl::__internal::__pattern_search_n(std::forward<_ExecutionPolicy>(__exec),
auto __res = oneapi::dpl::__internal::__pattern_search_n(__tag, std::forward<_ExecutionPolicy>(__exec),
std::ranges::begin(__r), std::ranges::begin(__r) + __r.size(), __count, __value, __pred_2);

return std::ranges::borrowed_subrange_t<_R>(__res, __res + __count);
return std::ranges::borrowed_subrange_t<_R>(__res, __res == std::ranges::end(__r) ? __res : __res + __count);
}

template<typename _IsVector, typename _ExecutionPolicy, typename _R, typename _T, typename _Pred, typename _Proj>
Expand Down
2 changes: 2 additions & 0 deletions include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ struct search_n_fn
}
}; //search_n_fn

inline constexpr search_n_fn search_n;

} //ranges

#endif //_ONEDPL___cplusplus >= 202002L
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,8 @@ __pattern_search_n(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _
auto __idx = oneapi::dpl::__internal::__ranges::__pattern_search_n(__tag, std::forward<_ExecutionPolicy>(__exec),
oneapi::dpl::views::all_read(::std::forward<_R>(__r)), __count, __value, __pred_2);

return std::ranges::borrowed_subrange_t<_R>(__r.begin() + __idx, __r.begin() + __idx + __count);
auto __end = (__idx == __r.size() ? __r.begin() + __idx : __r.begin() + __idx + __count);
return std::ranges::borrowed_subrange_t<_R>(__r.begin() + __idx, __end);
}
#endif //_ONEDPL___cplusplus >= 202002L

Expand Down
1 change: 1 addition & 0 deletions test/parallel_api/ranges/std_ranges.pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ main()
test_range_algo{}(oneapi::dpl::ranges::adjacent_find, std::ranges::adjacent_find, pred_2, proj);

test_range_algo<data_in_in>{}(oneapi::dpl::ranges::search, std::ranges::search, pred_2, proj);
test_range_algo<data_in_val_n>{}(oneapi::dpl::ranges::search_n, std::ranges::search_n, pred_2, proj);

#endif //_ENABLE_STD_RANGES_TESTING

Expand Down
35 changes: 33 additions & 2 deletions test/parallel_api/ranges/std_ranges_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

#include "support/utils.h"

#define _ENABLE_STD_RANGES_TESTING (_ONEDPL___cplusplus >= 202002L)
#if _ENABLE_STD_RANGES_TESTING

#include <oneapi/dpl/ranges>
Expand Down Expand Up @@ -49,7 +48,8 @@ enum TestDataMode
data_in,
data_in_out,
data_in_in,
data_in_in_out
data_in_in_out,
data_in_val_n,
};

template<typename Container, TestDataMode Ranges = data_in, bool RetTypeCheck = true>
Expand Down Expand Up @@ -161,6 +161,37 @@ struct test
}
}

template<typename Policy, typename Algo, typename Checker, typename FunctorOrVal, typename Proj = std::identity,
typename Transform = std::identity>
std::enable_if_t<!std::is_same_v<Policy, std::true_type> && Ranges == data_in_val_n>
operator()(Policy&& exec, Algo algo, Checker checker, FunctorOrVal f, Proj proj = {}, Transform tr = {})
{
constexpr int max_n = 10;
int data[max_n] = {0, 1, 2, 5, 5, 5, 6, 7, 8, 9};
int expected[max_n] = {0, 1, 2, 5, 5, 5, 6, 7, 8, 9};
int val = 5, n = 3;

auto expected_view = tr(std::ranges::subrange(expected, expected + max_n));
auto expected_res = checker(expected_view, n, val, f, proj);
{
Container cont(exec, data, max_n);
typename Container::type& A = cont();

auto res = algo(exec, tr(A), n, val, f, proj);

//check result
if constexpr(RetTypeCheck)
static_assert(std::is_same_v<decltype(res), decltype(checker(tr(A), n, val, f, proj))>, "Wrong return type");

auto bres = ret_in_val(expected_res, expected_view.begin()) == ret_in_val(res, tr(A).begin());
EXPECT_TRUE(bres, (std::string("wrong return value from algo with ranges: ") + typeid(Algo).name()).c_str());
}

//check result
EXPECT_EQ_N(expected, data, max_n, (std::string("wrong effect algo with ranges: ")
+ typeid(Algo).name() + typeid(decltype(tr(std::declval<Container&>()()))).name()).c_str());
}

private:

template<typename, typename = void>
Expand Down
2 changes: 2 additions & 0 deletions test/support/test_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@
#endif
#endif //!defined(_ENABLE_RANGES_TESTING)

#define _ENABLE_STD_RANGES_TESTING (_ONEDPL___cplusplus >= 202002L)

#define TEST_HAS_NO_INT128
#define _PSTL_TEST_COMPLEX_NON_FLOAT_AVAILABLE (_MSVC_STL_VERSION < 143)

Expand Down

0 comments on commit dac7c2a

Please sign in to comment.