Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasw21 committed Nov 26, 2021
1 parent b3cf175 commit 46d5c33
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
15 changes: 15 additions & 0 deletions megatron/model/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

import torch
import torch.nn as nn

from megatron import logging
from megatron.enums import AttnMaskType
from megatron.model import utils

logger = logging.get_logger(__name__)

class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Expand Down Expand Up @@ -134,6 +139,13 @@ def forward(self, input, mask):
def is_kernel_available(self, mask, b, np, sq, sk):
attn_batches = b * np

print("is_kernel_available")
print(f"self.self.scaled_masked_softmax_fusion = {self.scaled_masked_softmax_fusion}")
print(f"self.input_in_float16 = {self.input_in_float16}")
print(f"mask is not None = {mask is not None}")
print(f"sq = {sq}")
print(f"sk = {sk}")
print(f"attn_batches = {attn_batches}")
if (
self.scaled_masked_softmax_fusion # user want to fuse
and self.input_in_float16 # input must be fp16
Expand All @@ -149,10 +161,12 @@ def is_kernel_available(self, mask, b, np, sq, sk):
if attn_batches % batch_per_block == 0:
return True
else:
print(sq, batch_per_block)
if sq % batch_per_block == 0:
return True
return False

@utils.log_debug_usage(logger, "Using fused softmax")
def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0
Expand All @@ -168,6 +182,7 @@ def forward_fused_softmax(self, input, mask):
# input is 4D tensor (b, np, sq, sk)
return ScaledMaskedSoftmax.apply(input, mask, scale)

@utils.log_debug_usage(logger, "Using torch softmax")
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
Expand Down
14 changes: 12 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from megatron import initialize_megatron, get_args, get_tokenizer, global_vars
from megatron.testing_utils import TestCasePlus, mockenv_context
from megatron.testing_utils import TestCasePlus, mockenv_context, CaptureStdout
from megatron.training import setup_model_and_optimizer
from pretrain_gpt import model_provider as gpt_model_provider, get_batch_pipe as get_gpt_batch_pipe
from pretrain_prefix_lm import model_provider as prefix_lm_model_provider, get_batch_pipe as get_prefix_lm_batch_pipe
Expand Down Expand Up @@ -49,6 +49,10 @@ def get_default_args():
"--checkpoint-activations": "",

# DATA_ARGS

# LOGGING_ARGS
"--log-level": "debug",
"--log-level-replica": "info",
}


Expand Down Expand Up @@ -257,10 +261,16 @@ def test_prefix_lm_wo_reset_attention_mask(self):
# Make sure that the last prefix token predicts the first token.
self.assertTrue(loss_mask[batch_id, id -1] == 1)

model(*input_batch)
with CaptureStdout() as cs:
model(*input_batch)

self.assertIn("Using fused softmax", cs.out)
self.assertIn("Using torch softmax", cs.out)

#TODO: Check all invariants



def test_gpt_rotary_embeddings(self):
"""Test rotary embeddings"""
command_args = get_default_args()
Expand Down

0 comments on commit 46d5c33

Please sign in to comment.