diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 1684dd6dc04a..d83ee58af5ee 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -927,11 +927,7 @@ def forward( hidden_states = outputs[0] logits = self.score(hidden_states) - if input_ids is not None: - batch_size, sequence_length = input_ids.shape[:2] - else: - batch_size, sequence_length = inputs_embeds.shape[:2] - + batch_size = logits.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: diff --git a/src/transformers/models/gpt_neox/modular_gpt_neox.py b/src/transformers/models/gpt_neox/modular_gpt_neox.py index 95bbdaa77671..295882a9eedb 100644 --- a/src/transformers/models/gpt_neox/modular_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modular_gpt_neox.py @@ -625,29 +625,24 @@ def forward( hidden_states = outputs[0] logits = self.score(hidden_states) - if input_ids is not None: - batch_size, sequence_length = input_ids.shape[:2] - else: - batch_size, sequence_length = inputs_embeds.shape[:2] - + batch_size = logits.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: - sequence_lengths = -1 + last_non_pad_token = -1 + elif input_ids is not None: + # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id + non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) + token_indices = torch.arange(input_ids.shape[-1], device=logits.device) + last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: - if input_ids is not None: - # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 - sequence_lengths = sequence_lengths % input_ids.shape[-1] - sequence_lengths = sequence_lengths.to(logits.device) - else: - sequence_lengths = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) + last_non_pad_token = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] loss = None if labels is not None: