From 2e51c67a721f2d449c71623ece67388723f9c382 Mon Sep 17 00:00:00 2001 From: Daniel Arndt Date: Mon, 24 Apr 2023 20:09:14 +0000 Subject: [PATCH] Explicitly cast to CombinedFunctorReducerType --- core/src/SYCL/Kokkos_SYCL_Parallel_Reduce.hpp | 34 +++++++++---------- core/src/SYCL/Kokkos_SYCL_Parallel_Team.hpp | 16 ++++----- 2 files changed, 25 insertions(+), 25 deletions(-) diff --git a/core/src/SYCL/Kokkos_SYCL_Parallel_Reduce.hpp b/core/src/SYCL/Kokkos_SYCL_Parallel_Reduce.hpp index 0e2fee1da8..a29e8010d8 100644 --- a/core/src/SYCL/Kokkos_SYCL_Parallel_Reduce.hpp +++ b/core/src/SYCL/Kokkos_SYCL_Parallel_Reduce.hpp @@ -242,11 +242,11 @@ class ParallelReduce, const auto begin = policy.begin(); cgh.depends_on(memcpy_event); cgh.single_task([=]() { - const FunctorType& functor = - functor_reducer_wrapper.get_functor().get_functor(); - const ReducerType& reducer = - functor_reducer_wrapper.get_functor().get_reducer(); - reference_type update = reducer.init(results_ptr); + const CombinedFunctorReducerType& functor_reducer = + functor_reducer_wrapper.get_functor(); + const FunctorType& functor = functor_reducer.get_functor(); + const ReducerType& reducer = functor_reducer.get_reducer(); + reference_type update = reducer.init(results_ptr); if (size == 1) { if constexpr (std::is_void_v) functor(begin, update); @@ -285,10 +285,10 @@ class ParallelReduce, const auto global_id = wgroup_size * item.get_group_linear_id() * values_per_thread + local_id; - const FunctorType& functor = - functor_reducer_wrapper.get_functor().get_functor(); - const ReducerType& reducer = - functor_reducer_wrapper.get_functor().get_reducer(); + const CombinedFunctorReducerType& functor_reducer = + functor_reducer_wrapper.get_functor(); + const FunctorType& functor = functor_reducer.get_functor(); + const ReducerType& reducer = functor_reducer.get_reducer(); using index_type = typename Policy::index_type; const auto upper_bound = std::min( @@ -578,10 +578,10 @@ class ParallelReduce item) { const auto local_id = item.get_local_linear_id(); - const FunctorType& functor = - functor_reducer_wrapper.get_functor().get_functor(); - const ReducerType& reducer = - functor_reducer_wrapper.get_functor().get_reducer(); + const CombinedFunctorReducerType& functor_reducer = + functor_reducer_wrapper.get_functor(); + const FunctorType& functor = functor_reducer.get_functor(); + const ReducerType& reducer = functor_reducer.get_reducer(); // In the first iteration, we call functor to initialize the local // memory. Otherwise, the local memory is initialized with the diff --git a/core/src/SYCL/Kokkos_SYCL_Parallel_Team.hpp b/core/src/SYCL/Kokkos_SYCL_Parallel_Team.hpp index c11e37f9b8..b543e94a0b 100644 --- a/core/src/SYCL/Kokkos_SYCL_Parallel_Team.hpp +++ b/core/src/SYCL/Kokkos_SYCL_Parallel_Team.hpp @@ -596,10 +596,10 @@ class ParallelReduce(sycl::range<2>(1, 1), sycl::range<2>(1, 1)), [=](sycl::nd_item<2> item) { - const FunctorType& functor = - functor_reducer_wrapper.get_functor().get_functor(); - const ReducerType& reducer = - functor_reducer_wrapper.get_functor().get_reducer(); + const CombinedFunctorReducerType& functor_reducer = + functor_reducer_wrapper.get_functor(); + const FunctorType& functor = functor_reducer.get_functor(); + const ReducerType& reducer = functor_reducer.get_reducer(); reference_type update = reducer.init(results_ptr); if (size == 1) { @@ -655,10 +655,10 @@ class ParallelReduce( local_mem[wgroup_size * std::max(value_count, 1u)]); const auto local_id = item.get_local_linear_id(); - const FunctorType& functor = - functor_reducer_wrapper.get_functor().get_functor(); - const ReducerType& reducer = - functor_reducer_wrapper.get_functor().get_reducer(); + const CombinedFunctorReducerType& functor_reducer = + functor_reducer_wrapper.get_functor(); + const FunctorType& functor = functor_reducer.get_functor(); + const ReducerType& reducer = functor_reducer.get_reducer(); if constexpr (ReducerType::static_value_size() == 0) { reference_type update =