Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(5/n) Support 2D Parallelism in Lightning Trainer #19878

Merged
merged 9 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions src/lightning/fabric/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,13 @@ def _configure_launcher(self) -> None:
def setup_environment(self) -> None:
super().setup_environment()
self._setup_distributed()
self._setup_device_mesh()
if self._data_parallel_size == "auto":
self._data_parallel_size = self.num_nodes
if self._tensor_parallel_size == "auto":
self._tensor_parallel_size = self.num_processes
self._device_mesh = _setup_device_mesh(
self._data_parallel_size, self._tensor_parallel_size, self.world_size, self.root_device
)

@override
def setup_module(self, module: TModel) -> TModel:
Expand Down Expand Up @@ -303,25 +309,6 @@ def _setup_distributed(self) -> None:
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)

def _setup_device_mesh(self) -> None:
from torch.distributed.device_mesh import init_device_mesh

if self._data_parallel_size == "auto":
self._data_parallel_size = self.num_nodes
if self._tensor_parallel_size == "auto":
self._tensor_parallel_size = self.num_processes
if self._data_parallel_size * self._tensor_parallel_size != self.world_size:
raise RuntimeError(
f"The sizes `data_parallel_size={self._data_parallel_size}` and"
f" `tensor_parallel_size={self._tensor_parallel_size}` multiplied should equal the world size"
f" ({self.world_size})."
)
self._device_mesh = init_device_mesh(
device_type=self.root_device.type,
mesh_shape=(self._data_parallel_size, self._tensor_parallel_size),
mesh_dim_names=("data_parallel", "tensor_parallel"),
)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

Expand Down Expand Up @@ -510,6 +497,27 @@ def _load_checkpoint(
)


def _setup_device_mesh(
data_parallel_size: int,
tensor_parallel_size: int,
world_size: int,
device: torch.device,
) -> "DeviceMesh":
from torch.distributed.device_mesh import init_device_mesh

if data_parallel_size * tensor_parallel_size != world_size:
raise RuntimeError(
f"The sizes `data_parallel_size={data_parallel_size}` and"
f" `tensor_parallel_size={tensor_parallel_size}` multiplied should equal the world size"
f" ({world_size})."
)
return init_device_mesh(
device_type=device.type,
mesh_shape=(data_parallel_size, tensor_parallel_size),
mesh_dim_names=("data_parallel", "tensor_parallel"),
)


def _has_dtensor_modules(module: object) -> TypeGuard[Module]:
from torch.distributed._tensor import DTensor

Expand Down
14 changes: 14 additions & 0 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pathlib import Path
from typing import (
IO,
TYPE_CHECKING,
Any,
Callable,
Dict,
Expand Down Expand Up @@ -76,6 +77,9 @@
OptimizerLRScheduler,
)

if TYPE_CHECKING:
from torch.distributed.device_mesh import DeviceMesh

_ONNX_AVAILABLE = RequirementCache("onnx")

warning_cache = WarningCache()
Expand Down Expand Up @@ -110,6 +114,7 @@ class LightningModule(
"trainer",
"fabric",
"strict_loading",
"device_mesh",
]
+ _DeviceDtypeModuleMixin.__jit_unused_properties__
+ HyperparametersMixin.__jit_unused_properties__
Expand Down Expand Up @@ -142,6 +147,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._fabric: Optional["lf.Fabric"] = None
self._fabric_optimizers: List[_FabricOptimizer] = []

# access to device mesh in `conigure_model()` hook
self._device_mesh: Optional["DeviceMesh"] = None

@overload
def optimizers(
self, use_pl_optimizer: Literal[True] = True
Expand Down Expand Up @@ -319,6 +327,12 @@ def loggers(self) -> Union[List[Logger], List[FabricLogger]]:
return self._trainer.loggers
return []

@property
def device_mesh(self) -> Optional["DeviceMesh"]:
"""Strategies like ``ModelParallelStrategy`` will create a device mesh that can be accessed in the
:meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` hook to parallelize the LightningModule."""
return self._device_mesh

def _call_batch_hook(self, hook_name: str, *args: Any) -> Any:
trainer = self._trainer
if trainer:
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from lightning.pytorch.strategies.ddp import DDPStrategy
from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy
from lightning.pytorch.strategies.fsdp import FSDPStrategy
from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy
from lightning.pytorch.strategies.parallel import ParallelStrategy
from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
from lightning.pytorch.strategies.single_xla import SingleDeviceXLAStrategy # noqa: F401
Expand All @@ -31,6 +32,7 @@
"DDPStrategy",
"DeepSpeedStrategy",
"FSDPStrategy",
"ModelParallelStrategy",
"ParallelStrategy",
"SingleDeviceStrategy",
"Strategy",
Expand Down
Loading
Loading