Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PEFT fix: T5 prefix-tuning with new cache format #34312

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/transformers/models/longt5/modeling_longt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,9 +1479,23 @@ def forward(
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if past_key_values is not None and not past_key_values.is_updated.get(0, False):
past_seen_tokens_cross_attn = past_key_values.cross_attention_cache.get_seq_length()
encoder_sequence_length += past_seen_tokens_cross_attn

if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
)
# Edge case for PEFT tuning when several virtual tokens are added. We'll need to create attn mask
# of shape current/actual `encoder_hidden_states` + vitual token `encoder_hidden_states`
elif encoder_attention_mask is not None and encoder_attention_mask.shape[1] < encoder_sequence_length:
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length - encoder_attention_mask.shape[1])
new_encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=encoder_attention_mask.device, dtype=torch.long
)
encoder_attention_mask = torch.cat([new_encoder_attention_mask, encoder_attention_mask], dim=-1)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
Expand Down
14 changes: 13 additions & 1 deletion src/transformers/models/mt5/modeling_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,11 +1046,23 @@ def forward(
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if past_key_values is not None and not past_key_values.is_updated.get(0, False):
past_seen_tokens_cross_attn = past_key_values.cross_attention_cache.get_seq_length()
encoder_sequence_length += past_seen_tokens_cross_attn

if encoder_attention_mask is None:
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
)
# Edge case for PEFT tuning when several virtual tokens are added. We'll need to create attn mask
# of shape current/actual `encoder_hidden_states` + vitual token `encoder_hidden_states`
elif encoder_attention_mask is not None and encoder_attention_mask.shape[1] < encoder_sequence_length:
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length - encoder_attention_mask.shape[1])
new_encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=encoder_attention_mask.device, dtype=torch.long
)
encoder_attention_mask = torch.cat([new_encoder_attention_mask, encoder_attention_mask], dim=-1)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
Expand Down
20 changes: 17 additions & 3 deletions src/transformers/models/pix2struct/modeling_pix2struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,11 +1440,25 @@ def forward(

# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if encoder_hidden_states is not None:
if self.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if past_key_values is not None and not past_key_values.is_updated.get(0, False):
past_seen_tokens_cross_attn = past_key_values.cross_attention_cache.get_seq_length()
encoder_sequence_length += past_seen_tokens_cross_attn

if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
)
# Edge case for PEFT tuning when several virtual tokens are added. We'll need to create attn mask
# of shape current/actual `encoder_hidden_states` + vitual token `encoder_hidden_states`
elif encoder_attention_mask is not None and encoder_attention_mask.shape[1] < encoder_sequence_length:
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length - encoder_attention_mask.shape[1])
new_encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=encoder_attention_mask.device, dtype=torch.long
)
encoder_attention_mask = torch.cat([new_encoder_attention_mask, encoder_attention_mask], dim=-1)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
Expand Down
18 changes: 16 additions & 2 deletions src/transformers/models/pop2piano/modeling_pop2piano.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,9 +882,23 @@ def forward(
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if past_key_values is not None and not past_key_values.is_updated.get(0, False):
past_seen_tokens_cross_attn = past_key_values.cross_attention_cache.get_seq_length()
encoder_sequence_length += past_seen_tokens_cross_attn

if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
)
# Edge case for PEFT tuning when several virtual tokens are added. We'll need to create attn mask
# of shape current/actual `encoder_hidden_states` + vitual token `encoder_hidden_states`
elif encoder_attention_mask is not None and encoder_attention_mask.shape[1] < encoder_sequence_length:
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length - encoder_attention_mask.shape[1])
new_encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=encoder_attention_mask.device, dtype=torch.long
)
encoder_attention_mask = torch.cat([new_encoder_attention_mask, encoder_attention_mask], dim=-1)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1004,9 +1004,23 @@ def forward(
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if past_key_values is not None and not past_key_values.is_updated.get(0, False):
past_seen_tokens_cross_attn = past_key_values.cross_attention_cache.get_seq_length()
encoder_sequence_length += past_seen_tokens_cross_attn

if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
)
# Edge case for PEFT tuning when several virtual tokens are added. We'll need to create attn mask
# of shape current/actual `encoder_hidden_states` + vitual token `encoder_hidden_states`
elif encoder_attention_mask is not None and encoder_attention_mask.shape[1] < encoder_sequence_length:
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length - encoder_attention_mask.shape[1])
new_encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=encoder_attention_mask.device, dtype=torch.long
)
encoder_attention_mask = torch.cat([new_encoder_attention_mask, encoder_attention_mask], dim=-1)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
Expand Down
14 changes: 13 additions & 1 deletion src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,11 +1059,23 @@ def forward(
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if past_key_values is not None and not past_key_values.is_updated.get(0, False):
past_seen_tokens_cross_attn = past_key_values.cross_attention_cache.get_seq_length()
encoder_sequence_length += past_seen_tokens_cross_attn

if encoder_attention_mask is None:
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
)
# Edge case for PEFT tuning when several virtual tokens are added. We'll need to create attn mask
# of shape current/actual `encoder_hidden_states` + vitual token `encoder_hidden_states`
elif encoder_attention_mask is not None and encoder_attention_mask.shape[1] < encoder_sequence_length:
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length - encoder_attention_mask.shape[1])
new_encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=encoder_attention_mask.device, dtype=torch.long
)
encoder_attention_mask = torch.cat([new_encoder_attention_mask, encoder_attention_mask], dim=-1)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
Expand Down
20 changes: 19 additions & 1 deletion src/transformers/models/udop/modeling_udop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,7 +1434,25 @@ def forward(
causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min

if self.is_decoder and encoder_attention_mask is not None:
if self.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
if past_key_values is not None and not past_key_values.is_updated.get(0, False):
past_seen_tokens_cross_attn = past_key_values.cross_attention_cache.get_seq_length()
encoder_sequence_length += past_seen_tokens_cross_attn

if encoder_attention_mask is None:
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
)
# Edge case for PEFT tuning when several virtual tokens are added. We'll need to create attn mask
# of shape current/actual `encoder_hidden_states` + vitual token `encoder_hidden_states`
elif encoder_attention_mask is not None and encoder_attention_mask.shape[1] < encoder_sequence_length:
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length - encoder_attention_mask.shape[1])
new_encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=encoder_attention_mask.device, dtype=torch.long
)
encoder_attention_mask = torch.cat([new_encoder_attention_mask, encoder_attention_mask], dim=-1)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
Expand Down
18 changes: 16 additions & 2 deletions src/transformers/models/umt5/modeling_umt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,9 +745,23 @@ def forward(
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if past_key_values is not None and not past_key_values.is_updated.get(0, False):
past_seen_tokens_cross_attn = past_key_values.cross_attention_cache.get_seq_length()
encoder_sequence_length += past_seen_tokens_cross_attn

if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
)
# Edge case for PEFT tuning when several virtual tokens are added. We'll need to create attn mask
# of shape current/actual `encoder_hidden_states` + vitual token `encoder_hidden_states`
elif encoder_attention_mask is not None and encoder_attention_mask.shape[1] < encoder_sequence_length:
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length - encoder_attention_mask.shape[1])
new_encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=encoder_attention_mask.device, dtype=torch.long
)
encoder_attention_mask = torch.cat([new_encoder_attention_mask, encoder_attention_mask], dim=-1)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
Expand Down
Loading