From 6745567b6e3d294dc996ee4ac0686ab74b5b2915 Mon Sep 17 00:00:00 2001 From: David Reguera <33068707+nablabits@users.noreply.github.com> Date: Mon, 28 Aug 2023 15:31:33 +0200 Subject: [PATCH] Add type hints for several pytorch models (batch-4) (#25749) * Add type hints for MGP STR model * Add missing type hints for plbart model * Add type hints for Pix2struct model * Add missing type hints to Rag model and tweak the docstring * Add missing type hints to Sam model * Add missing type hints to Swin2sr model * Fix a type hint for Pix2StructTextModel Co-authored-by: Matt * Fix typo on Rag model docstring Co-authored-by: Matt * Fix linter --------- Co-authored-by: Matt --- .../models/mgp_str/modeling_mgp_str.py | 20 ++++++++----- .../models/pix2struct/modeling_pix2struct.py | 28 +++++++++---------- .../models/plbart/modeling_plbart.py | 4 +-- src/transformers/models/rag/modeling_rag.py | 18 +++++------- src/transformers/models/sam/modeling_sam.py | 2 +- .../models/swin2sr/modeling_swin2sr.py | 2 +- 6 files changed, 38 insertions(+), 36 deletions(-) diff --git a/src/transformers/models/mgp_str/modeling_mgp_str.py b/src/transformers/models/mgp_str/modeling_mgp_str.py index 5e34faf408858e..5d1f5bea7bfd35 100644 --- a/src/transformers/models/mgp_str/modeling_mgp_str.py +++ b/src/transformers/models/mgp_str/modeling_mgp_str.py @@ -380,7 +380,13 @@ def get_input_embeddings(self) -> nn.Module: return self.embeddings.proj @add_start_docstrings_to_model_forward(MGP_STR_INPUTS_DOCSTRING) - def forward(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None): + def forward( + self, + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -437,12 +443,12 @@ def __init__(self, config: MgpstrConfig) -> None: @replace_return_docstrings(output_type=MgpstrModelOutput, config_class=MgpstrConfig) def forward( self, - pixel_values, - output_attentions=None, - output_a3_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + pixel_values: torch.FloatTensor, + output_attentions: Optional[bool] = None, + output_a3_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], MgpstrModelOutput]: r""" output_a3_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of a3 modules. See `a3_attentions` under returned tensors diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 015007a9679b9b..288e31a126e675 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1387,21 +1387,21 @@ def set_output_embeddings(self, new_embeddings): @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) def forward( self, - input_ids=None, - attention_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - inputs_embeds=None, - head_mask=None, - cross_attn_head_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - labels=None, - return_dict=None, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + labels: Optional[torch.LongTensor] = None, + return_dict: Optional[bool] = None, **kwargs, - ): + ) -> Union[Tuple[torch.FloatTensor, ...], CausalLMOutputWithCrossAttentions]: r""" Returns: diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 4271a37ee5dd1b..62f62dbb953917 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1177,7 +1177,7 @@ def forward( encoder_outputs: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds=None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -1302,7 +1302,7 @@ def forward( encoder_outputs: Optional[List[torch.FloatTensor]] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds=None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index 21ee10386a8157..15f4fca475df10 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -462,16 +462,12 @@ def from_pretrained_question_encoder_generator( `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information. context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*): Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the - retriever. - - If the model has is not initialized with a `retriever` ``context_input_ids` has to be provided to the - forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`]. context_attention_mask - (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, - returned when *output_retrieved=True*): Attention mask post-processed from the retrieved documents and the - question encoder `input_ids` by the retriever. - - If the model has is not initialized with a `retriever` `context_attention_mask` has to be provided to the - forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`]. + retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to + the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`]. + context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*): + Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the + retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be + provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`]. use_cache (`bool`, *optional*, defaults to `True`): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). @@ -545,7 +541,7 @@ def forward( past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, doc_scores: Optional[torch.FloatTensor] = None, context_input_ids: Optional[torch.LongTensor] = None, - context_attention_mask=None, + context_attention_mask: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, diff --git a/src/transformers/models/sam/modeling_sam.py b/src/transformers/models/sam/modeling_sam.py index 3b8e1aba714cda..abf5544a5b4de6 100644 --- a/src/transformers/models/sam/modeling_sam.py +++ b/src/transformers/models/sam/modeling_sam.py @@ -1296,7 +1296,7 @@ def forward( target_embedding: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict=None, + return_dict: Optional[bool] = None, **kwargs, ) -> List[Dict[str, torch.Tensor]]: r""" diff --git a/src/transformers/models/swin2sr/modeling_swin2sr.py b/src/transformers/models/swin2sr/modeling_swin2sr.py index 9464981bafeef1..72de9ac1cb5c97 100644 --- a/src/transformers/models/swin2sr/modeling_swin2sr.py +++ b/src/transformers/models/swin2sr/modeling_swin2sr.py @@ -903,7 +903,7 @@ def pad_and_normalize(self, pixel_values): ) def forward( self, - pixel_values, + pixel_values: torch.FloatTensor, head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None,