21
21
BloomModel ,
22
22
)
23
23
from transformers .utils import logging
24
-
24
+ from transformers . modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
25
25
from colossalai .pipeline .stage_manager import PipelineStageManager
26
26
from colossalai .shardformer .layer ._operation import gather_forward_split_backward , split_forward_gather_backward
27
27
from colossalai .shardformer .shard import ShardConfig
@@ -205,12 +205,13 @@ def bloom_model_forward(
205
205
alibi = self .build_alibi_tensor (attention_mask , self .num_heads , dtype = hidden_states .dtype )
206
206
207
207
# causal_mask is constructed every stage and its input is passed through different stages
208
- causal_mask = self . _prepare_attn_mask (
208
+ causal_mask = _prepare_4d_causal_attention_mask (
209
209
attention_mask ,
210
210
input_shape = (batch_size , seq_length ),
211
+ inputs_embeds = hidden_states ,
211
212
past_key_values_length = past_key_values_length ,
212
213
)
213
-
214
+ causal_mask = causal_mask . bool ()
214
215
# split the input tensor along sequence dimension
215
216
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
216
217
if shard_config .enable_sequence_parallelism :
@@ -226,21 +227,15 @@ def bloom_model_forward(
226
227
all_hidden_states = all_hidden_states + (hidden_states ,)
227
228
228
229
if self .gradient_checkpointing and self .training :
229
-
230
- def create_custom_forward (module ):
231
- def custom_forward (* inputs ):
232
- # None for past_key_value
233
- return module (* inputs , use_cache = use_cache , output_attentions = output_attentions )
234
-
235
- return custom_forward
236
-
237
- outputs = torch .utils .checkpoint .checkpoint (
238
- create_custom_forward (block ),
230
+ outputs = self ._gradient_checkpointing_func (
231
+ block .__call__ ,
239
232
hidden_states ,
240
233
alibi ,
241
234
causal_mask ,
242
235
layer_past ,
243
236
head_mask [i ],
237
+ use_cache ,
238
+ output_attentions ,
244
239
)
245
240
else :
246
241
outputs = block (
@@ -1000,11 +995,13 @@ def forward(
1000
995
1001
996
alibi = self .build_alibi_tensor (attention_mask , self .num_heads , dtype = hidden_states .dtype )
1002
997
1003
- causal_mask = self . _prepare_attn_mask (
998
+ causal_mask = _prepare_4d_causal_attention_mask (
1004
999
attention_mask ,
1005
1000
input_shape = (batch_size , seq_length ),
1001
+ inputs_embeds = hidden_states ,
1006
1002
past_key_values_length = past_key_values_length ,
1007
1003
)
1004
+ causal_mask = causal_mask .bool ()
1008
1005
# split the input tensor along sequence dimension
1009
1006
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
1010
1007
hidden_states = split_forward_gather_backward (
@@ -1016,21 +1013,15 @@ def forward(
1016
1013
all_hidden_states = all_hidden_states + (hidden_states ,)
1017
1014
1018
1015
if self .gradient_checkpointing and self .training :
1019
-
1020
- def create_custom_forward (module ):
1021
- def custom_forward (* inputs ):
1022
- # None for past_key_value
1023
- return module (* inputs , use_cache = use_cache , output_attentions = output_attentions )
1024
-
1025
- return custom_forward
1026
-
1027
- outputs = torch .utils .checkpoint .checkpoint (
1028
- create_custom_forward (block ),
1016
+ outputs = self ._gradient_checkpointing_func (
1017
+ block .__call__ ,
1029
1018
hidden_states ,
1030
1019
alibi ,
1031
1020
causal_mask ,
1032
1021
layer_past ,
1033
1022
head_mask [i ],
1023
+ use_cache ,
1024
+ output_attentions ,
1034
1025
)
1035
1026
else :
1036
1027
outputs = block (
0 commit comments