Skip to content

Commit

Permalink
Flexible and easy to use HSDP setting (#19504)
Browse files Browse the repository at this point in the history
Co-authored-by: awaelchli <aedu.waelchli@gmail.com>
  • Loading branch information
Liyang90 and awaelchli authored Jun 6, 2024
1 parent 1a6786d commit 7668a6b
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 6 deletions.
2 changes: 2 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a call to `torch.distributed.destroy_process_group` in atexit handler if process group needs destruction ([#19931](https://github.com/Lightning-AI/pytorch-lightning/pull/19931))

- Added support for configuring hybrid-sharding by passing a tuple for the `FSDPStrategy(device_mesh=...)` argument ([#19504](https://github.com/Lightning-AI/pytorch-lightning/pull/19504))


### Changed

Expand Down
20 changes: 19 additions & 1 deletion src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,14 @@
from lightning.fabric.utilities.types import _PATH, _Stateful

if TYPE_CHECKING:
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy

_POLICY = Union[Set[Type[Module]], Callable[[Module, bool, int], bool], ModuleWrapPolicy]
_SHARDING_STRATEGY = Union[ShardingStrategy, Literal["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD"]]


_FSDP_ALIASES = ("fsdp", "fsdp_cpu_offload")


Expand Down Expand Up @@ -117,10 +119,14 @@ class FSDPStrategy(ParallelStrategy, _Sharded):
- ``"SHARD_GRAD_OP"``: Shards gradients and optimizer states only. Model parameters get replicated.
- ``"NO_SHARD"``: No sharding (identical to regular DDP).
- ``"HYBRID_SHARD"``: Shards model parameters, gradients, and optimizer states within a single machine, but
replicates across machines.
replicates across machines. See also the `device_mesh` parameter below.
Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value.
device_mesh: A tuple `(replication size, sharding size)` that defines over how many devices to shard and
replicate the model. The product of the two numbers must equal the world size. Only valid in combination
with the `HYBRID_SHARD` sharding strategy.
state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.
- ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file.
Expand All @@ -146,6 +152,7 @@ def __init__(
activation_checkpointing_policy: Optional["_POLICY"] = None,
sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD",
state_dict_type: Literal["full", "sharded"] = "sharded",
device_mesh: Optional[Union[Tuple[int], "DeviceMesh"]] = None,
**kwargs: Any,
) -> None:
super().__init__(
Expand All @@ -163,6 +170,11 @@ def __init__(
# Enables joint setup of model and optimizer, multiple optimizer param groups, and `torch.compile()`
self._fsdp_kwargs.setdefault("use_orig_params", True)

if device_mesh is not None:
if not _TORCH_GREATER_EQUAL_2_2:
raise ValueError("The `device_mesh` argument is only supported in torch >= 2.2.")
self._fsdp_kwargs["device_mesh"] = device_mesh

self._activation_checkpointing_kwargs = _activation_checkpointing_kwargs(
activation_checkpointing, activation_checkpointing_policy
)
Expand Down Expand Up @@ -244,6 +256,12 @@ def setup_environment(self) -> None:
super().setup_environment()
self._setup_distributed()

# if 'device_mesh' in the `_fsdp_kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
if isinstance(self._fsdp_kwargs.get("device_mesh"), tuple):
from torch.distributed.device_mesh import init_device_mesh

self._fsdp_kwargs["device_mesh"] = init_device_mesh("cuda", self._fsdp_kwargs["device_mesh"])

@override
def setup_module_and_optimizers(
self, module: Module, optimizers: List[Optimizer]
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a call to `torch.distributed.destroy_process_group` in atexit handler if process group needs destruction ([#19931](https://github.com/Lightning-AI/pytorch-lightning/pull/19931))

- Added support for configuring hybrid-sharding by passing a tuple for the `FSDPStrategy(device_mesh=...)` argument ([#19504](https://github.com/Lightning-AI/pytorch-lightning/pull/19504))


### Changed

Expand Down
41 changes: 38 additions & 3 deletions src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,21 @@
from contextlib import contextmanager, nullcontext
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Literal, Mapping, Optional, Set, Type, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
List,
Literal,
Mapping,
Optional,
Set,
Tuple,
Type,
Union,
)

import torch
from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only
Expand Down Expand Up @@ -53,7 +67,10 @@
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.utilities.imports import (
_TORCH_GREATER_EQUAL_2_1,
_TORCH_GREATER_EQUAL_2_2,
)
from lightning.fabric.utilities.init import _EmptyInit, _has_meta_device_parameters_or_buffers
from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors
from lightning.fabric.utilities.optimizer import _optimizers_to_device
Expand All @@ -70,6 +87,7 @@
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn

if TYPE_CHECKING:
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision, ShardingStrategy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy

Expand Down Expand Up @@ -114,10 +132,14 @@ class FSDPStrategy(ParallelStrategy):
- ``"SHARD_GRAD_OP"``: Shards gradients and optimizer states only. Model parameters get replicated.
- ``"NO_SHARD"``: No sharding (identical to regular DDP).
- ``"HYBRID_SHARD"``: Shards model parameters, gradients, and optimizer states within a single machine, but
replicates across machines.
replicates across machines. See also the `device_mesh` parameter below.
Also accepts a :class:`torch.distributed.fsdp.ShardingStrategy` enum value.
device_mesh: A tuple `(replication size, sharding size)` that defines over how many devices to shard and
replicate the model. The product of the two numbers must equal the world size. Only valid in combination
with the `HYBRID_SHARD` sharding strategy.
state_dict_type: The format in which the state of the model and optimizers gets saved into the checkpoint.
- ``"full"``: The full weights and optimizer states get assembled on rank 0 and saved to a single file.
Expand Down Expand Up @@ -147,6 +169,7 @@ def __init__(
activation_checkpointing_policy: Optional["_POLICY"] = None,
sharding_strategy: "_SHARDING_STRATEGY" = "FULL_SHARD",
state_dict_type: Literal["full", "sharded"] = "full",
device_mesh: Optional[Union[Tuple[int], "DeviceMesh"]] = None,
**kwargs: Any,
) -> None:
super().__init__(
Expand All @@ -162,6 +185,12 @@ def __init__(
self.cpu_offload = _init_cpu_offload(cpu_offload)
self.mixed_precision = mixed_precision
self.kwargs = _auto_wrap_policy_kwargs(auto_wrap_policy, kwargs)

if device_mesh is not None:
if not _TORCH_GREATER_EQUAL_2_2:
raise ValueError("The `device_mesh` argument is only supported in torch >= 2.2.")
self.kwargs["device_mesh"] = device_mesh

self.sharding_strategy = _init_sharding_strategy(sharding_strategy, self.kwargs)

# Avoids the need for user to reference params in `configure_optimizers` via
Expand Down Expand Up @@ -242,6 +271,12 @@ def setup_environment(self) -> None:
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)

# if 'device_mesh' in the `kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
if isinstance(self.kwargs.get("device_mesh"), tuple):
from torch.distributed.device_mesh import init_device_mesh

self.kwargs["device_mesh"] = init_device_mesh("cuda", self.kwargs["device_mesh"])

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

Expand Down
7 changes: 6 additions & 1 deletion tests/tests_fabric/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_sharding_strategy():


@pytest.mark.parametrize("sharding_strategy", ["HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"])
def test_hybrid_shard_configuration(sharding_strategy):
def test_hybrid_shard_configuration(sharding_strategy, monkeypatch):
"""Test that the hybrid sharding strategies can only be used with automatic wrapping or a manually specified pg."""
with pytest.raises(RuntimeError, match="The hybrid sharding strategy requires you to pass at least one of"):
FSDPStrategy(sharding_strategy=sharding_strategy)
Expand All @@ -85,6 +85,11 @@ def test_hybrid_shard_configuration(sharding_strategy):
assert strategy.sharding_strategy.name == sharding_strategy
assert strategy._fsdp_kwargs["process_group"] is process_group

monkeypatch.setattr("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", False)
with pytest.raises(ValueError, match="`device_mesh` argument is only supported in torch >= 2.2."):
FSDPStrategy(device_mesh=Mock())

monkeypatch.setattr("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", True)
device_mesh = Mock()
strategy = FSDPStrategy(sharding_strategy=sharding_strategy, device_mesh=device_mesh)
assert strategy.sharding_strategy.name == sharding_strategy
Expand Down
7 changes: 6 additions & 1 deletion tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def test_sharding_strategy():


@pytest.mark.parametrize("sharding_strategy", ["HYBRID_SHARD", "_HYBRID_SHARD_ZERO2"])
def test_hybrid_sharding_strategy(sharding_strategy):
def test_hybrid_shard_configuration(sharding_strategy, monkeypatch):
"""Test that the hybrid sharding strategies can only be used with automatic wrapping or a manually specified pg."""
with pytest.raises(RuntimeError, match="The hybrid sharding strategy requires you to pass at least one of"):
FSDPStrategy(sharding_strategy=sharding_strategy)
Expand All @@ -514,6 +514,11 @@ def test_hybrid_sharding_strategy(sharding_strategy):
assert strategy.sharding_strategy.name == sharding_strategy
assert strategy.kwargs["process_group"] is process_group

monkeypatch.setattr("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", False)
with pytest.raises(ValueError, match="`device_mesh` argument is only supported in torch >= 2.2."):
FSDPStrategy(device_mesh=Mock())

monkeypatch.setattr("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_2", True)
device_mesh = Mock()
strategy = FSDPStrategy(sharding_strategy=sharding_strategy, device_mesh=device_mesh)
assert strategy.sharding_strategy.name == sharding_strategy
Expand Down

0 comments on commit 7668a6b

Please sign in to comment.