Skip to content

Commit

Permalink
Remove broadcast (#5558)
Browse files Browse the repository at this point in the history
Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>

Signed-off-by: MaximumEntropy <sandeep.subramanian.1@umontreal.ca>
  • Loading branch information
MaximumEntropy authored and web-flow committed Dec 7, 2022
1 parent 52aac8e commit 8c19e33
Showing 1 changed file with 0 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -464,15 +464,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A
else:
encoder_input = torch.zeros((batch_size, seq_length, self.hidden_size), dtype=self.autocast_dtype).cuda()

if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
# Broadcasting encoder inputs to all ranks for now, but this is inefficent.
# TODO: Make Enc-Dec improvement to only boardcast encoder_ids/embeddings when needed
torch.distributed.broadcast(
encoder_input,
parallel_state.get_pipeline_model_parallel_first_rank(),
group=parallel_state.get_pipeline_model_parallel_group(),
)

predicted_token_ids, log_probs = self.frozen_model.decode(
tokens_enc=input_ids,
enc_mask=enc_mask,
Expand Down

0 comments on commit 8c19e33

Please sign in to comment.