@@ -36,13 +36,27 @@ def config_sanity_check(self):
36
36
37
37
def preprocess (self ):
38
38
self .tie_weight = self .tie_weight_check ()
39
+ self .origin_attn_implement = self .model .config ._attn_implementation
39
40
return self .model
40
41
41
42
def module_policy (self ) -> Dict [Union [str , nn .Module ], ModulePolicyDescription ]:
42
- from transformers .models .llama .modeling_llama import LlamaAttention , LlamaDecoderLayer , LlamaModel
43
+ from transformers .models .llama .modeling_llama import (
44
+ LlamaAttention ,
45
+ LlamaDecoderLayer ,
46
+ LlamaFlashAttention2 ,
47
+ LlamaModel ,
48
+ LlamaSdpaAttention ,
49
+ )
43
50
51
+ ATTN_IMPLEMENTATION = {
52
+ "eager" : LlamaAttention ,
53
+ "flash_attention_2" : LlamaFlashAttention2 ,
54
+ "sdpa" : LlamaSdpaAttention ,
55
+ }
44
56
policy = {}
45
57
58
+ attn_cls = ATTN_IMPLEMENTATION [self .origin_attn_implement ]
59
+
46
60
embedding_cls = None
47
61
if self .shard_config .enable_tensor_parallelism :
48
62
embedding_cls = VocabParallelEmbedding1D
@@ -93,7 +107,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
93
107
"forward" : get_llama_seq_parallel_attention_forward (sp_mode , sp_size , sp_group ),
94
108
},
95
109
policy = policy ,
96
- target_key = LlamaAttention ,
110
+ target_key = attn_cls ,
97
111
)
98
112
elif sp_mode == "all_to_all" :
99
113
decoder_attribute_replacement = {
@@ -102,15 +116,15 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
102
116
if getattr (self .model .config , "num_key_value_heads" , False ):
103
117
decoder_attribute_replacement ["num_key_value_heads" ] = self .model .config .num_key_value_heads // sp_size
104
118
105
- policy [LlamaAttention ] = ModulePolicyDescription (
119
+ policy [attn_cls ] = ModulePolicyDescription (
106
120
attribute_replacement = decoder_attribute_replacement ,
107
121
)
108
122
self .append_or_create_method_replacement (
109
123
description = {
110
124
"forward" : get_llama_seq_parallel_attention_forward (sp_mode , sp_size , sp_group ),
111
125
},
112
126
policy = policy ,
113
- target_key = LlamaAttention ,
127
+ target_key = attn_cls ,
114
128
)
115
129
self .append_or_create_method_replacement (
116
130
description = {
@@ -221,7 +235,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
221
235
"forward" : get_llama_flash_attention_forward (self .shard_config , sp_mode , sp_group , sp_size ),
222
236
},
223
237
policy = policy ,
224
- target_key = LlamaAttention ,
238
+ target_key = attn_cls ,
225
239
)
226
240
if self .pipeline_stage_manager is None :
227
241
# replace llama model forward method
0 commit comments