Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up environment access in plugins #6941

Merged
merged 66 commits into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
c87cf9a
initial draft
awaelchli Apr 9, 2021
903e673
x
awaelchli Apr 9, 2021
bdf4eaa
x
awaelchli Apr 9, 2021
f9478c8
x
awaelchli Apr 9, 2021
a2f67bd
x
awaelchli Apr 9, 2021
64c8c36
x
awaelchli Apr 9, 2021
bb2acdc
torchelastic
awaelchli Apr 9, 2021
03fe590
x
awaelchli Apr 9, 2021
d4060e2
x
awaelchli Apr 9, 2021
1ac0a69
rename
awaelchli Apr 9, 2021
ab00577
spawn
awaelchli Apr 10, 2021
7a7f4ac
init ddp
awaelchli Apr 10, 2021
8d33db9
x
awaelchli Apr 10, 2021
3fa1264
horovod
awaelchli Apr 10, 2021
22b1ebf
horovod
awaelchli Apr 10, 2021
cd7327a
ranks for DP
awaelchli Apr 10, 2021
405734f
slurm variables
awaelchli Apr 10, 2021
39f2961
fix test
awaelchli Apr 10, 2021
ca64092
slurm env vars in test
awaelchli Apr 10, 2021
b17a6b7
fix test
awaelchli Apr 10, 2021
06fec87
rank
awaelchli Apr 10, 2021
c2744e6
fix test
awaelchli Apr 10, 2021
93a0538
TYPO
awaelchli Apr 10, 2021
ee3f7f8
cpu_te
awaelchli Apr 10, 2021
ca6ee97
slurm environment tests
awaelchli Apr 10, 2021
90f1d37
rpc
awaelchli Apr 10, 2021
391624d
clean up
awaelchli Apr 10, 2021
77af73d
added new tests
awaelchli Apr 10, 2021
5764ef5
None check
awaelchli Apr 10, 2021
d6b2f7c
add test description
awaelchli Apr 10, 2021
15324de
add comments
awaelchli Apr 10, 2021
0048ae5
add more plugins to test
awaelchli Apr 10, 2021
d9310a5
add tests for automatic plugin selection
awaelchli Apr 10, 2021
1c64dfa
include local world size for elastic
awaelchli Apr 10, 2021
5992563
clean up
awaelchli Apr 10, 2021
d968744
add torch elastic local world size to env variable in test
awaelchli Apr 10, 2021
c99043d
make changes to tpu plugin
awaelchli Apr 12, 2021
11c00da
format tests
awaelchli Apr 12, 2021
1b57674
isort
awaelchli Apr 12, 2021
50638f8
Merge branch 'master' into bugfix/elastic-world-size
awaelchli Apr 12, 2021
de8454a
redundant init
awaelchli Apr 12, 2021
4e4afcd
move helper function
awaelchli Apr 12, 2021
be33371
fix test
awaelchli Apr 12, 2021
eb143a7
type hint
awaelchli Apr 12, 2021
e892e57
typing rank properties
awaelchli Apr 12, 2021
d8d9e8b
typo
awaelchli Apr 12, 2021
7897e41
Merge branch 'master' into bugfix/elastic-world-size
awaelchli Apr 12, 2021
b2b705b
missing env var
awaelchli Apr 12, 2021
8d3dbcf
flake
awaelchli Apr 12, 2021
1a7af2e
changelog
awaelchli Apr 13, 2021
d809f2d
redundant init
awaelchli Apr 13, 2021
1bb6e9d
redundant init
awaelchli Apr 13, 2021
9161f4f
Merge branch 'master' into bugfix/elastic-world-size
awaelchli Apr 13, 2021
354c901
Apply suggestions from code review
awaelchli Apr 13, 2021
2218fac
Update pytorch_lightning/plugins/environments/lightning_environment.py
awaelchli Apr 13, 2021
7005b4e
add ddp2 test and fix
awaelchli Apr 13, 2021
5bbfe17
Merge remote-tracking branch 'origin/bugfix/elastic-world-size' into …
awaelchli Apr 13, 2021
a6d0f5d
Update pytorch_lightning/plugins/environments/lightning_environment.py
awaelchli Apr 13, 2021
c3b9db4
test for ddp_cpu, ddp_spawn
awaelchli Apr 13, 2021
246384d
set ranks in ddp_spawn
awaelchli Apr 13, 2021
dae1d73
fix signature
awaelchli Apr 13, 2021
01eb6de
patch xla
awaelchli Apr 13, 2021
45e9f78
world size defaults to 1
awaelchli Apr 13, 2021
9490ce9
deprecation docs
awaelchli Apr 13, 2021
a0a53b7
rename test file
awaelchli Apr 13, 2021
7d39a92
log debug setter
awaelchli Apr 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `self.device` not returning the correct device in replicas of data-parallel ([#6414](https://github.com/PyTorchLightning/pytorch-lightning/pull/6414))


- Fixed process rank not being available right away after `Trainer` instantiation ([#6941](https://github.com/PyTorchLightning/pytorch-lightning/pull/6941))


## [1.2.7] - 2021-04-06

### Fixed
Expand Down
15 changes: 13 additions & 2 deletions pytorch_lightning/plugins/environments/cluster_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Optional


class ClusterEnvironment(ABC):
Expand All @@ -31,9 +30,21 @@ def master_port(self) -> int:
""" An open and configured port in the master node through which all processes communicate. """

@abstractmethod
def world_size(self) -> Optional[int]:
def world_size(self) -> int:
""" The number of processes across all devices and nodes. """

@abstractmethod
def set_world_size(self, size: int) -> None:
pass

@abstractmethod
def global_rank(self) -> int:
""" The rank (index) of the currently running process across all nodes and devices. """
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be pass too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't make a difference. pass is only required when we have nothing under the function. here the docstring is already enough :)


@abstractmethod
def set_global_rank(self, rank: int) -> None:
pass

@abstractmethod
def local_rank(self) -> int:
""" The rank (index) of the currently running process inside of the current node. """
Expand Down
18 changes: 15 additions & 3 deletions pytorch_lightning/plugins/environments/lightning_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

import os
import socket
from typing import Optional

from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.utilities import rank_zero_only


class LightningEnvironment(ClusterEnvironment):
Expand All @@ -34,6 +34,8 @@ class LightningEnvironment(ClusterEnvironment):
def __init__(self):
super().__init__()
self._master_port = None
self._global_rank: int = 0
self._world_size: int = 1

def creates_children(self) -> bool:
return False
Expand All @@ -46,8 +48,18 @@ def master_port(self) -> int:
self._master_port = os.environ.get("MASTER_PORT", find_free_network_port())
return int(self._master_port)

def world_size(self) -> Optional[int]:
return None
def world_size(self) -> int:
return self._world_size

def set_world_size(self, size: int) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need set_world_size. Can you just use a setter and make everything properties ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but this interface is old and all getters are methods from the beginning.
It would be better to have proper setters and getters. I decided not to keep backward compatibility and make the interface consistent. I propose to do a deprecation in a follow up to keep this PR managable

self._world_size = size

def global_rank(self) -> int:
return self._global_rank

def set_global_rank(self, rank: int) -> None:
self._global_rank = rank
rank_zero_only.rank = rank

def local_rank(self) -> int:
return int(os.environ.get("LOCAL_RANK", 0))
Expand Down
16 changes: 11 additions & 5 deletions pytorch_lightning/plugins/environments/slurm_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@

class SLURMEnvironment(ClusterEnvironment):

def __init__(self):
super().__init__()

def creates_children(self) -> bool:
return True

Expand Down Expand Up @@ -69,8 +66,17 @@ def master_port(self) -> int:

return int(default_port)

def world_size(self):
return None
def world_size(self) -> int:
return int(os.environ["SLURM_NTASKS"])

def set_world_size(self, size: int) -> None:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
log.debug("SLURMEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

def global_rank(self) -> int:
return int(os.environ["SLURM_PROCID"])

def set_global_rank(self, rank: int) -> None:
log.debug("SLURMEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")

def local_rank(self) -> int:
return int(os.environ['SLURM_LOCALID'])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@

class TorchElasticEnvironment(ClusterEnvironment):

def __init__(self):
super().__init__()
@staticmethod
def is_using_torchelastic() -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we have something similar for slurm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have, it's in the accelerator connector. and I will do a follow up. didn't want to make the PR larger

""" Returns ``True`` if the current process was launched using the torchelastic command. """
required_env_vars = ("RANK", "GROUP_RANK", "LOCAL_RANK", "LOCAL_WORLD_SIZE")
return all(v in os.environ for v in required_env_vars)

def creates_children(self) -> bool:
return True
Expand All @@ -51,6 +54,17 @@ def world_size(self) -> Optional[int]:
world_size = os.environ.get('WORLD_SIZE')
return int(world_size) if world_size is not None else world_size

def set_world_size(self, size: int) -> None:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
log.debug("TorchElasticEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")

def global_rank(self) -> int:
return int(os.environ["RANK"])

def set_global_rank(self, rank: int) -> None:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
log.debug(
"TorchElasticEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored."
)

def local_rank(self) -> int:
return int(os.environ['LOCAL_RANK'])

Expand Down
22 changes: 11 additions & 11 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ def __init__(
self._ddp_kwargs = kwargs
self._has_spawned_children = False
self.task_idx = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this?

self.node_rank = 0
self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices
self._ddp_comm_state = ddp_comm_state
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper
self.set_world_ranks()

@property
def root_device(self):
Expand Down Expand Up @@ -193,7 +193,7 @@ def setup_distributed(self):
# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
self.init_ddp_connection(self.global_rank, self.world_size)
self.init_ddp_connection()

# on world_size=0 let everyone know training is starting
if self.is_global_zero and not torch.distributed.is_initialized():
Expand All @@ -213,11 +213,11 @@ def _check_can_spawn_children(self):
" This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead."
)

def set_world_ranks(self):
self.local_rank = self.task_idx
self.node_rank = self.cluster_environment.node_rank()
self.global_rank = self.node_rank * self.num_processes + self.local_rank
self.world_size = self.num_nodes * self.num_processes
def set_world_ranks(self) -> None:
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

def pre_configure_ddp(self):
# if unset, default `find_unused_parameters` `True`
Expand Down Expand Up @@ -260,11 +260,11 @@ def determine_ddp_device_ids(self):
return None
return [self.root_device.index]

def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address())
def init_ddp_connection(self, global_rank: Optional[int] = None, world_size: Optional[int] = None) -> None:
global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank()
world_size = world_size if world_size is not None else self.cluster_environment.world_size()
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size())

if not torch.distributed.is_initialized():
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)
Expand Down
14 changes: 10 additions & 4 deletions pytorch_lightning/plugins/training_type/ddp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@

class DDP2Plugin(DDPPlugin):

@property
def global_rank(self) -> int:
return self.node_rank

@property
def world_size(self) -> int:
return self.num_nodes

def setup(self, model):
self._model = model
# set the task idx
Expand Down Expand Up @@ -64,7 +72,5 @@ def _is_single_process_single_device(self) -> bool:
return False

def set_world_ranks(self):
self.local_rank = self.task_idx
self.node_rank = self.cluster_environment.node_rank()
self.global_rank = self.node_rank
self.world_size = self.num_nodes
self.cluster_environment.set_global_rank(self.node_rank)
self.cluster_environment.set_world_size(self.num_nodes)
26 changes: 16 additions & 10 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,16 @@ def __init__(
self._ddp_kwargs = kwargs
self.dist = LightningDistributed()
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
self.node_rank = 0
self.mp_queue = None
self._ddp_comm_state = ddp_comm_state
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper
self._local_rank = 0
self.set_world_ranks()

@property
def local_rank(self) -> int:
return self._local_rank

def __getstate__(self):
""" Makes this plugin pickleable without destroying the queue in the current process. """
Expand Down Expand Up @@ -95,12 +100,12 @@ def setup(self, model):
smp = mp.get_context("spawn")
self.mp_queue = smp.SimpleQueue()

def set_world_ranks(self, process_idx):
self.local_rank = process_idx
self.node_rank = self.cluster_environment.node_rank()
self.task_idx = self.cluster_environment.local_rank()
self.global_rank = self.node_rank * self.num_processes + self.local_rank
self.world_size = self.num_nodes * self.num_processes
def set_world_ranks(self, process_idx: int = 0) -> None:
self._local_rank = process_idx
if self.cluster_environment is not None:
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

@property
def mp_spawn_kwargs(self):
Expand Down Expand Up @@ -213,11 +218,12 @@ def configure_ddp(self):
)
self._register_ddp_hooks()

def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
def init_ddp_connection(self, global_rank: Optional[int], world_size: Optional[int]) -> None:
# TODO: this code is duplicated in DDP and DDPSpawn, make this a function
os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address())
global_rank = global_rank if global_rank is not None else self.cluster_environment.global_rank()
world_size = world_size if world_size is not None else self.cluster_environment.world_size()
os.environ["MASTER_ADDR"] = self.cluster_environment.master_address()
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size())

if not torch.distributed.is_initialized():
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
Expand Down
16 changes: 16 additions & 0 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ class DataParallelPlugin(ParallelPlugin):
def __init__(self, parallel_devices: Optional[List[torch.device]]):
super().__init__(parallel_devices=parallel_devices, cluster_environment=None)

@property
def global_rank(self) -> int:
return 0

@property
def local_rank(self) -> int:
return 0

@property
def node_rank(self) -> int:
return 0

@property
def world_size(self) -> int:
return 1

def setup(self, model):
# model needs to be moved to the device before it is wrapped
model.to(self.root_device)
Expand Down
25 changes: 16 additions & 9 deletions pytorch_lightning/plugins/training_type/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,31 @@ class HorovodPlugin(ParallelPlugin):

def __init__(self, parallel_devices: Optional[List[torch.device]] = None):
super().__init__(parallel_devices=parallel_devices, cluster_environment=None)
rank_zero_only.rank = self.global_rank

@property
def global_rank(self) -> int:
return hvd.rank()

@property
def local_rank(self) -> int:
return hvd.local_rank()

@property
def world_size(self) -> int:
return hvd.size()
Comment on lines +37 to +46
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be cleaner to also have a horovod environment plugin here? since for this part it's similar to torchelastic and should also be handled like that.


@property
def root_device(self):
return self.parallel_devices[self.local_rank]

@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank())
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
return distributed_sampler_kwargs

def setup(self, model):
self._model = model

self.global_rank = hvd.rank()
self.local_rank = hvd.local_rank()
self.world_size = hvd.size()
rank_zero_only.rank = self.global_rank

self.model_to_device()

def pre_dispatch(self):
Expand All @@ -63,14 +70,14 @@ def _unpack_lightning_optimizer(opt):
# increased total batch size
for optimizer in optimizers:
for param_group in optimizer.param_groups:
param_group["lr"] *= hvd.size()
param_group["lr"] *= self.world_size

# Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR
lr_schedulers = self.lightning_module.trainer.lr_schedulers
for scheduler in lr_schedulers:
scheduler = scheduler["scheduler"]
if isinstance(scheduler, _LRScheduler):
scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs]
scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs]

# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
hvd.broadcast_parameters(self.lightning_module.state_dict(), root_rank=0)
Expand Down
19 changes: 16 additions & 3 deletions pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ def __init__(
super().__init__()
self.parallel_devices = parallel_devices
self.cluster_environment = cluster_environment
self.global_rank = 0
self.world_size = 1
self.local_rank = 0

@property
@abstractmethod
Expand All @@ -53,6 +50,22 @@ def on_gpu(self):
def lightning_module(self):
return unwrap_lightning_module(self._model)

@property
def global_rank(self) -> int:
return self.cluster_environment.global_rank() if self.cluster_environment is not None else 0

@property
def local_rank(self) -> int:
return self.cluster_environment.local_rank() if self.cluster_environment is not None else 0

@property
def node_rank(self) -> int:
return self.cluster_environment.node_rank() if self.cluster_environment is not None else 0

@property
def world_size(self) -> int:
return self.cluster_environment.world_size() if self.cluster_environment is not None else 1

@property
def is_global_zero(self) -> bool:
return self.global_rank == 0
Expand Down
Loading