Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: Victor49152 <Victor49152@users.noreply.github.com>
  • Loading branch information
Victor49152 committed Oct 16, 2024
1 parent d78c682 commit fbd6987
Showing 1 changed file with 35 additions and 24 deletions.
59 changes: 35 additions & 24 deletions nemo/collections/diffusion/models/dit/dit_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
bias=self.config.add_qkv_bias,
skip_bias_add=False,
is_expert=False,
tp_comm_buffer_name='qkv'
tp_comm_buffer_name='qkv',
)

if not context_pre_only:
Expand Down Expand Up @@ -126,23 +126,22 @@ def __init__(
else:
self.added_k_layernorm = None


def _split_qkv(self, mixed_qkv):
# [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
new_tensor_shape = mixed_qkv.size()[:-1] + (
self.num_query_groups_per_partition,
(
(self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
* self.hidden_size_per_attention_head
(self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2)
* self.hidden_size_per_attention_head
),
)
mixed_qkv = mixed_qkv.view(*new_tensor_shape)

split_arg_list = [
(
self.num_attention_heads_per_partition
// self.num_query_groups_per_partition
* self.hidden_size_per_attention_head
self.num_attention_heads_per_partition
// self.num_query_groups_per_partition
* self.hidden_size_per_attention_head
),
self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head,
Expand All @@ -151,11 +150,19 @@ def _split_qkv(self, mixed_qkv):
if SplitAlongDim is not None:

# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list, )
(query, key, value) = SplitAlongDim(
mixed_qkv,
3,
split_arg_list,
)
else:

# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3, )
(query, key, value) = torch.split(
mixed_qkv,
split_arg_list,
dim=3,
)

# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)
Expand Down Expand Up @@ -202,14 +209,14 @@ def get_added_query_key_value_tensors(self, added_hidden_states, key_value_state
return query, key, value

def forward(
self,
hidden_states,
attention_mask,
key_value_states=None,
inference_params=None,
rotary_pos_emb=None,
packed_seq_params=None,
additional_hidden_states=None,
self,
hidden_states,
attention_mask,
key_value_states=None,
inference_params=None,
rotary_pos_emb=None,
packed_seq_params=None,
additional_hidden_states=None,
):
# hidden_states: [sq, b, h]

Expand All @@ -226,7 +233,6 @@ def forward(
query, key, value = self.get_query_key_value_tensors(hidden_states)
added_query, added_key, added_value = self.get_added_query_key_value_tensors(additional_hidden_states)


query = torch.cat([added_query, query], dim=0)
key = torch.cat([added_key, key], dim=0)
value = torch.cat([added_value, value], dim=0)
Expand All @@ -243,7 +249,6 @@ def forward(
key = key.squeeze(1)
value = value.squeeze(1)


# ================================================
# relative positional embedding (rotary embedding)
# ================================================
Expand All @@ -256,10 +261,16 @@ def forward(
else:
cu_seqlens_q = cu_seqlens_kv = None
query = apply_rotary_pos_emb(
query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q,
query,
q_pos_emb,
config=self.config,
cu_seqlens=cu_seqlens_q,
)
key = apply_rotary_pos_emb(
key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv,
key,
k_pos_emb,
config=self.config,
cu_seqlens=cu_seqlens_kv,
)

# TODO, can apply positional embedding to value_layer so it has
Expand Down Expand Up @@ -299,18 +310,18 @@ def forward(
# =================
# Output. [sq, b, h]
# =================
encoder_attention_output = core_attn_out[:additional_hidden_states.shape[0], :, :]
attention_output = core_attn_out[additional_hidden_states.shape[0]:, :, :]
encoder_attention_output = core_attn_out[: additional_hidden_states.shape[0], :, :]
attention_output = core_attn_out[additional_hidden_states.shape[0] :, :, :]

output, bias = self.linear_proj(attention_output)
encoder_output, encoder_bias = self.added_linear_proj(encoder_attention_output)

output = output + bias
encoder_output = encoder_output + encoder_bias


return output, encoder_output


class FluxSingleAttention(SelfAttention):
"""Self-attention layer class
Expand Down

0 comments on commit fbd6987

Please sign in to comment.