Skip to content

Commit

Permalink
Upgrade Transformers to v4.42.x (adapter-hub#719)
Browse files Browse the repository at this point in the history
Changes needed for sync:
- remove setting `_hf_peft_config_loaded` for HF Trainer
- fix BEiT interpolate_pos_encoding
- add sdpa to GPT-2
- add LlamaForTokenClassification head conversion
- copy changes to Mistral implementation

---------

Co-authored-by: Leon Engländer <leon.englaender@gmail.com>
  • Loading branch information
dainis-boumber and lenglaender committed Aug 30, 2024
1 parent d8bac6d commit e41b094
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 79 deletions.
6 changes: 2 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
"sphinx-multiversion==0.2.4",
"timeout-decorator",
"torch>=1.10,!=1.12.0",
"transformers~=4.41.2",
"transformers~=4.42.4",
]


Expand Down Expand Up @@ -143,13 +143,11 @@ def deps_list(*pkgs):
description="A Unified Library for Parameter-Efficient and Modular Transfer Learning",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
keywords="NLP deep learning transformer pytorch BERT adapters",
keywords="NLP deep learning transformer pytorch BERT adapters PEFT LoRA",
license="Apache",
url="https://github.com/adapter-hub/adapters",
package_dir={"": "src"},
packages=find_packages("src"),
include_package_data=True,
package_data={"transformers": ["*.cu", "*.cpp", "*.cuh", "*.h", "*.pyx"]},
zip_safe=False,
extras_require=extras,
python_requires=">=3.8.0",
Expand Down
16 changes: 16 additions & 0 deletions src/adapters/head_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,14 @@
},
"layers": [None, "qa_outputs"],
},
"LlamaForTokenClassification": {
"config": {
"head_type": "tagging",
"layers": 1,
"activation_function": None,
},
"layers": [None, "score"],
},
# Mistral
"MistralForSequenceClassification": {
"config": {
Expand All @@ -690,6 +698,14 @@
},
"layers": ["lm_head"],
},
"MistralForTokenClassification": {
"config": {
"head_type": "tagging",
"layers": 1,
"activation_function": None,
},
"layers": [None, "score"],
},
# Electra
"ElectraForTokenClassification": {
"config": {
Expand Down
3 changes: 0 additions & 3 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,9 +1459,6 @@ def train_adapter(self, adapter_setup: Union[list, AdapterCompositionBlock], tra
if not train_embeddings:
self.freeze_embeddings()

# Hack to prevent HF Trainer from throwing an error due to peft missing.
self._hf_peft_config_loaded = True

def train_adapter_fusion(self, adapter_setup: Union[list, AdapterCompositionBlock], unfreeze_adapters=False):
"""
Sets the model into mode for training of adapter fusion determined by a list of adapter names. If
Expand Down
9 changes: 7 additions & 2 deletions src/adapters/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def forward(
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional[BeitRelativePositionBias] = None,
relative_position_bias: Optional["BeitRelativePositionBias"] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)

Expand All @@ -50,7 +51,9 @@ def forward(

# Add relative position bias if present.
if self.relative_position_bias is not None:
attention_scores = attention_scores + self.relative_position_bias().unsqueeze(0)
attention_scores = attention_scores + self.relative_position_bias(
interpolate_pos_encoding, attention_scores.shape[2]
).unsqueeze(0)

# Add shared relative position bias if provided.
if relative_position_bias is not None:
Expand Down Expand Up @@ -87,12 +90,14 @@ def forward(
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
relative_position_bias: Optional[BeitRelativePositionBias] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
head_mask,
output_attentions=output_attentions,
relative_position_bias=relative_position_bias,
interpolate_pos_encoding=interpolate_pos_encoding,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
Expand Down
106 changes: 105 additions & 1 deletion src/adapters/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@
import torch
import torch.utils.checkpoint

from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2SdpaAttention
from transformers.utils import logging

from ...composition import adjust_tensors_for_parallel, adjust_tensors_for_parallel_
from .mixin_gpt2 import GPT2AttentionAdaptersMixin, GPT2DecoderBlockAdaptersMixin


logger = logging.get_logger(__name__)


class GPT2AttentionWithAdapters(GPT2AttentionAdaptersMixin, GPT2Attention):
def forward(
self,
Expand Down Expand Up @@ -65,8 +69,10 @@ def forward(
else:
present = None

# >>> START AH Changes <<<
key, value, attention_mask = self.prefix_tuning(key, value, hidden_states, attention_mask)
(query,) = adjust_tensors_for_parallel(key, query)
# >>> END AH Changes <<<

if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
Expand All @@ -84,6 +90,104 @@ def forward(
return outputs # a, present, (attentions)


class GPT2SdpaAttentionWithAdapters(GPT2AttentionAdaptersMixin, GPT2SdpaAttention):
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if output_attentions or head_mask is not None:
logger.warning_once(
"`GPT2SdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)

bsz, q_len, _ = hidden_states.size()

# Initial attention projections
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPT2SdpaAttention(..., is_cross_attention=True)`."
)

query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)

query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)

# Optional kv caching
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)

present = None
if use_cache is True:
present = (key, value)

# >>> START AH Changes <<<
key, value, attention_mask = self.prefix_tuning(key, value, hidden_states, attention_mask)
(query,) = adjust_tensors_for_parallel(key, query)
bsz = key.shape[0]
# >>> END AH Changes <<<

# Avoid torch==2.1.2 specific bug for the memory-efficient backend in SDPA
if self.require_contiguous_qkv and query.device.type == "cuda" and attention_mask is not None:
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()

# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
is_causal = True if attention_mask is None and q_len > 1 and not is_cross_attention else False

attn_output = torch.nn.functional.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=self.attn_dropout.p if self.training else 0.0,
is_causal=is_causal,
)

# Reshape outputs
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.embed_dim)

# Final projection
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

return attn_output, present, None


class GPT2BlockWithAdapters(GPT2DecoderBlockAdaptersMixin, GPT2Block):
def forward(
self,
Expand Down
Loading

0 comments on commit e41b094

Please sign in to comment.