Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distill BLOOM - tentative 2 #354

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ def parse_args(extra_args_provider=None, defaults={},
# Checks.
if args.ffn_hidden_size is None:
args.ffn_hidden_size = 4 * args.hidden_size

if args.student_ffn_hidden_size is None:
args.student_ffn_hidden_size = 4 * args.student_hidden_size

if args.kv_channels is None:
assert args.hidden_size % args.num_attention_heads == 0
Expand Down Expand Up @@ -353,9 +356,18 @@ def _add_network_size_args(parser):
help='Number of transformer layers.')
group.add_argument('--hidden-size', type=int, default=None,
help='Tansformer hidden size.')
group.add_argument('--student-num-layers', type=int, default=None,
help='Number of student transformer layers.')
group.add_argument('--student-hidden-size', type=int, default=None,
help='Student Tansformer hidden size.')
group.add_argument('--student-num-attention-heads', type=int, default=None,
help='Number of student transformer attention heads.')
group.add_argument('--ffn-hidden-size', type=int, default=None,
help='Transformer Feed-Forward Network hidden size. '
'This is set to 4*hidden-size if not provided')
group.add_argument('--student-ffn-hidden-size', type=int, default=None,
help='Transformer Feed-Forward Network hidden size. '
'This is set to 4*hidden-size if not provided')
group.add_argument('--num-attention-heads', type=int, default=None,
help='Number of transformer attention heads.')
group.add_argument('--kv-channels', type=int, default=None,
Expand Down Expand Up @@ -660,6 +672,10 @@ def _add_checkpointing_args(parser):
help='Do not save current rng state.')
group.add_argument('--load', type=str, default=None,
help='Directory containing a model checkpoint.')
group.add_argument('--teacher-load', type=str, default=None,
help='Directory containing a model checkpoint.')
group.add_argument('--student-load', type=str, default=None,
help='Directory containing a model checkpoint.')
group.add_argument('--no-load-optim', action='store_true', default=None,
help='Do not load optimizer when loading checkpoint.')
group.add_argument('--no-load-rng', action='store_true', default=None,
Expand Down Expand Up @@ -715,8 +731,12 @@ def _add_distributed_args(parser):

group.add_argument('--tensor-model-parallel-size', type=int, default=1,
help='Degree of tensor model parallelism.')
group.add_argument('--student-tensor-model-parallel-size', type=int, default=1,
help='Degree of tensor model parallelism.')
group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism.')
group.add_argument('--student-pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism.')
group.add_argument('--model-parallel-size', type=int, default=None,
help='Old model parallel argument, do not use. Use '
'--tensor-model-parallel-size instead.')
Expand Down
17 changes: 16 additions & 1 deletion megatron/model/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@ def forward(ctx, input, weight, bias, normalized_shape, eps):

ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
if isinstance(input, tuple):
input_ = input[0].contiguous()
else:
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
Expand Down Expand Up @@ -109,3 +112,15 @@ def forward(self, input):
input, self.weight, self.bias, self.normalized_shape, self.eps)
else:
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias)

class MixedFusedLayerNormTeacher(MixedFusedLayerNorm):

@torch.no_grad()
def forward(self, input):
input, original_input = input
return (super().forward(input), original_input)

class MixedFusedLayerNormStudent(MixedFusedLayerNorm):
def forward(self, input):
input, logits_teacher = input
return (super().forward(input), logits_teacher)
3 changes: 2 additions & 1 deletion megatron/model/fused_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def forward_fused_softmax(self, input, mask):

if self.attn_mask_type == AttnMaskType.causal:
assert sq == sk, "causal mask is only for self attention"
assert mask is None, "Mask is silently ignored due to the use of a custom kernel"
# assert mask is None, "Mask is silently ignored due to the use of a custom kernel"

# input is 3D tensor (attn_batches, sq, sk)
input = input.view(-1, sq, sk)
Expand Down Expand Up @@ -236,3 +236,4 @@ def get_batch_per_block(sq, sk, b, np):
import scaled_masked_softmax_cuda

return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)

183 changes: 167 additions & 16 deletions megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@

from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec
from megatron.model.fused_layer_norm import MixedFusedLayerNorm as LayerNorm
from megatron.model.fused_layer_norm import MixedFusedLayerNormTeacher as LayerNormTeacher
from megatron.model.fused_layer_norm import MixedFusedLayerNormStudent as LayerNormStudent
from megatron.model.module import float16_to_fp32
from .language_model import EmbeddingPipe
from .transformer import ParallelTransformerLayerPipe
from .language_model import EmbeddingPipe, EmbeddingPipeTeacher, EmbeddingPipeStudent
from .transformer import ParallelTransformerLayerPipe, ParallelTransformerLayerPipeTeacher, ParallelTransformerLayerPipeStudent


def post_language_model_processing(lm_output, labels, logit_weights,
Expand Down Expand Up @@ -195,6 +197,57 @@ def CrossEntropy(output, labels):
return CrossEntropy


def get_ts_loss(is_prefix: bool):
def TeacherStudentLoss(output, labels):
output, teacher_logits = output[0], output[1]
labels, loss_mask = labels[0], labels[1]

args = get_args()

losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels)

if is_prefix:
micro_batch_size, sequence_length = loss_mask.shape
average_tokens_per_sample: torch.Tensor
if args.loss_on_targets_only:
# HACK: This is useful when we obtain loss masks that are microbatch dependent. Consequently, if we want to
# preserve the notion that all tokens have the same impact on the loss, we can only normalise using a
# microbatch independent value. It should be expected weight over a microbatch.
# Here we still use `sequence_length`, that's batch size dependent, in order to be backwards compatible with
# current experiment on vanilla gpt.
if args.reweight_loss_based_on_position_frequency:
reweight = torch.arange(
sequence_length, 0, -1, dtype=torch.float, device=loss_mask.device
) / (sequence_length + 1) * 2
average_tokens_per_sample = reweight.flip(-1).cumsum(-1).mean()
else:
average_tokens_per_sample = (sequence_length + 1) / 2
else:
average_tokens_per_sample = sequence_length
expected_number_of_tokens = average_tokens_per_sample * micro_batch_size
else:
expected_number_of_tokens = loss_mask.sum()

loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / expected_number_of_tokens

# TODO: check if the formula is correct
teacher_logits = teacher_logits.detach()
# First pass it on CPU - otherwise we get OOM errors
softmax_labels = torch.nn.Softmax(dim=-1)(teacher_logits)
softmax_labels = softmax_labels.permute(1, 0, 2)

student_log_softax = -torch.nn.LogSoftmax(dim=-1)(output)

# print(output.shape, teacher_logits.shape)
# print(student_log_softax.shape, softmax_labels.shape)
softmax_logits = student_log_softax * softmax_labels
logits_loss = softmax_logits.mean()

return loss + logits_loss
return TeacherStudentLoss


class GPTModelPipe(PipelineModule,MegatronModule):
"""GPT-2 Language model."""

Expand Down Expand Up @@ -223,7 +276,7 @@ def _to_float16(inputs):

# Embedding layer
self.specs.append(TiedLayerSpec('embed',
EmbeddingPipe,
EmbeddingPipeTeacher,
args.hidden_size,
args.padded_vocab_size,
args.hidden_dropout,
Expand All @@ -239,14 +292,14 @@ def _to_float16(inputs):
self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous().float(), *x[1:]))
else:
if getattr(args, 'pretrain_causal_attention', False):
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), x[1]))
else:
# EmbeddingPipe returns attention mask as well
self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), *x[1:]))

for layer_idx in range(args.num_layers):
self.specs.append(
LayerSpec(ParallelTransformerLayerPipe,
LayerSpec(ParallelTransformerLayerPipeTeacher,
init_method=init_method,
output_layer_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers),
Expand All @@ -256,14 +309,16 @@ def _to_float16(inputs):

# Undo data format change
def undo(x):
if not getattr(args, 'pretrain_causal_attention', False):
x = x[0]
# if not getattr(args, 'pretrain_causal_attention', False):
# x = x[0]
if isinstance(x, tuple):
return (x[0][0].transpose(0, 1).contiguous(), x[1])
return x.transpose(0, 1).contiguous()
self.specs.append(undo)

# Final layernorm after transformer layers
self.specs.append(
LayerSpec(LayerNorm,
LayerSpec(LayerNormTeacher,
args.hidden_size,
eps=args.layernorm_epsilon))

Expand All @@ -276,7 +331,7 @@ def _logits_helper(embedding, lm_output):

self.specs.append(
TiedLayerSpec('embed',
EmbeddingPipe,
EmbeddingPipeTeacher,
args.hidden_size,
args.padded_vocab_size,
args.hidden_dropout,
Expand All @@ -286,34 +341,130 @@ def _logits_helper(embedding, lm_output):
tied_weight_attr='word_embeddings_weight')
)

# self.specs.append(lambda x: print(x[0]))
# Convert to fp32 if needed
if args.fp16 or args.bf16:
self.specs.append(float16_to_fp32)
# if args.fp16 or args.bf16:
# self.specs.append(float16_to_fp32)

if args.checkpoint_activations:
interval = args.checkpoint_num_layers
else:
interval = 0

from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
topo = PipeModelDataParallelTopology(num_pp=mpu.get_pipeline_model_parallel_world_size(),
num_mp=mpu.get_tensor_model_parallel_world_size(),
num_dp=mpu.get_data_parallel_world_size())


# here one can extend the regex to include more layers to be counted towards partitioning,
# e.g. 'type:transformer|embedding' will add up all the transformer blocks and also the first
# and last embedding layers and then partition that transformers+2 layers - so to get a good
# balance you may want to use less transformer layers
#
# caveat emptor: the current implementation of PP fails unless each stage has at least one

# Beginning student model

init_method = init_method_normal(args.init_method_std)


def _to_float16(inputs):
if args.fp16:
return fp32_to_float16(inputs, lambda v: v.half())
elif args.bf16:
return fp32_to_float16(inputs, lambda v: v.bfloat16())
else:
return inputs

# self.specs.append(_to_float16)
self.specs.append(lambda x: (x[0], x[1]))

# Embedding layer
self.specs.append(TiedLayerSpec('embed_student',
EmbeddingPipeStudent,
args.student_hidden_size,
args.padded_vocab_size,
args.hidden_dropout,
init_method=init_method,
num_tokentypes=num_tokentypes,
tied_weight_attr='word_embeddings_weight'))

if args.fp32_residual_connection:
if getattr(args, 'pretrain_causal_attention', False):
self.specs.append(lambda x: x.transpose(0, 1).contiguous().float())
else:
# EmbeddingPipe returns attention mask as well
self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous().float(), *x[1:]))
else:
if getattr(args, 'pretrain_causal_attention', False):
self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), x[1]))
else:
# EmbeddingPipe returns attention mask as well
self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), *x[1:]))

for layer_idx in range(args.student_num_layers):
self.specs.append(
LayerSpec(ParallelTransformerLayerPipeStudent,
init_method=init_method,
output_layer_init_method=scaled_init_method_normal(args.init_method_std,
args.student_num_layers),
layer_number=layer_idx,
# TODO: Change naming of class from GPT to something that encapsulate prefix lm.
self_attn_mask_type=attn_mask_type))

# Undo data format change
# def undo(x):
# if not getattr(args, 'pretrain_causal_attention', False):
# x = x[0]
# return x.transpose(0, 1).contiguous()
# self.specs.append(undo)

# Final layernorm after transformer layers
self.specs.append(
LayerSpec(LayerNormStudent,
args.student_hidden_size,
eps=args.layernorm_epsilon))

def _logits_helper(embedding, lm_output):
"""A wrapper to massage inputs/outputs from pipeline. """
return parallel_lm_logits(
lm_output,
embedding.word_embeddings_weight,
self.parallel_output)

self.specs.append(
TiedLayerSpec('embed_student',
EmbeddingPipeStudent,
args.student_hidden_size,
args.padded_vocab_size,
args.hidden_dropout,
init_method=init_method,
num_tokentypes=num_tokentypes,
forward_fn=_logits_helper,
tied_weight_attr='word_embeddings_weight')
)

# Convert to fp32 if needed
if args.fp16 or args.bf16:
self.specs.append(float16_to_fp32)

if args.checkpoint_activations:
interval = args.checkpoint_num_layers
else:
interval = 0

# transformer layer
if args.pp_partition_method is not None:
partition_method = args.pp_partition_method
else:
partition_method = 'type:transformer'

from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

topo = PipeModelDataParallelTopology(num_pp=mpu.get_pipeline_model_parallel_world_size(),
num_mp=mpu.get_tensor_model_parallel_world_size(),
num_dp=mpu.get_data_parallel_world_size())



super().__init__(layers=self.specs,
loss_fn=get_cross_entropy(is_prefix=attn_mask_type is AttnMaskType.prefix),
loss_fn=get_ts_loss(is_prefix=attn_mask_type is AttnMaskType.prefix),
topology=topo,
activation_checkpoint_interval=interval,
partition_method=partition_method)
Loading