-
Notifications
You must be signed in to change notification settings - Fork 462
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DRAFT] [models] add ViTSTR in TF and PT #1048
Conversation
Codecov Report
@@ Coverage Diff @@
## main #1048 +/- ##
==========================================
+ Coverage 94.94% 95.06% +0.12%
==========================================
Files 135 141 +6
Lines 5634 5893 +259
==========================================
+ Hits 5349 5602 +253
- Misses 285 291 +6
Flags with carried forward coverage won't be shown. Click here to find out more.
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. |
NOTE:
To debug:
|
@frgfm before i continue to debug this let me know your ideas you have in mind 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Felix! I added some comments :)
@@ -57,10 +57,10 @@ def scaled_dot_product_attention( | |||
class PositionwiseFeedForward(nn.Sequential): | |||
"""Position-wise Feed-Forward Network""" | |||
|
|||
def __init__(self, d_model: int, ffd: int, dropout: float = 0.1) -> None: | |||
def __init__(self, d_model: int, ffd: int, dropout: float = 0.1, use_gelu: bool = False) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
boolean value for activation selection is rather limited: are we positive that only relu & gelu can be used for such architecture types?
super(PositionwiseFeedForward, self).__init__() | ||
self.use_gelu = use_gelu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instantiate a self.activation_fn
in the constructor to avoid conditional execution in the call 👍
class PatchEmbedding(nn.Module): | ||
"""Compute 2D patch embedding""" | ||
|
||
# Inpired by: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just FYI: can you confirm that you made a lot of modifications?
inspired by is rather light
borrowed from is more significant
"""VisionTransformer architecture as described in | ||
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", | ||
<https://arxiv.org/pdf/2010.11929.pdf>`_.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's specify the constructor args
"""VisionTransformer architecture as described in | ||
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", | ||
<https://arxiv.org/pdf/2010.11929.pdf>`_.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
@frgfm haha i do currently focus to implement it as classification model 😅 we have to do a decision ! |
outdated by: #1050 |
This PR:
TODOS:
Any feedback is welcome :)
@frgfm wdyt ? I think it is a more flexible way instead of adding ViT as a fixed classification model (where we run into trouble with different sizes for patch_embeds) - can easily be extended needs only a head 😅
related to: #513 #1003