-
Notifications
You must be signed in to change notification settings - Fork 9
/
Unet.py
90 lines (71 loc) · 3.27 KB
/
Unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from basic_blocks import *
class UnetGenerator(nn.Module):
def __init__(self, in_dim, out_dim, num_filter):
super(UnetGenerator,self).__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.num_filter = num_filter
act_fn = nn.LeakyReLU(0.2, inplace=True)
print("\n------Initiating U-Net------\n")
self.down_1 = conv_block_2(self.in_dim,self.num_filter,act_fn)
self.pool_1 = maxpool()
self.down_2 = conv_block_2(self.num_filter*1,self.num_filter*2,act_fn)
self.pool_2 = maxpool()
self.down_3 = conv_block_2(self.num_filter*2,self.num_filter*4,act_fn)
self.pool_3 = maxpool()
self.down_4 = conv_block_2(self.num_filter*4,self.num_filter*8,act_fn)
self.pool_4 = maxpool()
self.bridge = conv_block_2(self.num_filter*8,self.num_filter*16,act_fn)
self.trans_1 = conv_trans_block(self.num_filter*16,self.num_filter*8,act_fn)
self.up_1 = conv_block_2(self.num_filter*16,self.num_filter*8,act_fn)
self.trans_2 = conv_trans_block(self.num_filter*8,self.num_filter*4,act_fn)
self.up_2 = conv_block_2(self.num_filter*8,self.num_filter*4,act_fn)
self.trans_3 = conv_trans_block(self.num_filter*4,self.num_filter*2,act_fn)
self.up_3 = conv_block_2(self.num_filter*4,self.num_filter*2,act_fn)
self.trans_4 = conv_trans_block(self.num_filter*2,self.num_filter*1,act_fn)
self.up_4 = conv_block_2(self.num_filter*2,self.num_filter*1,act_fn)
self.out = nn.Sequential(
nn.Conv2d(self.num_filter,self.out_dim,3,1,1)
)
def forward(self, input, mode="train", comparison = False):
down_1 = self.down_1(input)
pool_1 = self.pool_1(down_1)
down_2 = self.down_2(pool_1)
pool_2 = self.pool_2(down_2)
down_3 = self.down_3(pool_2)
pool_3 = self.pool_3(down_3)
down_4 = self.down_4(pool_3)
pool_4 = self.pool_4(down_4)
bridge = self.bridge(pool_4)
if mode == "test" and comparison == False:
torch.save(down_1, 'tensor_store\\temp_down_1.pt')
torch.save(down_2, 'tensor_store\\temp_down_2.pt')
torch.save(down_3, 'tensor_store\\temp_down_3.pt')
torch.save(down_4, 'tensor_store\\temp_down_4.pt')
torch.save(bridge, 'tensor_store\\temp_bridge.pt')
if mode == "test" and comparison == True:
prev_down_1 = torch.load('tensor_store\\temp_down_1.pt')
prev_down_2 = torch.load('tensor_store\\temp_down_2.pt')
prev_down_3 = torch.load('tensor_store\\temp_down_3.pt')
prev_down_4 = torch.load('tensor_store\\temp_down_4.pt')
prev_bridge = torch.load('tensor_store\\temp_bridge.pt')
down_1 = difference(down_1, prev_down_1,threshhold=0.6)
down_2 = difference(down_2, prev_down_2,threshhold=0.6)
down_3 = difference(down_3, prev_down_3,threshhold=0.8)
down_4 = difference(down_4, prev_down_4,threshhold=1.0)
bridge = difference(bridge, prev_bridge,threshhold=1.2)
#print(bridge.type())
trans_1 = self.trans_1(bridge)
concat_1 = torch.cat([trans_1,down_4],dim=1)
up_1 = self.up_1(concat_1)
trans_2 = self.trans_2(up_1)
concat_2 = torch.cat([trans_2,down_3],dim=1)
up_2 = self.up_2(concat_2)
trans_3 = self.trans_3(up_2)
concat_3 = torch.cat([trans_3,down_2],dim=1)
up_3 = self.up_3(concat_3)
trans_4 = self.trans_4(up_3)
concat_4 = torch.cat([trans_4,down_1],dim=1)
up_4 = self.up_4(concat_4)
out = self.out(up_4)
return out