21
21
BloomModel ,
22
22
)
23
23
from transformers .utils import logging
24
-
24
+ from transformers . modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
25
25
from colossalai .pipeline .stage_manager import PipelineStageManager
26
26
from colossalai .shardformer .layer ._operation import gather_forward_split_backward , split_forward_gather_backward
27
27
from colossalai .shardformer .shard import ShardConfig
@@ -205,12 +205,13 @@ def bloom_model_forward(
205
205
alibi = self .build_alibi_tensor (attention_mask , self .num_heads , dtype = hidden_states .dtype )
206
206
207
207
# causal_mask is constructed every stage and its input is passed through different stages
208
- causal_mask = self . _prepare_attn_mask (
208
+ causal_mask = _prepare_4d_causal_attention_mask (
209
209
attention_mask ,
210
210
input_shape = (batch_size , seq_length ),
211
+ inputs_embeds = hidden_states ,
211
212
past_key_values_length = past_key_values_length ,
212
213
)
213
-
214
+ causal_mask = causal_mask . bool ()
214
215
# split the input tensor along sequence dimension
215
216
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
216
217
if shard_config and shard_config .enable_sequence_parallelism :
@@ -227,21 +228,15 @@ def bloom_model_forward(
227
228
all_hidden_states = all_hidden_states + (hidden_states ,)
228
229
229
230
if self .gradient_checkpointing and self .training :
230
-
231
- def create_custom_forward (module ):
232
- def custom_forward (* inputs ):
233
- # None for past_key_value
234
- return module (* inputs , use_cache = use_cache , output_attentions = output_attentions )
235
-
236
- return custom_forward
237
-
238
- outputs = torch .utils .checkpoint .checkpoint (
239
- create_custom_forward (block ),
231
+ outputs = self ._gradient_checkpointing_func (
232
+ block .__call__ ,
240
233
hidden_states ,
241
234
alibi ,
242
235
causal_mask ,
243
236
layer_past ,
244
237
head_mask [i ],
238
+ use_cache ,
239
+ output_attentions ,
245
240
)
246
241
else :
247
242
outputs = block (
@@ -1002,11 +997,13 @@ def forward(
1002
997
1003
998
alibi = self .build_alibi_tensor (attention_mask , self .num_heads , dtype = hidden_states .dtype )
1004
999
1005
- causal_mask = self . _prepare_attn_mask (
1000
+ causal_mask = _prepare_4d_causal_attention_mask (
1006
1001
attention_mask ,
1007
1002
input_shape = (batch_size , seq_length ),
1003
+ inputs_embeds = hidden_states ,
1008
1004
past_key_values_length = past_key_values_length ,
1009
1005
)
1006
+ causal_mask = causal_mask .bool ()
1010
1007
# split the input tensor along sequence dimension
1011
1008
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
1012
1009
hidden_states = split_forward_gather_backward (
@@ -1018,21 +1015,15 @@ def forward(
1018
1015
all_hidden_states = all_hidden_states + (hidden_states ,)
1019
1016
1020
1017
if self .gradient_checkpointing and self .training :
1021
-
1022
- def create_custom_forward (module ):
1023
- def custom_forward (* inputs ):
1024
- # None for past_key_value
1025
- return module (* inputs , use_cache = use_cache , output_attentions = output_attentions )
1026
-
1027
- return custom_forward
1028
-
1029
- outputs = torch .utils .checkpoint .checkpoint (
1030
- create_custom_forward (block ),
1018
+ outputs = self ._gradient_checkpointing_func (
1019
+ block .__call__ ,
1031
1020
hidden_states ,
1032
1021
alibi ,
1033
1022
causal_mask ,
1034
1023
layer_past ,
1035
1024
head_mask [i ],
1025
+ use_cache ,
1026
+ output_attentions ,
1036
1027
)
1037
1028
else :
1038
1029
outputs = block (
0 commit comments