diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index ce2df84bc307dd..53a877a6be7dd8 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -11,12 +11,11 @@ import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss from torch.nn.init import trunc_normal_ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache -from ...generation.utils import GenerationMixin +from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import PreTrainedModel @@ -57,28 +56,6 @@ from .configuration_aria import AriaTextConfig -class AriaPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. - """ - - config_class = AriaConfig - base_model_prefix = "model" - _no_split_modules = [] - supports_gradient_checkpointing = True - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_cache_class = True - - @property - def _supports_sdpa(self): - """ - Retrieve language_model's attribute to check whether the model supports - SDPA (Scaled Dot Product Attention) or not. - """ - return self.language_model._supports_sdpa - - class IdentityOp(torch.nn.Module): """ An identity operation that returns the input unchanged. @@ -159,6 +136,28 @@ def trunc_normal_tf_( tensor.mul_(std).add_(mean) +class AriaPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. + """ + + config_class = AriaConfig + base_model_prefix = "model" + _no_split_modules = [] + supports_gradient_checkpointing = True + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + + @property + def _supports_sdpa(self): + """ + Retrieve language_model's attribute to check whether the model supports + SDPA (Scaled Dot Product Attention) or not. + """ + return self.language_model._supports_sdpa + + class AriaVisionEmbeddings(nn.Module): """ This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable @@ -878,6 +877,7 @@ class AriaVisionTransformer(AriaPreTrainedModel): """ config_class = AriaVisionConfig + _supports_sdpa = False def __init__(self, config: AriaVisionConfig): @@ -1243,7 +1243,6 @@ def __init__(self, config: AriaTextConfig): self.config = config self.weight = nn.Parameter(torch.empty((self.config.moe_num_experts, self.config.hidden_size))) - # self.weight = nn.Linear(self.config.moe_num_experts, self.config.hidden_size, bias=None) # FIXME: initialize the weight # Simplify code a lot compared to original, since we do not need training. @@ -1253,14 +1252,12 @@ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torc top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1) scores = F.softmax(top_logits, dim=-1) - initial_type = top_indices.dtype - tokens_per_expert = torch.histc( top_indices.flatten().to(torch.float32), bins=self.config.moe_num_experts, min=0, max=self.config.moe_num_experts - 1, - ).to(initial_type) + ) return scores, top_indices, tokens_per_expert @@ -1328,7 +1325,7 @@ def __init__(self, in_features, out_features, groups): self.in_features = in_features self.out_features = out_features self.groups = groups - self.weight = nn.Parameter(torch.empty(groups, in_features, out_features, dtype=torch.bfloat16)) + self.weight = nn.Parameter(torch.empty(groups, in_features, out_features)) def forward(self, input, tokens_per_expert): """ @@ -2700,6 +2697,7 @@ def forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, + **loss_kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -2762,18 +2760,7 @@ def forward( loss = None if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs) if not return_dict: output = (logits,) + outputs[1:] @@ -2838,6 +2825,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): Args: config (AriaConfig): Configuration object for the model. """ + _supports_sdpa = False def __init__(self, config: AriaConfig): @@ -2859,6 +2847,7 @@ def __init__(self, config: AriaConfig): config.text_config, attn_implementation=config._attn_implementation ) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 + self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2" self.post_init() def get_input_embeddings(self): @@ -2900,8 +2889,6 @@ def get_image_features( image_features = self.multi_modal_projector(selected_image_feature) return image_features - @add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, diff --git a/src/transformers/models/aria/modular_aria.py b/src/transformers/models/aria/modular_aria.py index 66fd24698246ae..0c54c10d1120e9 100644 --- a/src/transformers/models/aria/modular_aria.py +++ b/src/transformers/models/aria/modular_aria.py @@ -11,7 +11,7 @@ from ...activations import ACT2FN from ...configuration_utils import PretrainedConfig from ...feature_extraction_utils import BatchFeature -from ...generation.utils import GenerationMixin +from ...generation import GenerationMixin from ...image_processing_utils import BaseImageProcessor from ...image_utils import ImageInput from ...modeling_outputs import BaseModelOutput @@ -85,6 +85,7 @@ class AriaVisionTransformer(Idefics3VisionTransformer): This class extends the original Idefics3VisionTransformer by removing the post-layernorm operation. """ + _supports_sdpa = False def __init__(self, config: AriaVisionConfig): @@ -1132,6 +1133,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin): Args: config (AriaConfig): Configuration object for the model. """ + _supports_sdpa = False def __init__(self, config: AriaConfig): diff --git a/tests/models/aria/test_modeling_aria.py b/tests/models/aria/test_modeling_aria.py index fcba5a1492d6dd..8d6fa83df16564 100644 --- a/tests/models/aria/test_modeling_aria.py +++ b/tests/models/aria/test_modeling_aria.py @@ -517,10 +517,7 @@ def test_aria_merge_inputs_error_bug(self): def test_tokenizer_integration(self): slow_tokenizer = AutoTokenizer.from_pretrained( - "rhymes-ai/Aria", - bos_token="<|startoftext|>", - eos_token="<|endoftext|>", - use_fast=False + "rhymes-ai/Aria", bos_token="<|startoftext|>", eos_token="<|endoftext|>", use_fast=False ) slow_tokenizer.add_tokens("", True)