From 006ccb8a8c6bc7eb7e512392e692a29d9b1553cd Mon Sep 17 00:00:00 2001 From: Nouamane Tazi Date: Wed, 21 Sep 2022 15:22:59 +0000 Subject: [PATCH] removing all reshapes to test perf --- src/diffusers/models/attention.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 02454d603632..f3623c6e7ed5 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -91,7 +91,7 @@ def forward(self, hidden_states): # compute next hidden_states hidden_states = self.proj_attn(hidden_states) - hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + hidden_states = hidden_states.reshape(batch, channel, height, width) # res connect and rescale hidden_states = (hidden_states + residual) / self.rescale_output_factor @@ -150,10 +150,10 @@ def forward(self, hidden_states, context=None): residual = hidden_states hidden_states = self.norm(hidden_states) hidden_states = self.proj_in(hidden_states) - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel) + hidden_states = hidden_states.reshape(batch, height * weight, channel) for block in self.transformer_blocks: hidden_states = block(hidden_states, context=context) - hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2) + hidden_states = hidden_states.reshape(batch, channel, height, weight) hidden_states = self.proj_out(hidden_states) return hidden_states + residual @@ -262,9 +262,9 @@ def forward(self, hidden_states, context=None, mask=None): key = self.to_k(context) value = self.to_v(context) - query = self.reshape_heads_to_batch_dim(query) - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) + # query = self.reshape_heads_to_batch_dim(query) + # key = self.reshape_heads_to_batch_dim(key) + # value = self.reshape_heads_to_batch_dim(value) # TODO(PVP) - mask is currently never used. Remember to re-implement when used @@ -290,7 +290,7 @@ def _attention(self, query, key, value): # compute attention output hidden_states = torch.matmul(attention_probs, value) # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + # hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states def _sliced_attention(self, query, key, value, sequence_length, dim): @@ -309,7 +309,7 @@ def _sliced_attention(self, query, key, value, sequence_length, dim): hidden_states[start_idx:end_idx] = attn_slice # reshape hidden_states - hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + # hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states