-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] OSS restore state to proper device #46
Conversation
Codecov Report
@@ Coverage Diff @@
## master #46 +/- ##
=======================================
Coverage 94.18% 94.18%
=======================================
Files 35 35
Lines 2065 2065
=======================================
Hits 1945 1945
Misses 120 120
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
@@ -148,7 +148,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | |||
self.load_local_state_dict(state_dict["state"][self.rank]) | |||
|
|||
# Restore the global param_groups | |||
self.param_groups = state_dict["param_groups"] | |||
self.param_groups = recursive_copy_to_device(state_dict["param_groups"], non_blocking=True, device=self._device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line too long?
Is there a test for this already? If so, should there be some assert in the test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no test, it's a good point, I'll add that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We follow PyTorch coding style of 120 columns:
https://github.com/pytorch/pytorch/blob/master/.flake8#L3
It is enforced here:
https://github.com/facebookresearch/fairscale/blob/master/setup.cfg#L29
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's supposed to be automatic with vscode+black, it's pure FB tooling and set to 120 cols, this formatting crap is so annoying.. I'll fix that in the next PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah no ok, black is happy with that indeed, it's 121 cols but for some reason this passes (same for the init above, also 121). Anyway, all good then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Before submitting
What does this PR do?
Fixes the state pull and push behavior to being something more easily predictable, when a state is pulled it gets moved to CPU (failsafe if this is a reasonably big model). Upon restore the
param_groups
were not restored to the proper device. It looks like there are still probable duplicates in that field thoughPR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃