Skip to content

Commit

Permalink
Checking we use fused kernels to compute scaled masked softmax on pre…
Browse files Browse the repository at this point in the history
…fix lm (#209)
  • Loading branch information
thomasw21 authored Nov 26, 2021
1 parent b3cf175 commit b227590
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 95 deletions.
7 changes: 7 additions & 0 deletions megatron/model/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

import torch
import torch.nn as nn

from megatron import logging
from megatron.enums import AttnMaskType
from megatron.model.utils import log_debug_usage

logger = logging.get_logger(__name__)

class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Expand Down Expand Up @@ -153,6 +158,7 @@ def is_kernel_available(self, mask, b, np, sq, sk):
return True
return False

@log_debug_usage(logger, "Using fused softmax")
def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0
Expand All @@ -168,6 +174,7 @@ def forward_fused_softmax(self, input, mask):
# input is 4D tensor (b, np, sq, sk)
return ScaledMaskedSoftmax.apply(input, mask, scale)

@log_debug_usage(logger, "Using torch softmax")
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
Expand Down
201 changes: 107 additions & 94 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from megatron import initialize_megatron, get_args, get_tokenizer, global_vars
from megatron.testing_utils import TestCasePlus, mockenv_context
from megatron.testing_utils import TestCasePlus, mockenv_context, CaptureStdout
from megatron.training import setup_model_and_optimizer
from pretrain_gpt import model_provider as gpt_model_provider, get_batch_pipe as get_gpt_batch_pipe
from pretrain_prefix_lm import model_provider as prefix_lm_model_provider, get_batch_pipe as get_prefix_lm_batch_pipe
Expand Down Expand Up @@ -49,6 +49,10 @@ def get_default_args():
"--checkpoint-activations": "",

# DATA_ARGS

# LOGGING_ARGS
"--log-level": "debug",
"--log-level-replica": "info",
}


Expand Down Expand Up @@ -98,7 +102,7 @@ def test_gpt(self):
model, _, _ = setup_model_and_optimizer(gpt_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))
token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length + 1))

# eod is a special token
token_ids[token_ids == tokenizer.eod] += 1
Expand Down Expand Up @@ -141,91 +145,96 @@ def test_prefix_lm_reset_attention_mask(self):
with patch('sys.argv', flatten_arguments(command_args)):
with mockenv_context(**self.dist_env_1_gpu):
deepspeed.init_distributed()
initialize_megatron()
args = get_args()
tokenizer = get_tokenizer()

model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))

# eod is a special token, this also guarantees that the whole row is considered as a document.
token_ids[token_ids == tokenizer.eod] += 1
token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size

# process batch to have non empty prefix
input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})

for batch_id in range(len(prefix_indices)):
for id in prefix_indices[batch_id]:
self.assertTrue(loss_mask[batch_id, id] == 1)
self.assertTrue(id > 0)
# Make sure that the last prefix token predicts the first token.
self.assertTrue(loss_mask[batch_id, id -1] == 1)

output = model(*input_batch)

## --------------- CHANGE A TARGET TOKEN ---------------------------
# get a modified version of the first batch
# guaranteed to exist as each row has at least one partial document
changed_target_index = prefix_indices[0][0]
token_ids_changed_target = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_target[0, changed_target_index] = \
(token_ids_changed_target[0, changed_target_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] += 1
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] %= args.padded_vocab_size

# Test change
output_changed_target = model(token_ids_changed_target, *input_batch[1:])

# All token in past should be unchanged
self.assertTrue(
torch.all(
equal_vectors(output[0, :changed_target_index], output_changed_target[0, :changed_target_index])
with CaptureStdout() as cs:
initialize_megatron()
args = get_args()
tokenizer = get_tokenizer()

model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length + 1))

# eod is a special token, this also guarantees that the whole row is considered as a document.
token_ids[token_ids == tokenizer.eod] += 1
token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size

# process batch to have non empty prefix
input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})

for batch_id in range(len(prefix_indices)):
for id in prefix_indices[batch_id]:
self.assertTrue(loss_mask[batch_id, id] == 1)
self.assertTrue(id > 0)
# Make sure that the last prefix token predicts the first token.
self.assertTrue(loss_mask[batch_id, id -1] == 1)

output = model(*input_batch)

## --------------- CHANGE A TARGET TOKEN ---------------------------
# get a modified version of the first batch
# guaranteed to exist as each row has at least one partial document
changed_target_index = prefix_indices[0][0]
token_ids_changed_target = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_target[0, changed_target_index] = \
(token_ids_changed_target[0, changed_target_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] += 1
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] %= args.padded_vocab_size

# Test change
output_changed_target = model(token_ids_changed_target, *input_batch[1:])

# All token in past should be unchanged
self.assertTrue(
torch.all(
equal_vectors(output[0, :changed_target_index], output_changed_target[0, :changed_target_index])
)
)
)
# All tokens in the future should have changed
self.assertFalse(
torch.any(
equal_vectors(output[0, changed_target_index:], output_changed_target[0, changed_target_index:])
# All tokens in the future should have changed
self.assertFalse(
torch.any(
equal_vectors(output[0, changed_target_index:], output_changed_target[0, changed_target_index:])
)
)
)
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_target[1, :])
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_target[1, :])
)
)
)

## --------------- CHANGE AN INPUT TOKEN ---------------------------
# Let's change the the last prefix token and make sure that the first token changed
# guaranteed to be positive as we avoid pathological case previously
last_prefix_index = prefix_indices[0][0] - 1
token_ids_changed_input = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_input[0, last_prefix_index] = \
(token_ids_changed_input[0, last_prefix_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] += 1
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] %= args.padded_vocab_size

output_changed_input = model(token_ids_changed_input, *input_batch[1:])

# All tokens should be changed
self.assertFalse(
torch.any(
equal_vectors(output[0, :], output_changed_input[0, :])
## --------------- CHANGE AN INPUT TOKEN ---------------------------
# Let's change the the last prefix token and make sure that the first token changed
# guaranteed to be positive as we avoid pathological case previously
last_prefix_index = prefix_indices[0][0] - 1
token_ids_changed_input = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_input[0, last_prefix_index] = \
(token_ids_changed_input[0, last_prefix_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] += 1
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] %= args.padded_vocab_size

output_changed_input = model(token_ids_changed_input, *input_batch[1:])

# All tokens should be changed
self.assertFalse(
torch.any(
equal_vectors(output[0, :], output_changed_input[0, :])
)
)
)
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_input[1, :])
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_input[1, :])
)
)
)

self.assertIn("Using fused softmax", cs.out)
self.assertNotIn("Using torch softmax", cs.out)

def test_prefix_lm_wo_reset_attention_mask(self):
"""
Expand All @@ -241,24 +250,28 @@ def test_prefix_lm_wo_reset_attention_mask(self):
with patch('sys.argv', flatten_arguments(command_args)):
with mockenv_context(**self.dist_env_1_gpu):
deepspeed.init_distributed()
initialize_megatron()
args = get_args()

model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]
with CaptureStdout() as cs:
initialize_megatron()
args = get_args()

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))
input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})
model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]

for batch_id in range(len(prefix_indices)):
id = prefix_indices[batch_id]
self.assertTrue(loss_mask[batch_id, id] == 1)
self.assertTrue(id > 0)
# Make sure that the last prefix token predicts the first token.
self.assertTrue(loss_mask[batch_id, id -1] == 1)
token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length + 1))
input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})

model(*input_batch)
for batch_id in range(len(prefix_indices)):
id = prefix_indices[batch_id]
self.assertTrue(loss_mask[batch_id, id] == 1)
self.assertTrue(id > 0)
# Make sure that the last prefix token predicts the first token.
self.assertTrue(loss_mask[batch_id, id -1] == 1)

model(*input_batch)

self.assertIn("Using fused softmax", cs.out)
self.assertNotIn("Using torch softmax", cs.out)
#TODO: Check all invariants

def test_gpt_rotary_embeddings(self):
Expand All @@ -278,7 +291,7 @@ def test_gpt_rotary_embeddings(self):
model, _, _ = setup_model_and_optimizer(gpt_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))
token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length + 1))

# eod is a special token
token_ids[token_ids == tokenizer.eod] += 1
Expand Down
6 changes: 5 additions & 1 deletion tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def test_training_prefix_lm_all(self, loss_on_targets_only, reweight_loss_based_
--num-layers 2
--hidden-size 64
--num-attention-heads 2
--num-attention-heads 4
--seq-length 128
--max-position-embeddings 1024
--micro-batch-size 1
Expand Down Expand Up @@ -392,6 +392,10 @@ def test_training_prefix_lm_all(self, loss_on_targets_only, reweight_loss_based_
tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*")
self.assertEqual(len(tensorboard_files), 1, "tensorboard files")

# test use scaled softmax
self.assertIn("Using fused softmax", cs.out)
self.assertNotIn("Using torch softmax", cs.out)

if reweight_loss_based_on_position_frequency:
self.assertIn("Using loss reweighting", cs.out)

Expand Down

0 comments on commit b227590

Please sign in to comment.