Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Shared t5 code #286

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,8 +925,8 @@ def __call__(self, parser, args, values, option_string=None):
'specific positions. This option tries to un-bias the loss by reweighting loss on specific '
'positions based on how frequently we train on that position.'
'This is mostly used for prefix_lm training')
group.add_argument("--noise_density", type=float, default=None, help="Span corruption noise density")
group.add_argument("--mean_noise_span_length", type=int, default=None, help="Span corruption mean noise span length")
group.add_argument("--noise-density", type=float, default=None, help="Span corruption noise density")
group.add_argument("--mean-noise-span-length", type=int, default=None, help="Span corruption mean noise span length")


return parser
Expand Down
24 changes: 18 additions & 6 deletions megatron/data/mlm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import torch

from megatron import print_rank_0, get_tokenizer
from megatron import print_rank_0, get_tokenizer, get_args
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.dataset_utils import get_datasets_weights_and_num_samples, get_split_by_range_
from megatron.data.dataset_utils import get_train_valid_test_split_, get_indexed_dataset_
Expand Down Expand Up @@ -297,13 +297,14 @@ def __init__(
# according to `noise_density` and `mean_noise_span_length`. We can also define the label length accordingly.
number_of_raw_tokens, inputs_length, targets_length, num_noise_spans = compute_input_and_target_lengths(
# +1 is used so that we can compute the as autoregressive systems require us to add one more token.
sequence_length=self.sequence_length + 1,
sequence_length=self.sequence_length,
noise_density=self.noise_density,
mean_noise_span_length=self.mean_noise_span_length
)
self.number_of_raw_tokens = number_of_raw_tokens
self.inputs_length = inputs_length
self.targets_length = targets_length
# As the loss we add a token at the end
self.number_of_raw_tokens = number_of_raw_tokens + 1
self.targets_length = targets_length +1
self.num_noise_spans = num_noise_spans

# Build the samples mapping.
Expand All @@ -320,13 +321,24 @@ def __init__(

# Vocab stuff.
tokenizer = get_tokenizer()
self.sep_id = tokenizer.sep
# TODO @thomasw21 find if overloading eod is acceptable.
# self.sep_id = tokenizer.sep
self.sep_id = tokenizer.eod
self.sentinel_token_ids = tokenizer.additional_special_tokens_ids
assert self.sep_id is not None, "MLM dataset requires tokenizer to have a <sep> token"
assert len(self.sentinel_token_ids) > 0, "Provide the argument --vocab-extra-ids 100 to the script"
assert len(self.sentinel_token_ids) >= self.num_noise_spans, "Not enough sentinel tokens, please add more"

args = get_args()
if hasattr(args, "encoder_seq_length") and args.encoder_seq_length is not None:
# T5 style
assert self.inputs_length == args.encoder_seq_length
assert self.targets_length == args.decoder_seq_length + 1
else:
assert self.inputs_length + self.targets_length == args.seq_length

def __len__(self):
return len(self.samples_mapping)
return len(self._gpt_dataset)

def __getitem__(self, idx):
if isinstance(idx, slice):
Expand Down
1 change: 1 addition & 0 deletions megatron/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .distributed import DistributedDataParallel
from .bert_model import BertModel
from .gpt_model import GPTModel, GPTModelPipe
from .shared_t5_model import SharedT5ModelPipe
from .t5_model import T5Model
from .language_model import get_language_model
from .module import Float16Module
6 changes: 3 additions & 3 deletions megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from megatron import get_args
from megatron import mpu
from megatron.enums import AttnMaskType
from .module import MegatronModule, fp32_to_float16
from .module import MegatronModule, fp32_to_16bit

from .language_model import parallel_lm_logits
from .language_model import get_language_model
Expand Down Expand Up @@ -213,9 +213,9 @@ def __init__(

def _to_float16(inputs):
if args.fp16:
return fp32_to_float16(inputs, lambda v: v.half())
return fp32_to_16bit(inputs, lambda v: v.half())
elif args.bf16:
return fp32_to_float16(inputs, lambda v: v.bfloat16())
return fp32_to_16bit(inputs, lambda v: v.bfloat16())
else:
return inputs

Expand Down
4 changes: 2 additions & 2 deletions megatron/model/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def conversion_helper(val, conversion):
return rtn


def fp32_to_float16(val, float16_convertor):
def fp32_to_16bit(val, float16_convertor):
"""Convert fp32 `val` to fp16/bf16"""
def half_conversion(val):
val_typecheck = val
Expand Down Expand Up @@ -168,7 +168,7 @@ def float16_convertor(val):

def forward(self, *inputs, **kwargs):
if mpu.is_pipeline_first_stage():
inputs = fp32_to_float16(inputs, self.float16_convertor)
inputs = fp32_to_16bit(inputs, self.float16_convertor)
outputs = self.module(*inputs, **kwargs)
if mpu.is_pipeline_last_stage():
outputs = float16_to_fp32(outputs)
Expand Down
180 changes: 180 additions & 0 deletions megatron/model/shared_t5_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import torch
from deepspeed import PipelineModule
from deepspeed.runtime.pipe import TiedLayerSpec, LayerSpec
from torch.nn import LayerNorm

from megatron.enums import AttnMaskType, LayerType

from megatron.model.transformer import ParallelTransformerLayerPipe

from megatron.model.language_model import EmbeddingPipe, parallel_lm_logits

from megatron.model.utils import init_method_normal, scaled_init_method_normal

from megatron import get_args, mpu

from megatron.model.module import MegatronModule, fp32_to_16bit, float16_to_fp32

def cross_entropy(output, labels):
labels, loss_mask = labels[0], labels[1]

losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels)

expected_number_of_tokens = loss_mask.sum()

loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / expected_number_of_tokens
return loss

class SharedT5ModelPipe(PipelineModule, MegatronModule):
"""Share encoder decoder language model."""

def __init__(
self,
num_tokentypes=0,
parallel_output=True,
):
args = get_args()
self.parallel_output = parallel_output

init_method = init_method_normal(args.init_method_std)

self.specs = []

def _to_16bit(inputs):
if args.fp16:
return fp32_to_16bit(inputs, lambda v: v.half())
elif args.bf16:
return fp32_to_16bit(inputs, lambda v: v.bfloat16())
else:
return inputs

self.specs.append(lambda inputss: tuple(_to_16bit(inputs) for inputs in inputss))

# Embedding layer
self.specs.append(TiedLayerSpec('embed',
EmbeddingPipe,
args.hidden_size,
args.padded_vocab_size,
args.hidden_dropout,
forward_fn=lambda module, input_and_target: (module(input_and_target[:3]), module(input_and_target[3:])),
init_method=init_method,
num_tokentypes=num_tokentypes,
tied_weight_attr='word_embeddings_weight'))

assert hasattr(args, 'attn_mask'), "Deepspeed integration should have attention mask s"
# Drop everything beside tokens
# self.specs.append(lambda inputs, targets: (inputs[0], targets[0]))
if args.fp32_residual_connection:
self.specs.append(lambda input_and_target: (input_and_target[0].transpose(0, 1).contiguous().float(), input_and_target[1].transpose(0, 1).contiguous().float()))
else:
self.specs.append(lambda input_and_target: (input_and_target[0].transpose(0, 1).contiguous(), input_and_target[1].transpose(0, 1).contiguous()))

### ----- Encoder -----
for layer_idx in range(args.num_layers):
self.specs.append(
TiedLayerSpec(
f"block_{layer_idx}",
ParallelTransformerLayerPipe,
init_method=init_method,
forward_fn=lambda module, input_and_target: (module(input_and_target[0]), input_and_target[1]),
output_layer_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers),
layer_type=LayerType.encoder,
layer_number=layer_idx,
self_attn_mask_type=AttnMaskType.causal,
tied_weight_attr=None,
tied_weight_attrs=["self_attention", "mlp"]
))

# Final layernorm after encoder layers
self.specs.append(
LayerSpec(
LayerNorm,
args.hidden_size,
forward_fn=lambda module, input_and_target: (module(input_and_target[0]), input_and_target[1]),
eps=args.layernorm_epsilon
))

# Decoder
for layer_idx in range(args.num_layers):
self.specs.append(
TiedLayerSpec(
f"block_{layer_idx}",
ParallelTransformerLayerPipe,
init_method=init_method,
forward_fn=lambda module, encoded_and_target: (encoded_and_target[0], module(encoded_and_target[1], encoder_output=encoded_and_target[0])),
output_layer_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers),
layer_number=layer_idx,
layer_type=LayerType.decoder,
self_attn_mask_type=AttnMaskType.padding,
tied_weight_attr=None,
tied_weight_attrs=["self_attention", "mlp"]
)
)

# Drop encoded tokens
self.specs.append(lambda encoded_and_target: encoded_and_target[1])

# Final layernorm after decoder layers
self.specs.append(
LayerSpec(
LayerNorm,
args.hidden_size,
eps=args.layernorm_epsilon
))

# Undo data format change
self.specs.append(lambda x: x.transpose(0, 1).contiguous())

def _logits_helper(embedding, lm_output):
"""A wrapper to massage inputs/outputs from pipeline. """
return parallel_lm_logits(
lm_output,
embedding.word_embeddings_weight,
self.parallel_output)

self.specs.append(
TiedLayerSpec('embed',
EmbeddingPipe,
args.hidden_size,
args.padded_vocab_size,
args.hidden_dropout,
init_method=init_method,
num_tokentypes=num_tokentypes,
forward_fn=_logits_helper,
tied_weight_attr='word_embeddings_weight')
)

# Convert to fp32 if needed
if args.fp16 or args.bf16:
self.specs.append(float16_to_fp32)

if args.checkpoint_activations:
interval = args.checkpoint_num_layers
else:
interval = 0

from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
topo = PipeModelDataParallelTopology(num_pp=mpu.get_pipeline_model_parallel_world_size(),
num_mp=mpu.get_tensor_model_parallel_world_size(),
num_dp=mpu.get_data_parallel_world_size())

# here one can extend the regex to include more layers to be counted towards partitioning,
# e.g. 'type:transformer|embedding' will add up all the transformer blocks and also the first
# and last embedding layers and then partition that transformers+2 layers - so to get a good
# balance you may want to use less transformer layers
#
# caveat emptor: the current implementation of PP fails unless each stage has at least one
# transformer layer
if args.pp_partition_method is not None:
partition_method = args.pp_partition_method
else:
partition_method = 'type:transformer'

super().__init__(layers=self.specs,
loss_fn=cross_entropy,
topology=topo,
activation_checkpoint_interval=interval,
partition_method=partition_method)
4 changes: 2 additions & 2 deletions megatron/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from megatron import get_args
from megatron import get_tokenizer
from megatron import mpu
from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model
from megatron.utils import get_attention_masks_and_position_ids, unwrap_model
from megatron.p2p_communication import recv_forward, send_forward

# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
Expand All @@ -42,7 +42,7 @@ def get_batch(context_tokens):
# Move to GPU.
tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda()
# Get the attention mask and position ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
attention_mask, _, position_ids = get_attention_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
Expand Down
11 changes: 7 additions & 4 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,16 @@ def check_adlr_autoresume_termination(iteration, model,
sys.exit(0)


def get_ltor_masks_and_position_ids(

def get_attention_masks_and_position_ids(
data,
eod_token,
reset_position_ids,
reset_attention_mask,
eod_mask_loss,
prefix_indices,
loss_on_targets_only,
ltor=True,
):
"""
Build masks and position id for left to right model.
Expand All @@ -177,9 +179,10 @@ def get_ltor_masks_and_position_ids(
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(torch.ones(
(att_mask_batch, seq_length, seq_length), device=data.device)).view(
att_mask_batch, 1, seq_length, seq_length)
attention_mask = torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
if ltor:
attention_mask = torch.tril(attention_mask)
attention_mask = attention_mask.view(att_mask_batch, 1, seq_length, seq_length)

# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
Expand Down
6 changes: 3 additions & 3 deletions pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from megatron.data.gpt_dataset import build_train_valid_test_datasets, build_dataset_group
from megatron.model import GPTModel, GPTModelPipe
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids, get_prefix_indices
from megatron.utils import get_attention_masks_and_position_ids, get_prefix_indices
from megatron.utils import average_losses_across_data_parallel_group

import deepspeed
Expand Down Expand Up @@ -110,7 +110,7 @@ def get_batch(data_iterator):
tokens = tokens_[:, :-1].contiguous()

# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
attention_mask, loss_mask, position_ids = get_attention_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
Expand Down Expand Up @@ -141,7 +141,7 @@ def get_batch_pipe(data):
tokens = tokens_[:, :-1].contiguous()

# Get the masks and position ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
attention_mask, loss_mask, position_ids = get_attention_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
Expand Down
Loading