Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] [models] add ViTSTR in TF and PT #1048

Closed
wants to merge 17 commits into from

Conversation

felixdittrich92
Copy link
Contributor

@felixdittrich92 felixdittrich92 commented Sep 6, 2022

This PR:

  • adds VisionTransformer as module in models/modules/vision_transformer (TF/PT)
  • add Vitstr head (TF / PT)

TODOS:

  • toy run / check that all works fine
  • check code quality

Any feedback is welcome :)

@frgfm wdyt ? I think it is a more flexible way instead of adding ViT as a fixed classification model (where we run into trouble with different sizes for patch_embeds) - can easily be extended needs only a head 😅

related to: #513 #1003

@felixdittrich92 felixdittrich92 added module: models Related to doctr.models ext: tests Related to tests folder framework: pytorch Related to PyTorch backend framework: tensorflow Related to TensorFlow backend topic: text recognition Related to the task of text recognition type: new feature New feature labels Sep 6, 2022
@felixdittrich92 felixdittrich92 added this to the 0.6.0 milestone Sep 6, 2022
@felixdittrich92 felixdittrich92 self-assigned this Sep 6, 2022
@codecov
Copy link

codecov bot commented Sep 6, 2022

Codecov Report

Merging #1048 (006446b) into main (1cc073d) will increase coverage by 0.12%.
The diff coverage is 97.73%.

@@            Coverage Diff             @@
##             main    #1048      +/-   ##
==========================================
+ Coverage   94.94%   95.06%   +0.12%     
==========================================
  Files         135      141       +6     
  Lines        5634     5893     +259     
==========================================
+ Hits         5349     5602     +253     
- Misses        285      291       +6     
Flag Coverage Δ
unittests 95.06% <97.73%> (+0.12%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
doctr/models/recognition/vitstr/pytorch.py 94.59% <94.59%> (ø)
doctr/models/recognition/vitstr/tensorflow.py 97.43% <97.43%> (ø)
doctr/models/modules/__init__.py 100.00% <100.00%> (ø)
doctr/models/modules/transformer/pytorch.py 100.00% <100.00%> (ø)
doctr/models/modules/transformer/tensorflow.py 98.79% <100.00%> (+0.04%) ⬆️
...octr/models/modules/vision_transformer/__init__.py 100.00% <100.00%> (ø)
doctr/models/modules/vision_transformer/pytorch.py 100.00% <100.00%> (ø)
...tr/models/modules/vision_transformer/tensorflow.py 100.00% <100.00%> (ø)
doctr/models/recognition/__init__.py 100.00% <100.00%> (ø)
doctr/models/recognition/vitstr/__init__.py 100.00% <100.00%> (ø)
... and 1 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@felixdittrich92
Copy link
Contributor Author

felixdittrich92 commented Sep 7, 2022

NOTE:

  • ViT works fine (also tested with timm's ViT implementation)
  • slow tests (incl. onnx export) passes TF and PT

To debug:

  • loss does not decrease (needs debugging) tested also with timm's ViT implementation same stuck at ~2.9 loss
  • tested with org img / patch sizes from paper makes no difference
  • test data again toy runs with 500K MjSynth split

@felixdittrich92
Copy link
Contributor Author

@frgfm before i continue to debug this let me know your ideas you have in mind 🤗
Go this way to keep ViT as module or maybe start to implement as classification model and then think how we can make the patch_embedding flexible ? wdyt ?

Copy link
Collaborator

@frgfm frgfm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Felix! I added some comments :)

@@ -57,10 +57,10 @@ def scaled_dot_product_attention(
class PositionwiseFeedForward(nn.Sequential):
"""Position-wise Feed-Forward Network"""

def __init__(self, d_model: int, ffd: int, dropout: float = 0.1) -> None:
def __init__(self, d_model: int, ffd: int, dropout: float = 0.1, use_gelu: bool = False) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

boolean value for activation selection is rather limited: are we positive that only relu & gelu can be used for such architecture types?

super(PositionwiseFeedForward, self).__init__()
self.use_gelu = use_gelu
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instantiate a self.activation_fn in the constructor to avoid conditional execution in the call 👍

class PatchEmbedding(nn.Module):
"""Compute 2D patch embedding"""

# Inpired by: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just FYI: can you confirm that you made a lot of modifications?
inspired by is rather light
borrowed from is more significant

Comment on lines +45 to +47
"""VisionTransformer architecture as described in
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
<https://arxiv.org/pdf/2010.11929.pdf>`_."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's specify the constructor args

Comment on lines +52 to +54
"""VisionTransformer architecture as described in
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
<https://arxiv.org/pdf/2010.11929.pdf>`_."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

@felixdittrich92
Copy link
Contributor Author

@frgfm haha i do currently focus to implement it as classification model 😅 we have to do a decision !

@felixdittrich92
Copy link
Contributor Author

outdated by: #1050

@felixdittrich92 felixdittrich92 deleted the vit branch September 8, 2022 10:25
@felixdittrich92 felixdittrich92 removed this from the 0.6.0 milestone Sep 26, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ext: tests Related to tests folder framework: pytorch Related to PyTorch backend framework: tensorflow Related to TensorFlow backend module: models Related to doctr.models topic: text recognition Related to the task of text recognition type: new feature New feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants