Skip to content

Commit

Permalink
fix silu double grad prim
Browse files Browse the repository at this point in the history
  • Loading branch information
cyber-pioneer committed Aug 15, 2023
1 parent 68d6cf7 commit 94f7ee0
Showing 1 changed file with 2 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -439,16 +439,15 @@ void silu_double_grad(const Tensor& x,
const Tensor& grad_x_grad,
Tensor* grad_x,
Tensor* grad_out_grad) {
auto sigmoid = out / x;
auto sigmoid = 1 / (1 + exp<T>(-x));
auto tmp1 = 1 - sigmoid;
auto tmp2 = 1 + tmp1 * x;
if (grad_out_grad) {
auto ddout = grad_x_grad * sigmoid * tmp2;
set_output<T>(ddout, grad_out_grad);
}
if (grad_x) {
auto dx =
sigmoid * grad_x_grad * out_grad * (1 + (tmp2 - x * sigmoid)) * tmp1;
auto dx = sigmoid * grad_x_grad * out_grad * (1 + (tmp2 - out)) * tmp1;
set_output<T>(dx, grad_x);
}
}
Expand Down

0 comments on commit 94f7ee0

Please sign in to comment.