Skip to content

Commit

Permalink
Fix repo consistency (#36063)
Browse files Browse the repository at this point in the history
* fix 1

* fix 2

* fix modular

* simplify at the same time

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Cyril Vallez <cyril.vallez@gmail.com>
  • Loading branch information
3 people authored Feb 6, 2025
1 parent ed98ad3 commit 37faa97
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 23 deletions.
6 changes: 1 addition & 5 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,11 +927,7 @@ def forward(
hidden_states = outputs[0]
logits = self.score(hidden_states)

if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]
else:
batch_size, sequence_length = inputs_embeds.shape[:2]

batch_size = logits.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
Expand Down
31 changes: 13 additions & 18 deletions src/transformers/models/gpt_neox/modular_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,29 +625,24 @@ def forward(
hidden_states = outputs[0]
logits = self.score(hidden_states)

if input_ids is not None:
batch_size, sequence_length = input_ids.shape[:2]
else:
batch_size, sequence_length = inputs_embeds.shape[:2]

batch_size = logits.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
last_non_pad_token = -1
elif input_ids is not None:
# To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
token_indices = torch.arange(input_ids.shape[-1], device=logits.device)
last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)
last_non_pad_token = -1
logger.warning_once(
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
)

pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]

loss = None
if labels is not None:
Expand Down

0 comments on commit 37faa97

Please sign in to comment.