Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi-node distributed training with spark #935

Merged
merged 17 commits into from
Apr 10, 2024
Merged

multi-node distributed training with spark #935

merged 17 commits into from
Apr 10, 2024

Conversation

jmoralez
Copy link
Member

@jmoralez jmoralez commented Mar 19, 2024

Adds the functionality to perform distributed data parallel training with spark. The logic is as follows:

  • The user provides a spark dataframe and sets a configuration with how many nodes and how many GPUs each node has.
  • We'll then have one task for each GPU in the cluster and thus partition the dataframe accordingly.
  • We save the partitioned dataframe and get the names of the generated parquet files.
  • Each task will compute its global rank, load its corresponding file and use that to train.

There were a couple of challenges:

  • The final model is serialized and sent back to the driver, so we should make sure that it doesn't contain any exotic things (to avoid pickling errors), thus we remove the _trainer attribute (and thus the trainer property) from the model.
  • The save method of the models used the trainer's save_checkpoint method. Since we won't have the trainer anymore this implements very simple methods to save and load models, which use only the init params and weights (which will also make the files smaller). The premise here is that we don't actually need all the stuff that the checkpoint has in order to load the model for inference. This tries to maintain backward compatibility by using the same names as pytorch lightning does (hyper_parameters and state_dict).

Also makes the following change, which isn't strictly necessary and could be made in a separate PR:

  • Ensures that the original aliases are preserved when saving and loading models. Right now when loading a saved model it'll use the default alias, so if an AutoNHITS was trained with the alias 'my_model' after loading it and making predictions with it the column will be named 'NHITS'.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@jmoralez jmoralez marked this pull request as ready for review March 20, 2024 17:34
Copy link
Contributor

@cchallu cchallu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work! We should still wait for Azul's review before merging.

Copy link
Member

@AzulGarza AzulGarza left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome @jmoralez!🎉

@AzulGarza AzulGarza self-requested a review April 10, 2024 05:02
@jmoralez jmoralez merged commit 8121bfc into main Apr 10, 2024
17 checks passed
@jmoralez jmoralez deleted the multi-node2 branch April 10, 2024 16:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants