-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathblocks.py
65 lines (51 loc) · 2.11 KB
/
blocks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
import torch.nn.functional as F
class BaseBlock(nn.Module):
alpha = 1
def __init__(self, input_channel, output_channel, t = 6, downsample = False):
"""
t: expansion factor, t*input_channel is channel of expansion layer
alpha: width multiplier, to get thinner models
rho: resolution multiplier, to get reduced representation
"""
super(BaseBlock, self).__init__()
self.stride = 2 if downsample else 1
self.downsample = downsample
self.shortcut = (not downsample) and (input_channel == output_channel)
# apply alpha
input_channel = int(self.alpha * input_channel)
output_channel = int(self.alpha * output_channel)
# for main path:
c = t * input_channel
# 1x1 point wise conv
self.conv1 = nn.Conv2d(input_channel, c, kernel_size = 1, bias = False)
self.bn1 = nn.BatchNorm2d(c)
# 3x3 depth wise conv
self.conv2 = nn.Conv2d(c, c, kernel_size = 3, stride = self.stride, padding = 1, groups = c, bias = False)
self.bn2 = nn.BatchNorm2d(c)
# 1x1 point wise conv
self.conv3 = nn.Conv2d(c, output_channel, kernel_size = 1, bias = False)
self.bn3 = nn.BatchNorm2d(output_channel)
def forward(self, inputs):
# main path
x = F.relu6(self.bn1(self.conv1(inputs)), inplace = True)
x = F.relu6(self.bn2(self.conv2(x)), inplace = True)
x = self.bn3(self.conv3(x))
# shortcut path
x = x + inputs if self.shortcut else x
return x
if __name__ == "__main__":
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = CIFAR10("~/dataset/cifar10", transform = transform)
x = trainset[0][0].unsqueeze(0)
print(x.shape)
BaseBlock.alpha = 0.5
b = BaseBlock(6, 5, downsample = True)
y = b(x)
print(b)
print(y.shape, y.max(), y.min())