From a6f27bf738df97b3274679ad8918a1c10249c849 Mon Sep 17 00:00:00 2001 From: Daniel Arndt Date: Fri, 24 Mar 2023 13:17:55 +0000 Subject: [PATCH] Pass local_accessor directly instead --- core/src/SYCL/Kokkos_SYCL_Parallel_Reduce.hpp | 41 +++++++++---------- core/src/SYCL/Kokkos_SYCL_Parallel_Scan.hpp | 8 ++-- core/src/SYCL/Kokkos_SYCL_Parallel_Team.hpp | 8 ++-- 3 files changed, 27 insertions(+), 30 deletions(-) diff --git a/core/src/SYCL/Kokkos_SYCL_Parallel_Reduce.hpp b/core/src/SYCL/Kokkos_SYCL_Parallel_Reduce.hpp index 1a3350cedc..4bdedc64e1 100644 --- a/core/src/SYCL/Kokkos_SYCL_Parallel_Reduce.hpp +++ b/core/src/SYCL/Kokkos_SYCL_Parallel_Reduce.hpp @@ -37,7 +37,7 @@ inline constexpr bool use_shuffle_based_algorithm = namespace SYCLReduction { template std::enable_if_t> workgroup_reduction( - sycl::nd_item& item, sycl::local_ptr local_mem, + sycl::nd_item& item, sycl::local_accessor local_mem, sycl::device_ptr results_ptr, sycl::global_ptr device_accessible_result_ptr, const unsigned int value_count, const ReducerType& final_reducer, @@ -109,7 +109,7 @@ std::enable_if_t> workgroup_reduction( template std::enable_if_t> workgroup_reduction( - sycl::nd_item& item, sycl::local_ptr local_mem, + sycl::nd_item& item, sycl::local_accessor local_mem, ValueType local_value, sycl::device_ptr results_ptr, sycl::global_ptr device_accessible_result_ptr, const ReducerType& final_reducer, bool final, unsigned int max_size) { @@ -271,8 +271,8 @@ class ParallelReduce, instance.scratch_flags(sizeof(unsigned int))); auto reduction_lambda_factory = - [&](sycl::local_accessor local_mem, - sycl::local_accessor num_teams_done, + [&](sycl::local_accessor local_mem, + sycl::local_accessor num_teams_done, sycl::device_ptr results_ptr) { const auto begin = policy.begin(); @@ -304,9 +304,8 @@ class ParallelReduce, item.barrier(sycl::access::fence_space::local_space); SYCLReduction::workgroup_reduction<>( - item, local_mem.get_pointer(), results_ptr, - device_accessible_result_ptr, value_count, reducer, false, - std::min(size, wgroup_size)); + item, local_mem, results_ptr, device_accessible_result_ptr, + value_count, reducer, false, std::min(size, wgroup_size)); if (local_id == 0) { sycl::atomic_ref, } SYCLReduction::workgroup_reduction<>( - item, local_mem.get_pointer(), results_ptr, + item, local_mem, results_ptr, device_accessible_result_ptr, value_count, reducer, true, std::min(n_wgroups, wgroup_size)); } @@ -346,7 +345,7 @@ class ParallelReduce, } SYCLReduction::workgroup_reduction<>( - item, local_mem.get_pointer(), local_value, results_ptr, + item, local_mem, local_value, results_ptr, device_accessible_result_ptr, reducer, false, std::min(size, wgroup_size)); @@ -370,7 +369,7 @@ class ParallelReduce, } SYCLReduction::workgroup_reduction<>( - item, local_mem.get_pointer(), local_value, results_ptr, + item, local_mem, local_value, results_ptr, device_accessible_result_ptr, reducer, true, std::min(n_wgroups, wgroup_size)); } @@ -380,7 +379,7 @@ class ParallelReduce, }; auto parallel_reduce_event = q.submit([&](sycl::handler& cgh) { - sycl::local_accessor num_teams_done(1, cgh); + sycl::local_accessor num_teams_done(1, cgh); auto dummy_reduction_lambda = reduction_lambda_factory({1, cgh}, num_teams_done, nullptr); @@ -421,7 +420,7 @@ class ParallelReduce, wgroup_size - 1) / wgroup_size; - sycl::local_accessor local_mem( + sycl::local_accessor local_mem( sycl::range<1>(wgroup_size) * std::max(value_count, 1u), cgh); cgh.depends_on(memcpy_events); @@ -608,9 +607,9 @@ class ParallelReduce 1) { auto n_wgroups = (size + wgroup_size - 1) / wgroup_size; auto parallel_reduce_event = q.submit([&](sycl::handler& cgh) { - sycl::local_accessor local_mem( + sycl::local_accessor local_mem( sycl::range<1>(wgroup_size) * std::max(value_count, 1u), cgh); - sycl::local_accessor num_teams_done(1, cgh); + sycl::local_accessor num_teams_done(1, cgh); const BarePolicy bare_policy = m_policy; @@ -652,9 +651,8 @@ class ParallelReduce( - item, local_mem.get_pointer(), results_ptr, - device_accessible_result_ptr, value_count, reducer, false, - std::min(size, wgroup_size)); + item, local_mem, results_ptr, device_accessible_result_ptr, + value_count, reducer, false, std::min(size, wgroup_size)); if (local_id == 0) { sycl::atomic_ref( - item, local_mem.get_pointer(), results_ptr, - device_accessible_result_ptr, value_count, reducer, true, - std::min(n_wgroups, wgroup_size)); + item, local_mem, results_ptr, device_accessible_result_ptr, + value_count, reducer, true, std::min(n_wgroups, wgroup_size)); } } else { value_type local_value; @@ -695,7 +692,7 @@ class ParallelReduce( - item, local_mem.get_pointer(), local_value, results_ptr, + item, local_mem, local_value, results_ptr, device_accessible_result_ptr, reducer, false, std::min(size, wgroup_size)); @@ -719,7 +716,7 @@ class ParallelReduce( - item, local_mem.get_pointer(), local_value, results_ptr, + item, local_mem, local_value, results_ptr, device_accessible_result_ptr, reducer, true, std::min(n_wgroups, wgroup_size)); } diff --git a/core/src/SYCL/Kokkos_SYCL_Parallel_Scan.hpp b/core/src/SYCL/Kokkos_SYCL_Parallel_Scan.hpp index 3bd25b1f23..5176c0f14e 100644 --- a/core/src/SYCL/Kokkos_SYCL_Parallel_Scan.hpp +++ b/core/src/SYCL/Kokkos_SYCL_Parallel_Scan.hpp @@ -31,7 +31,7 @@ namespace Impl { // total sum. template void workgroup_scan(sycl::nd_item item, const FunctorType& final_reducer, - sycl::local_ptr local_mem, + sycl::local_accessor local_mem, ValueType& local_value, unsigned int global_range) { // subgroup scans auto sg = item.get_sub_group(); @@ -136,7 +136,7 @@ class ParallelScanSYCLBase { q.get_device() .template get_info() .front(); - sycl::local_accessor local_mem( + sycl::local_accessor local_mem( sycl::range<1>((wgroup_size + min_subgroup_size - 1) / min_subgroup_size), cgh); @@ -160,8 +160,8 @@ class ParallelScanSYCLBase { else reducer.init(&local_value); - workgroup_scan<>(item, reducer, local_mem.get_pointer(), - local_value, wgroup_size); + workgroup_scan<>(item, reducer, local_mem, local_value, + wgroup_size); if (n_wgroups > 1 && local_id == wgroup_size - 1) group_results[item.get_group_linear_id()] = diff --git a/core/src/SYCL/Kokkos_SYCL_Parallel_Team.hpp b/core/src/SYCL/Kokkos_SYCL_Parallel_Team.hpp index 62a41fe91f..c1a3133428 100644 --- a/core/src/SYCL/Kokkos_SYCL_Parallel_Team.hpp +++ b/core/src/SYCL/Kokkos_SYCL_Parallel_Team.hpp @@ -668,7 +668,7 @@ class ParallelReduce( - item, local_mem.get_pointer(), results_ptr, + item, local_mem, results_ptr, device_accessible_result_ptr, value_count, reducer, false, std::min(size, item.get_local_range()[0] * @@ -696,7 +696,7 @@ class ParallelReduce( - item, local_mem.get_pointer(), results_ptr, + item, local_mem, results_ptr, device_accessible_result_ptr, value_count, reducer, true, std::min(n_wgroups, item.get_local_range()[0] * @@ -716,7 +716,7 @@ class ParallelReduce( - item, local_mem.get_pointer(), local_value, results_ptr, + item, local_mem, local_value, results_ptr, device_accessible_result_ptr, reducer, false, std::min(size, item.get_local_range()[0] * @@ -742,7 +742,7 @@ class ParallelReduce( - item, local_mem.get_pointer(), local_value, results_ptr, + item, local_mem, local_value, results_ptr, device_accessible_result_ptr, reducer, true, std::min(n_wgroups, item.get_local_range()[0] * item.get_local_range()[1]));