diff --git a/Jenkinsfile b/Jenkinsfile index 9d9bea8871bb..b6eeb7c53ade 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -3694,6 +3694,50 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"''' sh "rm -rf examples/nlp/language_modeling/t5_index_mappings" } } + stage('L2: Megatron T5 w/ Mixture of Expert Pretraining') { + when { + anyOf { + branch 'main' + changeRequest target: 'main' + } + } + failFast true + steps { + sh "python examples/nlp/language_modeling/megatron_t5_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=10 \ + trainer.limit_val_batches=2 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=10 \ + trainer.precision=16 \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/t5_pretrain_results \ + model.pipeline_model_parallel_split_rank=1 \ + model.seq_length=256 \ + model.encoder.num_layers=4 \ + model.decoder.num_layers=1 \ + model.encoder.num_moe_experts=4 \ + model.decoder.num_moe_experts=4 \ + model.encoder.moe_frequency=3 \ + model.decoder.moe_frequency=1 \ + model.encoder.hidden_size=64 \ + model.decoder.hidden_size=64 \ + model.encoder.num_attention_heads=8 \ + model.decoder.num_attention_heads=8 \ + model.decoder.ffn_hidden_size=2048 \ + model.encoder.activation='gelu' \ + model.encoder.activations_checkpoint_method='block' \ + model.encoder.activations_checkpoint_num_layers=1 \ + model.encoder.transformer_block_type='pre_ln' \ + model.decoder.transformer_block_type='post_ln' \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document,.5,/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/t5_index_mappings" + sh "rm -rf examples/nlp/language_modeling/t5_pretrain_results" + sh "rm -rf examples/nlp/language_modeling/t5_index_mappings" + } + } stage('L2: Megatron T5 Prompt Learning') { when { anyOf { diff --git a/examples/nlp/language_modeling/conf/megatron_model_base_config.yaml b/examples/nlp/language_modeling/conf/megatron_model_base_config.yaml index f68b9ecf87b2..dc17fcd59bbb 100644 --- a/examples/nlp/language_modeling/conf/megatron_model_base_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_model_base_config.yaml @@ -33,3 +33,6 @@ activations_checkpoint_method: null # 'uniform', 'block' activations_checkpoint_num_layers: 1 megatron_legacy: False # Whether to use the legacy Megatron model. This affects the way q,k,v is partitioned from the mixed q,k,v layer in ParallelAttention. This needs to be True for models converted from HF. normalize_attention_scores: True # Whether to scale the output Q * K^T by 1 / sqrt(hidden_size_per_head). This arg is provided as a configuration option mostly for compatibility with models that have been weight-converted from HF. You almost always want to se this to True. +num_moe_experts: 1 # When >1, FFNs are changed to MoE layers +moe_frequency: 1 # every Nth ffn layer will be made MoE +moe_dropout: 0.0 # Dropout value for MoE layers \ No newline at end of file diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py b/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py index 63d14cfe84d1..23902f9a57b9 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py @@ -82,6 +82,9 @@ def get_decoder_model( normalize_attention_scores=True, sequence_parallel=False, gradient_accumulation_fusion=False, + num_moe_experts=1, + moe_frequency=1, + moe_dropout=0.0, ): """Build language model and return along with the key to save.""" @@ -134,6 +137,9 @@ def get_decoder_model( parent_model_type=parent_model_type, megatron_legacy=megatron_legacy, normalize_attention_scores=normalize_attention_scores, + num_moe_experts=num_moe_experts, + moe_frequency=moe_frequency, + moe_dropout=moe_dropout, ) elif arch == "retro": decoder = MegatronRetrievalTransformerDecoderModule( diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py b/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py index 1917979fc66a..c72a744d99e6 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py @@ -84,6 +84,9 @@ def get_encoder_model( normalize_attention_scores=True, sequence_parallel=False, gradient_accumulation_fusion=False, + num_moe_experts=1, + moe_frequency=1, + moe_dropout=0.0, ): """Build language model and return along with the key to save.""" @@ -136,6 +139,9 @@ def get_encoder_model( parent_model_type=parent_model_type, megatron_legacy=megatron_legacy, normalize_attention_scores=normalize_attention_scores, + num_moe_experts=num_moe_experts, + moe_frequency=moe_frequency, + moe_dropout=moe_dropout, ) elif arch == "retro": encoder = MegatronRetrievalTransformerEncoderModule( diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py index 5104855c860d..2a3e6bc8d65c 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py @@ -80,6 +80,9 @@ def __init__( parent_model_type=ModelType.encoder_or_decoder, megatron_legacy=False, normalize_attention_scores=True, + num_moe_experts=1, + moe_frequency=1, + moe_dropout=0.0, ): super(MegatronTransformerDecoderModule, self).__init__() @@ -139,6 +142,9 @@ def __init__( gradient_accumulation_fusion=False, # TODO: This has to be False for enc-dec models for now. megatron_legacy=megatron_legacy, normalize_attention_scores=normalize_attention_scores, + num_moe_experts=num_moe_experts, + moe_frequency=moe_frequency, + moe_dropout=moe_dropout, ) self._model_key = 'model' diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py index b48d89cd9644..a9b4eb4298b5 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py @@ -77,6 +77,9 @@ def __init__( parent_model_type=ModelType.encoder_or_decoder, megatron_legacy=False, normalize_attention_scores=True, + num_moe_experts=1, + moe_frequency=1, + moe_dropout=0.0, ): super(MegatronTransformerEncoderModule, self).__init__() @@ -137,6 +140,9 @@ def __init__( gradient_accumulation_fusion=False, # TODO: This has to be False for enc-dec models for now. megatron_legacy=megatron_legacy, normalize_attention_scores=normalize_attention_scores, + num_moe_experts=num_moe_experts, + moe_frequency=moe_frequency, + moe_dropout=moe_dropout, ) self._model_key = 'model' diff --git a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py index 573cbab7fc4c..2ceb8252f0a5 100644 --- a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py @@ -196,6 +196,9 @@ def __init__( num_self_attention_per_cross_attention=encoder_cfg.get('num_self_attention_per_cross_attention', 1), megatron_legacy=encoder_cfg.get('megatron_legacy', False), normalize_attention_scores=encoder_cfg.get('normalize_attention_scores', True), + num_moe_experts=encoder_cfg.get('num_moe_experts', 1), + moe_frequency=encoder_cfg.get('moe_frequency', 1), + moe_dropout=encoder_cfg.get('moe_dropout', 0.0), ) if add_decoder: @@ -300,6 +303,9 @@ def __init__( parent_model_type=ModelType.encoder_and_decoder, megatron_legacy=decoder_cfg.get('megatron_legacy', False), normalize_attention_scores=decoder_cfg.get('normalize_attention_scores', True), + num_moe_experts=decoder_cfg.get('num_moe_experts', 1), + moe_frequency=decoder_cfg.get('moe_frequency', 1), + moe_dropout=decoder_cfg.get('moe_dropout', 0.0), ) self.enc_dec_model = MegatronTransformerEncoderDecoderModule( diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index cbacbd1ee9ac..40057643b509 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -286,6 +286,135 @@ def forward(self, hidden_states): return output, output_bias +class SwitchMLP(MegatronModule): + """Top-1 MoE + + Curently supports Sinkhorn based expert routing.""" + + def __init__( + self, + num_experts, + init_method, + output_layer_init_method, + hidden_size, + ffn_hidden_size, + use_cpu_initialization=False, + bias_activation_fusion=True, + openai_gelu=False, + onnx_safe=False, + activation='gelu', + bias=True, + transformer_block_type='pre_ln', + normalization='layernorm', + layernorm_epsilon=1e-5, + persist_layer_norm=False, + sequence_parallel=False, + gradient_accumulation_fusion=False, + dropout=0.0, + ): + super(SwitchMLP, self).__init__() + + self.num_experts = num_experts + self.route_algo = SwitchMLP.sinkhorn + self.router = tensor_parallel.RowParallelLinear( + hidden_size, + num_experts, + input_is_parallel=False, + init_method=init_method, + skip_bias_add=False, + use_cpu_initialization=use_cpu_initialization, + bias=bias, + sequence_parallel_enabled=sequence_parallel, + gradient_accumulation_fusion=gradient_accumulation_fusion, + ) + + mlp_args = { + 'init_method': init_method, + 'output_layer_init_method': output_layer_init_method, + 'hidden_size': hidden_size, + 'ffn_hidden_size': ffn_hidden_size, + 'use_cpu_initialization': use_cpu_initialization, + 'bias_activation_fusion': bias_activation_fusion, + 'openai_gelu': openai_gelu, + 'onnx_safe': onnx_safe, + 'activation': activation, + 'bias': bias, + 'transformer_block_type': transformer_block_type, + 'normalization': normalization, + 'layernorm_epsilon': layernorm_epsilon, + 'persist_layer_norm': persist_layer_norm, + 'sequence_parallel': sequence_parallel, + 'gradient_accumulation_fusion': gradient_accumulation_fusion, + 'dropout': dropout, + } + self.experts = torch.nn.ModuleList([ParallelMLP(**mlp_args) for _ in range(num_experts)]) + + def forward(self, hidden_states): + hidden_shape = hidden_states.shape + route, _ = self.router(hidden_states) + route = route.view(-1, self.num_experts) + if self.training: + with torch.no_grad(): + norm_route = self.route_algo( + route.detach().to(dtype=torch.float32) + ) # explicit fp32 conversion for stability + _, max_ind = torch.max(norm_route, dim=1) + route = torch.sigmoid(route) + max_prob = route[torch.arange(route.size(0)), max_ind] + else: + route = torch.sigmoid(route) + max_prob, max_ind = torch.max(route, dim=1) + max_prob = torch.unsqueeze(max_prob, 1) + + hidden_states = hidden_states.view(-1, hidden_shape[-1]) + + local_indices = (max_ind == 0).nonzero() + hidden = hidden_states[local_indices, :] + output, output_bias = self.experts[0](hidden) + output_bias = output_bias.expand_as(output) + + output_total = torch.empty_like(hidden_states, dtype=output.dtype) + output_bias_total = torch.empty_like(hidden_states, dtype=output_bias.dtype) + + output_total[local_indices, :] = output + output_bias_total[local_indices, :] = output_bias + + for expert_num, expert in enumerate(self.experts): + if expert_num == 0: + continue + local_indices = (max_ind == expert_num).nonzero() + hidden = hidden_states[local_indices, :] + output, output_bias = expert(hidden) + output_bias = output_bias.expand_as(output) + output_total[local_indices, :] = output + output_bias_total[local_indices, :] = output_bias + + output_total = output_total * max_prob + output_bias_total = output_bias_total * max_prob + output_total = output_total.view(hidden_shape) + output_bias_total = output_bias_total.view(hidden_shape) + + return output_total, output_bias_total + + @classmethod + def sinkhorn(cls, cost, tol=0.0001): + "Megatron-LMs sinkhorn implementation" + + cost = torch.exp(cost) + d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype) + d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype) + + eps = 0.00000001 + error = 1e9 + d1_old = d1 + while error > tol: + d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps) + d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps) + error = torch.mean(torch.abs(d1_old - d1)) + d1_old = d1 + return d1 * cost * d0.unsqueeze(1) + + class CoreAttention(MegatronModule): """ Region where selective activation recomputation is applied. See Figure 3. in Reducing Activation Recomputation in Large Transformer Models @@ -1128,6 +1257,9 @@ def __init__( activations_checkpoint_granularity=None, sequence_parallel=False, normalize_attention_scores=True, + num_moe_experts=1, + moe_frequency=1, + moe_dropout=0.0, ): super(ParallelTransformerLayer_, self).__init__() @@ -1335,25 +1467,47 @@ def __init__( self.post_inter_attention_layernorm = MixedFusedRMSNorm(hidden_size, layernorm_epsilon) # MLP - self.mlp = ParallelMLP( - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - use_cpu_initialization=use_cpu_initialization, - bias_activation_fusion=bias_activation_fusion, - openai_gelu=openai_gelu, - onnx_safe=onnx_safe, - activation=activation, - bias=bias, - transformer_block_type=transformer_block_type, - normalization=normalization, - layernorm_epsilon=layernorm_epsilon, - persist_layer_norm=persist_layer_norm, - sequence_parallel=sequence_parallel, - gradient_accumulation_fusion=gradient_accumulation_fusion, - dropout=ffn_dropout, - ) + if num_moe_experts > 1 and self.layer_number % moe_frequency == 0: + self.mlp = SwitchMLP( + num_experts=num_moe_experts, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + use_cpu_initialization=use_cpu_initialization, + bias_activation_fusion=bias_activation_fusion, + openai_gelu=openai_gelu, + onnx_safe=onnx_safe, + activation=activation, + bias=bias, + transformer_block_type=transformer_block_type, + normalization=normalization, + layernorm_epsilon=layernorm_epsilon, + persist_layer_norm=persist_layer_norm, + sequence_parallel=sequence_parallel, + gradient_accumulation_fusion=gradient_accumulation_fusion, + dropout=moe_dropout, + ) + else: + self.mlp = ParallelMLP( + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + use_cpu_initialization=use_cpu_initialization, + bias_activation_fusion=bias_activation_fusion, + openai_gelu=openai_gelu, + onnx_safe=onnx_safe, + activation=activation, + bias=bias, + transformer_block_type=transformer_block_type, + normalization=normalization, + layernorm_epsilon=layernorm_epsilon, + persist_layer_norm=persist_layer_norm, + sequence_parallel=sequence_parallel, + gradient_accumulation_fusion=gradient_accumulation_fusion, + dropout=ffn_dropout, + ) def _get_bias_droput_add_func(self, transformer_block_type='pre_ln', position_after='attention'): """ @@ -1596,6 +1750,9 @@ def __init__( sequence_parallel=False, gradient_accumulation_fusion=False, normalize_attention_scores=True, + num_moe_experts=1, + moe_frequency=1, + moe_dropout=0.0, ): super(ParallelTransformerLayer, self).__init__( init_method=init_method, @@ -1632,6 +1789,9 @@ def __init__( sequence_parallel=sequence_parallel, gradient_accumulation_fusion=gradient_accumulation_fusion, normalize_attention_scores=normalize_attention_scores, + num_moe_experts=num_moe_experts, + moe_frequency=moe_frequency, + moe_dropout=moe_dropout, ) if precision == 32: @@ -1848,6 +2008,9 @@ def __init__( fp8_amax_compute_algo='most_recent', use_emha=False, normalize_attention_scores=True, + num_moe_experts=1, + moe_frequency=1, + moe_dropout=0.0, ): super(ParallelTransformer, self).__init__() @@ -1943,6 +2106,7 @@ def __init__( num_layers % parallel_state.get_pipeline_model_parallel_world_size() == 0 ), 'num_layers must be divisible by pipeline_model_parallel_size' + assert moe_frequency <= num_layers, 'MoE frequency must be <= number of transformer layers' # TODO: Add similar assert for encoder-decoder. self.num_layers = self.get_num_layers(num_layers) @@ -2014,6 +2178,9 @@ def build_layer(layer_number): activations_checkpoint_granularity=activations_checkpoint_granularity, sequence_parallel=sequence_parallel, normalize_attention_scores=normalize_attention_scores, + num_moe_experts=num_moe_experts, + moe_frequency=moe_frequency, + moe_dropout=moe_dropout, ) if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: