From b046495edc863174390b1afa20523e084fab4b78 Mon Sep 17 00:00:00 2001 From: MaximumEntropy Date: Tue, 6 Dec 2022 22:17:57 -0800 Subject: [PATCH 1/2] Gather from data parallel only instead of all ranks Signed-off-by: MaximumEntropy --- .../nlp/models/language_modeling/megatron_finetune_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py b/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py index 941048304f6a..bbffe8753dd2 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py @@ -427,7 +427,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg): ) # Gather the outputs object from all data parallel ranks since we are using the DistributedSampler which splits data across DDP ranks. - gathered_outputs = [None for _ in range(self.world_size)] + gathered_outputs = [None for _ in range(parallel_state.get_data_parallel_world_size())] torch.distributed.all_gather_object( gathered_outputs, [ @@ -439,6 +439,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg): } for x in output ], + group=parallel_state.get_data_parallel_group(), ) # Figure out what the suffix of the file should be. From 652bdb0681f73950b1951dac73429cd70d4aba34 Mon Sep 17 00:00:00 2001 From: MaximumEntropy Date: Wed, 7 Dec 2022 00:40:21 -0800 Subject: [PATCH 2/2] Fix Signed-off-by: MaximumEntropy --- .../nlp/models/language_modeling/megatron_finetune_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py b/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py index bbffe8753dd2..46d6455327af 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py @@ -456,7 +456,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg): # PTL models have a self.global_rank attribute and we want to write to disk only on global rank 0. if self.global_rank == 0: - for rank in range(0, self.world_size): + for rank in range(0, parallel_state.get_data_parallel_world_size()): for batch in gathered_outputs[rank]: for pred, label, input, category in zip( batch['preds'], batch['labels'], batch['inputs'], batch['categories']