Skip to content

Commit

Permalink
adding return_moe_loss to bert model (bigscience-workshop#204)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Jipa <azzhipa@amazon.com>
  • Loading branch information
Alexander Jipa and azzhipa authored Aug 9, 2023
1 parent 712057c commit 319c9b1
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions megatron/model/bert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def __init__(self,
add_binary_head=True,
parallel_output=True,
pre_process=True,
post_process=True):
post_process=True,
return_moe_loss=False):
super().__init__(config=config)
args = get_args()

Expand All @@ -137,6 +138,7 @@ def __init__(self,
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
self.return_moe_loss = return_moe_loss

self.return_embeddings = args.output_bert_embeddings
if self.return_embeddings:
Expand Down Expand Up @@ -183,7 +185,7 @@ def forward(self, bert_model_input, attention_mask,
)

if self.post_process and self.add_binary_head:
lm_output, pooled_output, _ = lm_output
lm_output, pooled_output, moe_losses = lm_output

# Return pooled output (e.g., when computing Bert embeddings).
if self.return_embeddings:
Expand All @@ -206,12 +208,14 @@ def forward(self, bert_model_input, attention_mask,
pooled_output = None

if self.post_process:
lm_output = lm_output if self.add_binary_head else lm_output[0]
return post_language_model_processing(lm_output, pooled_output,
self.lm_head, self.binary_head,
lm_labels,
self.shared_embedding_or_output_weight(),
self.fp16_lm_cross_entropy)
if not self.add_binary_head:
lm_output, moe_losses = lm_output
lm_output = post_language_model_processing(lm_output, pooled_output,
self.lm_head, self.binary_head,
lm_labels,
self.shared_embedding_or_output_weight(),
self.fp16_lm_cross_entropy)
return *lm_output, moe_losses if self.return_moe_loss else lm_output
else:
return lm_output

Expand Down

0 comments on commit 319c9b1

Please sign in to comment.