-
Notifications
You must be signed in to change notification settings - Fork 380
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
multi-node distributed training with spark #935
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing work! We should still wait for Azul's review before merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome @jmoralez!🎉
Adds the functionality to perform distributed data parallel training with spark. The logic is as follows:
There were a couple of challenges:
_trainer
attribute (and thus thetrainer
property) from the model.save
method of the models used the trainer'ssave_checkpoint
method. Since we won't have the trainer anymore this implements very simple methods to save and load models, which use only the init params and weights (which will also make the files smaller). The premise here is that we don't actually need all the stuff that the checkpoint has in order to load the model for inference. This tries to maintain backward compatibility by using the same names as pytorch lightning does (hyper_parameters
andstate_dict
).Also makes the following change, which isn't strictly necessary and could be made in a separate PR: