diff --git a/core/src/Kokkos_Parallel_Reduce.hpp b/core/src/Kokkos_Parallel_Reduce.hpp index 2a3c39cbba..7fe539c4c6 100644 --- a/core/src/Kokkos_Parallel_Reduce.hpp +++ b/core/src/Kokkos_Parallel_Reduce.hpp @@ -1437,6 +1437,12 @@ template <> struct implements_new_reduce_interface : std::true_type {}; #endif +#ifdef KOKKOS_ENABLE_OPENMPTARGET +template <> +struct implements_new_reduce_interface + : std::true_type {}; +#endif + #ifdef KOKKOS_ENABLE_CUDA template <> struct implements_new_reduce_interface : std::true_type {}; diff --git a/core/src/OpenMPTarget/Kokkos_OpenMPTarget_ParallelReduce_Range.hpp b/core/src/OpenMPTarget/Kokkos_OpenMPTarget_ParallelReduce_Range.hpp index 9153402596..e12240208e 100644 --- a/core/src/OpenMPTarget/Kokkos_OpenMPTarget_ParallelReduce_Range.hpp +++ b/core/src/OpenMPTarget/Kokkos_OpenMPTarget_ParallelReduce_Range.hpp @@ -25,36 +25,33 @@ namespace Kokkos { namespace Impl { -template -class ParallelReduce, ReducerType, +template +class ParallelReduce, Kokkos::Experimental::OpenMPTarget> { private: - using Policy = Kokkos::RangePolicy; + using Policy = Kokkos::RangePolicy; + using FunctorType = typename CombinedFunctorReducerType::functor_type; + using ReducerType = typename CombinedFunctorReducerType::reducer_type; using WorkTag = typename Policy::work_tag; - using ReducerTypeFwd = - std::conditional_t::value, - FunctorType, ReducerType>; - using Analysis = Impl::FunctorAnalysis; + using pointer_type = typename ReducerType::pointer_type; + using reference_type = typename ReducerType::reference_type; - using pointer_type = typename Analysis::pointer_type; - using reference_type = typename Analysis::reference_type; - - static constexpr int HasJoin = + static constexpr int FunctorHasJoin = Impl::FunctorAnalysis::has_join_member_function; - static constexpr int UseReducer = is_reducer::value; - static constexpr int IsArray = std::is_pointer::value; + static constexpr int UseReducer = + !std::is_same_v; + static constexpr int IsArray = std::is_pointer::value; using ParReduceSpecialize = - ParallelReduceSpecialize; + ParallelReduceSpecialize; - const FunctorType m_functor; + const CombinedFunctorReducerType m_functor_reducer; const Policy m_policy; - const ReducerType m_reducer; const pointer_type m_result_ptr; bool m_result_ptr_on_device; const int m_result_ptr_num_elems; @@ -62,68 +59,53 @@ class ParallelReduce, ReducerType, public: void execute() const { - if constexpr (HasJoin) { + const FunctorType& functor = m_functor_reducer.get_functor(); + if constexpr (FunctorHasJoin) { // Enter this loop if the Functor has a init-join. - ParReduceSpecialize::execute_init_join(m_functor, m_policy, m_result_ptr, + ParReduceSpecialize::execute_init_join(functor, m_policy, m_result_ptr, m_result_ptr_on_device); } else if constexpr (UseReducer) { // Enter this loop if the Functor is a reducer type. - ParReduceSpecialize::execute_reducer(m_functor, m_policy, m_result_ptr, + ParReduceSpecialize::execute_reducer(functor, m_policy, m_result_ptr, m_result_ptr_on_device); } else if constexpr (IsArray) { // Enter this loop if the reduction is on an array and the routine is // templated over the size of the array. if (m_result_ptr_num_elems <= 2) { ParReduceSpecialize::template execute_array( - m_functor, m_policy, m_result_ptr, m_result_ptr_on_device); + functor, m_policy, m_result_ptr, m_result_ptr_on_device); } else if (m_result_ptr_num_elems <= 4) { ParReduceSpecialize::template execute_array( - m_functor, m_policy, m_result_ptr, m_result_ptr_on_device); + functor, m_policy, m_result_ptr, m_result_ptr_on_device); } else if (m_result_ptr_num_elems <= 8) { ParReduceSpecialize::template execute_array( - m_functor, m_policy, m_result_ptr, m_result_ptr_on_device); + functor, m_policy, m_result_ptr, m_result_ptr_on_device); } else if (m_result_ptr_num_elems <= 16) { ParReduceSpecialize::template execute_array( - m_functor, m_policy, m_result_ptr, m_result_ptr_on_device); + functor, m_policy, m_result_ptr, m_result_ptr_on_device); } else if (m_result_ptr_num_elems <= 32) { ParReduceSpecialize::template execute_array( - m_functor, m_policy, m_result_ptr, m_result_ptr_on_device); + functor, m_policy, m_result_ptr, m_result_ptr_on_device); } else { Kokkos::abort("array reduction length must be <= 32"); } } else { // This loop handles the basic scalar reduction. ParReduceSpecialize::template execute_array( - m_functor, m_policy, m_result_ptr, m_result_ptr_on_device); + functor, m_policy, m_result_ptr, m_result_ptr_on_device); } } template - ParallelReduce(const FunctorType& arg_functor, const Policy& arg_policy, - const ViewType& arg_result_view, - std::enable_if_t::value && - !Kokkos::is_reducer::value, - void*> = nullptr) - : m_functor(arg_functor), + ParallelReduce(const CombinedFunctorReducerType& arg_functor_reducer, + const Policy& arg_policy, const ViewType& arg_result_view) + : m_functor_reducer(arg_functor_reducer), m_policy(arg_policy), - m_reducer(InvalidType()), m_result_ptr(arg_result_view.data()), m_result_ptr_on_device( MemorySpaceAccess::accessible), m_result_ptr_num_elems(arg_result_view.size()) {} - - ParallelReduce(const FunctorType& arg_functor, const Policy& arg_policy, - const ReducerType& reducer) - : m_functor(arg_functor), - m_policy(arg_policy), - m_reducer(reducer), - m_result_ptr(reducer.view().data()), - m_result_ptr_on_device( - MemorySpaceAccess::accessible), - m_result_ptr_num_elems(reducer.view().size()) {} }; } // namespace Impl diff --git a/core/src/OpenMPTarget/Kokkos_OpenMPTarget_ParallelReduce_Team.hpp b/core/src/OpenMPTarget/Kokkos_OpenMPTarget_ParallelReduce_Team.hpp index 39d452864a..417a53505f 100644 --- a/core/src/OpenMPTarget/Kokkos_OpenMPTarget_ParallelReduce_Team.hpp +++ b/core/src/OpenMPTarget/Kokkos_OpenMPTarget_ParallelReduce_Team.hpp @@ -432,113 +432,92 @@ parallel_reduce(const Impl::TeamVectorRangeBoundariesStruct< namespace Impl { -template -class ParallelReduce, - ReducerType, Kokkos::Experimental::OpenMPTarget> { +template +class ParallelReduce, + Kokkos::Experimental::OpenMPTarget> { private: using Policy = Kokkos::Impl::TeamPolicyInternal; + using FunctorType = typename CombinedFunctorReducerType::functor_type; + using ReducerType = typename CombinedFunctorReducerType::reducer_type; using WorkTag = typename Policy::work_tag; using Member = typename Policy::member_type; - using ReducerTypeFwd = - std::conditional_t::value, - FunctorType, ReducerType>; - using WorkTagFwd = - std::conditional_t::value, WorkTag, - void>; - using Analysis = Impl::FunctorAnalysis; - - using pointer_type = typename Analysis::pointer_type; - using reference_type = typename Analysis::reference_type; - using value_type = typename Analysis::value_type; + + using pointer_type = typename ReducerType::pointer_type; + using reference_type = typename ReducerType::reference_type; + using value_type = typename ReducerType::value_type; bool m_result_ptr_on_device; const int m_result_ptr_num_elems; - static constexpr int HasJoin = + static constexpr int FunctorHasJoin = Impl::FunctorAnalysis::has_join_member_function; - static constexpr int UseReducer = is_reducer::value; - static constexpr int IsArray = std::is_pointer::value; + static constexpr int UseReducer = + !std::is_same_v; + static constexpr int IsArray = std::is_pointer::value; using ParReduceSpecialize = - ParallelReduceSpecialize; + ParallelReduceSpecialize; - const FunctorType m_functor; + const CombinedFunctorReducerType m_functor_reducer; const Policy m_policy; - const ReducerType m_reducer; const pointer_type m_result_ptr; const size_t m_shmem_size; public: void execute() const { - if constexpr (HasJoin) { - ParReduceSpecialize::execute_init_join(m_functor, m_policy, m_result_ptr, + const FunctorType& functor = m_functor_reducer.get_functor(); + if constexpr (FunctorHasJoin) { + ParReduceSpecialize::execute_init_join(functor, m_policy, m_result_ptr, m_result_ptr_on_device); } else if constexpr (UseReducer) { - ParReduceSpecialize::execute_reducer(m_functor, m_policy, m_result_ptr, + ParReduceSpecialize::execute_reducer(functor, m_policy, m_result_ptr, m_result_ptr_on_device); } else if constexpr (IsArray) { if (m_result_ptr_num_elems <= 2) { ParReduceSpecialize::template execute_array<2>( - m_functor, m_policy, m_result_ptr, m_result_ptr_on_device); + functor, m_policy, m_result_ptr, m_result_ptr_on_device); } else if (m_result_ptr_num_elems <= 4) { ParReduceSpecialize::template execute_array<4>( - m_functor, m_policy, m_result_ptr, m_result_ptr_on_device); + functor, m_policy, m_result_ptr, m_result_ptr_on_device); } else if (m_result_ptr_num_elems <= 8) { ParReduceSpecialize::template execute_array<8>( - m_functor, m_policy, m_result_ptr, m_result_ptr_on_device); + functor, m_policy, m_result_ptr, m_result_ptr_on_device); } else if (m_result_ptr_num_elems <= 16) { ParReduceSpecialize::template execute_array<16>( - m_functor, m_policy, m_result_ptr, m_result_ptr_on_device); + functor, m_policy, m_result_ptr, m_result_ptr_on_device); } else if (m_result_ptr_num_elems <= 32) { ParReduceSpecialize::template execute_array<32>( - m_functor, m_policy, m_result_ptr, m_result_ptr_on_device); + functor, m_policy, m_result_ptr, m_result_ptr_on_device); } else { Kokkos::abort("array reduction length must be <= 32"); } } else { ParReduceSpecialize::template execute_array<1>( - m_functor, m_policy, m_result_ptr, m_result_ptr_on_device); + functor, m_policy, m_result_ptr, m_result_ptr_on_device); } } template - ParallelReduce(const FunctorType& arg_functor, const Policy& arg_policy, - const ViewType& arg_result, - std::enable_if_t::value && - !Kokkos::is_reducer::value, - void*> = nullptr) + ParallelReduce(const CombinedFunctorReducerType& arg_functor_reducer, + const Policy& arg_policy, const ViewType& arg_result) : m_result_ptr_on_device( MemorySpaceAccess::accessible), m_result_ptr_num_elems(arg_result.size()), - m_functor(arg_functor), + m_functor_reducer(arg_functor_reducer), m_policy(arg_policy), - m_reducer(InvalidType()), m_result_ptr(arg_result.data()), - m_shmem_size(arg_policy.scratch_size(0) + arg_policy.scratch_size(1) + - FunctorTeamShmemSize::value( - arg_functor, arg_policy.team_size())) {} - - ParallelReduce(const FunctorType& arg_functor, const Policy& arg_policy, - const ReducerType& reducer) - : m_result_ptr_on_device( - MemorySpaceAccess::accessible), - m_result_ptr_num_elems(reducer.view().size()), - m_functor(arg_functor), - m_policy(arg_policy), - m_reducer(reducer), - m_result_ptr(reducer.view().data()), - m_shmem_size(arg_policy.scratch_size(0) + arg_policy.scratch_size(1) + - FunctorTeamShmemSize::value( - arg_functor, arg_policy.team_size())) {} + m_shmem_size( + arg_policy.scratch_size(0) + arg_policy.scratch_size(1) + + FunctorTeamShmemSize::value( + arg_functor_reducer.get_functor(), arg_policy.team_size())) {} }; } // namespace Impl diff --git a/core/src/OpenMPTarget/Kokkos_OpenMPTarget_ParallelScan_Range.hpp b/core/src/OpenMPTarget/Kokkos_OpenMPTarget_ParallelScan_Range.hpp index 1900260e2a..e9a52f8e21 100644 --- a/core/src/OpenMPTarget/Kokkos_OpenMPTarget_ParallelScan_Range.hpp +++ b/core/src/OpenMPTarget/Kokkos_OpenMPTarget_ParallelScan_Range.hpp @@ -41,7 +41,8 @@ class ParallelScan, using pointer_type = typename Analysis::pointer_type; using reference_type = typename Analysis::reference_type; - const FunctorType m_functor; + const CombinedFunctorReducer + m_functor_reducer; const Policy m_policy; value_type* m_result_ptr; @@ -75,10 +76,12 @@ class ParallelScan, idx_type nteams = n_chunks > 512 ? 512 : n_chunks; idx_type team_size = 128; - FunctorType a_functor(m_functor); -#pragma omp target teams distribute map(to : a_functor) num_teams(nteams) + auto a_functor_reducer = m_functor_reducer; +#pragma omp target teams distribute map(to \ + : a_functor_reducer) num_teams(nteams) for (idx_type team_id = 0; team_id < n_chunks; ++team_id) { - typename Analysis::Reducer final_reducer(a_functor); + const typename Analysis::Reducer& reducer = + a_functor_reducer.get_reducer(); #pragma omp parallel num_threads(team_size) { const idx_type local_offset = team_id * chunk_size; @@ -87,16 +90,18 @@ class ParallelScan, for (idx_type i = 0; i < chunk_size; ++i) { const idx_type idx = local_offset + i; value_type val; - final_reducer.init(&val); - if (idx < N) call_with_tag(a_functor, idx, val, false); + reducer.init(&val); + if (idx < N) + call_with_tag(a_functor_reducer.get_functor(), idx, val, + false); element_values(team_id, i) = val; } #pragma omp barrier if (omp_get_thread_num() == 0) { value_type sum; - final_reducer.init(&sum); + reducer.init(&sum); for (idx_type i = 0; i < chunk_size; ++i) { - final_reducer.join(&sum, &element_values(team_id, i)); + reducer.join(&sum, &element_values(team_id, i)); element_values(team_id, i) = sum; } chunk_values(team_id) = sum; @@ -105,9 +110,9 @@ class ParallelScan, if (omp_get_thread_num() == 0) { if (Kokkos::atomic_fetch_add(&count(), 1) == n_chunks - 1) { value_type sum; - final_reducer.init(&sum); + reducer.init(&sum); for (idx_type i = 0; i < n_chunks; ++i) { - final_reducer.join(&sum, &chunk_values(i)); + reducer.join(&sum, &chunk_values(i)); chunk_values(i) = sum; } } @@ -115,11 +120,12 @@ class ParallelScan, } } -#pragma omp target teams distribute map(to \ - : a_functor) num_teams(nteams) \ +#pragma omp target teams distribute map(to \ + : a_functor_reducer) num_teams(nteams) \ thread_limit(team_size) for (idx_type team_id = 0; team_id < n_chunks; ++team_id) { - typename Analysis::Reducer final_reducer(a_functor); + const typename Analysis::Reducer& reducer = + a_functor_reducer.get_reducer(); #pragma omp parallel num_threads(team_size) { const idx_type local_offset = team_id * chunk_size; @@ -127,7 +133,7 @@ class ParallelScan, if (team_id > 0) offset_value = chunk_values(team_id - 1); else - final_reducer.init(&offset_value); + reducer.init(&offset_value); #pragma omp for for (idx_type i = 0; i < chunk_size; ++i) { @@ -145,12 +151,13 @@ class ParallelScan, } else local_offset_value += offset_value; #else - final_reducer.join(&local_offset_value, &offset_value); + reducer.join(&local_offset_value, &offset_value); #endif } else local_offset_value = offset_value; if (idx < N) - call_with_tag(a_functor, idx, local_offset_value, true); + call_with_tag(a_functor_reducer.get_functor(), idx, + local_offset_value, true); if (idx == N - 1 && m_result_ptr_device_accessible) *m_result_ptr = local_offset_value; } @@ -184,7 +191,7 @@ class ParallelScan, ParallelScan(const FunctorType& arg_functor, const Policy& arg_policy, pointer_type arg_result_ptr = nullptr, bool arg_result_ptr_device_accessible = false) - : m_functor(arg_functor), + : m_functor_reducer(arg_functor, typename Analysis::Reducer{arg_functor}), m_policy(arg_policy), m_result_ptr(arg_result_ptr), m_result_ptr_device_accessible(arg_result_ptr_device_accessible) {} @@ -227,7 +234,7 @@ class ParallelScanWithTotal, base_t::impl_execute(element_values, chunk_values, count); if (!base_t::m_result_ptr_device_accessible) { - const int size = base_t::Analysis::value_size(base_t::m_functor); + const int size = base_t::m_functor_reducer.get_reducer().value_size(); DeepCopy( base_t::m_result_ptr, chunk_values.data() + (n_chunks - 1), size); } diff --git a/core/src/OpenMPTarget/Kokkos_OpenMPTarget_Parallel_MDRange.hpp b/core/src/OpenMPTarget/Kokkos_OpenMPTarget_Parallel_MDRange.hpp index 251ca20b44..41e62ce6e6 100644 --- a/core/src/OpenMPTarget/Kokkos_OpenMPTarget_Parallel_MDRange.hpp +++ b/core/src/OpenMPTarget/Kokkos_OpenMPTarget_Parallel_MDRange.hpp @@ -411,32 +411,28 @@ class ParallelFor, namespace Kokkos { namespace Impl { -template -class ParallelReduce, ReducerType, +template +class ParallelReduce, Kokkos::Experimental::OpenMPTarget> { private: - using Policy = Kokkos::MDRangePolicy; + using Policy = Kokkos::MDRangePolicy; + using FunctorType = typename CombinedFunctorReducerType::functor_type; + using ReducerType = typename CombinedFunctorReducerType::reducer_type; using WorkTag = typename Policy::work_tag; using Member = typename Policy::member_type; using Index = typename Policy::index_type; - using ReducerConditional = - std::conditional::value, - FunctorType, ReducerType>; - using ReducerTypeFwd = typename ReducerConditional::type; - using Analysis = Impl::FunctorAnalysis; + using pointer_type = typename ReducerType::pointer_type; + using reference_type = typename ReducerType::reference_type; - using pointer_type = typename Analysis::pointer_type; - using reference_type = typename Analysis::reference_type; - - static constexpr bool UseReducer = is_reducer::value; + static constexpr bool UseReducer = + !std::is_same_v; const pointer_type m_result_ptr; - const FunctorType m_functor; + const CombinedFunctorReducerType m_functor_reducer; const Policy m_policy; - const ReducerType m_reducer; using ParReduceCopy = ParallelReduceCopy; @@ -444,36 +440,20 @@ class ParallelReduce, ReducerType, public: inline void execute() const { - execute_tile( - m_functor, m_policy, m_result_ptr); + execute_tile( + m_functor_reducer.get_functor(), m_policy, m_result_ptr); } template - inline ParallelReduce( - const FunctorType& arg_functor, Policy arg_policy, - const ViewType& arg_result_view, - std::enable_if_t::value && - !Kokkos::is_reducer::value, - void*> = NULL) + inline ParallelReduce(const CombinedFunctorReducerType& arg_functor_reducer, + Policy arg_policy, const ViewType& arg_result_view) : m_result_ptr(arg_result_view.data()), - m_functor(arg_functor), + m_functor_reducer(arg_functor_reducer), m_policy(arg_policy), - m_reducer(InvalidType()), m_result_ptr_on_device( MemorySpaceAccess::accessible) {} - inline ParallelReduce(const FunctorType& arg_functor, Policy arg_policy, - const ReducerType& reducer) - : m_result_ptr(reducer.view().data()), - m_functor(arg_functor), - m_policy(arg_policy), - m_reducer(reducer), - m_result_ptr_on_device( - MemorySpaceAccess::accessible) {} - template inline std::enable_if_t execute_tile(const FunctorType& functor, const Policy& policy, @@ -540,10 +520,13 @@ reduction(+:result) // FIXME_OPENMPTARGET: Unable to separate directives and their companion // loops which leads to code duplication for different reduction types. if constexpr (UseReducer) { -#pragma omp declare reduction( \ - custom:ValueType \ - : OpenMPTargetReducerWrapper ::join(omp_out, omp_in)) \ - initializer(OpenMPTargetReducerWrapper ::init(omp_priv)) +#pragma omp declare reduction( \ + custom:ValueType \ + : OpenMPTargetReducerWrapper ::join( \ + omp_out, omp_in)) \ + initializer( \ + OpenMPTargetReducerWrapper ::init( \ + omp_priv)) #pragma omp target teams distribute parallel for collapse(3) map(to \ : functor) \