Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize group_norm op forward #39596

Merged
merged 8 commits into from
Mar 1, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 54 additions & 11 deletions paddle/fluid/operators/group_norm_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,10 @@ __global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C, int W,
T x_mean = 0, x_var = 0;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T val;
if (data_layout == DataLayout::kNCHW) {
val = x[(bid * C + ccid) * imsize + imid];
} else {
int hid = imid / W;
int wid = imid % W;
val = x[(bid * H + hid) * W * C + wid * C + ccid];
}
int hid = imid / W;
int wid = imid % W;
val = x[(bid * H + hid) * W * C + wid * C + ccid];

Zjq9409 marked this conversation as resolved.
Show resolved Hide resolved
x_mean += val;
x_var += val * val;
}
Expand All @@ -84,6 +81,40 @@ __global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C, int W,
CudaAtomicAddWithWarp(&var[bid * groups + gid], x_var);
}

template <typename T, int BlockDim>
__global__ void GroupNormForwardGetMeanAndVarNCHW(
const T* x, int N, int C, int W, int imsize, int groups, int group_size,
T* mean, T* var, const DataLayout data_layout) {
T x_mean = 0, x_var = 0;
int i = blockIdx.x;
for (int j = threadIdx.x; j < group_size * imsize; j += blockDim.x) {
T val;
val = x[i * group_size * imsize + j];
x_mean += val;
x_var += val * val;
}
x_mean /= group_size * imsize;
x_var /= group_size * imsize;
if (blockDim.x <= 32) {
CudaAtomicAddWithWarp(&mean[blockIdx.x], x_mean);
Zjq9409 marked this conversation as resolved.
Show resolved Hide resolved
CudaAtomicAddWithWarp(&var[blockIdx.x], x_var);
} else {
typedef cub::BlockReduce<T, BlockDim> BlockReduce;

__shared__ typename BlockReduce::TempStorage mean_storage;
__shared__ typename BlockReduce::TempStorage var_storage;

__syncthreads();
auto mean_out = BlockReduce(mean_storage).Reduce(x_mean, cub::Sum());
auto var_out = BlockReduce(var_storage).Reduce(x_var, cub::Sum());
__syncthreads();
if (threadIdx.x == 0) {
mean[blockIdx.x] = mean_out;
var[blockIdx.x] = var_out;
}
}
}

template <typename T, int flags>
__global__ void GroupNormForward(const T* x, const T* mean, const T* var,
const T* scale, const T* bias, int N, int C,
Expand Down Expand Up @@ -113,7 +144,9 @@ __global__ void GroupNormForward(const T* x, const T* mean, const T* var,
}
val = (val - x_mean) * var_inv;
if (flags & kHasScale) val *= scale[gid * group_size + cid];
if (flags & kHasBias) val += bias[gid * group_size + cid];
if (flags & kHasBias) {
Zjq9409 marked this conversation as resolved.
Show resolved Hide resolved
val += bias[gid * group_size + cid];
}
if (data_layout == DataLayout::kNCHW) {
y[(bid * C + ccid) * imsize + imid] = val;
} else {
Expand Down Expand Up @@ -186,12 +219,22 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
int block_size = std::max(std::min(256, imsize), 64);
#else
int block_size = std::min(1024, imsize);
int block_size_nchw = std::min(1024, group_size * imsize);
#endif
dim3 grid(group_size, groups, x_dims[0]);
dim3 threads(block_size, 1, 1);
GroupNormForwardGetMeanAndVar<T><<<grid, threads, 0, dev_ctx.stream()>>>(
x_data, x_dims[0], C, W, imsize, groups, group_size, mean_data,
temp_var_data, data_layout);
if (data_layout == DataLayout::kNCHW) {
GroupNormForwardGetMeanAndVarNCHW<
T,
1024><<<x_dims[0] * groups, block_size_nchw, 0, dev_ctx.stream()>>>(
x_data, x_dims[0], C, W, imsize, groups, group_size, mean_data,
temp_var_data, data_layout);
} else {
GroupNormForwardGetMeanAndVar<T><<<grid, threads, 0, dev_ctx.stream()>>>(
x_data, x_dims[0], C, W, imsize, groups, group_size, mean_data,
temp_var_data, data_layout);
}

int flags =
(scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias;
UNROLL_ALL_CASES(flags, GroupNormForward, x_data, mean_data, temp_var_data,
Expand Down