From 6864e22a9fc26b6e9e2efaf47dd7c44f01d98d8f Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Wed, 5 Jun 2024 18:05:14 -0700 Subject: [PATCH 1/4] initial moe lora impl Signed-off-by: Chen Cui --- .../common/megatron/adapters/mcore_mixins.py | 76 +++++++++++++++---- .../megatron/adapters/parallel_adapters.py | 74 ++++++++++++++++++ nemo/collections/nlp/parts/peft_config.py | 39 ++++++++-- 3 files changed, 170 insertions(+), 19 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py index fe9e900f4ad0..6b9a0e942be4 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py @@ -14,19 +14,16 @@ import torch import torch.nn.functional as F -from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb +from megatron.core.tensor_parallel import ColumnParallelLinear from megatron.core.transformer.attention import SelfAttention -from megatron.core.transformer.custom_layers.transformer_engine import ( - SplitAlongDim, - TEColumnParallelLinear, - TELayerNormColumnParallelLinear, -) +from megatron.core.transformer.custom_layers.transformer_engine import SplitAlongDim from megatron.core.transformer.mlp import MLP +from megatron.core.transformer.moe.experts import SequentialMLP from megatron.core.transformer.transformer_layer import TransformerLayer from megatron.core.utils import make_viewless_tensor @@ -42,6 +39,8 @@ MLPInfusedAdapterConfig, ParallelLinearAdapterConfig, PromptEncoderAdapterConfig, + LoraMoeHto4HAdapterConfig, + LoraMoe4HtoHAdapterConfig, ) from nemo.core import adapter_mixins @@ -270,6 +269,8 @@ def mcore_register_adapters(self): LoraUnfusedHto4HAdapterConfig._target_, LoraHto4HAdapterConfig._target_, Lora4HtoHAdapterConfig._target_, + LoraMoeHto4HAdapterConfig._target_, + LoraMoe4HtoHAdapterConfig._target_, MLPInfusedAdapterConfig._target_, ] ) # only self attn (packed qkv) for now @@ -284,9 +285,12 @@ def mcore_register_adapters(self): # overlap is used. self.linear_fc1.return_layernorm_output_gathered = True - def forward(self, hidden_states): + def forward(self, hidden_states, expert_idx=None): # [s, b, 4 * h/p] - if self.linear_fc1.te_return_bias: + if isinstance(self.linear_fc1, ColumnParallelLinear): + layernorm_output = hidden_states + intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states) + elif self.linear_fc1.te_return_bias: intermediate_parallel, bias_parallel, layernorm_output = self.linear_fc1(hidden_states) else: # bias_parallel is None @@ -297,15 +301,19 @@ def forward(self, hidden_states): lora_adapter = None lora_fc1_adapter = self.get_adapter_module(AdapterName.LORA_Hto4H_ADAPTER) lora_unfused_fc1_adapter = self.get_adapter_module(AdapterName.LORA_UNFUSED_Hto4H_ADAPTER) + lora_moe_fc1_adapter = self.get_adapter_module(AdapterName.LORA_MOE_Hto4H_ADAPTER) if lora_fc1_adapter and self.adapter_cfg[AdapterName.LORA_Hto4H_ADAPTER]['enabled']: lora_adapter = lora_fc1_adapter if lora_unfused_fc1_adapter and self.adapter_cfg[AdapterName.LORA_UNFUSED_Hto4H_ADAPTER]['enabled']: assert lora_adapter is None, "Expected only one of LORA_Hto4H_ADAPTER or LORA_UNFUSED_Hto4H_ADAPTER" lora_adapter = lora_unfused_fc1_adapter + lora_output = 0 if lora_adapter: lora_output = lora_adapter(layernorm_output) - intermediate_parallel = intermediate_parallel + lora_output + elif lora_moe_fc1_adapter and self.adapter_cfg[AdapterName.LORA_MOE_Hto4H_ADAPTER]['enabled']: + lora_output = lora_moe_fc1_adapter(layernorm_output, expert_idx) + intermediate_parallel = intermediate_parallel + lora_output if self.config.bias_activation_fusion: if self.activation_func == F.gelu: @@ -343,13 +351,55 @@ def glu(x): # LoRA logic if self.is_adapter_available(): - lora_linear_fc2_adapter = self.get_adapter_module(AdapterName.LORA_4HtoH_ADAPTER) - if lora_linear_fc2_adapter and self.adapter_cfg[AdapterName.LORA_4HtoH_ADAPTER]['enabled']: - lora_output = lora_linear_fc2_adapter(intermediate_parallel) - output = output + lora_output + lora_fc2_adapter = self.get_adapter_module(AdapterName.LORA_4HtoH_ADAPTER) + lora_moe_fc2_adapter = self.get_adapter_module(AdapterName.LORA_MOE_4HtoH_ADAPTER) + + lora_output = 0 + if lora_fc2_adapter and self.adapter_cfg[AdapterName.LORA_4HtoH_ADAPTER]['enabled']: + lora_output = lora_fc2_adapter(intermediate_parallel) + elif lora_moe_fc2_adapter and self.adapter_cfg[AdapterName.LORA_MOE_4HtoH_ADAPTER]['enabled']: + lora_output = lora_moe_fc2_adapter(intermediate_parallel, expert_idx) + + output = output + lora_output return output, output_bias +class MCoreSequentialMLPMixin(SequentialMLP, MCoreAdapterModuleMixin): + def mcore_register_adapters(self): + """ + Setup NeMo IA3 adapter to this MCore layer. + """ + self.set_accepted_adapter_types( + [ + LoraMoeHto4HAdapterConfig._target_, + LoraMoe4HtoHAdapterConfig._target_, + ] + ) + + def forward(self, permuted_local_hidden_states, tokens_per_expert): + output_local = torch.zeros_like(permuted_local_hidden_states) + output_bias_local = None + if self.add_bias: + output_bias_local = torch.zeros_like(permuted_local_hidden_states) + + cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0) + # Insert zero at the begining for offset index's convenience + zero_tensor = torch.zeros(1, dtype=torch.long) + cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens)) + for expert_num, expert in enumerate(self.local_experts): + start = cumsum_num_tokens[expert_num] + end = cumsum_num_tokens[expert_num + 1] + hidden = permuted_local_hidden_states[start:end] + output, output_bias = expert(hidden, expert_num) # expert: MLP + + output_local[start:end] = output + if self.add_bias: + output_bias = output_bias.expand_as(output) + output_bias_local[start:end, :] = output_bias + + return output_local, output_bias_local + + class MCoreGPTEmbeddingMixin(LanguageModelEmbedding, MCoreAdapterModuleMixin): def mcore_register_adapters(self): diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index 51510f1b881e..61c1ac5909d9 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -83,6 +83,8 @@ class AdapterName(str, enum.Enum): LORA_Hto4H_ADAPTER = "lora_hto4h_adapter" LORA_UNFUSED_Hto4H_ADAPTER = "lora_unfused_hto4h_adapter" LORA_4HtoH_ADAPTER = "lora_4htoh_adapter" + LORA_MOE_Hto4H_ADAPTER = "lora_moe_hto4h_adapter" + LORA_MOE_4HtoH_ADAPTER = "lora_moe_4htoh_adapter" MULTIMODAL_PROJECTOR_ADAPTER = "mm_projector_adapter" PARALLEL_LINEAR_ADAPTER = "parallel_linear_adapter" @@ -609,6 +611,78 @@ class LoraUnfusedKQVAdapterConfig(AdapterConfig): _target_: str = "{0}.{1}".format(LoraUnfusedKQVAdapter.__module__, LoraUnfusedKQVAdapter.__name__) +class LoraMoeAdapter(nn.Module, AdapterModuleUtil): + def __init__( + self, + num_moe_experts: int, + in_features: int, + out_features: int, + dim: int, + activation: str = 'identity', + norm_position: Optional[str] = None, + norm_type: Optional[str] = None, + column_init_method: str = 'xavier', + row_init_method: str = 'zero', + gather_output: bool = False, + input_is_parallel: bool = False, + dropout: float = 0.0, + model_parallel_config: Optional[ModelParallelConfig] = None, + alpha: float | None = None, + dropout_position: str = 'post', + a2a_experimental: bool = False, + **kwargs, + ): + super().__init__() + + self.num_moe_experts = num_moe_experts + adapter_args = { + "in_features": in_features, + "out_features": out_features, + "dim": dim, + "activation": activation, + "norm_position": norm_position, + "norm_type": norm_type, + "column_init_method": column_init_method, + "row_init_method": row_init_method, + "gather_output": gather_output, + "input_is_parallel": input_is_parallel, + "dropout": dropout, + "model_parallel_config": model_parallel_config, + "alpha": alpha, + "dropout_position": dropout_position, + "a2a_experimental": a2a_experimental, + } + self.expert_adapters = nn.ModuleList() + for i in range(num_moe_experts): + self.expert_adapters.append(ParallelLinearAdapter(**adapter_args)) + + def forward(self, x, expert_idx): + return self.expert_adapters[expert_idx](x) + +@dataclass +class LoraMoeHto4HAdapterConfig(AdapterConfig): + num_moe_experts: int + in_features: int + out_features: int + dim: int + activation: str = 'identity' + norm_position: Optional[str] = None + norm_type: Optional[str] = None + column_init_method: str = 'xavier' + row_init_method: str = 'zero' + gather_output: bool = False + input_is_parallel: bool = False + dropout: float = 0.0 + dropout_position: str = 'post' + alpha: float | None = None + a2a_experimental: bool = False + _target_: str = "{0}.{1}".format(LoraMoeAdapter.__module__, LoraMoeAdapter.__name__) + +@dataclass +class LoraMoe4HtoHAdapterConfig(LoraMoeHto4HAdapterConfig): + input_is_parallel: bool = True + + class PromptEncoderAdapter(nn.Module, AdapterModuleUtil): """ The Tensor Parallel MLP prompt encoder network that is used to generate the virtual diff --git a/nemo/collections/nlp/parts/peft_config.py b/nemo/collections/nlp/parts/peft_config.py index 820e2ad63f24..e7c1a566a673 100644 --- a/nemo/collections/nlp/parts/peft_config.py +++ b/nemo/collections/nlp/parts/peft_config.py @@ -24,7 +24,8 @@ MCoreMLPMixin, MCoreSelfAttentionMixin, MCoreTransformerLayerMixin, - ) + MCoreSequentialMLPMixin, +) except (ImportError, ModuleNotFoundError): MCoreGPTEmbeddingMixin = MCoreSelfAttentionMixin = MCoreTransformerLayerMixin = MCoreMLPMixin = None @@ -42,6 +43,8 @@ ParallelLinearAdapterConfig, ParallelLinearAdapterWeightTyingConfig, PromptEncoderAdapterConfig, + LoraMoeHto4HAdapterConfig, + LoraMoe4HtoHAdapterConfig, ) PEFT_MODULE_MAP = { @@ -169,7 +172,10 @@ def __init__(self, cfg): elif module == PEFT_MODULE_MAP["hto4h_module"]: hto4h_projection_size = cfg.ffn_hidden_size * 2 if fast_glu_activation else cfg.ffn_hidden_size - if lora_cfg.get("variant", "nemo") == "canonical": + if cfg.get('num_moe_experts', None): + _adapter_name = AdapterName.LORA_MOE_Hto4H_ADAPTER + _adapter_cfg_cls = LoraMoeHto4HAdapterConfig + elif lora_cfg.get("variant", "nemo") == "canonical": _adapter_name = AdapterName.LORA_UNFUSED_Hto4H_ADAPTER _adapter_cfg_cls = LoraUnfusedHto4HAdapterConfig else: @@ -180,13 +186,32 @@ def __init__(self, cfg): cfg, lora_cfg, cfg.hidden_size, hto4h_projection_size, _adapter_cfg_cls ) name_key_to_cfg[_adapter_name] = adapter_cfg - name_key_to_mcore_mixins[_adapter_name] = [("mlp", MCoreMLPMixin)] + if _adapter_name == AdapterName.LORA_MOE_Hto4H_ADAPTER: + name_key_to_mcore_mixins[_adapter_name] = [("mlp.experts", MCoreSequentialMLPMixin)] + for i in range(int(cfg.num_moe_experts)): + name_key_to_mcore_mixins[_adapter_name].append((f"mlp.experts.local_experts.{i}", MCoreMLPMixin)) + else: + name_key_to_mcore_mixins[_adapter_name] = [("mlp", MCoreMLPMixin)] + + elif module == PEFT_MODULE_MAP["4htoh_module"]: + if cfg.get('num_moe_experts', None): + _adapter_name = AdapterName.LORA_MOE_4HtoH_ADAPTER + _adapter_cfg_cls = LoraMoe4HtoHAdapterConfig + else: + _adapter_name = AdapterName.LORA_4HtoH_ADAPTER + _adapter_cfg_cls = Lora4HtoHAdapterConfig + adapter_cfg = self._create_lora_config( - cfg, lora_cfg, cfg.ffn_hidden_size, cfg.hidden_size, Lora4HtoHAdapterConfig + cfg, lora_cfg, cfg.ffn_hidden_size, cfg.hidden_size, _adapter_cfg_cls ) - name_key_to_cfg[AdapterName.LORA_4HtoH_ADAPTER] = adapter_cfg - name_key_to_mcore_mixins[AdapterName.LORA_4HtoH_ADAPTER] = [("mlp", MCoreMLPMixin)] + name_key_to_cfg[_adapter_name] = adapter_cfg + if _adapter_name == AdapterName.LORA_MOE_4HtoH_ADAPTER: + name_key_to_mcore_mixins[_adapter_name] = [("mlp.experts", MCoreSequentialMLPMixin)] + for i in range(int(cfg.num_moe_experts)): + name_key_to_mcore_mixins[_adapter_name].append((f"mlp.experts.local_experts.{i}", MCoreMLPMixin)) + else: + name_key_to_mcore_mixins[_adapter_name] = [("mlp", MCoreMLPMixin)] else: logging.error( f"Unrecognized target_module string: {module}.\n" @@ -221,6 +246,8 @@ def _create_lora_config( assert kv_channels is not None, "kv_channels must be provided for canonical Lora" config_args.update({"num_query_groups": num_query_groups, "kv_channels": kv_channels}) config_args.pop("out_features") + elif adapter_cfg_cls in (LoraMoeHto4HAdapterConfig, LoraMoe4HtoHAdapterConfig): + config_args.update({'num_moe_experts': cfg.num_moe_experts}) if lora_cfg.weight_tying: position_embedding_strategy = lora_cfg.get("position_embedding_strategy", None) From cc3e60eed864922db3ce79ab080e0a9ee989cd55 Mon Sep 17 00:00:00 2001 From: cuichenx Date: Thu, 6 Jun 2024 02:54:18 +0000 Subject: [PATCH 2/4] Apply isort and black reformatting Signed-off-by: cuichenx --- .../common/megatron/adapters/mcore_mixins.py | 36 ++++++++++++++----- .../megatron/adapters/parallel_adapters.py | 24 +++++++++---- nemo/collections/nlp/parts/peft_config.py | 17 +++++---- 3 files changed, 55 insertions(+), 22 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py index 6b9a0e942be4..8fd0aed30738 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py @@ -34,13 +34,13 @@ LoraDenseAttentionAdapterConfig, LoraHto4HAdapterConfig, LoraKQVAdapterConfig, + LoraMoe4HtoHAdapterConfig, + LoraMoeHto4HAdapterConfig, LoraUnfusedHto4HAdapterConfig, LoraUnfusedKQVAdapterConfig, MLPInfusedAdapterConfig, ParallelLinearAdapterConfig, PromptEncoderAdapterConfig, - LoraMoeHto4HAdapterConfig, - LoraMoe4HtoHAdapterConfig, ) from nemo.core import adapter_mixins @@ -141,11 +141,19 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None): if SplitAlongDim is not None: # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] - (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list,) + (query, key, value) = SplitAlongDim( + mixed_qkv, + 3, + split_arg_list, + ) else: # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] - (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3,) + (query, key, value) = torch.split( + mixed_qkv, + split_arg_list, + dim=3, + ) # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) @@ -230,11 +238,21 @@ def forward( if self.checkpoint_core_attention: core_attn_out = self._checkpointed_attention_forward( - query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params, + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + packed_seq_params=packed_seq_params, ) else: core_attn_out = self.core_attention( - query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params, + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + packed_seq_params=packed_seq_params, ) if packed_seq_params is not None: @@ -324,7 +342,9 @@ def forward(self, hidden_states, expert_idx=None): intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel) elif self.activation_func == F.silu and self.config.gated_linear_unit: intermediate_parallel = bias_swiglu_impl( - intermediate_parallel, bias_parallel, self.config.activation_func_fp8_input_store, + intermediate_parallel, + bias_parallel, + self.config.activation_func_fp8_input_store, ) else: @@ -364,6 +384,7 @@ def glu(x): return output, output_bias + class MCoreSequentialMLPMixin(SequentialMLP, MCoreAdapterModuleMixin): def mcore_register_adapters(self): """ @@ -400,7 +421,6 @@ def forward(self, permuted_local_hidden_states, tokens_per_expert): return output_local, output_bias_local - class MCoreGPTEmbeddingMixin(LanguageModelEmbedding, MCoreAdapterModuleMixin): def mcore_register_adapters(self): """ diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index 61c1ac5909d9..d3a6655829f7 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -277,7 +277,9 @@ def _get_init_fn(self, init_method: str): raise NotImplementedError("out_init_method should be zero, normal, kaiming or xavier") return init_fn - def adapter_unfreeze(self,): + def adapter_unfreeze( + self, + ): """ Can be customized to allow for selective training of only some params in the PEFT. """ @@ -404,7 +406,7 @@ class LoraQAdapter(ParallelLinearAdapter): class LoraDenseAttentionAdapter(ParallelLinearAdapter): """ - Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes + Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes and they do not use an bottleneck activation function """ @@ -413,7 +415,7 @@ class LoraDenseAttentionAdapter(ParallelLinearAdapter): class LoraHto4HAdapter(ParallelLinearAdapter): """ - Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes + Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes and they do not use an bottleneck activation function """ @@ -422,7 +424,7 @@ class LoraHto4HAdapter(ParallelLinearAdapter): class Lora4HtoHAdapter(ParallelLinearAdapter): """ - Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes + Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes and they do not use an bottleneck activation function """ @@ -659,6 +661,7 @@ def __init__( def forward(self, x, expert_idx): return self.expert_adapters[expert_idx](x) + @dataclass class LoraMoeHto4HAdapterConfig(AdapterConfig): num_moe_experts: int @@ -678,6 +681,7 @@ class LoraMoeHto4HAdapterConfig(AdapterConfig): a2a_experimental: bool = False _target_: str = "{0}.{1}".format(LoraMoeAdapter.__module__, LoraMoeAdapter.__name__) + @dataclass class LoraMoe4HtoHAdapterConfig(LoraMoeHto4HAdapterConfig): input_is_parallel: bool = True @@ -762,14 +766,20 @@ def set_inference_table(self, prompt_representation: torch.Tensor): self.is_inference_ready = True return True - def clear_inference_table(self,): + def clear_inference_table( + self, + ): self.inference_table.fill_(0.0) self.is_inference_ready = False - def get_inference_table(self,): + def get_inference_table( + self, + ): return self.inference_table.data - def inner_forward(self,): + def inner_forward( + self, + ): input_embeds = self.embedding(self.indices).unsqueeze(0) intermediate_parallel, bias_parallel = self.first(input_embeds) intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel) diff --git a/nemo/collections/nlp/parts/peft_config.py b/nemo/collections/nlp/parts/peft_config.py index e7c1a566a673..eee8aca3058c 100644 --- a/nemo/collections/nlp/parts/peft_config.py +++ b/nemo/collections/nlp/parts/peft_config.py @@ -23,9 +23,9 @@ MCoreGPTEmbeddingMixin, MCoreMLPMixin, MCoreSelfAttentionMixin, - MCoreTransformerLayerMixin, MCoreSequentialMLPMixin, -) + MCoreTransformerLayerMixin, + ) except (ImportError, ModuleNotFoundError): MCoreGPTEmbeddingMixin = MCoreSelfAttentionMixin = MCoreTransformerLayerMixin = MCoreMLPMixin = None @@ -37,14 +37,14 @@ LoraHto4HAdapterConfig, LoraKQVAdapterConfig, LoraKQVAdapterWeightTyingConfig, + LoraMoe4HtoHAdapterConfig, + LoraMoeHto4HAdapterConfig, LoraUnfusedHto4HAdapterConfig, LoraUnfusedKQVAdapterConfig, MLPInfusedAdapterConfig, ParallelLinearAdapterConfig, ParallelLinearAdapterWeightTyingConfig, PromptEncoderAdapterConfig, - LoraMoeHto4HAdapterConfig, - LoraMoe4HtoHAdapterConfig, ) PEFT_MODULE_MAP = { @@ -189,11 +189,12 @@ def __init__(self, cfg): if _adapter_name == AdapterName.LORA_MOE_Hto4H_ADAPTER: name_key_to_mcore_mixins[_adapter_name] = [("mlp.experts", MCoreSequentialMLPMixin)] for i in range(int(cfg.num_moe_experts)): - name_key_to_mcore_mixins[_adapter_name].append((f"mlp.experts.local_experts.{i}", MCoreMLPMixin)) + name_key_to_mcore_mixins[_adapter_name].append( + (f"mlp.experts.local_experts.{i}", MCoreMLPMixin) + ) else: name_key_to_mcore_mixins[_adapter_name] = [("mlp", MCoreMLPMixin)] - elif module == PEFT_MODULE_MAP["4htoh_module"]: if cfg.get('num_moe_experts', None): _adapter_name = AdapterName.LORA_MOE_4HtoH_ADAPTER @@ -209,7 +210,9 @@ def __init__(self, cfg): if _adapter_name == AdapterName.LORA_MOE_4HtoH_ADAPTER: name_key_to_mcore_mixins[_adapter_name] = [("mlp.experts", MCoreSequentialMLPMixin)] for i in range(int(cfg.num_moe_experts)): - name_key_to_mcore_mixins[_adapter_name].append((f"mlp.experts.local_experts.{i}", MCoreMLPMixin)) + name_key_to_mcore_mixins[_adapter_name].append( + (f"mlp.experts.local_experts.{i}", MCoreMLPMixin) + ) else: name_key_to_mcore_mixins[_adapter_name] = [("mlp", MCoreMLPMixin)] else: From 5163a9be8e201c45a829afc7abda129fbef1fd27 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Thu, 6 Jun 2024 17:56:09 -0700 Subject: [PATCH 3/4] fix dangling adapter Signed-off-by: Chen Cui --- .../modules/common/megatron/adapters/mcore_mixins.py | 11 +++-------- .../common/megatron/adapters/parallel_adapters.py | 12 +++--------- 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py index 8fd0aed30738..461149b46a78 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py @@ -280,7 +280,7 @@ def forward( class MCoreMLPMixin(MLP, MCoreAdapterModuleMixin): def mcore_register_adapters(self): """ - Setup NeMo IA3 adapter to this MCore layer. + Setup NeMo IA3 and LoRA adapter to this MCore layer. """ self.set_accepted_adapter_types( [ @@ -388,14 +388,9 @@ def glu(x): class MCoreSequentialMLPMixin(SequentialMLP, MCoreAdapterModuleMixin): def mcore_register_adapters(self): """ - Setup NeMo IA3 adapter to this MCore layer. + We don't want the SequentialMLP layer to take any adapters. We only want to override the forward() behavior """ - self.set_accepted_adapter_types( - [ - LoraMoeHto4HAdapterConfig._target_, - LoraMoe4HtoHAdapterConfig._target_, - ] - ) + pass def forward(self, permuted_local_hidden_states, tokens_per_expert): output_local = torch.zeros_like(permuted_local_hidden_states) diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index d3a6655829f7..fb904013fbfc 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -766,20 +766,14 @@ def set_inference_table(self, prompt_representation: torch.Tensor): self.is_inference_ready = True return True - def clear_inference_table( - self, - ): + def clear_inference_table(self): self.inference_table.fill_(0.0) self.is_inference_ready = False - def get_inference_table( - self, - ): + def get_inference_table(self): return self.inference_table.data - def inner_forward( - self, - ): + def inner_forward(self): input_embeds = self.embedding(self.indices).unsqueeze(0) intermediate_parallel, bias_parallel = self.first(input_embeds) intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel) From 7ac5a5d7715f16c48442fe337fbf6424b85f9084 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Tue, 11 Jun 2024 10:47:41 -0700 Subject: [PATCH 4/4] update to newest mcore code Signed-off-by: Chen Cui --- .../nlp/modules/common/megatron/adapters/mcore_mixins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py index eba1497b33b8..bcfe07f702a0 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py @@ -400,7 +400,7 @@ def forward(self, permuted_local_hidden_states, tokens_per_expert): cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0) # Insert zero at the begining for offset index's convenience - zero_tensor = torch.zeros(1, dtype=torch.long) + zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device) cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens)) for expert_num, expert in enumerate(self.local_experts): start = cumsum_num_tokens[expert_num]