Skip to content

Commit

Permalink
Merge pull request #5954 from masterleinad/pass_functor_analysis_to_p…
Browse files Browse the repository at this point in the history
…arallel_reduce_ompt

Convert OpenMPTarget ParallelReduce and ParallelScan
  • Loading branch information
dalg24 authored Mar 8, 2023
2 parents 9fe93d4 + 42abe36 commit 3707be7
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 160 deletions.
6 changes: 6 additions & 0 deletions core/src/Kokkos_Parallel_Reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,12 @@ template <>
struct implements_new_reduce_interface<Kokkos::Threads> : std::true_type {};
#endif

#ifdef KOKKOS_ENABLE_OPENMPTARGET
template <>
struct implements_new_reduce_interface<Kokkos::Experimental::OpenMPTarget>
: std::true_type {};
#endif

#ifdef KOKKOS_ENABLE_CUDA
template <>
struct implements_new_reduce_interface<Kokkos::Cuda> : std::true_type {};
Expand Down
74 changes: 28 additions & 46 deletions core/src/OpenMPTarget/Kokkos_OpenMPTarget_ParallelReduce_Range.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,105 +25,87 @@
namespace Kokkos {
namespace Impl {

template <class FunctorType, class ReducerType, class... Traits>
class ParallelReduce<FunctorType, Kokkos::RangePolicy<Traits...>, ReducerType,
template <class CombinedFunctorReducerType, class... Traits>
class ParallelReduce<CombinedFunctorReducerType, Kokkos::RangePolicy<Traits...>,
Kokkos::Experimental::OpenMPTarget> {
private:
using Policy = Kokkos::RangePolicy<Traits...>;
using Policy = Kokkos::RangePolicy<Traits...>;
using FunctorType = typename CombinedFunctorReducerType::functor_type;
using ReducerType = typename CombinedFunctorReducerType::reducer_type;

using WorkTag = typename Policy::work_tag;

using ReducerTypeFwd =
std::conditional_t<std::is_same<InvalidType, ReducerType>::value,
FunctorType, ReducerType>;
using Analysis = Impl::FunctorAnalysis<Impl::FunctorPatternInterface::REDUCE,
Policy, ReducerTypeFwd>;
using pointer_type = typename ReducerType::pointer_type;
using reference_type = typename ReducerType::reference_type;

using pointer_type = typename Analysis::pointer_type;
using reference_type = typename Analysis::reference_type;

static constexpr int HasJoin =
static constexpr int FunctorHasJoin =
Impl::FunctorAnalysis<Impl::FunctorPatternInterface::REDUCE, Policy,
FunctorType>::has_join_member_function;
static constexpr int UseReducer = is_reducer<ReducerType>::value;
static constexpr int IsArray = std::is_pointer<reference_type>::value;
static constexpr int UseReducer =
!std::is_same_v<FunctorType, typename ReducerType::functor_type>;
static constexpr int IsArray = std::is_pointer<reference_type>::value;

using ParReduceSpecialize =
ParallelReduceSpecialize<FunctorType, Policy, ReducerType, pointer_type,
typename Analysis::value_type>;
ParallelReduceSpecialize<FunctorType, Policy,
typename ReducerType::functor_type, pointer_type,
typename ReducerType::value_type>;

const FunctorType m_functor;
const CombinedFunctorReducerType m_functor_reducer;
const Policy m_policy;
const ReducerType m_reducer;
const pointer_type m_result_ptr;
bool m_result_ptr_on_device;
const int m_result_ptr_num_elems;
using TagType = typename Policy::work_tag;

public:
void execute() const {
if constexpr (HasJoin) {
const FunctorType& functor = m_functor_reducer.get_functor();
if constexpr (FunctorHasJoin) {
// Enter this loop if the Functor has a init-join.
ParReduceSpecialize::execute_init_join(m_functor, m_policy, m_result_ptr,
ParReduceSpecialize::execute_init_join(functor, m_policy, m_result_ptr,
m_result_ptr_on_device);
} else if constexpr (UseReducer) {
// Enter this loop if the Functor is a reducer type.
ParReduceSpecialize::execute_reducer(m_functor, m_policy, m_result_ptr,
ParReduceSpecialize::execute_reducer(functor, m_policy, m_result_ptr,
m_result_ptr_on_device);
} else if constexpr (IsArray) {
// Enter this loop if the reduction is on an array and the routine is
// templated over the size of the array.
if (m_result_ptr_num_elems <= 2) {
ParReduceSpecialize::template execute_array<TagType, 2>(
m_functor, m_policy, m_result_ptr, m_result_ptr_on_device);
functor, m_policy, m_result_ptr, m_result_ptr_on_device);
} else if (m_result_ptr_num_elems <= 4) {
ParReduceSpecialize::template execute_array<TagType, 4>(
m_functor, m_policy, m_result_ptr, m_result_ptr_on_device);
functor, m_policy, m_result_ptr, m_result_ptr_on_device);
} else if (m_result_ptr_num_elems <= 8) {
ParReduceSpecialize::template execute_array<TagType, 8>(
m_functor, m_policy, m_result_ptr, m_result_ptr_on_device);
functor, m_policy, m_result_ptr, m_result_ptr_on_device);
} else if (m_result_ptr_num_elems <= 16) {
ParReduceSpecialize::template execute_array<TagType, 16>(
m_functor, m_policy, m_result_ptr, m_result_ptr_on_device);
functor, m_policy, m_result_ptr, m_result_ptr_on_device);
} else if (m_result_ptr_num_elems <= 32) {
ParReduceSpecialize::template execute_array<TagType, 32>(
m_functor, m_policy, m_result_ptr, m_result_ptr_on_device);
functor, m_policy, m_result_ptr, m_result_ptr_on_device);
} else {
Kokkos::abort("array reduction length must be <= 32");
}
} else {
// This loop handles the basic scalar reduction.
ParReduceSpecialize::template execute_array<TagType, 1>(
m_functor, m_policy, m_result_ptr, m_result_ptr_on_device);
functor, m_policy, m_result_ptr, m_result_ptr_on_device);
}
}

template <class ViewType>
ParallelReduce(const FunctorType& arg_functor, const Policy& arg_policy,
const ViewType& arg_result_view,
std::enable_if_t<Kokkos::is_view<ViewType>::value &&
!Kokkos::is_reducer<ReducerType>::value,
void*> = nullptr)
: m_functor(arg_functor),
ParallelReduce(const CombinedFunctorReducerType& arg_functor_reducer,
const Policy& arg_policy, const ViewType& arg_result_view)
: m_functor_reducer(arg_functor_reducer),
m_policy(arg_policy),
m_reducer(InvalidType()),
m_result_ptr(arg_result_view.data()),
m_result_ptr_on_device(
MemorySpaceAccess<Kokkos::Experimental::OpenMPTargetSpace,
typename ViewType::memory_space>::accessible),
m_result_ptr_num_elems(arg_result_view.size()) {}

ParallelReduce(const FunctorType& arg_functor, const Policy& arg_policy,
const ReducerType& reducer)
: m_functor(arg_functor),
m_policy(arg_policy),
m_reducer(reducer),
m_result_ptr(reducer.view().data()),
m_result_ptr_on_device(
MemorySpaceAccess<Kokkos::Experimental::OpenMPTargetSpace,
typename ReducerType::result_view_type::
memory_space>::accessible),
m_result_ptr_num_elems(reducer.view().size()) {}
};

} // namespace Impl
Expand Down
91 changes: 35 additions & 56 deletions core/src/OpenMPTarget/Kokkos_OpenMPTarget_ParallelReduce_Team.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,113 +432,92 @@ parallel_reduce(const Impl::TeamVectorRangeBoundariesStruct<

namespace Impl {

template <class FunctorType, class ReducerType, class... Properties>
class ParallelReduce<FunctorType, Kokkos::TeamPolicy<Properties...>,
ReducerType, Kokkos::Experimental::OpenMPTarget> {
template <class CombinedFunctorReducerType, class... Properties>
class ParallelReduce<CombinedFunctorReducerType,
Kokkos::TeamPolicy<Properties...>,
Kokkos::Experimental::OpenMPTarget> {
private:
using Policy =
Kokkos::Impl::TeamPolicyInternal<Kokkos::Experimental::OpenMPTarget,
Properties...>;
using FunctorType = typename CombinedFunctorReducerType::functor_type;
using ReducerType = typename CombinedFunctorReducerType::reducer_type;

using WorkTag = typename Policy::work_tag;
using Member = typename Policy::member_type;
using ReducerTypeFwd =
std::conditional_t<std::is_same<InvalidType, ReducerType>::value,
FunctorType, ReducerType>;
using WorkTagFwd =
std::conditional_t<std::is_same<InvalidType, ReducerType>::value, WorkTag,
void>;
using Analysis = Impl::FunctorAnalysis<Impl::FunctorPatternInterface::REDUCE,
Policy, ReducerTypeFwd>;

using pointer_type = typename Analysis::pointer_type;
using reference_type = typename Analysis::reference_type;
using value_type = typename Analysis::value_type;

using pointer_type = typename ReducerType::pointer_type;
using reference_type = typename ReducerType::reference_type;
using value_type = typename ReducerType::value_type;

bool m_result_ptr_on_device;
const int m_result_ptr_num_elems;

static constexpr int HasJoin =
static constexpr int FunctorHasJoin =
Impl::FunctorAnalysis<Impl::FunctorPatternInterface::REDUCE, Policy,
FunctorType>::has_join_member_function;
static constexpr int UseReducer = is_reducer<ReducerType>::value;
static constexpr int IsArray = std::is_pointer<reference_type>::value;
static constexpr int UseReducer =
!std::is_same_v<FunctorType, typename ReducerType::functor_type>;
static constexpr int IsArray = std::is_pointer<reference_type>::value;

using ParReduceSpecialize =
ParallelReduceSpecialize<FunctorType, Policy, ReducerType, pointer_type,
typename Analysis::value_type>;
ParallelReduceSpecialize<FunctorType, Policy,
typename ReducerType::functor_type, pointer_type,
typename ReducerType::value_type>;

const FunctorType m_functor;
const CombinedFunctorReducerType m_functor_reducer;
const Policy m_policy;
const ReducerType m_reducer;
const pointer_type m_result_ptr;
const size_t m_shmem_size;

public:
void execute() const {
if constexpr (HasJoin) {
ParReduceSpecialize::execute_init_join(m_functor, m_policy, m_result_ptr,
const FunctorType& functor = m_functor_reducer.get_functor();
if constexpr (FunctorHasJoin) {
ParReduceSpecialize::execute_init_join(functor, m_policy, m_result_ptr,
m_result_ptr_on_device);
} else if constexpr (UseReducer) {
ParReduceSpecialize::execute_reducer(m_functor, m_policy, m_result_ptr,
ParReduceSpecialize::execute_reducer(functor, m_policy, m_result_ptr,
m_result_ptr_on_device);
} else if constexpr (IsArray) {
if (m_result_ptr_num_elems <= 2) {
ParReduceSpecialize::template execute_array<2>(
m_functor, m_policy, m_result_ptr, m_result_ptr_on_device);
functor, m_policy, m_result_ptr, m_result_ptr_on_device);
} else if (m_result_ptr_num_elems <= 4) {
ParReduceSpecialize::template execute_array<4>(
m_functor, m_policy, m_result_ptr, m_result_ptr_on_device);
functor, m_policy, m_result_ptr, m_result_ptr_on_device);
} else if (m_result_ptr_num_elems <= 8) {
ParReduceSpecialize::template execute_array<8>(
m_functor, m_policy, m_result_ptr, m_result_ptr_on_device);
functor, m_policy, m_result_ptr, m_result_ptr_on_device);
} else if (m_result_ptr_num_elems <= 16) {
ParReduceSpecialize::template execute_array<16>(
m_functor, m_policy, m_result_ptr, m_result_ptr_on_device);
functor, m_policy, m_result_ptr, m_result_ptr_on_device);
} else if (m_result_ptr_num_elems <= 32) {
ParReduceSpecialize::template execute_array<32>(
m_functor, m_policy, m_result_ptr, m_result_ptr_on_device);
functor, m_policy, m_result_ptr, m_result_ptr_on_device);
} else {
Kokkos::abort("array reduction length must be <= 32");
}
} else {
ParReduceSpecialize::template execute_array<1>(
m_functor, m_policy, m_result_ptr, m_result_ptr_on_device);
functor, m_policy, m_result_ptr, m_result_ptr_on_device);
}
}

template <class ViewType>
ParallelReduce(const FunctorType& arg_functor, const Policy& arg_policy,
const ViewType& arg_result,
std::enable_if_t<Kokkos::is_view<ViewType>::value &&
!Kokkos::is_reducer<ReducerType>::value,
void*> = nullptr)
ParallelReduce(const CombinedFunctorReducerType& arg_functor_reducer,
const Policy& arg_policy, const ViewType& arg_result)
: m_result_ptr_on_device(
MemorySpaceAccess<Kokkos::Experimental::OpenMPTargetSpace,
typename ViewType::memory_space>::accessible),
m_result_ptr_num_elems(arg_result.size()),
m_functor(arg_functor),
m_functor_reducer(arg_functor_reducer),
m_policy(arg_policy),
m_reducer(InvalidType()),
m_result_ptr(arg_result.data()),
m_shmem_size(arg_policy.scratch_size(0) + arg_policy.scratch_size(1) +
FunctorTeamShmemSize<FunctorType>::value(
arg_functor, arg_policy.team_size())) {}

ParallelReduce(const FunctorType& arg_functor, const Policy& arg_policy,
const ReducerType& reducer)
: m_result_ptr_on_device(
MemorySpaceAccess<Kokkos::Experimental::OpenMPTargetSpace,
typename ReducerType::result_view_type::
memory_space>::accessible),
m_result_ptr_num_elems(reducer.view().size()),
m_functor(arg_functor),
m_policy(arg_policy),
m_reducer(reducer),
m_result_ptr(reducer.view().data()),
m_shmem_size(arg_policy.scratch_size(0) + arg_policy.scratch_size(1) +
FunctorTeamShmemSize<FunctorType>::value(
arg_functor, arg_policy.team_size())) {}
m_shmem_size(
arg_policy.scratch_size(0) + arg_policy.scratch_size(1) +
FunctorTeamShmemSize<FunctorType>::value(
arg_functor_reducer.get_functor(), arg_policy.team_size())) {}
};

} // namespace Impl
Expand Down
Loading

0 comments on commit 3707be7

Please sign in to comment.