forked from keyu-tian/SparK
-
Notifications
You must be signed in to change notification settings - Fork 0
/
custom.py
89 lines (70 loc) · 3.39 KB
/
custom.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from typing import List
from timm.models.registry import register_model
class YourConvNet(nn.Module):
"""
This is a template for your custom ConvNet.
It is required to implement the following three functions: `get_downsample_ratio`, `get_feature_map_channels`, `forward`.
You can refer to the implementations in `pretrain\models\resnet.py` for an example.
"""
def get_downsample_ratio(self) -> int:
"""
This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
:return: the TOTAL downsample ratio of the ConvNet.
E.g., for a ResNet-50, this should return 32.
"""
raise NotImplementedError
def get_feature_map_channels(self) -> List[int]:
"""
This func would ONLY be used in `SparseEncoder's __init__` (see `pretrain/encoder.py`).
:return: a list of the number of channels of each feature map.
E.g., for a ResNet-50, this should return [256, 512, 1024, 2048].
"""
raise NotImplementedError
def forward(self, inp_bchw: torch.Tensor, hierarchical=False):
"""
The forward with `hierarchical=True` would ONLY be used in `SparseEncoder.forward` (see `pretrain/encoder.py`).
:param inp_bchw: input image tensor, shape: (batch_size, channels, height, width).
:param hierarchical: return the logits (not hierarchical), or the feature maps (hierarchical).
:return:
- hierarchical == False: return the logits of the classification task, shape: (batch_size, num_classes).
- hierarchical == True: return a list of all feature maps, which should have the same length as the return value of `get_feature_map_channels`.
E.g., for a ResNet-50, it should return a list [1st_feat_map, 2nd_feat_map, 3rd_feat_map, 4th_feat_map].
for an input size of 224, the shapes are [(B, 256, 56, 56), (B, 512, 28, 28), (B, 1024, 14, 14), (B, 2048, 7, 7)]
"""
raise NotImplementedError
@register_model
def your_convnet_small(pretrained=False, **kwargs):
raise NotImplementedError
return YourConvNet(**kwargs)
@torch.no_grad()
def convnet_test():
from timm.models import create_model
cnn = create_model('your_convnet_small')
print('get_downsample_ratio:', cnn.get_downsample_ratio())
print('get_feature_map_channels:', cnn.get_feature_map_channels())
downsample_ratio = cnn.get_downsample_ratio()
feature_map_channels = cnn.get_feature_map_channels()
# check the forward function
B, C, H, W = 4, 3, 224, 224
inp = torch.rand(B, C, H, W)
feats = cnn(inp, hierarchical=True)
assert isinstance(feats, list)
assert len(feats) == len(feature_map_channels)
print([tuple(t.shape) for t in feats])
# check the downsample ratio
feats = cnn(inp, hierarchical=True)
assert feats[-1].shape[-2] == H // downsample_ratio
assert feats[-1].shape[-1] == W // downsample_ratio
# check the channel number
for feat, ch in zip(feats, feature_map_channels):
assert feat.ndim == 4
assert feat.shape[1] == ch
if __name__ == '__main__':
convnet_test()