Skip to content

Commit

Permalink
Use CombinedReducer in HostIterateTile
Browse files Browse the repository at this point in the history
  • Loading branch information
masterleinad committed Mar 7, 2023
1 parent 0f7b7eb commit 2b035de
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 44 deletions.
13 changes: 5 additions & 8 deletions core/src/HPX/Kokkos_HPX.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1105,19 +1105,17 @@ class ParallelReduce<CombinedFunctorReducerType,
using pointer_type = typename ReducerType::pointer_type;
using value_type = typename ReducerType::value_type;
using reference_type = typename ReducerType::reference_type;
using iterate_type =
typename Kokkos::Impl::HostIterateTile<MDRangePolicy, FunctorType,
WorkTag, reference_type>;
using iterate_type = typename Kokkos::Impl::HostIterateTile<
MDRangePolicy, CombinedFunctorReducerType, WorkTag, reference_type>;

const iterate_type m_iter;
const Policy m_policy;
const CombinedFunctorReducerType m_functor_reducer;
const pointer_type m_result_ptr;
const bool m_force_synchronous;

public:
void setup() const {
const ReducerType &reducer = m_functor_reducer.get_reducer();
const ReducerType &reducer = m_iter.m_func.get_reducer();
const std::size_t value_size = reducer.value_size();
const int num_worker_threads = m_policy.space().concurrency();

Expand All @@ -1143,7 +1141,7 @@ class ParallelReduce<CombinedFunctorReducerType,

void finalize() const {
hpx_thread_buffer &buffer = m_iter.m_rp.space().impl_get_buffer();
ReducerType reducer = m_functor_reducer.get_reducer();
ReducerType reducer = m_iter.m_func.get_reducer();
const int num_worker_threads = m_policy.space().concurrency();
for (int i = 1; i < num_worker_threads; ++i) {
reducer.join(reinterpret_cast<pointer_type>(buffer.get(0)),
Expand Down Expand Up @@ -1175,9 +1173,8 @@ class ParallelReduce<CombinedFunctorReducerType,
template <class ViewType>
inline ParallelReduce(const CombinedFunctorReducerType &arg_functor_reducer,
MDRangePolicy arg_policy, const ViewType &arg_view)
: m_iter(arg_policy, arg_functor_reducer.get_functor()),
: m_iter(arg_policy, arg_functor_reducer),
m_policy(Policy(0, arg_policy.m_num_tiles).set_chunk_size(1)),
m_functor_reducer(arg_functor_reducer),
m_result_ptr(arg_view.data()),
m_force_synchronous(!arg_view.impl_track().has_record()) {}

Expand Down
28 changes: 13 additions & 15 deletions core/src/OpenMP/Kokkos_OpenMP_Parallel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,13 +467,11 @@ class ParallelReduce<CombinedFunctorReducerType,
using value_type = typename ReducerType::value_type;
using reference_type = typename ReducerType::reference_type;

using iterate_type =
typename Kokkos::Impl::HostIterateTile<MDRangePolicy, FunctorType,
WorkTag, reference_type>;
using iterate_type = typename Kokkos::Impl::HostIterateTile<
MDRangePolicy, CombinedFunctorReducerType, WorkTag, reference_type>;

OpenMPInternal* m_instance;
const iterate_type m_iter;
const ReducerType m_reducer;
const pointer_type m_result_ptr;

inline void exec_range(const Member ibeg, const Member iend,
Expand All @@ -485,7 +483,8 @@ class ParallelReduce<CombinedFunctorReducerType,

public:
inline void execute() const {
const size_t pool_reduce_bytes = m_reducer.value_size();
const ReducerType& reducer = m_iter.m_func.get_reducer();
const size_t pool_reduce_bytes = reducer.value_size();

m_instance->acquire_lock();

Expand All @@ -504,11 +503,11 @@ class ParallelReduce<CombinedFunctorReducerType,
: pointer_type(
m_instance->get_thread_data(0)->pool_reduce_local());

reference_type update = m_reducer.init(ptr);
reference_type update = reducer.init(ptr);

ParallelReduce::exec_range(0, m_iter.m_rp.m_num_tiles, update);

m_reducer.final(ptr);
reducer.final(ptr);

m_instance->release_lock();

Expand All @@ -533,7 +532,7 @@ class ParallelReduce<CombinedFunctorReducerType,
if (data.pool_rendezvous()) data.pool_rendezvous_release();
}

reference_type update = m_reducer.init(
reference_type update = reducer.init(
reinterpret_cast<pointer_type>(data.pool_reduce_local()));

std::pair<int64_t, int64_t> range(0, 0);
Expand All @@ -554,15 +553,15 @@ class ParallelReduce<CombinedFunctorReducerType,
pointer_type(m_instance->get_thread_data(0)->pool_reduce_local());

for (int i = 1; i < pool_size; ++i) {
m_reducer.join(ptr,
reinterpret_cast<pointer_type>(
m_instance->get_thread_data(i)->pool_reduce_local()));
reducer.join(ptr,
reinterpret_cast<pointer_type>(
m_instance->get_thread_data(i)->pool_reduce_local()));
}

m_reducer.final(ptr);
reducer.final(ptr);

if (m_result_ptr) {
const int n = m_reducer.value_count();
const int n = reducer.value_count();

for (int j = 0; j < n; ++j) {
m_result_ptr[j] = ptr[j];
Expand All @@ -578,8 +577,7 @@ class ParallelReduce<CombinedFunctorReducerType,
ParallelReduce(const CombinedFunctorReducerType& arg_functor_reducer,
MDRangePolicy arg_policy, const ViewType& arg_view)
: m_instance(nullptr),
m_iter(arg_policy, arg_functor_reducer.get_functor()),
m_reducer(arg_functor_reducer.get_reducer()),
m_iter(arg_policy, arg_functor_reducer),
m_result_ptr(arg_view.data()) {
#ifdef KOKKOS_ENABLE_DEPRECATED_CODE_3
if (t_openmp_instance) {
Expand Down
16 changes: 7 additions & 9 deletions core/src/Serial/Kokkos_Serial_Parallel_MDRange.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,9 @@ class ParallelReduce<CombinedFunctorReducerType,
using value_type = typename ReducerType::value_type;
using reference_type = typename ReducerType::reference_type;

using iterate_type =
typename Kokkos::Impl::HostIterateTile<MDRangePolicy, FunctorType,
WorkTag, reference_type>;
using iterate_type = typename Kokkos::Impl::HostIterateTile<
MDRangePolicy, CombinedFunctorReducerType, WorkTag, reference_type>;
const iterate_type m_iter;
const ReducerType m_reducer;
const pointer_type m_result_ptr;

inline void exec(reference_type update) const {
Expand All @@ -98,7 +96,8 @@ class ParallelReduce<CombinedFunctorReducerType,
return 1024;
}
inline void execute() const {
const size_t pool_reduce_size = m_reducer.value_size();
const ReducerType& reducer = m_iter.m_func.get_reducer();
const size_t pool_reduce_size = reducer.value_size();
const size_t team_reduce_size = 0; // Never shrinks
const size_t team_shared_size = 0; // Never shrinks
const size_t thread_local_size = 0; // Never shrinks
Expand All @@ -118,19 +117,18 @@ class ParallelReduce<CombinedFunctorReducerType,
: pointer_type(
internal_instance->m_thread_team_data.pool_reduce_local());

reference_type update = m_reducer.init(ptr);
reference_type update = reducer.init(ptr);

this->exec(update);

m_reducer.final(ptr);
reducer.final(ptr);
}

template <class ViewType>
ParallelReduce(const CombinedFunctorReducerType& arg_functor_reducer,
const MDRangePolicy& arg_policy,
const ViewType& arg_result_view)
: m_iter(arg_policy, arg_functor_reducer.get_functor()),
m_reducer(arg_functor_reducer.get_reducer()),
: m_iter(arg_policy, arg_functor_reducer),
m_result_ptr(arg_result_view.data()) {
static_assert(Kokkos::is_view<ViewType>::value,
"Kokkos::Serial reduce result must be a View");
Expand Down
20 changes: 10 additions & 10 deletions core/src/Threads/Kokkos_Threads_Parallel_MDRange.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,10 @@ class ParallelReduce<CombinedFunctorReducerType,
using value_type = typename ReducerType::value_type;
using reference_type = typename ReducerType::reference_type;

using iterate_type =
typename Kokkos::Impl::HostIterateTile<MDRangePolicy, FunctorType,
WorkTag, reference_type>;
using iterate_type = typename Kokkos::Impl::HostIterateTile<
MDRangePolicy, CombinedFunctorReducerType, WorkTag, reference_type>;

const iterate_type m_iter;
const ReducerType m_reducer;
const pointer_type m_result_ptr;

inline void exec_range(const Member &ibeg, const Member &iend,
Expand All @@ -156,11 +154,12 @@ class ParallelReduce<CombinedFunctorReducerType,
const WorkRange range(Policy(0, num_tiles).set_chunk_size(1),
exec.pool_rank(), exec.pool_size());

const ReducerType &reducer = self.m_iter.m_func.get_reducer();
self.exec_range(
range.begin(), range.end(),
self.m_reducer.init(static_cast<pointer_type>(exec.reduce_memory())));
reducer.init(static_cast<pointer_type>(exec.reduce_memory())));

exec.fan_in_reduce(self.m_reducer);
exec.fan_in_reduce(reducer);
}

template <class Schedule>
Expand All @@ -178,6 +177,7 @@ class ParallelReduce<CombinedFunctorReducerType,

long work_index = exec.get_work_index();

const ReducerType &reducer = self.m_iter.m_func.get_reducer();
reference_type update =
self.m_reducer.init(static_cast<pointer_type>(exec.reduce_memory()));
while (work_index != -1) {
Expand All @@ -192,7 +192,8 @@ class ParallelReduce<CombinedFunctorReducerType,

public:
inline void execute() const {
ThreadsExec::resize_scratch(m_reducer.value_size(), 0);
const ReducerType &reducer = m_iter.m_func.get_reducer();
ThreadsExec::resize_scratch(reducer.value_size(), 0);

ThreadsExec::start(&ParallelReduce::exec, this);

Expand All @@ -202,7 +203,7 @@ class ParallelReduce<CombinedFunctorReducerType,
const pointer_type data =
(pointer_type)ThreadsExec::root_reduce_scratch();

const unsigned n = m_reducer.value_count();
const unsigned n = reducer.value_count();
for (unsigned i = 0; i < n; ++i) {
m_result_ptr[i] = data[i];
}
Expand All @@ -213,8 +214,7 @@ class ParallelReduce<CombinedFunctorReducerType,
ParallelReduce(const CombinedFunctorReducerType &arg_functor_reducer,
const MDRangePolicy &arg_policy,
const ViewType &arg_result_view)
: m_iter(arg_policy, arg_functor_reducer.get_functor()),
m_reducer(arg_functor_reducer.get_reducer()),
: m_iter(arg_policy, arg_functor_reducer),
m_result_ptr(arg_result_view.data()) {
static_assert(Kokkos::is_view<ViewType>::value,
"Kokkos::Threads reduce result must be a View");
Expand Down
4 changes: 2 additions & 2 deletions core/src/impl/KokkosExp_Host_IterateTile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2093,8 +2093,8 @@ struct HostIterateTile<RP, Functor, Tag, ValueType,
const bool full_tile = check_iteration_bounds(m_tiledims, m_offset);

Tile_Loop_Type<RP::rank, (RP::inner_direction == Iterate::Left), index_type,
Tag>::apply(val, m_func, full_tile, m_offset, m_rp.m_tile,
m_tiledims);
Tag>::apply(val, m_func.get_functor(), full_tile, m_offset,
m_rp.m_tile, m_tiledims);
}

#else
Expand Down

0 comments on commit 2b035de

Please sign in to comment.