-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FullyShardedDataParallel (FSDP) #413
Conversation
* [mypy]: fixed all the mypy errors * make older version python happy with typing hints * [chore] Fix lint errors that broke master (#348) authored-by: Anjali Sridhar <anj@devfair0443.h2.fair> Co-authored-by: anj-s <32556631+anj-s@users.noreply.github.com>
* Test CPU offload * remove dead code
…pper (#42) * Only sync fp32_to_fp16 stream for the top-most (root) ShardParams wrapper * Fix mypy, add test, address some comments * Add missing assert * Comments
* Test backward hooks are registered * expand * fs_test * passing * assert again * add assert not called * naming
* format change * [test]: test apply_to_tensors * formatting * added some skeletons * name for TODO * fixing the use of lru_cache * formatting
* Do reduce-scatter in a separate CUDA stream * Add _post_backward_stream to stubs
* Fix delayed_reduce_scatter test * Decompose NestedWrappedModule from ModuleWithDelay
* add unit test pack/unpack kwargs * added two more corner cases * more doc and more tests * more corner cases * formatting * Update fairscale/utils/containers.py Co-authored-by: Sam Shleifer <sshleifer@gmail.com> * with pytest.raises is awesome * addressed comment * add tuple to be tested Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
- Add two new tests (TestParamInit and TestSerialization) which would have failed previously. These mostly cover the fairseq usage that was not captured by tests before. - Deprecate the `compute_device` option, since we don't actually use it anywhere (nor do we need it in fairseq) - Remove `_move_fp32_shard_to_cuda` and embrace a stronger invariant: p.data and p._fp32_shard should always be the same at the start and end of each function (namely, state_dict, and also forward/backward). - Slightly unrelated, but refactor streams logic a bit, so we have a single `self._streams` dictionary -- this will make an upcoming PR that adds more streams easier
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Congrats @myleott, @min-xu-ai and @sshleifer :)
optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) | ||
for _ in range(num_steps): | ||
optim.zero_grad() | ||
with torch.cuda.amp.autocast(enabled=autocast): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had a quick sneak at the PR, was wondering if this test will pass if we add a GradScaler
object? If weights/states are sharded across will we need to use something similar to the ShardedGradScaler
: https://github.com/facebookresearch/fairscale/blob/master/fairscale/optim/grad_scaler.py#L24
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question and I think you're probably right. PR's welcome :)
Merging so that future developments can be done as separate PRs. Thanks for all the help reviewing this @anj-s and @blefaudeux! |
Great work @myleott @sshleifer @min-xu-ai! Excited to see this PR :) |
This is full ZeRO-DP stage 3, right? |
Yep, ZeRO stage 3 and CPU offloading. It can also implement stage 2 (with |
Absolutely amazing!!! This dramatically changes what we can do with huge models. Can't wait to test it out. |
OK, I'd like to start integrating this right away in HF transformers - do you have examples I could learn from/mimic? Also wrt:
where does And I'm also thinking about the long term - as fairscale has been backporting this upstream to pytorch, we should probably extended the HF trainer API to assume that eventually this will be native and until then just require |
The doc is still being worked on. We are adding more tests and making sure FSDP work with more models. The docstring right now has a minimal example. model = FSDP(nn.Sequential( I think Myle's usage in Transformer might be slightly different with respect to checkpoint_wrapper & FSDP's order. Again, I'd warn that this is fresh off the press. Sharp edges are likely in multiple places.
They are only roughly == in terms memory overhead. I'd think they will be coexisting for a while if not forever. They may have different fits for different model and model sizes. We are still gaining experiences with them. In the end, it will be a comparison of usability, generalizability, performance, memory efficiency, etc.
yes |
Thank you for the example, @min-xu-ai Oh, the OP mentioned nothing about the model needing to be Could one traverse the model graph and inject those intermediary layers automatically? Perhaps with pytorch FX?
I appreciate the disclaimer. I realize this is an early alpha and I promise to thread with care.
Thank you for clarifying that. I'm just starting to think about the user-facing API in the HF trainer, hence trying to think how to best name things. |
Sorry for misleading you. That's just an example. In fact, it was a ModuleDict in vissl's case. I used nn.Sequential out of habit.
Definitely a possibility! I think figuring out where to inject for best results will be the harder part though. |
Correct, it’s not implemented the same way as ShardedDDP, so there may be different use cases that prefer one or the other. That said, the version here (with
To clarify, no need for nn.Sequential, almost any module structure works. I agree we should traverse the graph and wrap the intermediate layers with FSDP automatically. In fact, there’s almost no downside in terms of perf of just wrapping every layer that’s big enough, because FSDP internally manages the child instances (e.g., it will share CUDA streams with the children to avoid unnecessary synchronization). The one caveat is that some model structures break if you wrap a child module. In particular, if the parent module accesses the weights of the child module directly (i.e., not via the child module’s forward), then we don’t execute the hook to all-gather the params, so it breaks. The nn.Linears in fairseq’s MultiheadAttention is one example; had we wrapped So for now we’re requiring layers to be wrapped manually, but I do think this can be relaxed once we add better fallback logic/error reporting for the case above.
I’ll try to put up a diff in fairseq today or tomorrow that works for Transformer. |
For the sake of precision, that's true (overlap FW and gather) if you can wrap subparts of your model with FSDP, not if the model is in one FSDP wrap |
Thank you very much for your detailed commentary to my questions - that's very useful and I'm looking forward to studying the examples Request: to make things consistent with DDP could this be named On the other hand it implements ZeRO-DP, so it matches the name in the paper. These are conflicting arguments. May I suggest that removing ambiguity for the user is more important than matching the name paper? Note that you rename |
Dear @myleott, Thanks for your incredible work. I have a sill question: Currently, sharding is done by splitting the contiguous parameter tensor into
I was wondering If you considered to perform a chunking in the following way. Chunk each individual parameters into their own world_size chunks and concatenate them across the entire model in a contiguous way. For all rank I, [chunks_world_size(p1, rank_i), chunks_world_size(p2, rank_i), ..., chunks_world_size(pn, rank_i)]]. Then, it could be possible to override the the parameter getattr function to perform the in-place all_gather when accessed + the following ones in async way, and perform memory release with a post forward hook, which could maybe reduce some memory. My assumption is: It would avoid to call I hope I was able to properly express myself :) Best, |
Nice idea. I think this would be possible, especially as we move towards automatic wrapping of children modules (discussed above). As always, I'm sure there are some corner cases we'll need to be careful of, but this is a good direction to explore 😄 A related, but separate next step is to plug into the DDP Communication Hooks interface to handle the reductions. Right now we're doing our own bucketing, which is really slow when |
Dear @myleott, Sounds great ! If you are interested, we could try to explore this idea with your team. I recently worked on integrating a DDP Comm Hook to perform training for unstable NaN Loss such as the CTC Loss: https://github.com/PyTorchLightning/pytorch-lightning/blob/490c40a8be339afd08dcfdcdd658db6b0be4671b/pytorch_lightning/plugins/ddp_comm_hooks/allreduce_invalid.py#L52. Side Note: I think it would be also interesting to work on composable DDP Comm Hooks by making it more modular. I am going to keep studying the code. Thanks you and your team for such a clean work ! Learning a lot :) Best, |
Not sure I totally follow this idea, Thomas. Say the world_size is 2. Do you mean break each param in 2 chunks and concat all 1st chunks on rank 0 and concat all 2nd chunks on rank 1? Is the concat needed so that you don't all_gather small params one by one? With concat, you just all_gather a partial view into the big array? The idea to use getattr to trigger all_gather is nice but it assumes different ranks will always execute all getattr in the same order, which is controlled by the model code that we don't control. |
Co-authored-by: @min-xu-ai and @sshleifer
Overview
Recent work by Microsoft and Google has shown that data parallel training can be made significantly more efficient by sharding the model parameters and optimizer state across data parallel workers. These ideas are encapsulated in the new
FullyShardedDataParallel
(FSDP) wrapper, which is a drop-in replacement for PyTorch'sDistributedDataParallel
(DDP) wrapper.Compared to PyTorch DDP:
reshard_after_forward=False
has the same communication cost as PyTorch DDP and is similar to ZeRO-2reshard_after_forward=True
increases total communication by 50% and is similar to ZeRO-3:cpu_offload=True
option, it's possible to train 1T parameter models on 256 GPUs.General usage notes
reshard_after_forward=True
reshard_after_forward=False
(wrapping each layer is not required, but will improve speed further)torch.cuda.amp.autocast
for mixed precision, that's fully compatible with the FSDP wrapper, just setmixed_precision=True
FSDP(checkpoint_wrapper(module))
overcheckpoint_wrapper(FSDP(module))
. The latter will result in more communication and will be slower.How it works
In standard distributed data parallel (DDP) training every worker processes a separate batch and the gradients are summed across workers using an all-reduce operation. While DDP has become very popular, it wastes GPU memory because the model weights and optimizer states are replicated across all DDP workers.
The key insight to unlock full parameter sharding is that we can decompose the all-reduce operation in DDP into separate all-gather and reduce-scatter operations:
Then, we can rearrange the reduce-scatter + all-gather so that each DDP worker only needs to store a single shard of parameters and optimizer state. The figure below illustrates standard DDP training (left) and fully sharded training (right):
To maximize memory efficiency we can discard the full weights after each layer's forward pass, saving memory for subsequent layers. This can be implemented by applying the FSDP wrapper to every layer in your network (with
reshard_after_forward=True
). In pseudo-code: