From 7faf1438390f222cd0102f61a5e9602daad93bfb Mon Sep 17 00:00:00 2001 From: aymeric-roucher Date: Tue, 22 Oct 2024 12:30:18 +0000 Subject: [PATCH] Change sdpa --- src/transformers/models/aria/configuration_aria.py | 4 +--- src/transformers/models/aria/modeling_aria.py | 3 +++ src/transformers/models/aria/modular_aria.py | 4 ++++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/aria/configuration_aria.py b/src/transformers/models/aria/configuration_aria.py index 1e022396f2e980..bf2b262c0cc398 100644 --- a/src/transformers/models/aria/configuration_aria.py +++ b/src/transformers/models/aria/configuration_aria.py @@ -91,8 +91,8 @@ def __init__( self.image_size = image_size self.attention_dropout = attention_dropout self.layer_norm_eps = layer_norm_eps - self._attn_implementation = "eager" self.hidden_act = hidden_act + self._supports_sdpa = False @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": @@ -198,7 +198,6 @@ def __init__( self.moe_z_loss_coeff = moe_z_loss_coeff self.moe_aux_loss_coeff = moe_aux_loss_coeff self.moe_num_shared_experts = moe_num_shared_experts - self._attn_implementation = "eager" super().__init__( pad_token_id=pad_token_id, @@ -249,7 +248,6 @@ def __init__( super().__init__(**kwargs) self.ignore_index = ignore_index self.image_token_index = image_token_index - self._attn_implementation = "eager" # Convert the keys and values of projector_patch_to_query_dict to integers # This ensures consistency even if they were provided as strings diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 5d7f8e1802333f..ce2df84bc307dd 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -878,6 +878,7 @@ class AriaVisionTransformer(AriaPreTrainedModel): """ config_class = AriaVisionConfig + _supports_sdpa = False def __init__(self, config: AriaVisionConfig): super().__init__(config) @@ -992,6 +993,7 @@ class AriaVisionModel(AriaPreTrainedModel): config_class = AriaVisionConfig main_input_name = "pixel_values" + _supports_sdpa = False def __init__(self, config: AriaVisionConfig): super().__init__(config) @@ -2836,6 +2838,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): Args: config (AriaConfig): Configuration object for the model. """ + _supports_sdpa = False def __init__(self, config: AriaConfig): super().__init__(config) diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index 7f490778da8a95..66fd24698246ae 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -85,6 +85,7 @@ class AriaVisionTransformer(Idefics3VisionTransformer): This class extends the original Idefics3VisionTransformer by removing the post-layernorm operation. """ + _supports_sdpa = False def __init__(self, config: AriaVisionConfig): super().__init__(config) @@ -110,6 +111,7 @@ class AriaVisionModel(SiglipVisionModel): config_class = AriaVisionConfig main_input_name = "pixel_values" + _supports_sdpa = False def __init__(self, config: AriaVisionConfig): super().__init__(config) @@ -1130,6 +1132,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): Args: config (AriaConfig): Configuration object for the model. """ + _supports_sdpa = False def __init__(self, config: AriaConfig): super().__init__(config) @@ -1150,6 +1153,7 @@ def __init__(self, config: AriaConfig): config.text_config, attn_implementation=config._attn_implementation ) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2" self.post_init() def get_input_embeddings(self):