diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 9c03d06d94ad48..67bd31fdaeede5 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -79,6 +79,7 @@ FlashAttention-2 is currently supported for the following architectures: * [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel) * [OLMoE](https://huggingface.co/docs/transformers/model_doc/olmoe#transformers.OlmoeModel) * [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel) +* [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration) * [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel) * [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model) * [PhiMoE](https://huggingface.co/docs/transformers/model_doc/phimoe#transformers.PhimoeModel) @@ -88,6 +89,10 @@ FlashAttention-2 is currently supported for the following architectures: * [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder) * [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel) * [Qwen2VL](https://huggingface.co/docs/transformers/model_doc/qwen2_vl#transformers.Qwen2VLModel) +* [RAG](https://huggingface.co/docs/transformers/model_doc/rag#transformers.RagModel) +* [SpeechEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/speech_encoder_decoder#transformers.SpeechEncoderDecoderModel) +* [VisionEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/vision_encoder_decoder#transformers.VisionEncoderDecoderModel) +* [VisionTextDualEncoder](https://huggingface.co/docs/transformers/model_doc/vision_text_dual_encoder#transformers.VisionTextDualEncoderModel) * [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) * [Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model) * [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel) @@ -225,6 +230,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2) * [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel) * [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader) +* [EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder_decoder#transformers.EncoderDecoderModel) * [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) * [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) * [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model) @@ -233,11 +239,16 @@ For now, Transformers supports SDPA inference and training for the following arc * [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel) * [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel) * [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel) +* [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model) +* [Idefics3](https://huggingface.co/docs/transformers/model_doc/idefics3#transformers.Idefics3Model) * [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel) * [GraniteMoe](https://huggingface.co/docs/transformers/model_doc/granitemoe#transformers.GraniteMoeModel) * [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel) * [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel) * [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) +* [Llava](https://huggingface.co/docs/transformers/model_doc/llava) +* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next) +* [Llava-NeXT-Video](https://huggingface.co/docs/transformers/model_doc/llava_next_video) * [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision) * [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100#transformers.M2M100Model) * [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi) @@ -277,10 +288,15 @@ For now, Transformers supports SDPA inference and training for the following arc * [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) * [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) * [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron) +* [SpeechEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/speech_encoder_decoder#transformers.SpeechEncoderDecoderModel) +* [VideoLlava](https://huggingface.co/docs/transformers/model_doc/video_llava) +* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava) +* [VisionEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/vision_encoder_decoder#transformers.VisionEncoderDecoderModel) * [ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.ViTModel) * [ViTHybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid#transformers.ViTHybridModel) * [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel) * [ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel) +* [VisionTextDualEncoder](https://huggingface.co/docs/transformers/model_doc/vision_text_dual_encoder#transformers.VisionTextDualEncoderModel) * [VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell) * [ViViT](https://huggingface.co/docs/transformers/model_doc/vivit#transformers.VivitModel) * [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model) diff --git a/docs/source/ko/_toctree.yml b/docs/source/ko/_toctree.yml index 883db54c7a3cd7..351f89c7891d59 100644 --- a/docs/source/ko/_toctree.yml +++ b/docs/source/ko/_toctree.yml @@ -673,7 +673,7 @@ - local: in_translation title: (번역중) XLSR-Wav2Vec2 title: (번역중) 오디오 모델 - - isExpanded: false + - isExpanded: false sections: - local: model_doc/vivit title: ViViT diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 4e4a1ee26c12d7..0f696cc3ac6a4d 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1475,11 +1475,7 @@ def from_legacy_cache( def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor` - if self.self_attention_cache.key_cache == []: - return 0 - if len(self.self_attention_cache.key_cache) > 1 and self.self_attention_cache.key_cache[layer_idx] == []: - return 0 - return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum() + return self.self_attention_cache.get_seq_length(layer_idx) def reset(self): if hasattr(self.self_attention_cache, "reset"): diff --git a/src/transformers/configuration_utils.py b/src/transformers/configuration_utils.py index 8bc08ca625961e..1d892c49a231fc 100755 --- a/src/transformers/configuration_utils.py +++ b/src/transformers/configuration_utils.py @@ -296,6 +296,7 @@ def __init__(self, **kwargs): # Attention implementation to use, if relevant. self._attn_implementation_internal = kwargs.pop("attn_implementation", None) + self._attn_implementation_autoset = False # Drop the transformers version info self.transformers_version = kwargs.pop("transformers_version", None) @@ -776,6 +777,10 @@ def __eq__(self, other): def __repr__(self): return f"{self.__class__.__name__} {self.to_json_string()}" + def __iter__(self): + for attr in self.__dict__: + yield attr + def to_diff_dict(self) -> Dict[str, Any]: """ Removes all attributes from config which correspond to the default config attributes for better readability and diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 1acd40641132b3..37d57248c46a17 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -93,7 +93,7 @@ class GenerationMode(ExplicitEnum): class GenerationConfig(PushToHubMixin): # no-format - rf""" + """ Class that holds a configuration for a generation task. A `generate` call supports the following generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9ede527ecb7b80..c399a8a2c829c7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1535,8 +1535,12 @@ def _prepare_generation_config( def _get_initial_cache_position(self, input_ids, model_kwargs): """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` - if "inputs_embeds" in model_kwargs: + if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder: cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 + elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder: + cache_position = ( + torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 + ) else: cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 @@ -1633,7 +1637,7 @@ def get_layer_device_map(execution_device_map: Optional[dict] = None): cache_kwargs = { "config": self.config.get_text_config(), - "max_batch_size": batch_size, + "batch_size": batch_size, "max_cache_len": max_cache_len, "device": device, "dtype": cache_dtype, diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c84aec21a32663..a6fbd7b1a91453 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1420,9 +1420,10 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" ) # Save config and origin of the pretrained weights if given in model - config = self._autoset_attn_implementation( - config, torch_dtype=torch.get_default_dtype(), check_device_map=False - ) + if not getattr(config, "_attn_implementation_autoset", False): + config = self._autoset_attn_implementation( + config, torch_dtype=torch.get_default_dtype(), check_device_map=False + ) self.config = config self.name_or_path = config.name_or_path @@ -1500,6 +1501,9 @@ def _from_config(cls, config, **kwargs): torch_dtype (`torch.dtype`, *optional*): Override the default `torch.dtype` and load the model under this dtype. """ + # when we init a model from within another model (e.g. VLMs) and dispatch on FA2 + # a warning is raised that dtype should be fp16. Since we never pass dtype from within + # modeling code, we can try to infer it here same way as done in `from_pretrained` torch_dtype = kwargs.pop("torch_dtype", torch.get_default_dtype()) use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) @@ -1518,12 +1522,13 @@ def _from_config(cls, config, **kwargs): attn_implementation = None config._attn_implementation = kwargs.pop("attn_implementation", attn_implementation) - config = cls._autoset_attn_implementation( - config, - use_flash_attention_2=use_flash_attention_2, - check_device_map=False, - torch_dtype=torch_dtype, - ) + if not getattr(config, "_attn_implementation_autoset", False): + config = cls._autoset_attn_implementation( + config, + use_flash_attention_2=use_flash_attention_2, + check_device_map=False, + torch_dtype=torch_dtype, + ) if is_deepspeed_zero3_enabled(): import deepspeed @@ -1570,7 +1575,11 @@ def _autoset_attn_implementation( ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' ) - if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2"]: + if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [ + "eager", + "sdpa", + "flash_attention_2", + ]: message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' if cls._supports_flash_attn_2: message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' @@ -1581,6 +1590,22 @@ def _autoset_attn_implementation( # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. requested_attn_implementation = config._attn_implementation_internal + # Composite models consisting of several PretrainedModels have to specify attention impl as a dict + # where keys are sub-config names. But most people will specify one `str` which means that should dispatch it + # for all sub-models. + # Below we check if a config is composite and manually prepare a dict of attn impl if not already passed as a dict. + # Later each sub-module will dispatch with its own attn impl, by calling `XXXModel._from_config(config.text_config)` + # If any of sub-modules doesn't support requested attn, an error will be raised. See https://github.com/huggingface/transformers/pull/32238 + for key in config: + if isinstance(getattr(config, key), PretrainedConfig): + sub_config = getattr(config, key) + curr_attn_implementation = ( + requested_attn_implementation + if not isinstance(requested_attn_implementation, dict) + else requested_attn_implementation.get(key, None) + ) + sub_config._attn_implementation_internal = curr_attn_implementation + if use_flash_attention_2: logger.warning_once( 'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.' @@ -1611,9 +1636,12 @@ def _autoset_attn_implementation( "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends." ) torch.backends.cuda.enable_flash_sdp(False) + elif isinstance(requested_attn_implementation, dict): + config._attn_implementation = None else: config._attn_implementation = "eager" + config._attn_implementation_autoset = True return config @classmethod @@ -2771,6 +2799,9 @@ def save_pretrained( # Attach architecture to the config model_to_save.config.architectures = [model_to_save.__class__.__name__] + # Unset attn implementation so it can be set to another one when loading back + model_to_save.config._attn_implementation_autoset = False + # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be # loaded from the Hub. if self._auto_class is not None: @@ -4055,9 +4086,10 @@ def from_pretrained( init_contexts.append(init_empty_weights()) config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. - config = cls._autoset_attn_implementation( - config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map - ) + if not getattr(config, "_attn_implementation_autoset", False): + config = cls._autoset_attn_implementation( + config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map + ) with ContextManagers(init_contexts): # Let's make sure we don't run the init function of buffer modules diff --git a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py index beb249202b96c7..491c6ce164611a 100644 --- a/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +++ b/src/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py @@ -176,8 +176,24 @@ def __init__(self, config: ASTConfig) -> None: self.attention_probs_dropout_prob = config.attention_probs_dropout_prob def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states: torch.FloatTensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + if output_attentions or head_mask is not None: + logger.warning_once( + "`ASTSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " + "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + head_mask=head_mask, + output_attentions=output_attentions, + ) + mixed_query_layer = self.query(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states)) diff --git a/src/transformers/models/blip_2/modeling_blip_2.py b/src/transformers/models/blip_2/modeling_blip_2.py index 4b0ed4f71d9c95..eba82cd1b3c8e4 100644 --- a/src/transformers/models/blip_2/modeling_blip_2.py +++ b/src/transformers/models/blip_2/modeling_blip_2.py @@ -410,6 +410,7 @@ class Blip2PreTrainedModel(PreTrainedModel): config_class = Blip2Config base_model_prefix = "blip" supports_gradient_checkpointing = True + _no_split_modules = [ "Blip2Attention", "Blip2QFormerMultiHeadAttention", @@ -1455,13 +1456,9 @@ def __init__(self, config: Blip2Config): self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) if config.use_decoder_only_language_model: - language_model = AutoModelForCausalLM.from_config( - config.text_config, attn_implementation=config._attn_implementation - ) + language_model = AutoModelForCausalLM.from_config(config.text_config) else: - language_model = AutoModelForSeq2SeqLM.from_config( - config.text_config, attn_implementation=config._attn_implementation - ) + language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) # Update _tied_weights_keys using the base model used. if language_model._tied_weights_keys is not None: @@ -2020,13 +2017,9 @@ def __init__(self, config: Blip2Config): self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) if config.use_decoder_only_language_model: - language_model = AutoModelForCausalLM.from_config( - config.text_config, attn_implementation=config._attn_implementation - ) + language_model = AutoModelForCausalLM.from_config(config.text_config) else: - language_model = AutoModelForSeq2SeqLM.from_config( - config.text_config, attn_implementation=config._attn_implementation - ) + language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) # Update _tied_weights_keys using the base model used. if language_model._tied_weights_keys is not None: diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index f946f828eec639..04a3a73de0455e 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -1204,10 +1204,10 @@ def __init__(self, config: CLIPConfig): self.text_embed_dim = text_config.hidden_size self.vision_embed_dim = vision_config.hidden_size - text_model = CLIPTextModel._from_config(text_config, attn_implementation=config._attn_implementation) + text_model = CLIPTextModel._from_config(text_config) self.text_model = text_model.text_model - vision_model = CLIPVisionModel._from_config(vision_config, attn_implementation=config._attn_implementation) + vision_model = CLIPVisionModel._from_config(vision_config) self.vision_model = vision_model.vision_model self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) @@ -1590,9 +1590,7 @@ def __init__(self, config: CLIPConfig) -> None: super().__init__(config) self.num_labels = config.num_labels - vision_model = CLIPVisionModel._from_config( - config.vision_config, attn_implementation=config._attn_implementation - ) + vision_model = CLIPVisionModel._from_config(config.vision_config) self.vision_model = vision_model.vision_model # Classifier head diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 03194c15d98f1c..e0b053e43906b8 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -248,8 +248,24 @@ def __init__(self, config: DeiTConfig) -> None: self.attention_probs_dropout_prob = config.attention_probs_dropout_prob def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states: torch.FloatTensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + if output_attentions or head_mask is not None: + logger.warning_once( + "`DeiTSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " + "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + head_mask=head_mask, + output_attentions=output_attentions, + ) + mixed_query_layer = self.query(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states)) diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index d1029160dd0cc2..9ebedce07fb833 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -180,6 +180,8 @@ class EncoderDecoderModel(PreTrainedModel, GenerationMixin): main_input_name = "input_ids" supports_gradient_checkpointing = True _supports_param_buffer_assignment = False + _supports_flash_attn_2 = True + _supports_sdpa = True def __init__( self, @@ -210,12 +212,12 @@ def __init__( if encoder is None: from ..auto.modeling_auto import AutoModel - encoder = AutoModel.from_config(config.encoder, attn_implementation=config._attn_implementation) + encoder = AutoModel.from_config(config.encoder) if decoder is None: from ..auto.modeling_auto import AutoModelForCausalLM - decoder = AutoModelForCausalLM.from_config(config.decoder, attn_implementation=config._attn_implementation) + decoder = AutoModelForCausalLM.from_config(config.decoder) self.encoder = encoder self.decoder = decoder @@ -233,6 +235,9 @@ def __init__( # make sure that the individual model's config refers to the shared config # so that the updates to the config will be synced + # update `_attn_implementation` because the attn is set in a deepcopied config within PreTrainedModel + self.config.encoder._attn_implementation = self.encoder.config._attn_implementation + self.config.decoder._attn_implementation = self.decoder.config._attn_implementation self.encoder.config = self.config.encoder self.decoder.config = self.config.decoder diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index bc983744559fc9..8bd24728b03885 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -933,18 +933,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - # Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa - @classmethod - def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: - # We remove the checks on `is_torch_sdpa_available()` and `cls._supports_sdpa` as Falcon supports SDPA from torch==2.0.0 (no requirement on 2.1). - _is_bettertransformer = getattr(cls, "use_bettertransformer", False) - if _is_bettertransformer: - return config - - if not hard_check_only: - config._attn_implementation = "sdpa" - return config - LLAMA_INPUTS_DOCSTRING = r""" Args: diff --git a/src/transformers/models/idefics2/configuration_idefics2.py b/src/transformers/models/idefics2/configuration_idefics2.py index 1333895407e6e5..64743d1cd470e7 100644 --- a/src/transformers/models/idefics2/configuration_idefics2.py +++ b/src/transformers/models/idefics2/configuration_idefics2.py @@ -57,7 +57,7 @@ class Idefics2VisionConfig(PretrainedConfig): The epsilon used by the layer normalization layers. attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. - intializer_range (`float`, *optional*, defaults to 0.02): + initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation for initializing all weight matrices in the model. Example: @@ -134,6 +134,10 @@ class Idefics2PerceiverConfig(PretrainedConfig): Args: hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the perceiver block. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. resampler_n_latents (`int`, *optional*, defaults to 64): Number of latent embeddings to resample ("compress") the input sequence to (usually < 128). resampler_depth (`int`, *optional*, defaults to 3): @@ -153,6 +157,8 @@ class Idefics2PerceiverConfig(PretrainedConfig): def __init__( self, hidden_act="silu", + hidden_size=4096, + rms_norm_eps=1e-06, resampler_n_latents=64, resampler_depth=3, resampler_n_heads=16, @@ -162,6 +168,8 @@ def __init__( **kwargs, ): self.hidden_act = hidden_act + self.hidden_size = hidden_size + self.rms_norm_eps = rms_norm_eps self.resampler_n_latents = resampler_n_latents self.resampler_depth = resampler_depth self.resampler_n_heads = resampler_n_heads @@ -258,5 +266,12 @@ def __init__( ) self.text_config = text_config + if self.text_config.hidden_size != self.perceiver_config.hidden_size: + self.perceiver_config.hidden_size = self.text_config.hidden_size + self.perceiver_config.rms_norm_eps = self.text_config.rms_norm_eps + logger.warning_once( + "Perceiver config has a different `hidden_size` than text config, which means default values were used. " + "In your model's config on the hub, add `hidden_size` and `rms_norm_eps` keys under the `perceiver_config` dict. " + ) super().__init__(**kwargs, tie_word_embeddings=tie_word_embeddings) diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index daa8bfb055b561..3d46c3bd82e788 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -38,7 +38,7 @@ replace_return_docstrings, ) from ..auto import AutoModel -from .configuration_idefics2 import Idefics2Config, Idefics2VisionConfig +from .configuration_idefics2 import Idefics2Config, Idefics2PerceiverConfig, Idefics2VisionConfig if is_flash_attn_2_available(): @@ -572,9 +572,86 @@ def forward( ) -class Idefics2VisionTransformer(nn.Module): +IDEFICS2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`Idefics2Config`] or [`Idefics2VisionConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Idefics2 Model outputting raw hidden-states without any specific head on top.", + IDEFICS2_START_DOCSTRING, +) +class Idefics2PreTrainedModel(PreTrainedModel): + config_class = Idefics2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["Idefics2VisionAttention", "Idefics2MLP", "Idefics2PerceiverLayer", "Idefics2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = ( + self.config.text_config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.text_config.initializer_range + ) + + if hasattr(module, "class_embedding"): + module.class_embedding.data.normal_(mean=0.0, std=std) + + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +IDEFICS2_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)): + The tensors corresponding to the input images. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details ([]`LlavaProcessor`] uses + [`CLIPImageProcessor`] for processing images). + pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): + Mask to avoid performing attention on padding pixel indices. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + """Idefics2 vision encoder model that returnss raw image embeddings.""", + IDEFICS2_START_DOCSTRING, +) +class Idefics2VisionTransformer(Idefics2PreTrainedModel): + _supports_sdpa = False + config_class = Idefics2VisionConfig + def __init__(self, config: Idefics2VisionConfig): - super().__init__() + super().__init__(config) embed_dim = config.hidden_size self.config = config @@ -687,12 +764,12 @@ def __init__(self, config, layer_idx: Optional[int] = None) -> None: super().__init__() self.layer_idx = None - self.hidden_size = config.text_config.hidden_size - self.num_heads = config.perceiver_config.resampler_n_heads - self.head_dim = config.perceiver_config.resampler_head_dim - self.num_key_value_heads = config.perceiver_config.num_key_value_heads + self.hidden_size = config.hidden_size + self.num_heads = config.resampler_n_heads + self.head_dim = config.resampler_head_dim + self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads - self.attention_dropout = config.perceiver_config.attention_dropout + self.attention_dropout = config.attention_dropout self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) @@ -918,20 +995,20 @@ def forward( class Idefics2PerceiverLayer(nn.Module): def __init__(self, config, layer_idx: int): super().__init__() - self.hidden_size = config.text_config.hidden_size - self.n_latents = config.perceiver_config.resampler_n_latents - self.depth = config.perceiver_config.resampler_depth - self.rms_norm_eps = config.text_config.rms_norm_eps + self.hidden_size = config.hidden_size + self.n_latents = config.resampler_n_latents + self.depth = config.resampler_depth + self.rms_norm_eps = config.rms_norm_eps self.input_latents_norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps) self.input_context_norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps) self.self_attn = IDEFICS2_PERCEIVER_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) self.post_attention_layernorm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps) self.mlp = Idefics2MLP( - hidden_size=config.text_config.hidden_size, - intermediate_size=config.text_config.hidden_size * 4, - output_size=config.text_config.hidden_size, - hidden_act=config.perceiver_config.hidden_act, + hidden_size=config.hidden_size, + intermediate_size=config.hidden_size * 4, + output_size=config.hidden_size, + hidden_act=config.hidden_act, ) def forward( @@ -987,20 +1064,37 @@ def forward( return outputs -class Idefics2PerceiverResampler(nn.Module): +IDEFICS2_INPUTS_DOCSTRING = r""" + Args: + context (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`): + The hidden states of the image after vision encoder and modality projection. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) +""" + + +@add_start_docstrings( + "Idefics2 perceiver resampler model that performs `depth` blocks of cross-attention with a fixed ", + "`n_latents` inputs to decrease embedding sequence length. The Resampler acts as a form of learned pooling and ", + "is derived from [Perceiver: General Perception with Iterative Attention](https://arxiv.org/abs/2103.03206)", + IDEFICS2_START_DOCSTRING, +) +class Idefics2PerceiverResampler(Idefics2PreTrainedModel): + _supports_sdpa = False + config_class = Idefics2PerceiverConfig + def __init__(self, config) -> None: - """ - Instantiates a Perceiver Resampler that operates over a sequence of embeddings (say from a ResNet or ViT or - MAE) of a given dimension, performs `depth` blocks of cross-attention with a fixed `n_latents` inputs, then - returns a Tensor of shape [bsz, n_latents, embed_dim]. The Resampler acts as a form of learned pooling and - is derived from [Perceiver: General Perception with Iterative Attention](https://arxiv.org/abs/2103.03206). - """ - super().__init__() - self.hidden_size = config.text_config.hidden_size - self.hidden_act = config.perceiver_config.hidden_act - self.n_latents = config.perceiver_config.resampler_n_latents - self.depth = config.perceiver_config.resampler_depth - self.rms_norm_eps = config.text_config.rms_norm_eps + super().__init__(config) + self.hidden_size = config.hidden_size + self.hidden_act = config.hidden_act + self.n_latents = config.resampler_n_latents + self.depth = config.resampler_depth + self.rms_norm_eps = config.rms_norm_eps # Create Latents for Perceiver self.latents = nn.Parameter(torch.ones(self.n_latents, self.hidden_size)) @@ -1014,7 +1108,7 @@ def __init__(self, config) -> None: def forward( self, context: torch.Tensor, - attention_mask, + attention_mask: torch.Tensor, ) -> torch.Tensor: # seq embed -> bsz seq embed latents = self.latents.unsqueeze(0).expand((context.shape[0], *self.latents.size())) @@ -1057,7 +1151,7 @@ def __init__(self, config): output_size=config.text_config.hidden_size, hidden_act=config.text_config.hidden_act, ) - self.perceiver_resampler = Idefics2PerceiverResampler(config) + self.perceiver_resampler = Idefics2PerceiverResampler._from_config(config.perceiver_config) def forward(self, image_hidden_states, attention_mask): image_hidden_states = self.modality_projection(image_hidden_states) @@ -1065,80 +1159,6 @@ def forward(self, image_hidden_states, attention_mask): return image_hidden_states -IDEFICS2_START_DOCSTRING = r""" - This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the - library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads - etc.) - - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. - Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage - and behavior. - - Parameters: - config ([`Idefics2Config`] or [`Idefics2VisionConfig`]): - Model configuration class with all the parameters of the model. Initializing with a config file does not - load the weights associated with the model, only the configuration. Check out the - [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" - - -@add_start_docstrings( - "The bare Idefics2 Model outputting raw hidden-states without any specific head on top.", - IDEFICS2_START_DOCSTRING, -) -class Idefics2PreTrainedModel(PreTrainedModel): - config_class = Idefics2Config - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["Idefics2VisionAttention", "Idefics2MLP", "Idefics2PerceiverLayer", "Idefics2DecoderLayer"] - _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True - _supports_cache_class = True - - def _init_weights(self, module): - std = ( - self.config.initializer_range - if hasattr(self.config, "initializer_range") - else self.config.text_config.initializer_range - ) - - if hasattr(module, "class_embedding"): - module.class_embedding.data.normal_(mean=0.0, std=std) - - if isinstance(module, (nn.Linear, nn.Conv2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - @classmethod - def _autoset_attn_implementation( - cls, - config, - use_flash_attention_2: bool = False, - torch_dtype: Optional[torch.dtype] = None, - device_map: Optional[Union[str, Dict[str, int]]] = None, - check_device_map: bool = True, - **kwargs, - ): - """ - Overrides the method in `PreTrainedModel` to update the vision config with the correct attention implementation - """ - config = super()._autoset_attn_implementation( - config=config, - use_flash_attention_2=use_flash_attention_2, - torch_dtype=torch_dtype, - device_map=device_map, - check_device_map=check_device_map, - **kwargs, - ) - config.vision_config._attn_implementation = config._attn_implementation - return config - - IDEFICS2_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -1219,14 +1239,14 @@ def __init__(self, config: Idefics2Config): self.padding_idx = self.config.text_config.pad_token_id self.vocab_size = self.config.text_config.vocab_size - self.vision_model = Idefics2VisionTransformer(config.vision_config) + self.vision_model = Idefics2VisionTransformer._from_config(config.vision_config) self.connector = Idefics2Connector(config) - self.text_model = AutoModel.from_config(config.text_config, attn_implementation=config._attn_implementation) + self.text_model = AutoModel.from_config(config.text_config) self.image_seq_len = config.perceiver_config.resampler_n_latents self.image_token_id = self.config.image_token_id - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2" self.post_init() diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 748eda8c026377..31d43948fbd565 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -621,12 +621,13 @@ class Idefics3PreTrainedModel(PreTrainedModel): _no_split_modules = ["Idefics3VisionAttention", "Idefics3DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_sdpa = True _supports_cache_class = True # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2PreTrainedModel._init_weights def _init_weights(self, module): std = ( - self.config.initializer_range + self.config.text_config.initializer_range if hasattr(self.config, "initializer_range") else self.config.text_config.initializer_range ) @@ -667,6 +668,7 @@ def _init_weights(self, module): ) class Idefics3VisionTransformer(Idefics3PreTrainedModel): config_class = Idefics3VisionConfig + _supports_sdpa = False def __init__(self, config: Idefics3VisionConfig): super().__init__(config) @@ -824,18 +826,16 @@ def __init__(self, config: Idefics3Config): self.padding_idx = self.config.text_config.pad_token_id self.vocab_size = self.config.text_config.vocab_size - self.vision_model = Idefics3VisionTransformer._from_config( - config.vision_config, attn_implementation=config._attn_implementation - ) + self.vision_model = Idefics3VisionTransformer._from_config(config.vision_config) self.connector = Idefics3Connector(config) - self.text_model = AutoModel.from_config(config.text_config, attn_implementation=config._attn_implementation) + self.text_model = AutoModel.from_config(config.text_config) self.image_seq_len = int( ((config.vision_config.image_size // config.vision_config.patch_size) ** 2) / (config.scale_factor**2) ) self.image_token_id = self.config.image_token_id - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_2 = config.text_config._attn_implementation == "flash_attention_2" self.post_init() diff --git a/src/transformers/models/instructblip/modeling_instructblip.py b/src/transformers/models/instructblip/modeling_instructblip.py index de4e84b82f8377..5cce774ce0716a 100644 --- a/src/transformers/models/instructblip/modeling_instructblip.py +++ b/src/transformers/models/instructblip/modeling_instructblip.py @@ -315,6 +315,7 @@ class InstructBlipPreTrainedModel(PreTrainedModel): config_class = InstructBlipConfig base_model_prefix = "blip" supports_gradient_checkpointing = True + _no_split_modules = [ "InstructBlipQFormerEmbeddings", "InstructBlipAttention", @@ -1298,13 +1299,9 @@ def __init__(self, config: InstructBlipConfig): self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) if config.use_decoder_only_language_model: - language_model = AutoModelForCausalLM.from_config( - config.text_config, attn_implementation=config._attn_implementation - ) + language_model = AutoModelForCausalLM.from_config(config.text_config) else: - language_model = AutoModelForSeq2SeqLM.from_config( - config.text_config, attn_implementation=config._attn_implementation - ) + language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) if language_model._no_split_modules is not None: self._no_split_modules.extend(language_model._no_split_modules) diff --git a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py index a300268ed71327..c9f12391666c22 100644 --- a/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py +++ b/src/transformers/models/instructblipvideo/modeling_instructblipvideo.py @@ -317,6 +317,7 @@ class InstructBlipVideoPreTrainedModel(PreTrainedModel): config_class = InstructBlipVideoConfig base_model_prefix = "blip" supports_gradient_checkpointing = True + _no_split_modules = [ "InstructBlipVideoQFormerEmbeddings", "InstructBlipVideoAttention", @@ -1292,13 +1293,9 @@ def __init__(self, config: InstructBlipVideoConfig): self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size) if config.use_decoder_only_language_model: - language_model = AutoModelForCausalLM.from_config( - config.text_config, attn_implementation=config._attn_implementation - ) + language_model = AutoModelForCausalLM.from_config(config.text_config) else: - language_model = AutoModelForSeq2SeqLM.from_config( - config.text_config, attn_implementation=config._attn_implementation - ) + language_model = AutoModelForSeq2SeqLM.from_config(config.text_config) if language_model._no_split_modules is not None: self._no_split_modules.extend(language_model._no_split_modules) diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index 31593bc62d098c..50b3d4c6a89533 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -125,8 +125,9 @@ class LlavaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_sdpa = True def _init_weights(self, module): # important: this ported version of Llava isn't meant for training from scratch - only @@ -150,14 +151,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - @property - def _supports_sdpa(self): - """ - Retrieve language_model's attribute to check whether the model supports - SDPA or not. - """ - return self.language_model._supports_sdpa - LLAVA_INPUTS_DOCSTRING = r""" Args: @@ -245,9 +238,7 @@ def __init__(self, config: LlavaConfig): self.multi_modal_projector = LlavaMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config( - config.text_config, attn_implementation=config._attn_implementation - ) + self.language_model = AutoModelForCausalLM.from_config(config.text_config) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() @@ -282,6 +273,20 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m def get_image_features( self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + vision_feature_layer (`int`): + The index of the layer to select the vision feature. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated. selected_image_feature = image_outputs.hidden_states[vision_feature_layer] diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index 03ab28dfff9cb1..0cbda9cfd64b74 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -234,8 +234,9 @@ class LlavaNextPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LlavaNextVisionAttention"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_sdpa = True def _init_weights(self, module): # important: this ported version of LlavaNext isn't meant for training from scratch - only @@ -259,14 +260,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - @property - def _supports_sdpa(self): - """ - Retrieve language_model's attribute to check whether the model supports - SDPA or not. - """ - return self.language_model._supports_sdpa - LLAVA_NEXT_INPUTS_DOCSTRING = r""" Args: @@ -360,9 +353,7 @@ def __init__(self, config: LlavaNextConfig): self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std) self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config( - config.text_config, attn_implementation=config._attn_implementation - ) + self.language_model = AutoModelForCausalLM.from_config(config.text_config) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides self.post_init() @@ -714,6 +705,57 @@ def pack_image_features(self, image_features, image_sizes, vision_feature_select feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) return image_features, feature_lens + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + vision_feature_layer: int, + vision_feature_select_strategy: str, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`) + The tensors corresponding to the input images. + image_sizes (`torch.Tensor` of shape `(num_images, 2)`) + Actual image size of each images (H, W). + vision_feature_layer (`int`): + The index of the layer to select the vision feature. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches + and are of shape `(num_patches, image_length, embed_dim)`). + """ + # ! infer image_num_patches from image_sizes + image_num_patches = [ + image_size_to_num_patches( + image_size=imsize, + grid_pinpoints=self.config.image_grid_pinpoints, + patch_size=self.config.vision_config.image_size, + ) + for imsize in image_sizes + ] + if pixel_values.dim() == 5: + # stacked if input is (batch_size, num_patches, num_channels, height, width) + _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)] + pixel_values = torch.cat(_pixel_values_list, dim=0) + elif pixel_values.dim() != 4: + # otherwise has to be stacked from list of (num_patches, num_channels, height, width) + raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") + + image_features = self.vision_tower(pixel_values, output_hidden_states=True) + selected_image_feature = image_features.hidden_states[vision_feature_layer] + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + image_features = self.multi_modal_projector(selected_image_feature) + image_features = torch.split(image_features, image_num_patches, dim=0) + return image_features + @add_start_docstrings_to_model_forward(LLAVA_NEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=LlavaNextCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -805,34 +847,12 @@ def forward( ) or (input_ids.shape[-1] == 1 and pixel_values is not None) if pixel_values is not None and pixel_values.size(0) > 0: - # ! infer image_num_patches from image_sizes - image_num_patches = [ - image_size_to_num_patches( - image_size=imsize, - grid_pinpoints=self.config.image_grid_pinpoints, - patch_size=self.config.vision_config.image_size, - ) - for imsize in image_sizes - ] - # figure out if pixel_values is concatenated or stacked - if pixel_values.dim() == 5: - # stacking when input is (batch_size, num_patches, num_channels, height, width) - _pixel_values_list = [ - pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches) - ] - pixel_values = torch.cat(_pixel_values_list, dim=0) - elif pixel_values.dim() != 4: - # otherwise has to be stacked from list of (num_patches, num_channels, height, width) - raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") - - image_features = self.vision_tower(pixel_values, output_hidden_states=True) - selected_image_feature = image_features.hidden_states[vision_feature_layer] - if vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - image_features = self.multi_modal_projector(selected_image_feature) - image_features = torch.split(image_features, image_num_patches, dim=0) + image_features = self.get_image_features( + pixel_values, + image_sizes, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) # NOTE we only support multimodal_patch_merge_type == "spatial_unpad" image_features, feature_lens = self.pack_image_features( diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 3fd6bb47fc7661..96f4373afd9ec6 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -277,8 +277,9 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["LlavaNextVideoVisionAttention"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_sdpa = True def _init_weights(self, module): # important: this ported version of LlavaNextVideo isn't meant for training from scratch - only @@ -302,14 +303,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - @property - def _supports_sdpa(self): - """ - Retrieve language_model's attribute to check whether the model supports - SDPA or not. - """ - return self.language_model._supports_sdpa - LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING = r""" Args: @@ -406,9 +399,7 @@ def __init__( self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std) self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config( - config.text_config, attn_implementation=config._attn_implementation - ) + self.language_model = AutoModelForCausalLM.from_config(config.text_config) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides self.vision_resampler = LlavaNextVideoPooler(config) @@ -753,6 +744,57 @@ def pack_image_features(self, image_features, image_sizes, vision_feature_select feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) return image_features, feature_lens + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + vision_feature_layer: int, + vision_feature_select_strategy: str, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`) + The tensors corresponding to the input images. + image_sizes (`torch.Tensor` of shape `(num_images, 2)`) + Actual image size of each images (H, W). + vision_feature_layer (`int`): + The index of the layer to select the vision feature. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches + and are of shape `(num_patches, image_length, embed_dim)`). + """ + # ! infer image_num_patches from image_sizes + image_num_patches = [ + image_size_to_num_patches( + image_size=imsize, + grid_pinpoints=self.config.image_grid_pinpoints, + patch_size=self.config.vision_config.image_size, + ) + for imsize in image_sizes + ] + if pixel_values.dim() == 5: + # stacked if input is (batch_size, num_patches, num_channels, height, width) + _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)] + pixel_values = torch.cat(_pixel_values_list, dim=0) + elif pixel_values.dim() != 4: + # otherwise has to be stacked from list of (num_patches, num_channels, height, width) + raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") + + image_features = self.vision_tower(pixel_values, output_hidden_states=True) + selected_image_feature = image_features.hidden_states[vision_feature_layer] + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + image_features = self.multi_modal_projector(selected_image_feature) + image_features = torch.split(image_features, image_num_patches, dim=0) + return image_features + @add_start_docstrings_to_model_forward(LLAVA_NEXT_VIDEO_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -892,7 +934,12 @@ def forward( image_features = feature_lens = None if pixel_values is not None and pixel_values.size(0) > 0: - image_features = self._get_image_features(pixel_values, image_sizes) + image_features = self.get_image_features( + pixel_values, + image_sizes, + vision_feature_layer=self.vision_feature_layer, + vision_feature_select_strategy=self.vision_feature_select_strategy, + ) image_features, feature_lens = self.pack_image_features( image_features, image_sizes, @@ -902,7 +949,11 @@ def forward( video_features = video_feature_lens = None if pixel_values_videos is not None and pixel_values_videos.size(0) > 0: - video_features = self._get_video_features(pixel_values_videos) + video_features = self.get_video_features( + pixel_values_videos, + vision_feature_layer=self.vision_feature_layer, + vision_feature_select_strategy=self.vision_feature_select_strategy, + ) video_features = [feature.flatten(0, 1) for feature in video_features] video_feature_lens = [feature.size(0) for feature in video_features] video_features = torch.cat(video_features, dim=0) @@ -1089,46 +1140,35 @@ def prepare_inputs_for_generation( return model_inputs - def _get_image_features(self, pixel_values, image_sizes): - # ! infer image_num_patches from image_sizes - image_num_patches = [ - image_size_to_num_patches( - image_size=imsize, - grid_pinpoints=self.config.image_grid_pinpoints, - patch_size=self.config.vision_config.image_size, - ) - for imsize in image_sizes - ] - if pixel_values.dim() == 5: - # stacked if input is (batch_size, num_patches, num_channels, height, width) - _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)] - pixel_values = torch.cat(_pixel_values_list, dim=0) - elif pixel_values.dim() != 4: - # otherwise has to be stacked from list of (num_patches, num_channels, height, width) - raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") - - image_features = self.vision_tower(pixel_values, output_hidden_states=True) - selected_image_feature = image_features.hidden_states[self.vision_feature_layer] - if self.vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif self.vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - image_features = self.multi_modal_projector(selected_image_feature) - image_features = torch.split(image_features, image_num_patches, dim=0) - return image_features + def get_video_features( + self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str + ): + """ + Obtains video last hidden states from the vision tower and apply multimodal projection. - def _get_video_features(self, pixel_values): + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`) + The tensors corresponding to the input video. + vision_feature_layer (`int`): + The index of the layer to select the vision feature. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches + and are of shape `(num_videos, video_length, embed_dim)`). + """ batch_size, frames, channels, height, width = pixel_values.shape pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width) - image_features = self.vision_tower(pixel_values, output_hidden_states=True) - selected_image_feature = image_features.hidden_states[self.vision_feature_layer] - if self.vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif self.vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature + video_features = self.vision_tower(pixel_values, output_hidden_states=True) + selected_video_features = video_features.hidden_states[vision_feature_layer] + if vision_feature_select_strategy == "default": + selected_video_features = selected_video_features[:, 1:] + elif vision_feature_select_strategy == "full": + selected_video_features = selected_video_features # Same as image features except that video has pooling layer - image_features = self.vision_resampler(selected_image_feature) - image_features = self.multi_modal_projector(image_features) - image_features = torch.split(image_features, frames, dim=0) - return image_features + video_features = self.vision_resampler(selected_video_features) + video_features = self.multi_modal_projector(video_features) + video_features = torch.split(video_features, frames, dim=0) + return video_features diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index ec5a05733ec878..c1ed7571941b9e 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -225,7 +225,30 @@ def __init__(self, config: LlavaNextVideoConfig, **super_kwargs): self.vision_resampler = LlavaNextVideoPooler(config) self.post_init() - def _get_image_features(self, pixel_values, image_sizes): + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + vision_feature_layer: int, + vision_feature_select_strategy: str, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`) + The tensors corresponding to the input images. + image_sizes (`torch.Tensor` of shape `(num_images, 2)`) + Actual image size of each images (H, W). + vision_feature_layer (`int`): + The index of the layer to select the vision feature. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches + and are of shape `(num_patches, image_length, embed_dim)`). + """ # ! infer image_num_patches from image_sizes image_num_patches = [ image_size_to_num_patches( @@ -244,30 +267,47 @@ def _get_image_features(self, pixel_values, image_sizes): raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") image_features = self.vision_tower(pixel_values, output_hidden_states=True) - selected_image_feature = image_features.hidden_states[self.vision_feature_layer] - if self.vision_feature_select_strategy == "default": + selected_image_feature = image_features.hidden_states[vision_feature_layer] + if vision_feature_select_strategy == "default": selected_image_feature = selected_image_feature[:, 1:] - elif self.vision_feature_select_strategy == "full": + elif vision_feature_select_strategy == "full": selected_image_feature = selected_image_feature image_features = self.multi_modal_projector(selected_image_feature) image_features = torch.split(image_features, image_num_patches, dim=0) return image_features - def _get_video_features(self, pixel_values): + def get_video_features( + self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str + ): + """ + Obtains video last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`) + The tensors corresponding to the input video. + vision_feature_layer (`int`): + The index of the layer to select the vision feature. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches + and are of shape `(num_videos, video_length, embed_dim)`). + """ batch_size, frames, channels, height, width = pixel_values.shape pixel_values = pixel_values.reshape(batch_size * frames, channels, height, width) - image_features = self.vision_tower(pixel_values, output_hidden_states=True) - selected_image_feature = image_features.hidden_states[self.vision_feature_layer] - if self.vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif self.vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature + video_features = self.vision_tower(pixel_values, output_hidden_states=True) + selected_video_features = video_features.hidden_states[vision_feature_layer] + if vision_feature_select_strategy == "default": + selected_video_features = selected_video_features[:, 1:] + elif vision_feature_select_strategy == "full": + selected_video_features = selected_video_features # Same as image features except that video has pooling layer - image_features = self.vision_resampler(selected_image_feature) - image_features = self.multi_modal_projector(image_features) - image_features = torch.split(image_features, frames, dim=0) - return image_features + video_features = self.vision_resampler(selected_video_features) + video_features = self.multi_modal_projector(video_features) + video_features = torch.split(video_features, frames, dim=0) + return video_features @replace_return_docstrings(output_type=LlavaNextVideoCausalLMOutputWithPast, config_class="LlavaNextVideoConfig") def forward( @@ -407,7 +447,12 @@ def forward( image_features = feature_lens = None if pixel_values is not None and pixel_values.size(0) > 0: - image_features = self._get_image_features(pixel_values, image_sizes) + image_features = self.get_image_features( + pixel_values, + image_sizes, + vision_feature_layer=self.vision_feature_layer, + vision_feature_select_strategy=self.vision_feature_select_strategy, + ) image_features, feature_lens = self.pack_image_features( image_features, image_sizes, @@ -417,7 +462,11 @@ def forward( video_features = video_feature_lens = None if pixel_values_videos is not None and pixel_values_videos.size(0) > 0: - video_features = self._get_video_features(pixel_values_videos) + video_features = self.get_video_features( + pixel_values_videos, + vision_feature_layer=self.vision_feature_layer, + vision_feature_select_strategy=self.vision_feature_select_strategy, + ) video_features = [feature.flatten(0, 1) for feature in video_features] video_feature_lens = [feature.size(0) for feature in video_features] video_features = torch.cat(video_features, dim=0) diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 7bacd2a54fc97f..946688bfcf07f4 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -363,18 +363,14 @@ def _init_weights(self, module): class LlavaOnevisionForConditionalGeneration(LlavaOnevisionPreTrainedModel, GenerationMixin): def __init__(self, config: LlavaOnevisionConfig): super().__init__(config) - self.vision_tower = AutoModel.from_config( - config.vision_config, attn_implementation=config._attn_implementation - ) + self.vision_tower = AutoModel.from_config(config.vision_config) self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config) embed_std = 1 / math.sqrt(config.text_config.hidden_size) self.image_newline = nn.Parameter(torch.randn(config.text_config.hidden_size, dtype=self.dtype) * embed_std) self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config( - config.text_config, attn_implementation=config._attn_implementation - ) + self.language_model = AutoModelForCausalLM.from_config(config.text_config) self.post_init() # Copied from transformers.models.llava_next.modeling_llava_next.LlavaNextForConditionalGeneration.get_input_embeddings @@ -485,6 +481,91 @@ def apply_pooling(self, image_features): image_features = image_features.view(batch_frames, -1, dim) return image_features + def get_image_features( + self, + pixel_values: torch.FloatTensor, + image_sizes: torch.Tensor, + vision_feature_layer: int, + vision_feature_select_strategy: str, + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_patches, channels, height, width)`) + The tensors corresponding to the input images. + image_sizes (`torch.Tensor` of shape `(num_images, 2)`) + Actual image size of each images (H, W). + vision_feature_layer (`int`): + The index of the layer to select the vision feature. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + image_features (List[`torch.Tensor`]): List of image feature tensor, each contains all the visual feature of all patches + and are of shape `(num_patches, image_length, embed_dim)`). + """ + # ! infer image_num_patches from image_sizes + image_num_patches = [ + image_size_to_num_patches( + image_size=imsize, + grid_pinpoints=self.config.image_grid_pinpoints, + patch_size=self.config.vision_config.image_size, + ) + for imsize in image_sizes + ] + if pixel_values.dim() == 5: + # stacked if input is (batch_size, num_patches, num_channels, height, width) + _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)] + pixel_values = torch.cat(_pixel_values_list, dim=0) + elif pixel_values.dim() != 4: + # otherwise has to be stacked from list of (num_patches, num_channels, height, width) + raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") + + image_features = self.vision_tower(pixel_values, output_hidden_states=True) + selected_image_feature = image_features.hidden_states[vision_feature_layer] + if vision_feature_select_strategy == "default": + selected_image_feature = selected_image_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_image_feature = selected_image_feature + image_features = self.multi_modal_projector(selected_image_feature) + image_features = torch.split(image_features, image_num_patches, dim=0) + return image_features + + def get_video_features( + self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str + ): + """ + Obtains video last hidden states from the vision tower, apply multimodal projection and pooling. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`) + The tensors corresponding to the input video. + vision_feature_layer (`int`): + The index of the layer to select the vision feature. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + video_features (List[`torch.Tensor`]): List of video feature tensor, each contains all the visual feature of all patches + and are of shape `(num_videos, video_length, embed_dim)`). + """ + batch_size, frames, channels, height, width = pixel_values.shape + pixel_values = pixel_values.view(batch_size * frames, channels, height, width) + video_features = self.vision_tower(pixel_values, output_hidden_states=True) + selected_video_feature = video_features.hidden_states[vision_feature_layer] + + if vision_feature_select_strategy == "default": + selected_video_feature = selected_video_feature[:, 1:] + elif vision_feature_select_strategy == "full": + selected_video_feature = selected_video_feature + video_features = self.multi_modal_projector(selected_video_feature) + + video_features = self.apply_pooling(video_features) + video_features = video_features.reshape(batch_size, frames * video_features.shape[1], -1) + + return video_features + @add_start_docstrings(LLAVA_ONEVISION_INPUTS_DOCSTRING) def forward( self, @@ -584,35 +665,12 @@ def forward( # Images are processed with Anyres if pixel_values is not None: - image_num_patches = [ - image_size_to_num_patches( - image_size=imsize, - grid_pinpoints=self.config.image_grid_pinpoints, - patch_size=self.config.vision_config.image_size, - ) - for imsize in image_sizes - ] - - # unpad extra patches and concatenate them - if pixel_values.dim() == 5: - _pixel_values_list = [ - pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches) - ] - # [batch_size*frames*num_patches, num_channels, height, width] where frames=1 for images - pixel_values = torch.cat(_pixel_values_list, dim=0) - elif pixel_values.dim() != 4: - raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") - - image_features = self.vision_tower(pixel_values, output_hidden_states=True) - selected_image_feature = image_features.hidden_states[vision_feature_layer] - - if vision_feature_select_strategy == "default": - selected_image_feature = selected_image_feature[:, 1:] - elif vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - image_features = self.multi_modal_projector(selected_image_feature) - - image_features = torch.split(image_features, image_num_patches, dim=0) + image_features = self.get_image_features( + pixel_values, + image_sizes, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) image_features, feature_lens = self.pack_image_features( image_features, image_sizes, @@ -636,20 +694,14 @@ def forward( # Video are simply embedded and further pooled to decrease seq len if pixel_values_videos is not None: - batch_size, frames, channels, height, width = pixel_values_videos.shape - pixel_values_videos = pixel_values_videos.view(batch_size * frames, channels, height, width) - video_features = self.vision_tower(pixel_values_videos, output_hidden_states=True) - selected_video_feature = video_features.hidden_states[vision_feature_layer] - - if vision_feature_select_strategy == "default": - selected_video_feature = selected_video_feature[:, 1:] - elif vision_feature_select_strategy == "full": - selected_video_feature = selected_video_feature - video_features = self.multi_modal_projector(selected_video_feature) - - video_features = self.apply_pooling(video_features) - video_features = video_features.reshape(batch_size, frames * video_features.shape[1], -1) - image_newline = self.image_newline[None, None, :].repeat(batch_size, 1, 1).to(video_features.device) + video_features = self.get_video_features( + pixel_values_videos, + vision_feature_layer=vision_feature_layer, + vision_feature_select_strategy=vision_feature_select_strategy, + ) + image_newline = ( + self.image_newline[None, None, :].repeat(video_features.shape[0], 1, 1).to(video_features.device) + ) video_features = torch.cat((video_features, image_newline), dim=1) video_features = video_features.flatten(0, 1) n_video_tokens = (input_ids == self.config.video_token_index).sum().item() diff --git a/src/transformers/models/longt5/configuration_longt5.py b/src/transformers/models/longt5/configuration_longt5.py index 0e541ae2a1b4fa..b6e7d21b3d677b 100644 --- a/src/transformers/models/longt5/configuration_longt5.py +++ b/src/transformers/models/longt5/configuration_longt5.py @@ -79,7 +79,12 @@ class LongT5Config(PretrainedConfig): model_type = "longt5" keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + "head_dim": "d_kv", + } def __init__( self, diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index d351e798ac7f88..29536d9ad6f284 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -24,7 +24,9 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -39,6 +41,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -317,7 +320,12 @@ def forward(self, hidden_states): # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->LongT5 class LongT5Attention(nn.Module): - def __init__(self, config: LongT5Config, has_relative_attention_bias=False): + def __init__( + self, + config: LongT5Config, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -328,6 +336,13 @@ def __init__(self, config: LongT5Config, has_relative_attention_bias=False): self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -404,11 +419,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( @@ -432,94 +450,72 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] - real_seq_length = seq_length + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) - - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -529,22 +525,22 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -1008,9 +1004,11 @@ def unshape(states): # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->LongT5 class LongT5LayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() - self.SelfAttention = LongT5Attention(config, has_relative_attention_bias=has_relative_attention_bias) + self.SelfAttention = LongT5Attention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -1023,6 +1021,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -1033,6 +1032,7 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -1042,7 +1042,7 @@ def forward( class LongT5LayerLocalSelfAttention(nn.Module): """Local self attention used in encoder""" - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.LocalSelfAttention = LongT5LocalAttention(config, has_relative_attention_bias=has_relative_attention_bias) self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) @@ -1073,7 +1073,7 @@ def forward( class LongT5LayerTransientGlobalSelfAttention(nn.Module): """Transient-Global self attention used in encoder""" - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.TransientGlobalSelfAttention = LongT5TransientGlobalAttention( config, has_relative_attention_bias=has_relative_attention_bias @@ -1105,9 +1105,9 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->LongT5 class LongT5LayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = LongT5Attention(config, has_relative_attention_bias=False) + self.EncDecAttention = LongT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -1122,6 +1122,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -1134,6 +1135,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -1141,7 +1143,7 @@ def forward( class LongT5Block(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder if config.is_decoder: @@ -1156,9 +1158,11 @@ def __init__(self, config, has_relative_attention_bias=False): f"but got {config.encoder_attention_type}." ) self.layer = nn.ModuleList() - self.layer.append(attention_layer(config, has_relative_attention_bias=has_relative_attention_bias)) + self.layer.append( + attention_layer(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx) + ) if self.is_decoder: - self.layer.append(LongT5LayerCrossAttention(config)) + self.layer.append(LongT5LayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(LongT5LayerFF(config)) @@ -1176,34 +1180,19 @@ def forward( use_cache=False, output_attentions=False, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/ @@ -1213,35 +1202,25 @@ def forward( do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - cross_attention_outputs = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, + past_key_value=past_key_value, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/ if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -1256,7 +1235,7 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + outputs = outputs + (past_key_value,) + attention_outputs else: outputs = outputs + attention_outputs @@ -1273,6 +1252,8 @@ class LongT5PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["LongT5Block"] + _supports_cache_class = True + _supports_static_cache = False # TODO: @raushan more involved due to local/global attn @property # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs @@ -1376,7 +1357,10 @@ def __init__(self, config, embed_tokens=None): self.block_len = self.local_radius + 1 self.block = nn.ModuleList( - [LongT5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + [ + LongT5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) + for i in range(config.num_layers) + ] ) self.final_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -1408,6 +1392,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1430,36 +1415,65 @@ def forward( err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if inputs_embeds is None: assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" inputs_embeds = self.embed_tokens(input_ids) batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - - if use_cache is True: - assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) - if attention_mask is None: + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values_length + seq_length attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - # We use local attention in encoder self-attention, otherwise standard self & cross attentions are used if self.is_decoder: - extended_attention_mask = self.get_extended_attention_mask( - attention_mask, input_shape, inputs_embeds.device + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, ) + # We use local attention in encoder self-attention, otherwise standard self & cross attentions are used elif self.config.encoder_attention_type == "local": - extended_attention_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device) + causal_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device) else: # we need to use both local attention mask and standard extended mask for transient-global attention - extended_attention_mask = attention_mask + causal_mask = attention_mask # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -1472,17 +1486,9 @@ def forward( else: encoder_extended_attention_mask = None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None @@ -1491,7 +1497,7 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] @@ -1502,7 +1508,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -1512,20 +1518,24 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + return_dict, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -1533,7 +1543,7 @@ def forward( if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -1541,9 +1551,6 @@ def forward( position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -1557,12 +1564,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -1571,12 +1584,135 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + LONGT5_START_DOCSTRING = r""" @@ -1693,6 +1829,9 @@ def forward( more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ LONGT5_ENCODER_INPUTS_DOCSTRING = r""" @@ -1817,6 +1956,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" Returns: @@ -1883,6 +2023,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1975,6 +2116,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -2050,6 +2192,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index d1cc3a13bf3cc3..c5ae615a12b5cc 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -1979,12 +1979,8 @@ def __init__(self, config: MllamaConfig): self.vision_output_dim = config.vision_config.vision_output_dim self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 - self.vision_model = MllamaVisionModel._from_config( - config.vision_config, attn_implementation=config._attn_implementation - ) - self.language_model = MllamaForCausalLM._from_config( - config.text_config, attn_implementation=config._attn_implementation - ) + self.vision_model = MllamaVisionModel._from_config(config.vision_config) + self.language_model = MllamaForCausalLM._from_config(config.text_config) self.multi_modal_projector = nn.Linear( config.vision_config.vision_output_dim, config.text_config.hidden_size, diff --git a/src/transformers/models/mt5/configuration_mt5.py b/src/transformers/models/mt5/configuration_mt5.py index ef629718b1b591..267179f81247e8 100644 --- a/src/transformers/models/mt5/configuration_mt5.py +++ b/src/transformers/models/mt5/configuration_mt5.py @@ -72,7 +72,12 @@ class MT5Config(PretrainedConfig): model_type = "mt5" keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + "head_dim": "d_kv", + } def __init__( self, diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 9051414d7414fa..659a84c5fe3784 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -25,7 +25,9 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -43,6 +45,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -214,7 +217,12 @@ def forward(self, hidden_states): # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->MT5 class MT5Attention(nn.Module): - def __init__(self, config: MT5Config, has_relative_attention_bias=False): + def __init__( + self, + config: MT5Config, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -225,6 +233,13 @@ def __init__(self, config: MT5Config, has_relative_attention_bias=False): self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -301,11 +316,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( @@ -329,94 +347,72 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] - real_seq_length = seq_length + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -426,22 +422,22 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -450,9 +446,11 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->MT5 class MT5LayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() - self.SelfAttention = MT5Attention(config, has_relative_attention_bias=has_relative_attention_bias) + self.SelfAttention = MT5Attention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -465,6 +463,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -475,6 +474,7 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -483,9 +483,9 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->MT5 class MT5LayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = MT5Attention(config, has_relative_attention_bias=False) + self.EncDecAttention = MT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -500,6 +500,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -512,6 +513,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -520,13 +522,15 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5Block with T5->MT5 class MT5Block(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.layer = nn.ModuleList() - self.layer.append(MT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + self.layer.append( + MT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx) + ) if self.is_decoder: - self.layer.append(MT5LayerCrossAttention(config)) + self.layer.append(MT5LayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(MT5LayerFF(config)) @@ -544,34 +548,19 @@ def forward( use_cache=False, output_attentions=False, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -585,25 +574,18 @@ def forward( do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - cross_attention_outputs = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, + past_key_value=past_key_value, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -614,10 +596,6 @@ def forward( ) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -636,11 +614,11 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + outputs = outputs + (past_key_value,) + attention_outputs else: outputs = outputs + attention_outputs - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) def load_tf_weights_in_mt5(model, config, tf_checkpoint_path): @@ -780,6 +758,9 @@ class MT5PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True + _supports_quantized_cache = False # enc-dec models don't support yet + _supports_static_cache = True + _supports_cache_class = True _no_split_modules = ["MT5Block"] _keep_in_fp32_modules = ["wo"] @@ -892,7 +873,7 @@ def __init__(self, config, embed_tokens=None): self.is_decoder = config.is_decoder self.block = nn.ModuleList( - [MT5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + [MT5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)] ) self.final_layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -968,6 +949,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): # Model parallel if self.model_parallel: @@ -994,6 +976,13 @@ def forward( err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") @@ -1001,23 +990,57 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if use_cache is True: if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) - if attention_mask is None: + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + elif attention_mask is not None: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + else: + causal_mask = None # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -1032,17 +1055,9 @@ def forward( else: encoder_extended_attention_mask = None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None @@ -1051,15 +1066,15 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) + if causal_mask is not None: + causal_mask = causal_mask.to(hidden_states.device) if position_bias is not None: position_bias = position_bias.to(hidden_states.device) if encoder_hidden_states is not None: @@ -1079,7 +1094,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -1089,20 +1104,24 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + return_dict, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -1110,7 +1129,7 @@ def forward( if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -1118,9 +1137,6 @@ def forward( position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -1140,12 +1156,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -1154,12 +1176,135 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + MT5_START_DOCSTRING = r""" @@ -1454,6 +1599,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" Returns: @@ -1533,6 +1679,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1685,6 +1832,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1779,6 +1927,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] diff --git a/src/transformers/models/musicgen/configuration_musicgen.py b/src/transformers/models/musicgen/configuration_musicgen.py index ef2e0244c1406f..0d282355defa96 100644 --- a/src/transformers/models/musicgen/configuration_musicgen.py +++ b/src/transformers/models/musicgen/configuration_musicgen.py @@ -236,20 +236,3 @@ def from_sub_models_config( # This is a property because you might want to change the codec model on the fly def sampling_rate(self): return self.audio_encoder.sampling_rate - - @property - def _attn_implementation(self): - # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.) - if hasattr(self, "_attn_implementation_internal"): - if self._attn_implementation_internal is None: - # `config.attn_implementation` should never be None, for backward compatibility. - return "eager" - else: - return self._attn_implementation_internal - else: - return "eager" - - @_attn_implementation.setter - def _attn_implementation(self, value): - self._attn_implementation_internal = value - self.decoder._attn_implementation = value diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 626097f5c7cbcc..c18e1d1c9d86b1 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1713,7 +1713,7 @@ def __init__( audio_encoder = AutoModel.from_config(config.audio_encoder) if decoder is None: - decoder = MusicgenForCausalLM(config.decoder) + decoder = MusicgenForCausalLM._from_config(config.decoder) self.text_encoder = text_encoder self.audio_encoder = audio_encoder @@ -1737,6 +1737,9 @@ def __init__( # make sure that the individual model's config refers to the shared config # so that the updates to the config will be synced + self.config.text_encoder._attn_implementation = self.text_encoder.config._attn_implementation + self.config.audio_encoder._attn_implementation = self.audio_encoder.config._attn_implementation + self.config.decoder._attn_implementation = self.decoder.config._attn_implementation self.text_encoder.config = self.config.text_encoder self.audio_encoder.config = self.config.audio_encoder self.decoder.config = self.config.decoder diff --git a/src/transformers/models/musicgen_melody/configuration_musicgen_melody.py b/src/transformers/models/musicgen_melody/configuration_musicgen_melody.py index b29187facb3d1b..8a77cea0252234 100644 --- a/src/transformers/models/musicgen_melody/configuration_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/configuration_musicgen_melody.py @@ -250,20 +250,3 @@ def from_sub_models_config( # This is a property because you might want to change the codec model on the fly def sampling_rate(self): return self.audio_encoder.sampling_rate - - @property - def _attn_implementation(self): - # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.) - if hasattr(self, "_attn_implementation_internal"): - if self._attn_implementation_internal is None: - # `config.attn_implementation` should never be None, for backward compatibility. - return "eager" - else: - return self._attn_implementation_internal - else: - return "eager" - - @_attn_implementation.setter - def _attn_implementation(self, value): - self._attn_implementation_internal = value - self.decoder._attn_implementation = value diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 166623796d65d0..d2f339afc41451 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -1628,7 +1628,7 @@ def __init__( audio_encoder = AutoModel.from_config(config.audio_encoder) if decoder is None: - decoder = MusicgenMelodyForCausalLM(config.decoder) + decoder = MusicgenMelodyForCausalLM._from_config(config.decoder) self.text_encoder = text_encoder self.audio_encoder = audio_encoder @@ -1636,6 +1636,9 @@ def __init__( # make sure that the individual model's config refers to the shared config # so that the updates to the config will be synced + self.config.text_encoder._attn_implementation = self.text_encoder.config._attn_implementation + self.config.audio_encoder._attn_implementation = self.audio_encoder.config._attn_implementation + self.config.decoder._attn_implementation = self.decoder.config._attn_implementation self.text_encoder.config = self.config.text_encoder self.audio_encoder.config = self.config.audio_encoder self.decoder.config = self.config.decoder diff --git a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py index bf9dbd951b5b06..0f44e4bd40208c 100644 --- a/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py +++ b/src/transformers/models/omdet_turbo/modeling_omdet_turbo.py @@ -288,7 +288,7 @@ def put(self, key, value) -> None: class OmDetTurboLanguageBackbone(nn.Module): def __init__(self, config: OmDetTurboConfig): super().__init__() - self.model = AutoModel.from_config(config.text_config, attn_implementation=config._attn_implementation) + self.model = AutoModel.from_config(config.text_config) self.text_projection = nn.Parameter(torch.zeros(config.text_projection_in_dim, config.text_projection_out_dim)) def forward(self, hidden_states, mask=None, encode_type="task"): diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 1607261eaac673..e198dab420abe8 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -193,12 +193,12 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["PaliGemmaMultiModalProjector"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = False _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True - _supports_sdpa = True _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_sdpa = True def _init_weights(self, module): # important: this ported version of PaliGemmaisn't meant for training from scratch - only @@ -221,14 +221,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - @property - def _supports_sdpa(self): - """ - Retrieve language_model's attribute to check whether the model supports - SDPA or not. - """ - return self.language_model._supports_sdpa - PALIGEMMA_INPUTS_DOCSTRING = r""" Args: @@ -310,11 +302,8 @@ def __init__(self, config: PaliGemmaConfig): self.vision_tower = AutoModel.from_config(config=config.vision_config) self.multi_modal_projector = PaliGemmaMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size - self._attn_implementation = config._attn_implementation - language_model = AutoModelForCausalLM.from_config( - config=config.text_config, attn_implementation=self._attn_implementation - ) + language_model = AutoModelForCausalLM.from_config(config=config.text_config) if language_model._tied_weights_keys is not None: self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys] @@ -354,6 +343,11 @@ def tie_weights(self): def _update_causal_mask( self, attention_mask, token_type_ids, inputs_embeds, past_key_values, cache_position, is_training: bool = False ): + if self.config.text_config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + using_static_cache = isinstance(past_key_values, StaticCache) dtype = inputs_embeds.dtype min_dtype = torch.finfo(dtype).min @@ -398,6 +392,22 @@ def _update_causal_mask( ) return causal_mask + def get_image_features(self, pixel_values: torch.FloatTensor): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + image_outputs = self.vision_tower(pixel_values) + selected_image_feature = image_outputs.last_hidden_state + image_features = self.multi_modal_projector(selected_image_feature) + image_features = image_features / (self.config.hidden_size**0.5) + return image_features + @add_start_docstrings_to_model_forward(PALIGEMMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=PaliGemmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( @@ -483,10 +493,7 @@ def forward( # Merge text and images if pixel_values is not None: - image_outputs = self.vision_tower(pixel_values.to(inputs_embeds.dtype)) - selected_image_feature = image_outputs.last_hidden_state - image_features = self.multi_modal_projector(selected_image_feature) - image_features = image_features / (self.config.hidden_size**0.5) + image_features = self.get_image_features(pixel_values) special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index d6f92e9fe03495..6a64a27e007b3e 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -25,7 +25,9 @@ from transformers.generation import GenerationConfig from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -37,6 +39,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -136,6 +139,9 @@ more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ @@ -245,7 +251,12 @@ def forward(self, hidden_states): # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Pop2Piano,t5->pop2piano class Pop2PianoAttention(nn.Module): - def __init__(self, config: Pop2PianoConfig, has_relative_attention_bias=False): + def __init__( + self, + config: Pop2PianoConfig, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -256,6 +267,13 @@ def __init__(self, config: Pop2PianoConfig, has_relative_attention_bias=False): self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -332,11 +350,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( @@ -360,94 +381,72 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] - real_seq_length = seq_length + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -457,22 +456,22 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -481,9 +480,11 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Pop2Piano,t5->pop2piano class Pop2PianoLayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() - self.SelfAttention = Pop2PianoAttention(config, has_relative_attention_bias=has_relative_attention_bias) + self.SelfAttention = Pop2PianoAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -496,6 +497,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -506,6 +508,7 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -514,9 +517,9 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Pop2Piano,t5->pop2piano class Pop2PianoLayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = Pop2PianoAttention(config, has_relative_attention_bias=False) + self.EncDecAttention = Pop2PianoAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -531,6 +534,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -543,6 +547,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -551,13 +556,17 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5Block with T5->Pop2Piano,t5->pop2piano class Pop2PianoBlock(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.layer = nn.ModuleList() - self.layer.append(Pop2PianoLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + self.layer.append( + Pop2PianoLayerSelfAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + ) if self.is_decoder: - self.layer.append(Pop2PianoLayerCrossAttention(config)) + self.layer.append(Pop2PianoLayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(Pop2PianoLayerFF(config)) @@ -575,34 +584,19 @@ def forward( use_cache=False, output_attentions=False, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -616,25 +610,18 @@ def forward( do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - cross_attention_outputs = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, + past_key_value=past_key_value, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -645,10 +632,6 @@ def forward( ) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -667,11 +650,11 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + outputs = outputs + (past_key_value,) + attention_outputs else: outputs = outputs + attention_outputs - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) class Pop2PianoPreTrainedModel(PreTrainedModel): @@ -684,6 +667,8 @@ class Pop2PianoPreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = False supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = False _no_split_modules = ["Pop2PianoBlock"] _keep_in_fp32_modules = ["wo"] @@ -769,7 +754,10 @@ def __init__(self, config, embed_tokens=None): self.is_decoder = config.is_decoder self.block = nn.ModuleList( - [Pop2PianoBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + [ + Pop2PianoBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) + for i in range(config.num_layers) + ] ) self.final_layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -803,6 +791,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -825,6 +814,13 @@ def forward( err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") @@ -832,28 +828,55 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if use_cache is True: if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -866,17 +889,9 @@ def forward( else: encoder_extended_attention_mask = None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None @@ -885,7 +900,7 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] if output_hidden_states: @@ -895,7 +910,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -905,20 +920,22 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -926,7 +943,7 @@ def forward( if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -934,9 +951,6 @@ def forward( position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -950,12 +964,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -964,12 +984,135 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + class Pop2PianoConcatEmbeddingToMel(nn.Module): """Embedding Matrix for `composer` tokens.""" @@ -1122,6 +1265,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1177,6 +1321,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index e923e535da8e34..ce0e427048cf23 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -544,6 +544,7 @@ class Qwen2AudioPreTrainedModel(PreTrainedModel): _no_split_modules = ["Qwen2AudioAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_sdpa = True def _init_weights(self, module): # important: this ported version of Qwen2Audio isn't meant for training from scratch - only @@ -559,14 +560,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - @property - def _supports_sdpa(self): - """ - Retrieve language_model's attribute to check whether the model supports - SDPA or not. - """ - return self.language_model._supports_sdpa - QWEN2AUDIOENCODER_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -859,13 +852,11 @@ def forward(self, audio_features): class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin): def __init__(self, config: Qwen2AudioConfig): super().__init__(config) - self.audio_tower = AutoModel.from_config(config.audio_config, attn_implementation=config._attn_implementation) + self.audio_tower = AutoModel.from_config(config.audio_config) self.multi_modal_projector = Qwen2AudioMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config( - config.text_config, attn_implementation=config._attn_implementation - ) + self.language_model = AutoModelForCausalLM.from_config(config.text_config) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides self.post_init() diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index f4cb84a2444eb6..07531248f63b1d 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1443,9 +1443,7 @@ class Qwen2VLForConditionalGeneration(Qwen2VLPreTrainedModel, GenerationMixin): def __init__(self, config): super().__init__(config) - self.visual = Qwen2VisionTransformerPretrainedModel._from_config( - config.vision_config, attn_implementation=config._attn_implementation - ) + self.visual = Qwen2VisionTransformerPretrainedModel._from_config(config.vision_config) self.model = Qwen2VLModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 5e6f13ca478f32..dfc2664b78a3dc 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -232,6 +232,8 @@ class RagPreTrainedModel(PreTrainedModel): config_class = RagConfig base_model_prefix = "rag" + _supports_flash_attn_2 = True + _supports_sdpa = True @classmethod def from_pretrained(cls, *args, **kwargs): @@ -506,16 +508,12 @@ def __init__( if question_encoder is None: from ..auto.modeling_auto import AutoModel - question_encoder = AutoModel.from_config( - config.question_encoder, attn_implementation=config._attn_implementation - ) + question_encoder = AutoModel.from_config(config.question_encoder) if generator is None: from ..auto.modeling_auto import AutoModelForSeq2SeqLM - generator = AutoModelForSeq2SeqLM.from_config( - config.generator, attn_implementation=config._attn_implementation - ) + generator = AutoModelForSeq2SeqLM.from_config(config.generator) self.retriever = retriever if self.retriever is not None: diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index 507e0768a226ef..a3d06cbb4792b4 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -669,6 +669,7 @@ class SiglipPreTrainedModel(PreTrainedModel): config_class = SiglipConfig base_model_prefix = "siglip" supports_gradient_checkpointing = True + _no_split_modules = [ "SiglipTextEmbeddings", "SiglipEncoderLayer", @@ -1218,8 +1219,8 @@ def __init__(self, config: SiglipConfig): vision_config = config.vision_config # First, initialize the text and vision models with proper attention implementation - text_model = SiglipTextModel._from_config(text_config, attn_implementation=config._attn_implementation) - vision_model = SiglipVisionModel._from_config(vision_config, attn_implementation=config._attn_implementation) + text_model = SiglipTextModel._from_config(text_config) + vision_model = SiglipVisionModel._from_config(vision_config) # Second, get the text and vision submodules (for backward compatibility) self.text_model = text_model.text_model @@ -1454,9 +1455,7 @@ def __init__(self, config: SiglipConfig) -> None: # Create the vision model with proper attention # and take only vision_model submodule (for backward compatibility) - vision_model = SiglipVisionModel._from_config( - config.vision_config, attn_implementation=config._attn_implementation - ) + vision_model = SiglipVisionModel._from_config(config.vision_config) self.vision_model = vision_model.vision_model # Classifier head diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index a1caa7cf6da2f7..0d2b911bebe582 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -183,6 +183,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel, GenerationMixin): main_input_name = "inputs" supports_gradient_checkpointing = True _supports_param_buffer_assignment = False + _supports_flash_attn_2 = True + _supports_sdpa = True def __init__( self, @@ -213,10 +215,10 @@ def __init__( super().__init__(config) if encoder is None: - encoder = AutoModel.from_config(config.encoder, attn_implementation=config._attn_implementation) + encoder = AutoModel.from_config(config.encoder) if decoder is None: - decoder = AutoModelForCausalLM.from_config(config.decoder, attn_implementation=config._attn_implementation) + decoder = AutoModelForCausalLM.from_config(config.decoder) self.encoder = encoder self.decoder = decoder @@ -234,6 +236,8 @@ def __init__( # make sure that the individual model's config refers to the shared config # so that the updates to the config will be synced + self.config.encoder._attn_implementation = self.encoder.config._attn_implementation + self.config.decoder._attn_implementation = self.decoder.config._attn_implementation self.encoder.config = self.config.encoder self.decoder.config = self.config.decoder diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index c39e85bacdd3d1..b150b04eea57b8 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -24,7 +24,9 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( MoEModelOutput, MoEModelOutputWithPastAndCrossAttentions, @@ -39,6 +41,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -355,7 +358,12 @@ def forward(self, hidden_states, output_router_logits): # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->SwitchTransformers class SwitchTransformersAttention(nn.Module): - def __init__(self, config: SwitchTransformersConfig, has_relative_attention_bias=False): + def __init__( + self, + config: SwitchTransformersConfig, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -366,6 +374,13 @@ def __init__(self, config: SwitchTransformersConfig, has_relative_attention_bias self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -442,11 +457,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( @@ -470,94 +488,72 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] - real_seq_length = seq_length + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -567,22 +563,22 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -591,10 +587,10 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->SwitchTransformers class SwitchTransformersLayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.SelfAttention = SwitchTransformersAttention( - config, has_relative_attention_bias=has_relative_attention_bias + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx ) self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -608,6 +604,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -618,6 +615,7 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -626,9 +624,11 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->SwitchTransformers class SwitchTransformersLayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = SwitchTransformersAttention(config, has_relative_attention_bias=False) + self.EncDecAttention = SwitchTransformersAttention( + config, has_relative_attention_bias=False, layer_idx=layer_idx + ) self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -643,6 +643,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -655,6 +656,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -662,16 +664,18 @@ def forward( class SwitchTransformersBlock(nn.Module): - def __init__(self, config, has_relative_attention_bias=False, is_sparse=False): + def __init__(self, config, has_relative_attention_bias=False, is_sparse=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.is_sparse = is_sparse self.layer = nn.ModuleList() self.layer.append( - SwitchTransformersLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias) + SwitchTransformersLayerSelfAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) ) if self.is_decoder: - self.layer.append(SwitchTransformersLayerCrossAttention(config)) + self.layer.append(SwitchTransformersLayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(SwitchTransformersLayerFF(config, is_sparse=self.is_sparse)) @@ -690,34 +694,19 @@ def forward( output_attentions=False, output_router_logits=True, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -727,35 +716,25 @@ def forward( do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - cross_attention_outputs = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, + past_key_value=past_key_value, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -775,11 +754,11 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + (router_tuple,) + outputs = outputs + (past_key_value,) + attention_outputs + (router_tuple,) else: outputs = outputs + attention_outputs + (router_tuple,) - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights), (router_tuple) + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights), (router_tuple) class SwitchTransformersPreTrainedModel(PreTrainedModel): @@ -791,6 +770,8 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel): config_class = SwitchTransformersConfig base_model_prefix = "switch_transformers" supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = False _no_split_modules = ["SwitchTransformersBlock"] @property @@ -897,7 +878,9 @@ def __init__(self, config, embed_tokens=None): is_sparse = (i % sparse_step == 1 or sparse_step == 1) if sparse_step > 0 else False self.block.append( - SwitchTransformersBlock(config, has_relative_attention_bias=bool(i == 0), is_sparse=is_sparse) + SwitchTransformersBlock( + config, has_relative_attention_bias=bool(i == 0), is_sparse=is_sparse, layer_idx=i + ) ) self.final_layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) @@ -930,6 +913,7 @@ def forward( output_hidden_states=None, output_router_logits=True, return_dict=None, + cache_position=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -952,6 +936,13 @@ def forward( err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") @@ -959,28 +950,55 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if use_cache is True: if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -993,17 +1011,9 @@ def forward( else: encoder_extended_attention_mask = None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_router_probs = () if output_router_logits else None @@ -1013,7 +1023,7 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] @@ -1024,7 +1034,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -1034,21 +1044,26 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + output_router_logits, + return_dict, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, ) router_probs = layer_outputs[-1] @@ -1059,7 +1074,7 @@ def forward( if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -1067,9 +1082,6 @@ def forward( position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -1086,12 +1098,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -1101,13 +1119,136 @@ def forward( ) return MoEModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, router_probs=all_router_probs, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + SWITCH_TRANSFORMERS_START_DOCSTRING = r""" @@ -1228,6 +1369,9 @@ def forward( should not be returned during inference. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ SWITCH_TRANSFORMERS_ENCODER_INPUTS_DOCSTRING = r""" @@ -1355,6 +1499,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqMoEModelOutput]: r""" Returns: @@ -1435,6 +1580,7 @@ def forward( output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1535,6 +1681,7 @@ def forward( output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = True, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqMoEOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1618,6 +1765,7 @@ def forward( output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] diff --git a/src/transformers/models/t5/configuration_t5.py b/src/transformers/models/t5/configuration_t5.py index e5f2615611b879..be6fbe9528d10a 100644 --- a/src/transformers/models/t5/configuration_t5.py +++ b/src/transformers/models/t5/configuration_t5.py @@ -73,7 +73,12 @@ class T5Config(PretrainedConfig): model_type = "t5" keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + "head_dim": "d_kv", + } def __init__( self, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 91596f013ab4f5..9012c8db9feb0a 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -25,7 +25,9 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -43,6 +45,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -339,7 +342,12 @@ def forward(self, hidden_states): class T5Attention(nn.Module): - def __init__(self, config: T5Config, has_relative_attention_bias=False): + def __init__( + self, + config: T5Config, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -350,6 +358,13 @@ def __init__(self, config: T5Config, has_relative_attention_bias=False): self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -426,11 +441,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( @@ -454,94 +472,72 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] - real_seq_length = seq_length + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -551,22 +547,22 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -574,9 +570,11 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): class T5LayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() - self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias) + self.SelfAttention = T5Attention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -589,6 +587,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -599,6 +598,7 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -606,9 +606,9 @@ def forward( class T5LayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False) + self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -623,6 +623,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -635,6 +636,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -642,13 +644,15 @@ def forward( class T5Block(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.layer = nn.ModuleList() - self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + self.layer.append( + T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx) + ) if self.is_decoder: - self.layer.append(T5LayerCrossAttention(config)) + self.layer.append(T5LayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(T5LayerFF(config)) @@ -666,34 +670,19 @@ def forward( use_cache=False, output_attentions=False, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -707,25 +696,18 @@ def forward( do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - cross_attention_outputs = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, + past_key_value=past_key_value, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -736,10 +718,6 @@ def forward( ) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -758,11 +736,11 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + outputs = outputs + (past_key_value,) + attention_outputs else: outputs = outputs + attention_outputs - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) class T5ClassificationHead(nn.Module): @@ -794,6 +772,9 @@ class T5PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True + _supports_quantized_cache = False # enc-dec models don't support yet + _supports_static_cache = True + _supports_cache_class = True _no_split_modules = ["T5Block"] _keep_in_fp32_modules = ["wo"] @@ -905,7 +886,7 @@ def __init__(self, config, embed_tokens=None): self.is_decoder = config.is_decoder self.block = nn.ModuleList( - [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + [T5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)] ) self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -981,6 +962,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): # Model parallel if self.model_parallel: @@ -1007,6 +989,13 @@ def forward( err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") @@ -1014,23 +1003,57 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if use_cache is True: if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) - if attention_mask is None: + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + elif attention_mask is not None: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + else: + causal_mask = None # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -1045,17 +1068,9 @@ def forward( else: encoder_extended_attention_mask = None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None @@ -1064,15 +1079,15 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) # Ensure that attention_mask is always on the same device as hidden_states - if attention_mask is not None: - attention_mask = attention_mask.to(hidden_states.device) + if causal_mask is not None: + causal_mask = causal_mask.to(hidden_states.device) if position_bias is not None: position_bias = position_bias.to(hidden_states.device) if encoder_hidden_states is not None: @@ -1092,7 +1107,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -1102,20 +1117,24 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + return_dict, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + return_dict=return_dict, + cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -1123,7 +1142,7 @@ def forward( if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -1131,9 +1150,6 @@ def forward( position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -1153,12 +1169,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -1167,12 +1189,135 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + T5_START_DOCSTRING = r""" @@ -1286,6 +1431,9 @@ def forward( more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ T5_ENCODER_INPUTS_DOCSTRING = r""" @@ -1446,6 +1594,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" Returns: @@ -1525,6 +1674,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1656,6 +1806,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1750,6 +1901,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index 6be8752d5b63b0..1928ac8a5c20c9 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -34,13 +34,16 @@ ) from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ( ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, + is_torchdynamo_compiling, replace_return_docstrings, ) @@ -154,6 +157,9 @@ more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ @@ -411,6 +417,8 @@ class UdopPreTrainedModel(PreTrainedModel): config_class = UdopConfig base_model_prefix = "transformer" supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = False _keep_in_fp32_modules = ["wo"] def _init_weights(self, module): @@ -598,7 +606,12 @@ def forward(self, hidden_states): # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Udop class UdopAttention(nn.Module): - def __init__(self, config: UdopConfig, has_relative_attention_bias=False): + def __init__( + self, + config: UdopConfig, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -609,6 +622,13 @@ def __init__(self, config: UdopConfig, has_relative_attention_bias=False): self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -685,11 +705,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( @@ -713,94 +736,72 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] - real_seq_length = seq_length - - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - def shape(states): - """projection""" - return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - def unshape(states): - """reshape""" - return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states - - # get query states - query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) - - # get key/value states - key_states = project( - hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True - # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) - - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -810,22 +811,22 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): position_bias_masked = position_bias scores += position_bias_masked - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( - scores - ) # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.dropout( - attn_weights, p=self.dropout, training=self.training - ) # (batch_size, n_heads, seq_length, key_length) + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.o(attn_output) - present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -834,9 +835,11 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Udop class UdopLayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() - self.SelfAttention = UdopAttention(config, has_relative_attention_bias=has_relative_attention_bias) + self.SelfAttention = UdopAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -849,6 +852,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -859,6 +863,7 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -867,9 +872,9 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Udop class UdopLayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = UdopAttention(config, has_relative_attention_bias=False) + self.EncDecAttention = UdopAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -884,6 +889,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -896,6 +902,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -904,13 +911,17 @@ def forward( # Copied from transformers.models.t5.modeling_t5.T5Block with T5->Udop class UdopBlock(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.layer = nn.ModuleList() - self.layer.append(UdopLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + self.layer.append( + UdopLayerSelfAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + ) if self.is_decoder: - self.layer.append(UdopLayerCrossAttention(config)) + self.layer.append(UdopLayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(UdopLayerFF(config)) @@ -928,34 +939,19 @@ def forward( use_cache=False, output_attentions=False, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - if not self.is_decoder: - logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.layer[0]( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -969,25 +965,18 @@ def forward( do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - cross_attention_outputs = self.layer[1]( hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, + past_key_value=past_key_value, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -998,10 +987,6 @@ def forward( ) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -1020,11 +1005,11 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + outputs = outputs + (past_key_value,) + attention_outputs else: outputs = outputs + attention_outputs - return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) class UdopCellEmbeddings(nn.Module): @@ -1286,7 +1271,7 @@ def __init__(self, config, embed_tokens=None, embed_patches=None): self.num_layers = config.num_layers self.block = nn.ModuleList( - [UdopBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(self.num_layers)] + [UdopBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(self.num_layers)] ) self.final_layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) @@ -1338,6 +1323,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -1399,26 +1385,54 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if use_cache is True: assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self) - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min if self.is_decoder and encoder_attention_mask is not None: encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) @@ -1427,7 +1441,6 @@ def forward( # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None @@ -1436,34 +1449,35 @@ def forward( position_bias = None else: position_bias = self.relative_bias(attention_mask=attention_mask, bbox=bbox) - position_bias = position_bias + extended_attention_mask + position_bias = position_bias + causal_mask encoder_decoder_position_bias = None hidden_states = inputs_embeds hidden_states = self.dropout(hidden_states) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=head_mask[i], - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) # layer_outputs is a tuple with: # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) if use_cache is False: # MP fixes layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention weights), @@ -1472,9 +1486,6 @@ def forward( position_bias = layer_outputs[2] if self.is_decoder and encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now @@ -1488,13 +1499,19 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, attention_mask, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -1505,12 +1522,135 @@ def forward( return BaseModelOutputWithAttentionMask( last_hidden_state=hidden_states, attention_mask=attention_mask, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + @add_start_docstrings( "The bare UDOP encoder-decoder Transformer outputting raw hidden-states without any specific head on top.", @@ -1584,6 +1724,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[Tensor, ...]: r""" Returns: @@ -1653,6 +1794,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1759,6 +1901,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[Tensor, ...]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1837,6 +1980,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] diff --git a/src/transformers/models/umt5/configuration_umt5.py b/src/transformers/models/umt5/configuration_umt5.py index d7323d759fd086..ba8ea0460ba071 100644 --- a/src/transformers/models/umt5/configuration_umt5.py +++ b/src/transformers/models/umt5/configuration_umt5.py @@ -72,7 +72,12 @@ class UMT5Config(PretrainedConfig): model_type = "umt5" keys_to_ignore_at_inference = ["past_key_values"] - attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} + attribute_map = { + "hidden_size": "d_model", + "num_attention_heads": "num_heads", + "num_hidden_layers": "num_layers", + "head_dim": "d_kv", + } def __init__( self, diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index bd621fc2fb3ac2..985dc5e4426dff 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -23,7 +23,9 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -40,6 +42,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -155,7 +158,7 @@ class UMT5Attention(nn.Module): T5's attention using relative_attention_bias. """ - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.has_relative_attention_bias = has_relative_attention_bias @@ -166,6 +169,13 @@ def __init__(self, config, has_relative_attention_bias=False): self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) @@ -230,11 +240,14 @@ def _relative_position_bucket(self, relative_position): relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket(relative_position) @@ -249,78 +262,95 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.Tensor] = None, ): - is_cross_attention = encoder_hidden_states is not None batch_size, seq_length = hidden_states.shape[:2] - # use encoder_hidden_states if cross attention - current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - # checking that the `sequence_length` of the `past_key_value` is the same as the he provided - # `encoder_hidden_states` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + # if encoder_hidden_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = encoder_hidden_states is not None + + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - key_states = self._shape(self.k(current_states)) - value_states = self._shape(self.v(current_states)) - if past_key_value is not None and not is_cross_attention: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - - query_states = self._shape(self.q(hidden_states)) - attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - # compute positional bias - if self.has_relative_attention_bias: - query_length = seq_length if past_key_value is not None: - query_length += past_key_value[0].shape[2] - position_bias = self.compute_bias(query_length, key_states.size(2), device=attention_scores.device) - else: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = seq_length + past_key_value.get_seq_length() if past_key_value is not None else seq_length + key_length = key_states.shape[-2] + if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, seq_length, key_states.size(2)), - device=attention_scores.device, - dtype=attention_scores.dtype, - requires_grad=self.training, + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] + if attention_mask is not None: - position_bias = position_bias + attention_mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - attention_scores += position_bias # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.softmax(attention_scores.float(), dim=-1).type_as(attention_scores) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask - # attn_output = torch.bmm(attn_probs, value_states) ? - context_states = torch.matmul(attn_weights, value_states) - # attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) ? - context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) - attn_output = self.o(context_states) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, seq_length, -1) + + attn_output = self.o(attn_output) return attn_output, attn_weights, past_key_value class UMT5LayerSelfAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.SelfAttention = UMT5Attention(config, has_relative_attention_bias=True) + self.SelfAttention = UMT5Attention(config, has_relative_attention_bias=True, layer_idx=layer_idx) self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -330,6 +360,7 @@ def forward( attention_mask=None, layer_head_mask=None, past_key_value=None, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.SelfAttention( @@ -337,6 +368,7 @@ def forward( attention_mask=attention_mask, layer_head_mask=layer_head_mask, past_key_value=past_key_value, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them @@ -344,9 +376,9 @@ def forward( class UMT5LayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.EncDecAttention = UMT5Attention(config, has_relative_attention_bias=False) + self.EncDecAttention = UMT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -357,6 +389,7 @@ def forward( attention_mask=None, layer_head_mask=None, past_key_value=None, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.EncDecAttention( @@ -365,6 +398,7 @@ def forward( attention_mask=attention_mask, layer_head_mask=layer_head_mask, past_key_value=past_key_value, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -372,13 +406,13 @@ def forward( class UMT5Block(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() self.is_decoder = config.is_decoder self.layer = nn.ModuleList() - self.layer.append(UMT5LayerSelfAttention(config)) + self.layer.append(UMT5LayerSelfAttention(config, layer_idx=layer_idx)) if self.is_decoder: - self.layer.append(UMT5LayerCrossAttention(config)) + self.layer.append(UMT5LayerCrossAttention(config, layer_idx=layer_idx)) self.layer.append(UMT5LayerFF(config)) @@ -393,16 +427,14 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): - # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - - hidden_states, self_attn_weights, present_key_value = self.layer[0]( + hidden_states, self_attn_weights, past_key_value = self.layer[0]( hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) # clamp inf values to enable fp16 training @@ -412,18 +444,16 @@ def forward( hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None do_cross_attention = self.is_decoder and encoder_hidden_states is not None if do_cross_attention: - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.layer[1]( + hidden_states, cross_attn_weights, past_key_value = self.layer[1]( hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, + cache_position=cache_position, ) # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16: @@ -431,8 +461,6 @@ def forward( clamp_value = torch.where(torch.isinf(hidden_states).any(), max_dtype - 1000, max_dtype) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - present_key_value += cross_attn_present_key_value - # Apply Feed Forward layer hidden_states = self.layer[-1](hidden_states) @@ -444,7 +472,7 @@ def forward( outputs = ( hidden_states, - present_key_value, + past_key_value, ) if output_attentions: @@ -481,6 +509,8 @@ class UMT5PreTrainedModel(PreTrainedModel): config_class = UMT5Config base_model_prefix = "transformer" supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = True _no_split_modules = ["UMT5Block"] _keep_in_fp32_modules = ["wo"] @@ -594,7 +624,7 @@ def __init__(self, config, embed_tokens=None): super().__init__(config) self.embed_tokens = embed_tokens self.is_decoder = config.is_decoder - self.block = nn.ModuleList([UMT5Block(config) for i in range(config.num_layers)]) + self.block = nn.ModuleList([UMT5Block(config, layer_idx=i) for i in range(config.num_layers)]) self.final_layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -622,6 +652,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -644,6 +675,13 @@ def forward( err_msg_prefix = "decoder_" if self.is_decoder else "" raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if inputs_embeds is None: if self.embed_tokens is None: raise ValueError("You have to initialize the model with valid token embeddings") @@ -651,28 +689,57 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length - if use_cache is True: if not self.is_decoder: raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") - if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if self.is_decoder and (use_cache or past_key_values is not None): + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + elif not self.is_decoder: + # do not pass cache object down the line for encoder stack + # it messes indexing later in decoder-stack because cache object is modified in-place + past_key_values = None + + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * len(self.block) + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + elif attention_mask is not None: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min + else: + causal_mask = None # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -685,24 +752,16 @@ def forward( else: encoder_extended_attention_mask = None - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.is_decoder else None hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + for i, layer_module in enumerate(self.block): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] @@ -713,7 +772,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, encoder_hidden_states, encoder_extended_attention_mask, layer_head_mask, @@ -721,24 +780,26 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - present_key_value_states += (layer_outputs[1],) + next_decoder_cache = layer_outputs[1] if output_attentions: all_attentions += (layer_outputs[2],) @@ -752,12 +813,18 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -766,12 +833,135 @@ def forward( ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + UMT5_START_DOCSTRING = r""" @@ -885,6 +1075,9 @@ def forward( more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ UMT5_ENCODER_INPUTS_DOCSTRING = r""" @@ -1022,6 +1215,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" Returns: @@ -1084,6 +1278,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1197,6 +1392,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1268,6 +1464,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = decoder_outputs[0] diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index c9703d263e7d20..c4ec1b5196929a 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -23,7 +23,7 @@ from ...activations import ACT2FN from ...generation import GenerationMixin -from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput +from ...modeling_outputs import ModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ( add_start_docstrings, @@ -126,8 +126,9 @@ class VideoLlavaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["VideoLlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_sdpa = True def _init_weights(self, module): std = ( @@ -148,14 +149,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - @property - def _supports_sdpa(self): - """ - Retrieve language_model's attribute to check whether the model supports - SDPA or not. - """ - return self.language_model._supports_sdpa - VIDEO_LLAVA_INPUTS_DOCSTRING = r""" Args: @@ -248,9 +241,7 @@ def __init__(self, config: VideoLlavaConfig): self.multi_modal_projector = VideoLlavaMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config( - config.text_config, attn_implementation=config._attn_implementation - ) + self.language_model = AutoModelForCausalLM.from_config(config.text_config) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() @@ -364,41 +355,59 @@ def _merge_input_ids_with_visual_features( return final_embedding, final_attention_mask, final_labels, position_ids, final_input_ids - def _get_vision_features( - self, - pixel_values_images: Optional[torch.FloatTensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - vision_feature_layer: Optional[int] = None, - vision_feature_select_strategy: Optional[str] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: - if pixel_values_images is None and pixel_values_videos is None: - raise ValueError("You have to specify `pixel_values_images` or `pixel_values_videos`") + def get_image_features( + self, pixel_values_images: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str + ): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. - # videos do not need to select features and it's always "full" (as it is done in the orig implementation) - if pixel_values_videos is not None: - batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape + Args: + pixel_values_images (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + vision_feature_layer (`int`): + The index of the layer to select the vision feature. + vision_feature_select_strategy (`str`): + The feature selection strategy used to select the vision feature from the vision backbone. + Can be one of `"default"` or `"full"` + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + + image_outputs = self.image_tower(pixel_values_images, output_hidden_states=True) + image_outputs = image_outputs.hidden_states[vision_feature_layer].squeeze(1) - pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width) - video_outputs = self.video_tower(pixel_values, output_hidden_states=True) - video_outputs = video_outputs.hidden_states[vision_feature_layer].squeeze(1) + if vision_feature_select_strategy == "default": + image_outputs = image_outputs[:, 1:] + elif vision_feature_select_strategy == "full": + image_outputs = image_outputs else: - video_outputs = None - num_frames = 0 + raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") - if pixel_values_images is not None: - image_outputs = self.image_tower(pixel_values_images, output_hidden_states=True) - image_outputs = image_outputs.hidden_states[vision_feature_layer].squeeze(1) + image_features = self.multi_modal_projector(image_outputs) - if vision_feature_select_strategy == "default": - image_outputs = image_outputs[:, 1:] - elif vision_feature_select_strategy == "full": - image_outputs = image_outputs - else: - raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") - else: - image_outputs = None + return image_features - return image_outputs, video_outputs, num_frames + def get_video_features(self, pixel_values_videos: torch.FloatTensor, vision_feature_layer: int): + """ + Obtains video last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values_videos (`torch.FloatTensor]` of shape `(batch_size, num_frames, channels, height, width)`) + The tensors corresponding to the input videos. + vision_feature_layer (`int`): + The index of the layer to select the vision feature. + Returns: + video_features (`torch.Tensor`): Video feature tensor of shape `(num_videos * num_frames, image_length, embed_dim)`). + frames (`int`): Number of frames the videos have. + """ + batch_size_vid, num_frames, channels, height, width = pixel_values_videos.shape + + pixel_values = pixel_values_videos.reshape(batch_size_vid * num_frames, channels, height, width) + video_outputs = self.video_tower(pixel_values, output_hidden_states=True) + video_features = video_outputs.hidden_states[vision_feature_layer].squeeze(1) + video_features = self.multi_modal_projector(video_features) + + return video_features, num_frames @add_start_docstrings_to_model_forward(VIDEO_LLAVA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=VideoLlavaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -543,110 +552,106 @@ def forward( ) legacy_processing = inputs_not_expanded or pixels_present - if pixel_values_images is not None or pixel_values_videos is not None: - image_outputs, video_outputs, num_frames = self._get_vision_features( - pixel_values_images=pixel_values_images, - pixel_values_videos=pixel_values_videos, + image_features = None + if pixel_values_images is not None: + image_features = self.get_image_features( + pixel_values_images, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) - image_features = video_features = None - if image_outputs is not None: - image_features = self.multi_modal_projector(image_outputs) - if video_outputs is not None: - video_features = self.multi_modal_projector(video_outputs) - - if legacy_processing: - logger.warning_once( - "Expanding inputs for image tokens in Video-LLaVa should be done in processing. " - "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " - "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " - "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." - ) - if input_ids.shape[1] != 1: - for features, frames in ((image_features, 1), (video_features, num_frames)): - if features is not None: - ( - inputs_embeds, - attention_mask, - labels, - position_ids, - input_ids, - ) = self._merge_input_ids_with_visual_features( - features, - inputs_embeds, - input_ids, - attention_mask, - labels, - num_frames=frames, - ) - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) - else: - # Retrieve the first layer to inspect the logits and mask out the hidden states - # that are set to 0 - first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] - - # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 - batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) - - target_length = input_ids.shape[1] - past_length = first_layer_past_key_value.shape[-1] - - extended_attention_mask = torch.ones( - (attention_mask.shape[0], past_length), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) + video_features = None + if pixel_values_videos is not None: + video_features, num_frames = self.get_video_features( + pixel_values_videos=pixel_values_videos, vision_feature_layer=vision_feature_layer + ) - # Filter out only the tokens that can be un-attended, this can happen - # if one uses Llava + Fused modules where the cache on the - # first iteration is already big enough, or if one passes custom cache - valid_indices = non_attended_tokens < extended_attention_mask.size(-1) - new_batch_index = batch_index[valid_indices] - new_non_attended_tokens = non_attended_tokens[valid_indices] + if legacy_processing: + logger.warning_once( + "Expanding inputs for image tokens in Video-LLaVa should be done in processing. " + "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " + "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " + "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." + ) + if input_ids.shape[1] != 1: + for features, frames in ((image_features, 1), (video_features, num_frames)): + if features is not None: + ( + inputs_embeds, + attention_mask, + labels, + position_ids, + input_ids, + ) = self._merge_input_ids_with_visual_features( + features, + inputs_embeds, + input_ids, + attention_mask, + labels, + num_frames=frames, + ) + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) + else: + # Retrieve the first layer to inspect the logits and mask out the hidden states + # that are set to 0 + first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] + + # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 + batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) + target_length = input_ids.shape[1] + past_length = first_layer_past_key_value.shape[-1] + extended_attention_mask = torch.ones( + (attention_mask.shape[0], past_length), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) - # Zero-out the places where we don't need to attend - extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + # Filter out only the tokens that can be un-attended, this can happen + # if one uses Llava + Fused modules where the cache on the + # first iteration is already big enough, or if one passes custom cache + valid_indices = non_attended_tokens < extended_attention_mask.size(-1) + new_batch_index = batch_index[valid_indices] + new_non_attended_tokens = non_attended_tokens[valid_indices] - attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) - position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 - cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[ - -target_length: - ] + # Zero-out the places where we don't need to attend + extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 + attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) + position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 + cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[-target_length:] - # TODO: @raushan retain only the new behavior after v4.47 - else: - if image_outputs is not None: - n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() - n_image_features = image_features.shape[1] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) + # TODO: @raushan retain only the new behavior after v4.47 + else: + if pixel_values_images is not None: + n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() + n_image_features = image_features.shape[1] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - if video_outputs is not None: - n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item() - n_video_features = video_features.shape[1] - if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - special_image_mask = ( - (input_ids == self.config.video_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) + special_image_mask = ( + (input_ids == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + if pixel_values_videos is not None: + n_video_tokens = (input_ids == self.config.video_token_index).sum(dim=-1)[0].item() + n_video_features = video_features.shape[1] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) + special_image_mask = ( + (input_ids == self.config.video_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 3af32a9caace0e..dd7baa34406fb0 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -132,8 +132,9 @@ class VipLlavaPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["VipLlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True _supports_cache_class = True + _supports_flash_attn_2 = True + _supports_sdpa = True def _init_weights(self, module): # important: this ported version of VipLlava isn't meant for training from scratch - only @@ -157,14 +158,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - @property - def _supports_sdpa(self): - """ - Retrieve language_model's attribute to check whether the model supports - SDPA or not. - """ - return self.language_model._supports_sdpa - VIPLLAVA_INPUTS_DOCSTRING = r""" Args: @@ -248,9 +241,7 @@ def __init__(self, config: VipLlavaConfig): self.multi_modal_projector = VipLlavaMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config( - config.text_config, attn_implementation=config._attn_implementation - ) + self.language_model = AutoModelForCausalLM.from_config(config.text_config) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() @@ -284,6 +275,17 @@ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_m # Ignore copy def get_image_features(self, pixel_values: torch.FloatTensor, vision_feature_layers: List[int]): + """ + Obtains image last hidden states from the vision tower and apply multimodal projection. + + Args: + pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`) + The tensors corresponding to the input images. + vision_feature_layers (`List[int]`): + The list og indexes of the layers to select the vision feature. + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) # For VIP-llava, the image features are computed this way diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index b044dda300ab48..152a9601403301 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -161,6 +161,8 @@ class VisionEncoderDecoderModel(PreTrainedModel, GenerationMixin): main_input_name = "pixel_values" supports_gradient_checkpointing = True _supports_param_buffer_assignment = False + _supports_flash_attn_2 = True + _supports_sdpa = True def __init__( self, @@ -191,10 +193,10 @@ def __init__( super().__init__(config) if encoder is None: - encoder = AutoModel.from_config(config.encoder, attn_implementation=config._attn_implementation) + encoder = AutoModel.from_config(config.encoder) if decoder is None: - decoder = AutoModelForCausalLM.from_config(config.decoder, attn_implementation=config._attn_implementation) + decoder = AutoModelForCausalLM.from_config(config.decoder) self.encoder = encoder self.decoder = decoder @@ -212,6 +214,8 @@ def __init__( # make sure that the individual model's config refers to the shared config # so that the updates to the config will be synced + self.config.encoder._attn_implementation = self.encoder.config._attn_implementation + self.config.decoder._attn_implementation = self.decoder.config._attn_implementation self.encoder.config = self.config.encoder self.decoder.config = self.config.decoder diff --git a/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py b/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py index 5b90faa8862c97..4b39de3df1c882 100755 --- a/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +++ b/src/transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py @@ -161,6 +161,8 @@ def clip_loss(similarity: torch.Tensor) -> torch.Tensor: class VisionTextDualEncoderModel(PreTrainedModel): config_class = VisionTextDualEncoderConfig base_model_prefix = "vision_text_dual_encoder" + _supports_flash_attn_2 = True + _supports_sdpa = True def __init__( self, @@ -184,18 +186,18 @@ def __init__( if isinstance(config.vision_config, CLIPVisionConfig): vision_model = CLIPVisionModel(config.vision_config) else: - vision_model = AutoModel.from_config( - config.vision_config, attn_implementation=config._attn_implementation - ) + vision_model = AutoModel.from_config(config.vision_config) if text_model is None: - text_model = AutoModel.from_config(config.text_config, attn_implementation=config._attn_implementation) + text_model = AutoModel.from_config(config.text_config) self.vision_model = vision_model self.text_model = text_model # make sure that the individual model's config refers to the shared config # so that the updates to the config will be synced + self.config.vision_config._attn_implementation = self.vision_model.config._attn_implementation + self.config.text_config._attn_implementation = self.text_model.config._attn_implementation self.vision_model.config = self.config.vision_config self.text_model.config = self.config.text_config diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 76ebd18ed32d7b..bb08acfc0bba67 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -250,8 +250,24 @@ def __init__(self, config: ViTConfig) -> None: self.attention_probs_dropout_prob = config.attention_probs_dropout_prob def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states: torch.FloatTensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + if output_attentions or head_mask is not None: + logger.warning_once( + "`ViTSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " + "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + head_mask=head_mask, + output_attentions=output_attentions, + ) + mixed_query_layer = self.query(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states)) diff --git a/src/transformers/models/vit_mae/modeling_vit_mae.py b/src/transformers/models/vit_mae/modeling_vit_mae.py index 0be169a51b276d..e319f2f655aabf 100755 --- a/src/transformers/models/vit_mae/modeling_vit_mae.py +++ b/src/transformers/models/vit_mae/modeling_vit_mae.py @@ -424,8 +424,24 @@ def __init__(self, config: ViTMAEConfig) -> None: self.attention_probs_dropout_prob = config.attention_probs_dropout_prob def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states: torch.FloatTensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + if output_attentions or head_mask is not None: + logger.warning_once( + "`ViTMAESdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " + "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + head_mask=head_mask, + output_attentions=output_attentions, + ) + mixed_query_layer = self.query(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states)) diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index b962ac597dabb8..39274dd28fef5b 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -241,8 +241,24 @@ def __init__(self, config: ViTMSNConfig) -> None: self.attention_probs_dropout_prob = config.attention_probs_dropout_prob def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states: torch.FloatTensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + if output_attentions or head_mask is not None: + logger.warning_once( + "`ViTMSNSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " + "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + head_mask=head_mask, + output_attentions=output_attentions, + ) + mixed_query_layer = self.query(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states)) diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index 2d00c973b85c18..f7ef3e55f5f799 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -299,8 +299,24 @@ def __init__(self, config: YolosConfig) -> None: self.attention_probs_dropout_prob = config.attention_probs_dropout_prob def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False + self, + hidden_states: torch.FloatTensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: + if output_attentions or head_mask is not None: + logger.warning_once( + "`YolosSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support " + "`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but " + "specifying the manual implementation will be required from Transformers version v5.0.0 onwards. " + 'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + head_mask=head_mask, + output_attentions=output_attentions, + ) + mixed_query_layer = self.query(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states)) diff --git a/tests/models/blip_2/test_modeling_blip_2.py b/tests/models/blip_2/test_modeling_blip_2.py index f2ccb2da8dba94..e5d04bd85a3404 100644 --- a/tests/models/blip_2/test_modeling_blip_2.py +++ b/tests/models/blip_2/test_modeling_blip_2.py @@ -27,6 +27,7 @@ require_torch_fp16, require_torch_gpu, require_torch_multi_accelerator, + require_torch_sdpa, require_vision, slow, torch_device, @@ -456,6 +457,7 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT test_resize_embeddings = False test_attention_outputs = False test_torchscript = False + _is_composite = True def setUp(self): self.model_tester = Blip2ForConditionalGenerationDecoderOnlyModelTester(self) @@ -488,6 +490,66 @@ def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_to_base(self): pass + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + """ + Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model. + This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention". + In contrast to the above test, this one checks if the "config._attn_implamentation" is a dict after the model + is loaded, because we manually replicate requested attn implementation on each sub-config when loading. + See https://github.com/huggingface/transformers/pull/32238 for more info + + The test tries to cover most general cases of composite models, VLMs with vision and text configs. Any model + that has a different set of sub-configs has to overwrite this test. + """ + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self._is_composite: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + text_attn = "sdpa" if model.language_model._supports_sdpa else "eager" + vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager" + qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager" + + # `None` as it is the requested one which will be assigned to each sub-config + # Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present) + self.assertTrue(model.language_model.config._attn_implementation == text_attn) + self.assertTrue(model.vision_model.config._attn_implementation == vision_attn) + self.assertTrue(model.qformer.config._attn_implementation == qformer_attn) + + model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") + model_eager = model_eager.eval().to(torch_device) + self.assertTrue(model_eager.config._attn_implementation == "eager") + self.assertTrue(model_eager.language_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.qformer.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa and any( + module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn] + ): + raise ValueError("The SDPA model should have SDPA attention layers") + def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() @@ -715,6 +777,7 @@ class Blip2ModelTest(ModelTesterMixin, PipelineTesterMixin, GenerationTesterMixi test_resize_embeddings = False test_attention_outputs = False test_torchscript = False + _is_composite = True # TODO: Fix the failed tests def is_pipeline_test_to_skip( @@ -768,6 +831,66 @@ def test_save_load_fast_init_to_base(self): def test_cpu_offload(self): pass + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + """ + Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model. + This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention". + In contrast to the above test, this one checks if the "config._attn_implamentation" is a dict after the model + is loaded, because we manually replicate requested attn implementation on each sub-config when loading. + See https://github.com/huggingface/transformers/pull/32238 for more info + + The test tries to cover most general cases of composite models, VLMs with vision and text configs. Any model + that has a different set of sub-configs has to overwrite this test. + """ + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self._is_composite: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + text_attn = "sdpa" if model.language_model._supports_sdpa else "eager" + vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager" + qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager" + + # `None` as it is the requested one which will be assigned to each sub-config + # Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present) + self.assertTrue(model.language_model.config._attn_implementation == text_attn) + self.assertTrue(model.vision_model.config._attn_implementation == vision_attn) + self.assertTrue(model.qformer.config._attn_implementation == qformer_attn) + + model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") + model_eager = model_eager.eval().to(torch_device) + self.assertTrue(model_eager.config._attn_implementation == "eager") + self.assertTrue(model_eager.language_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.qformer.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa and any( + module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn] + ): + raise ValueError("The SDPA model should have SDPA attention layers") + def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py index 88824756a6fb54..a7c8c8ef8410e8 100644 --- a/tests/models/clip/test_modeling_clip.py +++ b/tests/models/clip/test_modeling_clip.py @@ -191,6 +191,53 @@ class CLIPModelTesterMixin(ModelTesterMixin): different output logits, and are not supposed to be used or tested with padding_side="left". """ + def test_sdpa_can_dispatch_composite_models(self): + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + # Load the model with SDPA + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + # Load model with eager attention + model_eager = model_class.from_pretrained( + tmpdirname, + attn_implementation="eager", + ) + model_eager = model_eager.eval().to(torch_device) + + # SigLip has one shared cls attr for all models, so we assign both submodels heer + vision_attn = text_attn = "sdpa" if model._supports_sdpa else "eager" + + # `None` as it is the requested one which will be assigned to each sub-config + # Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present) + if hasattr(model_sdpa, "vision_model") and hasattr(model_sdpa, "text_model"): + self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn) + self.assertTrue(model_sdpa.text_model.config._attn_implementation == text_attn) + self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.text_model.config._attn_implementation == "eager") + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa and model_sdpa.config.model_type != "falcon": + raise ValueError("The SDPA model should have SDPA attention layers") + def test_eager_matches_sdpa_inference( self, torch_dtype: str, @@ -252,24 +299,6 @@ def get_mean_reldiff(msg, current_case, x, ref, atol, rtol): ) model_eager = model_eager.eval().to(torch_device) - self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - self.assertTrue(model_eager.config._attn_implementation == "eager") - - for name, submodule in model_eager.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - raise ValueError("The eager model should not have SDPA attention layers") - - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - - if not has_sdpa: - raise ValueError("The SDPA model should have SDPA attention layers") - # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving the model each time, # but it would be nicer to have an efficient way to use parameterized.expand cases = [ @@ -461,6 +490,10 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): use_attention_mask_options=(None,), ) + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + super().test_sdpa_can_dispatch_composite_models() + class CLIPTextModelTester: def __init__( @@ -639,6 +672,10 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): use_attention_mask_options=(None, "right"), # "left" is not supported for text model ) + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + super().test_sdpa_can_dispatch_composite_models() + @require_torch_sdpa def test_sdpa_can_dispatch_on_flash(self): self.skipTest(reason="CLIPTextModel has two attention masks: `causal_attention_mask` and `attention_mask`") @@ -704,6 +741,7 @@ class CLIPModelTest(CLIPModelTesterMixin, PipelineTesterMixin, unittest.TestCase test_pruning = False test_resize_embeddings = False test_attention_outputs = False + _is_composite = True def setUp(self): self.model_tester = CLIPModelTester(self) @@ -975,6 +1013,10 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): use_attention_mask_options=(None, "right"), # "left" is not supported for text model ) + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + super().test_sdpa_can_dispatch_composite_models() + @require_torch_sdpa def test_sdpa_can_dispatch_on_flash(self): self.skipTest(reason="CLIP text tower has two attention masks: `causal_attention_mask` and `attention_mask`") @@ -1104,6 +1146,7 @@ class CLIPForImageClassificationModelTest(CLIPModelTesterMixin, PipelineTesterMi test_pruning = False test_resize_embeddings = False test_attention_outputs = False + _is_composite = True def setUp(self): self.model_tester = CLIPForImageClassificationModelTester(self) @@ -1143,6 +1186,10 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): use_attention_mask_options=(None,), ) + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + super().test_sdpa_can_dispatch_composite_models() + # We will verify our results on an image of cute cats def prepare_img(): diff --git a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py index 5e5263b6afb98c..0ee4b75ed803e3 100644 --- a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py @@ -18,7 +18,14 @@ import unittest from transformers import is_torch_available, logging -from transformers.testing_utils import CaptureLogger, require_deterministic_for_xpu, require_torch, slow, torch_device +from transformers.testing_utils import ( + CaptureLogger, + require_deterministic_for_xpu, + require_torch, + require_torch_sdpa, + slow, + torch_device, +) from ...test_modeling_common import ids_tensor from ..bart.test_modeling_bart import BartStandaloneDecoderModelTester @@ -54,6 +61,8 @@ @require_torch class EncoderDecoderMixin: + supports_sdpa = False + def get_encoder_decoder_model(self, config, decoder_config): raise NotImplementedError @@ -670,6 +679,67 @@ def test_real_model_save_load_from_pretrained(self): max_diff = np.amax(np.abs(out_1 - out_2)) self.assertLessEqual(max_diff, 1e-5) + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + if not self.supports_sdpa: + self.skipTest("SDPA is not supported") + + inputs_dict = self.prepare_config_and_inputs() + encoder_config, decoder_config = inputs_dict["config"], inputs_dict["decoder_config"] + config = EncoderDecoderConfig.from_encoder_decoder_configs( + encoder_config=encoder_config, decoder_config=decoder_config + ) + model = EncoderDecoderModel(config=config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = EncoderDecoderModel.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + # see https://github.com/huggingface/transformers/pull/32238 + # Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present) + encoder_attn = "sdpa" if model.encoder._supports_sdpa else "eager" + decoder_attn = "sdpa" if model.decoder._supports_sdpa else "eager" + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + self.assertTrue(model_sdpa.encoder.config._attn_implementation == encoder_attn) + self.assertTrue(model_sdpa.decoder.config._attn_implementation == decoder_attn) + + # Also test that nothing break if we request SDPA explicitly, when both sub-parts support it. + # If the model supports sdpa (i.e. all of sub-models supports it) we'll dispatch safely + # Otherwise we should raise error that SDPA is not supported, as some of the sub-models doesn't support + if encoder_attn == "sdpa" and decoder_attn == "sdpa": + model_sdpa_explicit = EncoderDecoderModel.from_pretrained(tmpdirname, attn_implementation="sdpa") + model_sdpa_explicit = model_sdpa_explicit.eval().to(torch_device) + + self.assertTrue(model_sdpa_explicit.config._attn_implementation == "sdpa") + else: + with self.assertRaises(ValueError): + model_sdpa_explicit = EncoderDecoderModel.from_pretrained(tmpdirname, attn_implementation="sdpa") + + model_eager = EncoderDecoderModel.from_pretrained( + tmpdirname, + attn_implementation="eager", + ) + model_eager = model_eager.eval().to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + self.assertTrue(model_eager.encoder.config._attn_implementation == "eager") + self.assertTrue(model_eager.decoder.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa: + raise ValueError("The SDPA model should have SDPA attention layers") + @require_torch class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): @@ -949,6 +1019,8 @@ def get_pretrained_model(self): @require_torch class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): + supports_sdpa = True + def get_encoder_decoder_model(self, config, decoder_config): encoder_model = BertModel(config) decoder_model = GPT2LMHeadModel(decoder_config) diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 8f9a918dca0082..94670803daa998 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -88,6 +88,10 @@ def setUp(self): def test_model_outputs_equivalence(self, **kwargs): pass + @unittest.skip("Gemma2's forcefully disables sdpa due to softcapping") + def test_sdpa_can_dispatch_non_composite_models(self): + pass + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different") def test_eager_matches_sdpa_inference(self): diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index 250c47c3a7e8ce..bbade169550f8c 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -580,11 +580,9 @@ def test_model_from_pretrained(self): model = IdeficsModel.from_pretrained(model_name) self.assertIsNotNone(model) - @require_torch_sdpa - @slow - @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) - def test_eager_matches_sdpa_inference(self, torch_dtype: str): - self.skipTest(reason="Idefics has a hard requirement on SDPA, skipping this test") + @unittest.skip("Idefics has a hard requirement on SDPA") + def test_sdpa_can_dispatch_non_composite_models(self): + pass @unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required") @@ -806,6 +804,10 @@ def test_training_gradient_checkpointing_use_reentrant(self): def test_training_gradient_checkpointing_use_reentrant_false(self): pass + @unittest.skip("Idefics has a hard requirement on SDPA") + def test_sdpa_can_dispatch_non_composite_models(self): + pass + @unittest.skipIf(not is_torch_greater_or_equal_than_2_0, reason="pytorch 2.0 or higher is required") @require_torch diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py index 4071fcbb232805..854b8b934578e0 100644 --- a/tests/models/idefics2/test_modeling_idefics2.py +++ b/tests/models/idefics2/test_modeling_idefics2.py @@ -16,6 +16,7 @@ import copy import gc +import tempfile import unittest from io import BytesIO @@ -36,6 +37,7 @@ require_torch, require_torch_gpu, require_torch_multi_gpu, + require_torch_sdpa, slow, torch_device, ) @@ -180,6 +182,7 @@ class Idefics2ModelTest(ModelTesterMixin, unittest.TestCase): test_pruning = False test_resize_embeddings = True test_head_masking = False + _is_composite = True def setUp(self): self.model_tester = Idefics2VisionText2TextModelTester(self) @@ -327,6 +330,43 @@ def test_resize_embeddings_untied(self): # Check that the model can still do a forward pass successfully (every parameter should be resized) model(**self._prepare_for_class(inputs_dict, model_class)) + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + vision_attn = None if model.vision_model._supports_sdpa else "eager" + perceiver_attn = None if model.connector.perceiver_resampler._supports_sdpa else "eager" + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn) + self.assertTrue(model_sdpa.connector.perceiver_resampler.config._attn_implementation == perceiver_attn) + + model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") + model_eager = model_eager.eval().to(torch_device) + self.assertTrue(model_eager.config._attn_implementation == "eager") + self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager") + self.assertTrue(model_sdpa.connector.perceiver_resampler.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa and model_sdpa.config.model_type != "falcon": + raise ValueError("The SDPA model should have SDPA attention layers") + @require_torch class Idefics2ForConditionalGenerationModelTest(GenerationTesterMixin, ModelTesterMixin, unittest.TestCase): diff --git a/tests/models/instructblip/test_modeling_instructblip.py b/tests/models/instructblip/test_modeling_instructblip.py index 8292567334bf3b..5182ac20cd993e 100644 --- a/tests/models/instructblip/test_modeling_instructblip.py +++ b/tests/models/instructblip/test_modeling_instructblip.py @@ -32,6 +32,7 @@ require_accelerate, require_bitsandbytes, require_torch, + require_torch_sdpa, require_vision, slow, torch_device, @@ -460,6 +461,7 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene test_resize_embeddings = False test_attention_outputs = False test_torchscript = False + _is_composite = True def setUp(self): self.model_tester = InstructBlipForConditionalGenerationDecoderOnlyModelTester(self) @@ -529,6 +531,66 @@ def test_model_from_pretrained(self): model = InstructBlipForConditionalGeneration.from_pretrained(model_name) self.assertIsNotNone(model) + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + """ + Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model. + This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention". + In contrast to the above test, this one checks if the "config._attn_implamentation" is a dict after the model + is loaded, because we manually replicate requested attn implementation on each sub-config when loading. + See https://github.com/huggingface/transformers/pull/32238 for more info + + The test tries to cover most general cases of composite models, VLMs with vision and text configs. Any model + that has a different set of sub-configs has to overwrite this test. + """ + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self._is_composite: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + text_attn = "sdpa" if model.language_model._supports_sdpa else "eager" + vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager" + qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager" + + # `None` as it is the requested one which will be assigned to each sub-config + # Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present) + self.assertTrue(model.language_model.config._attn_implementation == text_attn) + self.assertTrue(model.vision_model.config._attn_implementation == vision_attn) + self.assertTrue(model.qformer.config._attn_implementation == qformer_attn) + + model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") + model_eager = model_eager.eval().to(torch_device) + self.assertTrue(model_eager.config._attn_implementation == "eager") + self.assertTrue(model_eager.language_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.qformer.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa and any( + module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn] + ): + raise ValueError("The SDPA model should have SDPA attention layers") + # We will verify our results on an image of cute cats def prepare_img(): diff --git a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py index 8a9326c22ac11c..298c7a8d7ff46f 100644 --- a/tests/models/instructblipvideo/test_modeling_instructblipvideo.py +++ b/tests/models/instructblipvideo/test_modeling_instructblipvideo.py @@ -32,6 +32,7 @@ require_accelerate, require_bitsandbytes, require_torch, + require_torch_sdpa, require_vision, slow, torch_device, @@ -481,6 +482,7 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest( test_resize_embeddings = False test_attention_outputs = False test_torchscript = False + _is_composite = True def setUp(self): self.model_tester = InstructBlipVideoForConditionalGenerationDecoderOnlyModelTester(self) @@ -550,6 +552,66 @@ def test_model_from_pretrained(self): model = InstructBlipVideoForConditionalGeneration.from_pretrained(model_name) self.assertIsNotNone(model) + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + """ + Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model. + This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention". + In contrast to the above test, this one checks if the "config._attn_implamentation" is a dict after the model + is loaded, because we manually replicate requested attn implementation on each sub-config when loading. + See https://github.com/huggingface/transformers/pull/32238 for more info + + The test tries to cover most general cases of composite models, VLMs with vision and text configs. Any model + that has a different set of sub-configs has to overwrite this test. + """ + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self._is_composite: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + text_attn = "sdpa" if model.language_model._supports_sdpa else "eager" + vision_attn = "sdpa" if model.vision_model._supports_sdpa else "eager" + qformer_attn = "sdpa" if model.qformer._supports_sdpa else "eager" + + # `None` as it is the requested one which will be assigned to each sub-config + # Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present) + self.assertTrue(model.language_model.config._attn_implementation == text_attn) + self.assertTrue(model.vision_model.config._attn_implementation == vision_attn) + self.assertTrue(model.qformer.config._attn_implementation == qformer_attn) + + model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") + model_eager = model_eager.eval().to(torch_device) + self.assertTrue(model_eager.config._attn_implementation == "eager") + self.assertTrue(model_eager.language_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.qformer.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa and any( + module_attn == "sdpa" for module_attn in [text_attn, vision_attn, qformer_attn] + ): + raise ValueError("The SDPA model should have SDPA attention layers") + # We will verify our results on an image of cute cats def prepare_video(): diff --git a/tests/models/kosmos2/test_modeling_kosmos2.py b/tests/models/kosmos2/test_modeling_kosmos2.py index 22cbffcfdb6b13..de6c0b15d661f9 100644 --- a/tests/models/kosmos2/test_modeling_kosmos2.py +++ b/tests/models/kosmos2/test_modeling_kosmos2.py @@ -25,8 +25,17 @@ from transformers import AutoModelForImageTextToText, AutoProcessor, Kosmos2Config from transformers.models.kosmos2.configuration_kosmos2 import Kosmos2TextConfig, Kosmos2VisionConfig -from transformers.testing_utils import IS_ROCM_SYSTEM, require_torch, require_vision, slow, torch_device -from transformers.utils import is_torch_available, is_vision_available +from transformers.testing_utils import ( + IS_ROCM_SYSTEM, + require_torch, + require_vision, + slow, + torch_device, +) +from transformers.utils import ( + is_torch_available, + is_vision_available, +) from ...test_configuration_common import ConfigTester from ...test_modeling_common import ( @@ -257,6 +266,7 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase) test_pruning = False test_resize_embeddings = False test_attention_outputs = False + _is_composite = True # TODO: `image-to-text` pipeline for this model needs Processor. def is_pipeline_test_to_skip( diff --git a/tests/models/llava/test_modeling_llava.py b/tests/models/llava/test_modeling_llava.py index 07415900bb93db..405fad1bd31c8d 100644 --- a/tests/models/llava/test_modeling_llava.py +++ b/tests/models/llava/test_modeling_llava.py @@ -186,6 +186,7 @@ class LlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTesterM pipeline_model_mapping = {"image-to-text": LlavaForConditionalGeneration} if is_torch_available() else {} test_pruning = False test_head_masking = False + _is_composite = True def setUp(self): self.model_tester = LlavaVisionText2TextModelTester(self) @@ -260,6 +261,16 @@ def test_sdpa_can_compile_dynamic(self): def test_sdpa_can_dispatch_on_flash(self): pass + @unittest.skip("FlashAttention only support fp16 and bf16 data type") + def test_flash_attn_2_fp32_ln(self): + pass + + @unittest.skip( + "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" + ) + def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): + pass + @require_torch class LlavaForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/models/llava_next/test_modeling_llava_next.py b/tests/models/llava_next/test_modeling_llava_next.py index a54aeab8a28252..6589bf14d24c65 100644 --- a/tests/models/llava_next/test_modeling_llava_next.py +++ b/tests/models/llava_next/test_modeling_llava_next.py @@ -218,6 +218,7 @@ class LlavaNextForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes all_generative_model_classes = (LlavaNextForConditionalGeneration,) if is_torch_available() else () test_pruning = False test_head_masking = False + _is_composite = True def setUp(self): self.model_tester = LlavaNextVisionText2TextModelTester(self) @@ -316,6 +317,16 @@ def test_sdpa_can_compile_dynamic(self): def test_sdpa_can_dispatch_on_flash(self): pass + @unittest.skip("FlashAttention only support fp16 and bf16 data type") + def test_flash_attn_2_fp32_ln(self): + pass + + @unittest.skip( + "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" + ) + def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): + pass + @require_torch class LlavaNextForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/models/llava_next_video/test_modeling_llava_next_video.py b/tests/models/llava_next_video/test_modeling_llava_next_video.py index 30eaa7fb050c7c..05fc8a49e1e9b9 100644 --- a/tests/models/llava_next_video/test_modeling_llava_next_video.py +++ b/tests/models/llava_next_video/test_modeling_llava_next_video.py @@ -236,6 +236,7 @@ class LlavaNextVideoForConditionalGenerationModelTest(ModelTesterMixin, Generati all_generative_model_classes = (LlavaNextVideoForConditionalGeneration,) if is_torch_available() else () test_pruning = False test_head_masking = False + _is_composite = True def setUp(self): self.model_tester = LlavaNextVideoVisionText2TextModelTester(self) @@ -340,6 +341,16 @@ def test_sdpa_can_compile_dynamic(self): def test_sdpa_can_dispatch_on_flash(self): pass + @unittest.skip("FlashAttention only support fp16 and bf16 data type") + def test_flash_attn_2_fp32_ln(self): + pass + + @unittest.skip( + "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" + ) + def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): + pass + @require_torch class LlavaNextVideoForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/models/llava_onevision/test_modeling_llava_onevision.py b/tests/models/llava_onevision/test_modeling_llava_onevision.py index 0e9c88cb3463fd..0a33898b63072b 100644 --- a/tests/models/llava_onevision/test_modeling_llava_onevision.py +++ b/tests/models/llava_onevision/test_modeling_llava_onevision.py @@ -219,6 +219,7 @@ class LlavaOnevisionForConditionalGenerationModelTest(ModelTesterMixin, Generati all_generative_model_classes = (LlavaOnevisionForConditionalGeneration,) if is_torch_available() else () test_pruning = False test_head_masking = False + _is_composite = True def setUp(self): self.model_tester = LlavaOnevisionVisionText2TextModelTester(self) @@ -306,6 +307,16 @@ def test_training_gradient_checkpointing_use_reentrant_false(self): def test_assisted_decoding_with_num_logits_to_keep(self): pass + @unittest.skip("FlashAttention only support fp16 and bf16 data type") + def test_flash_attn_2_fp32_ln(self): + pass + + @unittest.skip( + "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" + ) + def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): + pass + @require_torch class LlavaOnevisionForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/models/longt5/test_modeling_longt5.py b/tests/models/longt5/test_modeling_longt5.py index c0cf21b2369d0a..a9d3e7479e9578 100644 --- a/tests/models/longt5/test_modeling_longt5.py +++ b/tests/models/longt5/test_modeling_longt5.py @@ -31,6 +31,7 @@ if is_torch_available(): import torch + import torch.nn.functional as F from transformers import ( MODEL_FOR_QUESTION_ANSWERING_MAPPING, @@ -574,6 +575,41 @@ def test_decoder_model_past_with_3d_attn_mask(self): lm_labels, ) + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) @@ -602,7 +638,7 @@ def test_export_to_onnx(self): (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), f"{tmpdirname}/longt5_test.onnx", export_params=True, - opset_version=13, + opset_version=14, input_names=["input_ids", "decoder_input_ids"], ) diff --git a/tests/models/mllama/test_modeling_mllama.py b/tests/models/mllama/test_modeling_mllama.py index 5c5ca3985ee08f..fafa2f71331ba3 100644 --- a/tests/models/mllama/test_modeling_mllama.py +++ b/tests/models/mllama/test_modeling_mllama.py @@ -274,6 +274,7 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester test_pruning = False test_head_masking = False test_torchscript = False + _is_composite = True def setUp(self): self.model_tester = MllamaVisionText2TextModelTester(self) diff --git a/tests/models/mt5/test_modeling_mt5.py b/tests/models/mt5/test_modeling_mt5.py index 6e912ec3607d40..20412da2e1db06 100644 --- a/tests/models/mt5/test_modeling_mt5.py +++ b/tests/models/mt5/test_modeling_mt5.py @@ -40,6 +40,7 @@ if is_torch_available(): import torch + import torch.nn.functional as F from transformers import ( AutoModelForSeq2SeqLM, @@ -575,6 +576,9 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, # The small MT5 model needs higher percentages for CPU/MP tests model_split_percents = [0.5, 0.8, 0.9] + # used in `test_torch_compile` + _torch_compile_test_ckpt = "google/mt5-small" + def setUp(self): self.model_tester = MT5ModelTester(self) self.config_tester = ConfigTester(self, config_class=MT5Config, d_model=37) @@ -627,12 +631,9 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa ] if labels is not None: input_names.append("labels") - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} input_names = list(filtered_inputs.keys()) - model_output = model(**filtered_inputs) - traced_model = symbolic_trace(model, input_names) traced_output = traced_model(**filtered_inputs) else: @@ -647,7 +648,6 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa "visual_feats", "visual_pos", ] - labels = inputs.get("labels", None) start_positions = inputs.get("start_positions", None) end_positions = inputs.get("end_positions", None) @@ -657,15 +657,12 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa input_names.append("start_positions") if end_positions is not None: input_names.append("end_positions") - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} input_names = list(filtered_inputs.keys()) - if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( not hasattr(model.config, "problem_type") or model.config.problem_type is None ): model.config.problem_type = "single_label_classification" - traced_model = symbolic_trace(model, input_names) traced_output = traced_model(**filtered_inputs) model_output = model(**filtered_inputs) @@ -718,6 +715,41 @@ def flatten_output(output): # (Even with this call, there are still memory leak by ~0.04MB) self.clear_torch_jit_class_registry() + # overwrite because MT5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + def test_config(self): self.config_tester.run_common_tests() diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index cc30238c8df9f5..438178bfc6faa2 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -654,8 +654,6 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) model_sdpa = model_sdpa.eval().to(torch_device) - self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - model_eager = model_class.from_pretrained( tmpdirname, torch_dtype=torch_dtype, @@ -663,20 +661,6 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): ) model_eager = model_eager.eval().to(torch_device) - self.assertTrue(model_eager.config._attn_implementation == "eager") - - for name, submodule in model_eager.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: - raise ValueError("The eager model should not have SDPA attention layers") - - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: - has_sdpa = True - break - if not has_sdpa and model_sdpa.config.model_type != "falcon": - raise ValueError("The SDPA model should have SDPA attention layers") - # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model, # but it would be nicer to have an efficient way to use parameterized.expand fail_cases = [] @@ -1042,6 +1026,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, # not to test torchscript as the model tester doesn't prepare `input_values` and `padding_mask` # (and `torchscript` hates `None` values). test_torchscript = False + _is_composite = True def setUp(self): self.model_tester = MusicgenTester(self) @@ -1420,7 +1405,7 @@ def test_save_load_fast_init_from_base(self): @require_torch_gpu @mark.flash_attn_test @slow - # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence + # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence def test_flash_attn_2_inference_equivalence(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: @@ -1432,7 +1417,9 @@ def test_flash_attn_2_inference_equivalence(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + tmpdirname, + torch_dtype=torch.bfloat16, + attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, ) model_fa.to(torch_device) @@ -1505,7 +1492,88 @@ def test_flash_attn_2_inference_equivalence(self): @require_torch_gpu @mark.flash_attn_test @slow - # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding + def test_flash_attn_2_conversion(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, + ).to(torch_device) + + for _, module in model.named_modules(): + if "FlashAttention" in module.__class__.__name__: + return + + self.assertTrue(False, "FlashAttention2 modules not found in model") + + @require_torch_sdpa + @require_torch_gpu + @slow + def test_sdpa_can_dispatch_on_flash(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + torch.compiler.reset() + compute_capability = torch.cuda.get_device_capability() + major, _ = compute_capability + + if not torch.version.cuda or major < 8: + self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0") + + for model_class in self.all_model_classes: + if not model_class._supports_sdpa: + self.skipTest(f"{model_class.__name__} does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + inputs_dict = self._prepare_for_class(inputs_dict, model_class) + if config.model_type in ["llava", "llava_next", "vipllava", "video_llava"]: + self.skipTest( + reason="Llava-like models currently (transformers==4.39.1) requires an attention_mask input" + ) + if config.model_type in ["paligemma"]: + self.skipTest( + "PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input" + ) + if config.model_type in ["idefics", "idefics2", "idefics3"]: + self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input") + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation={"decoder": "sdpa", "audio_encoder": None, "text_encoder": None}, + ) + model.to(torch_device) + + inputs_dict.pop("attention_mask", None) + inputs_dict.pop("decoder_attention_mask", None) + + for name, inp in inputs_dict.items(): + if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]: + inputs_dict[name] = inp.to(torch.float16) + + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + _ = model(**inputs_dict) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding def test_flash_attn_2_inference_equivalence_right_padding(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: @@ -1517,7 +1585,9 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + tmpdirname, + torch_dtype=torch.bfloat16, + attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, ) model_fa.to(torch_device) @@ -1587,7 +1657,7 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): @require_torch_gpu @mark.flash_attn_test @slow - # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding + # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding def test_flash_attn_2_generate_left_padding(self): # Ignore copy for model_class in self.greedy_sample_model_classes: @@ -1622,7 +1692,7 @@ def test_flash_attn_2_generate_left_padding(self): model = model_class.from_pretrained( tmpdirname, torch_dtype=torch.float16, - attn_implementation="flash_attention_2", + attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, low_cpu_mem_usage=True, ).to(torch_device) @@ -1636,7 +1706,7 @@ def test_flash_attn_2_generate_left_padding(self): @require_torch_gpu @mark.flash_attn_test @slow - # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right + # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right def test_flash_attn_2_generate_padding_right(self): # Ignore copy for model_class in self.greedy_sample_model_classes: @@ -1670,7 +1740,7 @@ def test_flash_attn_2_generate_padding_right(self): model = model_class.from_pretrained( tmpdirname, torch_dtype=torch.float16, - attn_implementation="flash_attention_2", + attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, low_cpu_mem_usage=True, ).to(torch_device) @@ -1684,7 +1754,7 @@ def test_flash_attn_2_generate_padding_right(self): @require_torch_gpu @mark.flash_attn_test @slow - # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache + # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache def test_flash_attn_2_generate_use_cache(self): max_new_tokens = 30 @@ -1713,7 +1783,7 @@ def test_flash_attn_2_generate_use_cache(self): model = model_class.from_pretrained( tmpdirname, torch_dtype=torch.float16, - attn_implementation="flash_attention_2", + attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, low_cpu_mem_usage=True, ).to(torch_device) @@ -1726,6 +1796,53 @@ def test_flash_attn_2_generate_use_cache(self): use_cache=True, ) + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self._is_composite: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + audio_encoder_attn = "sdpa" if model.audio_encoder._supports_sdpa else "eager" + text_encoder_attn = "sdpa" if model.text_encoder._supports_sdpa else "eager" + decoder_attn = "sdpa" if model.decoder._supports_sdpa else "eager" + + # `None` as it is the requested one which will be assigned to each sub-config + # Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present) + self.assertTrue(model_sdpa.audio_encoder.config._attn_implementation == audio_encoder_attn) + self.assertTrue(model_sdpa.text_encoder.config._attn_implementation == text_encoder_attn) + self.assertTrue(model_sdpa.decoder.config._attn_implementation == decoder_attn) + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") + model_eager = model_eager.eval().to(torch_device) + + self.assertTrue(model_eager.audio_encoder.config._attn_implementation == "eager") + self.assertTrue(model_eager.text_encoder.config._attn_implementation == "eager") + self.assertTrue(model_eager.decoder.config._attn_implementation == "eager") + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + has_sdpa = True + break + if not has_sdpa and model_sdpa.config.model_type != "falcon": + raise ValueError("The SDPA model should have SDPA attention layers") + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @require_torch_sdpa @slow @@ -1792,8 +1909,6 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) model_sdpa = model_sdpa.eval().to(torch_device) - self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - model_eager = model_class.from_pretrained( tmpdirname, torch_dtype=torch_dtype, @@ -1801,20 +1916,6 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): ) model_eager = model_eager.eval().to(torch_device) - self.assertTrue(model_eager.config._attn_implementation == "eager") - - for name, submodule in model_eager.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: - raise ValueError("The eager model should not have SDPA attention layers") - - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: - has_sdpa = True - break - if not has_sdpa and model_sdpa.config.model_type != "falcon": - raise ValueError("The SDPA model should have SDPA attention layers") - # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model, # but it would be nicer to have an efficient way to use parameterized.expand fail_cases = [] diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index 35af9fe0768da8..f53fc21ba80c09 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -311,7 +311,9 @@ def test_flash_attn_2_inference_equivalence(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + tmpdirname, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", ) model_fa.to(torch_device) @@ -391,7 +393,9 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + tmpdirname, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", ) model_fa.to(torch_device) @@ -454,148 +458,10 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) - @require_flash_attn - @require_torch_gpu - @mark.flash_attn_test - @slow - # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding - def test_flash_attn_2_generate_left_padding(self): - # Ignore copy - for model_class in self.greedy_sample_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( - torch_device - ) - - dummy_input = inputs_dict[model.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - # make sure we do left padding - dummy_attention_mask[:, :-1] = 0 - dummy_attention_mask[:, -1:] = 1 - - out = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False - ) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - out_fa = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False - ) - - self.assertTrue(torch.allclose(out, out_fa)) - - @require_flash_attn - @require_torch_gpu - @mark.flash_attn_test - @slow - # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right - def test_flash_attn_2_generate_padding_right(self): - # Ignore copy - for model_class in self.greedy_sample_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( - torch_device - ) - - dummy_input = inputs_dict[model.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - # make sure we do right padding - dummy_attention_mask[:, :-1] = 1 - dummy_attention_mask[:, -1:] = 0 - - out = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False - ) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - out_fa = model.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False - ) - - self.assertTrue(torch.allclose(out, out_fa)) - - @require_flash_attn - @require_torch_gpu - @mark.flash_attn_test - @slow - # Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_2_generate_use_cache - def test_flash_attn_2_generate_use_cache(self): - max_new_tokens = 30 - - # Ignore copy - for model_class in self.greedy_sample_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - dummy_input = inputs_dict[model_class.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - # make sure that all models have enough positions for generation - if hasattr(config, "max_position_embeddings"): - config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - ).to(torch_device) - - # Just test that a large cache works as expected - _ = model.generate( - dummy_input, - attention_mask=dummy_attention_mask, - max_new_tokens=max_new_tokens, - do_sample=False, - use_cache=True, - ) - @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @require_torch_sdpa @slow - # Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_eager_matches_sdpa_inference + # Copied from tests.test_modeling_common.ModelTesterMixin.test_eager_matches_sdpa_inference def test_eager_matches_sdpa_inference(self, torch_dtype: str): if not self.has_attentions: self.skipTest(reason="Model architecture does not support attentions") @@ -658,8 +524,6 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) model_sdpa = model_sdpa.eval().to(torch_device) - self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - model_eager = model_class.from_pretrained( tmpdirname, torch_dtype=torch_dtype, @@ -667,20 +531,6 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): ) model_eager = model_eager.eval().to(torch_device) - self.assertTrue(model_eager.config._attn_implementation == "eager") - - for name, submodule in model_eager.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: - raise ValueError("The eager model should not have SDPA attention layers") - - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: - has_sdpa = True - break - if not has_sdpa and model_sdpa.config.model_type != "falcon": - raise ValueError("The SDPA model should have SDPA attention layers") - # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model, # but it would be nicer to have an efficient way to use parameterized.expand fail_cases = [] @@ -839,74 +689,6 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): self.assertTrue(len(fail_cases) == 0, "\n".join(fail_cases)) - @require_torch_sdpa - @slow - # Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_eager_matches_sdpa_generate - def test_eager_matches_sdpa_generate(self): - max_new_tokens = 30 - - # Ignore copy - for model_class in self.greedy_sample_model_classes: - if not model_class._supports_sdpa: - self.skipTest(f"{model_class.__name__} does not support SDPA") - - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - dummy_input = inputs_dict[model_class.main_input_name] - if dummy_input.dtype in [torch.float32, torch.bfloat16]: - dummy_input = dummy_input.to(torch.float16) - - # make sure that all models have enough positions for generation - if hasattr(config, "max_position_embeddings"): - config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - - dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) - - model_sdpa = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - low_cpu_mem_usage=True, - ).to(torch_device) - - self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - - model_eager = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - low_cpu_mem_usage=True, - attn_implementation="eager", - ).to(torch_device) - - self.assertTrue(model_eager.config._attn_implementation == "eager") - - for name, submodule in model_eager.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: - raise ValueError("The eager model should not have SDPA attention layers") - - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: - has_sdpa = True - break - if not has_sdpa: - raise ValueError("The SDPA model should have SDPA attention layers") - - # Just test that a large cache works as expected - res_eager = model_eager.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False - ) - - res_sdpa = model_sdpa.generate( - dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False - ) - - self.assertTrue(torch.allclose(res_eager, res_sdpa)) - def prepare_musicgen_melody_inputs_dict( config, @@ -1048,6 +830,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester # not to test torchscript as the model tester doesn't prepare `input_features` and `padding_mask` # (and `torchscript` hates `None` values). test_torchscript = False + _is_composite = True def setUp(self): self.model_tester = MusicgenMelodyTester(self) @@ -1406,7 +1189,7 @@ def test_save_load_fast_init_from_base(self): @require_torch_gpu @mark.flash_attn_test @slow - # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence + # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence def test_flash_attn_2_inference_equivalence(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: @@ -1418,7 +1201,9 @@ def test_flash_attn_2_inference_equivalence(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + tmpdirname, + torch_dtype=torch.bfloat16, + attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, ) model_fa.to(torch_device) @@ -1491,7 +1276,88 @@ def test_flash_attn_2_inference_equivalence(self): @require_torch_gpu @mark.flash_attn_test @slow - # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding + def test_flash_attn_2_conversion(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, + ).to(torch_device) + + for _, module in model.named_modules(): + if "FlashAttention" in module.__class__.__name__: + return + + self.assertTrue(False, "FlashAttention2 modules not found in model") + + @require_torch_sdpa + @require_torch_gpu + @slow + def test_sdpa_can_dispatch_on_flash(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + torch.compiler.reset() + compute_capability = torch.cuda.get_device_capability() + major, _ = compute_capability + + if not torch.version.cuda or major < 8: + self.skipTest(reason="This test requires an NVIDIA GPU with compute capability >= 8.0") + + for model_class in self.all_model_classes: + if not model_class._supports_sdpa: + self.skipTest(f"{model_class.__name__} does not support SDPA") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + inputs_dict = self._prepare_for_class(inputs_dict, model_class) + if config.model_type in ["llava", "llava_next", "vipllava", "video_llava"]: + self.skipTest( + reason="Llava-like models currently (transformers==4.39.1) requires an attention_mask input" + ) + if config.model_type in ["paligemma"]: + self.skipTest( + "PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input" + ) + if config.model_type in ["idefics", "idefics2", "idefics3"]: + self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input") + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation={"decoder": "sdpa", "audio_encoder": None, "text_encoder": None}, + ) + model.to(torch_device) + + inputs_dict.pop("attention_mask", None) + inputs_dict.pop("decoder_attention_mask", None) + + for name, inp in inputs_dict.items(): + if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]: + inputs_dict[name] = inp.to(torch.float16) + + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + _ = model(**inputs_dict) + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence_right_padding def test_flash_attn_2_inference_equivalence_right_padding(self): for model_class in self.all_model_classes: if not model_class._supports_flash_attn_2: @@ -1503,7 +1369,9 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + tmpdirname, + torch_dtype=torch.bfloat16, + attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, ) model_fa.to(torch_device) @@ -1573,7 +1441,7 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): @require_torch_gpu @mark.flash_attn_test @slow - # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding + # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_left_padding def test_flash_attn_2_generate_left_padding(self): # Ignore copy for model_class in self.greedy_sample_model_classes: @@ -1608,7 +1476,7 @@ def test_flash_attn_2_generate_left_padding(self): model = model_class.from_pretrained( tmpdirname, torch_dtype=torch.float16, - attn_implementation="flash_attention_2", + attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, low_cpu_mem_usage=True, ).to(torch_device) @@ -1622,7 +1490,7 @@ def test_flash_attn_2_generate_left_padding(self): @require_torch_gpu @mark.flash_attn_test @slow - # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right + # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_padding_right def test_flash_attn_2_generate_padding_right(self): # Ignore copy for model_class in self.greedy_sample_model_classes: @@ -1656,7 +1524,7 @@ def test_flash_attn_2_generate_padding_right(self): model = model_class.from_pretrained( tmpdirname, torch_dtype=torch.float16, - attn_implementation="flash_attention_2", + attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, low_cpu_mem_usage=True, ).to(torch_device) @@ -1670,7 +1538,7 @@ def test_flash_attn_2_generate_padding_right(self): @require_torch_gpu @mark.flash_attn_test @slow - # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache + # Adapted from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_generate_use_cache def test_flash_attn_2_generate_use_cache(self): max_new_tokens = 30 @@ -1699,7 +1567,7 @@ def test_flash_attn_2_generate_use_cache(self): model = model_class.from_pretrained( tmpdirname, torch_dtype=torch.float16, - attn_implementation="flash_attention_2", + attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, low_cpu_mem_usage=True, ).to(torch_device) @@ -1712,6 +1580,53 @@ def test_flash_attn_2_generate_use_cache(self): use_cache=True, ) + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self._is_composite: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + audio_encoder_attn = "sdpa" if model.audio_encoder._supports_sdpa else "eager" + text_encoder_attn = "sdpa" if model.text_encoder._supports_sdpa else "eager" + decoder_attn = "sdpa" if model.decoder._supports_sdpa else "eager" + + # `None` as it is the requested one which will be assigned to each sub-config + # Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present) + self.assertTrue(model_sdpa.audio_encoder.config._attn_implementation == audio_encoder_attn) + self.assertTrue(model_sdpa.text_encoder.config._attn_implementation == text_encoder_attn) + self.assertTrue(model_sdpa.decoder.config._attn_implementation == decoder_attn) + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") + model_eager = model_eager.eval().to(torch_device) + + self.assertTrue(model_eager.audio_encoder.config._attn_implementation == "eager") + self.assertTrue(model_eager.text_encoder.config._attn_implementation == "eager") + self.assertTrue(model_eager.decoder.config._attn_implementation == "eager") + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + if "SdpaAttention" in submodule.__class__.__name__: + has_sdpa = True + break + if not has_sdpa and model_sdpa.config.model_type != "falcon": + raise ValueError("The SDPA model should have SDPA attention layers") + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @require_torch_sdpa @slow @@ -1775,8 +1690,6 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) model_sdpa = model_sdpa.eval().to(torch_device) - self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - model_eager = model_class.from_pretrained( tmpdirname, torch_dtype=torch_dtype, @@ -1784,20 +1697,6 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): ) model_eager = model_eager.eval().to(torch_device) - self.assertTrue(model_eager.config._attn_implementation == "eager") - - for name, submodule in model_eager.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: - raise ValueError("The eager model should not have SDPA attention layers") - - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: - has_sdpa = True - break - if not has_sdpa and model_sdpa.config.model_type != "falcon": - raise ValueError("The SDPA model should have SDPA attention layers") - # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 8 times the model, # but it would be nicer to have an efficient way to use parameterized.expand fail_cases = [] diff --git a/tests/models/paligemma/test_modeling_paligemma.py b/tests/models/paligemma/test_modeling_paligemma.py index 644ac2cc5bd1b4..cfc2a2c29b1d70 100644 --- a/tests/models/paligemma/test_modeling_paligemma.py +++ b/tests/models/paligemma/test_modeling_paligemma.py @@ -187,6 +187,7 @@ class PaliGemmaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTes test_pruning = False test_torchscript = False test_head_masking = False + _is_composite = True def setUp(self): self.model_tester = PaliGemmaVisionText2TextModelTester(self) @@ -319,6 +320,16 @@ def test_generate_from_inputs_embeds_with_static_cache(self): def test_static_cache_matches_dynamic(self): pass + @unittest.skip("FlashAttention only support fp16 and bf16 data type") + def test_flash_attn_2_fp32_ln(self): + pass + + @unittest.skip( + "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" + ) + def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): + pass + @slow @require_torch diff --git a/tests/models/pop2piano/test_modeling_pop2piano.py b/tests/models/pop2piano/test_modeling_pop2piano.py index 3a33b5a98128e2..39ff67f08ce5a9 100644 --- a/tests/models/pop2piano/test_modeling_pop2piano.py +++ b/tests/models/pop2piano/test_modeling_pop2piano.py @@ -620,7 +620,7 @@ def test_export_to_onnx(self): (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), f"{tmpdirname}/Pop2Piano_test.onnx", export_params=True, - opset_version=9, + opset_version=14, input_names=["input_ids", "decoder_input_ids"], ) diff --git a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py index 4054055082c781..314f870f5d9096 100644 --- a/tests/models/qwen2_audio/test_modeling_qwen2_audio.py +++ b/tests/models/qwen2_audio/test_modeling_qwen2_audio.py @@ -15,6 +15,7 @@ """Testing suite for the PyTorch Qwen2Audio model.""" import gc +import tempfile import unittest from io import BytesIO from urllib.request import urlopen @@ -29,6 +30,7 @@ ) from transformers.testing_utils import ( require_torch, + require_torch_sdpa, slow, torch_device, ) @@ -152,6 +154,7 @@ class Qwen2AudioForConditionalGenerationModelTest(ModelTesterMixin, unittest.Tes all_model_classes = (Qwen2AudioForConditionalGeneration,) if is_torch_available() else () test_pruning = False test_head_masking = False + _is_composite = True def setUp(self): self.model_tester = Qwen2AudioModelTester(self) @@ -165,6 +168,53 @@ def test_sdpa_can_compile_dynamic(self): def test_sdpa_can_dispatch_on_flash(self): pass + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + # overwrite because Qwen2 is audio+text model (not vision+text) + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self._is_composite: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + text_attn = "sdpa" if model.language_model._supports_sdpa else "eager" + vision_attn = "sdpa" if model.audio_tower._supports_sdpa else "eager" + + # `None` as it is the requested one which will be assigned to each sub-config + # Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present) + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + self.assertTrue(model.language_model.config._attn_implementation == text_attn) + self.assertTrue(model.audio_tower.config._attn_implementation == vision_attn) + + model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") + model_eager = model_eager.eval().to(torch_device) + self.assertTrue(model_eager.config._attn_implementation == "eager") + self.assertTrue(model_eager.language_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.audio_tower.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa and model_sdpa.config.model_type != "falcon": + raise ValueError("The SDPA model should have SDPA attention layers") + @require_torch class Qwen2AudioForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/models/siglip/test_modeling_siglip.py b/tests/models/siglip/test_modeling_siglip.py index 9d1e3109b313c3..2fe06b1511a471 100644 --- a/tests/models/siglip/test_modeling_siglip.py +++ b/tests/models/siglip/test_modeling_siglip.py @@ -71,6 +71,51 @@ class SiglipModelTesterMixin(ModelTesterMixin): + def test_sdpa_can_dispatch_composite_models(self): + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + # Load the model with SDPA + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + # Load model with eager attention + model_eager = model_class.from_pretrained( + tmpdirname, + attn_implementation="eager", + ) + model_eager = model_eager.eval().to(torch_device) + + # SigLip has one shared cls attr for all models, so we assign both submodels heer + vision_attn = text_attn = "sdpa" if model._supports_sdpa else "eager" + + if hasattr(model_sdpa, "vision_model") and hasattr(model_sdpa, "text_model"): + self.assertTrue(model_sdpa.vision_model.config._attn_implementation == vision_attn) + self.assertTrue(model_sdpa.text_model.config._attn_implementation == text_attn) + self.assertTrue(model_eager.vision_model.config._attn_implementation == "eager") + self.assertTrue(model_eager.text_model.config._attn_implementation == "eager") + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa and model_sdpa.config.model_type != "falcon": + raise ValueError("The SDPA model should have SDPA attention layers") + def test_eager_matches_sdpa_inference( self, torch_dtype: str, @@ -132,23 +177,6 @@ def get_mean_reldiff(msg, current_case, x, ref, atol, rtol): ) model_eager = model_eager.eval().to(torch_device) - self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - self.assertTrue(model_eager.config._attn_implementation == "eager") - - for name, submodule in model_eager.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - raise ValueError("The eager model should not have SDPA attention layers") - - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa and model_sdpa.config.model_type != "falcon": - raise ValueError("The SDPA model should have SDPA attention layers") - # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving the model each time, # but it would be nicer to have an efficient way to use parameterized.expand cases = [ @@ -400,6 +428,10 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): use_attention_mask_options=(False,), ) + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + super().test_sdpa_can_dispatch_composite_models() + class SiglipTextModelTester: def __init__( @@ -562,6 +594,10 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): use_attention_mask_options=(False, True), ) + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + super().test_sdpa_can_dispatch_composite_models() + class SiglipModelTester: def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True): @@ -629,6 +665,7 @@ class SiglipModelTest(SiglipModelTesterMixin, PipelineTesterMixin, unittest.Test test_cpu_offload = False test_disk_offload_safetensors = False test_disk_offload_bin = False + _is_composite = True # Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.setUp with CLIP->Siglip def setUp(self): @@ -851,6 +888,10 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): use_attention_mask_options=(False, True), ) + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + super().test_sdpa_can_dispatch_composite_models() + class SiglipForImageClassificationModelTester(SiglipModelTester): def __init__(self, parent): @@ -888,6 +929,7 @@ class SiglipForImageClassificationModelTest(SiglipModelTesterMixin, PipelineTest test_cpu_offload = False test_disk_offload_safetensors = False test_disk_offload_bin = False + _is_composite = True def setUp(self): self.model_tester = SiglipForImageClassificationModelTester(self) @@ -925,6 +967,10 @@ def test_eager_matches_sdpa_inference(self, torch_dtype: str): torch_dtype=torch_dtype, logit_keys=("logits",), use_attention_mask_options=(False,) ) + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + super().test_sdpa_can_dispatch_composite_models() + # We will verify our results on an image of cute cats def prepare_img(): diff --git a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py index b193cacfb40042..6e0b7fa9782fbc 100644 --- a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py +++ b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py @@ -18,7 +18,13 @@ import unittest from transformers import is_torch_available -from transformers.testing_utils import require_deterministic_for_xpu, require_torch, slow, torch_device +from transformers.testing_utils import ( + require_deterministic_for_xpu, + require_torch, + require_torch_sdpa, + slow, + torch_device, +) from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from ..bert.test_modeling_bert import BertModelTester @@ -441,6 +447,66 @@ def test_real_model_save_load_from_pretrained(self): max_diff = np.amax(np.abs(out_1 - out_2)) self.assertLessEqual(max_diff, 1e-5) + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + inputs_dict = self.prepare_config_and_inputs() + encoder_config, decoder_config = inputs_dict["config"], inputs_dict["decoder_config"] + config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs( + encoder_config=encoder_config, decoder_config=decoder_config + ) + model = SpeechEncoderDecoderModel(config=config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = SpeechEncoderDecoderModel.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + # see https://github.com/huggingface/transformers/pull/32238 + # Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present) + encoder_attn = "sdpa" if model.encoder._supports_sdpa else "eager" + decoder_attn = "sdpa" if model.decoder._supports_sdpa else "eager" + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + self.assertTrue(model_sdpa.encoder.config._attn_implementation == encoder_attn) + self.assertTrue(model_sdpa.decoder.config._attn_implementation == decoder_attn) + + # Also test that nothing break if we request SDPA explicitly, when both sub-parts support it. + # If the model supports sdpa (i.e. all of sub-models supports it) we'll dispatch safely + # Otherwise we should raise error that SDPA is not supported, as some of the sub-models doesn't support + if encoder_attn == "sdpa" and decoder_attn == "sdpa": + model_sdpa_explicit = SpeechEncoderDecoderModel.from_pretrained(tmpdirname, attn_implementation="sdpa") + model_sdpa_explicit = model_sdpa_explicit.eval().to(torch_device) + + self.assertTrue(model_sdpa_explicit.config._attn_implementation == "sdpa") + else: + with self.assertRaises(ValueError): + model_sdpa_explicit = SpeechEncoderDecoderModel.from_pretrained( + tmpdirname, attn_implementation="sdpa" + ) + + model_eager = SpeechEncoderDecoderModel.from_pretrained( + tmpdirname, + attn_implementation="eager", + ) + model_eager = model_eager.eval().to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + self.assertTrue(model_eager.encoder.config._attn_implementation == "eager") + self.assertTrue(model_eager.decoder.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa: + raise ValueError("The SDPA model should have SDPA attention layers") + @require_torch class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase): diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 13215b2826fe0c..7adb1f40c6e696 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -36,6 +36,7 @@ if is_torch_available(): import torch + import torch.nn.functional as F from transformers import ( AutoTokenizer, @@ -645,6 +646,41 @@ def test_decoder_model_past_with_3d_attn_mask(self): lm_labels, ) + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + def test_decoder_model_past_with_large_inputs(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) diff --git a/tests/models/t5/test_modeling_t5.py b/tests/models/t5/test_modeling_t5.py index fe9b40a54abef7..68dd5a52b3d69b 100644 --- a/tests/models/t5/test_modeling_t5.py +++ b/tests/models/t5/test_modeling_t5.py @@ -27,6 +27,7 @@ require_sentencepiece, require_tokenizers, require_torch, + require_torch_gpu, slow, torch_device, ) @@ -44,6 +45,7 @@ if is_torch_available(): import torch + import torch.nn.functional as F from transformers import ( AutoTokenizer, @@ -578,6 +580,9 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, # The small T5 model needs higher percentages for CPU/MP tests model_split_percents = [0.5, 0.8, 0.9] + # used in `test_torch_compile` + _torch_compile_test_ckpt = "google-t5/t5-small" + def setUp(self): self.model_tester = T5ModelTester(self) self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37) @@ -630,12 +635,9 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa ] if labels is not None: input_names.append("labels") - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} input_names = list(filtered_inputs.keys()) - model_output = model(**filtered_inputs) - traced_model = symbolic_trace(model, input_names) traced_output = traced_model(**filtered_inputs) else: @@ -650,7 +652,6 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa "visual_feats", "visual_pos", ] - labels = inputs.get("labels", None) start_positions = inputs.get("start_positions", None) end_positions = inputs.get("end_positions", None) @@ -660,15 +661,12 @@ def _create_and_check_torch_fx_tracing(self, config, inputs_dict, output_loss=Fa input_names.append("start_positions") if end_positions is not None: input_names.append("end_positions") - filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} input_names = list(filtered_inputs.keys()) - if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( not hasattr(model.config, "problem_type") or model.config.problem_type is None ): model.config.problem_type = "single_label_classification" - traced_model = symbolic_trace(model, input_names) traced_output = traced_model(**filtered_inputs) model_output = model(**filtered_inputs) @@ -721,6 +719,41 @@ def flatten_output(output): # (Even with this call, there are still memory leak by ~0.04MB) self.clear_torch_jit_class_registry() + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + def test_config(self): self.config_tester.run_common_tests() @@ -1482,6 +1515,7 @@ def test_summarization(self): [model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]], padding="max_length", truncation=True, + max_length=512, return_tensors="pt", ).to(torch_device) self.assertEqual(512, dct["input_ids"].shape[1]) @@ -1604,14 +1638,76 @@ def test_contrastive_search_t5(self): outputs = t5_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64) generated_text = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True) + # TODO: @arthur? + # PR #31938 caused regression on this test which was fixed by PR #34089 self.assertListEqual( generated_text, [ - "Liana Barrientos has been married 10 times, nine of them in the Bronx. Her husbands filed for " - "permanent residence after the marriages, prosecutors say." + "Liana Barrientos has been married 10 times, nine of them in the Bronx . Her husbands filed for " + "permanent residence after the marriages, prosecutors say ." ], ) + @slow + @require_torch_gpu + def test_compile_static_cache(self): + NUM_TOKENS_TO_GENERATE = 40 + EXPECTED_TEXT_COMPLETION = [ + "theory of relativity states that 1) the speed of light is constant in all inertial reference frames. the laws of physics are the same for all inertial reference frames.", + "ketchup is my favorite condiment.", + ] + + prompts = [ + "summarize: Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " + "theory of relativity is not hard to grasp.", + "summarize: My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " + "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my pizza.", + ] + model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small").to(torch_device) + tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small") + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text) + + # Static Cache + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text) + + # Static Cache + compile + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text) + + @slow + @require_torch_gpu + def test_compile_static_cache_encoder(self): + prompts = [ + "summarize: Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " + "theory of relativity is not hard to grasp.", + "summarize: My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " + "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my pizza.", + ] + model = T5EncoderModel.from_pretrained("google-t5/t5-small").to(torch_device) + tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small") + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + logits = model(**inputs) + + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + logits_compiled = model(**inputs) + self.assertTrue(torch.allclose(logits[0][:, -3:, -3], logits_compiled[0][:, -3:, -3], atol=1e-5)) + @require_torch class TestAsymmetricT5(unittest.TestCase): diff --git a/tests/models/udop/test_modeling_udop.py b/tests/models/udop/test_modeling_udop.py index a3ae498606a379..9d82173b1aed6c 100644 --- a/tests/models/udop/test_modeling_udop.py +++ b/tests/models/udop/test_modeling_udop.py @@ -37,6 +37,7 @@ if is_torch_available(): import torch + import torch.nn.functional as F from transformers import UdopEncoderModel, UdopForConditionalGeneration, UdopModel, UdopProcessor @@ -348,6 +349,7 @@ def test_forward_signature(self): expected_arg_names = [ "attention_mask", "bbox", + "cache_position", "cross_attn_head_mask", "decoder_attention_mask", "decoder_head_mask", @@ -365,6 +367,43 @@ def test_forward_signature(self): expected_arg_names = sorted(expected_arg_names) self.assertListEqual(sorted(arg_names[: len(expected_arg_names)]), expected_arg_names) + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + bbox=input_dict["bbox"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + bbox=input_dict["bbox"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + @unittest.skip( "Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!" ) @@ -534,6 +573,41 @@ def test_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + @unittest.skip( "Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!" ) diff --git a/tests/models/umt5/test_modeling_umt5.py b/tests/models/umt5/test_modeling_umt5.py index 1bd01da8e6caec..ec4c1d019b6d17 100644 --- a/tests/models/umt5/test_modeling_umt5.py +++ b/tests/models/umt5/test_modeling_umt5.py @@ -41,6 +41,7 @@ if is_torch_available(): import torch + import torch.nn.functional as F from transformers import ( AutoTokenizer, @@ -316,6 +317,9 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin # The small UMT5 model needs higher percentages for CPU/MP tests model_split_percents = [0.5, 0.8, 0.9] + # used in `test_torch_compile` + _torch_compile_test_ckpt = "google/umt5-small" + def setUp(self): self.model_tester = UMT5ModelTester(self) @@ -486,6 +490,41 @@ def test_inputs_embeds(self): with torch.no_grad(): model(**inputs)[0] + # overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids` + def test_custom_4d_attention_mask(self): + for model_class in self.all_generative_model_classes: + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config).to(device=torch_device, dtype=torch.float32) + + ( + input_ids, + _, + input_ids_shared_prefix, + mask_shared_prefix, + _, + ) = self._get_custom_4d_mask_test_data() + + logits = model.forward( + decoder_input_ids=input_ids, + input_ids=input_dict["input_ids"][:3], + ).logits + # logits.shape == torch.Size([3, 4, ...]) + + logits_shared_prefix = model( + input_ids=input_dict["input_ids"][:1], + decoder_input_ids=input_ids_shared_prefix, + decoder_attention_mask=mask_shared_prefix, + )[0] + # logits_shared_prefix.shape == torch.Size([1, 6, ...]) + + out_last_tokens = logits[:, -1, :] # last tokens in each batch line + out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens + + # comparing softmax-normalized logits: + normalized_0 = F.softmax(out_last_tokens) + normalized_1 = F.softmax(out_shared_prefix_last_tokens) + torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4) + def test_with_sequence_classification_head(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_with_sequence_classification_head(*config_and_inputs) diff --git a/tests/models/video_llava/test_modeling_video_llava.py b/tests/models/video_llava/test_modeling_video_llava.py index 492dcb9bae1f92..1bd01843981deb 100644 --- a/tests/models/video_llava/test_modeling_video_llava.py +++ b/tests/models/video_llava/test_modeling_video_llava.py @@ -206,6 +206,7 @@ class VideoLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTe test_pruning = False test_resize_embeddings = True test_head_masking = False + _is_composite = True def setUp(self): self.model_tester = VideoLlavaVisionText2TextModelTester(self) @@ -237,6 +238,16 @@ def test_sdpa_can_compile_dynamic(self): def test_sdpa_can_dispatch_on_flash(self): pass + @unittest.skip("FlashAttention only support fp16 and bf16 data type") + def test_flash_attn_2_fp32_ln(self): + pass + + @unittest.skip( + "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" + ) + def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): + pass + @unittest.skip( reason="After #33533, this still passes, but many subsequential tests fail with `device-side assert triggered`" ) diff --git a/tests/models/vipllava/test_modeling_vipllava.py b/tests/models/vipllava/test_modeling_vipllava.py index 862e144ecdd7d8..2c241c23f26158 100644 --- a/tests/models/vipllava/test_modeling_vipllava.py +++ b/tests/models/vipllava/test_modeling_vipllava.py @@ -168,6 +168,7 @@ class VipLlavaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTest test_pruning = False test_resize_embeddings = True test_head_masking = False + _is_composite = True def setUp(self): self.model_tester = VipLlavaVisionText2TextModelTester(self) @@ -242,6 +243,16 @@ def test_sdpa_can_compile_dynamic(self): def test_sdpa_can_dispatch_on_flash(self): pass + @unittest.skip("FlashAttention only support fp16 and bf16 data type") + def test_flash_attn_2_fp32_ln(self): + pass + + @unittest.skip( + "VLMs need lots of steps to prepare images/mask correctly to get pad-free inputs. Can be tested as part of LLM test" + ) + def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): + pass + @require_torch class VipLlavaForConditionalGenerationIntegrationTest(unittest.TestCase): diff --git a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py index e5bc88d5bfb272..7def8a9ac96507 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py @@ -27,17 +27,24 @@ require_nltk, require_sentencepiece, require_torch, + require_torch_sdpa, require_vision, slow, to_2tuple, torch_device, ) -from transformers.utils import cached_property, is_torch_available, is_vision_available +from transformers.utils import ( + cached_property, + is_torch_available, + is_vision_available, +) from ...test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from ..bart.test_modeling_bart import BartModelTester from ..bert.test_modeling_bert import BertModelTester from ..deit.test_modeling_deit import DeiTModelTester +from ..donut.test_modeling_donut_swin import DonutSwinModelTester +from ..gpt2.test_modeling_gpt2 import GPT2ModelTester from ..layoutlmv3.test_modeling_layoutlmv3 import LayoutLMv3ModelTester from ..swin.test_modeling_swin import SwinModelTester from ..trocr.test_modeling_trocr import TrOCRStandaloneDecoderModelTester @@ -53,6 +60,8 @@ BartForCausalLM, BertLMHeadModel, DeiTModel, + DonutSwinModel, + GPT2LMHeadModel, LayoutLMv3Model, SwinModel, TrOCRForCausalLM, @@ -72,6 +81,8 @@ @require_torch class EncoderDecoderMixin: + supports_sdpa = False + def get_encoder_decoder_model(self, config, decoder_config): pass @@ -374,6 +385,69 @@ def test_real_model_save_load_from_pretrained(self): max_diff = np.amax(np.abs(out_1 - out_2)) self.assertLessEqual(max_diff, 1e-5) + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + if not self.supports_sdpa: + self.skipTest("SDPA is not supported") + + inputs_dict = self.prepare_config_and_inputs() + encoder_config, decoder_config = inputs_dict["config"], inputs_dict["decoder_config"] + config = VisionEncoderDecoderConfig.from_encoder_decoder_configs( + encoder_config=encoder_config, decoder_config=decoder_config + ) + model = VisionEncoderDecoderModel(config=config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = VisionEncoderDecoderModel.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + # see https://github.com/huggingface/transformers/pull/32238 + # Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present) + encoder_attn = "sdpa" if model.encoder._supports_sdpa else "eager" + decoder_attn = "sdpa" if model.decoder._supports_sdpa else "eager" + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + self.assertTrue(model_sdpa.encoder.config._attn_implementation == encoder_attn) + self.assertTrue(model_sdpa.decoder.config._attn_implementation == decoder_attn) + + # Also test that nothing break if we request SDPA explicitly, when both sub-parts support it. + # If the model supports sdpa (i.e. all of sub-models supports it) we'll dispatch safely + # Otherwise we should raise error that SDPA is not supported, as some of the sub-models doesn't support + if encoder_attn == "sdpa" and decoder_attn == "sdpa": + model_sdpa_explicit = VisionEncoderDecoderModel.from_pretrained(tmpdirname, attn_implementation="sdpa") + model_sdpa_explicit = model_sdpa_explicit.eval().to(torch_device) + + self.assertTrue(model_sdpa_explicit.config._attn_implementation == "sdpa") + else: + with self.assertRaises(ValueError): + model_sdpa_explicit = VisionEncoderDecoderModel.from_pretrained( + tmpdirname, attn_implementation="sdpa" + ) + + model_eager = VisionEncoderDecoderModel.from_pretrained( + tmpdirname, + attn_implementation="eager", + ) + model_eager = model_eager.eval().to(torch_device) + + self.assertTrue(model_eager.config._attn_implementation == "eager") + self.assertTrue(model_eager.encoder.config._attn_implementation == "eager") + self.assertTrue(model_eager.decoder.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa: + raise ValueError("The SDPA model should have SDPA attention layers") + @require_torch class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase): @@ -497,6 +571,8 @@ def prepare_config_and_inputs(self): @require_torch class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase): + supports_sdpa = True # one submodel support SDPA + def get_pretrained_model_and_inputs(self): model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( "hf-internal-testing/tiny-random-vit", "hf-internal-testing/tiny-bert" @@ -649,6 +725,8 @@ def test_real_model_save_load_from_pretrained(self): @require_torch class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase): + supports_sdpa = True # one submodel support SDPA + def get_encoder_decoder_model(self, config, decoder_config): encoder_model = ViTModel(config).eval() decoder_model = TrOCRForCausalLM(decoder_config).eval() @@ -804,6 +882,240 @@ def test_real_model_save_load_from_pretrained(self): pass +@require_torch +class VIT2GPT2Test(EncoderDecoderMixin, unittest.TestCase): + supports_sdpa = True # both submodels support SDPA + + def get_encoder_decoder_model(self, config, decoder_config): + encoder_model = ViTModel(config).eval() + decoder_model = GPT2LMHeadModel(decoder_config).eval() + return encoder_model, decoder_model + + def prepare_config_and_inputs(self): + model_tester_encoder = ViTModelTester(self, batch_size=13) + model_tester_decoder = GPT2ModelTester(self, batch_size=13, hidden_size=32, max_position_embeddings=512) + encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs() + decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs() + config, pixel_values, labels = encoder_config_and_inputs + ( + decoder_config, + decoder_input_ids, + decoder_attention_mask, + decoder_head_mask, + decoder_token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) = decoder_config_and_inputs + + # make sure that cross attention layers are added + decoder_config.add_cross_attention = True + # disable cache for now + decoder_config.use_cache = False + return { + "config": config, + "pixel_values": pixel_values, + "decoder_config": decoder_config, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + "decoder_head_mask": decoder_head_mask, + "labels": decoder_input_ids, + } + + def check_encoder_decoder_model_output_attentions( + self, + config, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + pixel_values, + labels=None, + **kwargs, + ): + # make the decoder inputs a different shape from the encoder inputs to harden the test + decoder_input_ids = decoder_input_ids[:, :-1] + decoder_attention_mask = decoder_attention_mask[:, :-1] + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + enc_dec_model.to(torch_device) + outputs_encoder_decoder = enc_dec_model( + pixel_values=pixel_values, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + output_attentions=True, + **kwargs, + ) + + encoder_attentions = outputs_encoder_decoder["encoder_attentions"] + self.assertEqual(len(encoder_attentions), config.num_hidden_layers) + + seq_len = (encoder_model.config.image_size // encoder_model.config.patch_size) ** 2 + 1 + + decoder_attentions = outputs_encoder_decoder["decoder_attentions"] + num_decoder_layers = ( + decoder_config.num_decoder_layers + if hasattr(decoder_config, "num_decoder_layers") + else decoder_config.num_hidden_layers + ) + self.assertEqual(len(decoder_attentions), num_decoder_layers) + + self.assertEqual( + decoder_attentions[0].shape[-3:], + (decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]), + ) + + cross_attentions = outputs_encoder_decoder["cross_attentions"] + self.assertEqual(len(cross_attentions), num_decoder_layers) + + cross_attention_input_seq_len = decoder_input_ids.shape[-1] + self.assertEqual( + cross_attentions[0].shape[-3:], + (decoder_config.num_attention_heads, cross_attention_input_seq_len, seq_len), # 4 6 16 + ) + + def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_values=None, **kwargs): + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + + # Generate until max length + if hasattr(enc_dec_model.config, "eos_token_id"): + enc_dec_model.config.eos_token_id = None + if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"): + enc_dec_model.config.decoder.eos_token_id = None + if hasattr(enc_dec_model.generation_config, "eos_token_id"): + enc_dec_model.generation_config.eos_token_id = None + enc_dec_model.to(torch_device) + + generated_output = enc_dec_model.generate( + pixel_values=pixel_values, + decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id, + **kwargs, + ) + self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,)) + + @unittest.skip(reason="VIT2GPT2 also has an integration test for testinf save-load") + def test_real_model_save_load_from_pretrained(self): + pass + + +@require_torch +class Donut2GPT2Test(EncoderDecoderMixin, unittest.TestCase): + supports_sdpa = True # one submodel (GPT2) support SDPA + + def get_encoder_decoder_model(self, config, decoder_config): + encoder_model = DonutSwinModel(config).eval() + decoder_model = GPT2LMHeadModel(decoder_config).eval() + return encoder_model, decoder_model + + def prepare_config_and_inputs(self): + model_tester_encoder = DonutSwinModelTester(self, batch_size=13) + model_tester_decoder = GPT2ModelTester(self, batch_size=13, hidden_size=32, max_position_embeddings=512) + encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs() + decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs() + config, pixel_values, labels = encoder_config_and_inputs + ( + decoder_config, + decoder_input_ids, + decoder_attention_mask, + decoder_head_mask, + decoder_token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) = decoder_config_and_inputs + + # make sure that cross attention layers are added + decoder_config.add_cross_attention = True + # disable cache for now + decoder_config.use_cache = False + return { + "config": config, + "pixel_values": pixel_values, + "decoder_config": decoder_config, + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + "decoder_head_mask": decoder_head_mask, + "labels": decoder_input_ids, + } + + def check_encoder_decoder_model_output_attentions( + self, + config, + decoder_config, + decoder_input_ids, + decoder_attention_mask, + pixel_values, + labels=None, + **kwargs, + ): + # make the decoder inputs a different shape from the encoder inputs to harden the test + decoder_input_ids = decoder_input_ids[:, :-1] + decoder_attention_mask = decoder_attention_mask[:, :-1] + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + enc_dec_model.to(torch_device) + outputs_encoder_decoder = enc_dec_model( + pixel_values=pixel_values, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + output_attentions=True, + **kwargs, + ) + + encoder_attentions = outputs_encoder_decoder["encoder_attentions"] + self.assertEqual(len(encoder_attentions), config.num_hidden_layers) + + seq_len = encoder_model.config.image_size // encoder_model.config.patch_size + + decoder_attentions = outputs_encoder_decoder["decoder_attentions"] + num_decoder_layers = ( + decoder_config.num_decoder_layers + if hasattr(decoder_config, "num_decoder_layers") + else decoder_config.num_hidden_layers + ) + self.assertEqual(len(decoder_attentions), num_decoder_layers) + + self.assertEqual( + decoder_attentions[0].shape[-3:], + (decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]), + ) + + cross_attentions = outputs_encoder_decoder["cross_attentions"] + self.assertEqual(len(cross_attentions), num_decoder_layers) + + cross_attention_input_seq_len = decoder_input_ids.shape[-1] + self.assertEqual( + cross_attentions[0].shape[-3:], + (decoder_config.num_attention_heads, cross_attention_input_seq_len, seq_len), # 4 6 16 + ) + + def check_encoder_decoder_model_generate(self, config, decoder_config, pixel_values=None, **kwargs): + encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) + enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) + + # Generate until max length + if hasattr(enc_dec_model.config, "eos_token_id"): + enc_dec_model.config.eos_token_id = None + if hasattr(enc_dec_model.config, "decoder") and hasattr(enc_dec_model.config.decoder, "eos_token_id"): + enc_dec_model.config.decoder.eos_token_id = None + if hasattr(enc_dec_model.generation_config, "eos_token_id"): + enc_dec_model.generation_config.eos_token_id = None + enc_dec_model.to(torch_device) + + generated_output = enc_dec_model.generate( + pixel_values=pixel_values, + decoder_start_token_id=enc_dec_model.config.decoder.bos_token_id, + **kwargs, + ) + self.assertEqual(generated_output.shape, (pixel_values.shape[0],) + (decoder_config.max_length,)) + + @unittest.skip(reason="Donut has an Integration test for that") + def test_real_model_save_load_from_pretrained(self): + pass + + @require_vision @require_torch class TrOCRModelIntegrationTest(unittest.TestCase): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 104923957568aa..964b7b912b4e0f 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -37,6 +37,7 @@ from transformers import ( AutoModel, AutoModelForCausalLM, + AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig, @@ -207,6 +208,7 @@ class ModelTesterMixin: test_model_parallel = False is_encoder_decoder = False has_attentions = True + _is_composite = False model_split_percents = [0.5, 0.7, 0.9] def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): @@ -3006,6 +3008,7 @@ def test_inputs_embeds_matches_input_ids_with_generate(self): *get_values(MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES), ]: continue + model = model_class(config) model.to(torch_device) model.eval() @@ -3950,6 +3953,147 @@ def test_flash_attn_2_generate_padding_right(self): self.assertTrue(torch.allclose(out, out_fa)) + def test_attn_implementation_composite_models(self): + """ + Tests if composite models can receive a dict object as attn_implementation, where each key should be + one of the sub-configs from the model's config. + """ + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + for model_class in self.all_model_classes: + if not self._is_composite: + self.skipTest("Model is not a composite model.") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + sub_configs = { + key: getattr(config, key) for key in config if isinstance(getattr(config, key), PretrainedConfig) + } + + # set eager as it will be the one supported in all models + # we just need to test if passing 'attn_implementation' as a dict fails or not + attn_implementation_per_subconfig = {} + for key, sub_config in sub_configs.items(): + attn_implementation_per_subconfig[key] = "eager" + + config._attn_implementation = attn_implementation_per_subconfig + model = model_class(config) + for key in model.config: + if isinstance(getattr(model.config, key), PretrainedConfig): + sub_config = getattr(model.config, key) + self.assertTrue(sub_config._attn_implementation == "eager") + + for name, submodule in model.named_modules(): + class_name = submodule.__class__.__name__ + if ( + "SdpaAttention" in class_name + or "SdpaSelfAttention" in class_name + or "FlashAttention" in class_name + ): + raise ValueError("The eager model should not have SDPA/FA2 attention layers") + + @require_torch_sdpa + def test_sdpa_can_dispatch_non_composite_models(self): + """ + Tests if non-composite models dispatch correctly on SDPA/eager when requested so when loading the model. + This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention". + """ + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self.all_model_classes[0]._supports_sdpa or self._is_composite: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") + + model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") + model_eager = model_eager.eval().to(torch_device) + self.assertTrue(model_eager.config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa and model_sdpa.config.model_type != "falcon": + raise ValueError("The SDPA model should have SDPA attention layers") + + @require_torch_sdpa + def test_sdpa_can_dispatch_composite_models(self): + """ + Tests if composite models dispatch correctly on SDPA/eager when requested so when loading the model. + This tests only by looking at layer names, as usually SDPA layers are calles "SDPAAttention". + In contrast to the above test, this one checks if the "config._attn_implamentation" is a dict after the model + is loaded, because we manually replicate requested attn implementation on each sub-config when loading. + See https://github.com/huggingface/transformers/pull/32238 for more info + + The test tries to cover most general cases of composite models, VLMs with vision and text configs. Any model + that has a different set of sub-configs has to overwrite this test. + """ + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not self._is_composite: + self.skipTest(f"{self.all_model_classes[0].__name__} does not support SDPA") + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_sdpa = model_class.from_pretrained(tmpdirname) + model_sdpa = model_sdpa.eval().to(torch_device) + + vision_model_names = {"visual", "image_tower", "vision_tower", "vision_model"} + language_model_names = {"language_model", "model", "text_model"} + vision_model_name = [name for name in vision_model_names if hasattr(model_sdpa, name)][0] + language_model_name = [name for name in language_model_names if hasattr(model_sdpa, name)][0] + + vision_model_sdpa = getattr(model, vision_model_name) + language_model_sdpa = getattr(model, language_model_name) + text_attn = "sdpa" if language_model_sdpa._supports_sdpa else "eager" + vision_attn = "sdpa" if vision_model_sdpa._supports_sdpa else "eager" + + # `None` as it is the requested one which will be assigned to each sub-config + # Sub-model will dispatch to SDPA if it can (checked below that `SDPA` layers are present) + self.assertTrue(language_model_sdpa.config._attn_implementation == text_attn) + self.assertTrue(vision_model_sdpa.config._attn_implementation == vision_attn) + + model_eager = model_class.from_pretrained(tmpdirname, attn_implementation="eager") + model_eager = model_eager.eval().to(torch_device) + self.assertTrue(getattr(model_eager, language_model_name).config._attn_implementation == "eager") + self.assertTrue(getattr(model_eager, vision_model_name).config._attn_implementation == "eager") + + for name, submodule in model_eager.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + raise ValueError("The eager model should not have SDPA attention layers") + + has_sdpa = False + for name, submodule in model_sdpa.named_modules(): + class_name = submodule.__class__.__name__ + if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: + has_sdpa = True + break + if not has_sdpa and any(module_attn == "sdpa" for module_attn in [text_attn, vision_attn]): + raise ValueError("The SDPA model should have SDPA attention layers") + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @require_torch_sdpa @slow @@ -4012,7 +4156,6 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): # This means that the class needs to be instantiated much later, after `use_mask` is set, which means a significant refactor of the code. # However masking there is not done at any layers that matters (i.e self-attention), therefore we can safely deactivate it. deactivate_mask = "use_mask_token" in inspect.signature(model_class).parameters - is_encoder_decoder = model.config.is_encoder_decoder with tempfile.TemporaryDirectory() as tmpdirname: @@ -4020,8 +4163,6 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): model_sdpa = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) model_sdpa = model_sdpa.eval().to(torch_device) - self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - model_eager = model_class.from_pretrained( tmpdirname, torch_dtype=torch_dtype, @@ -4029,22 +4170,6 @@ def get_mean_reldiff(failcase, x, ref, atol, rtol): ) model_eager = model_eager.eval().to(torch_device) - self.assertTrue(model_eager.config._attn_implementation == "eager") - - for name, submodule in model_eager.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - raise ValueError("The eager model should not have SDPA attention layers") - - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa and model_sdpa.config.model_type != "falcon": - raise ValueError("The SDPA model should have SDPA attention layers") - # We use these for loops instead of parameterized.expand just for the interest of avoiding loading/saving 16 times the model, # but it would be nicer to have an efficient way to use parameterized.expand fail_cases = [] @@ -4279,7 +4404,7 @@ def test_sdpa_can_dispatch_on_flash(self): self.skipTest( "PaliGemma-like models currently (transformers==4.41.0) requires an attention_mask input" ) - if config.model_type in ["idefics"]: + if config.model_type in ["idefics", "idefics2", "idefics3"]: self.skipTest(reason="Idefics currently (transformers==4.39.1) requires an image_attention_mask input") model = model_class(config) @@ -4382,8 +4507,6 @@ def test_eager_matches_sdpa_generate(self): low_cpu_mem_usage=True, ).to(torch_device) - self.assertTrue(model_sdpa.config._attn_implementation == "sdpa") - model_eager = model_class.from_pretrained( tmpdirname, torch_dtype=torch.float16, @@ -4391,22 +4514,6 @@ def test_eager_matches_sdpa_generate(self): attn_implementation="eager", ).to(torch_device) - self.assertTrue(model_eager.config._attn_implementation == "eager") - - for name, submodule in model_eager.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - raise ValueError("The eager model should not have SDPA attention layers") - - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - class_name = submodule.__class__.__name__ - if "SdpaAttention" in class_name or "SdpaSelfAttention" in class_name: - has_sdpa = True - break - if not has_sdpa: - raise ValueError("The SDPA model should have SDPA attention layers") - # Just test that a large cache works as expected res_eager = model_eager.generate( dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=max_new_tokens, do_sample=False @@ -4429,6 +4536,8 @@ def test_sdpa_matches_eager_sliding_window(self): self.skipTest(f"No generative model classes for {self.__class__.__name__}") for model_class in self.all_generative_model_classes: + if model_class._supports_sdpa: + self.skipTest(reason="Model architecture does not support attentions") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() if config.model_type not in WINDOW_ATTENTION_MODELS: @@ -4531,6 +4640,62 @@ def test_flash_attn_2_generate_use_cache(self): use_cache=True, ) + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + def test_flash_attn_2_can_dispatch_composite_models(self): + """ + Tests if composite models can dispatch on FA2 if the sub-models support FA2. + The tests is needed as we handle differently composite models and we cannot check them + with above tests. If any of the sub-models does not support FA2, we'll raise an error when dispatching + that particular sub-model. Otherwise we dispatch safely in all sub-models, where "sub-models" are specific + backbone models (LM/vision/audio/etc) + """ + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + if not is_torch_fp16_available_on_device(torch_device): + self.skipTest(f"float16 not supported on {torch_device} (on the specific device currently used)") + + torch_dtype = torch.float16 + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + if not self._is_composite: + self.skipTest("This model is not a composte model!") + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch_dtype) + + supports_fa2_all_modules = all( + module._supports_flash_attn_2 + for name, module in model.named_modules() + if isinstance(module, PreTrainedModel) and name != "" + ) + if not supports_fa2_all_modules: + with self.assertRaises(ValueError): + model_fa2 = model_class.from_pretrained( + tmpdirname, torch_dtype=torch_dtype, attn_implementation="flash_attention_2" + ) + else: + model_fa2 = model_class.from_pretrained( + tmpdirname, torch_dtype=torch_dtype, attn_implementation="flash_attention_2" + ) + for key in model_fa2.config: + if isinstance(getattr(model_fa2.config, key), PretrainedConfig): + sub_config = getattr(model_fa2.config, key) + self.assertTrue(sub_config._attn_implementation == "flash_attention_2") + + has_fa2 = False + for name, submodule in model_fa2.named_modules(): + class_name = submodule.__class__.__name__ + if "FlashAttention" in class_name: + has_fa2 = True + break + if not has_fa2: + raise ValueError("The FA2 model should have FA2 layers") + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -4679,7 +4844,7 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): if 0 in inputs_dict["attention_mask"][:, -1]: inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1) dummy_attention_mask = inputs_dict["attention_mask"] - inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.pad_token_id + inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id model = ( model_class.from_pretrained( @@ -4945,10 +5110,15 @@ def test_torch_compile(self): batch_size = 1 n_iter = 3 - tokenizer = AutoTokenizer.from_pretrained(ckpt, revision=revision) - model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( - torch_device - ) + tokenizer = AutoTokenizer.from_pretrained(ckpt) + if self.is_encoder_decoder: + model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( + torch_device + ) + else: + model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( + torch_device + ) model.generation_config.max_new_tokens = 4 @@ -5020,10 +5190,15 @@ def test_compile_cuda_graph_time(self): os.environ["TOKENIZERS_PARALLELISM"] = "false" - tokenizer = AutoTokenizer.from_pretrained(ckpt, revision=revision) - model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( - torch_device - ) + tokenizer = AutoTokenizer.from_pretrained(ckpt) + if self.is_encoder_decoder: + model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( + torch_device + ) + else: + model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( + torch_device + ) cache_implementation = "static" if model.config.model_type == "gemma2": diff --git a/tests/utils/test_configuration_utils.py b/tests/utils/test_configuration_utils.py index d2701bf35e6603..35a651d0e59873 100644 --- a/tests/utils/test_configuration_utils.py +++ b/tests/utils/test_configuration_utils.py @@ -228,6 +228,7 @@ def test_config_common_kwargs_is_complete(self): "_name_or_path", "_commit_hash", "_attn_implementation_internal", + "_attn_implementation_autoset", "transformers_version", ], ) diff --git a/utils/check_repo.py b/utils/check_repo.py index 6872dada3d9384..10be5cdcd26230 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -82,6 +82,8 @@ "SeamlessM4Tv2TextToUnitModel", "SeamlessM4Tv2CodeHifiGan", "SeamlessM4Tv2TextToUnitForConditionalGeneration", + "Idefics2PerceiverResampler", + "Idefics2VisionTransformer", "Idefics3VisionTransformer", ] @@ -225,7 +227,6 @@ "BeitForMaskedImageModeling", "ChineseCLIPTextModel", "ChineseCLIPVisionModel", - "CLIPTextModel", "CLIPTextModelWithProjection", "CLIPVisionModelWithProjection", "ClvpForCausalLM", @@ -327,6 +328,7 @@ "SiglipVisionModel", "SiglipTextModel", "ChameleonVQVAE", # no autoclass for VQ-VAE models + "CLIPTextModel", "MoshiForConditionalGeneration", # no auto class for speech-to-speech ]