Skip to content

Commit

Permalink
Add type hints for pytorch models (final batch) (huggingface#25750)
Browse files Browse the repository at this point in the history
* Add type hints for table_transformer

* Add type hints to Timesformer model

* Add type hints to Timm Backbone model

* Add type hints to TVLT family models

* Add type hints to Vivit family models

* Use the typing instance instead of the python builtin.

* Fix the `replace_return_docstrings` decorator for Vivit model

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>

---------

Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
  • Loading branch information
2 people authored and blbadger committed Nov 8, 2023
1 parent 102e049 commit ec7b5be
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union

import torch
from torch import Tensor, nn
Expand Down Expand Up @@ -814,15 +814,15 @@ def _set_gradient_checkpointing(self, module, value=False):
Pixel values can be obtained using [`DetrImageProcessor`]. See [`DetrImageProcessor.__call__`] for details.
pixel_mask (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
pixel_mask (`torch.FloatTensor` of shape `(batch_size, height, width)`, *optional*):
Mask to avoid performing attention on padding pixel values. Mask values selected in `[0, 1]`:
- 1 for pixels that are real (i.e. **not masked**),
- 0 for pixels that are padding (i.e. **masked**).
[What are attention masks?](../glossary#attention-mask)
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, num_queries)`, *optional*):
decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
Not used by default. Can be used to mask object queries.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
Expand Down Expand Up @@ -1190,16 +1190,16 @@ def unfreeze_backbone(self):
@replace_return_docstrings(output_type=TableTransformerModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values,
pixel_mask=None,
decoder_attention_mask=None,
encoder_outputs=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: torch.FloatTensor,
pixel_mask: Optional[torch.FloatTensor] = None,
decoder_attention_mask: Optional[torch.FloatTensor] = None,
encoder_outputs: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], TableTransformerModelOutput]:
r"""
Returns:
Expand Down Expand Up @@ -1351,17 +1351,17 @@ def _set_aux_loss(self, outputs_class, outputs_coord):
@replace_return_docstrings(output_type=TableTransformerObjectDetectionOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values,
pixel_mask=None,
decoder_attention_mask=None,
encoder_outputs=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: torch.FloatTensor,
pixel_mask: Optional[torch.FloatTensor] = None,
decoder_attention_mask: Optional[torch.FloatTensor] = None,
encoder_outputs: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[List[Dict]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], TableTransformerObjectDetectionOutput]:
r"""
labels (`List[Dict]` of len `(batch_size,)`, *optional*):
Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
Expand Down
10 changes: 5 additions & 5 deletions src/transformers/models/timesformer/modeling_timesformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,11 +559,11 @@ class PreTrainedModel
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
r"""
Returns:
Expand Down
11 changes: 9 additions & 2 deletions src/transformers/models/timm_backbone/modeling_timm_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple, Union
from typing import Optional, Tuple, Union

import torch

from ...modeling_outputs import BackboneOutput
from ...modeling_utils import PreTrainedModel
Expand Down Expand Up @@ -107,7 +109,12 @@ def _init_weights(self, module):
pass

def forward(
self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs
self,
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[BackboneOutput, Tuple[Tensor, ...]]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states = (
Expand Down
60 changes: 30 additions & 30 deletions src/transformers/models/tvlt/modeling_tvlt.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,16 +715,16 @@ class PreTrainedModel
@replace_return_docstrings(output_type=TvltModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values,
audio_values,
pixel_mask=None,
audio_mask=None,
mask_pixel=False,
mask_audio=False,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
) -> Union[tuple, TvltModelOutput]:
pixel_values: torch.FloatTensor,
audio_values: torch.FloatTensor,
pixel_mask: Optional[torch.FloatTensor] = None,
audio_mask: Optional[torch.FloatTensor] = None,
mask_pixel: bool = False,
mask_audio: bool = False,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], TvltModelOutput]:
r"""
Returns:
Expand Down Expand Up @@ -1049,17 +1049,17 @@ def concatenate_mask(self, mask_token, sequence, ids_restore):
@replace_return_docstrings(output_type=TvltForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values,
audio_values,
pixel_mask=None,
audio_mask=None,
labels=None,
pixel_values_mixed=None,
pixel_mask_mixed=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
) -> Union[tuple, TvltForPreTrainingOutput]:
pixel_values: torch.FloatTensor,
audio_values: torch.FloatTensor,
pixel_mask: Optional[torch.FloatTensor] = None,
audio_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
pixel_values_mixed: Optional[torch.FloatTensor] = None,
pixel_mask_mixed: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], TvltForPreTrainingOutput]:
r"""
pixel_values_mixed (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
Pixel values that mix positive and negative samples in Tvlt vision-audio matching. Audio values can be
Expand Down Expand Up @@ -1250,15 +1250,15 @@ def __init__(self, config):
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values,
audio_values,
pixel_mask=None,
audio_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
) -> Union[tuple, SequenceClassifierOutput]:
pixel_values: torch.FloatTensor,
audio_values: torch.FloatTensor,
pixel_mask: Optional[torch.FloatTensor] = None,
audio_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.FloatTensor], SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, num_labels)`, *optional*):
Labels for computing the audiovisual loss. Indices should be in `[0, ..., num_classes-1]` where num_classes
Expand Down
28 changes: 14 additions & 14 deletions src/transformers/models/vivit/modeling_vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,15 +486,15 @@ def _prune_heads(self, heads_to_prune):
self.encoder.layer[layer].attention.prune_heads(heads)

@add_start_docstrings_to_model_forward(VIVIT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPooling]:
r"""
Returns:
Expand Down Expand Up @@ -628,13 +628,13 @@ def __init__(self, config):
@replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values=None,
head_mask=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.FloatTensor], ImageClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
Expand Down

0 comments on commit ec7b5be

Please sign in to comment.