Skip to content

Commit

Permalink
Merge branch 'huggingface:main' into fix_fsdp_with_fp8_in_trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
eljandoubi authored Oct 22, 2024
2 parents b94376d + 93352e8 commit b9b9eb4
Show file tree
Hide file tree
Showing 87 changed files with 4,886 additions and 2,017 deletions.
16 changes: 16 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [OLMoE](https://huggingface.co/docs/transformers/model_doc/olmoe#transformers.OlmoeModel)
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
* [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
* [PhiMoE](https://huggingface.co/docs/transformers/model_doc/phimoe#transformers.PhimoeModel)
Expand All @@ -88,6 +89,10 @@ FlashAttention-2 is currently supported for the following architectures:
* [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder)
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
* [Qwen2VL](https://huggingface.co/docs/transformers/model_doc/qwen2_vl#transformers.Qwen2VLModel)
* [RAG](https://huggingface.co/docs/transformers/model_doc/rag#transformers.RagModel)
* [SpeechEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/speech_encoder_decoder#transformers.SpeechEncoderDecoderModel)
* [VisionEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/vision_encoder_decoder#transformers.VisionEncoderDecoderModel)
* [VisionTextDualEncoder](https://huggingface.co/docs/transformers/model_doc/vision_text_dual_encoder#transformers.VisionTextDualEncoderModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
* [Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel)
Expand Down Expand Up @@ -225,6 +230,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Dinov2](https://huggingface.co/docs/transformers/en/model_doc/dinov2)
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
* [Dpr](https://huggingface.co/docs/transformers/model_doc/dpr#transformers.DprReader)
* [EncoderDecoder](https://huggingface.co/docs/transformers/model_doc/encoder_decoder#transformers.EncoderDecoderModel)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
Expand All @@ -233,11 +239,16 @@ For now, Transformers supports SDPA inference and training for the following arc
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel)
* [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel)
* [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model)
* [Idefics3](https://huggingface.co/docs/transformers/model_doc/idefics3#transformers.Idefics3Model)
* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel)
* [GraniteMoe](https://huggingface.co/docs/transformers/model_doc/granitemoe#transformers.GraniteMoeModel)
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next)
* [Llava-NeXT-Video](https://huggingface.co/docs/transformers/model_doc/llava_next_video)
* [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision)
* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100#transformers.M2M100Model)
* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi)
Expand Down Expand Up @@ -277,10 +288,15 @@ For now, Transformers supports SDPA inference and training for the following arc
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron)
* [SpeechEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/speech_encoder_decoder#transformers.SpeechEncoderDecoderModel)
* [VideoLlava](https://huggingface.co/docs/transformers/model_doc/video_llava)
* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)
* [VisionEncoderDecoder](https://huggingface.co/docs/transformers/model_doc/vision_encoder_decoder#transformers.VisionEncoderDecoderModel)
* [ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.ViTModel)
* [ViTHybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid#transformers.ViTHybridModel)
* [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel)
* [ViTMSN](https://huggingface.co/docs/transformers/model_doc/vit_msn#transformers.ViTMSNModel)
* [VisionTextDualEncoder](https://huggingface.co/docs/transformers/model_doc/vision_text_dual_encoder#transformers.VisionTextDualEncoderModel)
* [VideoMAE](https://huggingface.co/docs/transformers/model_doc/videomae#transformers.VideoMAEModell)
* [ViViT](https://huggingface.co/docs/transformers/model_doc/vivit#transformers.VivitModel)
* [wav2vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/ko/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@
- local: in_translation
title: (번역중) XLSR-Wav2Vec2
title: (번역중) 오디오 모델
- isExpanded: false
- isExpanded: false
sections:
- local: model_doc/vivit
title: ViViT
Expand Down
6 changes: 1 addition & 5 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,11 +1475,7 @@ def from_legacy_cache(
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor`
if self.self_attention_cache.key_cache == []:
return 0
if len(self.self_attention_cache.key_cache) > 1 and self.self_attention_cache.key_cache[layer_idx] == []:
return 0
return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
return self.self_attention_cache.get_seq_length(layer_idx)

def reset(self):
if hasattr(self.self_attention_cache, "reset"):
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def __init__(self, **kwargs):

# Attention implementation to use, if relevant.
self._attn_implementation_internal = kwargs.pop("attn_implementation", None)
self._attn_implementation_autoset = False

# Drop the transformers version info
self.transformers_version = kwargs.pop("transformers_version", None)
Expand Down Expand Up @@ -776,6 +777,10 @@ def __eq__(self, other):
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"

def __iter__(self):
for attr in self.__dict__:
yield attr

def to_diff_dict(self) -> Dict[str, Any]:
"""
Removes all attributes from config which correspond to the default config attributes for better readability and
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class GenerationMode(ExplicitEnum):

class GenerationConfig(PushToHubMixin):
# no-format
rf"""
"""
Class that holds a configuration for a generation task. A `generate` call supports the following generation methods
for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,8 +1535,12 @@ def _prepare_generation_config(
def _get_initial_cache_position(self, input_ids, model_kwargs):
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
if "inputs_embeds" in model_kwargs:
if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:
cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder:
cache_position = (
torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
)
else:
cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1

Expand Down Expand Up @@ -1633,7 +1637,7 @@ def get_layer_device_map(execution_device_map: Optional[dict] = None):

cache_kwargs = {
"config": self.config.get_text_config(),
"max_batch_size": batch_size,
"batch_size": batch_size,
"max_cache_len": max_cache_len,
"device": device,
"dtype": cache_dtype,
Expand Down
58 changes: 45 additions & 13 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,9 +1420,10 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
# Save config and origin of the pretrained weights if given in model
config = self._autoset_attn_implementation(
config, torch_dtype=torch.get_default_dtype(), check_device_map=False
)
if not getattr(config, "_attn_implementation_autoset", False):
config = self._autoset_attn_implementation(
config, torch_dtype=torch.get_default_dtype(), check_device_map=False
)
self.config = config

self.name_or_path = config.name_or_path
Expand Down Expand Up @@ -1500,6 +1501,9 @@ def _from_config(cls, config, **kwargs):
torch_dtype (`torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype.
"""
# when we init a model from within another model (e.g. VLMs) and dispatch on FA2
# a warning is raised that dtype should be fp16. Since we never pass dtype from within
# modeling code, we can try to infer it here same way as done in `from_pretrained`
torch_dtype = kwargs.pop("torch_dtype", torch.get_default_dtype())
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)

Expand All @@ -1518,12 +1522,13 @@ def _from_config(cls, config, **kwargs):
attn_implementation = None

config._attn_implementation = kwargs.pop("attn_implementation", attn_implementation)
config = cls._autoset_attn_implementation(
config,
use_flash_attention_2=use_flash_attention_2,
check_device_map=False,
torch_dtype=torch_dtype,
)
if not getattr(config, "_attn_implementation_autoset", False):
config = cls._autoset_attn_implementation(
config,
use_flash_attention_2=use_flash_attention_2,
check_device_map=False,
torch_dtype=torch_dtype,
)

if is_deepspeed_zero3_enabled():
import deepspeed
Expand Down Expand Up @@ -1570,7 +1575,11 @@ def _autoset_attn_implementation(
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
)

if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2"]:
if not isinstance(config._attn_implementation, dict) and config._attn_implementation not in [
"eager",
"sdpa",
"flash_attention_2",
]:
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
Expand All @@ -1581,6 +1590,22 @@ def _autoset_attn_implementation(
# If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available.
requested_attn_implementation = config._attn_implementation_internal

# Composite models consisting of several PretrainedModels have to specify attention impl as a dict
# where keys are sub-config names. But most people will specify one `str` which means that should dispatch it
# for all sub-models.
# Below we check if a config is composite and manually prepare a dict of attn impl if not already passed as a dict.
# Later each sub-module will dispatch with its own attn impl, by calling `XXXModel._from_config(config.text_config)`
# If any of sub-modules doesn't support requested attn, an error will be raised. See https://github.com/huggingface/transformers/pull/32238
for key in config:
if isinstance(getattr(config, key), PretrainedConfig):
sub_config = getattr(config, key)
curr_attn_implementation = (
requested_attn_implementation
if not isinstance(requested_attn_implementation, dict)
else requested_attn_implementation.get(key, None)
)
sub_config._attn_implementation_internal = curr_attn_implementation

if use_flash_attention_2:
logger.warning_once(
'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.'
Expand Down Expand Up @@ -1611,9 +1636,12 @@ def _autoset_attn_implementation(
"Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
)
torch.backends.cuda.enable_flash_sdp(False)
elif isinstance(requested_attn_implementation, dict):
config._attn_implementation = None
else:
config._attn_implementation = "eager"

config._attn_implementation_autoset = True
return config

@classmethod
Expand Down Expand Up @@ -2771,6 +2799,9 @@ def save_pretrained(
# Attach architecture to the config
model_to_save.config.architectures = [model_to_save.__class__.__name__]

# Unset attn implementation so it can be set to another one when loading back
model_to_save.config._attn_implementation_autoset = False

# If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
if self._auto_class is not None:
Expand Down Expand Up @@ -4055,9 +4086,10 @@ def from_pretrained(
init_contexts.append(init_empty_weights())

config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
config = cls._autoset_attn_implementation(
config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map
)
if not getattr(config, "_attn_implementation_autoset", False):
config = cls._autoset_attn_implementation(
config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map
)

with ContextManagers(init_contexts):
# Let's make sure we don't run the init function of buffer modules
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,24 @@ def __init__(self, config: ASTConfig) -> None:
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob

def forward(
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
self,
hidden_states: torch.FloatTensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
if output_attentions or head_mask is not None:
logger.warning_once(
"`ASTSdpaAttention` is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"`output_attentions=True` or `head_mask`. Falling back to the manual attention implementation, but "
"specifying the manual implementation will be required from Transformers version v5.0.0 onwards. "
'This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
head_mask=head_mask,
output_attentions=output_attentions,
)

mixed_query_layer = self.query(hidden_states)

key_layer = self.transpose_for_scores(self.key(hidden_states))
Expand Down
17 changes: 5 additions & 12 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
config_class = Blip2Config
base_model_prefix = "blip"
supports_gradient_checkpointing = True

_no_split_modules = [
"Blip2Attention",
"Blip2QFormerMultiHeadAttention",
Expand Down Expand Up @@ -1455,13 +1456,9 @@ def __init__(self, config: Blip2Config):

self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
language_model = AutoModelForCausalLM.from_config(config.text_config)
else:
language_model = AutoModelForSeq2SeqLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)

# Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None:
Expand Down Expand Up @@ -2020,13 +2017,9 @@ def __init__(self, config: Blip2Config):

self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
if config.use_decoder_only_language_model:
language_model = AutoModelForCausalLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
language_model = AutoModelForCausalLM.from_config(config.text_config)
else:
language_model = AutoModelForSeq2SeqLM.from_config(
config.text_config, attn_implementation=config._attn_implementation
)
language_model = AutoModelForSeq2SeqLM.from_config(config.text_config)

# Update _tied_weights_keys using the base model used.
if language_model._tied_weights_keys is not None:
Expand Down
Loading

0 comments on commit b9b9eb4

Please sign in to comment.