Skip to content

Commit

Permalink
Reweighting strat for prefix lm (#190)
Browse files Browse the repository at this point in the history
* First test to un bias the loss for prefix lm

* Woops

* Add same code for not deepspeed mode

* Improve testing

* Woops

* Test moving it inside?

* This changes the normalization factor in loss computation

* Fix

* Woops

* Better refactoring of loss normalization
  • Loading branch information
thomasw21 authored Nov 26, 2021
1 parent 1202668 commit b3cf175
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 25 deletions.
5 changes: 5 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,11 @@ def __call__(self, parser, args, values, option_string=None):
help='Mask loss for the end of document tokens.')
group.add_argument('--loss-on-targets-only', action='store_true',
help='Mask loss on input sequence.')
group.add_argument('--reweight-loss-based-on-position-frequency', action="store_true",
help='Some objectives require us to sample loss_mask. This might introduce bias towards '
'specific positions. This option tries to un-bias the loss by reweighting loss on specific '
'positions based on how frequently we train on that position.'
'This is mostly used for prefix_lm training')

return parser

Expand Down
23 changes: 6 additions & 17 deletions megatron/model/glu_activations.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from functools import wraps

import torch
from torch import nn
from torch.nn import functional as F

from megatron import logging
from megatron.model.utils import log_debug_usage

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -38,21 +37,11 @@ class SwiGLU(_GLUBaseModule):
def __init__(self):
super().__init__(F.silu)

def log_debug_usage(func, msg: str):
func.__logged_message__ = False
@wraps(func)
def wrapped(*args, **kwargs):
if func.__logged_message__ is False:
logger.debug(msg)
func.__logged_message__ = True
return func(*args, **kwargs)
return wrapped


liglu = log_debug_usage(torch.jit.script(LiGLU()), "Using GLU activation: LiGLU.")
geglu = log_debug_usage(torch.jit.script(GEGLU()), "Using GLU activation: GELU.")
reglu = log_debug_usage(torch.jit.script(ReGLU()), "Using GLU activation: ReGLU.")
swiglu = log_debug_usage(torch.jit.script(SwiGLU()), "Using GLU activation: SwiGLU.")

liglu = log_debug_usage(logger, "Using GLU activation: LiGLU.")(torch.jit.script(LiGLU()))
geglu = log_debug_usage(logger, "Using GLU activation: GELU.")(torch.jit.script(GEGLU()))
reglu = log_debug_usage(logger, "Using GLU activation: ReGLU.")(torch.jit.script(ReGLU()))
swiglu = log_debug_usage(logger, "Using GLU activation: SwiGLU.")(torch.jit.script(SwiGLU()))


GLU_ACTIVATIONS = {
Expand Down
15 changes: 12 additions & 3 deletions megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,23 @@ def CrossEntropy(output, labels):

if is_prefix:
micro_batch_size, sequence_length = loss_mask.shape
expected_number_of_tokens = micro_batch_size * sequence_length
average_tokens_per_sample: torch.Tensor
if args.loss_on_targets_only:
# HACK: This is useful when we obtain loss masks that are microbatch dependent. Consequently, if we want to
# preserve the notion that all tokens have the same impact on the loss, we can only normalise using a
# microbatch independent value.
# microbatch independent value. It should be expected weight over a microbatch.
# Here we still use `sequence_length`, that's batch size dependent, in order to be backwards compatible with
# current experiment on vanilla gpt.
expected_number_of_tokens /= 2
if args.reweight_loss_based_on_position_frequency:
reweight = torch.arange(
sequence_length, 0, -1, dtype=torch.float, device=loss_mask.device
) / (sequence_length + 1) * 2
average_tokens_per_sample = reweight.flip(-1).cumsum(-1).mean()
else:
average_tokens_per_sample = (sequence_length + 1) / 2
else:
average_tokens_per_sample = sequence_length
expected_number_of_tokens = average_tokens_per_sample * micro_batch_size
else:
expected_number_of_tokens = loss_mask.sum()

Expand Down
16 changes: 16 additions & 0 deletions megatron/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Utilities for models."""

import math
from functools import wraps

import torch

Expand Down Expand Up @@ -73,3 +74,18 @@ def openai_gelu(x):
@torch.jit.script
def erf_gelu(x):
return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype))

def log_debug_usage(logger, msg: str):
def log_debug_usage_(func):
"""Helper function in order to log a message when using a function for the first time"""
func.__logged_message__ = False

@wraps(func)
def wrapped(*args, **kwargs):
if func.__logged_message__ is False:
logger.debug(msg)
func.__logged_message__ = True
return func(*args, **kwargs)

return wrapped
return log_debug_usage_
14 changes: 13 additions & 1 deletion megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,17 @@
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C

from megatron import get_args
from megatron import get_args, logging
from megatron import print_rank_0
from megatron import get_adlr_autoresume
from megatron import mpu
from megatron.model.module import param_is_not_shared
from megatron.model.utils import log_debug_usage
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate, VocabParallelEmbedding
from megatron import get_num_microbatches

logger = logging.get_logger(__name__)

def unwrap_model(model, module_instances=(torchDDP)):
return_list = True
if not isinstance(model, list):
Expand Down Expand Up @@ -371,3 +374,12 @@ def get_prefix_indices(data, eod_token, partial_prefix_indices, reset_attention_
prefix_indices.append(prefix_index)

return prefix_indices


@log_debug_usage(logger, "Using loss reweighting")
def reweight_loss_mask_(loss_mask: torch.Tensor, tokens: torch.Tensor):
"""Reweight loss mask in-place"""
_, seq_length = tokens.shape
weight_loss = torch.arange(seq_length, 0, -1, dtype=torch.float, device=loss_mask.device) / (seq_length + 1) * 2
# in-place operation
loss_mask *= weight_loss[None, :]
11 changes: 9 additions & 2 deletions pretrain_prefix_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,11 @@
from megatron.data.gpt_dataset import build_train_valid_test_datasets, build_dataset_group
from megatron.model import GPTModel, GPTModelPipe
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids, get_prefix_indices
from megatron.utils import get_ltor_masks_and_position_ids, get_prefix_indices, reweight_loss_mask_
from megatron.utils import average_losses_across_data_parallel_group

import deepspeed
from deepspeed.runtime.utils import see_memory_usage
import os
import subprocess

def model_provider(pre_process=True, post_process=True):
Expand Down Expand Up @@ -108,6 +107,10 @@ def get_batch(data_iterator):
loss_on_targets_only=args.loss_on_targets_only
)

# weight loss_mask
if args.reweight_loss_based_on_position_frequency:
reweight_loss_mask_(loss_mask, tokens)

return tokens, labels, loss_mask, attention_mask, position_ids

def get_batch_pipe(data):
Expand Down Expand Up @@ -146,6 +149,10 @@ def get_batch_pipe(data):
loss_on_targets_only=args.loss_on_targets_only
)

# weight loss_mask
if args.reweight_loss_based_on_position_frequency:
reweight_loss_mask_(loss_mask, tokens)

return (tokens, position_ids, attention_mask), (labels, loss_mask), prefix_indices

def loss_func(loss_mask, output_tensor):
Expand Down
10 changes: 8 additions & 2 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ def test_training_all(self, variation):
if variation == "glu":
self.assertIn("Using GLU activation: GELU", cs.out)

@parameterized.expand([(True, ), (False, )])
def test_training_prefix_lm_all(self, loss_on_targets_only):
@parameterized.expand([(True, True), (False, False), (True, False), (False, True)])
def test_training_prefix_lm_all(self, loss_on_targets_only, reweight_loss_based_on_position_frequency):
# all in one test
src_dir = self.src_dir
data_dir = f"{self.data_dir}/gpt2"
Expand All @@ -325,6 +325,7 @@ def test_training_prefix_lm_all(self, loss_on_targets_only):
--global-batch-size 16
--train-samples {n_samples}
{"--loss-on-targets-only" if loss_on_targets_only else ""}
{"--reweight-loss-based-on-position-frequency" if reweight_loss_based_on_position_frequency else ""}
--optimizer adam
--adam-beta1 0.9
Expand Down Expand Up @@ -353,6 +354,8 @@ def test_training_prefix_lm_all(self, loss_on_targets_only):
--log-timers-to-tensorboard
--log-batch-size-to-tensorboard
--log-validation-ppl-to-tensorboard
--log-level debug
""".split()

ds_args = f"""
Expand Down Expand Up @@ -389,6 +392,9 @@ def test_training_prefix_lm_all(self, loss_on_targets_only):
tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*")
self.assertEqual(len(tensorboard_files), 1, "tensorboard files")

if reweight_loss_based_on_position_frequency:
self.assertIn("Using loss reweighting", cs.out)

# 2. test training from checkpoint: resume
# now do it again, this time resuming from the checkpoint
with CaptureStdout() as cs:
Expand Down

0 comments on commit b3cf175

Please sign in to comment.