Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer] fix attn replacement #5636

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions colossalai/shardformer/policies/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@

import colossalai.shardformer.layer as col_nn

from ..modeling.falcon import (
FalconPipelineForwards,
build_falcon_alibi_tensor_fn,
get_falcon_flash_attention_forward,
get_tp_falcon_decoder_layer_forward,
)
from ..modeling.falcon import FalconPipelineForwards, build_falcon_alibi_tensor_fn, get_tp_falcon_decoder_layer_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ["FalconPolicy"]
Expand All @@ -30,7 +25,7 @@ def preprocess(self):
return self.model

def module_policy(self):
from transformers.models.falcon.modeling_falcon import FalconAttention, FalconDecoderLayer, FalconModel
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel

if not self.model.config.new_decoder_architecture and self.model.config.multi_query:
warnings.warn(
Expand Down Expand Up @@ -141,11 +136,12 @@ def module_policy(self):
)

if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
description={"forward": get_falcon_flash_attention_forward()},
policy=policy,
target_key=FalconAttention,
)
warnings.warn("Falcon doesn't support flash attention now, fallback to transformers attention.")
# self.append_or_create_method_replacement(
# description={"forward": get_falcon_flash_attention_forward()},
# policy=policy,
# target_key=FalconAttention,
# )
return policy

def postprocess(self):
Expand Down
34 changes: 18 additions & 16 deletions colossalai/shardformer/policies/sam.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings

import colossalai.shardformer.layer as col_nn

from ..modeling.sam import forward_fn, get_sam_flash_attention_forward, get_sam_vision_flash_attention_forward
from ..modeling.sam import forward_fn
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ["SamPolicy", "SamModelPolicy"]
Expand All @@ -15,7 +17,6 @@ def preprocess(self):

def module_policy(self):
from transformers.models.sam.modeling_sam import (
SamAttention,
SamTwoWayAttentionBlock,
SamTwoWayTransformer,
SamVisionAttention,
Expand Down Expand Up @@ -210,20 +211,21 @@ def module_policy(self):

# use flash attention
if self.shard_config.enable_flash_attention:
self.append_or_create_method_replacement(
description={
"forward": get_sam_flash_attention_forward(),
},
policy=policy,
target_key=SamAttention,
)
self.append_or_create_method_replacement(
description={
"forward": get_sam_vision_flash_attention_forward(),
},
policy=policy,
target_key=SamVisionAttention,
)
warnings.warn("Flash attention is not supported in SAM model. Fallback to normal attention.")
# self.append_or_create_method_replacement(
# description={
# "forward": get_sam_flash_attention_forward(),
# },
# policy=policy,
# target_key=SamAttention,
# )
# self.append_or_create_method_replacement(
# description={
# "forward": get_sam_vision_flash_attention_forward(),
# },
# policy=policy,
# target_key=SamVisionAttention,
# )

return policy

Expand Down
16 changes: 16 additions & 0 deletions colossalai/shardformer/policies/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def module_policy(self):
WhisperDecoderLayer,
WhisperEncoder,
WhisperEncoderLayer,
WhisperFlashAttention2,
WhisperSdpaAttention,
)

policy = {}
Expand Down Expand Up @@ -242,6 +244,20 @@ def module_policy(self):
policy=policy,
target_key=WhisperAttention,
)
self.append_or_create_method_replacement(
description={
"forward": get_whisper_flash_attention_forward(),
},
policy=policy,
target_key=WhisperFlashAttention2,
)
self.append_or_create_method_replacement(
description={
"forward": get_whisper_flash_attention_forward(),
},
policy=policy,
target_key=WhisperSdpaAttention,
)
if not self.shard_config.pipeline_stage_manager:
self.append_or_create_method_replacement(
description={
Expand Down
1 change: 0 additions & 1 deletion tests/kit/model_zoo/transformers/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def data_gen_for_audio_classification():
encoder_ffn_dim=1536,
encoder_layers=2,
vocab_size=51866,
_attn_implementation="eager",
)

# register the Whisper variants
Expand Down
Loading