diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 450e7bf951ebd..90f2f21c92b8d 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -1031,7 +1031,7 @@ void batch_norm_grad(const Tensor& x, auto tmp = out_grad_data * x_sub_mean * rsqrt_var * rsqrt_var / nhw; auto mean_temp2 = sum(tmp, reduce_axis, dtype, false); - auto part2 = nhwc_out_grad - mean_temp1 - x_sub_mean * mean_temp2; + auto part2 = out_grad - mean_temp1 - x_sub_mean * mean_temp2; auto x_grad_data = part1 * part2; if (x.dtype() == phi::DataType::FLOAT16) {