-
Notifications
You must be signed in to change notification settings - Fork 856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issue1416 #1533
Issue1416 #1533
Conversation
Codecov Report
@@ Coverage Diff @@
## master #1533 +/- ##
=========================================
Coverage ? 97.11%
=========================================
Files ? 55
Lines ? 2077
Branches ? 341
=========================================
Hits ? 2017
Misses ? 31
Partials ? 29
|
@ptrcklv thanks for taking this on! One of the maintainers will take a look soon! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ptrcklv — thanks so much for this contribution and great documentation!!
left a few formatting comments / requests for additional docs. please re-request a review once you've made changes!
trainer1.optimizer.state_dict()["state"][k]["exp_avg"], | ||
trainer2.optimizer.state_dict()["state"][k]["exp_avg"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why are we only checking equivalence of these fields? could we check the entire state_dict
's values?
@@ -216,6 +216,35 @@ def test_warmup(self): | |||
trainer.fit(model, [dataloaders[0]]) | |||
self.assertEqual(trainer.warmup_steps, 1) | |||
|
|||
def test_save_load(self): | |||
fd, checkpoint_path = tempfile.mkstemp() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we put this in a try
/except
/finally
or use a context with tempfile.NamedTemporaryFile() as f:
to ensure proper cleanup?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I put a NamedTemporaryFile, however, the tempfile.mkstemp() I copied from test_save_load() from test_multitask_classifier.py. Maybe you want to update it there, too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it — yes, we likely need to clean up other parts of the codebase as well. :)
|
||
Parameters | ||
---------- | ||
trainer_path : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: no :
here. see https://github.com/snorkel-team/snorkel/blob/master/snorkel/classification/multitask_classifier.py for example of doc strings
@@ -216,6 +216,35 @@ def test_warmup(self): | |||
trainer.fit(model, [dataloaders[0]]) | |||
self.assertEqual(trainer.warmup_steps, 1) | |||
|
|||
def test_save_load(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have you tested this with resuming training for a saved checkpoint?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I have included it in the test now
trainer_path : | ||
The path to the saved trainer config to be loaded | ||
model : | ||
MultitaskClassifier for which the optimizer has been set. Parameters of optimizer must fit to model parameters. This model | ||
shall be the model which was fit by the stored Trainer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: no :
for parameters (see above)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fantastic, lgtm! thank you for the contribution!!
Description of proposed changes
save() and load() method for Trainer to serialize the optimizer + trainer config (dependent on whether model has been fitted with that Trainer instance)
Related issue(s)
1416
Fixes # (issue)
1416
Test plan
test_trainer.py adapted save_load_test()
Checklist
Need help on these? Just ask!
tox -e complex
and/ortox -e spark
if appropriate.