-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] Move init dist connection into the setup function #6506
Changes from all commits
6bf721e
1576176
7148ee6
4fd0c02
cbfa681
2a1dfbf
e9c3f83
2141a1f
96ca54f
ffe1c3f
1709cdb
708f97f
ec33b96
d782554
0c03487
edde60b
9d31742
9db893a
56ef252
cad0671
6b7d835
4651e57
d7ec33e
72097ba
bd2a53a
b5450de
136ddc5
0210f17
1bae940
69d6c32
91fff3a
88e2e09
3eced98
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,16 +80,16 @@ def distributed_sampler_kwargs(self): | |
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) | ||
return distributed_sampler_kwargs | ||
|
||
def setup(self, model): | ||
self._model = model | ||
|
||
def setup_environment(self): | ||
# start the other scripts | ||
if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1": | ||
self._call_children_scripts() | ||
|
||
# set the task idx | ||
self.task_idx = self.cluster_environment.local_rank() | ||
|
||
self.setup_distributed() | ||
|
||
def _call_children_scripts(self): | ||
|
||
# bookkeeping of spawned processes | ||
|
@@ -161,6 +161,34 @@ def _call_children_scripts(self): | |
delay = np.random.uniform(1, 5, 1)[0] | ||
sleep(delay) | ||
|
||
def setup_distributed(self): | ||
# TODO: check if needed | ||
seed = os.environ.get("PL_GLOBAL_SEED") | ||
if seed is not None: | ||
seed_everything(int(seed)) | ||
|
||
# determine which process we are and world size | ||
self.set_world_ranks() | ||
|
||
# set warning rank | ||
rank_zero_only.rank = self.global_rank | ||
|
||
# set up server using proc 0's ip address | ||
# try to init for 20 times at max in case ports are taken | ||
# where to store ip_table | ||
self.init_ddp_connection(self.global_rank, self.world_size) | ||
|
||
# on world_size=0 let everyone know training is starting | ||
if self.is_global_zero and not torch.distributed.is_initialized(): | ||
log.info("-" * 100) | ||
log.info(f"distributed_backend={self.distributed_backend}") | ||
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") | ||
log.info("-" * 100) | ||
Comment on lines
+183
to
+186
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. btw, shall we have this as a single message intend for 4 separate? |
||
|
||
# set the ranks and devices | ||
self.dist.rank = self.global_rank | ||
self.dist.device = self.root_device | ||
|
||
def _check_can_spawn_children(self): | ||
if self._has_spawned_children: | ||
raise RuntimeError( | ||
|
@@ -213,37 +241,6 @@ def init_ddp_connection(self, global_rank: int, world_size: int) -> None: | |
torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) | ||
|
||
def pre_dispatch(self): | ||
# TODO: check if needed | ||
seed = os.environ.get("PL_GLOBAL_SEED") | ||
if seed is not None: | ||
seed_everything(int(seed)) | ||
|
||
# determine which process we are and world size | ||
self.set_world_ranks() | ||
|
||
# set warning rank | ||
rank_zero_only.rank = self.global_rank | ||
|
||
# set up server using proc 0's ip address | ||
# try to init for 20 times at max in case ports are taken | ||
# where to store ip_table | ||
self.init_ddp_connection(self.global_rank, self.world_size) | ||
|
||
# TODO: we moved it to the trainer.fit after calling pre_dispatch | ||
# ... need to double check that it is the correct place | ||
# self.trainer.call_setup_hook(self.model) | ||
|
||
Comment on lines
-231
to
-235
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah my silly todo.... Thanks for double checking @SeanNaren 😄 |
||
# on world_size=0 let everyone know training is starting | ||
if self.is_global_zero and not torch.distributed.is_initialized(): | ||
log.info("-" * 100) | ||
log.info(f"distributed_backend={self.distributed_backend}") | ||
log.info(f"All DDP processes registered. Starting ddp with {self.world_size} processes") | ||
log.info("-" * 100) | ||
|
||
# set the ranks and devices | ||
self.dist.rank = self.global_rank | ||
self.dist.device = self.root_device | ||
|
||
if self.sync_batchnorm: | ||
self.model = self.configure_sync_batchnorm(self.model) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If anyone extended and made their own accelerator, this will be a breaking change so might need to handle a deprecation path here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to rename them ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't i guess? It's just a bit concerning because if I don't rename them,
connect_training_type_plugin
will be callingplugin.setup
, andconnect
will be callingtraining_type_plugin.connect
. Just confusing function names. I think if it becomes an issue we can make this BW compatible however in most cases it seems users should be defining plugins, not accelerators.