11
11
)
12
12
from transformers .models .llama .modeling_llama import LlamaForCausalLM , LlamaForSequenceClassification , LlamaModel
13
13
from transformers .utils import logging
14
+ from transformers .cache_utils import Cache
14
15
15
16
from colossalai .pipeline .stage_manager import PipelineStageManager
16
17
from colossalai .shardformer .shard import ShardConfig
17
18
18
19
from ..layer import cross_entropy_1d
19
20
from ..layer ._operation import gather_forward_split_backward
20
21
21
- try :
22
- from transformers .models .llama .modeling_llama import _prepare_4d_causal_attention_mask
23
-
24
- LATEST_VERSION = True
25
- except ImportError :
26
- LATEST_VERSION = False
22
+ from transformers .models .llama .modeling_llama import _prepare_4d_causal_attention_mask , _prepare_4d_causal_attention_mask_for_sdpa
27
23
28
24
29
25
class LlamaPipelineForwards :
@@ -62,13 +58,13 @@ def llama_model_forward(
62
58
# retrieve input_ids and inputs_embeds
63
59
if stage_manager .is_first_stage ():
64
60
if input_ids is not None and inputs_embeds is not None :
65
- raise ValueError ("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" )
61
+ raise ValueError ("You cannot specify both input_ids and inputs_embeds at the same time" )
66
62
elif input_ids is not None :
67
- batch_size , seq_length = input_ids .shape
63
+ batch_size , seq_length = input_ids .shape [: 2 ]
68
64
elif inputs_embeds is not None :
69
- batch_size , seq_length , _ = inputs_embeds .shape
65
+ batch_size , seq_length , _ = inputs_embeds .shape [: 2 ]
70
66
else :
71
- raise ValueError ("You have to specify either decoder_input_ids or decoder_inputs_embeds " )
67
+ raise ValueError ("You have to specify either input_ids or inputs_embeds " )
72
68
device = input_ids .device if input_ids is not None else inputs_embeds .device
73
69
if inputs_embeds is None :
74
70
inputs_embeds = self .embed_tokens (input_ids )
@@ -100,22 +96,23 @@ def llama_model_forward(
100
96
position_ids = torch .arange (
101
97
past_key_values_length , seq_length + past_key_values_length , dtype = torch .long , device = device
102
98
)
103
- position_ids = position_ids .unsqueeze (0 ). view ( - 1 , seq_length )
104
- else :
105
- position_ids = position_ids . view ( - 1 , seq_length ). long ()
106
-
107
- # embed positions, for the first stage, hidden_states is the input embeddings,
108
- # for the other stages, hidden_states is the output of the previous stage
109
- if attention_mask is None :
110
- attention_mask = torch . ones (
111
- ( batch_size , seq_length_with_past ), dtype = torch . bool , device = hidden_states . device
112
- )
113
- if LATEST_VERSION :
114
- attention_mask = _prepare_4d_causal_attention_mask (
115
- attention_mask , ( batch_size , seq_length ), hidden_states , past_key_values_length
99
+ position_ids = position_ids .unsqueeze (0 )
100
+
101
+ if self . _use_flash_attention_2 :
102
+ # 2d mask is passed through the layers
103
+ attention_mask = attention_mask if ( attention_mask is not None and 0 in attention_mask ) else None
104
+ elif self . _use_sdpa and not output_attentions :
105
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
106
+ # the manual implementation that requires a 4D causal mask in all cases.
107
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa (
108
+ attention_mask ,
109
+ ( batch_size , seq_length ),
110
+ inputs_embeds ,
111
+ past_key_values_length ,
116
112
)
117
113
else :
118
- attention_mask = self ._prepare_decoder_attention_mask (
114
+ # 4d mask is passed through the layers
115
+ attention_mask = _prepare_4d_causal_attention_mask (
119
116
attention_mask , (batch_size , seq_length ), hidden_states , past_key_values_length
120
117
)
121
118
@@ -129,45 +126,38 @@ def llama_model_forward(
129
126
# decoder layers
130
127
all_hidden_states = () if output_hidden_states else None
131
128
all_self_attns = () if output_attentions else None
132
- next_decoder_cache = () if use_cache else None
129
+ next_decoder_cache = None
133
130
134
131
start_idx , end_idx = stage_index [0 ], stage_index [1 ]
135
132
for idx , decoder_layer in enumerate (self .layers [start_idx :end_idx ], start = start_idx ):
136
133
if output_hidden_states :
137
134
all_hidden_states += (hidden_states ,)
138
135
139
- past_key_value = past_key_values [idx ] if past_key_values is not None else None
140
-
141
136
if self .gradient_checkpointing and self .training :
142
137
143
- def create_custom_forward (module ):
144
- def custom_forward (* inputs ):
145
- # None for past_key_value
146
- return module (* inputs , output_attentions , None )
147
-
148
- return custom_forward
149
-
150
- layer_outputs = torch .utils .checkpoint .checkpoint (
151
- create_custom_forward (decoder_layer ),
138
+ layer_outputs = self ._gradient_checkpointing_func (
139
+ decoder_layer .__call__ ,
152
140
hidden_states ,
153
141
attention_mask ,
154
142
position_ids ,
155
- None ,
143
+ past_key_values ,
144
+ output_attentions ,
145
+ use_cache ,
156
146
)
157
147
else :
158
148
layer_outputs = decoder_layer (
159
149
hidden_states ,
160
150
attention_mask = attention_mask ,
161
151
position_ids = position_ids ,
162
- past_key_value = past_key_value ,
152
+ past_key_value = past_key_values ,
163
153
output_attentions = output_attentions ,
164
154
use_cache = use_cache ,
165
155
)
166
156
167
157
hidden_states = layer_outputs [0 ]
168
158
169
159
if use_cache :
170
- next_decoder_cache += ( layer_outputs [2 if output_attentions else 1 ],)
160
+ next_decoder_cache = layer_outputs [2 if output_attentions else 1 ]
171
161
if output_attentions :
172
162
all_self_attns += (layer_outputs [1 ],)
173
163
@@ -438,11 +428,15 @@ def forward(
438
428
hidden_states : torch .Tensor ,
439
429
attention_mask : Optional [torch .Tensor ] = None ,
440
430
position_ids : Optional [torch .LongTensor ] = None ,
441
- past_key_value : Optional [Tuple [ torch . Tensor ] ] = None ,
431
+ past_key_value : Optional [Cache ] = None ,
442
432
output_attentions : bool = False ,
443
433
use_cache : bool = False ,
444
434
** kwargs ,
445
435
) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
436
+ if "padding_mask" in kwargs :
437
+ warnings .warn (
438
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
439
+ )
446
440
bsz , q_len , _ = hidden_states .size ()
447
441
assert q_len % 4 == 0 , "Flash Attention Error: The sequence length should be a multiple of 4."
448
442
@@ -452,23 +446,23 @@ def forward(
452
446
453
447
kv_seq_len = key_states .shape [- 2 ]
454
448
if past_key_value is not None :
455
- kv_seq_len += past_key_value [0 ].shape [- 2 ]
449
+ if self .layer_idx is None :
450
+ raise ValueError (
451
+ f"The cache structure has changed since version v4.36. If you are using { self .__class__ .__name__ } "
452
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
453
+ "with a layer index."
454
+ )
455
+ kv_seq_len += past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
456
456
457
457
cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
458
-
459
458
query_states , key_states = apply_rotary_pos_emb (query_states , key_states , cos , sin , position_ids )
460
459
461
460
if past_key_value is not None :
462
- # reuse k, v, self_attention
463
- key_states = torch .cat ([past_key_value [0 ], key_states ], dim = 2 )
464
- value_states = torch .cat ([past_key_value [1 ], value_states ], dim = 2 )
465
-
466
- past_key_value = (key_states , value_states ) if use_cache else None
461
+ cache_kwargs = {"sin" : sin , "cos" : cos } # Specific to RoPE models
462
+ key_states , value_states = past_key_value .update (key_states , value_states , self .layer_idx , cache_kwargs )
467
463
468
- # repeat k/v heads if n_kv_heads < n_heads
469
- if llama_version == 2 :
470
- key_states = repeat_kv (key_states , self .num_key_value_groups )
471
- value_states = repeat_kv (value_states , self .num_key_value_groups )
464
+ key_states = repeat_kv (key_states , self .num_key_value_groups )
465
+ value_states = repeat_kv (value_states , self .num_key_value_groups )
472
466
473
467
me_input_shape = (bsz , q_len , self .num_heads , self .head_dim )
474
468
query_states = query_states .transpose (1 , 2 ).contiguous ().view (* me_input_shape )
0 commit comments