Skip to content

Commit

Permalink
Merge pull request #6003 from masterleinad/fix_team_scratch_1_queues_…
Browse files Browse the repository at this point in the history
…sycl_cuda

Fix guards for using scratch space with SYCL
  • Loading branch information
dalg24 authored Mar 28, 2023
2 parents 54da8a2 + a798ac7 commit 8270db3
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 29 deletions.
59 changes: 41 additions & 18 deletions core/src/SYCL/Kokkos_SYCL_Instance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,28 +134,48 @@ void SYCLInternal::initialize(const sycl::queue& q) {
desul::Impl::init_lock_arrays_sycl(*m_queue);
}
#endif
}

int SYCLInternal::acquire_team_scratch_space() {
// Grab the next scratch memory allocation. We must make sure that the last
// kernel using the allocation has completed, so we wait for the event that
// was registered with that kernel.
int current_team_scratch = desul::atomic_fetch_inc_mod(
&m_current_team_scratch, m_n_team_scratch - 1,
desul::MemoryOrderRelaxed(), desul::MemoryScopeDevice());

m_team_scratch_current_size = 0;
m_team_scratch_ptr = nullptr;
m_team_scratch_event[current_team_scratch].wait_and_throw();

return current_team_scratch;
}

sycl::device_ptr<void> SYCLInternal::resize_team_scratch_space(
std::int64_t bytes, bool force_shrink) {
if (m_team_scratch_current_size == 0) {
m_team_scratch_current_size = bytes;
m_team_scratch_ptr =
int scratch_pool_id, std::int64_t bytes, bool force_shrink) {
// Multiple ParallelFor/Reduce Teams can call this function at the same time
// and invalidate the m_team_scratch_ptr. We use a pool to avoid any race
// condition.
if (m_team_scratch_current_size[scratch_pool_id] == 0) {
m_team_scratch_current_size[scratch_pool_id] = bytes;
m_team_scratch_ptr[scratch_pool_id] =
Kokkos::kokkos_malloc<Experimental::SYCLDeviceUSMSpace>(
"Kokkos::Experimental::SYCLDeviceUSMSpace::TeamScratchMemory",
m_team_scratch_current_size);
m_team_scratch_current_size[scratch_pool_id]);
}
if ((bytes > m_team_scratch_current_size) ||
((bytes < m_team_scratch_current_size) && (force_shrink))) {
m_team_scratch_current_size = bytes;
m_team_scratch_ptr =
if ((bytes > m_team_scratch_current_size[scratch_pool_id]) ||
((bytes < m_team_scratch_current_size[scratch_pool_id]) &&
(force_shrink))) {
m_team_scratch_current_size[scratch_pool_id] = bytes;
m_team_scratch_ptr[scratch_pool_id] =
Kokkos::kokkos_realloc<Experimental::SYCLDeviceUSMSpace>(
m_team_scratch_ptr, m_team_scratch_current_size);
m_team_scratch_ptr[scratch_pool_id],
m_team_scratch_current_size[scratch_pool_id]);
}
return m_team_scratch_ptr;
return m_team_scratch_ptr[scratch_pool_id];
}

void SYCLInternal::register_team_scratch_event(int scratch_pool_id,
sycl::event event) {
m_team_scratch_event[scratch_pool_id] = event;
}

uint32_t SYCLInternal::impl_get_instance_id() const { return m_instance_id; }
Expand Down Expand Up @@ -187,11 +207,14 @@ void SYCLInternal::finalize() {
m_scratchFlagsCount = 0;
m_scratchFlags = nullptr;

if (m_team_scratch_current_size > 0)
Kokkos::kokkos_free<Kokkos::Experimental::SYCLDeviceUSMSpace>(
m_team_scratch_ptr);
m_team_scratch_current_size = 0;
m_team_scratch_ptr = nullptr;
for (int i = 0; i < m_n_team_scratch; ++i) {
if (m_team_scratch_current_size[i] > 0) {
Kokkos::kokkos_free<Kokkos::Experimental::SYCLDeviceUSMSpace>(
m_team_scratch_ptr[i]);
m_team_scratch_current_size[i] = 0;
m_team_scratch_ptr[i] = nullptr;
}
}

for (auto& usm_mem : m_indirectKernelMem) usm_mem.reset();
// guard erasing from all_queues
Expand Down
13 changes: 10 additions & 3 deletions core/src/SYCL/Kokkos_SYCL_Instance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ class SYCLInternal {

sycl::device_ptr<void> scratch_space(const std::size_t size);
sycl::device_ptr<void> scratch_flags(const std::size_t size);
sycl::device_ptr<void> resize_team_scratch_space(std::int64_t bytes,
int acquire_team_scratch_space();
sycl::device_ptr<void> resize_team_scratch_space(int scratch_pool_id,
std::int64_t bytes,
bool force_shrink = false);
void register_team_scratch_event(int scratch_pool_id, sycl::event event);

uint32_t impl_get_instance_id() const;
static int m_syclDev;
Expand All @@ -62,8 +65,12 @@ class SYCLInternal {
// mutex to access shared memory
mutable std::mutex m_mutexScratchSpace;

int64_t m_team_scratch_current_size = 0;
sycl::device_ptr<void> m_team_scratch_ptr = nullptr;
// Team Scratch Level 1 Space
static constexpr int m_n_team_scratch = 10;
mutable int64_t m_team_scratch_current_size[m_n_team_scratch] = {};
mutable sycl::device_ptr<void> m_team_scratch_ptr[m_n_team_scratch] = {};
mutable int m_current_team_scratch = 0;
mutable sycl::event m_team_scratch_event[m_n_team_scratch] = {};
mutable std::mutex m_team_scratch_mutex;

uint32_t m_instance_id = Kokkos::Tools::Experimental::Impl::idForInstance<
Expand Down
18 changes: 13 additions & 5 deletions core/src/SYCL/Kokkos_SYCL_Parallel_Team.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ class ParallelFor<FunctorType, Kokkos::TeamPolicy<Properties...>,
// Only let one ParallelFor/Reduce modify the team scratch memory. The
// constructor acquires the mutex which is released in the destructor.
std::scoped_lock<std::mutex> m_scratch_lock;
int m_scratch_pool_id = -1;

template <typename FunctorWrapper>
sycl::event sycl_direct_launch(const Policy& policy,
Expand Down Expand Up @@ -445,17 +446,17 @@ class ParallelFor<FunctorType, Kokkos::TeamPolicy<Properties...>,
inline void execute() const {
if (m_league_size == 0) return;

auto& space = *m_policy.space().impl_internal_space_instance();
Kokkos::Experimental::Impl::SYCLInternal::IndirectKernelMem&
indirectKernelMem = m_policy.space()
.impl_internal_space_instance()
->get_indirect_kernel_mem();
indirectKernelMem = space.get_indirect_kernel_mem();

auto functor_wrapper = Experimental::Impl::make_sycl_function_wrapper(
m_functor, indirectKernelMem);

sycl::event event = sycl_direct_launch(m_policy, functor_wrapper,
functor_wrapper.get_copy_event());
functor_wrapper.register_event(event);
space.register_team_scratch_event(m_scratch_pool_id, event);
}

ParallelFor(FunctorType const& arg_functor, Policy const& arg_policy)
Expand All @@ -481,9 +482,11 @@ class ParallelFor<FunctorType, Kokkos::TeamPolicy<Properties...>,

// Functor's reduce memory, team scan memory, and team shared memory depend
// upon team size.
auto& space = *m_policy.space().impl_internal_space_instance();
auto& space = *m_policy.space().impl_internal_space_instance();
m_scratch_pool_id = space.acquire_team_scratch_space();
m_global_scratch_ptr =
static_cast<sycl::device_ptr<char>>(space.resize_team_scratch_space(
m_scratch_pool_id,
static_cast<ptrdiff_t>(m_scratch_size[1]) * m_league_size));

if (static_cast<int>(space.m_maxShmemPerBlock) <
Expand Down Expand Up @@ -546,6 +549,7 @@ class ParallelReduce<CombinedFunctorReducerType,
// Only let one ParallelFor/Reduce modify the team scratch memory. The
// constructor acquires the mutex which is released in the destructor.
std::scoped_lock<std::mutex> m_scratch_lock;
int m_scratch_pool_id = -1;

template <typename PolicyType, typename FunctorWrapper,
typename ReducerWrapper>
Expand Down Expand Up @@ -831,6 +835,8 @@ class ParallelReduce<CombinedFunctorReducerType,
{functor_wrapper.get_copy_event(), reducer_wrapper.get_copy_event()});
functor_wrapper.register_event(event);
reducer_wrapper.register_event(event);

instance.register_team_scratch_event(m_scratch_pool_id, event);
}

private:
Expand All @@ -857,9 +863,11 @@ class ParallelReduce<CombinedFunctorReducerType,

// Functor's reduce memory, team scan memory, and team shared memory depend
// upon team size.
auto& space = *m_policy.space().impl_internal_space_instance();
auto& space = *m_policy.space().impl_internal_space_instance();
m_scratch_pool_id = space.acquire_team_scratch_space();
m_global_scratch_ptr =
static_cast<sycl::device_ptr<char>>(space.resize_team_scratch_space(
m_scratch_pool_id,
static_cast<ptrdiff_t>(m_scratch_size[1]) * m_league_size));

if (static_cast<int>(space.m_maxShmemPerBlock) <
Expand Down
3 changes: 0 additions & 3 deletions core/unit_test/sycl/TestSYCL_TeamScratchStreams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,6 @@ void sycl_queue_scratch_test(
} // namespace Impl

TEST(sycl, team_scratch_1_queues) {
#if defined(KOKKOS_ENABLE_SYCL) && !defined(KOKKOS_ARCH_INTEL_GPU)
GTEST_SKIP() << "skipping for SYCL+Cuda";
#endif
int N = 1000000;
int T = 10;
int M_base = 150;
Expand Down

0 comments on commit 8270db3

Please sign in to comment.