diff --git a/paddle/fluid/operators/group_norm_op.cu b/paddle/fluid/operators/group_norm_op.cu index 72a90d17998d8..b376334f1e93c 100644 --- a/paddle/fluid/operators/group_norm_op.cu +++ b/paddle/fluid/operators/group_norm_op.cu @@ -29,6 +29,7 @@ namespace operators { using DataLayout = framework::DataLayout; enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 }; +#define ALIGN_BYTES 16 #define CHECK_CASE(i, flags, kernel_name, ...) \ if (i == flags) { \ @@ -56,8 +57,7 @@ __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) { template __global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C, int W, int imsize, int groups, - int group_size, T* mean, T* var, - const DataLayout data_layout) { + int group_size, T* mean, T* var) { int gid = blockIdx.y; int cid = blockIdx.x; int bid = blockIdx.z; @@ -68,13 +68,10 @@ __global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C, int W, T x_mean = 0, x_var = 0; for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { T val; - if (data_layout == DataLayout::kNCHW) { - val = x[(bid * C + ccid) * imsize + imid]; - } else { - int hid = imid / W; - int wid = imid % W; - val = x[(bid * H + hid) * W * C + wid * C + ccid]; - } + int hid = imid / W; + int wid = imid % W; + val = x[(bid * H + hid) * W * C + wid * C + ccid]; + x_mean += val; x_var += val * val; } @@ -84,6 +81,85 @@ __global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C, int W, CudaAtomicAddWithWarp(&var[bid * groups + gid], x_var); } +template +__device__ __forceinline__ void ThreadReduce(const T* input, int size, + const int offset, AccT* mean, + AccT* var) { + using VecT = kps::details::VectorType; + int tid = threadIdx.x; + if (offset > 0) { + input -= offset; + size += offset; + if (tid >= offset) { + AccT temp = input[tid]; + *mean += temp; + *var += temp * temp; + } + size -= blockDim.x; + input += blockDim.x; + } + int remain = size % (VecSize * blockDim.x); + + T ins[VecSize]; + VecT* ins_vec = reinterpret_cast(&ins); + + // vector part + for (; VecSize * tid < (size - remain); tid += blockDim.x) { + *ins_vec = reinterpret_cast(input)[tid]; + +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + AccT temp = ins[i]; + *mean += temp; + *var += temp * temp; + } + } + + // scalar part + tid = size - remain + threadIdx.x; + for (; tid < size; tid += blockDim.x) { + AccT temp = input[tid]; + *mean += temp; + *var += temp * temp; + } +} + +template +__global__ void ScalarGetMeanAndVarNCHW(const T* x, T* mean, T* var, int size) { + int i = blockIdx.x; + T x_mean = 0, x_var = 0; + for (int j = threadIdx.x; j < size; j += blockDim.x) { + T val; + val = x[i * size + j]; + x_mean += val; + x_var += val * val; + } + x_mean /= size; + x_var /= size; + CudaAtomicAddWithWarp(&mean[i], x_mean); + CudaAtomicAddWithWarp(&var[i], x_var); +} + +template +__global__ void VectorizedGetMeanAndVarNCHW(const T* x, T* mean, T* var, + int size) { + int i = blockIdx.x; + AccT x_mean = static_cast(0); + AccT x_var = static_cast(0); + const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T); + x += i * size; + ThreadReduce(x, size, input_offset, &x_mean, &x_var); + x_mean = kps::details::BlockXReduce>( + x_mean, kps::AddFunctor()); + x_var = kps::details::BlockXReduce>( + x_var, kps::AddFunctor()); + __syncthreads(); + if (threadIdx.x == 0) { + mean[i] = static_cast(x_mean / size); + var[i] = static_cast(x_var / size); + } +} + template __global__ void GroupNormForward(const T* x, const T* mean, const T* var, const T* scale, const T* bias, int N, int C, @@ -96,26 +172,34 @@ __global__ void GroupNormForward(const T* x, const T* mean, const T* var, int H = imsize / W; int ccid = gid * group_size + cid; if (ccid >= C) return; - T x_mean = mean[bid * groups + gid]; - T x_var = var[bid * groups + gid]; + auto ng = bid * groups + gid; + T x_mean = mean[ng]; + T x_var = var[ng]; x_var = x_var - x_mean * x_mean; - T var_inv = 1.0 / sqrt(x_var + epsilon); - if (cid == 0 && threadIdx.x == 0) real_var[bid * groups + gid] = x_var; + T var_inv = rsqrt(x_var + epsilon); + if (cid == 0 && threadIdx.x == 0) { + real_var[ng] = x_var; + } for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { T val; int hid, wid; + int index = (bid * C + ccid) * imsize + imid; if (data_layout == DataLayout::kNCHW) { - val = x[(bid * C + ccid) * imsize + imid]; + val = x[index]; } else { hid = imid / W; wid = imid % W; val = x[(bid * H + hid) * W * C + wid * C + ccid]; } val = (val - x_mean) * var_inv; - if (flags & kHasScale) val *= scale[gid * group_size + cid]; - if (flags & kHasBias) val += bias[gid * group_size + cid]; + if (flags & kHasScale) { + val *= scale[ccid]; + } + if (flags & kHasBias) { + val += bias[ccid]; + } if (data_layout == DataLayout::kNCHW) { - y[(bid * C + ccid) * imsize + imid] = val; + y[index] = val; } else { y[(bid * H + hid) * W * C + wid * C + ccid] = val; } @@ -182,16 +266,41 @@ class GroupNormKernel imsize *= x_dims[i]; } } + #ifdef __HIPCC__ int block_size = std::max(std::min(256, imsize), 64); #else int block_size = std::min(1024, imsize); #endif + dim3 grid(group_size, groups, x_dims[0]); dim3 threads(block_size, 1, 1); - GroupNormForwardGetMeanAndVar<<>>( - x_data, x_dims[0], C, W, imsize, groups, group_size, mean_data, - temp_var_data, data_layout); + if (data_layout == DataLayout::kNCHW) { + using AccT = typename details::MPTypeTrait::Type; + constexpr int vec_size = sizeof(float4) / sizeof(T); + int size = group_size * imsize; + const int max_num_threads = 1024; + int max_block_size = std::min(size / vec_size, max_num_threads); + int block_size_nchw = 1; + while (block_size_nchw < max_block_size) { + block_size_nchw *= 2; + } + block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize); + dim3 grids(x_dims[0] * groups); + dim3 blocks(block_size_nchw); + if (size < vec_size) { + ScalarGetMeanAndVarNCHW<<>>( + x_data, mean_data, temp_var_data, size); + } else { + VectorizedGetMeanAndVarNCHW< + T, AccT, vec_size><<>>( + x_data, mean_data, temp_var_data, size); + } + } else { + GroupNormForwardGetMeanAndVar<<>>( + x_data, x_dims[0], C, W, imsize, groups, group_size, mean_data, + temp_var_data); + } int flags = (scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias; UNROLL_ALL_CASES(flags, GroupNormForward, x_data, mean_data, temp_var_data,