From 0b2ce4b5bbb1f2ed346aa1512fe73bfb1dc9b968 Mon Sep 17 00:00:00 2001 From: HIT-cwh <2892770585@qq.com> Date: Mon, 4 Nov 2024 07:13:56 +0000 Subject: [PATCH] qwen support transformers>=4.46 --- xtuner/model/modules/dispatch/__init__.py | 2 +- xtuner/model/modules/dispatch/qwen2.py | 296 ++++++++-------------- 2 files changed, 108 insertions(+), 190 deletions(-) diff --git a/xtuner/model/modules/dispatch/__init__.py b/xtuner/model/modules/dispatch/__init__.py index e81ec7a3a..1f0260ab9 100644 --- a/xtuner/model/modules/dispatch/__init__.py +++ b/xtuner/model/modules/dispatch/__init__.py @@ -43,7 +43,7 @@ # Refer to https://github.com/microsoft/DeepSpeed/issues/5066 MixtralForCausalLM=digit_version('4.40'), CohereForCausalLM=digit_version('4.40'), - Qwen2ForCausalLM=digit_version('4.39'), + Qwen2ForCausalLM=digit_version('4.46'), Qwen2MoeForCausalLM=digit_version('4.40'), DeepseekV2ForCausalLM=digit_version('4.40'), ) diff --git a/xtuner/model/modules/dispatch/qwen2.py b/xtuner/model/modules/dispatch/qwen2.py index 20f2f40f3..1baacfc14 100644 --- a/xtuner/model/modules/dispatch/qwen2.py +++ b/xtuner/model/modules/dispatch/qwen2.py @@ -1,13 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. import inspect import warnings -from typing import Optional +from typing import Optional, Tuple import torch import torch.distributed as dist -import transformers from mmengine import MessageHub -from mmengine.utils import digit_version from transformers.cache_utils import Cache from transformers.models.qwen2.modeling_qwen2 import (apply_rotary_pos_emb, repeat_kv) @@ -28,31 +26,28 @@ except ImportError: pass -TRANSFORMERS_VERSION = digit_version(transformers.__version__) -IS_LOW_VERSION_TRANSFORMERS = TRANSFORMERS_VERSION < digit_version('4.43') - -if not IS_LOW_VERSION_TRANSFORMERS: +try: from transformers.modeling_flash_attention_utils import \ _flash_attention_forward +except ImportError: + _flash_attention_forward = None +# Modified from https://github.com/huggingface/transformers/blob/v4.46.1/src/transformers/models/qwen2/modeling_qwen2.py#L364 # noqa: E501 +# and sequence parallel is supported. def qwen2_attn_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, + torch.Tensor]] = None, # will become mandatory in v4.46 ): - if 'padding_mask' in kwargs: - warnings.warn( - 'Passing `padding_mask` is deprecated and will be removed in ' - 'v4.37. Please make sure use `attention_mask` instead.`') - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop('padding_mask') bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -66,74 +61,38 @@ def qwen2_attn_forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - 'The cache structure has changed since version v4.36. ' - f'If you are using {self.__class__.__name__} ' - 'for auto-regressive decoding with k/v caching, ' - 'please make sure to initialize the attention class ' - 'with a layer index.') - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, - self.layer_idx) - - assert position_ids is not None - rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1) - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - + if position_embeddings is None: + warnings.warn( + 'The attention layers in this model are transitioning from' + ' computing the RoPE embeddings internally through `position_ids` ' + '(2D tensor with the indexes of the tokens), to using externally ' + 'computed `position_embeddings` (Tuple of tensors, containing cos ' + 'and sin). In v4.46 `position_ids` will be removed and ' + '`position_embeddings` will be mandatory.') + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids) - - use_sliding_windows = ( - _flash_supports_window_size - and getattr(self.config, 'sliding_window', None) is not None - and kv_seq_len > self.config.sliding_window - and self.config.use_sliding_window) + cos, sin) if past_key_value is not None: - # Activate slicing cache only if the config has a value - # `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if (getattr(self.config, 'sliding_window', None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - 'past key must have a shape of (`batch_size, num_heads, ' - 'self.config.sliding_window-1, head_dim`), got' - f' {past_key.shape}') - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat( - [attention_mask, - torch.ones_like(attention_mask[:, -1:])], - dim=-1) - - cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models + cache_kwargs = { + 'sin': sin, + 'cos': cos, + 'cache_position': cache_position + } # Specific to RoPE models key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads for sequence parallel + # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) dropout_rate = 0.0 if not self.training else self.attention_dropout - # In PEFT, usually we cast the layer norms in float32 for training - # stability reasons therefore the input hidden states gets silently - # casted in float32. Hence, we need cast them back in the correct dtype - # just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not - # cast the LayerNorms in fp32. + # In PEFT, usually we cast the layer norms in float32 for + # training stability reasons, therefore the input hidden states gets + # silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): @@ -144,6 +103,12 @@ def qwen2_attn_forward( else: target_dtype = self.q_proj.weight.dtype + warnings.warn( + f'The input hidden states seems to be silently casted in float32,' + ' this might be related to the fact you have upcasted embedding ' + 'or layer norm layers in float32. We will cast back the input in' + f' {target_dtype}.') + query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) @@ -160,47 +125,35 @@ def qwen2_attn_forward( query_states, key_states, value_states = \ pre_process_for_sequence_parallel_attn( query_states, key_states, value_states) - # num_heads has been changed because of sequence parallel - # `self.num_heads`` is not used in self._flash_attention_forward - # in mistral/mixtral, we are doing this to avoid some unnecessary risk - ori_num_head = self.num_heads - self.num_heads = query_states.shape[-2] - - if IS_LOW_VERSION_TRANSFORMERS: - attn_output = self._flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - query_length=query_states.shape[1], - dropout=dropout_rate, - use_sliding_windows=use_sliding_windows, - ) + + if (self.config.use_sliding_window + and getattr(self.config, 'sliding_window', None) is not None + and self.layer_idx >= self.config.max_window_layers): + sliding_window = self.config.sliding_window else: - if (self.config.use_sliding_window - and getattr(self.config, 'sliding_window', None) is not None - and self.layer_idx >= self.config.max_window_layers): - # There may be bugs here, but we are aligned with Transformers - sliding_window = self.config.sliding_window - else: - sliding_window = None - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - query_states.shape[1], - dropout=dropout_rate, - sliding_window=sliding_window, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + sliding_window = None + + if _flash_attention_forward is None: + raise RuntimeError('Please install Transformers >= 4.46.1.') + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_states.shape[1], + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) if enable_sequence_parallel: attn_output = post_process_for_sequence_parallel_attn(attn_output) - self.num_heads = ori_num_head - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.reshape(bsz, q_len, + self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: @@ -210,32 +163,24 @@ def qwen2_attn_forward( def qwen2_varlen_attn_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, + torch.Tensor]] = None, # will become mandatory in v4.46 ): - is_training = self.training - message_hub = MessageHub.get_instance('varlen_attn_args') rank = dist.get_rank() cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') - - assert is_training == (past_key_value is None) use_varlen_atten = (cumulative_len is not None) - if 'padding_mask' in kwargs: - warnings.warn( - 'Passing `padding_mask` is deprecated and will be removed in v4.37' - ' Please make sure use `attention_mask` instead.`') - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop('padding_mask') bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -249,58 +194,30 @@ def qwen2_varlen_attn_forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - 'The cache structure has changed since version v4.36. ' - f'If you are using {self.__class__.__name__} ' - 'for auto-regressive decoding with k/v caching, ' - 'please make sure to initialize the attention class ' - 'with a layer index.') - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, - self.layer_idx) - - assert position_ids is not None - rotary_seq_len = max(kv_seq_len, position_ids.max().item() + 1) - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - + if position_embeddings is None: + warnings.warn( + 'The attention layers in this model are transitioning from' + ' computing the RoPE embeddings internally through `position_ids` ' + '(2D tensor with the indexes of the tokens), to using externally ' + 'computed `position_embeddings` (Tuple of tensors, containing cos ' + 'and sin). In v4.46 `position_ids` will be removed and ' + '`position_embeddings` will be mandatory.') + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids) + cos, sin) if past_key_value is not None: - # Activate slicing cache only if the config has a value - # `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - if (getattr(self.config, 'sliding_window', None) is not None - and kv_seq_len > self.config.sliding_window - and cache_has_contents): - slicing_tokens = 1 - self.config.sliding_window - - past_key = past_key_value[self.layer_idx][0] - past_value = past_key_value[self.layer_idx][1] - - past_key = past_key[:, :, slicing_tokens:, :].contiguous() - past_value = past_value[:, :, slicing_tokens:, :].contiguous() - - if past_key.shape[-2] != self.config.sliding_window - 1: - raise ValueError( - 'past key must have a shape of (`batch_size, num_heads, ' - 'self.config.sliding_window-1, head_dim`), got' - f' {past_key.shape}') - - if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat( - [attention_mask, - torch.ones_like(attention_mask[:, -1:])], - dim=-1) - - cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models + cache_kwargs = { + 'sin': sin, + 'cos': cos, + 'cache_position': cache_position + } # Specific to RoPE models key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs) - # repeat k/v heads if n_kv_heads < n_heads for sequence parallel + # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) dropout_rate = 0.0 if not self.training else self.attention_dropout @@ -319,6 +236,12 @@ def qwen2_varlen_attn_forward( else: target_dtype = self.q_proj.weight.dtype + warnings.warn( + f'The input hidden states seems to be silently casted in float32,' + ' this might be related to the fact you have upcasted embedding ' + 'or layer norm layers in float32. We will cast back the input in' + f' {target_dtype}.') + query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) @@ -330,15 +253,10 @@ def qwen2_varlen_attn_forward( # ----------------- flash attention forward ------------------------# - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - causal = self.is_causal and q_len != 1 - use_sliding_windows = ( _flash_supports_window_size and getattr(self.config, 'sliding_window', None) is not None - and kv_seq_len > self.config.sliding_window + and key_states.shape[1] > self.config.sliding_window and self.config.use_sliding_window) # Decide whether to use SWA or not by layer index. if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: @@ -355,7 +273,7 @@ def qwen2_varlen_attn_forward( value_states, cumulative_len, max_seqlen, - causal=causal, + causal=self.is_causal, dropout_p=dropout_rate, window_size=window_size, training=self.training) @@ -364,7 +282,7 @@ def qwen2_varlen_attn_forward( query_states, key_states, value_states, - causal=causal, + causal=self.is_causal, dropout_p=dropout_rate, window_size=window_size, training=self.training)