Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimization of depthwise_conv2d grad #46332

Merged
merged 1 commit into from
Sep 21, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions paddle/phi/kernels/gpu/depthwise_conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ __device__ __inline__ void KernelDepthwiseConvNCHW(
int offset = in_offset + h_in * input_width + w_in;
T in_data = input_data[offset];
if (fuse_relu_before_conv) {
value += weight[weight_offset] * T(max(0.0f, double(in_data)));
value += weight[weight_offset] *
T(max(0.0f, static_cast<double>(in_data)));
} else {
value += weight[weight_offset] * in_data;
}
Expand Down Expand Up @@ -228,7 +229,7 @@ __device__ __inline__ void KernelDepthwiseConvNHWC(
T in_data = input_data[offset];
const T* weight = filter_data + weight_offset * output_channels + c_out;
if (fuse_relu_before_conv) {
value += weight[0] * T(max(0.0f, double(in_data)));
value += weight[0] * T(max(0.0f, static_cast<double>(in_data)));
} else {
value += weight[0] * in_data;
}
Expand Down Expand Up @@ -281,7 +282,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNCHW(
int offset = in_offset + h_in * input_width + w_in;
if (fuse_relu_before_conv) {
value += r_weight[h_f * c_filter + w_f] *
T(max(0.0f, double(input_data[offset])));
T(max(0.0f, static_cast<double>(input_data[offset])));
} else {
value += r_weight[h_f * c_filter + w_f] * input_data[offset];
}
Expand Down Expand Up @@ -337,7 +338,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNHWC(
in_offset + (h_in * input_width + w_in) * input_channels + c_in;
if (fuse_relu_before_conv) {
value += r_weight[h_f * c_filter + w_f] *
T(max(0.0, double(input_data[offset])));
T(max(0.0, static_cast<double>(input_data[offset])));
} else {
value += r_weight[h_f * c_filter + w_f] * input_data[offset];
}
Expand Down Expand Up @@ -880,7 +881,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
image_wk;
if (fuse_relu_before_conv) {
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
T(max(0.0f, double(input_data[input_id])));
T(max(0.0f, static_cast<double>(input_data[input_id])));
} else {
s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] *
input_data[input_id];
Expand All @@ -891,7 +892,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW(
}

T val = BlockReduceSum(s);
platform::CudaAtomicAdd(&filter_grad_data[gbid], val);
if (threadIdx.y == 0 && threadIdx.x == 0) filter_grad_data[gbid] = val;
}

template <typename T, bool fuse_relu_before_conv>
Expand Down Expand Up @@ -941,7 +942,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNHWC(
kernel_id / filter_multiplier;
if (fuse_relu_before_conv) {
s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
T(max(0.0f, double(input_data[input_id])));
T(max(0.0f, static_cast<double>(input_data[input_id])));
} else {
s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] *
input_data[input_id];
Expand Down Expand Up @@ -1013,7 +1014,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC(
T s(0);
if (fuse_relu_before_conv) {
s = output_grad_data[output_id] *
T(max(0.0f, double(input_data[input_id])));
T(max(0.0f, static_cast<double>(input_data[input_id])));
} else {
s = output_grad_data[output_id] * input_data[input_id];
}
Expand Down