Skip to content

Commit

Permalink
Revert "Checking we use fused kernels to compute scaled masked softma…
Browse files Browse the repository at this point in the history
…x on prefix lm (#209)"

This reverts commit b227590.
  • Loading branch information
thomasw21 committed Nov 27, 2021
1 parent b227590 commit c9afebc
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 119 deletions.
7 changes: 0 additions & 7 deletions megatron/model/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,7 @@

import torch
import torch.nn as nn

from megatron import logging
from megatron.enums import AttnMaskType
from megatron.model.utils import log_debug_usage

logger = logging.get_logger(__name__)

class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Expand Down Expand Up @@ -158,7 +153,6 @@ def is_kernel_available(self, mask, b, np, sq, sk):
return True
return False

@log_debug_usage(logger, "Using fused softmax")
def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0
Expand All @@ -174,7 +168,6 @@ def forward_fused_softmax(self, input, mask):
# input is 4D tensor (b, np, sq, sk)
return ScaledMaskedSoftmax.apply(input, mask, scale)

@log_debug_usage(logger, "Using torch softmax")
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
Expand Down
201 changes: 94 additions & 107 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from megatron import initialize_megatron, get_args, get_tokenizer, global_vars
from megatron.testing_utils import TestCasePlus, mockenv_context, CaptureStdout
from megatron.testing_utils import TestCasePlus, mockenv_context
from megatron.training import setup_model_and_optimizer
from pretrain_gpt import model_provider as gpt_model_provider, get_batch_pipe as get_gpt_batch_pipe
from pretrain_prefix_lm import model_provider as prefix_lm_model_provider, get_batch_pipe as get_prefix_lm_batch_pipe
Expand Down Expand Up @@ -49,10 +49,6 @@ def get_default_args():
"--checkpoint-activations": "",

# DATA_ARGS

# LOGGING_ARGS
"--log-level": "debug",
"--log-level-replica": "info",
}


Expand Down Expand Up @@ -102,7 +98,7 @@ def test_gpt(self):
model, _, _ = setup_model_and_optimizer(gpt_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length + 1))
token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))

# eod is a special token
token_ids[token_ids == tokenizer.eod] += 1
Expand Down Expand Up @@ -145,96 +141,91 @@ def test_prefix_lm_reset_attention_mask(self):
with patch('sys.argv', flatten_arguments(command_args)):
with mockenv_context(**self.dist_env_1_gpu):
deepspeed.init_distributed()
initialize_megatron()
args = get_args()
tokenizer = get_tokenizer()

model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))

# eod is a special token, this also guarantees that the whole row is considered as a document.
token_ids[token_ids == tokenizer.eod] += 1
token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size

# process batch to have non empty prefix
input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})

with CaptureStdout() as cs:
initialize_megatron()
args = get_args()
tokenizer = get_tokenizer()

model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length + 1))

# eod is a special token, this also guarantees that the whole row is considered as a document.
token_ids[token_ids == tokenizer.eod] += 1
token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size

# process batch to have non empty prefix
input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})

for batch_id in range(len(prefix_indices)):
for id in prefix_indices[batch_id]:
self.assertTrue(loss_mask[batch_id, id] == 1)
self.assertTrue(id > 0)
# Make sure that the last prefix token predicts the first token.
self.assertTrue(loss_mask[batch_id, id -1] == 1)

output = model(*input_batch)

## --------------- CHANGE A TARGET TOKEN ---------------------------
# get a modified version of the first batch
# guaranteed to exist as each row has at least one partial document
changed_target_index = prefix_indices[0][0]
token_ids_changed_target = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_target[0, changed_target_index] = \
(token_ids_changed_target[0, changed_target_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] += 1
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] %= args.padded_vocab_size

# Test change
output_changed_target = model(token_ids_changed_target, *input_batch[1:])

# All token in past should be unchanged
self.assertTrue(
torch.all(
equal_vectors(output[0, :changed_target_index], output_changed_target[0, :changed_target_index])
)
for batch_id in range(len(prefix_indices)):
for id in prefix_indices[batch_id]:
self.assertTrue(loss_mask[batch_id, id] == 1)
self.assertTrue(id > 0)
# Make sure that the last prefix token predicts the first token.
self.assertTrue(loss_mask[batch_id, id -1] == 1)

output = model(*input_batch)

## --------------- CHANGE A TARGET TOKEN ---------------------------
# get a modified version of the first batch
# guaranteed to exist as each row has at least one partial document
changed_target_index = prefix_indices[0][0]
token_ids_changed_target = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_target[0, changed_target_index] = \
(token_ids_changed_target[0, changed_target_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] += 1
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] %= args.padded_vocab_size

# Test change
output_changed_target = model(token_ids_changed_target, *input_batch[1:])

# All token in past should be unchanged
self.assertTrue(
torch.all(
equal_vectors(output[0, :changed_target_index], output_changed_target[0, :changed_target_index])
)
# All tokens in the future should have changed
self.assertFalse(
torch.any(
equal_vectors(output[0, changed_target_index:], output_changed_target[0, changed_target_index:])
)
)
# All tokens in the future should have changed
self.assertFalse(
torch.any(
equal_vectors(output[0, changed_target_index:], output_changed_target[0, changed_target_index:])
)
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_target[1, :])
)
)
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_target[1, :])
)
)

## --------------- CHANGE AN INPUT TOKEN ---------------------------
# Let's change the the last prefix token and make sure that the first token changed
# guaranteed to be positive as we avoid pathological case previously
last_prefix_index = prefix_indices[0][0] - 1
token_ids_changed_input = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_input[0, last_prefix_index] = \
(token_ids_changed_input[0, last_prefix_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] += 1
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] %= args.padded_vocab_size

output_changed_input = model(token_ids_changed_input, *input_batch[1:])

# All tokens should be changed
self.assertFalse(
torch.any(
equal_vectors(output[0, :], output_changed_input[0, :])
)
## --------------- CHANGE AN INPUT TOKEN ---------------------------
# Let's change the the last prefix token and make sure that the first token changed
# guaranteed to be positive as we avoid pathological case previously
last_prefix_index = prefix_indices[0][0] - 1
token_ids_changed_input = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_input[0, last_prefix_index] = \
(token_ids_changed_input[0, last_prefix_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] += 1
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] %= args.padded_vocab_size

output_changed_input = model(token_ids_changed_input, *input_batch[1:])

# All tokens should be changed
self.assertFalse(
torch.any(
equal_vectors(output[0, :], output_changed_input[0, :])
)
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_input[1, :])
)
)
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_input[1, :])
)

self.assertIn("Using fused softmax", cs.out)
self.assertNotIn("Using torch softmax", cs.out)
)

def test_prefix_lm_wo_reset_attention_mask(self):
"""
Expand All @@ -250,28 +241,24 @@ def test_prefix_lm_wo_reset_attention_mask(self):
with patch('sys.argv', flatten_arguments(command_args)):
with mockenv_context(**self.dist_env_1_gpu):
deepspeed.init_distributed()
initialize_megatron()
args = get_args()

with CaptureStdout() as cs:
initialize_megatron()
args = get_args()

model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]
model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length + 1))
input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})
token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))
input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})

for batch_id in range(len(prefix_indices)):
id = prefix_indices[batch_id]
self.assertTrue(loss_mask[batch_id, id] == 1)
self.assertTrue(id > 0)
# Make sure that the last prefix token predicts the first token.
self.assertTrue(loss_mask[batch_id, id -1] == 1)
for batch_id in range(len(prefix_indices)):
id = prefix_indices[batch_id]
self.assertTrue(loss_mask[batch_id, id] == 1)
self.assertTrue(id > 0)
# Make sure that the last prefix token predicts the first token.
self.assertTrue(loss_mask[batch_id, id -1] == 1)

model(*input_batch)
model(*input_batch)

self.assertIn("Using fused softmax", cs.out)
self.assertNotIn("Using torch softmax", cs.out)
#TODO: Check all invariants

def test_gpt_rotary_embeddings(self):
Expand All @@ -291,7 +278,7 @@ def test_gpt_rotary_embeddings(self):
model, _, _ = setup_model_and_optimizer(gpt_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length + 1))
token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))

# eod is a special token
token_ids[token_ids == tokenizer.eod] += 1
Expand Down
6 changes: 1 addition & 5 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def test_training_prefix_lm_all(self, loss_on_targets_only, reweight_loss_based_
--num-layers 2
--hidden-size 64
--num-attention-heads 4
--num-attention-heads 2
--seq-length 128
--max-position-embeddings 1024
--micro-batch-size 1
Expand Down Expand Up @@ -392,10 +392,6 @@ def test_training_prefix_lm_all(self, loss_on_targets_only, reweight_loss_based_
tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*")
self.assertEqual(len(tensorboard_files), 1, "tensorboard files")

# test use scaled softmax
self.assertIn("Using fused softmax", cs.out)
self.assertNotIn("Using torch softmax", cs.out)

if reweight_loss_based_on_position_frequency:
self.assertIn("Using loss reweighting", cs.out)

Expand Down

0 comments on commit c9afebc

Please sign in to comment.