Skip to content

Commit f1ebe54

Browse files
authored
[shardformer] update bloom model (#5518)
* update bloom model * remove the version restriction
1 parent cdb166c commit f1ebe54

File tree

2 files changed

+15
-30
lines changed

2 files changed

+15
-30
lines changed

colossalai/shardformer/modeling/bloom.py

+15-24
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
BloomModel,
2222
)
2323
from transformers.utils import logging
24-
24+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
2525
from colossalai.pipeline.stage_manager import PipelineStageManager
2626
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
2727
from colossalai.shardformer.shard import ShardConfig
@@ -205,12 +205,13 @@ def bloom_model_forward(
205205
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
206206

207207
# causal_mask is constructed every stage and its input is passed through different stages
208-
causal_mask = self._prepare_attn_mask(
208+
causal_mask = _prepare_4d_causal_attention_mask(
209209
attention_mask,
210210
input_shape=(batch_size, seq_length),
211+
inputs_embeds=hidden_states,
211212
past_key_values_length=past_key_values_length,
212213
)
213-
214+
causal_mask = causal_mask.bool()
214215
# split the input tensor along sequence dimension
215216
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
216217
if shard_config.enable_sequence_parallelism:
@@ -226,21 +227,15 @@ def bloom_model_forward(
226227
all_hidden_states = all_hidden_states + (hidden_states,)
227228

228229
if self.gradient_checkpointing and self.training:
229-
230-
def create_custom_forward(module):
231-
def custom_forward(*inputs):
232-
# None for past_key_value
233-
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
234-
235-
return custom_forward
236-
237-
outputs = torch.utils.checkpoint.checkpoint(
238-
create_custom_forward(block),
230+
outputs = self._gradient_checkpointing_func(
231+
block.__call__,
239232
hidden_states,
240233
alibi,
241234
causal_mask,
242235
layer_past,
243236
head_mask[i],
237+
use_cache,
238+
output_attentions,
244239
)
245240
else:
246241
outputs = block(
@@ -1000,11 +995,13 @@ def forward(
1000995

1001996
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
1002997

1003-
causal_mask = self._prepare_attn_mask(
998+
causal_mask = _prepare_4d_causal_attention_mask(
1004999
attention_mask,
10051000
input_shape=(batch_size, seq_length),
1001+
inputs_embeds=hidden_states,
10061002
past_key_values_length=past_key_values_length,
10071003
)
1004+
causal_mask = causal_mask.bool()
10081005
# split the input tensor along sequence dimension
10091006
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
10101007
hidden_states = split_forward_gather_backward(
@@ -1016,21 +1013,15 @@ def forward(
10161013
all_hidden_states = all_hidden_states + (hidden_states,)
10171014

10181015
if self.gradient_checkpointing and self.training:
1019-
1020-
def create_custom_forward(module):
1021-
def custom_forward(*inputs):
1022-
# None for past_key_value
1023-
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
1024-
1025-
return custom_forward
1026-
1027-
outputs = torch.utils.checkpoint.checkpoint(
1028-
create_custom_forward(block),
1016+
outputs = self._gradient_checkpointing_func(
1017+
block.__call__,
10291018
hidden_states,
10301019
alibi,
10311020
causal_mask,
10321021
layer_past,
10331022
head_mask[i],
1023+
use_cache,
1024+
output_attentions,
10341025
)
10351026
else:
10361027
outputs = block(

colossalai/shardformer/policies/bloom.py

-6
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,6 @@
2323
class BloomPolicy(Policy):
2424
def __init__(self) -> None:
2525
super().__init__()
26-
import transformers
27-
from packaging.version import Version
28-
29-
assert Version(transformers.__version__) <= Version(
30-
"4.33.0"
31-
), "The Bloom model should run on a transformers version not greater than 4.33.0."
3226

3327
def config_sanity_check(self):
3428
pass

0 commit comments

Comments
 (0)