From fbd69879f678db2d492ed6ea98c2d57b1a83ed47 Mon Sep 17 00:00:00 2001 From: Victor49152 Date: Wed, 16 Oct 2024 16:16:08 +0000 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: Victor49152 --- .../diffusion/models/dit/dit_attention.py | 59 +++++++++++-------- 1 file changed, 35 insertions(+), 24 deletions(-) diff --git a/nemo/collections/diffusion/models/dit/dit_attention.py b/nemo/collections/diffusion/models/dit/dit_attention.py index 434a24faf51e..9e60b11dd1c6 100644 --- a/nemo/collections/diffusion/models/dit/dit_attention.py +++ b/nemo/collections/diffusion/models/dit/dit_attention.py @@ -69,7 +69,7 @@ def __init__( bias=self.config.add_qkv_bias, skip_bias_add=False, is_expert=False, - tp_comm_buffer_name='qkv' + tp_comm_buffer_name='qkv', ) if not context_pre_only: @@ -126,23 +126,22 @@ def __init__( else: self.added_k_layernorm = None - def _split_qkv(self, mixed_qkv): # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] new_tensor_shape = mixed_qkv.size()[:-1] + ( self.num_query_groups_per_partition, ( - (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) - * self.hidden_size_per_attention_head + (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) + * self.hidden_size_per_attention_head ), ) mixed_qkv = mixed_qkv.view(*new_tensor_shape) split_arg_list = [ ( - self.num_attention_heads_per_partition - // self.num_query_groups_per_partition - * self.hidden_size_per_attention_head + self.num_attention_heads_per_partition + // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head ), self.hidden_size_per_attention_head, self.hidden_size_per_attention_head, @@ -151,11 +150,19 @@ def _split_qkv(self, mixed_qkv): if SplitAlongDim is not None: # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] - (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list, ) + (query, key, value) = SplitAlongDim( + mixed_qkv, + 3, + split_arg_list, + ) else: # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] - (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3, ) + (query, key, value) = torch.split( + mixed_qkv, + split_arg_list, + dim=3, + ) # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) @@ -202,14 +209,14 @@ def get_added_query_key_value_tensors(self, added_hidden_states, key_value_state return query, key, value def forward( - self, - hidden_states, - attention_mask, - key_value_states=None, - inference_params=None, - rotary_pos_emb=None, - packed_seq_params=None, - additional_hidden_states=None, + self, + hidden_states, + attention_mask, + key_value_states=None, + inference_params=None, + rotary_pos_emb=None, + packed_seq_params=None, + additional_hidden_states=None, ): # hidden_states: [sq, b, h] @@ -226,7 +233,6 @@ def forward( query, key, value = self.get_query_key_value_tensors(hidden_states) added_query, added_key, added_value = self.get_added_query_key_value_tensors(additional_hidden_states) - query = torch.cat([added_query, query], dim=0) key = torch.cat([added_key, key], dim=0) value = torch.cat([added_value, value], dim=0) @@ -243,7 +249,6 @@ def forward( key = key.squeeze(1) value = value.squeeze(1) - # ================================================ # relative positional embedding (rotary embedding) # ================================================ @@ -256,10 +261,16 @@ def forward( else: cu_seqlens_q = cu_seqlens_kv = None query = apply_rotary_pos_emb( - query, q_pos_emb, config=self.config, cu_seqlens=cu_seqlens_q, + query, + q_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_q, ) key = apply_rotary_pos_emb( - key, k_pos_emb, config=self.config, cu_seqlens=cu_seqlens_kv, + key, + k_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_kv, ) # TODO, can apply positional embedding to value_layer so it has @@ -299,8 +310,8 @@ def forward( # ================= # Output. [sq, b, h] # ================= - encoder_attention_output = core_attn_out[:additional_hidden_states.shape[0], :, :] - attention_output = core_attn_out[additional_hidden_states.shape[0]:, :, :] + encoder_attention_output = core_attn_out[: additional_hidden_states.shape[0], :, :] + attention_output = core_attn_out[additional_hidden_states.shape[0] :, :, :] output, bias = self.linear_proj(attention_output) encoder_output, encoder_bias = self.added_linear_proj(encoder_attention_output) @@ -308,9 +319,9 @@ def forward( output = output + bias encoder_output = encoder_output + encoder_bias - return output, encoder_output + class FluxSingleAttention(SelfAttention): """Self-attention layer class