-
Notifications
You must be signed in to change notification settings - Fork 47
/
run.py
182 lines (131 loc) · 7.93 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
#!/usr/bin/env python
import getopt
import math
import numpy
import PIL
import PIL.Image
import sys
import torch
##########################################################
torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance
torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance
##########################################################
args_strModel = 'sintel-final' # 'sintel-final', or 'sintel-clean', or 'chairs-final', or 'chairs-clean', or 'kitti-final'
args_strOne = './images/one.png'
args_strTwo = './images/two.png'
args_strOut = './out.flo'
for strOption, strArg in getopt.getopt(sys.argv[1:], '', [
'model=',
'one=',
'two=',
'out=',
])[0]:
if strOption == '--model' and strArg != '': args_strModel = strArg # which model to use, see below
if strOption == '--one' and strArg != '': args_strOne = strArg # path to the first frame
if strOption == '--two' and strArg != '': args_strTwo = strArg # path to the second frame
if strOption == '--out' and strArg != '': args_strOut = strArg # path to where the output should be stored
# end
##########################################################
backwarp_tenGrid = {}
def backwarp(tenInput, tenFlow):
if str(tenFlow.shape) not in backwarp_tenGrid:
tenHor = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1)
tenVer = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3])
backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([ tenHor, tenVer ], 1).cuda()
# end
tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] * (2.0 / (tenInput.shape[3] - 1.0)), tenFlow[:, 1:2, :, :] * (2.0 / (tenInput.shape[2] - 1.0)) ], 1)
return torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='border', align_corners=True)
# end
##########################################################
class Network(torch.nn.Module):
def __init__(self):
super().__init__()
class Preprocess(torch.nn.Module):
def __init__(self):
super().__init__()
# end
def forward(self, tenInput):
tenInput = tenInput.flip([1])
tenInput = tenInput - torch.tensor(data=[0.485, 0.456, 0.406], dtype=tenInput.dtype, device=tenInput.device).view(1, 3, 1, 1)
tenInput = tenInput * torch.tensor(data=[1.0 / 0.229, 1.0 / 0.224, 1.0 / 0.225], dtype=tenInput.dtype, device=tenInput.device).view(1, 3, 1, 1)
return tenInput
# end
# end
class Basic(torch.nn.Module):
def __init__(self, intLevel):
super().__init__()
self.netBasic = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3),
torch.nn.ReLU(inplace=False),
torch.nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3)
)
# end
def forward(self, tenInput):
return self.netBasic(tenInput)
# end
# end
self.netPreprocess = Preprocess()
self.netBasic = torch.nn.ModuleList([ Basic(intLevel) for intLevel in range(6) ])
self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.hub.load_state_dict_from_url(url='http://content.sniklaus.com/github/pytorch-spynet/network-' + args_strModel + '.pytorch', file_name='spynet-' + args_strModel).items() })
# end
def forward(self, tenOne, tenTwo):
tenFlow = []
tenOne = [ self.netPreprocess(tenOne) ]
tenTwo = [ self.netPreprocess(tenTwo) ]
for intLevel in range(5):
if tenOne[0].shape[2] > 32 or tenOne[0].shape[3] > 32:
tenOne.insert(0, torch.nn.functional.avg_pool2d(input=tenOne[0], kernel_size=2, stride=2, count_include_pad=False))
tenTwo.insert(0, torch.nn.functional.avg_pool2d(input=tenTwo[0], kernel_size=2, stride=2, count_include_pad=False))
# end
# end
tenFlow = tenOne[0].new_zeros([ tenOne[0].shape[0], 2, int(math.floor(tenOne[0].shape[2] / 2.0)), int(math.floor(tenOne[0].shape[3] / 2.0)) ])
for intLevel in range(len(tenOne)):
tenUpsampled = torch.nn.functional.interpolate(input=tenFlow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
if tenUpsampled.shape[2] != tenOne[intLevel].shape[2]: tenUpsampled = torch.nn.functional.pad(input=tenUpsampled, pad=[ 0, 0, 0, 1 ], mode='replicate')
if tenUpsampled.shape[3] != tenOne[intLevel].shape[3]: tenUpsampled = torch.nn.functional.pad(input=tenUpsampled, pad=[ 0, 1, 0, 0 ], mode='replicate')
tenFlow = self.netBasic[intLevel](torch.cat([ tenOne[intLevel], backwarp(tenInput=tenTwo[intLevel], tenFlow=tenUpsampled), tenUpsampled ], 1)) + tenUpsampled
# end
return tenFlow
# end
# end
netNetwork = None
##########################################################
def estimate(tenOne, tenTwo):
global netNetwork
if netNetwork is None:
netNetwork = Network().cuda().eval()
# end
assert(tenOne.shape[1] == tenTwo.shape[1])
assert(tenOne.shape[2] == tenTwo.shape[2])
intWidth = tenOne.shape[2]
intHeight = tenOne.shape[1]
assert(intWidth == 1024) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue
assert(intHeight == 416) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue
tenPreprocessedOne = tenOne.cuda().view(1, 3, intHeight, intWidth)
tenPreprocessedTwo = tenTwo.cuda().view(1, 3, intHeight, intWidth)
intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 32.0) * 32.0))
intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 32.0) * 32.0))
tenPreprocessedOne = torch.nn.functional.interpolate(input=tenPreprocessedOne, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
tenPreprocessedTwo = torch.nn.functional.interpolate(input=tenPreprocessedTwo, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
tenFlow = torch.nn.functional.interpolate(input=netNetwork(tenPreprocessedOne, tenPreprocessedTwo), size=(intHeight, intWidth), mode='bilinear', align_corners=False)
tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth)
tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight)
return tenFlow[0, :, :, :].cpu()
# end
##########################################################
if __name__ == '__main__':
tenOne = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(args_strOne))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))
tenTwo = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(args_strTwo))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))
tenOutput = estimate(tenOne, tenTwo)
objOutput = open(args_strOut, 'wb')
numpy.array([ 80, 73, 69, 72 ], numpy.uint8).tofile(objOutput)
numpy.array([ tenOutput.shape[2], tenOutput.shape[1] ], numpy.int32).tofile(objOutput)
numpy.array(tenOutput.numpy().transpose(1, 2, 0), numpy.float32).tofile(objOutput)
objOutput.close()
# end