From 41c34ae0e5d0f7bb58174c42806fcddddebae19a Mon Sep 17 00:00:00 2001 From: cyber-pioneer Date: Mon, 9 Oct 2023 07:07:11 +0000 Subject: [PATCH] fix prim abs_grad nan bug --- paddle/fluid/prim/api/api.yaml | 1 + .../prim/api/composite_backward/composite_backward_api.h | 5 ++--- paddle/fluid/primitive/codegen/gen.py | 1 + paddle/fluid/primitive/primitive.yaml | 1 + paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc | 1 + 5 files changed, 6 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/prim/api/api.yaml b/paddle/fluid/prim/api/api.yaml index ec3bd5741371e..5a1a6e335abeb 100644 --- a/paddle/fluid/prim/api/api.yaml +++ b/paddle/fluid/prim/api/api.yaml @@ -48,3 +48,4 @@ - reshape - erf - tanh +- sign diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 53369e956d7b8..64c431b3d237f 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -585,9 +585,8 @@ void sigmoid_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { template void abs_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { - auto abs_tmp = abs(x); - auto divide_tmp = divide(x, abs_tmp); - set_output(out_grad * divide_tmp, x_grad); + auto sign_tmp = sign(x); + set_output(out_grad * sign_tmp, x_grad); } } diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index da9e12fa817c5..88f0209eb59d6 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -294,6 +294,7 @@ 'diag_grad', 'trace_grad', 'flip', + 'sign', ] diff --git a/paddle/fluid/primitive/primitive.yaml b/paddle/fluid/primitive/primitive.yaml index ccf9673bafba0..794f1121da679 100644 --- a/paddle/fluid/primitive/primitive.yaml +++ b/paddle/fluid/primitive/primitive.yaml @@ -50,3 +50,4 @@ - tanh - full - cast +- sign diff --git a/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc b/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc index 838b83d5d533b..6b3b1050448ef 100644 --- a/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc +++ b/paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc @@ -55,6 +55,7 @@ std::vector> reshape_vjp( if (paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled() && !need_skip) { FLAGS_tensor_operants_mode = "static"; + VLOG(4) << "Call PIR Decomposed backward op reshape_grad"; paddle::Tensor* x_grad = !stop_gradients[0][0] ? &vjp_res[0][0] : nullptr; details::reshape_grad(xshape, out_grad, x_grad);