diff --git a/megatron/model/fused_softmax.py b/megatron/model/fused_softmax.py index c0c3309b7..6fd055f0d 100644 --- a/megatron/model/fused_softmax.py +++ b/megatron/model/fused_softmax.py @@ -16,12 +16,7 @@ import torch import torch.nn as nn - -from megatron import logging from megatron.enums import AttnMaskType -from megatron.model.utils import log_debug_usage - -logger = logging.get_logger(__name__) class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): """ @@ -158,7 +153,6 @@ def is_kernel_available(self, mask, b, np, sq, sk): return True return False - @log_debug_usage(logger, "Using fused softmax") def forward_fused_softmax(self, input, mask): b, np, sq, sk = input.size() scale = self.scale if self.scale is not None else 1.0 @@ -174,7 +168,6 @@ def forward_fused_softmax(self, input, mask): # input is 4D tensor (b, np, sq, sk) return ScaledMaskedSoftmax.apply(input, mask, scale) - @log_debug_usage(logger, "Using torch softmax") def forward_torch_softmax(self, input, mask): if self.input_in_float16 and self.softmax_in_fp32: input = input.float() diff --git a/tests/test_model.py b/tests/test_model.py index 05eb13751..189fcce01 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -6,7 +6,7 @@ import torch from megatron import initialize_megatron, get_args, get_tokenizer, global_vars -from megatron.testing_utils import TestCasePlus, mockenv_context, CaptureStdout +from megatron.testing_utils import TestCasePlus, mockenv_context from megatron.training import setup_model_and_optimizer from pretrain_gpt import model_provider as gpt_model_provider, get_batch_pipe as get_gpt_batch_pipe from pretrain_prefix_lm import model_provider as prefix_lm_model_provider, get_batch_pipe as get_prefix_lm_batch_pipe @@ -49,10 +49,6 @@ def get_default_args(): "--checkpoint-activations": "", # DATA_ARGS - - # LOGGING_ARGS - "--log-level": "debug", - "--log-level-replica": "info", } @@ -102,7 +98,7 @@ def test_gpt(self): model, _, _ = setup_model_and_optimizer(gpt_model_provider) model = model[0] - token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length + 1)) + token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length)) # eod is a special token token_ids[token_ids == tokenizer.eod] += 1 @@ -145,96 +141,91 @@ def test_prefix_lm_reset_attention_mask(self): with patch('sys.argv', flatten_arguments(command_args)): with mockenv_context(**self.dist_env_1_gpu): deepspeed.init_distributed() + initialize_megatron() + args = get_args() + tokenizer = get_tokenizer() + + model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider) + model = model[0] + + token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length)) + + # eod is a special token, this also guarantees that the whole row is considered as a document. + token_ids[token_ids == tokenizer.eod] += 1 + token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size + + # process batch to have non empty prefix + input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids}) - with CaptureStdout() as cs: - initialize_megatron() - args = get_args() - tokenizer = get_tokenizer() - - model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider) - model = model[0] - - token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length + 1)) - - # eod is a special token, this also guarantees that the whole row is considered as a document. - token_ids[token_ids == tokenizer.eod] += 1 - token_ids[token_ids == tokenizer.eod] %= args.padded_vocab_size - - # process batch to have non empty prefix - input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids}) - - for batch_id in range(len(prefix_indices)): - for id in prefix_indices[batch_id]: - self.assertTrue(loss_mask[batch_id, id] == 1) - self.assertTrue(id > 0) - # Make sure that the last prefix token predicts the first token. - self.assertTrue(loss_mask[batch_id, id -1] == 1) - - output = model(*input_batch) - - ## --------------- CHANGE A TARGET TOKEN --------------------------- - # get a modified version of the first batch - # guaranteed to exist as each row has at least one partial document - changed_target_index = prefix_indices[0][0] - token_ids_changed_target = input_batch[0].clone() - # We increment the token id on the changed index. - token_ids_changed_target[0, changed_target_index] = \ - (token_ids_changed_target[0, changed_target_index] + 1) % args.padded_vocab_size - # make sure we're not changing a token to eod as it's a special token - token_ids_changed_target[token_ids_changed_target == tokenizer.eod] += 1 - token_ids_changed_target[token_ids_changed_target == tokenizer.eod] %= args.padded_vocab_size - - # Test change - output_changed_target = model(token_ids_changed_target, *input_batch[1:]) - - # All token in past should be unchanged - self.assertTrue( - torch.all( - equal_vectors(output[0, :changed_target_index], output_changed_target[0, :changed_target_index]) - ) + for batch_id in range(len(prefix_indices)): + for id in prefix_indices[batch_id]: + self.assertTrue(loss_mask[batch_id, id] == 1) + self.assertTrue(id > 0) + # Make sure that the last prefix token predicts the first token. + self.assertTrue(loss_mask[batch_id, id -1] == 1) + + output = model(*input_batch) + + ## --------------- CHANGE A TARGET TOKEN --------------------------- + # get a modified version of the first batch + # guaranteed to exist as each row has at least one partial document + changed_target_index = prefix_indices[0][0] + token_ids_changed_target = input_batch[0].clone() + # We increment the token id on the changed index. + token_ids_changed_target[0, changed_target_index] = \ + (token_ids_changed_target[0, changed_target_index] + 1) % args.padded_vocab_size + # make sure we're not changing a token to eod as it's a special token + token_ids_changed_target[token_ids_changed_target == tokenizer.eod] += 1 + token_ids_changed_target[token_ids_changed_target == tokenizer.eod] %= args.padded_vocab_size + + # Test change + output_changed_target = model(token_ids_changed_target, *input_batch[1:]) + + # All token in past should be unchanged + self.assertTrue( + torch.all( + equal_vectors(output[0, :changed_target_index], output_changed_target[0, :changed_target_index]) ) - # All tokens in the future should have changed - self.assertFalse( - torch.any( - equal_vectors(output[0, changed_target_index:], output_changed_target[0, changed_target_index:]) - ) + ) + # All tokens in the future should have changed + self.assertFalse( + torch.any( + equal_vectors(output[0, changed_target_index:], output_changed_target[0, changed_target_index:]) ) - # Unchanged changed rows should not change either - self.assertTrue( - torch.all( - equal_vectors(output[1, :], output_changed_target[1, :]) - ) + ) + # Unchanged changed rows should not change either + self.assertTrue( + torch.all( + equal_vectors(output[1, :], output_changed_target[1, :]) ) + ) - ## --------------- CHANGE AN INPUT TOKEN --------------------------- - # Let's change the the last prefix token and make sure that the first token changed - # guaranteed to be positive as we avoid pathological case previously - last_prefix_index = prefix_indices[0][0] - 1 - token_ids_changed_input = input_batch[0].clone() - # We increment the token id on the changed index. - token_ids_changed_input[0, last_prefix_index] = \ - (token_ids_changed_input[0, last_prefix_index] + 1) % args.padded_vocab_size - # make sure we're not changing a token to eod as it's a special token - token_ids_changed_input[token_ids_changed_input == tokenizer.eod] += 1 - token_ids_changed_input[token_ids_changed_input == tokenizer.eod] %= args.padded_vocab_size - - output_changed_input = model(token_ids_changed_input, *input_batch[1:]) - - # All tokens should be changed - self.assertFalse( - torch.any( - equal_vectors(output[0, :], output_changed_input[0, :]) - ) + ## --------------- CHANGE AN INPUT TOKEN --------------------------- + # Let's change the the last prefix token and make sure that the first token changed + # guaranteed to be positive as we avoid pathological case previously + last_prefix_index = prefix_indices[0][0] - 1 + token_ids_changed_input = input_batch[0].clone() + # We increment the token id on the changed index. + token_ids_changed_input[0, last_prefix_index] = \ + (token_ids_changed_input[0, last_prefix_index] + 1) % args.padded_vocab_size + # make sure we're not changing a token to eod as it's a special token + token_ids_changed_input[token_ids_changed_input == tokenizer.eod] += 1 + token_ids_changed_input[token_ids_changed_input == tokenizer.eod] %= args.padded_vocab_size + + output_changed_input = model(token_ids_changed_input, *input_batch[1:]) + + # All tokens should be changed + self.assertFalse( + torch.any( + equal_vectors(output[0, :], output_changed_input[0, :]) ) - # Unchanged changed rows should not change either - self.assertTrue( - torch.all( - equal_vectors(output[1, :], output_changed_input[1, :]) - ) + ) + # Unchanged changed rows should not change either + self.assertTrue( + torch.all( + equal_vectors(output[1, :], output_changed_input[1, :]) ) - - self.assertIn("Using fused softmax", cs.out) - self.assertNotIn("Using torch softmax", cs.out) + ) def test_prefix_lm_wo_reset_attention_mask(self): """ @@ -250,28 +241,24 @@ def test_prefix_lm_wo_reset_attention_mask(self): with patch('sys.argv', flatten_arguments(command_args)): with mockenv_context(**self.dist_env_1_gpu): deepspeed.init_distributed() + initialize_megatron() + args = get_args() - with CaptureStdout() as cs: - initialize_megatron() - args = get_args() - - model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider) - model = model[0] + model, _, _ = setup_model_and_optimizer(prefix_lm_model_provider) + model = model[0] - token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length + 1)) - input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids}) + token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length)) + input_batch, (_, loss_mask), prefix_indices = get_prefix_lm_batch_pipe({"text": token_ids}) - for batch_id in range(len(prefix_indices)): - id = prefix_indices[batch_id] - self.assertTrue(loss_mask[batch_id, id] == 1) - self.assertTrue(id > 0) - # Make sure that the last prefix token predicts the first token. - self.assertTrue(loss_mask[batch_id, id -1] == 1) + for batch_id in range(len(prefix_indices)): + id = prefix_indices[batch_id] + self.assertTrue(loss_mask[batch_id, id] == 1) + self.assertTrue(id > 0) + # Make sure that the last prefix token predicts the first token. + self.assertTrue(loss_mask[batch_id, id -1] == 1) - model(*input_batch) + model(*input_batch) - self.assertIn("Using fused softmax", cs.out) - self.assertNotIn("Using torch softmax", cs.out) #TODO: Check all invariants def test_gpt_rotary_embeddings(self): @@ -291,7 +278,7 @@ def test_gpt_rotary_embeddings(self): model, _, _ = setup_model_and_optimizer(gpt_model_provider) model = model[0] - token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length + 1)) + token_ids = torch.randint(args.padded_vocab_size, (args.micro_batch_size, args.seq_length)) # eod is a special token token_ids[token_ids == tokenizer.eod] += 1 diff --git a/tests/test_training.py b/tests/test_training.py index b9918e2f5..6b652280d 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -317,7 +317,7 @@ def test_training_prefix_lm_all(self, loss_on_targets_only, reweight_loss_based_ --num-layers 2 --hidden-size 64 - --num-attention-heads 4 + --num-attention-heads 2 --seq-length 128 --max-position-embeddings 1024 --micro-batch-size 1 @@ -392,10 +392,6 @@ def test_training_prefix_lm_all(self, loss_on_targets_only, reweight_loss_based_ tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*") self.assertEqual(len(tensorboard_files), 1, "tensorboard files") - # test use scaled softmax - self.assertIn("Using fused softmax", cs.out) - self.assertNotIn("Using torch softmax", cs.out) - if reweight_loss_based_on_position_frequency: self.assertIn("Using loss reweighting", cs.out)