diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py index dc111766d222..a85c155cc0a8 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py @@ -142,11 +142,19 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None): if SplitAlongDim is not None: # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] - (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list,) + (query, key, value) = SplitAlongDim( + mixed_qkv, + 3, + split_arg_list, + ) else: # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] - (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3,) + (query, key, value) = torch.split( + mixed_qkv, + split_arg_list, + dim=3, + ) # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) @@ -231,11 +239,21 @@ def forward( if self.checkpoint_core_attention: core_attn_out = self._checkpointed_attention_forward( - query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params, + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + packed_seq_params=packed_seq_params, ) else: core_attn_out = self.core_attention( - query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params, + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + packed_seq_params=packed_seq_params, ) if packed_seq_params is not None: @@ -316,7 +334,9 @@ def forward(self, hidden_states): intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel) elif self.activation_func == F.silu and self.config.gated_linear_unit: intermediate_parallel = bias_swiglu_impl( - intermediate_parallel, bias_parallel, self.config.activation_func_fp8_input_store, + intermediate_parallel, + bias_parallel, + self.config.activation_func_fp8_input_store, ) else: diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index 33ddc542be06..458ff58de47a 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -277,7 +277,9 @@ def _get_init_fn(self, init_method: str): raise NotImplementedError("out_init_method should be zero, normal, kaiming or xavier") return init_fn - def adapter_unfreeze(self,): + def adapter_unfreeze( + self, + ): """ Can be customized to allow for selective training of only some params in the PEFT. """ @@ -404,7 +406,7 @@ class LoraQAdapter(ParallelLinearAdapter): class LoraDenseAttentionAdapter(ParallelLinearAdapter): """ - Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes + Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes and they do not use an bottleneck activation function """ @@ -413,7 +415,7 @@ class LoraDenseAttentionAdapter(ParallelLinearAdapter): class LoraHto4HAdapter(ParallelLinearAdapter): """ - Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes + Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes and they do not use an bottleneck activation function """ @@ -422,7 +424,7 @@ class LoraHto4HAdapter(ParallelLinearAdapter): class Lora4HtoHAdapter(ParallelLinearAdapter): """ - Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes + Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes and they do not use an bottleneck activation function """ @@ -690,14 +692,20 @@ def set_inference_table(self, prompt_representation: torch.Tensor): self.is_inference_ready = True return True - def clear_inference_table(self,): + def clear_inference_table( + self, + ): self.inference_table.fill_(0.0) self.is_inference_ready = False - def get_inference_table(self,): + def get_inference_table( + self, + ): return self.inference_table.data - def inner_forward(self,): + def inner_forward( + self, + ): input_embeds = self.embedding(self.indices).unsqueeze(0) intermediate_parallel, bias_parallel = self.first(input_embeds) intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel)