-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FSDP integration #6152
FSDP integration #6152
Changes from all commits
78f1eb4
c36e00a
59dbb83
19a1440
3b38615
5ff06ab
36434f0
c61a190
1c4f011
02599e6
e79977a
d15d4b5
ebf1818
290e8fd
5c5f762
d28438b
ab591a8
23ccdb8
a60f2c0
3d4e6df
516bd04
9f8864f
eac5344
32df0cb
282a133
7a94e72
8091481
633fc77
c99a36f
a68c8d7
9529a22
3c1c782
7daba43
87ec222
966b2e5
5f6e039
b512e72
8ba82df
1e5ca37
936dc1a
a6de18e
8684f94
76091ae
226d498
cd63c10
b881e2f
e8959be
52478ac
b7f1896
a62f8d8
69c33f1
9fa26c0
56f23ce
9ca3f0c
b53ba36
0da5249
90c6479
69d8178
a459d10
a7842d9
48ee83f
57a696c
36889b8
89b8cb5
0c1d2de
592bb28
4e230c9
ca8e586
54f501d
02925cc
b67f1a9
132eb64
e6ce3cf
01153af
6cfe57d
efa81ab
78d52b5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import cast, Optional, Union | ||
|
||
from torch.nn import Module | ||
from torch.optim import Optimizer | ||
|
||
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin | ||
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE, GradClipAlgorithmType | ||
|
||
if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: | ||
from fairscale.nn.data_parallel import FullyShardedDataParallel | ||
|
||
|
||
class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): | ||
"""Mixed Precision for Full Sharded Training""" | ||
|
||
def clip_gradients( | ||
self, | ||
optimizer: 'Optimizer', | ||
clip_val: Union[int, float], | ||
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, | ||
model: Optional[Module] = None | ||
) -> None: | ||
# Model manages clipping of gradients | ||
model = cast(FullyShardedDataParallel, model) | ||
# todo: expose norm type once precision plugin supports this. | ||
model.clip_grad_norm_(clip_val, norm_type=2.0) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import contextlib | ||
from typing import Dict, Generator, List, Optional | ||
|
||
import torch | ||
|
||
from pytorch_lightning.core.lightning import LightningModule | ||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment | ||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin | ||
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
|
||
if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: | ||
from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap, FlattenParamsWrapper, wrap | ||
from fairscale.nn.data_parallel import FullyShardedDataParallel | ||
|
||
from pytorch_lightning.overrides.fairscale import LightningFullyShardedModule, unwrap_lightning_module_fully_sharded | ||
|
||
|
||
class FullyShardedPlugin(DDPPlugin): | ||
|
||
def __init__( | ||
self, | ||
cpu_offload: bool = False, | ||
flatten_parameters: bool = True, | ||
reshard_after_forward: bool = True, | ||
move_grads_to_cpu: Optional[bool] = None, | ||
fp32_reduce_scatter: Optional[bool] = None, | ||
compute_dtype: Optional[torch.dtype] = None, | ||
bucket_cap_mb: int = 25, | ||
automatic_module_wrap: bool = False, | ||
min_num_params: int = 1e8, | ||
parallel_devices: Optional[List[torch.device]] = None, | ||
num_nodes: Optional[int] = None, | ||
cluster_environment: ClusterEnvironment = None, | ||
sync_batchnorm: Optional[bool] = None | ||
): | ||
""" | ||
|
||
Provides capabilities to run training using the Full Sharded capabilities provided by FairScale. | ||
|
||
Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model | ||
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain | ||
at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar | ||
to ZeRO-Stage 3 but has been built for upstreaming to PyTorch. | ||
|
||
`For more information: https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`. | ||
|
||
.. warning:: ``FullyShardedPlugin`` is in beta and subject to change. | ||
|
||
Defaults have been set and options have been exposed, but may require configuration | ||
based on your level of memory/speed efficiency. | ||
We suggest having a look at this PR for more information. | ||
`https://github.com/facebookresearch/fairscale/pull/413` | ||
|
||
|
||
Many of the helpful doc strings below came from the original FairScale documentation: | ||
`https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html` | ||
|
||
Arguments: | ||
|
||
cpu_offload: Offload FP32 params to CPU. Only useable in precision=16 mode (default: False). | ||
|
||
move_grads_to_cpu: Moves gradient shards to CPU after reduction. | ||
Only disable if using CPU based optimizers (defaults to ``cpu_offload``). | ||
|
||
flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency | ||
(default: False). | ||
|
||
reshard_after_forward: Reshard parameters after the forward pass, which saves memory but slows | ||
down training. Only revelant when nesting FullyShardedDataParallel wrappers inside the model. | ||
(default: False). | ||
|
||
fp32_reduce_scatter: Reduce-Scatter gradients in FP32. Only relevant in mixed precision | ||
(default: None) | ||
|
||
compute_dtype: dtype for full parameters for computation. Default to torch.float32, | ||
unless using mixed precision, in which case defaults to torch.float16. | ||
|
||
bucket_cap_mb: bucket parameters so that gradient reduction | ||
can potentially overlap with backward computation. | ||
bucket_cap_mb controls the bucket size in MegaBytes (MB). | ||
Buckets are sub-divided based on world_size, | ||
so the max shard size is roughly bucket_cap_mb / world_size. | ||
Values <= 0 disable bucketing. (Default: 25). | ||
|
||
automatic_module_wrap: Automatically wrap the lightning module with Fully Sharded recursively. | ||
Using ``min_num_params`` to determine the amount of parameters to wrap at a time. | ||
(default: False) | ||
|
||
min_num_params: Number of parameters to wrap when using FairScale ``auto_wrap``. | ||
(default: 1e8) | ||
|
||
""" | ||
if not _FAIRSCALE_FULLY_SHARDED_AVAILABLE: | ||
raise MisconfigurationException( | ||
"Full Sharded Training is not available. Install the latest FairScale via `pip install fairscale -U`" | ||
) | ||
|
||
super().__init__(parallel_devices, num_nodes, cluster_environment, sync_batchnorm) | ||
self.cpu_offload = cpu_offload | ||
self.move_grads_to_cpu = move_grads_to_cpu | ||
self.flatten_parameters = flatten_parameters | ||
self.reshard_after_forward = reshard_after_forward | ||
self.fp32_reduce_scatter = fp32_reduce_scatter | ||
self.compute_dtype = compute_dtype | ||
self.bucket_cap_mb = bucket_cap_mb | ||
self.automatic_module_wrap = automatic_module_wrap | ||
self.min_num_params = min_num_params | ||
self._process_group = None | ||
|
||
@property | ||
def process_group(self): | ||
if self._process_group is None: | ||
self._process_group = torch.distributed.new_group() | ||
return self._process_group | ||
|
||
def setup_distributed(self): | ||
super().setup_distributed() | ||
if self.root_device.type == "cuda": | ||
torch.cuda.set_device(self.root_device) | ||
|
||
@contextlib.contextmanager | ||
def model_sharded_context(self) -> Generator: | ||
precision = self.lightning_module.trainer.precision | ||
|
||
def wrap_policy(*args, **kwargs): | ||
return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params) | ||
|
||
with enable_wrap( | ||
wrapper_cls=FullyShardedDataParallel, | ||
auto_wrap_policy=wrap_policy, | ||
process_group=self.process_group, | ||
cpu_offload=self.cpu_offload, | ||
move_grads_to_cpu=self.move_grads_to_cpu, | ||
flatten_parameters=self.flatten_parameters, | ||
mixed_precision=precision == "mixed", | ||
reshard_after_forward=self.reshard_after_forward, | ||
fp32_reduce_scatter=self.fp32_reduce_scatter, | ||
compute_dtype=self.compute_dtype, | ||
bucket_cap_mb=self.bucket_cap_mb, | ||
): | ||
yield | ||
|
||
def configure_ddp(self): | ||
with self.model_sharded_context(): | ||
if self.automatic_module_wrap and not self._model_has_nested_fsdp(): | ||
self.model = auto_wrap(LightningFullyShardedModule(self.model)) | ||
if not isinstance(self.model, FullyShardedDataParallel): | ||
self.model = wrap(self.model) | ||
else: | ||
self.model = wrap(LightningFullyShardedModule(self.model)) | ||
|
||
if not self.cpu_offload: | ||
# When using CPU Offload, FSDP will manage the CUDA movement for us | ||
self.model_to_device() | ||
# setup optimizers after fully sharded has wrapped the lightning module | ||
self.lightning_module.trainer.accelerator.setup_optimizers(self.lightning_module.trainer) | ||
|
||
def model_to_device(self): | ||
self.model.to(self.root_device) | ||
# ensure we update the device type in the lightning module | ||
self.lightning_module.to(self.root_device) | ||
Comment on lines
+173
to
+175
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we might need to be cautious about this, as and when we perform |
||
|
||
def pre_dispatch(self): | ||
if self.sync_batchnorm: | ||
self.model = self.configure_sync_batchnorm(self.model) | ||
self.configure_ddp() | ||
self.barrier() | ||
|
||
@property | ||
def lightning_module(self) -> LightningModule: | ||
return unwrap_lightning_module_fully_sharded(self.model) | ||
|
||
def on_save(self, checkpoint: dict) -> dict: | ||
state_dict = self.collate_state_dict() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @SeanNaren , after getting detailed memory usage, I finally figured out why originally the full model fits in one GPU, but when checkpointing, it OOM because in
here, we try to collect again, this would double the size. One easy workaround now, is to add
but this is not ideal, we summon the full parameters twice which is unnecessary. I feel, we should modify that file to let training type plugin to control, something like especially when we would like to collect only sharded state dict in the future. cc @ananthsub @min-xu-ai , I think this is the root cause for OOM, facebookresearch/fairscale#658 should not be problem (for setting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @shuyingsunshine21 for your help here! This makes sense since we're allocating memory new memory. I agree with allowing the training type plugin to return the state dict, we already rely on the accelerator to dump the optimizer dicts. I'm happy to make the change! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @SeanNaren , thanks, no worry, if you have not already made the change, I could help send a small PR for that. |
||
checkpoint['state_dict'] = state_dict | ||
return checkpoint | ||
|
||
def collate_state_dict(self): | ||
""" | ||
Collects the models sharded state dict from all processes before returning. | ||
Returns: The unsharded model state dict. | ||
""" | ||
state_dict = self.model.state_dict() | ||
# Remove module prefix from state dict as this is the behaviour of state dict. | ||
state_dict = {k.partition('module.')[2]: state_dict[k] for k in state_dict.keys()} | ||
return state_dict | ||
|
||
@property | ||
def setup_optimizers_in_pre_dispatch(self) -> bool: | ||
# Setup optimizers after the Fully Sharded Model has been made | ||
return True | ||
|
||
def _model_has_nested_fsdp(self): | ||
for module in self.model.modules(): | ||
if isinstance(module, FullyShardedDataParallel): | ||
return True | ||
return False | ||
|
||
@classmethod | ||
def register_plugins(cls, plugin_registry: Dict): | ||
plugin_registry.register("fsdp", cls, description="Fully Sharded with LightningModule wrap") | ||
plugin_registry.register( | ||
"fsdp_offload", cls, description="Fully Sharded Training with CPU Offloading.", cpu_offload=True | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if manually wrapping contents inside the lightning module, is this final outer layer wrap needed? or could we defer this to the user in the lightning module too?
then we could not wrap model in the dummy
LightningFullyShardedModule
to map forward to one of the step functions. would it also mean users don't have to refer to self.trainer.model inside of the lightning module?would this avoid the parameter flattening issue across stages?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I understand! Theoretically I think so, since then we're just using the LM as a wrapper. So the cases I see:
auto_wrap
to handle recursive wrappingconfigure_sharded_model
but then expects all other layers to be included in a higher wrapper class (wrap the entire LM)configure_sharded_model
, doesn't require any high level wrappingSolutions
plugins=fsdp
orplugins=fsdp_auto_wrap
plugins=fsdp
orplugins=fsdp_auto_wrap
plugins=fsdp_manual
where we do not wrap the highest level module, allowing the user to do whatever they'd like inconfigure_optimizers
.In either case, it's important to fix the flattening issue for 1. and 2. which for most users trying out will be the first step I think. Thoughts @ananthsub?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes exactly those 3 cases!
To unblock the initial integration, I was wondering if we should start with option #3 to unblock power users in release candidates with the caveat that they are responsible for the full wrapping. Maybe this could can be option on the plugin as to whether the outer wrap on lightning module needs to be applied in order to distinguish between cases 2 and 3.
Completely agreed with you that most users will opt for cases 1 and 2, so we'll need to figure out the parameter flattening, whether in lightning or fairscale, but wanted to offer this as one way we could sequence these cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I started making changes to see if any issues arise with case 3 and a few observation:
The user still has to define a single model, may it be a Module containing modules in a sequential wrapper, or just defining their own model structure defining a forward function. This means
self.model
will still probably be required in every case for FSDP to work inconfigure_optimizers
.I also ran into an issue where clipping grad norms which in manual mode cannot be handled automatically, as we do not wrap the model:
A potential solution albeit not as elegant as I'd like, would be to go through the immediate children of the LightningModule, find the root FSDP module and call
clip_grad_norm_
on it. I assume this will be a negligible cost added on top of the training loop but what are your thoughts @ananthsub?