Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer] upgrade transformers to 4.39.3 #5815

Merged
merged 5 commits into from
Jun 14, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
[shardformer]upgrade transformers for mistral (#5808)
* upgrade transformers for mistral

* fix

* fix
  • Loading branch information
flybird11111 authored Jun 12, 2024
commit c4dee4fdd1b60a3f6c2321d349672fcfff8b011b
31 changes: 26 additions & 5 deletions colossalai/shardformer/modeling/mistral.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,10 @@
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@@ -77,7 +80,7 @@ def mistral_model_forward(
else:
position_ids = position_ids.view(-1, seq_length).long()

if attention_mask is not None and self._use_flash_attention_2 and use_cache:
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
@@ -97,9 +100,18 @@ def mistral_model_forward(
is_causal=True,
)
else:
if self._use_flash_attention_2:
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
@@ -462,7 +474,7 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if attention_mask is not None and self._use_flash_attention_2 and use_cache:
if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
@@ -481,9 +493,18 @@ def forward(
is_causal=True,
)
else:
if self._use_flash_attention_2:
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
2 changes: 2 additions & 0 deletions colossalai/shardformer/policies/mistral.py
Original file line number Diff line number Diff line change
@@ -42,11 +42,13 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
MistralDecoderLayer,
MistralFlashAttention2,
MistralModel,
MistralSdpaAttention,
)

ATTN_IMPLEMENTATION = {
"eager": MistralAttention,
"flash_attention_2": MistralFlashAttention2,
"sdpa": MistralSdpaAttention,
}

policy = {}