Skip to content

Commit

Permalink
Fix OOM
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng-Ping Hsieh <chsieh@nvidia.com>
  • Loading branch information
hsiehjackson committed Aug 18, 2023
1 parent ef0e981 commit ae3f7d2
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,10 @@ def forward(
# convert to Megatron mask
if self.use_flash_attention:
enc_attn_mask_3d = enc_attn_mask < 0.5

enc_attn_mask_3d = attn_mask_postprocess(build_attention_mask_3d(
source_mask=enc_attn_mask, target_mask=enc_attn_mask, attn_mask_type=self.model_attn_mask_type,
))
else:
enc_attn_mask_3d = attn_mask_postprocess(build_attention_mask_3d(
source_mask=enc_attn_mask, target_mask=enc_attn_mask, attn_mask_type=self.model_attn_mask_type,
))

# transformer encoder
enc_output = self.model(
Expand Down

0 comments on commit ae3f7d2

Please sign in to comment.