diff --git a/core/src/HIP/Kokkos_HIP_Instance.cpp b/core/src/HIP/Kokkos_HIP_Instance.cpp index 67d650ad49..28c9c1cb6a 100644 --- a/core/src/HIP/Kokkos_HIP_Instance.cpp +++ b/core/src/HIP/Kokkos_HIP_Instance.cpp @@ -236,25 +236,45 @@ Kokkos::HIP::size_type *HIPInternal::scratch_flags(const std::size_t size) { return m_scratchFlags; } -Kokkos::HIP::size_type *HIPInternal::scratch_functor( - const std::size_t size) const { +Kokkos::HIP::size_type *HIPInternal::stage_functor_for_execution( + void const *driver, std::size_t const size) const { if (verify_is_initialized("scratch_functor") && m_scratchFunctorSize < size) { m_scratchFunctorSize = size; using Record = Kokkos::Impl::SharedAllocationRecord; + using RecordHost = + Kokkos::Impl::SharedAllocationRecord; - if (m_scratchFunctor) + if (m_scratchFunctor) { Record::decrement(Record::get_record(m_scratchFunctor)); + RecordHost::decrement(RecordHost::get_record(m_scratchFunctorHost)); + } Record *const r = Record::allocate(Kokkos::HIPSpace(), "Kokkos::InternalScratchFunctor", m_scratchFunctorSize); + RecordHost *const r_host = RecordHost::allocate( + Kokkos::HIPHostPinnedSpace(), "Kokkos::InternalScratchFunctorHost", + m_scratchFunctorSize); Record::increment(r); + RecordHost::increment(r_host); - m_scratchFunctor = reinterpret_cast(r->data()); + m_scratchFunctor = reinterpret_cast(r->data()); + m_scratchFunctorHost = reinterpret_cast(r_host->data()); } + // When using HSA_XNACK=1, it is necessary to copy the driver to the host to + // ensure that the driver is not destroyed before the computation is done. + // Without this fix, all the atomic tests fail. It is not obvious that this + // problem is limited to HSA_XNACK=1 even if all the tests pass when + // HSA_XNACK=0. That's why we always copy the driver. + KOKKOS_IMPL_HIP_SAFE_CALL(hipStreamSynchronize(m_stream)); + std::memcpy(m_scratchFunctorHost, driver, size); + KOKKOS_IMPL_HIP_SAFE_CALL(hipMemcpyAsync(m_scratchFunctor, + m_scratchFunctorHost, size, + hipMemcpyDefault, m_stream)); + return m_scratchFunctor; } @@ -318,8 +338,10 @@ void HIPInternal::finalize() { RecordHIP::decrement(RecordHIP::get_record(m_scratchFlags)); RecordHIP::decrement(RecordHIP::get_record(m_scratchSpace)); - if (m_scratchFunctorSize > 0) + if (m_scratchFunctorSize > 0) { RecordHIP::decrement(RecordHIP::get_record(m_scratchFunctor)); + RecordHIP::decrement(RecordHIP::get_record(m_scratchFunctorHost)); + } } for (int i = 0; i < m_n_team_scratch; ++i) { diff --git a/core/src/HIP/Kokkos_HIP_Instance.hpp b/core/src/HIP/Kokkos_HIP_Instance.hpp index 7fcd499cfb..06fab84b56 100644 --- a/core/src/HIP/Kokkos_HIP_Instance.hpp +++ b/core/src/HIP/Kokkos_HIP_Instance.hpp @@ -87,9 +87,11 @@ class HIPInternal { std::size_t m_scratchFlagsCount = 0; mutable std::size_t m_scratchFunctorSize = 0; - size_type *m_scratchSpace = nullptr; - size_type *m_scratchFlags = nullptr; - mutable size_type *m_scratchFunctor = nullptr; + size_type *m_scratchSpace = nullptr; + size_type *m_scratchFlags = nullptr; + mutable size_type *m_scratchFunctor = nullptr; + mutable size_type *m_scratchFunctorHost = nullptr; + inline static std::mutex scratchFunctorMutex; hipStream_t m_stream = nullptr; uint32_t m_instance_id = @@ -133,9 +135,10 @@ class HIPInternal { HIPInternal() = default; // Resizing of reduction related scratch spaces - size_type *scratch_space(const std::size_t size); - size_type *scratch_flags(const std::size_t size); - size_type *scratch_functor(const std::size_t size) const; + size_type *scratch_space(std::size_t const size); + size_type *scratch_flags(std::size_t const size); + size_type *stage_functor_for_execution(void const *driver, + std::size_t const size) const; uint32_t impl_get_instance_id() const noexcept; int acquire_team_scratch_space(); // Resizing of team level 1 scratch diff --git a/core/src/HIP/Kokkos_HIP_KernelLaunch.hpp b/core/src/HIP/Kokkos_HIP_KernelLaunch.hpp index 0a3e6b108a..8bf5d7f394 100644 --- a/core/src/HIP/Kokkos_HIP_KernelLaunch.hpp +++ b/core/src/HIP/Kokkos_HIP_KernelLaunch.hpp @@ -377,12 +377,11 @@ struct HIPParallelLaunchKernelInvoker lock(HIPInternal::scratchFunctorMutex); DriverType *driver_ptr = reinterpret_cast( - hip_instance->scratch_functor(sizeof(DriverType))); - - hipMemcpyAsync(driver_ptr, &driver, sizeof(DriverType), hipMemcpyDefault, - hip_instance->m_stream); - + hip_instance->stage_functor_for_execution( + reinterpret_cast(&driver), sizeof(DriverType))); (base_t::get_kernel_func())<<m_stream>>>( driver_ptr); }