Skip to content

Commit

Permalink
Fix shared T5
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasw21 committed Jun 21, 2022
1 parent 86aa795 commit 9b4a741
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions megatron/model/shared_t5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,26 @@ def _to_16bit(inputs):
else:
return inputs

self.specs.append(lambda inputss: tuple(_to_16bit(inputs) for inputs in inputss))
self.specs.append(lambda inputss: tuple(tuple(_to_16bit(inputs)) for inputs in inputss))

# Embedding layer
self.specs.append(TiedLayerSpec('embed',
EmbeddingPipe,
args.hidden_size,
args.padded_vocab_size,
args.hidden_dropout,
forward_fn=lambda module, inputs, targets: (module(*inputs), module(*targets)),
init_method=init_method,
num_tokentypes=num_tokentypes,
tied_weight_attr='word_embeddings_weight'))

assert hasattr(args, 'attn_mask'), "Deepspeed integration should have attention mask s"
# Drop everything beside tokens
self.specs.append(lambda inputs, targets: (inputs[0], targets[0]))
if args.fp32_residual_connection:
self.specs.append(lambda x: x.transpose(0, 1).contiguous().float())
self.specs.append(lambda input_tokens, target_tokens: (input_tokens.transpose(0, 1).contiguous().float(), target_tokens.transpose(0, 1).contiguous().float()))
else:
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
self.specs.append(lambda input_tokens, target_tokens: (input_tokens.transpose(0, 1).contiguous(), target_tokens.transpose(0, 1).contiguous()))

### ----- Encoder -----
for layer_idx in range(args.num_layers):
Expand All @@ -74,22 +77,21 @@ def _to_16bit(inputs):
f"block_{layer_idx}",
ParallelTransformerLayerPipe,
init_method=init_method,
# Inputs: (input_tokens, target_tokens,
forward_fn=lambda module, *inputs: ,
forward_fn=lambda module, input_tokens, target_tokens: (module(input_tokens), target_tokens),
output_layer_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers),
layer_type=LayerType.encoder,
layer_number=layer_idx,
self_attn_mask_type=AttnMaskType.causal,
tied_weight_attrs=["input_layernorm", "self_attention", "post_attention_layernorm", "mlp"]
tied_weight_attrs=["self_attention", "mlp"]
))

# Final layernorm after encoder layers
self.specs.append(
TiedLayerSpec(
"final_layer_norm",
LayerSpec(
LayerNorm,
args.hidden_size,
forward_fn=lambda module, input_tokens, target_tokens: (module(input_tokens), target_tokens),
eps=args.layernorm_epsilon
))

Expand All @@ -100,19 +102,22 @@ def _to_16bit(inputs):
f"block_{layer_idx}",
ParallelTransformerLayerPipe,
init_method=init_method,
forward_fn=lambda module, encoded_tokens, target_tokens: (encoded_tokens, module(target_tokens, encoder_output=encoded_tokens)),
output_layer_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers),
layer_number=layer_idx,
layer_type=LayerType.decoder,
self_attn_mask_type=AttnMaskType.padding,
tied_weight_attrs=["input_layernorm", "self_attention", "post_attention_layernorm", "mlp"]
tied_weight_attrs=["self_attention", "mlp"]
)
)

# Drop encoded tokens
self.specs.append(lambda encoded_tokens, target_tokens: target_tokens)

# Final layernorm after decoder layers
self.specs.append(
TiedLayerSpec(
"final_layer_norm",
LayerSpec(
LayerNorm,
args.hidden_size,
eps=args.layernorm_epsilon
Expand Down

0 comments on commit 9b4a741

Please sign in to comment.