Skip to content

Commit

Permalink
[Fix] Move init dist connection into the setup function (#6506)
Browse files Browse the repository at this point in the history
* Move connection setup into the setup function. Call setup hook after we set up the accelerator

* Added CHANGELOG.md

* fix setup order in callback test

* fix input arguments in test

* Mock distributed function, remove protection to turn into training type hook

* Remove import

* Add missing mock, ensure custom plugin does not create children process

* Skip test on windows

* Update deepspeed to init connection in setup

* Do not initialize distributed module

* Move DeepSpeed tests to special tests since dist communication is being set up

* Special the test to see if this fixes CI

* Delete accelerator connector test to see if its causing build to fail

* Delete deepspeed test

* Revert "Delete accelerator connector test to see if its causing build to fail"

This reverts commit edde60b

* Revert "Delete deepspeed test"

This reverts commit 9d317429

* Reverse hook

* Reverse setup hooks to debug again

* Add todo so i know where i left off

* For single device move in pre_dispatch after setup function

* Add additional model to device hook if any additional parameters have been set

* See if we can enable deepspeed tests

* Revert "See if we can enable deepspeed tests"

This reverts commit b5450de

* See if this hook approach works

* Introduce new granular hooks

* Remove import, fix tpu spawn by moving the function to setup

* Added missing special test

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
SeanNaren and awaelchli authored Mar 18, 2021
1 parent b606171 commit 4e9b453
Show file tree
Hide file tree
Showing 16 changed files with 139 additions and 100 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))


- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506))


## [1.2.4] - 2021-03-16

### Changed
Expand Down
34 changes: 21 additions & 13 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,28 @@ def __init__(
self.lr_schedulers: Sequence = []
self.optimizer_frequencies: Sequence = []

def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
def connect(self, model: LightningModule) -> None:
"""Transfers ownership of the model to this plugin"""
self.training_type_plugin.connect(model)

def setup_environment(self) -> None:
"""
Connects the plugins to the training process, creates optimizers
Setup any processes or distributed connections.
This is called before the LightningModule/DataModule setup hook
which allows the user to access the accelerator environment before setup is complete.
"""
self.training_type_plugin.setup_environment()

def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
"""
Setup plugins for the trainer fit and creates optimizers.
Args:
trainer: the trainer instance to connect to
model: the model to train
trainer: the trainer instance
model: the LightningModule
"""
self.connect_training_type_plugin(self.training_type_plugin, model)
self.setup_training_type_plugin(self.training_type_plugin, model)
self.setup_optimizers(trainer)
self.connect_precision_plugin(self.precision_plugin)
self.setup_precision_plugin(self.precision_plugin)

def start_training(self, trainer: 'Trainer') -> None:
self.training_type_plugin.start_training(trainer)
Expand Down Expand Up @@ -332,14 +343,11 @@ def setup_optimizers(self, trainer: 'Trainer') -> None:
self.lr_schedulers = lr_schedulers
self.optimizer_frequencies = optimizer_frequencies

def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
"""Attaches the training type plugin to the accelerator.
Also transfers ownership of the model to this plugin
"""
plugin.connect(model)
def setup_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
"""Attaches the training type plugin to the accelerator."""
plugin.setup(model)

def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None:
def setup_precision_plugin(self, plugin: PrecisionPlugin) -> None:
"""Attaches the precision plugin to the accelerator"""
model, optimizers, schedulers = plugin.connect(self.model, self.optimizers, self.lr_schedulers)
self.model = model
Expand Down
65 changes: 31 additions & 34 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,16 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
return distributed_sampler_kwargs

def setup(self, model):
self._model = model

def setup_environment(self):
# start the other scripts
if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1":
self._call_children_scripts()

# set the task idx
self.task_idx = self.cluster_environment.local_rank()

self.setup_distributed()

def _call_children_scripts(self):

# bookkeeping of spawned processes
Expand Down Expand Up @@ -161,6 +161,34 @@ def _call_children_scripts(self):
delay = np.random.uniform(1, 5, 1)[0]
sleep(delay)

def setup_distributed(self):
# TODO: check if needed
seed = os.environ.get("PL_GLOBAL_SEED")
if seed is not None:
seed_everything(int(seed))

# determine which process we are and world size
self.set_world_ranks()

# set warning rank
rank_zero_only.rank = self.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
self.init_ddp_connection(self.global_rank, self.world_size)

# on world_size=0 let everyone know training is starting
if self.is_global_zero and not torch.distributed.is_initialized():
log.info("-" * 100)
log.info(f"distributed_backend={self.distributed_backend}")
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
log.info("-" * 100)

# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device

def _check_can_spawn_children(self):
if self._has_spawned_children:
raise RuntimeError(
Expand Down Expand Up @@ -213,37 +241,6 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)

def pre_dispatch(self):
# TODO: check if needed
seed = os.environ.get("PL_GLOBAL_SEED")
if seed is not None:
seed_everything(int(seed))

# determine which process we are and world size
self.set_world_ranks()

# set warning rank
rank_zero_only.rank = self.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
self.init_ddp_connection(self.global_rank, self.world_size)

# TODO: we moved it to the trainer.fit after calling pre_dispatch
# ... need to double check that it is the correct place
# self.trainer.call_setup_hook(self.model)

# on world_size=0 let everyone know training is starting
if self.is_global_zero and not torch.distributed.is_initialized():
log.info("-" * 100)
log.info(f"distributed_backend={self.distributed_backend}")
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
log.info("-" * 100)

# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device

if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)

Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ def distributed_sampler_kwargs(self):
return distributed_sampler_kwargs

def setup(self, model):
self._model = model

os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())

# pass in a state q
Expand Down
10 changes: 0 additions & 10 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,17 +192,7 @@ def _load_config(self, config):
return config

def pre_dispatch(self):
self.set_world_ranks()
self.init_ddp_connection(self.global_rank, self.world_size)

self.init_deepspeed()

# set warning rank
rank_zero_only.rank = self.global_rank

# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device
self.barrier()

def init_deepspeed(self):
Expand Down
8 changes: 0 additions & 8 deletions pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,6 @@ def on_gpu(self):
def lightning_module(self):
return unwrap_lightning_module(self._model)

@abstractmethod
def setup(self, model):
raise NotImplementedError

def connect(self, model, *args, **kwargs):
self.setup(model)
return self.model

@property
def is_global_zero(self) -> bool:
return self.global_rank == 0
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def model_to_device(self) -> None:

self._model.to(self.root_device)

def connect(self, model: torch.nn.Module) -> torch.nn.Module:
self._model = model
def setup(self, model: torch.nn.Module) -> torch.nn.Module:
self.model_to_device()
return self.model

Expand Down
7 changes: 1 addition & 6 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,8 @@ def __init__(self, device: Union[torch.device, int]):
def on_tpu(self) -> bool:
return True

def connect(self, model: torch.nn.Module) -> torch.nn.Module:
self._model = model
self.model_to_device()
return self._model

def model_to_device(self) -> None:
self._model.to(self.root_device)
self.model.to(self.root_device)

def pre_dispatch(self) -> None:
if isinstance(self.device, int):
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ def __init__(
self.tpu_local_core_rank = 0
self.start_method = None

def connect(self, model: torch.nn.Module) -> torch.nn.Module:
def setup(self, model: torch.nn.Module) -> torch.nn.Module:
self.create_mp_queue()
self._model = model
return self._model
return self.model

def create_mp_queue(self):
self.start_method = 'fork'
Expand Down
14 changes: 12 additions & 2 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,19 @@ def __init__(self) -> None:
self._model = None
self._results = None

@abstractmethod
def connect(self, model: 'Module') -> None:
"""Called by the accelerator to connect it with this plugin"""
"""Called by the accelerator to connect the accelerator and the model with this plugin"""
self.model = model

def setup_environment(self) -> None:
"""
Setup any processes or distributed connections.
This is called before the LightningModule/DataModule setup hook
which allows the user to access the accelerator environment before setup is complete.
"""

def setup(self, model: 'Module') -> None:
"""Called by the accelerator to finish setup."""

@property
@abstractmethod
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,10 @@ def fit(
# ----------------------------
# SET UP TRAINING
# ----------------------------
self.call_setup_hook(model)
self.call_hook("on_before_accelerator_backend_setup", model)
self.accelerator.connect(model)
self.accelerator.setup_environment()
self.call_setup_hook(model) # allow user to setup lightning_module in accelerator environment
self.accelerator.setup(self, model) # note: this sets up self.lightning_module

# ----------------------------
Expand Down
30 changes: 21 additions & 9 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock):
"SLURM_LOCALID": "10"
}
)
def test_accelerator_choice_ddp_slurm():
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp_slurm(setup_distributed_mock):

class CB(Callback):

Expand Down Expand Up @@ -136,7 +137,8 @@ def on_fit_start(self, trainer, pl_module):
}
)
@mock.patch('torch.cuda.device_count', return_value=2)
def test_accelerator_choice_ddp2_slurm(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp2_slurm(device_count_mock, setup_distributed_mock):

class CB(Callback):

Expand Down Expand Up @@ -165,7 +167,8 @@ def on_fit_start(self, trainer, pl_module):
@RunIf(min_gpus=1)
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"})
@mock.patch('torch.cuda.device_count', return_value=2)
def test_accelerator_choice_ddp_te(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp_te(device_count_mock, setup_distributed_mock):

class CB(Callback):

Expand Down Expand Up @@ -193,7 +196,8 @@ def on_fit_start(self, trainer, pl_module):
@RunIf(min_gpus=1)
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"})
@mock.patch('torch.cuda.device_count', return_value=2)
def test_accelerator_choice_ddp2_te(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp2_te(device_count_mock, setup_distributed_mock):

class CB(Callback):

Expand Down Expand Up @@ -224,7 +228,8 @@ def on_fit_start(self, trainer, pl_module):
"NODE_RANK": "0",
})
@mock.patch('torch.cuda.device_count', return_value=0)
def test_accelerator_choice_ddp_cpu_te(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp_cpu_te(device_count_mock, setup_distributed_mock):

class CB(Callback):

Expand Down Expand Up @@ -259,7 +264,8 @@ def on_fit_start(self, trainer, pl_module):
}
)
@mock.patch('torch.cuda.device_count', return_value=0)
def test_accelerator_choice_ddp_cpu_slurm(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock):

class CB(Callback):

Expand Down Expand Up @@ -294,7 +300,8 @@ def on_fit_start(self, trainer, pl_module):
}
)
@mock.patch('torch.cuda.device_count', return_value=0)
def test_accelerator_choice_ddp_cpu_custom_cluster(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp_cpu_custom_cluster(device_count_mock, setup_distributed_mock):
"""
Test that we choose the custom cluster even when SLURM or TE flags are around
"""
Expand All @@ -304,6 +311,9 @@ class CustomCluster(LightningEnvironment):
def master_address(self):
return 'asdf'

def creates_children(self) -> bool:
return True

class CB(Callback):

def on_fit_start(self, trainer, pl_module):
Expand Down Expand Up @@ -336,7 +346,8 @@ def on_fit_start(self, trainer, pl_module):
}
)
@mock.patch('torch.cuda.device_count', return_value=0)
def test_custom_accelerator(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_custom_accelerator(device_count_mock, setup_distributed_mock):

class Accel(Accelerator):
pass
Expand Down Expand Up @@ -371,7 +382,8 @@ class TrainTypePlugin(SingleDevicePlugin):
}
)
@mock.patch('torch.cuda.device_count', return_value=0)
def test_dist_backend_accelerator_mapping(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_dist_backend_accelerator_mapping(device_count_mock, setup_distributed_mock):

class CB(Callback):

Expand Down
Loading

0 comments on commit 4e9b453

Please sign in to comment.