From 16bdfebeb7f9e0e8a3a7b640ca507213703ae45b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 7 Dec 2022 10:44:31 -0700 Subject: [PATCH] Fix all gather while writing to a file during T5 finetuning (#5561) (#5564) * Gather from data parallel only instead of all ranks Signed-off-by: MaximumEntropy * Fix Signed-off-by: MaximumEntropy Signed-off-by: MaximumEntropy Signed-off-by: MaximumEntropy Co-authored-by: Sandeep Subramanian --- .../nlp/models/language_modeling/megatron_finetune_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py b/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py index c49d7b50580a..e445bf7f4482 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py @@ -427,7 +427,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg): ) # Gather the outputs object from all data parallel ranks since we are using the DistributedSampler which splits data across DDP ranks. - gathered_outputs = [None for _ in range(self.world_size)] + gathered_outputs = [None for _ in range(parallel_state.get_data_parallel_world_size())] torch.distributed.all_gather_object( gathered_outputs, [ @@ -439,6 +439,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg): } for x in output ], + group=parallel_state.get_data_parallel_group(), ) # Figure out what the suffix of the file should be. @@ -455,7 +456,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg): # PTL models have a self.global_rank attribute and we want to write to disk only on global rank 0. if self.global_rank == 0: - for rank in range(0, self.world_size): + for rank in range(0, parallel_state.get_data_parallel_world_size()): for batch in gathered_outputs[rank]: for pred, label, input, category in zip( batch['preds'], batch['labels'], batch['inputs'], batch['categories']