Skip to content

Commit 0d0a582

Browse files
wangbluover217pre-commit-ci[bot]Camille7777Edenzzzz
authored
[shardformer] update transformers (#5583)
* flash_attention forward upgrade * llama_model_forward * remove useless comment * update the requirements.txt * add the transformers version requirements * remove the LATEST VERSION try * [shardformer] update bloom model (#5518) * update bloom model * remove the version restriction * [shardformer] update_falcon (#5520) * [shardformer] update mistral model (#5511) * [shardformer] update gpt2 (#5502) * [shardformer] update gptj model (#5503) * [shardformer] update opt (#5522) * [shardformer] update t5 model (#5524) * [shardformer] update whisper model (#5529) * [shardformer] update vit model (#5530) * update vit model * remove the output_hidden_states * [shardformer] fix llama modeling * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [zero] support multiple (partial) backward passes (#5596) * [zero] support multiple (partial) backward passes * [misc] update requirements * [zero] support multiple (partial) backward passes (#5596) * [zero] support multiple (partial) backward passes * [misc] update requirements * fix conflicts * [doc] fix ColossalMoE readme (#5599) * fix readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * merge with main * merge with main * llama_model_forward * remove useless comment * remove the LATEST VERSION try * [shardformer] update bloom model (#5518) * update bloom model * remove the version restriction * [shardformer] update mistral model (#5511) * [shardformer] update opt (#5522) * [shardformer] update whisper model (#5529) * [shardformer] fix llama modeling * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [hotfix] Fix examples no pad token & auto parallel codegen bug; (#5606) * fix no pad token bug * fixed some auto parallel codegen bug, but might not run on torch 2.1 --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [shardformer] fix pipeline grad ckpt (#5620) * [shardformer] fix pipeline grad ckpt * [shardformer] fix whisper (#5628) * [test] fix llama model test * fix the opt upgrade (#5634) * [shardformer] fix attn replacement (#5636) * [shardformer] update flashattention replacement (#5637) * update transformers update transformers fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [test] fix llama test (#5638) * [gemini] fix buffer cast (#5639) * Fix shardformer upgrade (#5640) * fix llama model * fix the mistral * fix the shardformer model * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [shardformer]support pipeline parallelism for mistral. (#5642) * [shardformer] fix attn replacement (#5636) * [shardformer] update flashattention replacement (#5637) * update transformers update transformers fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Support LLaMA-3 CPT and ST (#5619) * support LLaMA-3 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Run pre-commit --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [exampe] update llama example (#5626) * [plugin] support dp inside for hybriad parallel * [example] update llama benchmark * [example] update llama benchmark * [example] update llama readme * [example] update llama readme * [example] llama3 (#5631) * release llama3 * [release] llama3 * [release] llama3 * [release] llama3 * [release] llama3 * [test] fix llama test (#5638) * [gemini] fix buffer cast (#5639) * support pp for mistral * fix * fix fix fix * fix --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Camille Zhong <44392324+Camille7777@users.noreply.github.com> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com>
1 parent f4c5aaf commit 0d0a582

27 files changed

+1153
-439
lines changed

colossalai/shardformer/modeling/bloom.py

+15-23
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch.distributed import ProcessGroup
77
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
88
from torch.nn import functional as F
9+
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
910
from transformers.modeling_outputs import (
1011
BaseModelOutputWithPastAndCrossAttentions,
1112
CausalLMOutputWithCrossAttentions,
@@ -205,12 +206,13 @@ def bloom_model_forward(
205206
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
206207

207208
# causal_mask is constructed every stage and its input is passed through different stages
208-
causal_mask = self._prepare_attn_mask(
209+
causal_mask = _prepare_4d_causal_attention_mask(
209210
attention_mask,
210211
input_shape=(batch_size, seq_length),
212+
inputs_embeds=hidden_states,
211213
past_key_values_length=past_key_values_length,
212214
)
213-
215+
causal_mask = causal_mask.bool()
214216
# split the input tensor along sequence dimension
215217
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
216218
if shard_config and shard_config.enable_sequence_parallelism:
@@ -227,21 +229,15 @@ def bloom_model_forward(
227229
all_hidden_states = all_hidden_states + (hidden_states,)
228230

229231
if self.gradient_checkpointing and self.training:
230-
231-
def create_custom_forward(module):
232-
def custom_forward(*inputs):
233-
# None for past_key_value
234-
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
235-
236-
return custom_forward
237-
238-
outputs = torch.utils.checkpoint.checkpoint(
239-
create_custom_forward(block),
232+
outputs = self._gradient_checkpointing_func(
233+
block.__call__,
240234
hidden_states,
241235
alibi,
242236
causal_mask,
243237
layer_past,
244238
head_mask[i],
239+
use_cache,
240+
output_attentions,
245241
)
246242
else:
247243
outputs = block(
@@ -1002,11 +998,13 @@ def forward(
1002998

1003999
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
10041000

1005-
causal_mask = self._prepare_attn_mask(
1001+
causal_mask = _prepare_4d_causal_attention_mask(
10061002
attention_mask,
10071003
input_shape=(batch_size, seq_length),
1004+
inputs_embeds=hidden_states,
10081005
past_key_values_length=past_key_values_length,
10091006
)
1007+
causal_mask = causal_mask.bool()
10101008
# split the input tensor along sequence dimension
10111009
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
10121010
hidden_states = split_forward_gather_backward(
@@ -1018,21 +1016,15 @@ def forward(
10181016
all_hidden_states = all_hidden_states + (hidden_states,)
10191017

10201018
if self.gradient_checkpointing and self.training:
1021-
1022-
def create_custom_forward(module):
1023-
def custom_forward(*inputs):
1024-
# None for past_key_value
1025-
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
1026-
1027-
return custom_forward
1028-
1029-
outputs = torch.utils.checkpoint.checkpoint(
1030-
create_custom_forward(block),
1019+
outputs = self._gradient_checkpointing_func(
1020+
block.__call__,
10311021
hidden_states,
10321022
alibi,
10331023
causal_mask,
10341024
layer_past,
10351025
head_mask[i],
1026+
use_cache,
1027+
output_attentions,
10361028
)
10371029
else:
10381030
outputs = block(

0 commit comments

Comments
 (0)