Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[models] Vit: fix intermediate size scale and unify TF to PT #1063

Merged
merged 3 commits into from
Sep 19, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions doctr/models/classification/vit/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class VisionTransformer(nn.Sequential):
d_model: dimension of the transformer layers
num_layers: number of transformer layers
num_heads: number of attention heads
ffd_ratio: multiplier for the hidden dimension of the feedforward layer
dropout: dropout rate
num_classes: number of output classes
include_top: whether the classifier head should be instantiated
Expand All @@ -74,6 +75,7 @@ def __init__(
d_model: int = 768,
num_layers: int = 12,
num_heads: int = 12,
ffd_ratio: int = 4,
dropout: float = 0.0,
num_classes: int = 1000,
include_top: bool = True,
Expand All @@ -82,7 +84,7 @@ def __init__(

_layers: List[nn.Module] = [
PatchEmbedding(input_shape, patch_size, d_model),
EncoderBlock(num_layers, num_heads, d_model, dropout, nn.GELU()),
EncoderBlock(num_layers, num_heads, d_model, d_model * ffd_ratio, dropout, nn.GELU()),
]
if include_top:
_layers.append(ClassifierHead(d_model, num_classes))
Expand Down Expand Up @@ -121,7 +123,7 @@ def _vit(


def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""VisionTransformer architecture as described in
"""VisionTransformer-B architecture as described in
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
<https://arxiv.org/pdf/2010.11929.pdf>`_.

Expand Down
76 changes: 28 additions & 48 deletions doctr/models/classification/vit/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@
}


class ClassifierHead(layers.Layer, NestedObject):
"""Classifier head for Vision Transformer

Args:
num_classes: number of output classes
"""

def __init__(self, num_classes: int) -> None:
super().__init__()

self.head = layers.Dense(num_classes, kernel_initializer="he_normal")

def call(self, x: tf.Tensor) -> tf.Tensor:
# (batch_size, num_classes) cls token
return self.head(x[:, 0])

felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved

class VisionTransformer(Sequential):
"""VisionTransformer architecture as described in
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
Expand All @@ -42,6 +59,7 @@ class VisionTransformer(Sequential):
d_model: dimension of the transformer layers
num_layers: number of transformer layers
num_heads: number of attention heads
ffd_ratio: multiplier for the hidden dimension of the feedforward layer
dropout: dropout rate
num_classes: number of output classes
include_top: whether the classifier head should be instantiated
Expand All @@ -54,60 +72,22 @@ def __init__(
d_model: int = 768,
num_layers: int = 12,
num_heads: int = 12,
ffd_ratio: int = 4,
dropout: float = 0.0,
num_classes: int = 1000,
include_top: bool = True,
cfg: Optional[Dict[str, Any]] = None,
) -> None:

# Note: fix for onnx export
_vit = _VisionTransformer(
input_shape,
patch_size,
d_model,
num_layers,
num_heads,
dropout,
num_classes,
include_top,
)
super().__init__(_vit)
self.cfg = cfg


class _VisionTransformer(layers.Layer, NestedObject):
def __init__(
self,
input_shape: Tuple[int, int, int] = (32, 32, 3),
patch_size: Tuple[int, int] = (4, 4),
d_model: int = 768,
num_layers: int = 12,
num_heads: int = 12,
dropout: float = 0.0,
num_classes: int = 1000,
include_top: bool = True,
cfg: Optional[Dict[str, Any]] = None,
) -> None:

super().__init__()
self.include_top = include_top

self.patch_embedding = PatchEmbedding(input_shape, patch_size, d_model)
self.encoder = EncoderBlock(num_layers, num_heads, d_model, dropout, activation_fct=GELU())

if self.include_top:
self.head = layers.Dense(num_classes, kernel_initializer="he_normal")

def __call__(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
_layers = [
PatchEmbedding(input_shape, patch_size, d_model),
EncoderBlock(num_layers, num_heads, d_model, d_model * ffd_ratio, dropout, activation_fct=GELU()),
]
if include_top:
_layers.append(ClassifierHead(num_classes))

embeddings = self.patch_embedding(x, **kwargs)
encoded = self.encoder(embeddings, **kwargs)

if self.include_top:
# (batch_size, num_classes) cls token
return self.head(encoded[:, 0], **kwargs)

return encoded
super().__init__(_layers)
self.cfg = cfg


def _vit(
Expand Down Expand Up @@ -136,7 +116,7 @@ def _vit(


def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
"""VisionTransformer architecture as described in
"""VisionTransformer-B architecture as described in
`"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
<https://arxiv.org/pdf/2010.11929.pdf>`_.

Expand Down
3 changes: 2 additions & 1 deletion doctr/models/modules/transformer/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
num_layers: int,
num_heads: int,
d_model: int,
dff: int,
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
dropout: float,
activation_fct: Callable[[Any], Any] = nn.ReLU(),
) -> None:
Expand All @@ -124,7 +125,7 @@ def __init__(
[MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)]
)
self.position_feed_forward = nn.ModuleList(
[PositionwiseFeedForward(d_model, d_model, dropout, activation_fct) for _ in range(self.num_layers)]
[PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)]
)

def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
Expand Down
3 changes: 2 additions & 1 deletion doctr/models/modules/transformer/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def __init__(
num_layers: int,
num_heads: int,
d_model: int,
dff: int,
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
dropout: float,
activation_fct: Callable[[Any], Any] = layers.ReLU(),
) -> None:
Expand All @@ -156,7 +157,7 @@ def __init__(

self.attention = [MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)]
self.position_feed_forward = [
PositionwiseFeedForward(d_model, d_model, dropout, activation_fct) for _ in range(self.num_layers)
PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)
]

def call(self, x: tf.Tensor, mask: Optional[tf.Tensor] = None, **kwargs: Any) -> tf.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/modules/vision_transformer/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
assert W % self.patch_size[1] == 0, "Image width must be divisible by patch width"

# patchify image without convolution
# adopted from:
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
# adapted from:
# https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial15/Vision_Transformer.html
# NOTE: patchify with Conv2d works only with padding="valid" correctly on smaller images
# and has currently no ONNX support so we use this workaround
Expand Down