From 4ee821501cbd09d2c19d3bf3c6854591b95b910f Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Tue, 14 Jan 2025 19:04:48 +0800 Subject: [PATCH] [Feature]Support transformers==4.48 (#985) * update requirements * support internlm3, llama, mistral, mixtral, qwen2 and qwen2moe in transformers==4.48 --- requirements/deepspeed.txt | 3 +- requirements/runtime.txt | 19 +- xtuner/model/modules/dispatch/__init__.py | 93 ++-- xtuner/model/modules/dispatch/internlm3.py | 132 +++++ xtuner/model/modules/dispatch/llama.py | 567 ++++----------------- xtuner/model/modules/dispatch/mistral.py | 487 ++++-------------- xtuner/model/modules/dispatch/qwen2.py | 422 ++++----------- xtuner/model/sft.py | 13 +- 8 files changed, 446 insertions(+), 1290 deletions(-) create mode 100644 xtuner/model/modules/dispatch/internlm3.py diff --git a/requirements/deepspeed.txt b/requirements/deepspeed.txt index d7f9c3c0d..f6cda0a03 100644 --- a/requirements/deepspeed.txt +++ b/requirements/deepspeed.txt @@ -1,3 +1,2 @@ -# Minimum 0.12.3, see https://github.com/microsoft/DeepSpeed/pull/4587 -deepspeed>=0.12.3 +deepspeed==0.16.2 mpi4py-mpich diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 3a4d2f84e..b07b16539 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,27 +1,18 @@ -# Minimum 0.40.0.post4 to fix some 4-bit precision bugs -bitsandbytes>=0.40.0.post4 -# Minimum 2.16.0 to fix some bugs, see https://github.com/huggingface/datasets/pull/6444 -datasets>=2.16.0 +bitsandbytes==0.45.0 +datasets>=3.2.0 einops # Minimum 0.1.2 to fix some bugs, see https://github.com/InternLM/lagent/pull/44 lagent>=0.1.2 # Minimum 0.10.3 to support distributed evaluation for MMBench # see https://github.com/open-mmlab/mmengine/pull/1469 -mmengine>=0.10.3 +mmengine==0.10.6 openpyxl -# Minimum 0.4.0 to support QLoRA, see https://github.com/huggingface/peft/pull/476 -peft>=0.4.0 +peft>=0.14.0 scikit-image scipy SentencePiece tiktoken torch torchvision -# Minimum 4.36.0 to support `Cache` data structure used by KV Cache -# Registering a causal mask in `LlamaModel` is not friendly for very large -# `max_position_embeddings`. Refer to -# https://github.com/huggingface/transformers/blob/v4.38.0/src/transformers/models/llama/modeling_llama.py#L921-L923 -# transformers >= 4.43.0 use _flash_attention_forward but not self._flash_attention_forward -# to calculate attn output which lead to bc braeking -transformers>=4.36.0,!=4.38.0,!=4.38.1,!=4.38.2,<=4.42.4 +transformers==4.48.0 transformers_stream_generator diff --git a/xtuner/model/modules/dispatch/__init__.py b/xtuner/model/modules/dispatch/__init__.py index e81ec7a3a..41340e624 100644 --- a/xtuner/model/modules/dispatch/__init__.py +++ b/xtuner/model/modules/dispatch/__init__.py @@ -34,76 +34,75 @@ 'possible to return the `attn_weights`.') LOWEST_TRANSFORMERS_VERSION = dict( + InternLM3ForCausalLM=digit_version('4.48'), InternLM2ForCausalLM=digit_version('4.36'), InternLMForCausalLM=digit_version('4.36'), - LlamaForCausalLM=digit_version('4.36'), + LlamaForCausalLM=digit_version('4.48'), Phi3ForCausalLM=digit_version('4.39'), - MistralForCausalLM=digit_version('4.36'), + MistralForCausalLM=digit_version('4.48'), # Training mixtral with lower version may lead to nccl timeout # Refer to https://github.com/microsoft/DeepSpeed/issues/5066 - MixtralForCausalLM=digit_version('4.40'), + MixtralForCausalLM=digit_version('4.48'), CohereForCausalLM=digit_version('4.40'), - Qwen2ForCausalLM=digit_version('4.39'), - Qwen2MoeForCausalLM=digit_version('4.40'), + Qwen2ForCausalLM=digit_version('4.48'), + Qwen2MoeForCausalLM=digit_version('4.48'), DeepseekV2ForCausalLM=digit_version('4.40'), ) ATTN_DISPATCH_MAPPING = dict( + InternLM3Attention=LazyObject('xtuner.model.modules.dispatch.internlm3', + 'internlm3_attn_forward'), InternLM2FlashAttention2=LazyObject( 'xtuner.model.modules.dispatch.internlm2', 'internlm2_attn_forward'), InternLMAttention=LazyObject('xtuner.model.modules.dispatch.internlm', 'internlm_attn_forward'), - LlamaFlashAttention2=LazyObject('xtuner.model.modules.dispatch.llama', - 'llama_attn_forward'), + LlamaAttention=LazyObject('xtuner.model.modules.dispatch.llama', + 'llama_attn_forward'), Phi3FlashAttention2=LazyObject('xtuner.model.modules.dispatch.phi3', 'phi3_attn_forward'), - MistralFlashAttention2=LazyObject('xtuner.model.modules.dispatch.mistral', - 'mistral_attn_forward'), - MixtralFlashAttention2=LazyObject('xtuner.model.modules.dispatch.mistral', - 'mistral_attn_forward'), + MistralAttention=LazyObject('xtuner.model.modules.dispatch.mistral', + 'mistral_attn_forward'), + MixtralAttention=LazyObject('xtuner.model.modules.dispatch.mistral', + 'mistral_attn_forward'), CohereFlashAttention2=LazyObject('xtuner.model.modules.dispatch.cohere', 'cohere_attn_forward'), - Qwen2FlashAttention2=LazyObject('xtuner.model.modules.dispatch.qwen2', - 'qwen2_attn_forward'), - Qwen2MoeFlashAttention2=LazyObject('xtuner.model.modules.dispatch.qwen2', - 'qwen2_attn_forward'), + Qwen2Attention=LazyObject('xtuner.model.modules.dispatch.qwen2', + 'qwen2_attn_forward'), + Qwen2MoeAttention=LazyObject('xtuner.model.modules.dispatch.qwen2', + 'qwen2_attn_forward'), DeepseekV2FlashAttention2=LazyObject( 'xtuner.model.modules.dispatch.deepseek_v2', 'deepseek_attn_forward'), ) -ATTN_LEGACY_DISPATCH_MAPPING = dict( - LlamaFlashAttention2=LazyObject('xtuner.model.modules.dispatch.llama', - 'llama_attn_forward_legacy'), ) - VARLEN_ATTN_DISPATCH_MAPPING = dict( + InternLM3Attention=LazyObject('xtuner.model.modules.dispatch.internlm3', + 'internlm3_attn_forward'), InternLM2FlashAttention2=LazyObject( 'xtuner.model.modules.dispatch.internlm2', 'internlm2_varlen_attn_forward'), InternLMAttention=LazyObject('xtuner.model.modules.dispatch.internlm', 'internlm_varlen_attn_forward'), - LlamaFlashAttention2=LazyObject('xtuner.model.modules.dispatch.llama', - 'llama_varlen_attn_forward'), + LlamaAttention=LazyObject('xtuner.model.modules.dispatch.llama', + 'llama_attn_forward'), Phi3FlashAttention2=LazyObject('xtuner.model.modules.dispatch.phi3', 'phi3_varlen_attn_forward'), - MistralFlashAttention2=LazyObject('xtuner.model.modules.dispatch.mistral', - 'mistral_varlen_attn_forward'), - MixtralFlashAttention2=LazyObject('xtuner.model.modules.dispatch.mistral', - 'mistral_varlen_attn_forward'), + MistralAttention=LazyObject('xtuner.model.modules.dispatch.mistral', + 'mistral_attn_forward'), + MixtralAttention=LazyObject('xtuner.model.modules.dispatch.mistral', + 'mistral_attn_forward'), CohereFlashAttention2=None, - Qwen2FlashAttention2=LazyObject('xtuner.model.modules.dispatch.qwen2', - 'qwen2_varlen_attn_forward'), - Qwen2MoeFlashAttention2=LazyObject('xtuner.model.modules.dispatch.qwen2', - 'qwen2_varlen_attn_forward'), + Qwen2Attention=LazyObject('xtuner.model.modules.dispatch.qwen2', + 'qwen2_attn_forward'), + Qwen2MoeAttention=LazyObject('xtuner.model.modules.dispatch.qwen2', + 'qwen2_attn_forward'), DeepseekV2FlashAttention2=LazyObject( 'xtuner.model.modules.dispatch.deepseek_v2', 'deepseek_varlen_attn_forward'), ) -VARLEN_ATTN_LEGACY_DISPATCH_MAPPING = dict( - LlamaFlashAttention2=LazyObject('xtuner.model.modules.dispatch.llama', - 'llama_varlen_attn_forward_legacy'), ) - RMS_DISPATCH_MAPPING = dict( + InternLM3RMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels', + 'rms_norm_forward'), InternLM2RMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels', 'rms_norm_forward'), InternLMRMSNorm=LazyObject('xtuner.model.modules.dispatch.triton_kernels', @@ -126,12 +125,7 @@ ROTE_DISPATCH_MAPPING = dict( InternLMRotaryEmbedding=LazyObject( - 'xtuner.model.modules.dispatch.internlm', 'InternLMRotaryEmbedding'), - MistralRotaryEmbedding=LazyObject('xtuner.model.modules.dispatch.mistral', - 'MistralRotaryEmbedding'), - MixtralRotaryEmbedding=LazyObject('xtuner.model.modules.dispatch.mistral', - 'MistralRotaryEmbedding'), -) + 'xtuner.model.modules.dispatch.internlm', 'InternLMRotaryEmbedding'), ) def log_once(func): @@ -158,15 +152,7 @@ def dispatch_attn_forward(model): attn_forward = None for module in model.modules(): name = type(module).__name__ - if (IS_LOW_VERSION_TRANSFORMERS - and name in ATTN_LEGACY_DISPATCH_MAPPING): - if attn_forward is None: - attn_forward = ATTN_LEGACY_DISPATCH_MAPPING[name] - attn_forward = attn_forward.build() - print_log(f'Dispatch {name} legacy forward. {NO_ATTN_WEIGHTS_MSG}', - 'current') - module.forward = types.MethodType(attn_forward, module) - elif name in ATTN_DISPATCH_MAPPING: + if name in ATTN_DISPATCH_MAPPING: if attn_forward is None: attn_forward = ATTN_DISPATCH_MAPPING[name] attn_forward = attn_forward.build() @@ -186,16 +172,7 @@ def dispatch_varlen_attn_forward(model): varlen_attn_forward = None for module in model.modules(): name = type(module).__name__ - if (IS_LOW_VERSION_TRANSFORMERS - and name in VARLEN_ATTN_LEGACY_DISPATCH_MAPPING): - if varlen_attn_forward is None: - varlen_attn_forward = VARLEN_ATTN_LEGACY_DISPATCH_MAPPING[name] - varlen_attn_forward = varlen_attn_forward.build() - print_log( - f'Dispatch legacy {name} varlen forward. ' - f'{NO_ATTN_WEIGHTS_MSG}', 'current') - module.forward = types.MethodType(varlen_attn_forward, module) - elif name in VARLEN_ATTN_DISPATCH_MAPPING: + if name in VARLEN_ATTN_DISPATCH_MAPPING: if varlen_attn_forward is None: varlen_attn_forward = VARLEN_ATTN_DISPATCH_MAPPING[name] varlen_attn_forward = varlen_attn_forward.build() diff --git a/xtuner/model/modules/dispatch/internlm3.py b/xtuner/model/modules/dispatch/internlm3.py new file mode 100644 index 000000000..0532bb0ae --- /dev/null +++ b/xtuner/model/modules/dispatch/internlm3.py @@ -0,0 +1,132 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import warnings +from typing import Callable, Optional, Tuple + +import torch +import torch.distributed as dist +from mmengine import MessageHub +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.models.llama.modeling_llama import (apply_rotary_pos_emb, + eager_attention_forward, + repeat_kv) +from transformers.processing_utils import Unpack + +from xtuner.parallel.sequence import get_sequence_parallel_world_size +from xtuner.parallel.sequence.attention import ( + post_process_for_sequence_parallel_attn, + pre_process_for_sequence_parallel_attn) + + +def internlm3_attn_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], +): + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose( + 1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose( + 1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed + # for the static cache + cache_kwargs = { + 'sin': sin, + 'cos': cos, + 'cache_position': cache_position + } + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs) + + # different from LlamaAttention.forward + # repeat k/v heads if n_kv_heads < n_heads for sequence parallel + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + enable_sequence_parallel = ( + dist.is_initialized() and get_sequence_parallel_world_size() > 1 + and self.training) + if enable_sequence_parallel: + # Reashape for `pre_process_for_sequence_parallel_attn` + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + query_states, key_states, value_states = \ + pre_process_for_sequence_parallel_attn( + query_states, key_states, value_states) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + # different places end + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != 'eager': + if self.config._attn_implementation == 'sdpa' and kwargs.get( + 'output_attentions', False): + warnings.warn( + '`torch.nn.functional.scaled_dot_product_attention` does not ' + 'support `output_attentions=True`. Falling back to eager ' + 'attention. This warning can be removed using the argument' + ' `attn_implementation="eager"` when loading the model.') + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation] + + message_hub = MessageHub.get_instance('varlen_attn_args') + rank = dist.get_rank() + cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') + use_varlen_atten = (cumulative_len is not None) + if use_varlen_atten: + # When gradient_checkpointing is enabled, the flash_attn_kwargs + # parameter is not automatically passed to the model. In such + # cases, parameters like cu_seq_lens_q and max_length_q are + # computed based on position_ids. However, when sequence + # parallel is enabled, position_ids is split along the + # sequence length, leading to incorrect calculations of these + # parameters. + # To address this issue, it is necessary to manually provide + # the flash_attn_kwargs parameters. + max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') + kwargs['cu_seq_lens_q'] = cumulative_len + kwargs['cu_seq_lens_k'] = cumulative_len + kwargs['max_length_q'] = max_seqlen + kwargs['max_length_k'] = max_seqlen + kwargs.pop('position_ids', None) + + # Hacky: `sdpa_attention_forward` does repeat_kv based on + # module.num_key_value_groups but it is done before + num_key_value_groups = self.num_key_value_groups + self.num_key_value_groups = 1 + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + self.num_key_value_groups = num_key_value_groups + + # different from LlamaAttention.forward + if enable_sequence_parallel: + attn_output = post_process_for_sequence_parallel_attn(attn_output) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights diff --git a/xtuner/model/modules/dispatch/llama.py b/xtuner/model/modules/dispatch/llama.py index 8132096fd..a81dff790 100644 --- a/xtuner/model/modules/dispatch/llama.py +++ b/xtuner/model/modules/dispatch/llama.py @@ -1,79 +1,51 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings -from typing import Optional, Tuple +from typing import Callable, Optional, Tuple import torch import torch.distributed as dist from mmengine import MessageHub +from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.models.llama.modeling_llama import (apply_rotary_pos_emb, + eager_attention_forward, repeat_kv) -from transformers.utils import is_flash_attn_greater_or_equal_2_10 +from transformers.processing_utils import Unpack -from .attention import (SUPPORT_FLASH2, flash_attn_w_mask, flash_attn_wo_mask, - varlen_flash_attn) -from .triton_kernels import apply_rotary_emb - -try: - from transformers.cache_utils import Cache -except ImportError: - - class Cache: - pass - - -def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """The hidden states go from (batch, seqlen, num_key_value_heads, head_dim) - to (batch, seqlen, num_attention_heads, head_dim)""" - batch, slen, num_key_value_heads, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, :, - None, :].expand(batch, slen, - num_key_value_heads, n_rep, - head_dim) - return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, - head_dim) +from xtuner.parallel.sequence import get_sequence_parallel_world_size +from xtuner.parallel.sequence.attention import ( + post_process_for_sequence_parallel_attn, + pre_process_for_sequence_parallel_attn) +# modified from transformers.model.llama.modeling_llama.LlamaAttention.forward +# and support sequence parallel def llama_attn_forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, - **kwargs, + **kwargs: Unpack[FlashAttentionKwargs], ): - # Modified from https://github.com/huggingface/transformers/blob/66ce9593fdb8e340df546ddd0774eb444f17a12c/src/transformers/models/llama/modeling_llama.py#L422 # noqa:E501 - output_attentions = False + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - bsz, q_len, _ = hidden_states.size() + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose( + 1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose( + 1, 2) - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - query_states = query_states.view(bsz, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - past_key_value = getattr(self, 'past_key_value', past_key_value) - if past_key_value is not None: - # sin and cos are specific to RoPE models; - # cache_position needed for the static cache + # sin and cos are specific to RoPE models; cache_position needed + # for the static cache cache_kwargs = { 'sin': sin, 'cos': cos, @@ -82,443 +54,80 @@ def llama_attn_forward( key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs) + # different from LlamaAttention.forward + # repeat k/v heads if n_kv_heads < n_heads for sequence parallel key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - assert SUPPORT_FLASH2 - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - # In PEFT, usually we cast the layer norms in float32 for training - # stability reasons therefore the input hidden states gets silently - # casted in float32. Hence, we need cast them back in the correct dtype - # just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not - # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, '_pre_quantization_dtype'): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - if is_flash_attn_greater_or_equal_2_10(): - causal = self.is_causal - else: - # TODO: Remove the `q_len != 1` check once Flash Attention for RoCm - # is bumped to 2.1. For details, please see the comment in - # LlamaFlashAttention2 __init__. - causal = self.is_causal and q_len != 1 - - # the shape of attention_mask used by flash_attn and - # F.scaled_dot_product_attention are different - assert attention_mask is None or attention_mask.ndim == 2, \ - ('When using flash_attn, attention_mask.ndim should equal to 2.' - f'But got attention_mask.shape = {attention_mask.shape}.' - 'We can pass the `attn_implementation="flash_attention_2"` flag ' - 'to `.from_pretrained` method when instantiating a Internlm2 ' - 'model.') - - if attention_mask is not None: - attn_output = flash_attn_w_mask( - query_states, - key_states, - value_states, - attention_mask, - causal=causal, - dropout_p=dropout_rate, - training=self.training) - else: - attn_output = flash_attn_wo_mask( - query_states, - key_states, - value_states, - causal=causal, - dropout_p=dropout_rate, - training=self.training) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -def llama_attn_forward_legacy( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - # Modified from https://github.com/huggingface/transformers/blob/ced9fd86f55ebb6b656c273f6e23f8ba50652f83/src/transformers/models/llama/modeling_llama.py#L331 # noqa:E501 - if 'padding_mask' in kwargs: - warnings.warn( - 'Passing `padding_mask` is deprecated and will be removed in ' - 'v4.37. Please make sure use `attention_mask` instead.`') - - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - 'The cache structure has changed since version v4.36. ' - f'If you are using {self.__class__.__name__} ' - 'for auto-regressive decoding with k/v caching, ' - 'please make sure to initialize the attention class ' - 'with a layer index.') - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, - self.layer_idx) - assert position_ids is not None - if self.training: - cos, sin = self.rotary_emb( - value_states, seq_len=position_ids.max() + 1) - else: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs) - - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - - assert SUPPORT_FLASH2 - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - # In PEFT, usually we cast the layer norms in float32 for training - # stability reasons therefore the input hidden states gets silently - # casted in float32. Hence, we need cast them back in the correct dtype - # just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not - # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, '_pre_quantization_dtype'): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - if is_flash_attn_greater_or_equal_2_10(): - causal = self.is_causal - else: - # TODO: Remove the `q_len != 1` check once Flash Attention for RoCm - # is bumped to 2.1. For details, please see the comment in - # LlamaFlashAttention2 __init__. - causal = self.is_causal and q_len != 1 - - # the shape of attention_mask used by flash_attn and - # F.scaled_dot_product_attention are different - assert attention_mask is None or attention_mask.ndim == 2, \ - ('When using flash_attn, attention_mask.ndim should equal to 2.' - f'But got attention_mask.shape = {attention_mask.shape}.' - 'We can pass the `attn_implementation="flash_attention_2"` flag ' - 'to `.from_pretrained` method when instantiating a Internlm2 ' - 'model.') - - if attention_mask is not None: - attn_output = flash_attn_w_mask( - query_states, - key_states, - value_states, - attention_mask=attention_mask, - causal=causal, - dropout_p=dropout_rate, - training=self.training) - else: - attn_output = flash_attn_wo_mask( - query_states, - key_states, - value_states, - causal=causal, - dropout_p=dropout_rate, - training=self.training) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - - # Due to the implementation of the PyTorch version of flash attention, - # even when the output_attentions flag is set to True, it is not possible - # to return the attn_weights. - return attn_output, None, past_key_value - - -def llama_varlen_attn_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - - message_hub = MessageHub.get_instance('varlen_attn_args') - rank = dist.get_rank() - cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') - max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') - use_varlen_atten = (cumulative_len is not None) - - if 'padding_mask' in kwargs: - warnings.warn('Passing `padding_mask` is deprecated and will be ' - 'removed in v4.37. Please make sure use ' - '`attention_mask` instead.`') - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, position_ids) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin) - - past_key_value = getattr(self, 'past_key_value', past_key_value) - - if past_key_value is not None: - # sin and cos are specific to RoPE models; - # cache_position needed for the static cache - cache_kwargs = { - 'sin': sin, - 'cos': cos, - 'cache_position': cache_position - } - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - # repeat kv for sequence parallel - key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) - value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training - # stability reasons therefore the input hidden states gets silently casted - # in float32. Hence, we need cast them back in the correct dtype - # just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not - # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, '_pre_quantization_dtype'): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - assert SUPPORT_FLASH2 - if use_varlen_atten: - attn_output = varlen_flash_attn( - query_states, - key_states, - value_states, - cumulative_len, - max_seqlen, - causal=True, - dropout_p=dropout_rate, - training=self.training) - else: - attn_output = flash_attn_wo_mask( - query_states, - key_states, - value_states, - causal=True, - training=self.training) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - - return attn_output, None, past_key_value - - -def llama_varlen_attn_forward_legacy( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - - message_hub = MessageHub.get_instance('varlen_attn_args') - rank = dist.get_rank() - cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') - max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') - use_varlen_atten = (cumulative_len is not None) - - if 'padding_mask' in kwargs: - warnings.warn('Passing `padding_mask` is deprecated and will be ' - 'removed in v4.37. Please make sure use ' - '`attention_mask` instead.`') - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim) - - kv_seq_len = key_states.shape[-3] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - 'The cache structure has changed since version v4.36. ' - f'If you are using {self.__class__.__name__} ' - 'for auto-regressive decoding with k/v caching, ' - 'please make sure to initialize the attention class ' - 'with a layer index.') - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, - self.layer_idx) - - if use_varlen_atten: - cos, sin = self.rotary_emb(value_states, max_seqlen) - # position_ids (1, seq_len) - # cos, sin (1, seq_len, dim) -> (seq_len, dim) - cos = cos[position_ids].squeeze(0) - sin = sin[position_ids].squeeze(0) - query_states = apply_rotary_emb(query_states, cos, sin) - key_states = apply_rotary_emb(key_states, cos, sin) - else: + enable_sequence_parallel = ( + dist.is_initialized() and get_sequence_parallel_world_size() > 1 + and self.training) + if enable_sequence_parallel: + # Reashape for `pre_process_for_sequence_parallel_attn` query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - cos, sin = self.rotary_emb(value_states, kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids) - - if past_key_value is not None: - cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs) - + query_states, key_states, value_states = \ + pre_process_for_sequence_parallel_attn( + query_states, key_states, value_states) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - # repeat kv for sequence parallel - key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) - value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training - # stability reasons therefore the input hidden states gets silently casted - # in float32. Hence, we need cast them back in the correct dtype - # just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not - # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, '_pre_quantization_dtype'): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != 'eager': + if self.config._attn_implementation == 'sdpa' and kwargs.get( + 'output_attentions', False): + warnings.warn( + '`torch.nn.functional.scaled_dot_product_attention` does not ' + 'support `output_attentions=True`. Falling back to eager ' + 'attention. This warning can be removed using the argument' + ' `attn_implementation="eager"` when loading the model.') else: - target_dtype = self.q_proj.weight.dtype - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation] - assert SUPPORT_FLASH2 + message_hub = MessageHub.get_instance('varlen_attn_args') + rank = dist.get_rank() + cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') + use_varlen_atten = (cumulative_len is not None) if use_varlen_atten: - attn_output = varlen_flash_attn( - query_states, - key_states, - value_states, - cumulative_len, - max_seqlen, - causal=True, - dropout_p=dropout_rate, - training=self.training) - else: - attn_output = flash_attn_wo_mask( - query_states, - key_states, - value_states, - causal=True, - dropout_p=dropout_rate, - training=self.training) - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - + # When gradient_checkpointing is enabled, the flash_attn_kwargs + # parameter is not automatically passed to the model. In such + # cases, parameters like cu_seq_lens_q and max_length_q are + # computed based on position_ids. However, when sequence + # parallel is enabled, position_ids is split along the + # sequence length, leading to incorrect calculations of these + # parameters. + # To address this issue, it is necessary to manually provide + # the flash_attn_kwargs parameters. + max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') + kwargs['cu_seq_lens_q'] = cumulative_len + kwargs['cu_seq_lens_k'] = cumulative_len + kwargs['max_length_q'] = max_seqlen + kwargs['max_length_k'] = max_seqlen + kwargs.pop('position_ids', None) + + # Hacky: `sdpa_attention_forward` does repeat_kv based on + # module.num_key_value_groups but it is done before + num_key_value_groups = self.num_key_value_groups + self.num_key_value_groups = 1 + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + self.num_key_value_groups = num_key_value_groups + + # different from LlamaAttention.forward + if enable_sequence_parallel: + attn_output = post_process_for_sequence_parallel_attn(attn_output) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - # Due to the implementation of the PyTorch version of flash attention, - # even when the output_attentions flag is set to True, it is not possible - # to return the attn_weights. - return attn_output, None, past_key_value + return attn_output, attn_weights diff --git a/xtuner/model/modules/dispatch/mistral.py b/xtuner/model/modules/dispatch/mistral.py index dc6c7fed8..da87ac189 100644 --- a/xtuner/model/modules/dispatch/mistral.py +++ b/xtuner/model/modules/dispatch/mistral.py @@ -1,447 +1,134 @@ # Copyright (c) OpenMMLab. All rights reserved. -import inspect import warnings -from typing import Optional +from typing import Callable, Optional, Tuple import torch import torch.distributed as dist -import torch.nn as nn -import transformers from mmengine import MessageHub -from mmengine.utils import digit_version from transformers.cache_utils import Cache -from transformers.models.mistral.modeling_mistral import (apply_rotary_pos_emb, - repeat_kv) +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.models.mistral.modeling_mistral import ( + apply_rotary_pos_emb, eager_attention_forward, repeat_kv) +from transformers.processing_utils import Unpack from xtuner.parallel.sequence import get_sequence_parallel_world_size from xtuner.parallel.sequence.attention import ( post_process_for_sequence_parallel_attn, pre_process_for_sequence_parallel_attn) -from .attention import flash_attn_wo_mask, varlen_flash_attn -from .triton_kernels import apply_rotary_emb - -SUPPORT_FLASH2 = False - -try: - from flash_attn import flash_attn_func - _flash_supports_window_size = 'window_size' in list( - inspect.signature(flash_attn_func).parameters) - SUPPORT_FLASH2 = True -except ImportError: - pass - -TRANSFORMERS_VERSION = digit_version(transformers.__version__) -IS_LOW_VERSION_TRANSFORMERS = TRANSFORMERS_VERSION < digit_version('4.43') - -if not IS_LOW_VERSION_TRANSFORMERS: - from transformers.modeling_flash_attention_utils import \ - _flash_attention_forward - - -class MistralRotaryEmbedding(nn.Module): - - def __init__(self, - dim, - max_position_embeddings=2048, - base=10000, - device=None): - super().__init__() - - self.dim = dim - self.max_position_embeddings = max_position_embeddings - self.base = base - self.inv_freq = 1.0 / ( - base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, - device=self.inv_freq.device, - dtype=torch.get_default_dtype()) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange( - self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - freqs = torch.einsum('i,j->ij', t, self.inv_freq.to(device)) - # Different from paper, but it uses a different permutation - # in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(device) - self.cos_cached = emb.cos().to(dtype) - self.sin_cached = emb.sin().to(dtype) - - def forward(self, x, seq_len=None): - # x: [bs, num_attention_heads, seq_len, head_size] - if (seq_len > self.max_seq_len_cached - or self.cos_cached.device != x.device # noqa: W503 - or self.cos_cached.dtype != x.dtype): # noqa: W503 - self._set_cos_sin_cache( - seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """The hidden states go from (batch, seqlen, num_key_value_heads, head_dim) - to (batch, seqlen, num_attention_heads, head_dim)""" - batch, slen, num_key_value_heads, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, :, - None, :].expand(batch, slen, - num_key_value_heads, n_rep, - head_dim) - return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, - head_dim) +# modified from transformers.model.mistral.modeling_mistral.MistralAttention.forward and # noqa: E501 +# support sequence parallel def mistral_attn_forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ): - if 'padding_mask' in kwargs: - warnings.warn( - 'Passing `padding_mask` is deprecated and will be removed in ' - 'v4.37. Please make sure use `attention_mask` instead.`') - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop('padding_mask') - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - query_states = query_states.view(bsz, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose( + 1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose( + 1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - 'The cache structure has changed since version v4.36. ' - f'If you are using {self.__class__.__name__} ' - 'for auto-regressive decoding with k/v caching, ' - 'please make sure to initialize the attention class ' - 'with a layer index.') - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, - self.layer_idx) - - assert position_ids is not None - if self.training: - cos, sin = self.rotary_emb( - value_states, seq_len=position_ids.max() + 1) - else: - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids) - - use_sliding_windows = ( - _flash_supports_window_size - and getattr(self.config, 'sliding_window', None) is not None - and kv_seq_len > self.config.sliding_window) + cos, sin) if past_key_value is not None: - # Activate slicing cache only if the config has a value - # `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if (getattr(self.config, 'sliding_window', None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - 'past key must have a shape of (`batch_size, num_heads, ' - 'self.config.sliding_window-1, head_dim`), got' - f' {past_key.shape}') - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat( - [attention_mask, - torch.ones_like(attention_mask[:, -1:])], - dim=-1) - - cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models + # sin and cos are specific to RoPE models; cache_position needed + # for the static cache + cache_kwargs = { + 'sin': sin, + 'cos': cos, + 'cache_position': cache_position + } key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs) + # different from MistralAttention.forward # repeat k/v heads if n_kv_heads < n_heads for sequence parallel key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training - # stability reasons therefore the input hidden states gets silently - # casted in float32. Hence, we need cast them back in the correct dtype - # just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not - # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, '_pre_quantization_dtype'): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) enable_sequence_parallel = ( dist.is_initialized() and get_sequence_parallel_world_size() > 1 and self.training) if enable_sequence_parallel: - query_states, key_states, value_states = \ - pre_process_for_sequence_parallel_attn( - query_states, key_states, value_states) - # num_heads has been changed because of sequence parallel - # `self.num_heads`` is not used in self._flash_attention_forward - # in mistral/mixtral, we are doing this to avoid some unnecessary risk - ori_num_head = self.num_heads - self.num_heads = query_states.shape[-2] - - if IS_LOW_VERSION_TRANSFORMERS: - attn_output = self._flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - query_length=query_states.shape[1], - dropout=dropout_rate, - use_sliding_windows=use_sliding_windows, - ) - else: - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - query_states.shape[1], - dropout=dropout_rate, - sliding_window=getattr(self.config, 'sliding_window', None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) - - if enable_sequence_parallel: - attn_output = post_process_for_sequence_parallel_attn(attn_output) - self.num_heads = ori_num_head - - attn_output = attn_output.reshape(bsz, q_len, - self.hidden_size).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -def mistral_varlen_attn_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, -): - is_training = self.training - - message_hub = MessageHub.get_instance('varlen_attn_args') - rank = dist.get_rank() - cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') - max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') - - assert is_training == (past_key_value is None) - use_varlen_atten = (cumulative_len is not None) - - if 'padding_mask' in kwargs: - warnings.warn( - 'Passing `padding_mask` is deprecated and will be removed in v4.37' - ' Please make sure use `attention_mask` instead.`') - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop('padding_mask') - bsz, q_len, _ = hidden_states.size() - assert bsz == 1, (f'If utilizing local attention, the batch size should be' - f' set to 1, but got {bsz}') - # attention_mask is set to None if no padding token in input_ids - assert attention_mask is None - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim) - - assert _flash_supports_window_size, \ - ('The current flash attention version does not support sliding window ' - 'attention, for a more memory efficient implementation make sure ' - 'to upgrade flash-attn library.') - - kv_seq_len = key_states.shape[-3] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - 'The cache structure has changed since version v4.36. ' - f'If you are using {self.__class__.__name__} ' - 'for auto-regressive decoding with k/v caching, ' - 'please make sure to initialize the attention class ' - 'with a layer index.') - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, - self.layer_idx) - - if use_varlen_atten: - cos, sin = self.rotary_emb(value_states, max_seqlen) - query_states = apply_rotary_emb(query_states, - cos[position_ids].squeeze(0), - sin[position_ids].squeeze(0)) - key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0), - sin[position_ids].squeeze(0)) - else: + # Reashape for `pre_process_for_sequence_parallel_attn` query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - # Because the input can be padded, the absolute sequence length - # depends on the max position id. - rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1) - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids) - - # Activate slicing cache only if the config has a value - # `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if (getattr(self.config, 'sliding_window', None) is not None - and kv_seq_len > self.config.sliding_window # noqa: W503 - and cache_has_contents): # noqa: W503 - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - 'past key must have a shape of (`batch_size, num_heads, ' - 'self.config.sliding_window-1, head_dim`), got' - f' {past_key.shape}') - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat( - [attention_mask, - torch.ones_like(attention_mask[:, -1:])], - dim=-1) - - cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs) + query_states, key_states, value_states = \ + pre_process_for_sequence_parallel_attn( + query_states, key_states, value_states) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - # repeat kv for sequence parallel - key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) - value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for - # training stability reasons, therefore the input hidden states gets - # silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, '_pre_quantization_dtype'): - target_dtype = self.config._pre_quantization_dtype + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != 'eager': + if self.config._attn_implementation == 'sdpa' and kwargs.get( + 'output_attentions', False): + warnings.warn( + '`torch.nn.functional.scaled_dot_product_attention` does not ' + 'support `output_attentions=True`. Falling back to eager ' + 'attention. This warning can be removed using the argument' + ' `attn_implementation="eager"` when loading the model.') else: - target_dtype = self.q_proj.weight.dtype - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation] - # ----------------- flash attention forward ------------------------# - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - causal = self.is_causal and q_len != 1 - - use_sliding_windows = ( - _flash_supports_window_size and # noqa: W504 - getattr(self.config, 'sliding_window', None) is not None # noqa: W503 - and kv_seq_len > self.config.sliding_window) # noqa: W503 - window_size = (self.config.sliding_window, - self.config.sliding_window) if use_sliding_windows else (-1, - -1) + message_hub = MessageHub.get_instance('varlen_attn_args') + rank = dist.get_rank() + cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') + use_varlen_atten = (cumulative_len is not None) if use_varlen_atten: - attn_output = varlen_flash_attn( - query_states, - key_states, - value_states, - cumulative_len, - max_seqlen, - causal=causal, - dropout_p=dropout_rate, - window_size=window_size, - training=self.training) - else: - attn_output = flash_attn_wo_mask( - query_states, - key_states, - value_states, - causal=causal, - dropout_p=dropout_rate, - window_size=window_size, - training=self.training) - - # ---------------- flash attention forward end ------------------- # + # When gradient_checkpointing is enabled, the flash_attn_kwargs + # parameter is not automatically passed to the model. In such + # cases, parameters like cu_seq_lens_q and max_length_q are + # computed based on position_ids. However, when sequence + # parallel is enabled, position_ids is split along the + # sequence length, leading to incorrect calculations of these + # parameters. + # To address this issue, it is necessary to manually provide + # the flash_attn_kwargs parameters. + max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') + kwargs['cu_seq_lens_q'] = cumulative_len + kwargs['cu_seq_lens_k'] = cumulative_len + kwargs['max_length_q'] = max_seqlen + kwargs['max_length_k'] = max_seqlen + kwargs.pop('position_ids', None) + + # Hacky: `sdpa_attention_forward` does repeat_kv based on + # module.num_key_value_groups but it is done before + num_key_value_groups = self.num_key_value_groups + self.num_key_value_groups = 1 + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=getattr(self.config, 'sliding_window', + None), # main diff with Llama + **kwargs, + ) + self.num_key_value_groups = num_key_value_groups + + # different from MistralAttention.forward + if enable_sequence_parallel: + attn_output = post_process_for_sequence_parallel_attn(attn_output) - attn_output = attn_output.reshape(bsz, q_len, - self.hidden_size).contiguous() + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights diff --git a/xtuner/model/modules/dispatch/qwen2.py b/xtuner/model/modules/dispatch/qwen2.py index 20f2f40f3..179a3aba4 100644 --- a/xtuner/model/modules/dispatch/qwen2.py +++ b/xtuner/model/modules/dispatch/qwen2.py @@ -1,380 +1,140 @@ # Copyright (c) OpenMMLab. All rights reserved. -import inspect import warnings -from typing import Optional +from typing import Callable, Optional, Tuple import torch import torch.distributed as dist -import transformers from mmengine import MessageHub -from mmengine.utils import digit_version from transformers.cache_utils import Cache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.models.qwen2.modeling_qwen2 import (apply_rotary_pos_emb, + eager_attention_forward, repeat_kv) +from transformers.processing_utils import Unpack from xtuner.parallel.sequence import get_sequence_parallel_world_size from xtuner.parallel.sequence.attention import ( post_process_for_sequence_parallel_attn, pre_process_for_sequence_parallel_attn) -from .attention import flash_attn_wo_mask, varlen_flash_attn - -SUPPORT_FLASH2 = False - -try: - from flash_attn import flash_attn_func - _flash_supports_window_size = 'window_size' in list( - inspect.signature(flash_attn_func).parameters) - SUPPORT_FLASH2 = True -except ImportError: - pass - -TRANSFORMERS_VERSION = digit_version(transformers.__version__) -IS_LOW_VERSION_TRANSFORMERS = TRANSFORMERS_VERSION < digit_version('4.43') - -if not IS_LOW_VERSION_TRANSFORMERS: - from transformers.modeling_flash_attention_utils import \ - _flash_attention_forward +# modified from transformers.model.qwen2.modeling_qwen2.Qwen2Attention.forward and # noqa: E501 +# support sequence parallel def qwen2_attn_forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], ): - if 'padding_mask' in kwargs: - warnings.warn( - 'Passing `padding_mask` is deprecated and will be removed in ' - 'v4.37. Please make sure use `attention_mask` instead.`') - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop('padding_mask') - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - 'The cache structure has changed since version v4.36. ' - f'If you are using {self.__class__.__name__} ' - 'for auto-regressive decoding with k/v caching, ' - 'please make sure to initialize the attention class ' - 'with a layer index.') - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, - self.layer_idx) + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) - assert position_ids is not None - rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1) - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose( + 1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose( + 1, 2) + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids) - - use_sliding_windows = ( - _flash_supports_window_size - and getattr(self.config, 'sliding_window', None) is not None - and kv_seq_len > self.config.sliding_window - and self.config.use_sliding_window) + cos, sin) if past_key_value is not None: - # Activate slicing cache only if the config has a value - # `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if (getattr(self.config, 'sliding_window', None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - 'past key must have a shape of (`batch_size, num_heads, ' - 'self.config.sliding_window-1, head_dim`), got' - f' {past_key.shape}') - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat( - [attention_mask, - torch.ones_like(attention_mask[:, -1:])], - dim=-1) - - cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models + # sin and cos are specific to RoPE models; cache_position needed + # for the static cache + cache_kwargs = { + 'sin': sin, + 'cos': cos, + 'cache_position': cache_position + } key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs) + # different from Qwen2Attention.forward # repeat k/v heads if n_kv_heads < n_heads for sequence parallel key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for training - # stability reasons therefore the input hidden states gets silently - # casted in float32. Hence, we need cast them back in the correct dtype - # just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not - # cast the LayerNorms in fp32. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, '_pre_quantization_dtype'): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) enable_sequence_parallel = ( dist.is_initialized() and get_sequence_parallel_world_size() > 1 and self.training) if enable_sequence_parallel: + # Reashape for `pre_process_for_sequence_parallel_attn` + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) query_states, key_states, value_states = \ pre_process_for_sequence_parallel_attn( query_states, key_states, value_states) - # num_heads has been changed because of sequence parallel - # `self.num_heads`` is not used in self._flash_attention_forward - # in mistral/mixtral, we are doing this to avoid some unnecessary risk - ori_num_head = self.num_heads - self.num_heads = query_states.shape[-2] - - if IS_LOW_VERSION_TRANSFORMERS: - attn_output = self._flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - query_length=query_states.shape[1], - dropout=dropout_rate, - use_sliding_windows=use_sliding_windows, - ) - else: - if (self.config.use_sliding_window - and getattr(self.config, 'sliding_window', None) is not None - and self.layer_idx >= self.config.max_window_layers): - # There may be bugs here, but we are aligned with Transformers - sliding_window = self.config.sliding_window + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + sliding_window = None + if (self.config.use_sliding_window + and getattr(self.config, 'sliding_window', None) is not None + and self.layer_idx >= self.config.max_window_layers): + sliding_window = self.config.sliding_window + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != 'eager': + if self.config._attn_implementation == 'sdpa' and kwargs.get( + 'output_attentions', False): + warnings.warn( + '`torch.nn.functional.scaled_dot_product_attention` does not ' + 'support `output_attentions=True`. Falling back to eager ' + 'attention. This warning can be removed using the argument' + ' `attn_implementation="eager"` when loading the model.') else: - sliding_window = None - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - query_states.shape[1], - dropout=dropout_rate, - sliding_window=sliding_window, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - - if enable_sequence_parallel: - attn_output = post_process_for_sequence_parallel_attn(attn_output) - self.num_heads = ori_num_head - - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -def qwen2_varlen_attn_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, -): - is_training = self.training + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation] message_hub = MessageHub.get_instance('varlen_attn_args') rank = dist.get_rank() cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') - max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') - - assert is_training == (past_key_value is None) use_varlen_atten = (cumulative_len is not None) - - if 'padding_mask' in kwargs: - warnings.warn( - 'Passing `padding_mask` is deprecated and will be removed in v4.37' - ' Please make sure use `attention_mask` instead.`') - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop('padding_mask') - bsz, q_len, _ = hidden_states.size() - - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) - - query_states = query_states.view(bsz, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim).transpose(1, 2) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - 'The cache structure has changed since version v4.36. ' - f'If you are using {self.__class__.__name__} ' - 'for auto-regressive decoding with k/v caching, ' - 'please make sure to initialize the attention class ' - 'with a layer index.') - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, - self.layer_idx) - - assert position_ids is not None - rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1) - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids) - - if past_key_value is not None: - # Activate slicing cache only if the config has a value - # `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if (getattr(self.config, 'sliding_window', None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - 'past key must have a shape of (`batch_size, num_heads, ' - 'self.config.sliding_window-1, head_dim`), got' - f' {past_key.shape}') - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat( - [attention_mask, - torch.ones_like(attention_mask[:, -1:])], - dim=-1) - - cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs) - - # repeat k/v heads if n_kv_heads < n_heads for sequence parallel - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - dropout_rate = 0.0 if not self.training else self.attention_dropout - - # In PEFT, usually we cast the layer norms in float32 for - # training stability reasons, therefore the input hidden states gets - # silently casted in float32. Hence, we need - # cast them back in float16 just to be sure everything works as expected. - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, '_pre_quantization_dtype'): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - # Reashape to the expected shape for Flash Attention - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - # ----------------- flash attention forward ------------------------# - - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - causal = self.is_causal and q_len != 1 - - use_sliding_windows = ( - _flash_supports_window_size - and getattr(self.config, 'sliding_window', None) is not None - and kv_seq_len > self.config.sliding_window - and self.config.use_sliding_window) - # Decide whether to use SWA or not by layer index. - if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: - use_sliding_windows = False - - window_size = (self.config.sliding_window, - self.config.sliding_window) if use_sliding_windows else (-1, - -1) - if use_varlen_atten: - attn_output = varlen_flash_attn( - query_states, - key_states, - value_states, - cumulative_len, - max_seqlen, - causal=causal, - dropout_p=dropout_rate, - window_size=window_size, - training=self.training) - else: - attn_output = flash_attn_wo_mask( - query_states, - key_states, - value_states, - causal=causal, - dropout_p=dropout_rate, - window_size=window_size, - training=self.training) - - # ---------------- flash attention forward end ------------------- # + # When gradient_checkpointing is enabled, the flash_attn_kwargs + # parameter is not automatically passed to the model. In such + # cases, parameters like cu_seq_lens_q and max_length_q are + # computed based on position_ids. However, when sequence + # parallel is enabled, position_ids is split along the + # sequence length, leading to incorrect calculations of these + # parameters. + # To address this issue, it is necessary to manually provide + # the flash_attn_kwargs parameters. + max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') + kwargs['cu_seq_lens_q'] = cumulative_len + kwargs['cu_seq_lens_k'] = cumulative_len + kwargs['max_length_q'] = max_seqlen + kwargs['max_length_k'] = max_seqlen + kwargs.pop('position_ids', None) + + # Hacky: `sdpa_attention_forward` does repeat_kv based on + # module.num_key_value_groups but it is done before + num_key_value_groups = self.num_key_value_groups + self.num_key_value_groups = 1 + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=sliding_window, # main diff with Llama + **kwargs, + ) + self.num_key_value_groups = num_key_value_groups + + # different from Qwen2Attention.forward + if enable_sequence_parallel: + attn_output = post_process_for_sequence_parallel_attn(attn_output) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights diff --git a/xtuner/model/sft.py b/xtuner/model/sft.py index 743a6ba83..f58e17c43 100644 --- a/xtuner/model/sft.py +++ b/xtuner/model/sft.py @@ -186,12 +186,13 @@ def _prepare_for_long_context_training(cfg, llm_cfg, @staticmethod def _prepare_for_flash_attn(cfg, llm_cfg): cls_name = type(llm_cfg).__name__ - SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig', - 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig', - 'Starcoder2Config', 'Starcoder2Config', - 'Phi3Config') - SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig', - 'MistralConfig', 'MixtralConfig', 'Qwen2Config', + SUPPORT_SDPA_ATTN = ('InternLM3Config', 'LlamaConfig', 'GemmaConfig', + 'MistralConfig', 'MixtralConfig', 'Qwen2Config', + 'Qwen2MoeConfig', 'Starcoder2Config', + 'Starcoder2Config', 'Phi3Config') + SUPPORT_FLASH_ATTN2 = ('InternLM3Config', 'InternLM2Config', + 'LlamaConfig', 'GemmaConfig', 'MistralConfig', + 'MixtralConfig', 'Qwen2Config', 'Qwen2MoeConfig', 'Starcoder2Config', 'Starcoder2Config', 'Phi3Config', 'DeepseekV2Config')