-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Fix gradient tensor mutate in {adam/ftrl/rmprop/rmspropalex}_update
.
#15768
Fix gradient tensor mutate in {adam/ftrl/rmprop/rmspropalex}_update
.
#15768
Conversation
adam_update
.{adam/ftrl}_update
.
@sxjscience @eric-haibin-lin @apeforest @larroy Could you please have a look. All the original tests pass and have added test to check only expected variables are mutated. |
{adam/ftrl}_update
.{adam/ftrl/rmprop/rmspropalex}_update
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Excellent job!
src/operator/optimizer_op-inl.h
Outdated
using namespace mshadow_op; | ||
|
||
const DType rescaled_grad = rescale_grad * grad_data[i] + | ||
wd * weight_data[i]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find that we can actually simplify the code by adding the if-else statement here.
if(clip_gradient >= 0.0f) {
rescaled_grad = clip::Map(rescaled_grad, clip_gradient);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change all appearance of the pattern and it should be good to merge.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sxjscience Done.
Thanks for the suggestions. Makes it more readable.
* refactor code.
apache#15768) * update code to fix apache#15759 * add relevant test * re-add the removed conditional dispatch * fix grad mutate for ftrl_update * add test for ftrl_update * fix grad mutate for rmspropalex_update * add test for rmspropalex_update * use KERNEL_ASSIGN in RMSPropAlexUpdateKernel. * fix grad mutate for rmsprop_update * add test for rmsprop_update * add more optimizers for mutation test * retrigger CI * retrigger CI * retrigger CI * retrigger CI * address comments. * refactor code. * retrigger CI * retrigger CI * retrigger CI
apache#15768) * update code to fix apache#15759 * add relevant test * re-add the removed conditional dispatch * fix grad mutate for ftrl_update * add test for ftrl_update * fix grad mutate for rmspropalex_update * add test for rmspropalex_update * use KERNEL_ASSIGN in RMSPropAlexUpdateKernel. * fix grad mutate for rmsprop_update * add test for rmsprop_update * add more optimizers for mutation test * retrigger CI * retrigger CI * retrigger CI * retrigger CI * address comments. * refactor code. * retrigger CI * retrigger CI * retrigger CI
@kshitij12345 @apeforest Test is failing for this change - |
@Vikas-kum Thanks. I have checked and found that the difference is -5.9604645e-08 which is lower than the default |
So I tested *_update ops and it turns out passing randn (which samples from uniform distribution) to *_update op (e.g. adam_update) gives output that may consist nans
Now the way you've tested is checked if input and output has mutation after *_update method is called. Does that take into consideration the NaNs? |
@ChaiBapchya |
Sure. I think it needs to be handled. |
Description
Rescaling the gradient used to update the data of
grad
passed as input.Detailed in #15759
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes