Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Checking we use fused kernels to compute scaled masked softmax on prefix lm #209

Merged
merged 3 commits into from
Nov 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions megatron/model/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

import torch
import torch.nn as nn

from megatron import logging
from megatron.enums import AttnMaskType
from megatron.model.utils import log_debug_usage

logger = logging.get_logger(__name__)

class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Expand Down Expand Up @@ -153,6 +158,7 @@ def is_kernel_available(self, mask, b, np, sq, sk):
return True
return False

@log_debug_usage(logger, "Using fused softmax")
def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0
Expand All @@ -168,6 +174,7 @@ def forward_fused_softmax(self, input, mask):
# input is 4D tensor (b, np, sq, sk)
return ScaledMaskedSoftmax.apply(input, mask, scale)

@log_debug_usage(logger, "Using torch softmax")
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
Expand Down
201 changes: 107 additions & 94 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from megatron import initialize_megatron, get_args, get_tokenizer, global_vars
from megatron.testing_utils import TestCasePlus, mockenv_context
from megatron.testing_utils import TestCasePlus, mockenv_context, CaptureStdout
from megatron.training import setup_model_and_optimizer
from pretrain_gpt import model_provider as gpt_model_provider, get_batch_pipe as get_gpt_batch_pipe
from pretrain_prefix_lm import model_provider as prefix_lm_model_provider, get_batch_pipe as get_prefix_lm_batch_pipe
Expand Down Expand Up @@ -49,6 +49,10 @@ def get_default_args():
"--checkpoint-activations": "",

# DATA_ARGS

# LOGGING_ARGS
"--log-level": "debug",
"--log-level-replica": "info",
}


Expand Down Expand Up @@ -98,7 +102,7 @@ def test_gpt(self):
model, _, _ = setup_model_and_optimizer(gpt_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))
token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length + 1))

# eod is a special token
token_ids[token_ids == tokenizer.eod] += 1
Expand Down Expand Up @@ -141,91 +145,96 @@ def test_prefix_lm_reset_attention_mask(self):
with patch('sys.argv', flatten_arguments(command_args)):
with mockenv_context(**self.dist_env_1_gpu):
deepspeed.init_distributed()
initialize_megatron()
args = get_args()
tokenizer = get_tokenizer()

model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))

# eod is a special token, this also guarantees that the whole row is considered as a document.
token_ids[token_ids == tokenizer.eod] += 1
token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size

# process batch to have non empty prefix
input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})

for batch_id in range(len(prefix_indices)):
for id in prefix_indices[batch_id]:
self.assertTrue(loss_mask[batch_id, id] == 1)
self.assertTrue(id > 0)
# Make sure that the last prefix token predicts the first token.
self.assertTrue(loss_mask[batch_id, id -1] == 1)

output = model(*input_batch)

## --------------- CHANGE A TARGET TOKEN ---------------------------
# get a modified version of the first batch
# guaranteed to exist as each row has at least one partial document
changed_target_index = prefix_indices[0][0]
token_ids_changed_target = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_target[0, changed_target_index] = \
(token_ids_changed_target[0, changed_target_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] += 1
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] %= args.padded_vocab_size

# Test change
output_changed_target = model(token_ids_changed_target, *input_batch[1:])

# All token in past should be unchanged
self.assertTrue(
torch.all(
equal_vectors(output[0, :changed_target_index], output_changed_target[0, :changed_target_index])
with CaptureStdout() as cs:
initialize_megatron()
args = get_args()
tokenizer = get_tokenizer()

model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length + 1))

# eod is a special token, this also guarantees that the whole row is considered as a document.
token_ids[token_ids == tokenizer.eod] += 1
token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size

# process batch to have non empty prefix
input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})

for batch_id in range(len(prefix_indices)):
for id in prefix_indices[batch_id]:
self.assertTrue(loss_mask[batch_id, id] == 1)
self.assertTrue(id > 0)
# Make sure that the last prefix token predicts the first token.
self.assertTrue(loss_mask[batch_id, id -1] == 1)

output = model(*input_batch)

## --------------- CHANGE A TARGET TOKEN ---------------------------
# get a modified version of the first batch
# guaranteed to exist as each row has at least one partial document
changed_target_index = prefix_indices[0][0]
token_ids_changed_target = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_target[0, changed_target_index] = \
(token_ids_changed_target[0, changed_target_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] += 1
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] %= args.padded_vocab_size

# Test change
output_changed_target = model(token_ids_changed_target, *input_batch[1:])

# All token in past should be unchanged
self.assertTrue(
torch.all(
equal_vectors(output[0, :changed_target_index], output_changed_target[0, :changed_target_index])
)
)
)
# All tokens in the future should have changed
self.assertFalse(
torch.any(
equal_vectors(output[0, changed_target_index:], output_changed_target[0, changed_target_index:])
# All tokens in the future should have changed
self.assertFalse(
torch.any(
equal_vectors(output[0, changed_target_index:], output_changed_target[0, changed_target_index:])
)
)
)
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_target[1, :])
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_target[1, :])
)
)
)

## --------------- CHANGE AN INPUT TOKEN ---------------------------
# Let's change the the last prefix token and make sure that the first token changed
# guaranteed to be positive as we avoid pathological case previously
last_prefix_index = prefix_indices[0][0] - 1
token_ids_changed_input = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_input[0, last_prefix_index] = \
(token_ids_changed_input[0, last_prefix_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] += 1
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] %= args.padded_vocab_size

output_changed_input = model(token_ids_changed_input, *input_batch[1:])

# All tokens should be changed
self.assertFalse(
torch.any(
equal_vectors(output[0, :], output_changed_input[0, :])
## --------------- CHANGE AN INPUT TOKEN ---------------------------
# Let's change the the last prefix token and make sure that the first token changed
# guaranteed to be positive as we avoid pathological case previously
last_prefix_index = prefix_indices[0][0] - 1
token_ids_changed_input = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_input[0, last_prefix_index] = \
(token_ids_changed_input[0, last_prefix_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] += 1
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] %= args.padded_vocab_size

output_changed_input = model(token_ids_changed_input, *input_batch[1:])

# All tokens should be changed
self.assertFalse(
torch.any(
equal_vectors(output[0, :], output_changed_input[0, :])
)
)
)
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_input[1, :])
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_input[1, :])
)
)
)

self.assertIn("Using fused softmax", cs.out)
self.assertNotIn("Using torch softmax", cs.out)

def test_prefix_lm_wo_reset_attention_mask(self):
"""
Expand All @@ -241,24 +250,28 @@ def test_prefix_lm_wo_reset_attention_mask(self):
with patch('sys.argv', flatten_arguments(command_args)):
with mockenv_context(**self.dist_env_1_gpu):
deepspeed.init_distributed()
initialize_megatron()
args = get_args()

model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]
with CaptureStdout() as cs:
initialize_megatron()
args = get_args()

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))
input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})
model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]

for batch_id in range(len(prefix_indices)):
id = prefix_indices[batch_id]
self.assertTrue(loss_mask[batch_id, id] == 1)
self.assertTrue(id > 0)
# Make sure that the last prefix token predicts the first token.
self.assertTrue(loss_mask[batch_id, id -1] == 1)
token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length + 1))
input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})

model(*input_batch)
for batch_id in range(len(prefix_indices)):
id = prefix_indices[batch_id]
self.assertTrue(loss_mask[batch_id, id] == 1)
self.assertTrue(id > 0)
# Make sure that the last prefix token predicts the first token.
self.assertTrue(loss_mask[batch_id, id -1] == 1)

model(*input_batch)

self.assertIn("Using fused softmax", cs.out)
self.assertNotIn("Using torch softmax", cs.out)
#TODO: Check all invariants

def test_gpt_rotary_embeddings(self):
Expand All @@ -278,7 +291,7 @@ def test_gpt_rotary_embeddings(self):
model, _, _ = setup_model_and_optimizer(gpt_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))
token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length + 1))

# eod is a special token
token_ids[token_ids == tokenizer.eod] += 1
Expand Down
6 changes: 5 additions & 1 deletion tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def test_training_prefix_lm_all(self, loss_on_targets_only, reweight_loss_based_

--num-layers 2
--hidden-size 64
--num-attention-heads 2
--num-attention-heads 4
--seq-length 128
--max-position-embeddings 1024
--micro-batch-size 1
Expand Down Expand Up @@ -392,6 +392,10 @@ def test_training_prefix_lm_all(self, loss_on_targets_only, reweight_loss_based_
tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*")
self.assertEqual(len(tensorboard_files), 1, "tensorboard files")

# test use scaled softmax
self.assertIn("Using fused softmax", cs.out)
self.assertNotIn("Using torch softmax", cs.out)

if reweight_loss_based_on_position_frequency:
self.assertIn("Using loss reweighting", cs.out)

Expand Down