diff --git a/docs/source/en/model_doc/splinter.mdx b/docs/source/en/model_doc/splinter.mdx index 50d4e8db74816d..9623ec75016bea 100644 --- a/docs/source/en/model_doc/splinter.mdx +++ b/docs/source/en/model_doc/splinter.mdx @@ -72,3 +72,8 @@ This model was contributed by [yuvalkirstain](https://huggingface.co/yuvalkirsta [[autodoc]] SplinterForQuestionAnswering - forward + +## SplinterForPreTraining + +[[autodoc]] SplinterForPreTraining + - forward diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index a81ae60898d511..d071411ec7f809 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1532,6 +1532,7 @@ _import_structure["models.splinter"].extend( [ "SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST", + "SplinterForPreTraining", "SplinterForQuestionAnswering", "SplinterLayer", "SplinterModel", @@ -3830,6 +3831,7 @@ from .models.speech_to_text_2 import Speech2Text2ForCausalLM, Speech2Text2PreTrainedModel from .models.splinter import ( SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST, + SplinterForPreTraining, SplinterForQuestionAnswering, SplinterLayer, SplinterModel, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index b7589b98b23a8f..980c560019b040 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -161,6 +161,7 @@ ("openai-gpt", "OpenAIGPTLMHeadModel"), ("retribert", "RetriBertModel"), ("roberta", "RobertaForMaskedLM"), + ("splinter", "SplinterForPreTraining"), ("squeezebert", "SqueezeBertForMaskedLM"), ("t5", "T5ForConditionalGeneration"), ("tapas", "TapasForMaskedLM"), diff --git a/src/transformers/models/splinter/__init__.py b/src/transformers/models/splinter/__init__.py index d21e5c04c21715..9f056d7200a197 100644 --- a/src/transformers/models/splinter/__init__.py +++ b/src/transformers/models/splinter/__init__.py @@ -42,6 +42,7 @@ _import_structure["modeling_splinter"] = [ "SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST", "SplinterForQuestionAnswering", + "SplinterForPreTraining", "SplinterLayer", "SplinterModel", "SplinterPreTrainedModel", @@ -68,6 +69,7 @@ else: from .modeling_splinter import ( SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST, + SplinterForPreTraining, SplinterForQuestionAnswering, SplinterLayer, SplinterModel, diff --git a/src/transformers/models/splinter/modeling_splinter.py b/src/transformers/models/splinter/modeling_splinter.py index 0bf8411f2f76c4..ae8ba4fa34b0c7 100755 --- a/src/transformers/models/splinter/modeling_splinter.py +++ b/src/transformers/models/splinter/modeling_splinter.py @@ -16,6 +16,7 @@ import math +from dataclasses import dataclass from typing import List, Optional, Tuple, Union import torch @@ -24,7 +25,7 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN -from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, QuestionAnsweringModelOutput +from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, ModelOutput, QuestionAnsweringModelOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging @@ -940,3 +941,171 @@ def forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +@dataclass +class SplinterForPreTrainingOutput(ModelOutput): + """ + Class for outputs of Splinter as a span selection model. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when start and end positions are provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`): + Span-start scores (before SoftMax). + end_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`): + Span-end scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + start_logits: torch.FloatTensor = None + end_logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@add_start_docstrings( + """ + Splinter Model for the recurring span selection task as done during the pretraining. The difference to the QA task + is that we do not have a question, but multiple question tokens that replace the occurrences of recurring spans + instead. + """, + SPLINTER_START_DOCSTRING, +) +class SplinterForPreTraining(SplinterPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.splinter = SplinterModel(config) + self.splinter_qass = QuestionAwareSpanSelectionHead(config) + self.question_token_id = config.question_token_id + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward( + SPLINTER_INPUTS_DOCSTRING.format("batch_size, num_questions, sequence_length") + ) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + question_positions: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, SplinterForPreTrainingOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + question_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*): + The positions of all question tokens. If given, start_logits and end_logits will be of shape `(batch_size, + num_questions, sequence_length)`. If None, the first question token in each sequence in the batch will be + the only one for which start_logits and end_logits are calculated and they will be of shape `(batch_size, + sequence_length)`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if question_positions is None and start_positions is not None and end_positions is not None: + raise TypeError("question_positions must be specified in order to calculate the loss") + + elif question_positions is None and input_ids is None: + raise TypeError("question_positions must be specified when input_embeds is used") + + elif question_positions is None: + question_positions = self._prepare_question_positions(input_ids) + + outputs = self.splinter( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + batch_size, sequence_length, dim = sequence_output.size() + # [batch_size, num_questions, sequence_length] + start_logits, end_logits = self.splinter_qass(sequence_output, question_positions) + + num_questions = question_positions.size(1) + if attention_mask is not None: + attention_mask_for_each_question = attention_mask.unsqueeze(1).expand( + batch_size, num_questions, sequence_length + ) + start_logits = start_logits + (1 - attention_mask_for_each_question) * -10000.0 + end_logits = end_logits + (1 - attention_mask_for_each_question) * -10000.0 + + total_loss = None + # [batch_size, num_questions, sequence_length] + if start_positions is not None and end_positions is not None: + # sometimes the start/end positions are outside our model inputs, we ignore these terms + start_positions.clamp_(0, max(0, sequence_length - 1)) + end_positions.clamp_(0, max(0, sequence_length - 1)) + + # Ignore zero positions in the loss. Splinter never predicts zero + # during pretraining and zero is used for padding question + # tokens as well as for start and end positions of padded + # question tokens. + loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id) + start_loss = loss_fct( + start_logits.view(batch_size * num_questions, sequence_length), + start_positions.view(batch_size * num_questions), + ) + end_loss = loss_fct( + end_logits.view(batch_size * num_questions, sequence_length), + end_positions.view(batch_size * num_questions), + ) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return SplinterForPreTrainingOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def _prepare_question_positions(self, input_ids: torch.Tensor) -> torch.Tensor: + rows, flat_positions = torch.where(input_ids == self.config.question_token_id) + num_questions = torch.bincount(rows) + positions = torch.full( + (input_ids.size(0), num_questions.max()), + self.config.pad_token_id, + dtype=torch.long, + device=input_ids.device, + ) + cols = torch.cat([torch.arange(n) for n in num_questions]) + positions[rows, cols] = flat_positions + return positions diff --git a/tests/models/splinter/test_modeling_splinter.py b/tests/models/splinter/test_modeling_splinter.py index 9b62b822c09896..bc355bd2cd0719 100644 --- a/tests/models/splinter/test_modeling_splinter.py +++ b/tests/models/splinter/test_modeling_splinter.py @@ -14,7 +14,7 @@ # limitations under the License. """ Testing suite for the PyTorch Splinter model. """ - +import copy import unittest from transformers import is_torch_available @@ -27,7 +27,7 @@ if is_torch_available(): import torch - from transformers import SplinterConfig, SplinterForQuestionAnswering, SplinterModel + from transformers import SplinterConfig, SplinterForPreTraining, SplinterForQuestionAnswering, SplinterModel from transformers.models.splinter.modeling_splinter import SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST @@ -36,6 +36,7 @@ def __init__( self, parent, batch_size=13, + num_questions=3, seq_length=7, is_training=True, use_input_mask=True, @@ -43,6 +44,7 @@ def __init__( use_labels=True, vocab_size=99, hidden_size=32, + question_token_id=1, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37, @@ -59,6 +61,7 @@ def __init__( ): self.parent = parent self.batch_size = batch_size + self.num_questions = num_questions self.seq_length = seq_length self.is_training = is_training self.use_input_mask = use_input_mask @@ -66,6 +69,7 @@ def __init__( self.use_labels = use_labels self.vocab_size = vocab_size self.hidden_size = hidden_size + self.question_token_id = question_token_id self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size @@ -82,6 +86,7 @@ def __init__( def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + input_ids[:, 1] = self.question_token_id input_mask = None if self.use_input_mask: @@ -91,13 +96,13 @@ def prepare_config_and_inputs(self): if self.use_token_type_ids: token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - sequence_labels = None - token_labels = None - choice_labels = None + start_positions = None + end_positions = None + question_positions = None if self.use_labels: - sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) - token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) - choice_labels = ids_tensor([self.batch_size], self.num_choices) + start_positions = ids_tensor([self.batch_size, self.num_questions], self.type_sequence_label_size) + end_positions = ids_tensor([self.batch_size, self.num_questions], self.type_sequence_label_size) + question_positions = ids_tensor([self.batch_size, self.num_questions], self.num_labels) config = SplinterConfig( vocab_size=self.vocab_size, @@ -112,12 +117,20 @@ def prepare_config_and_inputs(self): type_vocab_size=self.type_vocab_size, is_decoder=False, initializer_range=self.initializer_range, + question_token_id=self.question_token_id, ) - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + return (config, input_ids, token_type_ids, input_mask, start_positions, end_positions, question_positions) def create_and_check_model( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + self, + config, + input_ids, + token_type_ids, + input_mask, + start_positions, + end_positions, + question_positions, ): model = SplinterModel(config=config) model.to(torch_device) @@ -128,7 +141,14 @@ def create_and_check_model( self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) def create_and_check_for_question_answering( - self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + self, + config, + input_ids, + token_type_ids, + input_mask, + start_positions, + end_positions, + question_positions, ): model = SplinterForQuestionAnswering(config=config) model.to(torch_device) @@ -137,12 +157,36 @@ def create_and_check_for_question_answering( input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, - start_positions=sequence_labels, - end_positions=sequence_labels, + start_positions=start_positions[:, 0], + end_positions=end_positions[:, 0], ) self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length)) self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length)) + def create_and_check_for_pretraining( + self, + config, + input_ids, + token_type_ids, + input_mask, + start_positions, + end_positions, + question_positions, + ): + model = SplinterForPreTraining(config=config) + model.to(torch_device) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + start_positions=start_positions, + end_positions=end_positions, + question_positions=question_positions, + ) + self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.num_questions, self.seq_length)) + self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.num_questions, self.seq_length)) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -150,11 +194,15 @@ def prepare_config_and_inputs_for_common(self): input_ids, token_type_ids, input_mask, - sequence_labels, - token_labels, - choice_labels, + start_positions, + end_positions, + question_positions, ) = config_and_inputs - inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} + inputs_dict = { + "input_ids": input_ids, + "token_type_ids": token_type_ids, + "attention_mask": input_mask, + } return config, inputs_dict @@ -165,11 +213,44 @@ class SplinterModelTest(ModelTesterMixin, unittest.TestCase): ( SplinterModel, SplinterForQuestionAnswering, + SplinterForPreTraining, ) if is_torch_available() else () ) + def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): + inputs_dict = copy.deepcopy(inputs_dict) + if return_labels: + if issubclass(model_class, SplinterForPreTraining): + inputs_dict["start_positions"] = torch.zeros( + self.model_tester.batch_size, + self.model_tester.num_questions, + dtype=torch.long, + device=torch_device, + ) + inputs_dict["end_positions"] = torch.zeros( + self.model_tester.batch_size, + self.model_tester.num_questions, + dtype=torch.long, + device=torch_device, + ) + inputs_dict["question_positions"] = torch.zeros( + self.model_tester.batch_size, + self.model_tester.num_questions, + dtype=torch.long, + device=torch_device, + ) + elif issubclass(model_class, SplinterForQuestionAnswering): + inputs_dict["start_positions"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + inputs_dict["end_positions"] = torch.zeros( + self.model_tester.batch_size, dtype=torch.long, device=torch_device + ) + + return inputs_dict + def setUp(self): self.model_tester = SplinterModelTester(self) self.config_tester = ConfigTester(self, config_class=SplinterConfig, hidden_size=37) @@ -191,6 +272,44 @@ def test_for_question_answering(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_question_answering(*config_and_inputs) + def test_for_pretraining(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_pretraining(*config_and_inputs) + + def test_inputs_embeds(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class)) + + if not self.is_encoder_decoder: + input_ids = inputs["input_ids"] + del inputs["input_ids"] + else: + encoder_input_ids = inputs["input_ids"] + decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids) + del inputs["input_ids"] + inputs.pop("decoder_input_ids", None) + + wte = model.get_input_embeddings() + if not self.is_encoder_decoder: + inputs["inputs_embeds"] = wte(input_ids) + else: + inputs["inputs_embeds"] = wte(encoder_input_ids) + inputs["decoder_inputs_embeds"] = wte(decoder_input_ids) + + with torch.no_grad(): + if isinstance(model, SplinterForPreTraining): + with self.assertRaises(TypeError): + # question_positions must not be None. + model(**inputs)[0] + else: + model(**inputs)[0] + @slow def test_model_from_pretrained(self): for model_name in SPLINTER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: @@ -217,3 +336,122 @@ def test_splinter_question_answering(self): self.assertEqual(torch.argmax(output.start_logits), 10) self.assertEqual(torch.argmax(output.end_logits), 12) + + @slow + def test_splinter_pretraining(self): + model = SplinterForPreTraining.from_pretrained("tau/splinter-base-qass") + + # Input: "[CLS] [QUESTION] was born in [QUESTION] . Brad returned to the United Kingdom later . [SEP]" + # Output should be the spans "Brad" and "the United Kingdom" + input_ids = torch.tensor( + [[101, 104, 1108, 1255, 1107, 104, 119, 7796, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102]] + ) + question_positions = torch.tensor([[1, 5]], dtype=torch.long) + output = model(input_ids, question_positions=question_positions) + + expected_shape = torch.Size((1, 2, 16)) + self.assertEqual(output.start_logits.shape, expected_shape) + self.assertEqual(output.end_logits.shape, expected_shape) + + self.assertEqual(torch.argmax(output.start_logits[0, 0]), 7) + self.assertEqual(torch.argmax(output.end_logits[0, 0]), 7) + self.assertEqual(torch.argmax(output.start_logits[0, 1]), 10) + self.assertEqual(torch.argmax(output.end_logits[0, 1]), 12) + + @slow + def test_splinter_pretraining_loss_requires_question_positions(self): + model = SplinterForPreTraining.from_pretrained("tau/splinter-base-qass") + + # Input: "[CLS] [QUESTION] was born in [QUESTION] . Brad returned to the United Kingdom later . [SEP]" + # Output should be the spans "Brad" and "the United Kingdom" + input_ids = torch.tensor( + [[101, 104, 1108, 1255, 1107, 104, 119, 7796, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102]] + ) + start_positions = torch.tensor([[7, 10]], dtype=torch.long) + end_positions = torch.tensor([7, 12], dtype=torch.long) + with self.assertRaises(TypeError): + model( + input_ids, + start_positions=start_positions, + end_positions=end_positions, + ) + + @slow + def test_splinter_pretraining_loss(self): + model = SplinterForPreTraining.from_pretrained("tau/splinter-base-qass") + + # Input: "[CLS] [QUESTION] was born in [QUESTION] . Brad returned to the United Kingdom later . [SEP]" + # Output should be the spans "Brad" and "the United Kingdom" + input_ids = torch.tensor( + [ + [101, 104, 1108, 1255, 1107, 104, 119, 7796, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102], + [101, 104, 1108, 1255, 1107, 104, 119, 7796, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102], + ] + ) + start_positions = torch.tensor([[7, 10], [7, 10]], dtype=torch.long) + end_positions = torch.tensor([[7, 12], [7, 12]], dtype=torch.long) + question_positions = torch.tensor([[1, 5], [1, 5]], dtype=torch.long) + output = model( + input_ids, + start_positions=start_positions, + end_positions=end_positions, + question_positions=question_positions, + ) + self.assertAlmostEqual(output.loss.item(), 0.0024, 4) + + @slow + def test_splinter_pretraining_loss_with_padding(self): + model = SplinterForPreTraining.from_pretrained("tau/splinter-base-qass") + + # Input: "[CLS] [QUESTION] was born in [QUESTION] . Brad returned to the United Kingdom later . [SEP]" + # Output should be the spans "Brad" and "the United Kingdom" + input_ids = torch.tensor( + [ + [101, 104, 1108, 1255, 1107, 104, 119, 7796, 1608, 1106, 1103, 1244, 2325, 1224, 119, 102], + ] + ) + start_positions = torch.tensor([[7, 10]], dtype=torch.long) + end_positions = torch.tensor([7, 12], dtype=torch.long) + question_positions = torch.tensor([[1, 5]], dtype=torch.long) + start_positions_with_padding = torch.tensor([[7, 10, 0]], dtype=torch.long) + end_positions_with_padding = torch.tensor([7, 12, 0], dtype=torch.long) + question_positions_with_padding = torch.tensor([[1, 5, 0]], dtype=torch.long) + output = model( + input_ids, + start_positions=start_positions, + end_positions=end_positions, + question_positions=question_positions, + ) + output_with_padding = model( + input_ids, + start_positions=start_positions_with_padding, + end_positions=end_positions_with_padding, + question_positions=question_positions_with_padding, + ) + + self.assertAlmostEqual(output.loss.item(), output_with_padding.loss.item(), 4) + + # Note that the original code uses 0 to denote padded question tokens + # and their start and end positions. As the pad_token_id of the model's + # config is used for the losse's ignore_index in SplinterForPreTraining, + # we add this test to ensure anybody making changes to the default + # value of the config, will be aware of the implication. + self.assertEqual(model.config.pad_token_id, 0) + + @slow + def test_splinter_pretraining_prepare_question_positions(self): + model = SplinterForPreTraining.from_pretrained("tau/splinter-base-qass") + + input_ids = torch.tensor( + [ + [101, 104, 1, 2, 104, 3, 4, 102], + [101, 1, 104, 2, 104, 3, 104, 102], + [101, 1, 2, 104, 104, 3, 4, 102], + [101, 1, 2, 3, 4, 5, 104, 102], + ] + ) + question_positions = torch.tensor([[1, 4, 0], [2, 4, 6], [3, 4, 0], [6, 0, 0]], dtype=torch.long) + output_without_positions = model(input_ids) + output_with_positions = model(input_ids, question_positions=question_positions) + self.assertTrue((output_without_positions.start_logits == output_with_positions.start_logits).all()) + self.assertTrue((output_without_positions.end_logits == output_with_positions.end_logits).all())