Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer]upgrade transformers for gpt2/gptj/whisper #5807

Merged
merged 3 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,10 @@ def gpt2_for_sequence_classification_forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
Expand Down
20 changes: 15 additions & 5 deletions colossalai/shardformer/modeling/gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def _get_attention_mask(
hidden_states: torch.Tensor,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]],
attention_mask: Optional[torch.FloatTensor],
use_flash_attention_2: bool = False,
) -> Optional[Union[torch.Tensor, dict]]:
batch_size, seq_len = hidden_states.shape[:2]
past_key_values_length = 0
Expand All @@ -47,7 +48,7 @@ def _get_attention_mask(
attention_mask,
is_causal=True,
)
elif attention_mask is not None:
elif use_flash_attention_2 and attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
Expand Down Expand Up @@ -162,7 +163,9 @@ def gptj_model_forward(

output_shape = input_shape + (hidden_states.size(-1),)

attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
attention_mask = _get_attention_mask(
self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
)

if self.gradient_checkpointing and self.training:
if use_cache:
Expand Down Expand Up @@ -419,7 +422,10 @@ def gptj_for_sequence_classification_forward(
sequence_lengths = -1
else:
if input_ids is not None:
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
Expand Down Expand Up @@ -712,7 +718,9 @@ def forward(

hidden_states = self.drop(hidden_states)

attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
attention_mask = _get_attention_mask(
self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
)

output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)

Expand Down Expand Up @@ -886,7 +894,9 @@ def forward(
hidden_states = self.drop(hidden_states)

output_shape = input_shape + (hidden_states.size(-1),)
attention_mask = _get_attention_mask(self, shard_config, hidden_states, past_key_values, attention_mask)
attention_mask = _get_attention_mask(
self, shard_config, hidden_states, past_key_values, attention_mask, self._use_flash_attention_2
)

if self.gradient_checkpointing and self.training:
if use_cache:
Expand Down
32 changes: 27 additions & 5 deletions colossalai/shardformer/modeling/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
SequenceClassifierOutput,
)
from transformers.models.whisper.modeling_whisper import (
_HIDDEN_STATES_START_POSITION,
WhisperDecoder,
WhisperEncoder,
WhisperForAudioClassification,
Expand Down Expand Up @@ -166,6 +167,7 @@ def forward(
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
position_ids=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
Expand Down Expand Up @@ -199,9 +201,13 @@ def forward(

# embed positions
if input_ids is not None:
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
positions = self.embed_positions(
input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
)
else:
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
positions = self.embed_positions(
inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
)

hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
Expand Down Expand Up @@ -599,6 +605,7 @@ def whisper_decoder_forward(
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
position_ids=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
Expand Down Expand Up @@ -716,9 +723,13 @@ def whisper_decoder_forward(

# embed positions
if input_ids is not None:
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
positions = self.embed_positions(
input_ids, past_key_values_length=past_key_values_length, position_ids=position_ids
)
else:
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
positions = self.embed_positions(
inputs_embeds, past_key_values_length=past_key_values_length, position_ids=position_ids
)

hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
Expand Down Expand Up @@ -841,6 +852,7 @@ def whisper_model_forward(
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
Expand Down Expand Up @@ -944,6 +956,7 @@ def whisper_model_forward(
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
position_ids=decoder_position_ids,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
Expand Down Expand Up @@ -986,6 +999,7 @@ def whisper_for_conditional_generation_forward(
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
Expand Down Expand Up @@ -1048,6 +1062,7 @@ def whisper_for_conditional_generation_forward(
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
decoder_inputs_embeds=decoder_inputs_embeds,
decoder_position_ids=decoder_position_ids,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
Expand Down Expand Up @@ -1118,6 +1133,12 @@ def whisper_for_audio_classification_forward(
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

if self.config.use_weighted_layer_sum:
output_hidden_states = True
elif output_hidden_states is None:
output_hidden_states = self.config.output_hidden_states

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# audio_classification only holds encoder
Expand All @@ -1138,7 +1159,8 @@ def whisper_for_audio_classification_forward(
return encoder_outputs

if self.config.use_weighted_layer_sum:
hidden_states = torch.stack(encoder_outputs, dim=1)
hidden_states = encoder_outputs[_HIDDEN_STATES_START_POSITION]
hidden_states = torch.stack(hidden_states, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
Expand Down
8 changes: 2 additions & 6 deletions colossalai/shardformer/policies/gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,11 @@ def preprocess(self):
return self.model

def module_policy(self):
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel

ATTN_IMPLEMENTATION = {
"eager": GPTJAttention,
}
from transformers.models.gptj.modeling_gptj import GPTJ_ATTENTION_CLASSES, GPTJBlock, GPTJModel

policy = {}

attn_cls = ATTN_IMPLEMENTATION[self.origin_attn_implement]
attn_cls = GPTJ_ATTENTION_CLASSES[self.origin_attn_implement]

embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ ray
sentencepiece
google
protobuf
transformers>=4.36.2,<4.40.0
transformers==4.39.3
peft>=0.7.1
bitsandbytes>=0.39.0
rpyc==6.0.0
Expand Down
Loading