Skip to content

Commit 7227e81

Browse files
committed
[shardformer] update bloom model (#5518)
* update bloom model * remove the version restriction
1 parent e2ff589 commit 7227e81

File tree

2 files changed

+15
-30
lines changed

2 files changed

+15
-30
lines changed

colossalai/shardformer/modeling/bloom.py

+15-24
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
BloomModel,
2222
)
2323
from transformers.utils import logging
24-
24+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
2525
from colossalai.pipeline.stage_manager import PipelineStageManager
2626
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
2727
from colossalai.shardformer.shard import ShardConfig
@@ -205,12 +205,13 @@ def bloom_model_forward(
205205
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
206206

207207
# causal_mask is constructed every stage and its input is passed through different stages
208-
causal_mask = self._prepare_attn_mask(
208+
causal_mask = _prepare_4d_causal_attention_mask(
209209
attention_mask,
210210
input_shape=(batch_size, seq_length),
211+
inputs_embeds=hidden_states,
211212
past_key_values_length=past_key_values_length,
212213
)
213-
214+
causal_mask = causal_mask.bool()
214215
# split the input tensor along sequence dimension
215216
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
216217
if shard_config and shard_config.enable_sequence_parallelism:
@@ -227,21 +228,15 @@ def bloom_model_forward(
227228
all_hidden_states = all_hidden_states + (hidden_states,)
228229

229230
if self.gradient_checkpointing and self.training:
230-
231-
def create_custom_forward(module):
232-
def custom_forward(*inputs):
233-
# None for past_key_value
234-
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
235-
236-
return custom_forward
237-
238-
outputs = torch.utils.checkpoint.checkpoint(
239-
create_custom_forward(block),
231+
outputs = self._gradient_checkpointing_func(
232+
block.__call__,
240233
hidden_states,
241234
alibi,
242235
causal_mask,
243236
layer_past,
244237
head_mask[i],
238+
use_cache,
239+
output_attentions,
245240
)
246241
else:
247242
outputs = block(
@@ -1002,11 +997,13 @@ def forward(
1002997

1003998
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
1004999

1005-
causal_mask = self._prepare_attn_mask(
1000+
causal_mask = _prepare_4d_causal_attention_mask(
10061001
attention_mask,
10071002
input_shape=(batch_size, seq_length),
1003+
inputs_embeds=hidden_states,
10081004
past_key_values_length=past_key_values_length,
10091005
)
1006+
causal_mask = causal_mask.bool()
10101007
# split the input tensor along sequence dimension
10111008
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
10121009
hidden_states = split_forward_gather_backward(
@@ -1018,21 +1015,15 @@ def forward(
10181015
all_hidden_states = all_hidden_states + (hidden_states,)
10191016

10201017
if self.gradient_checkpointing and self.training:
1021-
1022-
def create_custom_forward(module):
1023-
def custom_forward(*inputs):
1024-
# None for past_key_value
1025-
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
1026-
1027-
return custom_forward
1028-
1029-
outputs = torch.utils.checkpoint.checkpoint(
1030-
create_custom_forward(block),
1018+
outputs = self._gradient_checkpointing_func(
1019+
block.__call__,
10311020
hidden_states,
10321021
alibi,
10331022
causal_mask,
10341023
layer_past,
10351024
head_mask[i],
1025+
use_cache,
1026+
output_attentions,
10361027
)
10371028
else:
10381029
outputs = block(

colossalai/shardformer/policies/bloom.py

-6
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,6 @@
2424
class BloomPolicy(Policy):
2525
def __init__(self) -> None:
2626
super().__init__()
27-
import transformers
28-
from packaging.version import Version
29-
30-
assert Version(transformers.__version__) <= Version(
31-
"4.33.0"
32-
), "The Bloom model should run on a transformers version not greater than 4.33.0."
3327

3428
def config_sanity_check(self):
3529
pass

0 commit comments

Comments
 (0)