forked from HyeongminLEE/pytorch-sepconv
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TestModule.py
54 lines (46 loc) · 2.26 KB
/
TestModule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from PIL import Image
import torch
from torchvision import transforms
from math import log10
from torchvision.utils import save_image as imwrite
from torch.autograd import Variable
import os
def to_variable(x):
if torch.cuda.is_available():
x = x.cuda()
return Variable(x)
class Middlebury_eval:
def __init__(self, input_dir):
self.im_list = ['Army', 'Backyard', 'Basketball', 'Dumptruck', 'Evergreen', 'Grove', 'Mequon', 'Schefflera', 'Teddy', 'Urban', 'Wooden', 'Yosemite']
class Middlebury_other:
def __init__(self, input_dir, gt_dir):
self.im_list = os.listdir(input_dir)
self.transform = transforms.Compose([transforms.ToTensor()])
self.input0_list = []
self.input1_list = []
self.gt_list = []
for item in self.im_list:
self.input0_list.append(to_variable(self.transform(Image.open(input_dir + '/' + item + '/frame0.jpg')).unsqueeze(0)))
self.input1_list.append(to_variable(self.transform(Image.open(input_dir + '/' + item + '/frame2.jpg')).unsqueeze(0)))
self.gt_list.append(to_variable(self.transform(Image.open(gt_dir + '/' + item + '/frame1.jpg')).unsqueeze(0)))
def Test(self, model, output_dir, logfile=None, output_name='output.png'):
av_psnr = 0
if logfile is not None:
logfile.write('{:<7s}{:<3d}'.format('Epoch: ', model.epoch.item()) + '\n')
for idx in range(len(self.im_list)):
if not os.path.exists(output_dir + '/' + self.im_list[idx]):
os.makedirs(output_dir + '/' + self.im_list[idx])
frame_out = model(self.input0_list[idx], self.input1_list[idx])
gt = self.gt_list[idx]
psnr = -10 * log10(torch.mean((gt - frame_out) * (gt - frame_out)).item())
av_psnr += psnr
imwrite(frame_out, output_dir + '/' + self.im_list[idx] + '/' + output_name, range=(0, 1))
msg = '{:<15s}{:<20.16f}'.format(self.im_list[idx] + ': ', psnr) + '\n'
print(msg, end='')
if logfile is not None:
logfile.write(msg)
av_psnr /= len(self.im_list)
msg = '{:<15s}{:<20.16f}'.format('Average: ', av_psnr) + '\n'
print(msg, end='')
if logfile is not None:
logfile.write(msg)