diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b5e1974ab973..d07d903be16bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added correct `dataloader_idx` to batch transfer hooks ([#6241](https://github.com/PyTorchLightning/pytorch-lightning/pull/6241)) +- Added `ddp_fully_sharded` support ([#7487](https://github.com/PyTorchLightning/pytorch-lightning/pull/7487)) + + ### Changed diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index 444d2aaef978b..58d43dc54cb7f 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -6,6 +6,9 @@ from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401 + FullyShardedNativeMixedPrecisionPlugin, +) from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401 @@ -15,6 +18,7 @@ from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401 @@ -32,6 +36,7 @@ "DDP2Plugin", "DDPPlugin", "DDPSpawnPlugin", + "DDPFullyShardedPlugin", "DeepSpeedPlugin", "DeepSpeedPrecisionPlugin", "DoublePrecisionPlugin", @@ -39,17 +44,18 @@ "NativeMixedPrecisionPlugin", "PrecisionPlugin", "ShardedNativeMixedPrecisionPlugin", + "FullyShardedNativeMixedPrecisionPlugin" "SingleDevicePlugin", "SingleTPUPlugin", "TPUHalfPrecisionPlugin", "TPUSpawnPlugin", - 'RPCPlugin', - 'RPCSequentialPlugin', - 'TrainingTypePlugin', - 'ParallelPlugin', - 'Plugin', - 'DDPShardedPlugin', - 'DDPSpawnShardedPlugin', + "RPCPlugin", + "RPCSequentialPlugin", + "TrainingTypePlugin", + "ParallelPlugin", + "Plugin", + "DDPShardedPlugin", + "DDPSpawnShardedPlugin", ] from pathlib import Path diff --git a/pytorch_lightning/plugins/precision/__init__.py b/pytorch_lightning/plugins/precision/__init__.py index d32aac829a13d..904e5f9f44a27 100644 --- a/pytorch_lightning/plugins/precision/__init__.py +++ b/pytorch_lightning/plugins/precision/__init__.py @@ -1,6 +1,9 @@ from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401 + FullyShardedNativeMixedPrecisionPlugin, +) from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py new file mode 100644 index 0000000000000..dedf274237f09 --- /dev/null +++ b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -0,0 +1,46 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +from torch.nn import Module +from torch.optim import Optimizer + +from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin +from pytorch_lightning.utilities import GradClipAlgorithmType + + +class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): + """Mixed Precision for Full Sharded Training""" + + precision = "mixed" + + def clip_gradients( + self, + optimizer: Optimizer, + clip_val: Union[int, float], + gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.VALUE, + model: Optional[Module] = None, + ) -> None: + clip_val = float(clip_val) + if clip_val <= 0: + return + # see https://fairscale.readthedocs.io/en/latest/api/nn/fsdp_tips.html + # section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect + # for FSDP module. To overcome this, needs to call sharded_module.clip_grad_norm(clip_val) + # however we rely on LightningModule's configure_sharded_model to wrap FSDP, it would be hard to + # trace back the root FSDP. Now we only support clip by value. + assert ( + gradient_clip_algorithm == GradClipAlgorithmType.VALUE + ), "`gradient_clip_algorithm`: `norm` is currently not supported for `FullyShardedNativeMixedPrecisionPlugin`" + self.clip_grad_by_value(optimizer, clip_val) diff --git a/pytorch_lightning/plugins/training_type/__init__.py b/pytorch_lightning/plugins/training_type/__init__.py index 30723d67da3f4..3cb43e44f5565 100644 --- a/pytorch_lightning/plugins/training_type/__init__.py +++ b/pytorch_lightning/plugins/training_type/__init__.py @@ -3,6 +3,7 @@ from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.fully_sharded import DDPFullyShardedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py new file mode 100644 index 0000000000000..476df9be13cfe --- /dev/null +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -0,0 +1,208 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +from typing import Any, Dict, Generator, List, Optional, Union + +import torch +from torch import Tensor +from torch.nn import Module + +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.training_type.ddp import DDPPlugin +from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: + from fairscale.nn import default_auto_wrap_policy, enable_wrap + from fairscale.nn.data_parallel import FullyShardedDataParallel + + +class DDPFullyShardedPlugin(DDPPlugin): + + def __init__( + self, + cpu_offload: bool = False, + flatten_parameters: bool = True, + reshard_after_forward: bool = True, + move_grads_to_cpu: Optional[bool] = None, + fp32_reduce_scatter: Optional[bool] = None, + compute_dtype: Optional[torch.dtype] = None, + bucket_cap_mb: int = 25, + min_num_params: int = 1e8, + state_dict_to_cpu: bool = True, + parallel_devices: Optional[List[torch.device]] = None, + cluster_environment: ClusterEnvironment = None, + ): + """ + Plugin for Fully Sharded Data Parallel provided by FairScale. + + Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model + size, whilst using efficient communication to reduce overhead. In practice, this means we can remain + at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar + to ZeRO-Stage 3 but has been built for upstreaming to PyTorch. + `For more information: https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`. + .. warning:: ``FullyShardedPlugin`` is in beta and subject to change. + + Defaults have been set and options have been exposed, but may require configuration + based on your level of memory/speed efficiency. We suggest having a look at this PR for more information. + `https://github.com/facebookresearch/fairscale/pull/413` + + Many of the helpful doc strings below came from the original FairScale documentation: + `https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html` + + Arguments: + cpu_offload: Offload FP32 params to CPU. Only usable in precision=16 mode. + (Default: False). + move_grads_to_cpu: Moves gradient shards to CPU after reduction. + Only disable if using CPU based optimizers + (Default to ``cpu_offload``). + flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency + (Default: True). + reshard_after_forward: Reshard parameters after the forward pass, which saves memory but slows + down training. This is only relevant when resharding individual layers. + (Default: True). + fp32_reduce_scatter: Reduce-Scatter gradients in FP32. Only relevant in mixed precision + (Default: None). + compute_dtype: dtype for full parameters for computation. Default to torch.float32, + unless using mixed precision, in which case defaults to torch.float16. + (Default: None). + bucket_cap_mb: bucket parameters so that gradient reduction + can potentially overlap with backward computation. + bucket_cap_mb controls the bucket size in MegaBytes (MB). + Buckets are sub-divided based on world_size, + so the max shard size is roughly bucket_cap_mb / world_size. + Values <= 0 disable bucketing. + (Default: 25). + min_num_params: Number of parameters to wrap when using FairScale ``auto_wrap``. + (Default: 1e8) + state_dict_to_cpu: Whether to return parameters (returned by :func:`state_dict`) on CPU device. + If ``False``, this will default to ``compute_device``. + (Defautl: True). + """ + + super().__init__( + parallel_devices=parallel_devices, + cluster_environment=cluster_environment, + ) + self.cpu_offload = cpu_offload + self.move_grads_to_cpu = move_grads_to_cpu + self.flatten_parameters = flatten_parameters + self.reshard_after_forward = reshard_after_forward + self.fp32_reduce_scatter = fp32_reduce_scatter + self.compute_dtype = compute_dtype + self.bucket_cap_mb = bucket_cap_mb + self.min_num_params = min_num_params + self.state_dict_device = torch.device("cpu") if state_dict_to_cpu else None + self._process_group = None + + @property + def process_group(self): + if self._process_group is None: + self._process_group = torch.distributed.new_group() + return self._process_group + + def setup_distributed(self) -> None: + if not self.on_gpu: + raise MisconfigurationException( + "You selected accelerator to be `ddp_fully_sharded`, but GPU is not available." + ) + super().setup_distributed() + torch.cuda.set_device(self.root_device) + + @contextlib.contextmanager + def model_sharded_context(self) -> Generator: + precision = self.lightning_module.trainer.precision + + def wrap_policy(*args, **kwargs): + return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params) + + with enable_wrap( + wrapper_cls=FullyShardedDataParallel, + auto_wrap_policy=wrap_policy, + process_group=self.process_group, + cpu_offload=self.cpu_offload, + move_grads_to_cpu=self.move_grads_to_cpu, + flatten_parameters=self.flatten_parameters, + mixed_precision=precision == "mixed", + reshard_after_forward=self.reshard_after_forward, + fp32_reduce_scatter=self.fp32_reduce_scatter, + compute_dtype=self.compute_dtype, + bucket_cap_mb=self.bucket_cap_mb, + state_dict_device=self.state_dict_device, + ): + yield + + def connect(self, model: Module) -> None: + super().connect(model) + model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False) + if not model_call_configure_sharded_model_hook: + # if model has not called configure sharded model, we reset + # the training type plugin's call_configure_sharded_model_hook + # to give trainer a chance to configure. + self.call_configure_sharded_model_hook = True + + def configure_ddp(self) -> None: + if not self.cpu_offload: + # When using CPU Offload, FSDP will manage the CUDA movement for us. + # Note: this would be problematic for large model (which could not fit in one GPU) + # as FSDP module.to(device) would first summon all parameters + # (TODO: need to figure out solution) + self.model_to_device() + + # setup optimizers after fully sharded has wrapped the lightning module + self.lightning_module.trainer.accelerator.setup_optimizers(self.lightning_module.trainer) + + def pre_dispatch(self) -> None: + if self.sync_batchnorm: + self.model = self.configure_sync_batchnorm(self.model) + self.configure_ddp() + self.barrier() + + def model_to_device(self) -> None: + # ensure we update the device type in the lightning module + self.lightning_module.to(self.root_device) + + def lightning_module_state_dict(self) -> Dict[str, Union[Any, Tensor]]: + # Currently it is same as default TrainingTypePlugin, i.e. return + # the full state dict for FSDP, in the future, we will provide sharded + # state dict. + return super().lightning_module_state_dict() + + @property + def setup_optimizers_in_pre_dispatch(self) -> bool: + # Setup optimizers after the Fully Sharded Model has been made + return True + + def training_step(self, *args, **kwargs): + return self.model.training_step(*args, **kwargs) + + def validation_step(self, *args, **kwargs): + return self.model.validation_step(*args, **kwargs) + + def test_step(self, *args, **kwargs): + return self.model.test_step(*args, **kwargs) + + def predict_step(self, *args, **kwargs): + return self.model.predict_step(*args, **kwargs) + + def post_training_step(self): + pass + + @classmethod + def register_plugins(cls, plugin_registry: Dict): + plugin_registry.register( + "fsdp", + cls, + description="Fully sharded training with checkpointing the full state dict.", + ) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 3068b4ffccb2b..4d692ec517d19 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -26,6 +26,7 @@ ApexMixedPrecisionPlugin, DataParallelPlugin, DDP2Plugin, + DDPFullyShardedPlugin, DDPPlugin, DDPShardedPlugin, DDPSpawnPlugin, @@ -33,6 +34,7 @@ DeepSpeedPlugin, DeepSpeedPrecisionPlugin, DoublePrecisionPlugin, + FullyShardedNativeMixedPrecisionPlugin, HorovodPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin, @@ -265,8 +267,13 @@ def use_dp(self) -> bool: @property def use_ddp(self) -> bool: return self._distrib_type in ( - DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED, - DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED, DistributedType.TPU_SPAWN + DistributedType.DDP, + DistributedType.DDP_SPAWN, + DistributedType.DDP_SHARDED, + DistributedType.DDP_SHARDED_SPAWN, + DistributedType.DDP_FULLY_SHARDED, + DistributedType.DEEPSPEED, + DistributedType.TPU_SPAWN, ) @property @@ -281,6 +288,14 @@ def use_horovod(self) -> bool: def use_deepspeed(self) -> bool: return self._distrib_type == DistributedType.DEEPSPEED + @property + def _is_sharded_training_type(self) -> bool: + return isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)) + + @property + def _is_fully_sharded_training_type(self) -> bool: + return isinstance(self.training_type_plugin, DDPFullyShardedPlugin) + @property def is_distributed(self) -> bool: # Used for custom plugins. @@ -365,8 +380,10 @@ def select_precision_plugin(self) -> PrecisionPlugin: raise MisconfigurationException(msg) else: log.info("Using native 16bit precision.") - if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)): + if self._is_sharded_training_type: return ShardedNativeMixedPrecisionPlugin() + if self._is_fully_sharded_training_type: + return FullyShardedNativeMixedPrecisionPlugin() return NativeMixedPrecisionPlugin() if self.amp_type == AMPType.APEX: @@ -375,7 +392,7 @@ def select_precision_plugin(self) -> PrecisionPlugin: "You have asked for Apex AMP but you have not installed it yet." " Install apex first using this guide: https://github.com/NVIDIA/apex#linux" ) - if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)): + if self._is_sharded_training_type or self._is_fully_sharded_training_type: raise MisconfigurationException( "Sharded Plugin is not supported with Apex AMP," " please using native AMP for 16-bit precision." @@ -407,6 +424,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN + use_ddp_fully_sharded = self._distrib_type == DistributedType.DDP_FULLY_SHARDED # TODO: decouple from TE # ddp script mode uses the same flags as TE @@ -426,6 +444,8 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: ddp_plugin_cls = DDPPlugin elif use_ddp_spawn or use_ddp_cpu_spawn: ddp_plugin_cls = DDPSpawnPlugin + elif use_ddp_fully_sharded: + ddp_plugin_cls = DDPFullyShardedPlugin else: ddp_plugin_cls = DDPPlugin @@ -481,10 +501,11 @@ def select_accelerator(self) -> Accelerator: acc_cls = TPUAccelerator else: acc_cls = CPUAccelerator - + # as precision_plugin is dependent on training_type_plugin, make sure + # that we first select training_type_plugin, then precision_plugin return acc_cls( - precision_plugin=self.precision_plugin, training_type_plugin=self.training_type_plugin, + precision_plugin=self.precision_plugin, ) def select_cluster_environment(self) -> ClusterEnvironment: diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index e18e18cdf953a..6664be43bef88 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -36,6 +36,7 @@ _BOLTS_AVAILABLE, _DEEPSPEED_AVAILABLE, _FAIRSCALE_AVAILABLE, + _FAIRSCALE_FULLY_SHARDED_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE, _GROUP_AVAILABLE, diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index e01f8862486d3..98e10a9126a44 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -80,6 +80,7 @@ def is_interactive_compatible(self) -> bool: DDP_SHARDED = 'ddp_sharded' DDP_SHARDED_SPAWN = 'ddp_sharded_spawn' RPC_SEQUENTIAL_PLUGIN = 'rpc_sequential' + DDP_FULLY_SHARDED = "ddp_fully_sharded" class DeviceType(LightningEnum): diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 8e63d9d3da156..f40d092f68e9f 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -77,6 +77,7 @@ def _compare_version(package: str, op, version) -> bool: _FAIRSCALE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and not _IS_WINDOWS and _module_available('fairscale.nn') _FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.le, "0.1.3") _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.3") +_FAIRSCALE_FULLY_SHARDED_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.4") _GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group') _HOROVOD_AVAILABLE = _module_available("horovod.torch") _HYDRA_AVAILABLE = _module_available("hydra") diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index 12810ba30ce3c..630a341ec2d30 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -24,6 +24,7 @@ _APEX_AVAILABLE, _DEEPSPEED_AVAILABLE, _FAIRSCALE_AVAILABLE, + _FAIRSCALE_FULLY_SHARDED_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE, _HOROVOD_AVAILABLE, _NATIVE_AMP_AVAILABLE, @@ -69,6 +70,7 @@ def __new__( rpc: bool = False, fairscale: bool = False, fairscale_pipe: bool = False, + fairscale_fully_sharded: bool = False, deepspeed: bool = False, **kwargs ): @@ -89,6 +91,8 @@ def __new__( special: running in special mode, outside pytest suit rpc: requires Remote Procedure Call (RPC) fairscale: if `fairscale` module is required to run the test + fairscale_pipe: if `fairscale` with pipe module is required to run the test + fairscale_fully_sharded: if `fairscale` fully sharded module is required to run the test deepspeed: if `deepspeed` module is required to run the test kwargs: native pytest.mark.skipif keyword arguments """ @@ -160,6 +164,10 @@ def __new__( conditions.append(not _FAIRSCALE_PIPE_AVAILABLE) reasons.append("Fairscale Pipe") + if fairscale_fully_sharded: + conditions.append(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE) + reasons.append("Fairscale Fully Sharded") + if deepspeed: conditions.append(not _DEEPSPEED_AVAILABLE) reasons.append("Deepspeed") diff --git a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py new file mode 100644 index 0000000000000..c4826d09dbaf8 --- /dev/null +++ b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py @@ -0,0 +1,210 @@ +import os +from typing import Any, Dict, Optional +from unittest import mock + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.plugins import DDPFullyShardedPlugin, FullyShardedNativeMixedPrecisionPlugin +from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf + +if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: + from fairscale.nn import FullyShardedDataParallel, wrap + + +def test_invalid_on_cpu(tmpdir): + """ + Test to ensure that to raise Misconfiguration for FSDP on CPU. + """ + with pytest.raises( + MisconfigurationException, + match="You selected accelerator to be `ddp_fully_sharded`, but GPU is not available.", + ): + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + plugins="fsdp", + ) + assert isinstance(trainer.accelerator.training_type_plugin, DDPFullyShardedPlugin) + trainer.accelerator.setup_environment() + + +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) +@mock.patch("torch.cuda.device_count", return_value=1) +@mock.patch("torch.cuda.is_available", return_value=True) +@RunIf(amp_apex=True, fairscale_fully_sharded=True) +def test_invalid_apex_sharded(device_count_mock, mock_cuda_available, tmpdir): + """ + Test to ensure that we raise an error when we try to use apex and fully sharded + """ + with pytest.raises( + MisconfigurationException, + match="Sharded Plugin is not supported with Apex AMP", + ): + Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + plugins="fsdp", + gpus=1, + precision=16, + amp_backend="apex", + ) + + +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) +@mock.patch("torch.cuda.device_count", return_value=1) +@mock.patch("torch.cuda.is_available", return_value=True) +@RunIf(amp_native=True, fairscale_fully_sharded=True) +def test_fsdp_with_sharded_amp(device_count_mock, mock_cuda_available, tmpdir): + """ + Test to ensure that plugin native amp plugin is correctly chosen when using sharded + """ + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + plugins="fsdp", + gpus=1, + precision=16, + ) + assert isinstance(trainer.accelerator.training_type_plugin, DDPFullyShardedPlugin) + assert isinstance(trainer.accelerator.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) + + +class TestFSDPModel(BoringModel): + + def setup(self, stage: str) -> None: + if stage != "fit": + # when running stages like test, validate, and predict, we will skip setting up, + # will directly use the module itself unless we load from checkpoint + return + # resetting call_configure_sharded_model_hook attribute so that we could call + # configure sharded model + self.call_configure_sharded_model_hook = False + # for loading full state dict, we first need to create a new unwrapped model + # to load state dict and then wrapping + self.layer = torch.nn.Sequential( + torch.nn.Linear(32, 32), + torch.nn.ReLU(), + torch.nn.Linear(32, 2), + ) + + def configure_sharded_model(self) -> None: + for i, layer in enumerate(self.layer): + if i % 2 == 0: + self.layer[i] = wrap(layer) + self.layer = wrap(self.layer) + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + # when loading full state dict, we first need to create a new unwrapped model + self.setup("fit") + + def configure_optimizers(self): + return torch.optim.SGD(self.layer.parameters(), lr=0.1) + + def on_train_start(self) -> None: + self._assert_layer_fsdp_instance() + + def on_test_start(self) -> None: + self._assert_layer_fsdp_instance() + + def on_validation_start(self) -> None: + self._assert_layer_fsdp_instance() + + def on_prediction_start(self) -> None: + self._assert_layer_fsdp_instance() + + def _assert_layer_fsdp_instance(self) -> None: + assert isinstance(self.layer, FullyShardedDataParallel) + assert isinstance(self.layer.module[0], FullyShardedDataParallel) + assert isinstance(self.layer.module[2], FullyShardedDataParallel) + # root should not be resharding + assert self.layer.reshard_after_forward is False + # Assert that the nested layers are set reshard_after_forward to True + assert self.layer.module[0].reshard_after_forward is True + assert self.layer.module[2].reshard_after_forward is True + + +@RunIf( + min_gpus=1, + skip_windows=True, + fairscale_fully_sharded=True, + amp_native=True, + special=True, +) +def test_fully_sharded_plugin_checkpoint(tmpdir): + """ + Test to ensure that checkpoint is saved correctly when using a single GPU, and all stages can be run. + """ + + model = TestFSDPModel() + trainer = Trainer( + default_root_dir=tmpdir, + gpus=1, + plugins="fsdp", + precision=16, + max_epochs=1, + ) + _run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt")) + + +@RunIf( + min_gpus=2, + skip_windows=True, + fairscale_fully_sharded=True, + amp_native=True, + special=True, +) +def test_fully_sharded_plugin_checkpoint_multi_gpus(tmpdir): + """ + Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run. + """ + + model = TestFSDPModel() + ck = ModelCheckpoint(save_last=True) + trainer = Trainer( + default_root_dir=tmpdir, + gpus=2, + plugins="fsdp", + precision=16, + max_epochs=1, + callbacks=[ck], + ) + _run_multiple_stages(trainer, model) + + +def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel): + # Use FullySharded to get the state dict for the sake of comparison + model_state_dict = trainer.accelerator.lightning_module_state_dict() + + if trainer.is_global_zero: + saved_model = cls.load_from_checkpoint(ckpt_path) + + # Assert model parameters are identical after loading + for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()): + assert torch.equal(ddp_param.float().cpu(), shard_param) + + +def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): + trainer.fit(model) + + model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False) + trainer_accelerator_call_configure_sharded_model_hook = (trainer.accelerator.call_configure_sharded_model_hook) + + model_path = (model_path if model_path else trainer.checkpoint_callback.last_model_path) + + assert model_call_configure_sharded_model_hook + assert not trainer_accelerator_call_configure_sharded_model_hook + trainer.save_checkpoint(model_path, weights_only=True) + + _assert_save_equality(trainer, model_path, cls=TestFSDPModel) + + # Test entry point + trainer.test(model) # model is wrapped, will not call configure_shared_model + + # provide model path, will create a new unwrapped model and load and then call configure_shared_model to wrap + trainer.test(ckpt_path=model_path)