Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: michal2409 <michal2409@users.noreply.github.com>
  • Loading branch information
michal2409 committed May 16, 2024
1 parent 2d50d12 commit 94c4a14
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,19 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
if SplitAlongDim is not None:

# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list,)
(query, key, value) = SplitAlongDim(
mixed_qkv,
3,
split_arg_list,
)
else:

# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3,)
(query, key, value) = torch.split(
mixed_qkv,
split_arg_list,
dim=3,
)

# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)
Expand Down Expand Up @@ -231,11 +239,21 @@ def forward(

if self.checkpoint_core_attention:
core_attn_out = self._checkpointed_attention_forward(
query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params,
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
packed_seq_params=packed_seq_params,
)
else:
core_attn_out = self.core_attention(
query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params,
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
packed_seq_params=packed_seq_params,
)

if packed_seq_params is not None:
Expand Down Expand Up @@ -316,7 +334,9 @@ def forward(self, hidden_states):
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
elif self.activation_func == F.silu and self.config.gated_linear_unit:
intermediate_parallel = bias_swiglu_impl(
intermediate_parallel, bias_parallel, self.config.activation_func_fp8_input_store,
intermediate_parallel,
bias_parallel,
self.config.activation_func_fp8_input_store,
)

else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,9 @@ def _get_init_fn(self, init_method: str):
raise NotImplementedError("out_init_method should be zero, normal, kaiming or xavier")
return init_fn

def adapter_unfreeze(self,):
def adapter_unfreeze(
self,
):
"""
Can be customized to allow for selective training of only some params in the PEFT.
"""
Expand Down Expand Up @@ -404,7 +406,7 @@ class LoraQAdapter(ParallelLinearAdapter):

class LoraDenseAttentionAdapter(ParallelLinearAdapter):
"""
Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes
Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes
and they do not use an bottleneck activation function
"""

Expand All @@ -413,7 +415,7 @@ class LoraDenseAttentionAdapter(ParallelLinearAdapter):

class LoraHto4HAdapter(ParallelLinearAdapter):
"""
Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes
Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes
and they do not use an bottleneck activation function
"""

Expand All @@ -422,7 +424,7 @@ class LoraHto4HAdapter(ParallelLinearAdapter):

class Lora4HtoHAdapter(ParallelLinearAdapter):
"""
Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes
Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes
and they do not use an bottleneck activation function
"""

Expand Down Expand Up @@ -690,14 +692,20 @@ def set_inference_table(self, prompt_representation: torch.Tensor):
self.is_inference_ready = True
return True

def clear_inference_table(self,):
def clear_inference_table(
self,
):
self.inference_table.fill_(0.0)
self.is_inference_ready = False

def get_inference_table(self,):
def get_inference_table(
self,
):
return self.inference_table.data

def inner_forward(self,):
def inner_forward(
self,
):
input_embeds = self.embedding(self.indices).unsqueeze(0)
intermediate_parallel, bias_parallel = self.first(input_embeds)
intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel)
Expand Down

0 comments on commit 94c4a14

Please sign in to comment.