Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Move init dist connection into the setup function #6506

Merged
merged 33 commits into from
Mar 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
6bf721e
Move connection setup into the setup function. Call setup hook after …
Mar 13, 2021
1576176
Added CHANGELOG.md
Mar 13, 2021
7148ee6
fix setup order in callback test
awaelchli Mar 13, 2021
4fd0c02
fix input arguments in test
awaelchli Mar 13, 2021
cbfa681
Mock distributed function, remove protection to turn into training ty…
Mar 13, 2021
2a1dfbf
Remove import
Mar 13, 2021
e9c3f83
Add missing mock, ensure custom plugin does not create children process
Mar 14, 2021
2141a1f
Merge branch 'master' into fix/setup_ddp_hook
Mar 14, 2021
96ca54f
Merge branch 'master' into fix/setup_ddp_hook
SeanNaren Mar 15, 2021
ffe1c3f
Skip test on windows
Mar 15, 2021
1709cdb
Update deepspeed to init connection in setup
Mar 15, 2021
708f97f
Do not initialize distributed module
Mar 15, 2021
ec33b96
Move DeepSpeed tests to special tests since dist communication is bei…
Mar 16, 2021
d782554
Merge branch 'master' into fix/setup_ddp_hook
Mar 16, 2021
0c03487
Special the test to see if this fixes CI
Mar 16, 2021
edde60b
Delete accelerator connector test to see if its causing build to fail
Mar 16, 2021
9d31742
Delete deepspeed test
Mar 16, 2021
9db893a
Revert "Delete accelerator connector test to see if its causing build…
Mar 16, 2021
56ef252
Revert "Delete deepspeed test"
Mar 16, 2021
cad0671
Reverse hook
Mar 16, 2021
6b7d835
Reverse setup hooks to debug again
Mar 16, 2021
4651e57
Add todo so i know where i left off
Mar 17, 2021
d7ec33e
For single device move in pre_dispatch after setup function
Mar 17, 2021
72097ba
Merge branch 'master' into fix/setup_ddp_hook
Mar 17, 2021
bd2a53a
Add additional model to device hook if any additional parameters have…
Mar 17, 2021
b5450de
See if we can enable deepspeed tests
Mar 17, 2021
136ddc5
Revert "See if we can enable deepspeed tests"
Mar 17, 2021
0210f17
See if this hook approach works
Mar 18, 2021
1bae940
Introduce new granular hooks
Mar 18, 2021
69d6c32
Remove import, fix tpu spawn by moving the function to setup
Mar 18, 2021
91fff3a
Added missing special test
Mar 18, 2021
88e2e09
Merge branch 'master' into fix/setup_ddp_hook
Mar 18, 2021
3eced98
Clean up the setup comment, since its run on train and test
Mar 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))


- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506))


## [1.2.4] - 2021-03-16

### Changed
Expand Down
34 changes: 21 additions & 13 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,28 @@ def __init__(
self.lr_schedulers: Sequence = []
self.optimizer_frequencies: Sequence = []

def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
def connect(self, model: LightningModule) -> None:
"""Transfers ownership of the model to this plugin"""
self.training_type_plugin.connect(model)

def setup_environment(self) -> None:
"""
Connects the plugins to the training process, creates optimizers
Setup any processes or distributed connections.
This is called before the LightningModule/DataModule setup hook
which allows the user to access the accelerator environment before setup is complete.
"""
self.training_type_plugin.setup_environment()

def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
"""
Setup plugins for the trainer fit and creates optimizers.
Args:
trainer: the trainer instance to connect to
model: the model to train
trainer: the trainer instance
model: the LightningModule
"""
self.connect_training_type_plugin(self.training_type_plugin, model)
self.setup_training_type_plugin(self.training_type_plugin, model)
self.setup_optimizers(trainer)
self.connect_precision_plugin(self.precision_plugin)
self.setup_precision_plugin(self.precision_plugin)
Comment on lines +87 to +89
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If anyone extended and made their own accelerator, this will be a breaking change so might need to handle a deprecation path here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to rename them ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't i guess? It's just a bit concerning because if I don't rename them, connect_training_type_plugin will be calling plugin.setup, and connect will be calling training_type_plugin.connect. Just confusing function names. I think if it becomes an issue we can make this BW compatible however in most cases it seems users should be defining plugins, not accelerators.


def start_training(self, trainer: 'Trainer') -> None:
self.training_type_plugin.start_training(trainer)
Expand Down Expand Up @@ -332,14 +343,11 @@ def setup_optimizers(self, trainer: 'Trainer') -> None:
self.lr_schedulers = lr_schedulers
self.optimizer_frequencies = optimizer_frequencies

def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
"""Attaches the training type plugin to the accelerator.
Also transfers ownership of the model to this plugin

"""
plugin.connect(model)
def setup_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None:
"""Attaches the training type plugin to the accelerator."""
plugin.setup(model)

def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None:
def setup_precision_plugin(self, plugin: PrecisionPlugin) -> None:
"""Attaches the precision plugin to the accelerator"""
model, optimizers, schedulers = plugin.connect(self.model, self.optimizers, self.lr_schedulers)
self.model = model
Expand Down
65 changes: 31 additions & 34 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,16 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
return distributed_sampler_kwargs

def setup(self, model):
self._model = model

def setup_environment(self):
# start the other scripts
if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1":
self._call_children_scripts()

# set the task idx
self.task_idx = self.cluster_environment.local_rank()

self.setup_distributed()

def _call_children_scripts(self):

# bookkeeping of spawned processes
Expand Down Expand Up @@ -161,6 +161,34 @@ def _call_children_scripts(self):
delay = np.random.uniform(1, 5, 1)[0]
sleep(delay)

def setup_distributed(self):
# TODO: check if needed
seed = os.environ.get("PL_GLOBAL_SEED")
if seed is not None:
seed_everything(int(seed))

# determine which process we are and world size
self.set_world_ranks()

# set warning rank
rank_zero_only.rank = self.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
self.init_ddp_connection(self.global_rank, self.world_size)

# on world_size=0 let everyone know training is starting
if self.is_global_zero and not torch.distributed.is_initialized():
log.info("-" * 100)
log.info(f"distributed_backend={self.distributed_backend}")
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
log.info("-" * 100)
Comment on lines +183 to +186
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, shall we have this as a single message intend for 4 separate?


# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device

def _check_can_spawn_children(self):
if self._has_spawned_children:
raise RuntimeError(
Expand Down Expand Up @@ -213,37 +241,6 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size)

def pre_dispatch(self):
# TODO: check if needed
seed = os.environ.get("PL_GLOBAL_SEED")
if seed is not None:
seed_everything(int(seed))

# determine which process we are and world size
self.set_world_ranks()

# set warning rank
rank_zero_only.rank = self.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
self.init_ddp_connection(self.global_rank, self.world_size)

# TODO: we moved it to the trainer.fit after calling pre_dispatch
# ... need to double check that it is the correct place
# self.trainer.call_setup_hook(self.model)

Comment on lines -231 to -235
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah my silly todo....
"need to double check that it is the correct place"

Thanks for double checking @SeanNaren 😄

# on world_size=0 let everyone know training is starting
if self.is_global_zero and not torch.distributed.is_initialized():
log.info("-" * 100)
log.info(f"distributed_backend={self.distributed_backend}")
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes")
log.info("-" * 100)

# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device

if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)

Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ def distributed_sampler_kwargs(self):
return distributed_sampler_kwargs

def setup(self, model):
self._model = model

os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())

# pass in a state q
Expand Down
10 changes: 0 additions & 10 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,17 +192,7 @@ def _load_config(self, config):
return config

def pre_dispatch(self):
self.set_world_ranks()
self.init_ddp_connection(self.global_rank, self.world_size)

self.init_deepspeed()

# set warning rank
rank_zero_only.rank = self.global_rank

# set the ranks and devices
self.dist.rank = self.global_rank
self.dist.device = self.root_device
self.barrier()

def init_deepspeed(self):
Expand Down
8 changes: 0 additions & 8 deletions pytorch_lightning/plugins/training_type/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,6 @@ def on_gpu(self):
def lightning_module(self):
return unwrap_lightning_module(self._model)

@abstractmethod
def setup(self, model):
raise NotImplementedError

def connect(self, model, *args, **kwargs):
self.setup(model)
return self.model

@property
def is_global_zero(self) -> bool:
return self.global_rank == 0
Expand Down
3 changes: 1 addition & 2 deletions pytorch_lightning/plugins/training_type/single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def model_to_device(self) -> None:

self._model.to(self.root_device)

def connect(self, model: torch.nn.Module) -> torch.nn.Module:
self._model = model
def setup(self, model: torch.nn.Module) -> torch.nn.Module:
self.model_to_device()
return self.model

Expand Down
7 changes: 1 addition & 6 deletions pytorch_lightning/plugins/training_type/single_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,8 @@ def __init__(self, device: Union[torch.device, int]):
def on_tpu(self) -> bool:
return True

def connect(self, model: torch.nn.Module) -> torch.nn.Module:
self._model = model
self.model_to_device()
return self._model

def model_to_device(self) -> None:
self._model.to(self.root_device)
self.model.to(self.root_device)

def pre_dispatch(self) -> None:
if isinstance(self.device, int):
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ def __init__(
self.tpu_local_core_rank = 0
self.start_method = None

def connect(self, model: torch.nn.Module) -> torch.nn.Module:
def setup(self, model: torch.nn.Module) -> torch.nn.Module:
self.create_mp_queue()
self._model = model
return self._model
return self.model

def create_mp_queue(self):
self.start_method = 'fork'
Expand Down
14 changes: 12 additions & 2 deletions pytorch_lightning/plugins/training_type/training_type_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,19 @@ def __init__(self) -> None:
self._model = None
self._results = None

@abstractmethod
def connect(self, model: 'Module') -> None:
"""Called by the accelerator to connect it with this plugin"""
"""Called by the accelerator to connect the accelerator and the model with this plugin"""
self.model = model

def setup_environment(self) -> None:
"""
Setup any processes or distributed connections.
This is called before the LightningModule/DataModule setup hook
which allows the user to access the accelerator environment before setup is complete.
"""

def setup(self, model: 'Module') -> None:
"""Called by the accelerator to finish setup."""

@property
@abstractmethod
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,10 @@ def fit(
# ----------------------------
# SET UP TRAINING
# ----------------------------
self.call_setup_hook(model)
self.call_hook("on_before_accelerator_backend_setup", model)
self.accelerator.connect(model)
self.accelerator.setup_environment()
self.call_setup_hook(model) # allow user to setup lightning_module in accelerator environment
self.accelerator.setup(self, model) # note: this sets up self.lightning_module

# ----------------------------
Expand Down
30 changes: 21 additions & 9 deletions tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def test_accelerator_choice_ddp_spawn(cuda_available_mock, device_count_mock):
"SLURM_LOCALID": "10"
}
)
def test_accelerator_choice_ddp_slurm():
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp_slurm(setup_distributed_mock):

class CB(Callback):

Expand Down Expand Up @@ -136,7 +137,8 @@ def on_fit_start(self, trainer, pl_module):
}
)
@mock.patch('torch.cuda.device_count', return_value=2)
def test_accelerator_choice_ddp2_slurm(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp2_slurm(device_count_mock, setup_distributed_mock):

class CB(Callback):

Expand Down Expand Up @@ -165,7 +167,8 @@ def on_fit_start(self, trainer, pl_module):
@RunIf(min_gpus=1)
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"})
@mock.patch('torch.cuda.device_count', return_value=2)
def test_accelerator_choice_ddp_te(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp_te(device_count_mock, setup_distributed_mock):

class CB(Callback):

Expand Down Expand Up @@ -193,7 +196,8 @@ def on_fit_start(self, trainer, pl_module):
@RunIf(min_gpus=1)
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2", "LOCAL_RANK": "10", "NODE_RANK": "0"})
@mock.patch('torch.cuda.device_count', return_value=2)
def test_accelerator_choice_ddp2_te(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp2_te(device_count_mock, setup_distributed_mock):

class CB(Callback):

Expand Down Expand Up @@ -224,7 +228,8 @@ def on_fit_start(self, trainer, pl_module):
"NODE_RANK": "0",
})
@mock.patch('torch.cuda.device_count', return_value=0)
def test_accelerator_choice_ddp_cpu_te(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
def test_accelerator_choice_ddp_cpu_te(device_count_mock, setup_distributed_mock):

class CB(Callback):

Expand Down Expand Up @@ -259,7 +264,8 @@ def on_fit_start(self, trainer, pl_module):
}
)
@mock.patch('torch.cuda.device_count', return_value=0)
def test_accelerator_choice_ddp_cpu_slurm(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp_cpu_slurm(device_count_mock, setup_distributed_mock):

class CB(Callback):

Expand Down Expand Up @@ -294,7 +300,8 @@ def on_fit_start(self, trainer, pl_module):
}
)
@mock.patch('torch.cuda.device_count', return_value=0)
def test_accelerator_choice_ddp_cpu_custom_cluster(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_accelerator_choice_ddp_cpu_custom_cluster(device_count_mock, setup_distributed_mock):
"""
Test that we choose the custom cluster even when SLURM or TE flags are around
"""
Expand All @@ -304,6 +311,9 @@ class CustomCluster(LightningEnvironment):
def master_address(self):
return 'asdf'

def creates_children(self) -> bool:
return True

class CB(Callback):

def on_fit_start(self, trainer, pl_module):
Expand Down Expand Up @@ -336,7 +346,8 @@ def on_fit_start(self, trainer, pl_module):
}
)
@mock.patch('torch.cuda.device_count', return_value=0)
def test_custom_accelerator(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_custom_accelerator(device_count_mock, setup_distributed_mock):

class Accel(Accelerator):
pass
Expand Down Expand Up @@ -371,7 +382,8 @@ class TrainTypePlugin(SingleDevicePlugin):
}
)
@mock.patch('torch.cuda.device_count', return_value=0)
def test_dist_backend_accelerator_mapping(device_count_mock):
@mock.patch('pytorch_lightning.plugins.DDPPlugin.setup_distributed', autospec=True)
def test_dist_backend_accelerator_mapping(device_count_mock, setup_distributed_mock):

class CB(Callback):

Expand Down
Loading