Skip to content

Commit

Permalink
Add weaker memory fence in custom allreduce synchronize.
Browse files Browse the repository at this point in the history
  • Loading branch information
戚余航 committed Sep 13, 2024
1 parent 6821020 commit 4cdb581
Showing 1 changed file with 60 additions and 13 deletions.
73 changes: 60 additions & 13 deletions csrc/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,42 @@ struct Signal {

struct __align__(16) RankData { const void* __restrict__ ptrs[8]; };

struct __align__(16) RankSignals { volatile Signal* signals[8]; };
// No more use volatile qualifier, we use PTX to enable volatile ld/st
// struct __align__(16) RankSignals { volatile Signal* signals[8]; };
struct __align__(16) RankSignals { Signal* signals[8]; };

////////////////////////////////////////////////////////////////////////////////////////////////////

static inline __device__ void st_flag_release(uint32_t const& flag, uint32_t* flag_addr)
{
asm volatile("st.release.sys.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}

static inline __device__ uint32_t ld_flag_acquire_global(uint32_t* flag_addr)
{
uint32_t flag;
asm volatile("ld.acquire.global.sys.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
return flag;
}

static inline __device__ void st_flag_volatile_global(uint32_t const& flag, uint32_t* flag_addr)
{
asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}

static inline __device__ void st_flag_volatile(uint32_t const& flag, uint32_t* flag_addr)
{
asm volatile("st.volatile.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
}

static inline __device__ uint32_t ld_flag_volatile_global(uint32_t* flag_addr)
{
uint32_t flag;
asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
return flag;
}

////////////////////////////////////////////////////////////////////////////////////////////////////

// like std::array, but aligned
template <typename T, int sz>
Expand Down Expand Up @@ -128,16 +163,22 @@ DINLINE O downcast(array_t<float, O::size> val) {
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template <int ngpus>
DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
DINLINE void start_sync(const RankSignals& sg, Signal* self_sg,
int rank) {
if (threadIdx.x < ngpus) {
// reset flag for next time
self_sg->end[blockIdx.x][threadIdx.x] = 0;
// origin: self_sg->end[blockIdx.x][threadIdx.x] = 0;
// new:
st_flag_volatile_global(0, &self_sg->end[blockIdx.x][threadIdx.x]);
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
// origin: sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
// new:
st_flag_volatile(1, &sg.signals[threadIdx.x]->start[blockIdx.x][rank]);
// wait until we got true from all ranks
while (!self_sg->start[blockIdx.x][threadIdx.x]);
// origin: while (!self_sg->start[blockIdx.x][threadIdx.x]);
// New:
while (!ld_flag_acquire_global(&self_sg->start[blockIdx.x][threadIdx.x]));
}
__syncthreads();
}
Expand All @@ -146,7 +187,7 @@ DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg,
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
template <int ngpus, bool final_sync = false>
DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
DINLINE void end_sync(const RankSignals& sg, Signal* self_sg,
int rank) {
__syncthreads();
// eliminate the case that prior writes are not visible after signals become
Expand All @@ -155,13 +196,19 @@ DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg,
// the memory model.
if constexpr (!final_sync) __threadfence_system();
if (threadIdx.x < ngpus) {
// reset flag for next time
self_sg->start[blockIdx.x][threadIdx.x] = 0;
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
// origin: sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1;
// New:
st_flag_release(1, &sg.signals[threadIdx.x]->end[blockIdx.x][rank]);
// reset flag for next time
// origin: self_sg->start[blockIdx.x][threadIdx.x] = 0;
// New:
st_flag_volatile_global(0, &self_sg->start[blockIdx.x][threadIdx.x]);
// wait until we got true from all ranks
while (!self_sg->end[blockIdx.x][threadIdx.x]);
// origin: while (!self_sg->end[blockIdx.x][threadIdx.x]);
// New:
while (!ld_flag_volatile_global(&self_sg->end[blockIdx.x][threadIdx.x]));
}
if constexpr (!final_sync) __syncthreads();
}
Expand All @@ -179,7 +226,7 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) {
template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_1stage(RankData* _dp, RankSignals sg,
volatile Signal* self_sg, T* __restrict__ result,
Signal* self_sg, T* __restrict__ result,
int rank, int size) {
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
Expand All @@ -196,14 +243,14 @@ __global__ void __launch_bounds__(512, 1)
}

template <typename P>
DINLINE P* get_tmp_buf(volatile Signal* sg) {
DINLINE P* get_tmp_buf(Signal* sg) {
return (P*)(((Signal*)sg) + 1);
}

template <typename T, int ngpus>
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData* _dp, RankSignals sg,
volatile Signal* self_sg, T* __restrict__ result,
Signal* self_sg, T* __restrict__ result,
int rank, int size) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
Expand Down

0 comments on commit 4cdb581

Please sign in to comment.