-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
accelerator refactor - add parallel plugins (#5714)
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
- Loading branch information
1 parent
692f77b
commit 3bacac7
Showing
7 changed files
with
873 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,272 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os | ||
import subprocess | ||
import sys | ||
from time import sleep | ||
from typing import Any, Dict, Optional, Union | ||
|
||
import numpy as np | ||
import torch | ||
import torch.distributed as torch_distrib | ||
|
||
from pytorch_lightning import _logger as log | ||
from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment | ||
from pytorch_lightning.distributed import LightningDistributed | ||
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel | ||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin | ||
from pytorch_lightning.utilities import _HYDRA_AVAILABLE | ||
from pytorch_lightning.utilities.distributed import ( | ||
find_free_network_port, | ||
rank_zero_only, | ||
ReduceOp, | ||
sync_ddp_if_available, | ||
) | ||
from pytorch_lightning.utilities.exceptions import MisconfigurationException | ||
from pytorch_lightning.utilities.seed import seed_everything | ||
|
||
if _HYDRA_AVAILABLE: | ||
from hydra.core.hydra_config import HydraConfig | ||
from hydra.utils import get_original_cwd, to_absolute_path | ||
|
||
|
||
class DDPPlugin(ParallelPlugin): | ||
""" | ||
Plugin for multi-process single-device training on one or multiple nodes. | ||
The master process in each node spawns N-1 child processes via :func:`subprocess.Popen`, | ||
where N is the number of devices (e.g. GPU) per node. | ||
It is very similar to how :mod:`torch.distributed.launch` launches processes. | ||
""" | ||
|
||
distributed_backend = "ddp" | ||
|
||
def __init__( | ||
self, | ||
parallel_devices, | ||
num_nodes=1, | ||
cluster_environment: ClusterEnvironment = None, | ||
sync_batchnorm=False, | ||
**kwargs: Dict[str, Any], | ||
) -> None: | ||
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment) | ||
self.interactive_ddp_procs = [] | ||
self.num_nodes = num_nodes | ||
self.sync_batchnorm = sync_batchnorm | ||
self.dist = LightningDistributed() | ||
self._ddp_kwargs = kwargs | ||
self._has_spawned_children = False | ||
self.task_idx = None | ||
self.node_rank = 0 | ||
self.num_processes = len(parallel_devices) | ||
|
||
@property | ||
def root_device(self): | ||
return self.parallel_devices[self.local_rank] | ||
|
||
@property | ||
def lightning_module(self): | ||
# the model may not be wrapped with DistributedDataParallel if calling this too early | ||
# fixme: uncomment when this class will actually be used | ||
# return unwrap_lightning_module(self._model) | ||
pass | ||
|
||
@property | ||
def distributed_sampler_kwargs(self): | ||
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) | ||
return distributed_sampler_kwargs | ||
|
||
def setup(self, model): | ||
self._model = model | ||
|
||
# start the other scripts | ||
# TODO: make sure this works, in torchelastic we should not launch child processes! | ||
if os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1": | ||
self._call_children_scripts() | ||
|
||
# set the task idx | ||
self.task_idx = self.cluster_environment.local_rank() | ||
|
||
def _call_children_scripts(self): | ||
|
||
# bookkeeping of spawned processes | ||
assert self.global_rank == 0 | ||
self._check_can_spawn_children() | ||
self._has_spawned_children = True | ||
|
||
# DDP Environment variables | ||
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1") | ||
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", str(find_free_network_port())) | ||
|
||
# allow the user to pass the node rank | ||
node_rank = "0" | ||
node_rank = os.environ.get("NODE_RANK", node_rank) | ||
node_rank = os.environ.get("GROUP_RANK", node_rank) | ||
os.environ["NODE_RANK"] = node_rank | ||
os.environ["LOCAL_RANK"] = "0" | ||
|
||
# when user is using hydra find the absolute path | ||
path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path | ||
|
||
# pull out the commands used to run the script and resolve the abs file path | ||
command = sys.argv | ||
try: | ||
full_path = path_lib(command[0]) | ||
except Exception as e: | ||
full_path = os.path.abspath(command[0]) | ||
|
||
command[0] = full_path | ||
# use the same python interpreter and actually running | ||
command = [sys.executable] + command | ||
|
||
# the visible devices tell us how many GPUs we want to use. | ||
# when the trainer script was called the device has already been scoped by the time | ||
# code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone | ||
# but forward the GPUs selected via environment variables | ||
if self.parallel_devices is None: | ||
raise MisconfigurationException("you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)") | ||
|
||
os.environ["PL_TRAINER_GPUS"] = ",".join([str(device.index) for device in self.parallel_devices]) | ||
os.environ["PL_IN_DDP_SUBPROCESS"] = "1" | ||
|
||
if self.lightning_module.logger is not None: | ||
os.environ["PL_EXP_VERSION"] = str(self.lightning_module.logger.version) | ||
|
||
num_gpus = len(self.parallel_devices) | ||
os.environ["WORLD_SIZE"] = f"{num_gpus * self.num_nodes}" | ||
|
||
self.interactive_ddp_procs = [] | ||
|
||
for local_rank in range(1, self.num_processes): | ||
env_copy = os.environ.copy() | ||
env_copy["LOCAL_RANK"] = f"{local_rank}" | ||
|
||
# remove env var if global seed not set | ||
if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy: | ||
del env_copy["PL_GLOBAL_SEED"] | ||
|
||
# start process | ||
# if hydra is available and initialized, make sure to set the cwd correctly | ||
cwd: Optional[str] = None | ||
if _HYDRA_AVAILABLE: | ||
if HydraConfig.initialized(): | ||
cwd = get_original_cwd() | ||
proc = subprocess.Popen(command, env=env_copy, cwd=cwd) | ||
self.interactive_ddp_procs.append(proc) | ||
|
||
# starting all processes at once can cause issues | ||
# with dataloaders delay between 1-10 seconds | ||
delay = np.random.uniform(1, 5, 1)[0] | ||
sleep(delay) | ||
|
||
def _check_can_spawn_children(self): | ||
if self._has_spawned_children: | ||
raise RuntimeError( | ||
"You tried to run `.fit` or `.test` multiple times in the same script." | ||
" This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead." | ||
) | ||
|
||
def set_world_ranks(self): | ||
self.local_rank = self.task_idx | ||
self.node_rank = self.cluster_environment.node_rank() | ||
self.global_rank = self.node_rank * self.num_processes + self.local_rank | ||
self.world_size = self.num_nodes * self.num_processes | ||
|
||
def configure_ddp(self): | ||
# if unset, default `find_unused_parameters` `True` | ||
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True) | ||
self._model = LightningDistributedDataParallel( | ||
self.model, | ||
device_ids=self.determine_ddp_device_ids(), | ||
**self._ddp_kwargs, | ||
) | ||
|
||
def determine_ddp_device_ids(self): | ||
if self.root_device.type == "cpu": | ||
return None | ||
return [self.root_device.index] | ||
|
||
def init_ddp_connection(self, global_rank: int, world_size: int) -> None: | ||
# TODO: From where to get cluster environment? | ||
os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address()) | ||
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) | ||
os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size()) | ||
torch_backend = "nccl" if self.on_gpu else "gloo" | ||
|
||
if not torch.distributed.is_initialized(): | ||
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") | ||
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size) | ||
|
||
def pre_training(self): | ||
# TODO: check if needed | ||
seed = os.environ.get("PL_GLOBAL_SEED") | ||
if seed is not None: | ||
seed_everything(int(seed)) | ||
|
||
# determine which process we are and world size | ||
self.set_world_ranks() | ||
|
||
# set warning rank | ||
rank_zero_only.rank = self.global_rank | ||
|
||
# set up server using proc 0's ip address | ||
# try to init for 20 times at max in case ports are taken | ||
# where to store ip_table | ||
self.init_ddp_connection(self.global_rank, self.world_size) | ||
|
||
# TODO: we moved it to the trainer.fit after calling pre_training | ||
# ... need to double check that it is the correct place | ||
# self.trainer.call_setup_hook(self.model) | ||
|
||
# on world_size=0 let everyone know training is starting | ||
if self.is_global_zero and not torch.distributed.is_initialized(): | ||
log.info("-" * 100) | ||
log.info(f"distributed_backend={self.distributed_backend}") | ||
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") | ||
log.info("-" * 100) | ||
|
||
# set the ranks and devices | ||
self.dist.rank = self.global_rank | ||
self.dist.device = self.root_device | ||
|
||
if self.sync_batchnorm: | ||
self.model = self.configure_sync_batchnorm(self.model) | ||
|
||
# move the model to the correct device | ||
self.model_to_device() | ||
|
||
self.configure_ddp() | ||
|
||
self.barrier() | ||
|
||
def post_training(self): | ||
if "WORLD_SIZE" in os.environ: | ||
del os.environ["WORLD_SIZE"] | ||
|
||
def barrier(self, *args, **kwargs): | ||
if torch_distrib.is_initialized(): | ||
torch_distrib.barrier() | ||
|
||
def broadcast(self, obj: object, src: int = 0) -> object: | ||
return self.dist.broadcast(obj) | ||
|
||
def model_to_device(self): | ||
if self.root_device.type == "cuda": | ||
torch.cuda.set_device(self.root_device) | ||
self.model.to(self.root_device) | ||
|
||
def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): | ||
if isinstance(output, torch.Tensor): | ||
output = sync_ddp_if_available(output, group, reduce_op) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright The PyTorch Lightning team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import torch | ||
|
||
from pytorch_lightning.core.step_result import Result | ||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin | ||
|
||
|
||
class DDP2Plugin(DDPPlugin): | ||
|
||
def setup(self, model): | ||
self._model = model | ||
# set the task idx | ||
self.task_idx = self.cluster_environment.local_rank() | ||
# the difference to DDP is that we don't call children processes here | ||
|
||
def reduce(self, output, *args, **kwargs): | ||
if isinstance(output, Result): | ||
output.dp_reduce() | ||
|
||
elif isinstance(output, torch.Tensor): | ||
output = output.mean() | ||
|
||
return output | ||
|
||
@property | ||
def root_device(self): | ||
return self.parallel_devices[0] | ||
|
||
def model_to_device(self): | ||
# no need to do anything when model is wrapped in torch.nn.DataParallel | ||
pass | ||
|
||
@property | ||
def distributed_sampler_kwargs(self): | ||
distributed_sampler_kwargs = dict(num_replicas=self.num_nodes, rank=self.global_rank) | ||
return distributed_sampler_kwargs | ||
|
||
def set_world_ranks(self): | ||
self.local_rank = self.task_idx | ||
self.node_rank = self.cluster_environment.node_rank() | ||
self.global_rank = self.node_rank | ||
self.world_size = self.num_nodes |
Oops, something went wrong.