Skip to content

Commit

Permalink
Add timeout to DeepSpeedStrategy (#20474)
Browse files Browse the repository at this point in the history
* allow user to pass kwargs to DeepSpeedStrategy

* Update deepspeed.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update deepspeed.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* make timeout explicit in DeepSpeedStrategy

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
  • Loading branch information
3 people authored Dec 10, 2024
1 parent 1c4612e commit 9983f3a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
8 changes: 7 additions & 1 deletion src/lightning/fabric/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import platform
from collections.abc import Mapping
from contextlib import AbstractContextManager, ExitStack
from datetime import timedelta
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
Expand All @@ -29,6 +30,7 @@
from typing_extensions import override

from lightning.fabric.accelerators import Accelerator, CUDAAccelerator
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.fabric.plugins.precision import Precision
from lightning.fabric.strategies.ddp import DDPStrategy
Expand Down Expand Up @@ -97,6 +99,7 @@ def __init__(
load_full_weights: bool = False,
precision: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
) -> None:
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
billion parameter models. `For more information: https://pytorch-
Expand Down Expand Up @@ -241,6 +244,7 @@ def __init__(
process_group_backend=process_group_backend,
)
self._backward_sync_control = None # DeepSpeed handles gradient accumulation internally
self._timeout: Optional[timedelta] = timeout

self.config = self._load_config(config)
if self.config is None:
Expand Down Expand Up @@ -648,7 +652,9 @@ def _init_deepspeed_distributed(self) -> None:
f"MEMBER: {self.global_rank + 1}/{self.world_size}"
)
self._process_group_backend = self._get_process_group_backend()
deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port)
deepspeed.init_distributed(
self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout
)

def _set_node_environment_variables(self) -> None:
assert self.cluster_environment is not None
Expand Down
8 changes: 7 additions & 1 deletion src/lightning/pytorch/strategies/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections import OrderedDict
from collections.abc import Generator, Mapping
from contextlib import contextmanager
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional, Union

Expand All @@ -30,6 +31,7 @@

import lightning.pytorch as pl
from lightning.fabric.plugins import ClusterEnvironment
from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout
from lightning.fabric.strategies import _StrategyRegistry
from lightning.fabric.strategies.deepspeed import (
_DEEPSPEED_AVAILABLE,
Expand Down Expand Up @@ -119,6 +121,7 @@ def __init__(
load_full_weights: bool = False,
precision_plugin: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
) -> None:
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
billion parameter models. `For more information: https://pytorch-
Expand Down Expand Up @@ -264,6 +267,7 @@ def __init__(
precision_plugin=precision_plugin,
process_group_backend=process_group_backend,
)
self._timeout: Optional[timedelta] = timeout

self.config = self._load_config(config)
if self.config is None:
Expand Down Expand Up @@ -364,7 +368,9 @@ def _init_deepspeed_distributed(self) -> None:
f"MEMBER: {self.global_rank + 1}/{self.world_size}"
)
self._process_group_backend = self._get_process_group_backend()
deepspeed.init_distributed(self._process_group_backend, distributed_port=self.cluster_environment.main_port)
deepspeed.init_distributed(
self._process_group_backend, distributed_port=self.cluster_environment.main_port, timeout=self._timeout
)

def _set_node_environment_variables(self) -> None:
assert self.cluster_environment is not None
Expand Down

0 comments on commit 9983f3a

Please sign in to comment.