Skip to content

Commit

Permalink
removing all reshapes to test perf
Browse files Browse the repository at this point in the history
  • Loading branch information
NouamaneTazi committed Sep 21, 2022
1 parent c0dd0e9 commit 006ccb8
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def forward(self, hidden_states):

# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
hidden_states = hidden_states.reshape(batch, channel, height, width)

# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
Expand Down Expand Up @@ -150,10 +150,10 @@ def forward(self, hidden_states, context=None):
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
hidden_states = hidden_states.reshape(batch, height * weight, channel)
for block in self.transformer_blocks:
hidden_states = block(hidden_states, context=context)
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2)
hidden_states = hidden_states.reshape(batch, channel, height, weight)
hidden_states = self.proj_out(hidden_states)
return hidden_states + residual

Expand Down Expand Up @@ -262,9 +262,9 @@ def forward(self, hidden_states, context=None, mask=None):
key = self.to_k(context)
value = self.to_v(context)

query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
# query = self.reshape_heads_to_batch_dim(query)
# key = self.reshape_heads_to_batch_dim(key)
# value = self.reshape_heads_to_batch_dim(value)

# TODO(PVP) - mask is currently never used. Remember to re-implement when used

Expand All @@ -290,7 +290,7 @@ def _attention(self, query, key, value):
# compute attention output
hidden_states = torch.matmul(attention_probs, value)
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
# hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states

def _sliced_attention(self, query, key, value, sequence_length, dim):
Expand All @@ -309,7 +309,7 @@ def _sliced_attention(self, query, key, value, sequence_length, dim):
hidden_states[start_idx:end_idx] = attn_slice

# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
# hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states


Expand Down

0 comments on commit 006ccb8

Please sign in to comment.