Skip to content

Commit

Permalink
Turns out there's no issue with the way we build prefix lm
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasw21 committed Nov 26, 2021
1 parent 46d5c33 commit e7a12e7
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 100 deletions.
10 changes: 6 additions & 4 deletions megatron/model/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from megatron import logging
from megatron.enums import AttnMaskType
from megatron.model import utils
from megatron.model.utils import log_debug_usage

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -154,19 +154,21 @@ def is_kernel_available(self, mask, b, np, sq, sk):
and sk % 4 == 0 # sk must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
):
print("wtf")
if 0 <= sk <= 2048:
batch_per_block = self.get_batch_per_block(sq, sk, b, np)

if self.attn_mask_type == AttnMaskType.causal:
print(f"causal {attn_batches} / {batch_per_block}")
if attn_batches % batch_per_block == 0:
return True
else:
print(sq, batch_per_block)
print(f"non causal {sq} / {batch_per_block}")
if sq % batch_per_block == 0:
return True
return False

@utils.log_debug_usage(logger, "Using fused softmax")
@log_debug_usage(logger, "Using fused softmax")
def forward_fused_softmax(self, input, mask):
b, np, sq, sk = input.size()
scale = self.scale if self.scale is not None else 1.0
Expand All @@ -182,7 +184,7 @@ def forward_fused_softmax(self, input, mask):
# input is 4D tensor (b, np, sq, sk)
return ScaledMaskedSoftmax.apply(input, mask, scale)

@utils.log_debug_usage(logger, "Using torch softmax")
@log_debug_usage(logger, "Using torch softmax")
def forward_torch_softmax(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()
Expand Down
195 changes: 100 additions & 95 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_gpt(self):
model, _, _ = setup_model_and_optimizer(gpt_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))
token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length + 1))

# eod is a special token
token_ids[token_ids == tokenizer.eod] += 1
Expand Down Expand Up @@ -145,91 +145,96 @@ def test_prefix_lm_reset_attention_mask(self):
with patch('sys.argv', flatten_arguments(command_args)):
with mockenv_context(**self.dist_env_1_gpu):
deepspeed.init_distributed()
initialize_megatron()
args = get_args()
tokenizer = get_tokenizer()

model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))

# eod is a special token, this also guarantees that the whole row is considered as a document.
token_ids[token_ids == tokenizer.eod] += 1
token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size

# process batch to have non empty prefix
input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})

for batch_id in range(len(prefix_indices)):
for id in prefix_indices[batch_id]:
self.assertTrue(loss_mask[batch_id, id] == 1)
self.assertTrue(id > 0)
# Make sure that the last prefix token predicts the first token.
self.assertTrue(loss_mask[batch_id, id -1] == 1)

output = model(*input_batch)

## --------------- CHANGE A TARGET TOKEN ---------------------------
# get a modified version of the first batch
# guaranteed to exist as each row has at least one partial document
changed_target_index = prefix_indices[0][0]
token_ids_changed_target = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_target[0, changed_target_index] = \
(token_ids_changed_target[0, changed_target_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] += 1
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] %= args.padded_vocab_size

# Test change
output_changed_target = model(token_ids_changed_target, *input_batch[1:])

# All token in past should be unchanged
self.assertTrue(
torch.all(
equal_vectors(output[0, :changed_target_index], output_changed_target[0, :changed_target_index])
with CaptureStdout() as cs:
initialize_megatron()
args = get_args()
tokenizer = get_tokenizer()

model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length + 1))

# eod is a special token, this also guarantees that the whole row is considered as a document.
token_ids[token_ids == tokenizer.eod] += 1
token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size

# process batch to have non empty prefix
input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})

for batch_id in range(len(prefix_indices)):
for id in prefix_indices[batch_id]:
self.assertTrue(loss_mask[batch_id, id] == 1)
self.assertTrue(id > 0)
# Make sure that the last prefix token predicts the first token.
self.assertTrue(loss_mask[batch_id, id -1] == 1)

output = model(*input_batch)

## --------------- CHANGE A TARGET TOKEN ---------------------------
# get a modified version of the first batch
# guaranteed to exist as each row has at least one partial document
changed_target_index = prefix_indices[0][0]
token_ids_changed_target = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_target[0, changed_target_index] = \
(token_ids_changed_target[0, changed_target_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] += 1
token_ids_changed_target[token_ids_changed_target == tokenizer.eod] %= args.padded_vocab_size

# Test change
output_changed_target = model(token_ids_changed_target, *input_batch[1:])

# All token in past should be unchanged
self.assertTrue(
torch.all(
equal_vectors(output[0, :changed_target_index], output_changed_target[0, :changed_target_index])
)
)
)
# All tokens in the future should have changed
self.assertFalse(
torch.any(
equal_vectors(output[0, changed_target_index:], output_changed_target[0, changed_target_index:])
# All tokens in the future should have changed
self.assertFalse(
torch.any(
equal_vectors(output[0, changed_target_index:], output_changed_target[0, changed_target_index:])
)
)
)
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_target[1, :])
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_target[1, :])
)
)
)

## --------------- CHANGE AN INPUT TOKEN ---------------------------
# Let's change the the last prefix token and make sure that the first token changed
# guaranteed to be positive as we avoid pathological case previously
last_prefix_index = prefix_indices[0][0] - 1
token_ids_changed_input = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_input[0, last_prefix_index] = \
(token_ids_changed_input[0, last_prefix_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] += 1
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] %= args.padded_vocab_size

output_changed_input = model(token_ids_changed_input, *input_batch[1:])

# All tokens should be changed
self.assertFalse(
torch.any(
equal_vectors(output[0, :], output_changed_input[0, :])
## --------------- CHANGE AN INPUT TOKEN ---------------------------
# Let's change the the last prefix token and make sure that the first token changed
# guaranteed to be positive as we avoid pathological case previously
last_prefix_index = prefix_indices[0][0] - 1
token_ids_changed_input = input_batch[0].clone()
# We increment the token id on the changed index.
token_ids_changed_input[0, last_prefix_index] = \
(token_ids_changed_input[0, last_prefix_index] + 1) % args.padded_vocab_size
# make sure we're not changing a token to eod as it's a special token
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] += 1
token_ids_changed_input[token_ids_changed_input == tokenizer.eod] %= args.padded_vocab_size

output_changed_input = model(token_ids_changed_input, *input_batch[1:])

# All tokens should be changed
self.assertFalse(
torch.any(
equal_vectors(output[0, :], output_changed_input[0, :])
)
)
)
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_input[1, :])
# Unchanged changed rows should not change either
self.assertTrue(
torch.all(
equal_vectors(output[1, :], output_changed_input[1, :])
)
)
)

self.assertIn("Using fused softmax", cs.out)
self.assertNotIn("Using torch softmax", cs.out)

def test_prefix_lm_wo_reset_attention_mask(self):
"""
Expand All @@ -245,28 +250,28 @@ def test_prefix_lm_wo_reset_attention_mask(self):
with patch('sys.argv', flatten_arguments(command_args)):
with mockenv_context(**self.dist_env_1_gpu):
deepspeed.init_distributed()
initialize_megatron()
args = get_args()

model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]
with CaptureStdout() as cs:
initialize_megatron()
args = get_args()

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))
input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})
model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider)
model = model[0]

for batch_id in range(len(prefix_indices)):
id = prefix_indices[batch_id]
self.assertTrue(loss_mask[batch_id, id] == 1)
self.assertTrue(id > 0)
# Make sure that the last prefix token predicts the first token.
self.assertTrue(loss_mask[batch_id, id -1] == 1)
token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length + 1))
input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids})

for batch_id in range(len(prefix_indices)):
id = prefix_indices[batch_id]
self.assertTrue(loss_mask[batch_id, id] == 1)
self.assertTrue(id > 0)
# Make sure that the last prefix token predicts the first token.
self.assertTrue(loss_mask[batch_id, id -1] == 1)

with CaptureStdout() as cs:
model(*input_batch)

self.assertIn("Using fused softmax", cs.out)
self.assertIn("Using torch softmax", cs.out)

self.assertNotIn("Using torch softmax", cs.out)
#TODO: Check all invariants


Expand All @@ -288,7 +293,7 @@ def test_gpt_rotary_embeddings(self):
model, _, _ = setup_model_and_optimizer(gpt_model_provider)
model = model[0]

token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length))
token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length + 1))

# eod is a special token
token_ids[token_ids == tokenizer.eod] += 1
Expand Down
6 changes: 5 additions & 1 deletion tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def test_training_prefix_lm_all(self, loss_on_targets_only, reweight_loss_based_
--num-layers 2
--hidden-size 64
--num-attention-heads 2
--num-attention-heads 4
--seq-length 128
--max-position-embeddings 1024
--micro-batch-size 1
Expand Down Expand Up @@ -392,6 +392,10 @@ def test_training_prefix_lm_all(self, loss_on_targets_only, reweight_loss_based_
tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*")
self.assertEqual(len(tensorboard_files), 1, "tensorboard files")

# test use scaled softmax
self.assertIn("Using fused softmax", cs.out)
self.assertNotIn("Using torch softmax", cs.out)

if reweight_loss_based_on_position_frequency:
self.assertIn("Using loss reweighting", cs.out)

Expand Down

0 comments on commit e7a12e7

Please sign in to comment.