diff --git a/megatron/model/shared_t5_model.py b/megatron/model/shared_t5_model.py index b375c6742..908568f83 100644 --- a/megatron/model/shared_t5_model.py +++ b/megatron/model/shared_t5_model.py @@ -49,7 +49,7 @@ def _to_16bit(inputs): else: return inputs - self.specs.append(lambda inputss: tuple(_to_16bit(inputs) for inputs in inputss)) + self.specs.append(lambda inputss: tuple(tuple(_to_16bit(inputs)) for inputs in inputss)) # Embedding layer self.specs.append(TiedLayerSpec('embed', @@ -57,15 +57,18 @@ def _to_16bit(inputs): args.hidden_size, args.padded_vocab_size, args.hidden_dropout, + forward_fn=lambda module, inputs, targets: (module(*inputs), module(*targets)), init_method=init_method, num_tokentypes=num_tokentypes, tied_weight_attr='word_embeddings_weight')) assert hasattr(args, 'attn_mask'), "Deepspeed integration should have attention mask s" + # Drop everything beside tokens + self.specs.append(lambda inputs, targets: (inputs[0], targets[0])) if args.fp32_residual_connection: - self.specs.append(lambda x: x.transpose(0, 1).contiguous().float()) + self.specs.append(lambda input_tokens, target_tokens: (input_tokens.transpose(0, 1).contiguous().float(), target_tokens.transpose(0, 1).contiguous().float())) else: - self.specs.append(lambda x: x.transpose(0, 1).contiguous()) + self.specs.append(lambda input_tokens, target_tokens: (input_tokens.transpose(0, 1).contiguous(), target_tokens.transpose(0, 1).contiguous())) ### ----- Encoder ----- for layer_idx in range(args.num_layers): @@ -74,22 +77,21 @@ def _to_16bit(inputs): f"block_{layer_idx}", ParallelTransformerLayerPipe, init_method=init_method, - # Inputs: (input_tokens, target_tokens, - forward_fn=lambda module, *inputs: , + forward_fn=lambda module, input_tokens, target_tokens: (module(input_tokens), target_tokens), output_layer_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers), layer_type=LayerType.encoder, layer_number=layer_idx, self_attn_mask_type=AttnMaskType.causal, - tied_weight_attrs=["input_layernorm", "self_attention", "post_attention_layernorm", "mlp"] + tied_weight_attrs=["self_attention", "mlp"] )) # Final layernorm after encoder layers self.specs.append( - TiedLayerSpec( - "final_layer_norm", + LayerSpec( LayerNorm, args.hidden_size, + forward_fn=lambda module, input_tokens, target_tokens: (module(input_tokens), target_tokens), eps=args.layernorm_epsilon )) @@ -100,19 +102,22 @@ def _to_16bit(inputs): f"block_{layer_idx}", ParallelTransformerLayerPipe, init_method=init_method, + forward_fn=lambda module, encoded_tokens, target_tokens: (encoded_tokens, module(target_tokens, encoder_output=encoded_tokens)), output_layer_init_method=scaled_init_method_normal(args.init_method_std, args.num_layers), layer_number=layer_idx, layer_type=LayerType.decoder, self_attn_mask_type=AttnMaskType.padding, - tied_weight_attrs=["input_layernorm", "self_attention", "post_attention_layernorm", "mlp"] + tied_weight_attrs=["self_attention", "mlp"] ) ) + # Drop encoded tokens + self.specs.append(lambda encoded_tokens, target_tokens: target_tokens) + # Final layernorm after decoder layers self.specs.append( - TiedLayerSpec( - "final_layer_norm", + LayerSpec( LayerNorm, args.hidden_size, eps=args.layernorm_epsilon