Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 10a815a

Browse files
flybird11111pre-commit-ci[bot]
andcommittedApr 24, 2024
[shardformer] update flashattention replacement (hpcaitech#5637)
* update transformers update transformers fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f3be14d commit 10a815a

File tree

7 files changed

+63
-19
lines changed

7 files changed

+63
-19
lines changed
 

‎colossalai/shardformer/policies/gpt2.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,20 @@ def preprocess(self):
3535
Reshape the Embedding layer to make the embedding dimension divisible by world_size
3636
"""
3737
self.tie_weight = self.tie_weight_check()
38+
self.origin_attn_implement = self.model.config._attn_implementation
3839
return self.model
3940

4041
def module_policy(self):
4142
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
4243

44+
ATTN_IMPLEMENTATION = {
45+
"eager": GPT2Attention,
46+
}
47+
4348
policy = {}
4449

50+
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
51+
4552
embedding_cls = None
4653
if self.shard_config.enable_tensor_parallelism:
4754
embedding_cls = col_nn.VocabParallelEmbedding1D
@@ -186,7 +193,7 @@ def module_policy(self):
186193
"forward": get_gpt2_flash_attention_forward(),
187194
},
188195
policy=policy,
189-
target_key=GPT2Attention,
196+
target_key=attn_cls,
190197
)
191198
if not self.shard_config.pipeline_stage_manager:
192199
policy[GPT2Model].method_replacement = {

‎colossalai/shardformer/policies/gptj.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,20 @@ def config_sanity_check(self):
3030

3131
def preprocess(self):
3232
self.tie_weight = self.tie_weight_check()
33+
self.origin_attn_implement = self.model.config._attn_implementation
3334
return self.model
3435

3536
def module_policy(self):
3637
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel
3738

39+
ATTN_IMPLEMENTATION = {
40+
"eager": GPTJAttention,
41+
}
42+
3843
policy = {}
3944

45+
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
46+
4047
embedding_cls = None
4148
if self.shard_config.enable_tensor_parallelism:
4249
embedding_cls = col_nn.VocabParallelEmbedding1D
@@ -160,7 +167,7 @@ def module_policy(self):
160167
"forward": get_gptj_flash_attention_forward(),
161168
},
162169
policy=policy,
163-
target_key=GPTJAttention,
170+
target_key=attn_cls,
164171
)
165172
if not self.shard_config.pipeline_stage_manager:
166173
self.append_or_create_method_replacement(

‎colossalai/shardformer/policies/llama.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,27 @@ def config_sanity_check(self):
3636

3737
def preprocess(self):
3838
self.tie_weight = self.tie_weight_check()
39+
self.origin_attn_implement = self.model.config._attn_implementation
3940
return self.model
4041

4142
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
42-
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaModel
43+
from transformers.models.llama.modeling_llama import (
44+
LlamaAttention,
45+
LlamaDecoderLayer,
46+
LlamaFlashAttention2,
47+
LlamaModel,
48+
LlamaSdpaAttention,
49+
)
4350

51+
ATTN_IMPLEMENTATION = {
52+
"eager": LlamaAttention,
53+
"flash_attention_2": LlamaFlashAttention2,
54+
"sdpa": LlamaSdpaAttention,
55+
}
4456
policy = {}
4557

58+
attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
59+
4660
embedding_cls = None
4761
if self.shard_config.enable_tensor_parallelism:
4862
embedding_cls = VocabParallelEmbedding1D
@@ -93,7 +107,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
93107
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
94108
},
95109
policy=policy,
96-
target_key=LlamaAttention,
110+
target_key=attn_cls,
97111
)
98112
elif sp_mode == "all_to_all":
99113
decoder_attribute_replacement = {
@@ -102,15 +116,15 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
102116
if getattr(self.model.config, "num_key_value_heads", False):
103117
decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
104118

105-
policy[LlamaAttention] = ModulePolicyDescription(
119+
policy[attn_cls] = ModulePolicyDescription(
106120
attribute_replacement=decoder_attribute_replacement,
107121
)
108122
self.append_or_create_method_replacement(
109123
description={
110124
"forward": get_llama_seq_parallel_attention_forward(sp_mode, sp_size, sp_group),
111125
},
112126
policy=policy,
113-
target_key=LlamaAttention,
127+
target_key=attn_cls,
114128
)
115129
self.append_or_create_method_replacement(
116130
description={
@@ -221,7 +235,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
221235
"forward": get_llama_flash_attention_forward(self.shard_config, sp_mode, sp_group, sp_size),
222236
},
223237
policy=policy,
224-
target_key=LlamaAttention,
238+
target_key=attn_cls,
225239
)
226240
if self.pipeline_stage_manager is None:
227241
# replace llama model forward method

‎colossalai/shardformer/policies/mistral.py

+16-7
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,26 @@ def config_sanity_check(self):
2626

2727
def preprocess(self):
2828
self.tie_weight = self.tie_weight_check()
29+
self.origin_attn_implement = self.model.config._attn_implementation
2930
return self.model
3031

3132
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
32-
from transformers.models.mistral.modeling_mistral import MistralAttention, MistralDecoderLayer, MistralModel
33+
from transformers.models.mistral.modeling_mistral import (
34+
MistralAttention,
35+
MistralDecoderLayer,
36+
MistralFlashAttention2,
37+
MistralModel,
38+
)
39+
40+
ATTN_IMPLEMENTATION = {
41+
"eager": MistralAttention,
42+
"flash_attention_2": MistralFlashAttention2,
43+
}
3344

3445
policy = {}
3546

47+
attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation]
48+
3649
embedding_cls = None
3750
if self.shard_config.enable_tensor_parallelism:
3851
embedding_cls = VocabParallelEmbedding1D
@@ -128,10 +141,10 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
128141
if self.shard_config.enable_flash_attention:
129142
self.append_or_create_method_replacement(
130143
description={
131-
"forward": get_mistral_flash_attention_forward(),
144+
"forward": get_mistral_flash_attention_forward(self.shard_config),
132145
},
133146
policy=policy,
134-
target_key=MistralAttention,
147+
target_key=attn_cls,
135148
)
136149

137150
return policy
@@ -143,10 +156,6 @@ def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict)
143156
method_replacement = {"forward": partial(new_forward)}
144157
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
145158

146-
def set_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
147-
method_replacement = {"forward": partial(new_forward)}
148-
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
149-
150159

151160
class MistralModelPolicy(MistralPolicy):
152161
def __init__(self) -> None:

‎colossalai/shardformer/policies/opt.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,21 @@ def config_sanity_check(self):
4444

4545
def preprocess(self):
4646
self.tie_weight = self.tie_weight_check()
47+
self.origin_attn_implement = self.model.config._attn_implementation
4748
return self.model
4849

4950
def module_policy(self):
50-
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
51+
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer, OptFlashAttention2
52+
53+
ATTN_IMPLEMENTATION = {
54+
"eager": OPTAttention,
55+
"flash_attention_2": OptFlashAttention2,
56+
}
5157

5258
policy = {}
5359

60+
attn_cls = ATTN_IMPLEMENTATION[self.model.config._attn_implementation]
61+
5462
embedding_cls = None
5563
if self.shard_config.enable_tensor_parallelism:
5664
embedding_cls = VocabParallelEmbedding1D
@@ -81,7 +89,7 @@ def module_policy(self):
8189
]
8290
)
8391

84-
policy[OPTAttention] = ModulePolicyDescription(
92+
policy[attn_cls] = ModulePolicyDescription(
8593
attribute_replacement={
8694
"embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
8795
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
@@ -151,7 +159,7 @@ def module_policy(self):
151159
"forward": get_opt_flash_attention_forward(self.shard_config),
152160
},
153161
policy=policy,
154-
target_key=OPTAttention,
162+
target_key=attn_cls,
155163
)
156164
if not self.shard_config.pipeline_stage_manager:
157165
self.append_or_create_method_replacement(

‎tests/kit/model_zoo/transformers/llama.py

-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def data_gen_for_casual_lm():
6565
num_attention_heads=4,
6666
max_position_embeddings=128,
6767
num_labels=16,
68-
attn_implementation="eager",
6968
)
7069

7170
if hasattr(config, "pad_token_id"):

‎tests/test_shardformer/test_model/test_shard_mistral.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def check_mistral(rank, world_size, port):
156156
run_mistral_test()
157157

158158

159-
@pytest.mark.skip("This test should be run on a version of transformers not less than 4.35.2.")
159+
@pytest.mark.skip("something wrong with pipeline parallelism")
160160
@pytest.mark.dist
161161
@rerun_if_address_is_in_use()
162162
@clear_cache_before_run()

0 commit comments

Comments
 (0)
Please sign in to comment.