-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FSDP with full state dict #7487
Changes from 72 commits
89f284d
80cfbff
536c132
f172101
bf70e43
ea74906
a9aae99
70fe5da
0d23d75
ca6f98b
c5053da
9d4a2b8
7635b4f
d64f90c
dcdcd29
8651d54
15f4b9e
250d0aa
6c095b2
8222dc9
3a9fde9
7a369f4
b4a0b9e
5cf1db1
0ce7e05
fe9736d
c314ef6
c3feda0
c759477
7a8e540
ab8b849
4e67db2
67b6188
179d47e
f9afa07
b461e44
e1bbc4d
8270d0d
803d5dd
ce1a19b
c6a13be
e274758
f337156
35bb931
c938a9c
ad93cde
b5d989d
3b86dd2
d7c33ab
1537637
c128e03
18aeb9e
8adbe0c
70f8bf0
7fb7c60
362d1a8
310b0ee
e7bef62
77bd7c8
f64a54c
a5ee90c
af31f6a
a044f58
78897a2
5fac05c
8fb38dc
53be8f4
2fdab94
ffb985d
3189f4c
c33fbcb
0067aeb
320c35d
5ea37a5
9e599c2
6c96625
32cc552
25329fe
f43a36f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Optional, Union | ||
|
||
from torch.nn import Module | ||
from torch.optim import Optimizer | ||
|
||
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin | ||
from pytorch_lightning.utilities import GradClipAlgorithmType | ||
|
||
|
||
class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): | ||
"""Mixed Precision for Full Sharded Training""" | ||
|
||
precision = "mixed" | ||
|
||
def clip_gradients( | ||
self, | ||
optimizer: Optimizer, | ||
clip_val: Union[int, float], | ||
gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.VALUE, | ||
model: Optional[Module] = None, | ||
) -> None: | ||
clip_val = float(clip_val) | ||
if clip_val <= 0: | ||
return | ||
# see https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html | ||
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect | ||
# for FSDP module. To overcome this, needs to call sharded_module.clip_grad_norm(clip_val) | ||
# however we rely on LightningModule's configure_sharded_model to wrap FSDP, it would be hard to | ||
# trace back the root FSDP. Now we only support clip by value. | ||
assert ( | ||
gradient_clip_algorithm == GradClipAlgorithmType.VALUE | ||
), "`gradient_clip_algorithm`: `norm` is currently not supported for `FullyShardedNativeMixedPrecisionPlugin`" | ||
self.clip_grad_by_value(optimizer, clip_val) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import contextlib | ||
from typing import Any, Dict, Generator, List, Optional, Union | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch.nn import Module | ||
|
||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment | ||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin | ||
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
|
||
if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: | ||
from fairscale.nn import default_auto_wrap_policy, enable_wrap | ||
from fairscale.nn.data_parallel import FullyShardedDataParallel | ||
|
||
|
||
class DDPFullyShardedPlugin(DDPPlugin): | ||
|
||
def __init__( | ||
self, | ||
cpu_offload: bool = False, | ||
flatten_parameters: bool = True, | ||
reshard_after_forward: bool = True, | ||
move_grads_to_cpu: Optional[bool] = None, | ||
fp32_reduce_scatter: Optional[bool] = None, | ||
compute_dtype: Optional[torch.dtype] = None, | ||
bucket_cap_mb: int = 25, | ||
# pyre-ignore[9] | ||
carmocca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
min_num_params: int = 1e8, | ||
state_dict_to_cpu: bool = True, | ||
parallel_devices: Optional[List[torch.device]] = None, | ||
# pyre-ignore[9] | ||
cluster_environment: ClusterEnvironment = None, | ||
): | ||
""" | ||
Plugin for Fully Sharded Data Parallel provided by FairScale. | ||
|
||
Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model | ||
size, whilst using efficient communication to reduce overhead. In practice, this means we can remain | ||
at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar | ||
to ZeRO-Stage 3 but has been built for upstreaming to PyTorch. | ||
`For more information: https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`. | ||
.. warning:: ``FullyShardedPlugin`` is in beta and subject to change. | ||
|
||
Defaults have been set and options have been exposed, but may require configuration | ||
based on your level of memory/speed efficiency. We suggest having a look at this PR for more information. | ||
`https://github.com/facebookresearch/fairscale/pull/413` | ||
|
||
Many of the helpful doc strings below came from the original FairScale documentation: | ||
`https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html` | ||
|
||
Arguments: | ||
cpu_offload: Offload FP32 params to CPU. Only usable in precision=16 mode. | ||
(Default: False). | ||
move_grads_to_cpu: Moves gradient shards to CPU after reduction. | ||
Only disable if using CPU based optimizers | ||
(Default to ``cpu_offload``). | ||
flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency | ||
(Default: True). | ||
reshard_after_forward: Reshard parameters after the forward pass, which saves memory but slows | ||
down training. This is only relevant when resharding individual layers. | ||
(Default: True). | ||
fp32_reduce_scatter: Reduce-Scatter gradients in FP32. Only relevant in mixed precision | ||
(Default: None). | ||
compute_dtype: dtype for full parameters for computation. Default to torch.float32, | ||
unless using mixed precision, in which case defaults to torch.float16. | ||
(Default: None). | ||
bucket_cap_mb: bucket parameters so that gradient reduction | ||
can potentially overlap with backward computation. | ||
bucket_cap_mb controls the bucket size in MegaBytes (MB). | ||
Buckets are sub-divided based on world_size, | ||
so the max shard size is roughly bucket_cap_mb / world_size. | ||
Values <= 0 disable bucketing. | ||
(Default: 25). | ||
min_num_params: Number of parameters to wrap when using FairScale ``auto_wrap``. | ||
(Default: 1e8) | ||
state_dict_to_cpu: Whether to return parameters (returned by :func:`state_dict`) on CPU device. | ||
If ``False``, this will default to ``compute_device``. | ||
(Defautl: True). | ||
""" | ||
|
||
super().__init__( | ||
parallel_devices=parallel_devices, | ||
cluster_environment=cluster_environment, | ||
) | ||
self.cpu_offload = cpu_offload | ||
self.move_grads_to_cpu = move_grads_to_cpu | ||
self.flatten_parameters = flatten_parameters | ||
self.reshard_after_forward = reshard_after_forward | ||
self.fp32_reduce_scatter = fp32_reduce_scatter | ||
self.compute_dtype = compute_dtype | ||
self.bucket_cap_mb = bucket_cap_mb | ||
self.min_num_params = min_num_params | ||
self.state_dict_device = torch.device("cpu") if state_dict_to_cpu else None | ||
self._process_group = None | ||
|
||
@property | ||
def process_group(self): | ||
if self._process_group is None: | ||
self._process_group = torch.distributed.new_group() | ||
return self._process_group | ||
|
||
def setup_distributed(self) -> None: | ||
if not self.on_gpu: | ||
raise MisconfigurationException( | ||
Comment on lines
+116
to
+117
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this could be easily unit tested There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds great, adding unit test for it. |
||
"You selected accelerator to be `ddp_fully_sharded`, but GPU is not available." | ||
) | ||
super().setup_distributed() | ||
torch.cuda.set_device(self.root_device) | ||
|
||
@contextlib.contextmanager | ||
def model_sharded_context(self) -> Generator: | ||
precision = self.lightning_module.trainer.precision | ||
|
||
def wrap_policy(*args, **kwargs): | ||
return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params) | ||
|
||
with enable_wrap( | ||
wrapper_cls=FullyShardedDataParallel, | ||
auto_wrap_policy=wrap_policy, | ||
process_group=self.process_group, | ||
cpu_offload=self.cpu_offload, | ||
move_grads_to_cpu=self.move_grads_to_cpu, | ||
flatten_parameters=self.flatten_parameters, | ||
mixed_precision=precision == "mixed", | ||
reshard_after_forward=self.reshard_after_forward, | ||
fp32_reduce_scatter=self.fp32_reduce_scatter, | ||
compute_dtype=self.compute_dtype, | ||
bucket_cap_mb=self.bucket_cap_mb, | ||
state_dict_device=self.state_dict_device, | ||
): | ||
yield | ||
|
||
def connect(self, model: Module) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is needed when we load from ckpt, as now we checkpoint the full state, to load, we need the unwrapped model first, and load. This allows us to configure shard (i.e. wrap) afterwards. Added unittest for making sure that for different stages start (on_x_start), |
||
super().connect(model) | ||
model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False) | ||
if not model_call_configure_sharded_model_hook: | ||
# if model has not called configure sharded model, we reset | ||
# the training type plugin's call_configure_sharded_model_hook | ||
# to give trainer a chance to configure. | ||
self.call_configure_sharded_model_hook = True | ||
|
||
def configure_ddp(self) -> None: | ||
if not self.cpu_offload: | ||
# When using CPU Offload, FSDP will manage the CUDA movement for us. | ||
# Note: this would be problematic for large model (which could not fit in one GPU) | ||
# as FSDP module.to(device) would first summon all parameters | ||
# (TODO: need to figure out solution) | ||
self.model_to_device() | ||
|
||
# setup optimizers after fully sharded has wrapped the lightning module | ||
self.lightning_module.trainer.accelerator.setup_optimizers(self.lightning_module.trainer) | ||
|
||
def pre_dispatch(self) -> None: | ||
if self.sync_batchnorm: | ||
self.model = self.configure_sync_batchnorm(self.model) | ||
self.configure_ddp() | ||
self.barrier() | ||
|
||
def model_to_device(self) -> None: | ||
# ensure we update the device type in the lightning module | ||
self.lightning_module.to(self.root_device) | ||
|
||
def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: | ||
# Currently it is same as default TrainingTypePlugin, i.e. return | ||
# the full state dict for FSDP, in the future, we will provide sharded | ||
# state dict. | ||
return super().lightning_module_state_dict() | ||
|
||
@property | ||
def setup_optimizers_in_pre_dispatch(self) -> bool: | ||
# Setup optimizers after the Fully Sharded Model has been made | ||
return True | ||
|
||
def training_step(self, *args, **kwargs): | ||
return self.model.training_step(*args, **kwargs) | ||
|
||
def validation_step(self, *args, **kwargs): | ||
return self.model.validation_step(*args, **kwargs) | ||
|
||
def test_step(self, *args, **kwargs): | ||
return self.model.test_step(*args, **kwargs) | ||
|
||
def predict_step(self, *args, **kwargs): | ||
return self.model.predict_step(*args, **kwargs) | ||
|
||
def post_training_step(self): | ||
pass | ||
|
||
@classmethod | ||
def register_plugins(cls, plugin_registry: Dict): | ||
plugin_registry.register( | ||
"fsdp", | ||
cls, | ||
description="Fully sharded training with checkpointing the full state dict.", | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for this, by looking at this facebookresearch/fairscale#413 probably we could register one with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After some thought not sure how important this is; currently a user has to wrap their model in nested FSDP There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we name it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This already exists in the highest level (in the accelerator connector via the Enum it's called But going off #7259 (comment) I think we should introduce |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would we need this if the precision plugin was owned by the training type plugin?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
then should not be needed.