6
6
from torch .distributed import ProcessGroup
7
7
from torch .nn import BCEWithLogitsLoss , CrossEntropyLoss , MSELoss
8
8
from torch .nn import functional as F
9
+ from transformers .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
9
10
from transformers .modeling_outputs import (
10
11
BaseModelOutputWithPastAndCrossAttentions ,
11
12
CausalLMOutputWithCrossAttentions ,
@@ -205,12 +206,13 @@ def bloom_model_forward(
205
206
alibi = self .build_alibi_tensor (attention_mask , self .num_heads , dtype = hidden_states .dtype )
206
207
207
208
# causal_mask is constructed every stage and its input is passed through different stages
208
- causal_mask = self . _prepare_attn_mask (
209
+ causal_mask = _prepare_4d_causal_attention_mask (
209
210
attention_mask ,
210
211
input_shape = (batch_size , seq_length ),
212
+ inputs_embeds = hidden_states ,
211
213
past_key_values_length = past_key_values_length ,
212
214
)
213
-
215
+ causal_mask = causal_mask . bool ()
214
216
# split the input tensor along sequence dimension
215
217
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
216
218
if shard_config and shard_config .enable_sequence_parallelism :
@@ -227,21 +229,15 @@ def bloom_model_forward(
227
229
all_hidden_states = all_hidden_states + (hidden_states ,)
228
230
229
231
if self .gradient_checkpointing and self .training :
230
-
231
- def create_custom_forward (module ):
232
- def custom_forward (* inputs ):
233
- # None for past_key_value
234
- return module (* inputs , use_cache = use_cache , output_attentions = output_attentions )
235
-
236
- return custom_forward
237
-
238
- outputs = torch .utils .checkpoint .checkpoint (
239
- create_custom_forward (block ),
232
+ outputs = self ._gradient_checkpointing_func (
233
+ block .__call__ ,
240
234
hidden_states ,
241
235
alibi ,
242
236
causal_mask ,
243
237
layer_past ,
244
238
head_mask [i ],
239
+ use_cache ,
240
+ output_attentions ,
245
241
)
246
242
else :
247
243
outputs = block (
@@ -1002,11 +998,13 @@ def forward(
1002
998
1003
999
alibi = self .build_alibi_tensor (attention_mask , self .num_heads , dtype = hidden_states .dtype )
1004
1000
1005
- causal_mask = self . _prepare_attn_mask (
1001
+ causal_mask = _prepare_4d_causal_attention_mask (
1006
1002
attention_mask ,
1007
1003
input_shape = (batch_size , seq_length ),
1004
+ inputs_embeds = hidden_states ,
1008
1005
past_key_values_length = past_key_values_length ,
1009
1006
)
1007
+ causal_mask = causal_mask .bool ()
1010
1008
# split the input tensor along sequence dimension
1011
1009
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
1012
1010
hidden_states = split_forward_gather_backward (
@@ -1018,21 +1016,15 @@ def forward(
1018
1016
all_hidden_states = all_hidden_states + (hidden_states ,)
1019
1017
1020
1018
if self .gradient_checkpointing and self .training :
1021
-
1022
- def create_custom_forward (module ):
1023
- def custom_forward (* inputs ):
1024
- # None for past_key_value
1025
- return module (* inputs , use_cache = use_cache , output_attentions = output_attentions )
1026
-
1027
- return custom_forward
1028
-
1029
- outputs = torch .utils .checkpoint .checkpoint (
1030
- create_custom_forward (block ),
1019
+ outputs = self ._gradient_checkpointing_func (
1020
+ block .__call__ ,
1031
1021
hidden_states ,
1032
1022
alibi ,
1033
1023
causal_mask ,
1034
1024
layer_past ,
1035
1025
head_mask [i ],
1026
+ use_cache ,
1027
+ output_attentions ,
1036
1028
)
1037
1029
else :
1038
1030
outputs = block (
0 commit comments