Skip to content

Commit

Permalink
Fix all gather while writing to a file during T5 finetuning (NVIDIA#5561
Browse files Browse the repository at this point in the history
) (NVIDIA#5564)

* Gather from data parallel only instead of all ranks

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

* Fix

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>
Co-authored-by: Sandeep Subramanian <sandeep.subramanian.1@umontreal.ca>
  • Loading branch information
2 people authored and titu1994 committed Mar 24, 2023
1 parent 956a554 commit 16bdfeb
Showing 1 changed file with 3 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg):
)

# Gather the outputs object from all data parallel ranks since we are using the DistributedSampler which splits data across DDP ranks.
gathered_outputs = [None for _ in range(self.world_size)]
gathered_outputs = [None for _ in range(parallel_state.get_data_parallel_world_size())]
torch.distributed.all_gather_object(
gathered_outputs,
[
Expand All @@ -439,6 +439,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg):
}
for x in output
],
group=parallel_state.get_data_parallel_group(),
)

# Figure out what the suffix of the file should be.
Expand All @@ -455,7 +456,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg):

# PTL models have a self.global_rank attribute and we want to write to disk only on global rank 0.
if self.global_rank == 0:
for rank in range(0, self.world_size):
for rank in range(0, parallel_state.get_data_parallel_world_size()):
for batch in gathered_outputs[rank]:
for pred, label, input, category in zip(
batch['preds'], batch['labels'], batch['inputs'], batch['categories']
Expand Down

0 comments on commit 16bdfeb

Please sign in to comment.