Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Oct 22, 2024
1 parent 7faf143 commit ae8643f
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 48 deletions.
73 changes: 30 additions & 43 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn.init import trunc_normal_

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache
from ...generation.utils import GenerationMixin
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
from ...modeling_outputs import BaseModelOutput, ModelOutput
from ...modeling_utils import PreTrainedModel
Expand Down Expand Up @@ -57,28 +56,6 @@
from .configuration_aria import AriaTextConfig


class AriaPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
"""

config_class = AriaConfig
base_model_prefix = "model"
_no_split_modules = []
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True

@property
def _supports_sdpa(self):
"""
Retrieve language_model's attribute to check whether the model supports
SDPA (Scaled Dot Product Attention) or not.
"""
return self.language_model._supports_sdpa


class IdentityOp(torch.nn.Module):
"""
An identity operation that returns the input unchanged.
Expand Down Expand Up @@ -159,6 +136,28 @@ def trunc_normal_tf_(
tensor.mul_(std).add_(mean)


class AriaPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
"""

config_class = AriaConfig
base_model_prefix = "model"
_no_split_modules = []
supports_gradient_checkpointing = True
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_cache_class = True

@property
def _supports_sdpa(self):
"""
Retrieve language_model's attribute to check whether the model supports
SDPA (Scaled Dot Product Attention) or not.
"""
return self.language_model._supports_sdpa


class AriaVisionEmbeddings(nn.Module):
"""
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
Expand Down Expand Up @@ -878,6 +877,7 @@ class AriaVisionTransformer(AriaPreTrainedModel):
"""

config_class = AriaVisionConfig

_supports_sdpa = False

def __init__(self, config: AriaVisionConfig):
Expand Down Expand Up @@ -1243,7 +1243,6 @@ def __init__(self, config: AriaTextConfig):
self.config = config

self.weight = nn.Parameter(torch.empty((self.config.moe_num_experts, self.config.hidden_size)))
# self.weight = nn.Linear(self.config.moe_num_experts, self.config.hidden_size, bias=None)
# FIXME: initialize the weight

# Simplify code a lot compared to original, since we do not need training.
Expand All @@ -1253,14 +1252,12 @@ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torc
top_logits, top_indices = torch.topk(logits, k=self.config.moe_topk, dim=1)
scores = F.softmax(top_logits, dim=-1)

initial_type = top_indices.dtype

tokens_per_expert = torch.histc(
top_indices.flatten().to(torch.float32),
bins=self.config.moe_num_experts,
min=0,
max=self.config.moe_num_experts - 1,
).to(initial_type)
)

return scores, top_indices, tokens_per_expert

Expand Down Expand Up @@ -1328,7 +1325,7 @@ def __init__(self, in_features, out_features, groups):
self.in_features = in_features
self.out_features = out_features
self.groups = groups
self.weight = nn.Parameter(torch.empty(groups, in_features, out_features, dtype=torch.bfloat16))
self.weight = nn.Parameter(torch.empty(groups, in_features, out_features))

def forward(self, input, tokens_per_expert):
"""
Expand Down Expand Up @@ -2700,6 +2697,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**loss_kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -2762,18 +2760,7 @@ def forward(

loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **loss_kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down Expand Up @@ -2838,6 +2825,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
Args:
config (AriaConfig): Configuration object for the model.
"""

_supports_sdpa = False

def __init__(self, config: AriaConfig):
Expand All @@ -2859,6 +2847,7 @@ def __init__(self, config: AriaConfig):
config.text_config, attn_implementation=config._attn_implementation
)
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2"
self.post_init()

def get_input_embeddings(self):
Expand Down Expand Up @@ -2900,8 +2889,6 @@ def get_image_features(
image_features = self.multi_modal_projector(selected_image_feature)
return image_features

@add_start_docstrings_to_model_forward(ARIA_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=AriaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/aria/modular_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ...activations import ACT2FN
from ...configuration_utils import PretrainedConfig
from ...feature_extraction_utils import BatchFeature
from ...generation.utils import GenerationMixin
from ...generation import GenerationMixin
from ...image_processing_utils import BaseImageProcessor
from ...image_utils import ImageInput
from ...modeling_outputs import BaseModelOutput
Expand Down Expand Up @@ -85,6 +85,7 @@ class AriaVisionTransformer(Idefics3VisionTransformer):
This class extends the original Idefics3VisionTransformer by removing the post-layernorm operation.
"""

_supports_sdpa = False

def __init__(self, config: AriaVisionConfig):
Expand Down Expand Up @@ -1132,6 +1133,7 @@ class AriaForConditionalGeneration(AriaPreTrainedModel, GenerationMixin):
Args:
config (AriaConfig): Configuration object for the model.
"""

_supports_sdpa = False

def __init__(self, config: AriaConfig):
Expand Down
5 changes: 1 addition & 4 deletions tests/models/aria/test_modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,10 +517,7 @@ def test_aria_merge_inputs_error_bug(self):

def test_tokenizer_integration(self):
slow_tokenizer = AutoTokenizer.from_pretrained(
"rhymes-ai/Aria",
bos_token="<|startoftext|>",
eos_token="<|endoftext|>",
use_fast=False
"rhymes-ai/Aria", bos_token="<|startoftext|>", eos_token="<|endoftext|>", use_fast=False
)
slow_tokenizer.add_tokens("<image>", True)

Expand Down

0 comments on commit ae8643f

Please sign in to comment.