-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clean up environment access in plugins #6941
Changes from all commits
c87cf9a
903e673
bdf4eaa
f9478c8
a2f67bd
64c8c36
bb2acdc
03fe590
d4060e2
1ac0a69
ab00577
7a7f4ac
8d33db9
3fa1264
22b1ebf
cd7327a
405734f
39f2961
ca64092
b17a6b7
06fec87
c2744e6
93a0538
ee3f7f8
ca6ee97
90f1d37
391624d
77af73d
5764ef5
d6b2f7c
15324de
0048ae5
d9310a5
1c64dfa
5992563
d968744
c99043d
11c00da
1b57674
50638f8
de8454a
4e4afcd
be33371
eb143a7
e892e57
d8d9e8b
7897e41
b2b705b
8d3dbcf
1a7af2e
d809f2d
1bb6e9d
9161f4f
354c901
2218fac
7005b4e
5bbfe17
a6d0f5d
c3b9db4
246384d
dae1d73
01eb6de
45e9f78
9490ce9
a0a53b7
7d39a92
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,9 +14,9 @@ | |
|
||
import os | ||
import socket | ||
from typing import Optional | ||
|
||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment | ||
from pytorch_lightning.utilities import rank_zero_only | ||
|
||
|
||
class LightningEnvironment(ClusterEnvironment): | ||
|
@@ -34,6 +34,8 @@ class LightningEnvironment(ClusterEnvironment): | |
def __init__(self): | ||
super().__init__() | ||
self._master_port = None | ||
self._global_rank: int = 0 | ||
self._world_size: int = 1 | ||
|
||
def creates_children(self) -> bool: | ||
return False | ||
|
@@ -46,8 +48,18 @@ def master_port(self) -> int: | |
self._master_port = os.environ.get("MASTER_PORT", find_free_network_port()) | ||
return int(self._master_port) | ||
|
||
def world_size(self) -> Optional[int]: | ||
return None | ||
def world_size(self) -> int: | ||
return self._world_size | ||
|
||
def set_world_size(self, size: int) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need set_world_size. Can you just use a setter and make everything properties ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes but this interface is old and all getters are methods from the beginning. |
||
self._world_size = size | ||
|
||
def global_rank(self) -> int: | ||
return self._global_rank | ||
|
||
def set_global_rank(self, rank: int) -> None: | ||
self._global_rank = rank | ||
rank_zero_only.rank = rank | ||
|
||
def local_rank(self) -> int: | ||
return int(os.environ.get("LOCAL_RANK", 0)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,8 +24,11 @@ | |
|
||
class TorchElasticEnvironment(ClusterEnvironment): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
@staticmethod | ||
def is_using_torchelastic() -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why don't we have something similar for slurm? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we have, it's in the accelerator connector. and I will do a follow up. didn't want to make the PR larger |
||
""" Returns ``True`` if the current process was launched using the torchelastic command. """ | ||
required_env_vars = ("RANK", "GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE") | ||
return all(v in os.environ for v in required_env_vars) | ||
|
||
def creates_children(self) -> bool: | ||
return True | ||
|
@@ -51,6 +54,17 @@ def world_size(self) -> Optional[int]: | |
world_size = os.environ.get('WORLD_SIZE') | ||
return int(world_size) if world_size is not None else world_size | ||
|
||
def set_world_size(self, size: int) -> None: | ||
awaelchli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
log.debug("TorchElasticEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.") | ||
|
||
def global_rank(self) -> int: | ||
return int(os.environ["RANK"]) | ||
|
||
def set_global_rank(self, rank: int) -> None: | ||
awaelchli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
log.debug( | ||
"TorchElasticEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored." | ||
) | ||
|
||
def local_rank(self) -> int: | ||
return int(os.environ['LOCAL_RANK']) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,11 +78,11 @@ def __init__( | |
self._ddp_kwargs = kwargs | ||
self._has_spawned_children = False | ||
self.task_idx = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove this? |
||
self.node_rank = 0 | ||
self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices | ||
self._ddp_comm_state = ddp_comm_state | ||
self._ddp_comm_hook = ddp_comm_hook | ||
self._ddp_comm_wrapper = ddp_comm_wrapper | ||
self.set_world_ranks() | ||
|
||
@property | ||
def root_device(self): | ||
|
@@ -193,7 +193,7 @@ def setup_distributed(self): | |
# set up server using proc 0's ip address | ||
# try to init for 20 times at max in case ports are taken | ||
# where to store ip_table | ||
self.init_ddp_connection(self.global_rank, self.world_size) | ||
self.init_ddp_connection() | ||
|
||
# on world_size=0 let everyone know training is starting | ||
if self.is_global_zero and not torch.distributed.is_initialized(): | ||
|
@@ -213,11 +213,11 @@ def _check_can_spawn_children(self): | |
" This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead." | ||
) | ||
|
||
def set_world_ranks(self): | ||
self.local_rank = self.task_idx | ||
self.node_rank = self.cluster_environment.node_rank() | ||
self.global_rank = self.node_rank * self.num_processes + self.local_rank | ||
self.world_size = self.num_nodes * self.num_processes | ||
def set_world_ranks(self) -> None: | ||
if self.cluster_environment is not None: | ||
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) | ||
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) | ||
rank_zero_only.rank = self.cluster_environment.global_rank() | ||
|
||
def pre_configure_ddp(self): | ||
# if unset, default `find_unused_parameters` `True` | ||
|
@@ -260,11 +260,11 @@ def determine_ddp_device_ids(self): | |
return None | ||
return [self.root_device.index] | ||
|
||
def init_ddp_connection(self, global_rank: int, world_size: int) -> None: | ||
os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) | ||
def init_ddp_connection(self, global_rank: Optional[int] = None, world_size: Optional[int] = None) -> None: | ||
global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank() | ||
world_size = world_size if world_size is not None else self.cluster_environment.world_size() | ||
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address() | ||
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) | ||
os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size()) | ||
|
||
if not torch.distributed.is_initialized(): | ||
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") | ||
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,24 +31,31 @@ class HorovodPlugin(ParallelPlugin): | |
|
||
def __init__(self, parallel_devices: Optional[List[torch.device]] = None): | ||
super().__init__(parallel_devices=parallel_devices, cluster_environment=None) | ||
rank_zero_only.rank = self.global_rank | ||
|
||
@property | ||
def global_rank(self) -> int: | ||
return hvd.rank() | ||
|
||
@property | ||
def local_rank(self) -> int: | ||
return hvd.local_rank() | ||
|
||
@property | ||
def world_size(self) -> int: | ||
return hvd.size() | ||
Comment on lines
+37
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't it be cleaner to also have a horovod environment plugin here? since for this part it's similar to torchelastic and should also be handled like that. |
||
|
||
@property | ||
def root_device(self): | ||
return self.parallel_devices[self.local_rank] | ||
|
||
@property | ||
def distributed_sampler_kwargs(self): | ||
distributed_sampler_kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank()) | ||
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) | ||
return distributed_sampler_kwargs | ||
|
||
def setup(self, model): | ||
self._model = model | ||
|
||
self.global_rank = hvd.rank() | ||
self.local_rank = hvd.local_rank() | ||
self.world_size = hvd.size() | ||
rank_zero_only.rank = self.global_rank | ||
|
||
self.model_to_device() | ||
|
||
def pre_dispatch(self): | ||
|
@@ -63,14 +70,14 @@ def _unpack_lightning_optimizer(opt): | |
# increased total batch size | ||
for optimizer in optimizers: | ||
for param_group in optimizer.param_groups: | ||
param_group["lr"] *= hvd.size() | ||
param_group["lr"] *= self.world_size | ||
|
||
# Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR | ||
lr_schedulers = self.lightning_module.trainer.lr_schedulers | ||
for scheduler in lr_schedulers: | ||
scheduler = scheduler["scheduler"] | ||
if isinstance(scheduler, _LRScheduler): | ||
scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs] | ||
scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs] | ||
|
||
# Horovod: broadcast parameters & optimizer state to ensure consistent initialization | ||
hvd.broadcast_parameters(self.lightning_module.state_dict(), root_rank=0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be
pass
too?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it doesn't make a difference. pass is only required when we have nothing under the function. here the docstring is already enough :)