diff --git a/paddle/phi/kernels/impl/atan2_grad_kernel_impl.h b/paddle/phi/kernels/impl/atan2_grad_kernel_impl.h index d0dd18298518a..ecf527cc589be 100644 --- a/paddle/phi/kernels/impl/atan2_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/atan2_grad_kernel_impl.h @@ -32,8 +32,12 @@ struct Atan2GradFunctor { float x1 = static_cast(x1_[idx]); float x2 = static_cast(x2_[idx]); float x = x1 * x1 + x2 * x2; - dx1_[idx] = static_cast(static_cast(dout_[idx]) * x2 / x); - dx2_[idx] = static_cast(-static_cast(dout_[idx]) * x1 / x); + if (dx1_) { + dx1_[idx] = static_cast(static_cast(dout_[idx]) * x2 / x); + } + if (dx2_) { + dx2_[idx] = static_cast(-static_cast(dout_[idx]) * x1 / x); + } } const T* x1_; @@ -56,8 +60,12 @@ struct Atan2GradFunctor { HOSTDEVICE void operator()(int64_t idx) const { auto x = x1_[idx] * x1_[idx] + x2_[idx] * x2_[idx]; - dx1_[idx] = dout_[idx] * x2_[idx] / x; - dx2_[idx] = -dout_[idx] * x1_[idx] / x; + if (dx1_) { + dx1_[idx] = dout_[idx] * x2_[idx] / x; + } + if (dx2_) { + dx2_[idx] = -dout_[idx] * x1_[idx] / x; + } } const double* x1_; @@ -81,9 +89,11 @@ void Atan2GradKernel(const Context& ctx, auto out_grad_data = out_grad.data(); auto* x_grad_data = - ctx.template Alloc(x_grad, size_t(x.numel() * sizeof(T))); + x_grad ? ctx.template Alloc(x_grad, size_t(x.numel() * sizeof(T))) + : nullptr; auto* y_grad_data = - ctx.template Alloc(y_grad, size_t(y.numel() * sizeof(T))); + y_grad ? ctx.template Alloc(y_grad, size_t(y.numel() * sizeof(T))) + : nullptr; phi::funcs::ForRange for_range(ctx, numel); phi::Atan2GradFunctor functor(