Skip to content

Commit

Permalink
Revert "make bilinear interpolate stable. (#48644)" (#49307)
Browse files Browse the repository at this point in the history
This reverts commit e1e8bf7.
  • Loading branch information
2742195759 authored Dec 27, 2022
1 parent a953395 commit 17ec162
Showing 1 changed file with 14 additions and 28 deletions.
42 changes: 14 additions & 28 deletions paddle/phi/kernels/gpu/interpolate_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/primitive/datamover_primitives.h"

DECLARE_bool(cudnn_deterministic);

namespace phi {

template <typename T>
Expand Down Expand Up @@ -1039,12 +1037,6 @@ static void Interpolate2DCUDABwd(
#endif

if (optimize_flag & is_nchw) {
if (FLAGS_cudnn_deterministic) {
VLOG(2)
<< "Run grad kernel of bilinear interpolate 2d with single thread.";
config.block_per_grid = 1;
config.thread_per_block = 1;
}
KeBilinearInterpBwShareMemory<T><<<config.block_per_grid,
config.thread_per_block,
0,
Expand All @@ -1063,27 +1055,21 @@ static void Interpolate2DCUDABwd(
} else if (!optimize_flag & is_nchw) {
const int num_kernels = n * c * out_h * out_w;
const int num_threads = std::min(dev_ctx.GetMaxThreadsPerBlock(), 1024);
int block_per_grid = backends::gpu::DivUp(num_kernels, num_threads);
int thread_per_block = num_threads;
if (FLAGS_cudnn_deterministic) {
VLOG(2)
<< "Run grad kernel of bilinear interpolate 2d with single thread.";
block_per_grid = 1;
thread_per_block = 1;
}
KeBilinearInterpNCHWBw<T>
<<<block_per_grid, thread_per_block, 0, dev_ctx.stream()>>>(
input_grad_data,
in_h,
in_w,
out_h,
out_w,
n,
c,
ratio_h,
ratio_w,
output_grad_data,
align_type_value);
<<<backends::gpu::DivUp(num_kernels, num_threads),
num_threads,
0,
dev_ctx.stream()>>>(input_grad_data,
in_h,
in_w,
out_h,
out_w,
n,
c,
ratio_h,
ratio_w,
output_grad_data,
align_type_value);
} else {
int64_t cw = c * out_w;
auto interp_divmods = funcs::FastDivModForInterpolate(c, out_chw, cw);
Expand Down

0 comments on commit 17ec162

Please sign in to comment.