Skip to content

Commit

Permalink
fix prim abs_grad nan bug (PaddlePaddle#57687)
Browse files Browse the repository at this point in the history
  • Loading branch information
cyber-pioneer authored and jiahy0825 committed Oct 16, 2023
1 parent 479a093 commit 18e7447
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 3 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/prim/api/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@
- reshape
- erf
- tanh
- sign
Original file line number Diff line number Diff line change
Expand Up @@ -585,9 +585,8 @@ void sigmoid_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
template <typename T>
void abs_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
auto abs_tmp = abs<T>(x);
auto divide_tmp = divide<T>(x, abs_tmp);
set_output<T>(out_grad * divide_tmp, x_grad);
auto sign_tmp = sign<T>(x);
set_output<T>(out_grad * sign_tmp, x_grad);
}
}

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@
'diag_grad',
'trace_grad',
'flip',
'sign',
]


Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/primitive/primitive.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@
- tanh
- full
- cast
- sign
1 change: 1 addition & 0 deletions paddle/fluid/primitive/rule/vjp/manual/manual_vjp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ std::vector<std::vector<paddle::Tensor>> reshape_vjp(
if (paddle::prim::StaticCompositeContext::Instance().IsBwdPrimEnabled() &&
!need_skip) {
FLAGS_tensor_operants_mode = "static";
VLOG(4) << "Call PIR Decomposed backward op reshape_grad";
paddle::Tensor* x_grad = !stop_gradients[0][0] ? &vjp_res[0][0] : nullptr;

details::reshape_grad<LazyTensor>(xshape, out_grad, x_grad);
Expand Down

0 comments on commit 18e7447

Please sign in to comment.