Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

accelerator refactor - add parallel plugins #5714

Merged
merged 17 commits into from
Jan 31, 2021
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516))


- Refactored Accelerators and Plugins (
[#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715),
)
- Refactored Accelerators and Plugins
* Added base classes for plugins ([#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715))
* Added parallel plugins for DP, DDP, DDPSpawn, DDP2 and Horovod ([#5714](https://github.com/PyTorchLightning/pytorch-lightning/pull/5714))


### Deprecated
Expand Down
272 changes: 272 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
from time import sleep
from typing import Any, Dict, Optional, Union

import numpy as np
import torch
import torch.distributed as torch_distrib

from pytorch_lightning import _logger as log
from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities import _HYDRA_AVAILABLE
from pytorch_lightning.utilities.distributed import (
find_free_network_port,
rank_zero_only,
ReduceOp,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything

if _HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path


class DDPPlugin(ParallelPlugin):
"""
Plugin for multi-process single-device training on one or multiple nodes.

The master process in each node spawns N-1 child processes via :func:`subprocess.Popen`,
where N is the number of devices (e.g. GPU) per node.
It is very similar to how :mod:`torch.distributed.launch` launches processes.
"""

distributed_backend = "ddp"

def __init__(
self,
parallel_devices,
num_nodes=1,
cluster_environment: ClusterEnvironment = None,
sync_batchnorm=False,
**kwargs: Dict[str, Any],
) -> None:
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
self.interactive_ddp_procs = []
self.num_nodes = num_nodes
self.sync_batchnorm = sync_batchnorm
self.dist = LightningDistributed()
self._ddp_kwargs = kwargs
self._has_spawned_children = False
self.task_idx = None
self.node_rank = 0
self.num_processes = len(parallel_devices)

@property
def root_device(self):
return self.parallel_devices[self.local_rank]

@property
def lightning_module(self):
# the model may not be wrapped with DistributedDataParallel if calling this too early
# fixme: uncomment when this class will actually be used
# return unwrap_lightning_module(self._model)
pass

@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
return distributed_sampler_kwargs

def setup(self, model):
self._model = model

# start the other scripts
# TODO: make sure this works, in torchelastic we should not launch child processes!
if os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1":
self._call_children_scripts()

# set the task idx
self.task_idx = self.cluster_environment.local_rank()

def _call_children_scripts(self):

# bookkeeping of spawned processes
assert self.global_rank == 0
self._check_can_spawn_children()
self._has_spawned_children = True

# DDP Environment variables
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1")
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", str(find_free_network_port()))

# allow the user to pass the node rank
node_rank = "0"
node_rank = os.environ.get("NODE_RANK", node_rank)
node_rank = os.environ.get("GROUP_RANK", node_rank)
os.environ["NODE_RANK"] = node_rank
os.environ["LOCAL_RANK"] = "0"

# when user is using hydra find the absolute path
path_lib = os.path.abspath if not _HYDRA_AVAILABLE else to_absolute_path

# pull out the commands used to run the script and resolve the abs file path
command = sys.argv
try:
full_path = path_lib(command[0])
except Exception as e:
full_path = os.path.abspath(command[0])

command[0] = full_path
# use the same python interpreter and actually running
command = [sys.executable] + command

# the visible devices tell us how many GPUs we want to use.
# when the trainer script was called the device has already been scoped by the time
# code reaches this point. so, to call the scripts, we need to leave cuda visible devices alone
# but forward the GPUs selected via environment variables
if self.parallel_devices is None:
raise MisconfigurationException("you selected (distribute_backend = ddp) but did not set Trainer(gpus=?)")

os.environ["PL_TRAINER_GPUS"] = ",".join([str(device.index) for device in self.parallel_devices])
os.environ["PL_IN_DDP_SUBPROCESS"] = "1"

if self.lightning_module.logger is not None:
os.environ["PL_EXP_VERSION"] = str(self.lightning_module.logger.version)

num_gpus = len(self.parallel_devices)
os.environ["WORLD_SIZE"] = f"{num_gpus * self.num_nodes}"

self.interactive_ddp_procs = []

for local_rank in range(1, self.num_processes):
env_copy = os.environ.copy()
env_copy["LOCAL_RANK"] = f"{local_rank}"

# remove env var if global seed not set
if os.environ.get("PL_GLOBAL_SEED") is None and "PL_GLOBAL_SEED" in env_copy:
del env_copy["PL_GLOBAL_SEED"]

# start process
# if hydra is available and initialized, make sure to set the cwd correctly
cwd: Optional[str] = None
if _HYDRA_AVAILABLE:
if HydraConfig.initialized():
cwd = get_original_cwd()
proc = subprocess.Popen(command, env=env_copy, cwd=cwd)
self.interactive_ddp_procs.append(proc)

# starting all processes at once can cause issues
# with dataloaders delay between 1-10 seconds
delay = np.random.uniform(1, 5, 1)[0]
sleep(delay)

def _check_can_spawn_children(self):
if self._has_spawned_children:
raise RuntimeError(
"You tried to run `.fit` or `.test` multiple times in the same script."
" This is not supported in DDP mode, switch to `distributed_backend='ddp_spawn'` instead."
)

def set_world_ranks(self):
self.local_rank = self.task_idx
self.node_rank = self.cluster_environment.node_rank()
self.global_rank = self.node_rank * self.num_processes + self.local_rank
self.world_size = self.num_nodes * self.num_processes

def configure_ddp(self):
# if unset, default `find_unused_parameters` `True`
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get("find_unused_parameters", True)
self._model = LightningDistributedDataParallel(
self.model,
device_ids=self.determine_ddp_device_ids(),
**self._ddp_kwargs,
)

def determine_ddp_device_ids(self):
if self.root_device.type == "cpu":
return None
return [self.root_device.index]

def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
# TODO: From where to get cluster environment?
os.environ["MASTER_ADDR"] = str(self.cluster_environment.master_address())
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
os.environ["WORLD_SIZE"] = str(self.cluster_environment.world_size())
torch_backend = "nccl" if self.on_gpu else "gloo"

if not torch.distributed.is_initialized():
log.info(f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)

def pre_training(self):
# TODO: check if needed
seed = os.environ.get("PL_GLOBAL_SEED")
if seed is not None:
seed_everything(int(seed))

# determine which process we are and world size
self.set_world_ranks()

# set warning rank
rank_zero_only.rank = self.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
self.init_ddp_connection(self.global_rank, self.world_size)

# TODO: we moved it to the trainer.fit after calling pre_training
# ... need to double check that it is the correct place
# self.trainer.call_setup_hook(self.model)

# on world_size=0 let everyone know training is starting
if self.is_global_zero and not torch.distributed.is_initialized():
log.info("-" * 100)
log.info(f"distributed_backend={self.distributed_backend}")
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
log.info("-" * 100)

# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device

if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)

# move the model to the correct device
self.model_to_device()

self.configure_ddp()

self.barrier()

def post_training(self):
if "WORLD_SIZE" in os.environ:
del os.environ["WORLD_SIZE"]

def barrier(self, *args, **kwargs):
if torch_distrib.is_initialized():
torch_distrib.barrier()

def broadcast(self, obj: object, src: int = 0) -> object:
return self.dist.broadcast(obj)

def model_to_device(self):
if self.root_device.type == "cuda":
torch.cuda.set_device(self.root_device)
self.model.to(self.root_device)

def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
if isinstance(output, torch.Tensor):
output = sync_ddp_if_available(output, group, reduce_op)
return output
54 changes: 54 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from pytorch_lightning.core.step_result import Result
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin


class DDP2Plugin(DDPPlugin):

def setup(self, model):
self._model = model
# set the task idx
self.task_idx = self.cluster_environment.local_rank()
# the difference to DDP is that we don't call children processes here

def reduce(self, output, *args, **kwargs):
if isinstance(output, Result):
output.dp_reduce()

elif isinstance(output, torch.Tensor):
output = output.mean()

return output

@property
def root_device(self):
return self.parallel_devices[0]

def model_to_device(self):
# no need to do anything when model is wrapped in torch.nn.DataParallel
pass

@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=self.num_nodes, rank=self.global_rank)
return distributed_sampler_kwargs

def set_world_ranks(self):
self.local_rank = self.task_idx
self.node_rank = self.cluster_environment.node_rank()
self.global_rank = self.node_rank
self.world_size = self.num_nodes
Loading