Skip to content

Commit

Permalink
fix the post ln (#4350)
Browse files Browse the repository at this point in the history
Signed-off-by: Yi Dong <yidong@nvidia.com>

Co-authored-by: Eric Harper <complex451@gmail.com>
  • Loading branch information
yidong72 and ericharper authored Jun 8, 2022
1 parent 4fa2156 commit 5f6f5a1
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,7 @@ def forward(
# Post-LN normalization after residual
if self.transformer_block_type == 'post_ln':
normalization_output = self.input_layernorm(layernorm_input)
layernorm_input = normalization_output
elif self.transformer_block_type in ['pre_ln', 'normformer']:
# Layer norm post the self attention.
normalization_output = self.post_attention_layernorm(layernorm_input)
Expand Down Expand Up @@ -1319,7 +1320,9 @@ def forward(

layernorm_input = bias_dropout_add_func(attention_output, attention_bias, residual, self.hidden_dropout)
normalization_output = self.post_inter_attention_layernorm(layernorm_input)

# Post-LN normalization after residual
if self.transformer_block_type == 'post_ln':
layernorm_input = normalization_output
# MLP.
mlp_output, mlp_bias = self.mlp(normalization_output)

Expand Down

0 comments on commit 5f6f5a1

Please sign in to comment.