Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to replicate DeiT-SimA result #4

Open
aaab8b opened this issue Jul 26, 2022 · 12 comments
Open

How to replicate DeiT-SimA result #4

aaab8b opened this issue Jul 26, 2022 · 12 comments

Comments

@aaab8b
Copy link

aaab8b commented Jul 26, 2022

Thanks for your wonderful work of SimA! And I'm trying to replicate your DeiT-S->SimA result but I don't find any hyperparameters settings. Is all hyperparameters inculding drop-path are the same with original DeiT-S? Thank you.

@soroush-abbasi
Copy link
Collaborator

Hi,

Thanks for your interests in our works! I explained how to run DeiT-S -> SimA in here (link). We use default hyper parameters (our code) except that we run it with 300 epochs to be comparable to the original numbers of DeiT-S in their paper.

@aaab8b
Copy link
Author

aaab8b commented Jul 29, 2022

Hi,

Thanks for your interests in our works! I explained how to run DeiT-S -> SimA in here (link). We use default hyper parameters (our code) except that we run it with 300 epochs to be comparable to the original numbers of DeiT-S in their paper.

Thanks for your reply! I'm actually trying to use DeiT-Tiny for my experiments because of GPU memory limits (mostly 4 or 8 2080Tis), I used the exact same way to replace MHSA in DeiT-Tiny with SimA but the performance dropped drastically about 4-5%. Is there any suggestion here to boost the performance of small models?

@soroush-abbasi
Copy link
Collaborator

Hi,

Please use our settings for DeiT-Tiny training. I guess since DeiT-Tiny has less capacity, you might want to reduce any regularization parameters (e.g., drop path rate).

@aaab8b
Copy link
Author

aaab8b commented Aug 3, 2022

Hi,

Please use our settings for DeiT-Tiny training. I guess since DeiT-Tiny has less capacity, you might want to reduce any regularization parameters (e.g., drop path rate).

Thank you for your reply. I will try a lower drop path rate. And when you train DeiT-S->SimA (79.8 as your paper claimed), is drop-path rate 0.1 or 0.05?

@aaab8b
Copy link
Author

aaab8b commented Aug 4, 2022

Hi,

Thanks for your interests in our works! I explained how to run DeiT-S -> SimA in here (link). We use default hyper parameters (our code) except that we run it with 300 epochs to be comparable to the original numbers of DeiT-S in their paper.

Also, by this code, the original PatchEmbed module is replaced by ConvPatchEmbed module, and this costs a little bit more parameters and flops. By original PatchEmbed and your default settings, the performance of DeiT-S-> SimA dropped drastically comparing to original DeiT-S as well... The flops of DeiT-S->SimA you claimed in your paper is actually not 4.6B. It shoule be 5.0B actually.

@aaab8b
Copy link
Author

aaab8b commented Aug 4, 2022

And the parameters of DeiT-S->SimA should be 23M instead of 22M.

@soroush-abbasi
Copy link
Collaborator

soroush-abbasi commented Aug 4, 2022

Hi,

Thanks for
These are the exact numbers we calculate for FLOPS and parameters:

FLOPS: 4594733568.0
Parameters: 22557304.0

Can you please elaborate how you calculate above numbers? We use thop to calculate the FLOPS.

We used ConvPatchEmbed in our code. Note that there are various approaches to calculate input tokens from original image, which is orthogonal to the main transformer architecture. But thanks for mentioning this difference! We will clarify this in the next arXiv version.

Please let me know if you have any questions. Thanks! Have a good day!

@soroush-abbasi
Copy link
Collaborator

Below is the full architecture for DeiT-S -> SimA


"""
Implementation of SimA (Simple Softmax-free Attention)
Based on timm, DeiT and XCiT code bases
https://github.com/rwightman/pytorch-image-models/tree/master/timm
https://github.com/facebookresearch/deit/
https://github.com/facebookresearch/xcit/
"""
import math

import torch
import torch.nn as nn
from functools import partial

from timm.models.vision_transformer import _cfg, Mlp
from timm.models.registry import register_model
from timm.models.layers import DropPath, trunc_normal_, to_2tuple
import torch.nn.functional as F


class PositionalEncodingFourier(nn.Module):
    """
    Positional encoding relying on a fourier kernel matching the one used in the
    "Attention is all of Need" paper. The implementation builds on DeTR code
    https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
    """

    def __init__(self, hidden_dim=32, dim=768, temperature=10000):
        super().__init__()
        self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1)
        self.scale = 2 * math.pi
        self.temperature = temperature
        self.hidden_dim = hidden_dim
        self.dim = dim

    def forward(self, B, H, W):
        mask = torch.zeros(B, H, W).bool().to(self.token_projection.weight.device)
        not_mask = ~mask
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        eps = 1e-6
        y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
        x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=mask.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.hidden_dim)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(),
                             pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(),
                             pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        pos = self.token_projection(pos)
        return pos


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return torch.nn.Sequential(
        nn.Conv2d(
            in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
        ),
        nn.SyncBatchNorm(out_planes)
    )


class ConvPatchEmbed(nn.Module):
    """ Image to Patch Embedding using multiple convolutional layers
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        if patch_size[0] == 16:
            self.proj = torch.nn.Sequential(
                conv3x3(3, embed_dim // 8, 2),
                nn.GELU(),
                conv3x3(embed_dim // 8, embed_dim // 4, 2),
                nn.GELU(),
                conv3x3(embed_dim // 4, embed_dim // 2, 2),
                nn.GELU(),
                conv3x3(embed_dim // 2, embed_dim, 2),
            )
        elif patch_size[0] == 8:
            self.proj = torch.nn.Sequential(
                conv3x3(3, embed_dim // 4, 2),
                nn.GELU(),
                conv3x3(embed_dim // 4, embed_dim // 2, 2),
                nn.GELU(),
                conv3x3(embed_dim // 2, embed_dim, 2),
            )
        else:
            raise("For convolutional projection, patch size has to be in [8, 16]")

    def forward(self, x, padding_size=None):
        B, C, H, W = x.shape
        x = self.proj(x)
        Hp, Wp = x.shape[2], x.shape[3]
        x = x.flatten(2).transpose(1, 2)

        return x, (Hp, Wp)





class SimA(nn.Module):
    """ SimA attention block
    """

    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        k = F.normalize(k, p=1.0, dim=-2)
        q = F.normalize(q, p=1.0, dim=-2)
        if (N / (C//self.num_heads)) < 1:
            x = ((q @ k.transpose(-2, -1)) @ v).transpose(1, 2).reshape(B, N, C)
        else:
            x = (q @ (k.transpose(-2, -1) @ v)).transpose(1, 2).reshape(B, N, C)


        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    @torch.jit.ignore
    def no_weight_decay(self):
        return {}


class SimABlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0.,
                 attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 num_tokens=196, eta=None):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = SimA(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
            proj_drop=drop
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)

        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer,
                       drop=drop)

        self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)
        self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)

    def forward(self, x, H, W):
        x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x)))
        x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
        return x


class SimAVisionTransformer(nn.Module):
    """
    Based on timm, DeiT and XCiT code bases
    https://github.com/rwightman/pytorch-image-models/tree/master/timm
    https://github.com/facebookresearch/deit/
    https://github.com/facebookresearch/xcit/
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768,
                 depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
                 cls_attn_layers=2, use_pos=True, patch_proj='linear', eta=None, tokens_norm=False):
        """
        Args:
            img_size (int, tuple): input image size
            patch_size (int, tuple): patch size
            in_chans (int): number of input channels
            num_classes (int): number of classes for classification head
            embed_dim (int): embedding dimension
            depth (int): depth of transformer
            num_heads (int): number of attention heads
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
            qkv_bias (bool): enable bias for qkv if True
            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
            drop_rate (float): dropout rate
            attn_drop_rate (float): attention dropout rate
            drop_path_rate (float): stochastic depth rate
            norm_layer: (nn.Module): normalization layer
            cls_attn_layers: (int) Depth of Class attention layers
            use_pos: (bool) whether to use positional encoding
            eta: (float) layerscale initialization value
            tokens_norm: (bool) Whether to normalize all tokens or just the cls_token in the CA
        """
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)

        self.patch_embed = ConvPatchEmbed(img_size=img_size, embed_dim=embed_dim,
                                          patch_size=patch_size)

        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [drop_path_rate for i in range(depth)]
        self.blocks = nn.ModuleList([
            SimABlock(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i],
                norm_layer=norm_layer, num_tokens=num_patches, eta=eta)
            for i in range(depth)])

        
        self.norm = norm_layer(embed_dim)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.pos_embeder = PositionalEncodingFourier(dim=embed_dim)
        self.use_pos = use_pos

        # Classifier head
        trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token', 'dist_token'}

    def forward_features(self, x):
        B, C, H, W = x.shape

        x, (Hp, Wp) = self.patch_embed(x)

        if self.use_pos:
            pos_encoding = self.pos_embeder(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
            x = x + pos_encoding

        x = self.pos_drop(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        for blk in self.blocks:
            x = blk(x, Hp, Wp)

        x = self.norm(x)[:, 0]
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)

        if self.training:
            return x, x
        else:
            return x



@register_model
def sima_small_deit(pretrained=False, **kwargs):
    model = SimAVisionTransformer(
        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), eta=1.0, tokens_norm=True, **kwargs)
    model.default_cfg = _cfg()
    return model

@aaab8b
Copy link
Author

aaab8b commented Aug 5, 2022

Hi,

Thanks for These are the exact numbers we calculate for FLOPS and parameters:

FLOPS: 4594733568.0 Parameters: 22557304.0

Can you please elaborate how you calculate above numbers? We use thop to calculate the FLOPS.

We used ConvPatchEmbed in our code. Note that there are various approaches to calculate input tokens from original image, which is orthogonal to the main transformer architecture. But thanks for mentioning this difference! We will clarify this in the next arXiv version.

Please let me know if you have any questions. Thanks! Have a good day!

Thanks for your reply! I used fvcore to calculate the flops and params. And I believe that various approaches to calculate input tokens are orthogonal too. But if you replace PatcheEmbedding with ConvPatchEmbedding on original DeiT-S, there is performance improvement as well (I did this experiment). So the comparison in your ablation study may seem not fair.

@soroush-abbasi
Copy link
Collaborator

Can you please use above code and calculate the FLOPs again? I wonder why your numbers are inconsistent with ours (yours: 5GFLOPs vs ours: 4.6 GFLOPs)? Can you please share the checkpoint for DeiT-S + ConvPatchEmbedding? How much is the improvement? I can evaluate it and add it to the next version of arXiv.

@aaab8b
Copy link
Author

aaab8b commented Aug 8, 2022

Can you please use above code and calculate the FLOPs again? I wonder why your numbers are inconsistent with ours (yours: 5GFLOPs vs ours: 4.6 GFLOPs)? Can you please share the checkpoint for DeiT-S + ConvPatchEmbedding? How much is the improvement? I can evaluate it and add it to the next version of arXiv.

This is the flops and parameters result of DeiT-S->SimA with ConvPatchEmbedding using fvcore to calculate.

module #parameters or shape #flops
model 22.588M 4.967G
cls_token (1, 1, 384)
patch_embed.proj 0.874M 0.412G
patch_embed.proj.0 1.392K 19.268M
patch_embed.proj.0.0 1.296K 16.257M
patch_embed.proj.0.0.weight (48, 3, 3, 3)
patch_embed.proj.0.1 96 3.011M
patch_embed.proj.0.1.weight (48,)
patch_embed.proj.0.1.bias (48,)
patch_embed.proj.2 41.664K 0.132G
patch_embed.proj.2.0 41.472K 0.13G
patch_embed.proj.2.0.weight (96, 48, 3, 3)
patch_embed.proj.2.1 0.192K 1.505M
patch_embed.proj.2.1.weight (96,)
patch_embed.proj.2.1.bias (96,)
patch_embed.proj.4 0.166M 0.131G
patch_embed.proj.4.0 0.166M 0.13G
patch_embed.proj.4.0.weight (192, 96, 3, 3)
patch_embed.proj.4.1 0.384K 0.753M
patch_embed.proj.4.1.weight (192,)
patch_embed.proj.4.1.bias (192,)
patch_embed.proj.6 0.664M 0.13G
patch_embed.proj.6.0 0.664M 0.13G
patch_embed.proj.6.0.weight (384, 192, 3, 3)
patch_embed.proj.6.1 0.768K 0.376M
patch_embed.proj.6.1.weight (384,)
patch_embed.proj.6.1.bias (384,)
blocks 21.303M 4.55G
blocks.0 1.775M 0.379G
blocks.0.gamma1 (384,)
blocks.0.gamma2 (384,)
blocks.0.norm1 0.768K 0.378M
blocks.0.norm1.weight (384,)
blocks.0.norm1.bias (384,)
blocks.0.attn 0.591M 0.146G
blocks.0.attn.qkv 0.444M 87.146M
blocks.0.attn.proj 0.148M 29.049M
blocks.0.norm2 0.768K 0.378M
blocks.0.norm2.weight (384,)
blocks.0.norm2.bias (384,)
blocks.0.mlp 1.182M 0.232G
blocks.0.mlp.fc1 0.591M 0.116G
blocks.0.mlp.fc2 0.59M 0.116G
blocks.1 1.775M 0.379G
blocks.1.gamma1 (384,)
blocks.1.gamma2 (384,)
blocks.1.norm1 0.768K 0.378M
blocks.1.norm1.weight (384,)
blocks.1.norm1.bias (384,)
blocks.1.attn 0.591M 0.146G
blocks.1.attn.qkv 0.444M 87.146M
blocks.1.attn.proj 0.148M 29.049M
blocks.1.norm2 0.768K 0.378M
blocks.1.norm2.weight (384,)
blocks.1.norm2.bias (384,)
blocks.1.mlp 1.182M 0.232G
blocks.1.mlp.fc1 0.591M 0.116G
blocks.1.mlp.fc2 0.59M 0.116G
blocks.2 1.775M 0.379G
blocks.2.gamma1 (384,)
blocks.2.gamma2 (384,)
blocks.2.norm1 0.768K 0.378M
blocks.2.norm1.weight (384,)
blocks.2.norm1.bias (384,)
blocks.2.attn 0.591M 0.146G
blocks.2.attn.qkv 0.444M 87.146M
blocks.2.attn.proj 0.148M 29.049M
blocks.2.norm2 0.768K 0.378M
blocks.2.norm2.weight (384,)
blocks.2.norm2.bias (384,)
blocks.2.mlp 1.182M 0.232G
blocks.2.mlp.fc1 0.591M 0.116G
blocks.2.mlp.fc2 0.59M 0.116G
blocks.3 1.775M 0.379G
blocks.3.gamma1 (384,)
blocks.3.gamma2 (384,)
blocks.3.norm1 0.768K 0.378M
blocks.3.norm1.weight (384,)
blocks.3.norm1.bias (384,)
blocks.3.attn 0.591M 0.146G
blocks.3.attn.qkv 0.444M 87.146M
blocks.3.attn.proj 0.148M 29.049M
blocks.3.norm2 0.768K 0.378M
blocks.3.norm2.weight (384,)
blocks.3.norm2.bias (384,)
blocks.3.mlp 1.182M 0.232G
blocks.3.mlp.fc1 0.591M 0.116G
blocks.3.mlp.fc2 0.59M 0.116G
blocks.4 1.775M 0.379G
blocks.4.gamma1 (384,)
blocks.4.gamma2 (384,)
blocks.4.norm1 0.768K 0.378M
blocks.4.norm1.weight (384,)
blocks.4.norm1.bias (384,)
blocks.4.attn 0.591M 0.146G
blocks.4.attn.qkv 0.444M 87.146M
blocks.4.attn.proj 0.148M 29.049M
blocks.4.norm2 0.768K 0.378M
blocks.4.norm2.weight (384,)
blocks.4.norm2.bias (384,)
blocks.4.mlp 1.182M 0.232G
blocks.4.mlp.fc1 0.591M 0.116G
blocks.4.mlp.fc2 0.59M 0.116G
blocks.5 1.775M 0.379G
blocks.5.gamma1 (384,)
blocks.5.gamma2 (384,)
blocks.5.norm1 0.768K 0.378M
blocks.5.norm1.weight (384,)
blocks.5.norm1.bias (384,)
blocks.5.attn 0.591M 0.146G
blocks.5.attn.qkv 0.444M 87.146M
blocks.5.attn.proj 0.148M 29.049M
blocks.5.norm2 0.768K 0.378M
blocks.5.norm2.weight (384,)
blocks.5.norm2.bias (384,)
blocks.5.mlp 1.182M 0.232G
blocks.5.mlp.fc1 0.591M 0.116G
blocks.5.mlp.fc2 0.59M 0.116G
blocks.6 1.775M 0.379G
blocks.6.gamma1 (384,)
blocks.6.gamma2 (384,)
blocks.6.norm1 0.768K 0.378M
blocks.6.norm1.weight (384,)
blocks.6.norm1.bias (384,)
blocks.6.attn 0.591M 0.146G
blocks.6.attn.qkv 0.444M 87.146M
blocks.6.attn.proj 0.148M 29.049M
blocks.6.norm2 0.768K 0.378M
blocks.6.norm2.weight (384,)
blocks.6.norm2.bias (384,)
blocks.6.mlp 1.182M 0.232G
blocks.6.mlp.fc1 0.591M 0.116G
blocks.6.mlp.fc2 0.59M 0.116G
blocks.7 1.775M 0.379G
blocks.7.gamma1 (384,)
blocks.7.gamma2 (384,)
blocks.7.norm1 0.768K 0.378M
blocks.7.norm1.weight (384,)
blocks.7.norm1.bias (384,)
blocks.7.attn 0.591M 0.146G
blocks.7.attn.qkv 0.444M 87.146M
blocks.7.attn.proj 0.148M 29.049M
blocks.7.norm2 0.768K 0.378M
blocks.7.norm2.weight (384,)
blocks.7.norm2.bias (384,)
blocks.7.mlp 1.182M 0.232G
blocks.7.mlp.fc1 0.591M 0.116G
blocks.7.mlp.fc2 0.59M 0.116G
blocks.8 1.775M 0.379G
blocks.8.gamma1 (384,)
blocks.8.gamma2 (384,)
blocks.8.norm1 0.768K 0.378M
blocks.8.norm1.weight (384,)
blocks.8.norm1.bias (384,)
blocks.8.attn 0.591M 0.146G
blocks.8.attn.qkv 0.444M 87.146M
blocks.8.attn.proj 0.148M 29.049M
blocks.8.norm2 0.768K 0.378M
blocks.8.norm2.weight (384,)
blocks.8.norm2.bias (384,)
blocks.8.mlp 1.182M 0.232G
blocks.8.mlp.fc1 0.591M 0.116G
blocks.8.mlp.fc2 0.59M 0.116G
blocks.9 1.775M 0.379G
blocks.9.gamma1 (384,)
blocks.9.gamma2 (384,)
blocks.9.norm1 0.768K 0.378M
blocks.9.norm1.weight (384,)
blocks.9.norm1.bias (384,)
blocks.9.attn 0.591M 0.146G
blocks.9.attn.qkv 0.444M 87.146M
blocks.9.attn.proj 0.148M 29.049M
blocks.9.norm2 0.768K 0.378M
blocks.9.norm2.weight (384,)
blocks.9.norm2.bias (384,)
blocks.9.mlp 1.182M 0.232G
blocks.9.mlp.fc1 0.591M 0.116G
blocks.9.mlp.fc2 0.59M 0.116G
blocks.10 1.775M 0.379G
blocks.10.gamma1 (384,)
blocks.10.gamma2 (384,)
blocks.10.norm1 0.768K 0.378M
blocks.10.norm1.weight (384,)
blocks.10.norm1.bias (384,)
blocks.10.attn 0.591M 0.146G
blocks.10.attn.qkv 0.444M 87.146M
blocks.10.attn.proj 0.148M 29.049M
blocks.10.norm2 0.768K 0.378M
blocks.10.norm2.weight (384,)
blocks.10.norm2.bias (384,)
blocks.10.mlp 1.182M 0.232G
blocks.10.mlp.fc1 0.591M 0.116G
blocks.10.mlp.fc2 0.59M 0.116G
blocks.11 1.775M 0.379G
blocks.11.gamma1 (384,)
blocks.11.gamma2 (384,)
blocks.11.norm1 0.768K 0.378M
blocks.11.norm1.weight (384,)
blocks.11.norm1.bias (384,)
blocks.11.attn 0.591M 0.146G
blocks.11.attn.qkv 0.444M 87.146M
blocks.11.attn.proj 0.148M 29.049M
blocks.11.norm2 0.768K 0.378M
blocks.11.norm2.weight (384,)
blocks.11.norm2.bias (384,)
blocks.11.mlp 1.182M 0.232G
blocks.11.mlp.fc1 0.591M 0.116G
blocks.11.mlp.fc2 0.59M 0.116G
norm 0.768K 0.378M
norm.weight (384,)
norm.bias (384,)
head 0.385M 0.384M
head.weight (1000, 384)
head.bias (1000,)
pos_embeder.token_projection 24.96K 4.817M
pos_embeder.token_projection.weight (384, 64, 1, 1)
pos_embeder.token_projection.bias (384,)

I haven't finished my training on DeiT-S + ConvPatchEmbedding but the validation accuracy has gone a lot higher than original DeiT-S.
image

@soroush-abbasi
Copy link
Collaborator

Thanks for running this experiment! Let's see the final converged accuracy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants