Skip to content

Commit cdb166c

Browse files
authored
Merge pull request #5499 from wangbluo/update_llama2
Update llama2
2 parents 5fcd779 + 9206dd1 commit cdb166c

File tree

3 files changed

+46
-52
lines changed

3 files changed

+46
-52
lines changed

colossalai/shardformer/modeling/llama.py

+45-51
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,15 @@
1111
)
1212
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
1313
from transformers.utils import logging
14+
from transformers.cache_utils import Cache
1415

1516
from colossalai.pipeline.stage_manager import PipelineStageManager
1617
from colossalai.shardformer.shard import ShardConfig
1718

1819
from ..layer import cross_entropy_1d
1920
from ..layer._operation import gather_forward_split_backward
2021

21-
try:
22-
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
23-
24-
LATEST_VERSION = True
25-
except ImportError:
26-
LATEST_VERSION = False
22+
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
2723

2824

2925
class LlamaPipelineForwards:
@@ -62,13 +58,13 @@ def llama_model_forward(
6258
# retrieve input_ids and inputs_embeds
6359
if stage_manager.is_first_stage():
6460
if input_ids is not None and inputs_embeds is not None:
65-
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
61+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
6662
elif input_ids is not None:
67-
batch_size, seq_length = input_ids.shape
63+
batch_size, seq_length = input_ids.shape[:2]
6864
elif inputs_embeds is not None:
69-
batch_size, seq_length, _ = inputs_embeds.shape
65+
batch_size, seq_length, _ = inputs_embeds.shape[:2]
7066
else:
71-
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
67+
raise ValueError("You have to specify either input_ids or inputs_embeds")
7268
device = input_ids.device if input_ids is not None else inputs_embeds.device
7369
if inputs_embeds is None:
7470
inputs_embeds = self.embed_tokens(input_ids)
@@ -100,22 +96,23 @@ def llama_model_forward(
10096
position_ids = torch.arange(
10197
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
10298
)
103-
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
104-
else:
105-
position_ids = position_ids.view(-1, seq_length).long()
106-
107-
# embed positions, for the first stage, hidden_states is the input embeddings,
108-
# for the other stages, hidden_states is the output of the previous stage
109-
if attention_mask is None:
110-
attention_mask = torch.ones(
111-
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
112-
)
113-
if LATEST_VERSION:
114-
attention_mask = _prepare_4d_causal_attention_mask(
115-
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
99+
position_ids = position_ids.unsqueeze(0)
100+
101+
if self._use_flash_attention_2:
102+
# 2d mask is passed through the layers
103+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
104+
elif self._use_sdpa and not output_attentions:
105+
# output_attentions=True can not be supported when using SDPA, and we fall back on
106+
# the manual implementation that requires a 4D causal mask in all cases.
107+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
108+
attention_mask,
109+
(batch_size, seq_length),
110+
inputs_embeds,
111+
past_key_values_length,
116112
)
117113
else:
118-
attention_mask = self._prepare_decoder_attention_mask(
114+
# 4d mask is passed through the layers
115+
attention_mask = _prepare_4d_causal_attention_mask(
119116
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
120117
)
121118

@@ -129,45 +126,38 @@ def llama_model_forward(
129126
# decoder layers
130127
all_hidden_states = () if output_hidden_states else None
131128
all_self_attns = () if output_attentions else None
132-
next_decoder_cache = () if use_cache else None
129+
next_decoder_cache = None
133130

134131
start_idx, end_idx = stage_index[0], stage_index[1]
135132
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
136133
if output_hidden_states:
137134
all_hidden_states += (hidden_states,)
138135

139-
past_key_value = past_key_values[idx] if past_key_values is not None else None
140-
141136
if self.gradient_checkpointing and self.training:
142137

143-
def create_custom_forward(module):
144-
def custom_forward(*inputs):
145-
# None for past_key_value
146-
return module(*inputs, output_attentions, None)
147-
148-
return custom_forward
149-
150-
layer_outputs = torch.utils.checkpoint.checkpoint(
151-
create_custom_forward(decoder_layer),
138+
layer_outputs = self._gradient_checkpointing_func(
139+
decoder_layer.__call__,
152140
hidden_states,
153141
attention_mask,
154142
position_ids,
155-
None,
143+
past_key_values,
144+
output_attentions,
145+
use_cache,
156146
)
157147
else:
158148
layer_outputs = decoder_layer(
159149
hidden_states,
160150
attention_mask=attention_mask,
161151
position_ids=position_ids,
162-
past_key_value=past_key_value,
152+
past_key_value=past_key_values,
163153
output_attentions=output_attentions,
164154
use_cache=use_cache,
165155
)
166156

167157
hidden_states = layer_outputs[0]
168158

169159
if use_cache:
170-
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
160+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
171161
if output_attentions:
172162
all_self_attns += (layer_outputs[1],)
173163

@@ -438,11 +428,15 @@ def forward(
438428
hidden_states: torch.Tensor,
439429
attention_mask: Optional[torch.Tensor] = None,
440430
position_ids: Optional[torch.LongTensor] = None,
441-
past_key_value: Optional[Tuple[torch.Tensor]] = None,
431+
past_key_value: Optional[Cache] = None,
442432
output_attentions: bool = False,
443433
use_cache: bool = False,
444434
**kwargs,
445435
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
436+
if "padding_mask" in kwargs:
437+
warnings.warn(
438+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
439+
)
446440
bsz, q_len, _ = hidden_states.size()
447441
assert q_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
448442

@@ -452,23 +446,23 @@ def forward(
452446

453447
kv_seq_len = key_states.shape[-2]
454448
if past_key_value is not None:
455-
kv_seq_len += past_key_value[0].shape[-2]
449+
if self.layer_idx is None:
450+
raise ValueError(
451+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
452+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
453+
"with a layer index."
454+
)
455+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
456456

457457
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
458-
459458
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
460459

461460
if past_key_value is not None:
462-
# reuse k, v, self_attention
463-
key_states = torch.cat([past_key_value[0], key_states], dim=2)
464-
value_states = torch.cat([past_key_value[1], value_states], dim=2)
465-
466-
past_key_value = (key_states, value_states) if use_cache else None
461+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
462+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
467463

468-
# repeat k/v heads if n_kv_heads < n_heads
469-
if llama_version == 2:
470-
key_states = repeat_kv(key_states, self.num_key_value_groups)
471-
value_states = repeat_kv(value_states, self.num_key_value_groups)
464+
key_states = repeat_kv(key_states, self.num_key_value_groups)
465+
value_states = repeat_kv(value_states, self.num_key_value_groups)
472466

473467
me_input_shape = (bsz, q_len, self.num_heads, self.head_dim)
474468
query_states = query_states.transpose(1, 2).contiguous().view(*me_input_shape)

requirements/requirements-test.txt

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ pytest
33
coverage==7.2.3
44
git+https://github.com/hpcaitech/pytest-testmon
55
torchvision
6-
transformers==4.33.0
76
timm
87
titans
98
torchaudio

requirements/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ ray
1616
sentencepiece
1717
google
1818
protobuf
19+
transformers==4.36.0

0 commit comments

Comments
 (0)