Skip to content

Commit

Permalink
reduce_sum default fp32 can avoid return inf when the sum value large…
Browse files Browse the repository at this point in the history
… than 65504
  • Loading branch information
thisjiang committed Jul 5, 2021
1 parent 3cf8a27 commit 86cb6ac
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def _update_list(self):
# fp16 is slower than fp32, though fp16 is supported.
'lookup_table',
'lookup_table_v2',
# default fp32 can avoid return inf when the sum value large than 65504
'reduce_sum',
}

# This set contains two types of ops. All ops supported fp16 calculation. One
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/dygraph/amp/auto_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
'sigmoid_cross_entropy_with_logits',
'cross_entropy',
'cross_entropy2',
# default fp32 can avoid return inf when the sum value large than 65504
'reduce_sum',
}

Expand Down

1 comment on commit 86cb6ac

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.