diff --git a/megatron/arguments.py b/megatron/arguments.py index 06330558e..294f9f85b 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -875,6 +875,11 @@ def __call__(self, parser, args, values, option_string=None): help='Mask loss for the end of document tokens.') group.add_argument('--loss-on-targets-only', action='store_true', help='Mask loss on input sequence.') + group.add_argument('--reweight-loss-based-on-position-frequency', action="store_true", + help='Some objectives require us to sample loss_mask. This might introduce bias towards ' + 'specific positions. This option tries to un-bias the loss by reweighting loss on specific ' + 'positions based on how frequently we train on that position.' + 'This is mostly used for prefix_lm training') return parser diff --git a/megatron/model/glu_activations.py b/megatron/model/glu_activations.py index cf136cb02..c479d9683 100644 --- a/megatron/model/glu_activations.py +++ b/megatron/model/glu_activations.py @@ -1,10 +1,9 @@ -from functools import wraps - import torch from torch import nn from torch.nn import functional as F from megatron import logging +from megatron.model.utils import log_debug_usage logger = logging.get_logger(__name__) @@ -38,21 +37,11 @@ class SwiGLU(_GLUBaseModule): def __init__(self): super().__init__(F.silu) -def log_debug_usage(func, msg: str): - func.__logged_message__ = False - @wraps(func) - def wrapped(*args, **kwargs): - if func.__logged_message__ is False: - logger.debug(msg) - func.__logged_message__ = True - return func(*args, **kwargs) - return wrapped - - -liglu = log_debug_usage(torch.jit.script(LiGLU()), "Using GLU activation: LiGLU.") -geglu = log_debug_usage(torch.jit.script(GEGLU()), "Using GLU activation: GELU.") -reglu = log_debug_usage(torch.jit.script(ReGLU()), "Using GLU activation: ReGLU.") -swiglu = log_debug_usage(torch.jit.script(SwiGLU()), "Using GLU activation: SwiGLU.") + +liglu = log_debug_usage(logger, "Using GLU activation: LiGLU.")(torch.jit.script(LiGLU())) +geglu = log_debug_usage(logger, "Using GLU activation: GELU.")(torch.jit.script(GEGLU())) +reglu = log_debug_usage(logger, "Using GLU activation: ReGLU.")(torch.jit.script(ReGLU())) +swiglu = log_debug_usage(logger, "Using GLU activation: SwiGLU.")(torch.jit.script(SwiGLU())) GLU_ACTIVATIONS = { diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index cb762d48b..9abc4799e 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -169,14 +169,23 @@ def CrossEntropy(output, labels): if is_prefix: micro_batch_size, sequence_length = loss_mask.shape - expected_number_of_tokens = micro_batch_size * sequence_length + average_tokens_per_sample: torch.Tensor if args.loss_on_targets_only: # HACK: This is useful when we obtain loss masks that are microbatch dependent. Consequently, if we want to # preserve the notion that all tokens have the same impact on the loss, we can only normalise using a - # microbatch independent value. + # microbatch independent value. It should be expected weight over a microbatch. # Here we still use `sequence_length`, that's batch size dependent, in order to be backwards compatible with # current experiment on vanilla gpt. - expected_number_of_tokens /= 2 + if args.reweight_loss_based_on_position_frequency: + reweight = torch.arange( + sequence_length, 0, -1, dtype=torch.float, device=loss_mask.device + ) / (sequence_length + 1) * 2 + average_tokens_per_sample = reweight.flip(-1).cumsum(-1).mean() + else: + average_tokens_per_sample = (sequence_length + 1) / 2 + else: + average_tokens_per_sample = sequence_length + expected_number_of_tokens = average_tokens_per_sample * micro_batch_size else: expected_number_of_tokens = loss_mask.sum() diff --git a/megatron/model/utils.py b/megatron/model/utils.py index 492f1f10b..8c3908a93 100644 --- a/megatron/model/utils.py +++ b/megatron/model/utils.py @@ -16,6 +16,7 @@ """Utilities for models.""" import math +from functools import wraps import torch @@ -73,3 +74,18 @@ def openai_gelu(x): @torch.jit.script def erf_gelu(x): return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype)) + +def log_debug_usage(logger, msg: str): + def log_debug_usage_(func): + """Helper function in order to log a message when using a function for the first time""" + func.__logged_message__ = False + + @wraps(func) + def wrapped(*args, **kwargs): + if func.__logged_message__ is False: + logger.debug(msg) + func.__logged_message__ = True + return func(*args, **kwargs) + + return wrapped + return log_debug_usage_ diff --git a/megatron/utils.py b/megatron/utils.py index f234d1650..56cc6a622 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -26,14 +26,17 @@ from apex.multi_tensor_apply import multi_tensor_applier import amp_C -from megatron import get_args +from megatron import get_args, logging from megatron import print_rank_0 from megatron import get_adlr_autoresume from megatron import mpu from megatron.model.module import param_is_not_shared +from megatron.model.utils import log_debug_usage from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate, VocabParallelEmbedding from megatron import get_num_microbatches +logger = logging.get_logger(__name__) + def unwrap_model(model, module_instances=(torchDDP)): return_list = True if not isinstance(model, list): @@ -371,3 +374,12 @@ def get_prefix_indices(data, eod_token, partial_prefix_indices, reset_attention_ prefix_indices.append(prefix_index) return prefix_indices + + +@log_debug_usage(logger, "Using loss reweighting") +def reweight_loss_mask_(loss_mask: torch.Tensor, tokens: torch.Tensor): + """Reweight loss mask in-place""" + _, seq_length = tokens.shape + weight_loss = torch.arange(seq_length, 0, -1, dtype=torch.float, device=loss_mask.device) / (seq_length + 1) * 2 + # in-place operation + loss_mask *= weight_loss[None, :] \ No newline at end of file diff --git a/pretrain_prefix_lm.py b/pretrain_prefix_lm.py index a6f1a55a6..391186e75 100644 --- a/pretrain_prefix_lm.py +++ b/pretrain_prefix_lm.py @@ -25,12 +25,11 @@ from megatron.data.gpt_dataset import build_train_valid_test_datasets, build_dataset_group from megatron.model import GPTModel, GPTModelPipe from megatron.training import pretrain -from megatron.utils import get_ltor_masks_and_position_ids, get_prefix_indices +from megatron.utils import get_ltor_masks_and_position_ids, get_prefix_indices, reweight_loss_mask_ from megatron.utils import average_losses_across_data_parallel_group import deepspeed from deepspeed.runtime.utils import see_memory_usage -import os import subprocess def model_provider(pre_process=True, post_process=True): @@ -108,6 +107,10 @@ def get_batch(data_iterator): loss_on_targets_only=args.loss_on_targets_only ) + # weight loss_mask + if args.reweight_loss_based_on_position_frequency: + reweight_loss_mask_(loss_mask, tokens) + return tokens, labels, loss_mask, attention_mask, position_ids def get_batch_pipe(data): @@ -146,6 +149,10 @@ def get_batch_pipe(data): loss_on_targets_only=args.loss_on_targets_only ) + # weight loss_mask + if args.reweight_loss_based_on_position_frequency: + reweight_loss_mask_(loss_mask, tokens) + return (tokens, position_ids, attention_mask), (labels, loss_mask), prefix_indices def loss_func(loss_mask, output_tensor): diff --git a/tests/test_training.py b/tests/test_training.py index ef6ae643b..6b652280d 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -298,8 +298,8 @@ def test_training_all(self, variation): if variation == "glu": self.assertIn("Using GLU activation: GELU", cs.out) - @parameterized.expand([(True, ), (False, )]) - def test_training_prefix_lm_all(self, loss_on_targets_only): + @parameterized.expand([(True, True), (False, False), (True, False), (False, True)]) + def test_training_prefix_lm_all(self, loss_on_targets_only, reweight_loss_based_on_position_frequency): # all in one test src_dir = self.src_dir data_dir = f"{self.data_dir}/gpt2" @@ -325,6 +325,7 @@ def test_training_prefix_lm_all(self, loss_on_targets_only): --global-batch-size 16 --train-samples {n_samples} {"--loss-on-targets-only" if loss_on_targets_only else ""} + {"--reweight-loss-based-on-position-frequency" if reweight_loss_based_on_position_frequency else ""} --optimizer adam --adam-beta1 0.9 @@ -353,6 +354,8 @@ def test_training_prefix_lm_all(self, loss_on_targets_only): --log-timers-to-tensorboard --log-batch-size-to-tensorboard --log-validation-ppl-to-tensorboard + + --log-level debug """.split() ds_args = f""" @@ -389,6 +392,9 @@ def test_training_prefix_lm_all(self, loss_on_targets_only): tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*") self.assertEqual(len(tensorboard_files), 1, "tensorboard files") + if reweight_loss_based_on_position_frequency: + self.assertIn("Using loss reweighting", cs.out) + # 2. test training from checkpoint: resume # now do it again, this time resuming from the checkpoint with CaptureStdout() as cs: