diff --git a/core/src/Cuda/Kokkos_Cuda_Parallel_Range.hpp b/core/src/Cuda/Kokkos_Cuda_Parallel_Range.hpp index 620ef67927..904d1d670e 100644 --- a/core/src/Cuda/Kokkos_Cuda_Parallel_Range.hpp +++ b/core/src/Cuda/Kokkos_Cuda_Parallel_Range.hpp @@ -421,7 +421,8 @@ class ParallelScan, Kokkos::Cuda> { // (c) gridDim.x <= blockDim.y * blockDim.y // (d) gridDim.y == gridDim.z == 1 - const FunctorType m_functor; + const CombinedFunctorReducer + m_functor_reducer; const Policy m_policy; word_size_type* m_scratch_space; size_type* m_scratch_flags; @@ -433,23 +434,25 @@ class ParallelScan, Kokkos::Cuda> { template __device__ inline std::enable_if_t::value> exec_range( const Member& i, reference_type update, const bool final_result) const { - m_functor(i, update, final_result); + m_functor_reducer.get_functor()(i, update, final_result); } template __device__ inline std::enable_if_t::value> exec_range( const Member& i, reference_type update, const bool final_result) const { - m_functor(TagType(), i, update, final_result); + m_functor_reducer.get_functor()(TagType(), i, update, final_result); } //---------------------------------------- __device__ inline void initial() const { - typename Analysis::Reducer final_reducer(m_functor); + const typename Analysis::Reducer& final_reducer = + m_functor_reducer.get_reducer(); const integral_nonzero_constant - word_count(Analysis::value_size(m_functor) / sizeof(word_size_type)); + word_count(Analysis::value_size(m_functor_reducer.get_functor()) / + sizeof(word_size_type)); word_size_type* const shared_value = kokkos_impl_cuda_shared_memory() + @@ -485,11 +488,13 @@ class ParallelScan, Kokkos::Cuda> { //---------------------------------------- __device__ inline void final() const { - typename Analysis::Reducer final_reducer(m_functor); + const typename Analysis::Reducer& final_reducer = + m_functor_reducer.get_reducer(); const integral_nonzero_constant - word_count(Analysis::value_size(m_functor) / sizeof(word_size_type)); + word_count(Analysis::value_size(m_functor_reducer.get_functor()) / + sizeof(word_size_type)); // Use shared memory as an exclusive scan: { 0 , value[0] , value[1] , // value[2] , ... } @@ -619,7 +624,7 @@ class ParallelScan, Kokkos::Cuda> { if (nwork) { constexpr int GridMaxComputeCapability_2x = 0x0ffff; - const int block_size = local_block_size(m_functor); + const int block_size = local_block_size(m_functor_reducer.get_functor()); KOKKOS_ASSERT(block_size > 0); const int grid_max = @@ -639,13 +644,15 @@ class ParallelScan, Kokkos::Cuda> { m_scratch_space = reinterpret_cast(cuda_internal_scratch_space( - m_policy.space(), Analysis::value_size(m_functor) * grid_x)); + m_policy.space(), + Analysis::value_size(m_functor_reducer.get_functor()) * grid_x)); m_scratch_flags = cuda_internal_scratch_flags(m_policy.space(), sizeof(size_type) * 1); dim3 grid(grid_x, 1, 1); dim3 block(1, block_size, 1); // REQUIRED DIMENSIONS ( 1 , N , 1 ) - const int shmem = Analysis::value_size(m_functor) * (block_size + 2); + const int shmem = Analysis::value_size(m_functor_reducer.get_functor()) * + (block_size + 2); #ifdef KOKKOS_IMPL_DEBUG_CUDA_SERIAL_EXECUTION if (m_run_serial) { @@ -670,7 +677,7 @@ class ParallelScan, Kokkos::Cuda> { } ParallelScan(const FunctorType& arg_functor, const Policy& arg_policy) - : m_functor(arg_functor), + : m_functor_reducer(arg_functor, typename Analysis::Reducer{arg_functor}), m_policy(arg_policy), m_scratch_space(nullptr), m_scratch_flags(nullptr), @@ -728,7 +735,8 @@ class ParallelScanWithTotal, // (c) gridDim.x <= blockDim.y * blockDim.y // (d) gridDim.y == gridDim.z == 1 - const FunctorType m_functor; + const CombinedFunctorReducer + m_functor_reducer; const Policy m_policy; word_size_type* m_scratch_space; size_type* m_scratch_flags; @@ -743,23 +751,25 @@ class ParallelScanWithTotal, template __device__ inline std::enable_if_t::value> exec_range( const Member& i, reference_type update, const bool final_result) const { - m_functor(i, update, final_result); + m_functor_reducer.get_functor()(i, update, final_result); } template __device__ inline std::enable_if_t::value> exec_range( const Member& i, reference_type update, const bool final_result) const { - m_functor(TagType(), i, update, final_result); + m_functor_reducer.get_functor()(TagType(), i, update, final_result); } //---------------------------------------- __device__ inline void initial() const { - typename Analysis::Reducer final_reducer(m_functor); + const typename Analysis::Reducer& final_reducer = + m_functor_reducer.get_reducer(); const integral_nonzero_constant - word_count(Analysis::value_size(m_functor) / sizeof(word_size_type)); + word_count(Analysis::value_size(m_functor_reducer.get_functor()) / + sizeof(word_size_type)); word_size_type* const shared_value = kokkos_impl_cuda_shared_memory() + @@ -795,11 +805,12 @@ class ParallelScanWithTotal, //---------------------------------------- __device__ inline void final() const { - typename Analysis::Reducer final_reducer(m_functor); + const typename Analysis::Reducer& final_reducer = + m_functor_reducer.get_reducer(); const integral_nonzero_constant - word_count(Analysis::value_size(m_functor) / sizeof(word_size_type)); + word_count(final_reducer.value_size() / sizeof(word_size_type)); // Use shared memory as an exclusive scan: { 0 , value[0] , value[1] , // value[2] , ... } @@ -935,7 +946,7 @@ class ParallelScanWithTotal, if (nwork) { enum { GridMaxComputeCapability_2x = 0x0ffff }; - const int block_size = local_block_size(m_functor); + const int block_size = local_block_size(m_functor_reducer.get_functor()); KOKKOS_ASSERT(block_size > 0); const int grid_max = @@ -953,15 +964,17 @@ class ParallelScanWithTotal, // How many block are really needed for this much work: const int grid_x = (nwork + work_per_block - 1) / work_per_block; + const typename Analysis::Reducer& final_reducer = + m_functor_reducer.get_reducer(); m_scratch_space = reinterpret_cast(cuda_internal_scratch_space( - m_policy.space(), Analysis::value_size(m_functor) * grid_x)); + m_policy.space(), final_reducer.value_size() * grid_x)); m_scratch_flags = cuda_internal_scratch_flags(m_policy.space(), sizeof(size_type) * 1); dim3 grid(grid_x, 1, 1); dim3 block(1, block_size, 1); // REQUIRED DIMENSIONS ( 1 , N , 1 ) - const int shmem = Analysis::value_size(m_functor) * (block_size + 2); + const int shmem = final_reducer.value_size() * (block_size + 2); #ifdef KOKKOS_IMPL_DEBUG_CUDA_SERIAL_EXECUTION if (m_run_serial) { @@ -982,7 +995,7 @@ class ParallelScanWithTotal, m_policy.space() .impl_internal_space_instance()); // copy to device and execute - const int size = Analysis::value_size(m_functor); + const int size = final_reducer.value_size(); #ifdef KOKKOS_IMPL_DEBUG_CUDA_SERIAL_EXECUTION if (m_run_serial) DeepCopy(m_policy.space(), &m_returnvalue, @@ -1003,7 +1016,7 @@ class ParallelScanWithTotal, ParallelScanWithTotal(const FunctorType& arg_functor, const Policy& arg_policy, const ViewType& arg_result_view) - : m_functor(arg_functor), + : m_functor_reducer(arg_functor, typename Analysis::Reducer{arg_functor}), m_policy(arg_policy), m_scratch_space(nullptr), m_scratch_flags(nullptr), diff --git a/core/src/Serial/Kokkos_Serial_Parallel_Range.hpp b/core/src/Serial/Kokkos_Serial_Parallel_Range.hpp index 01089677a2..5840cc736d 100644 --- a/core/src/Serial/Kokkos_Serial_Parallel_Range.hpp +++ b/core/src/Serial/Kokkos_Serial_Parallel_Range.hpp @@ -154,7 +154,8 @@ class ParallelScan, using pointer_type = typename Analysis::pointer_type; using reference_type = typename Analysis::reference_type; - const FunctorType m_functor; + const CombinedFunctorReducer + m_functor_reducer; const Policy m_policy; template @@ -162,7 +163,7 @@ class ParallelScan, reference_type update) const { const typename Policy::member_type e = m_policy.end(); for (typename Policy::member_type i = m_policy.begin(); i < e; ++i) { - m_functor(i, update, true); + m_functor_reducer.get_functor()(i, update, true); } } @@ -172,13 +173,15 @@ class ParallelScan, const TagType t{}; const typename Policy::member_type e = m_policy.end(); for (typename Policy::member_type i = m_policy.begin(); i < e; ++i) { - m_functor(t, i, update, true); + m_functor_reducer.get_functor()(t, i, update, true); } } public: inline void execute() const { - const size_t pool_reduce_size = Analysis::value_size(m_functor); + const typename Analysis::Reducer& final_reducer = + m_functor_reducer.get_reducer(); + const size_t pool_reduce_size = final_reducer.value_size(); const size_t team_reduce_size = 0; // Never shrinks const size_t team_shared_size = 0; // Never shrinks const size_t thread_local_size = 0; // Never shrinks @@ -191,8 +194,6 @@ class ParallelScan, pool_reduce_size, team_reduce_size, team_shared_size, thread_local_size); - typename Analysis::Reducer final_reducer(m_functor); - reference_type update = final_reducer.init(pointer_type( internal_instance->m_thread_team_data.pool_reduce_local())); @@ -200,7 +201,8 @@ class ParallelScan, } inline ParallelScan(const FunctorType& arg_functor, const Policy& arg_policy) - : m_functor(arg_functor), m_policy(arg_policy) {} + : m_functor_reducer(arg_functor, typename Analysis::Reducer{arg_functor}), + m_policy(arg_policy) {} }; /*--------------------------------------------------------------------------*/ @@ -218,7 +220,8 @@ class ParallelScanWithTotal, using pointer_type = typename Analysis::pointer_type; using reference_type = typename Analysis::reference_type; - const FunctorType m_functor; + const CombinedFunctorReducer + m_functor_reducer; const Policy m_policy; const pointer_type m_result_ptr; @@ -227,7 +230,7 @@ class ParallelScanWithTotal, reference_type update) const { const typename Policy::member_type e = m_policy.end(); for (typename Policy::member_type i = m_policy.begin(); i < e; ++i) { - m_functor(i, update, true); + m_functor_reducer.get_functor()(i, update, true); } } @@ -237,13 +240,14 @@ class ParallelScanWithTotal, const TagType t{}; const typename Policy::member_type e = m_policy.end(); for (typename Policy::member_type i = m_policy.begin(); i < e; ++i) { - m_functor(t, i, update, true); + m_functor_reducer.get_functor()(t, i, update, true); } } public: inline void execute() { - const size_t pool_reduce_size = Analysis::value_size(m_functor); + const size_t pool_reduce_size = + m_functor_reducer.get_reducer().value_size(); const size_t team_reduce_size = 0; // Never shrinks const size_t team_shared_size = 0; // Never shrinks const size_t thread_local_size = 0; // Never shrinks @@ -256,7 +260,8 @@ class ParallelScanWithTotal, pool_reduce_size, team_reduce_size, team_shared_size, thread_local_size); - typename Analysis::Reducer final_reducer(m_functor); + const typename Analysis::Reducer& final_reducer = + m_functor_reducer.get_reducer(); reference_type update = final_reducer.init(pointer_type( internal_instance->m_thread_team_data.pool_reduce_local())); @@ -271,7 +276,7 @@ class ParallelScanWithTotal, ParallelScanWithTotal(const FunctorType& arg_functor, const Policy& arg_policy, const ViewType& arg_result_view) - : m_functor(arg_functor), + : m_functor_reducer(arg_functor, typename Analysis::Reducer{arg_functor}), m_policy(arg_policy), m_result_ptr(arg_result_view.data()) { static_assert(