-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathDWSRx2.py
79 lines (74 loc) · 3.25 KB
/
DWSRx2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from __future__ import print_function
import os, glob, time
import tensorflow as tf
import numpy as np
from netx2 import model
import cv2
import pywt as pw
import ntpath
'''
Developed by Tiantong for NTIER CVPR 2017 SR Competition
tong.renly@gmail.com
'''
TEST_EXPERIMENT = True
# X2
TEST_MODEL_PATH = './Weightx2/x2.ckpt'
TEST_SAVE_PATH = './Resultx2Lum/'
print ('Developed by Tiantong for NTIER CVPR 2017 SR Competition, team iPAL')
print ('For track 1: bicubic downsample x2')
TEST_PATH = raw_input('Please enter the testing path [hit enter to run official DIV2K set]:')
if not os.path.exists(TEST_PATH):
print ('Using official DIV2K set: ./Testx2Lum')
TEST_PATH = './Testx2Lum/'
WV = 'db1'
def path_leaf(path):
head, tail = ntpath.split(path)
return tail or ntpath.basename(head)
if __name__ == '__main__':
test_input = tf.placeholder(np.float32)
# Feeding Forward
test_output, _, _ = model(test_input)
with tf.Session(config=tf.ConfigProto()) as sess:
tf.initialize_all_variables().run()
saver = tf.train.Saver(tf.all_variables())
if TEST_EXPERIMENT:
print ('>>>>>>>>Resuming Experiments For Testing')
saver.restore(sess, TEST_MODEL_PATH)
if not os.path.exists(str(TEST_SAVE_PATH)):
os.makedirs(str(TEST_SAVE_PATH))
for testImgName in glob.glob(TEST_PATH + '*.png'):
print('Test Image %s' % path_leaf(testImgName))
testBBImg = cv2.imread(testImgName, 0)
tcoeffs = pw.dwt2(testBBImg, WV)
tcA, (tcH, tcV, tcD) = tcoeffs
tcA = tcA.astype(np.float32) / 255
tcH = tcH.astype(np.float32) / 255
tcV = tcV.astype(np.float32) / 255
tcD = tcD.astype(np.float32) / 255
test_temp = np.array([tcA, tcH, tcV, tcD])
test_elem = np.rollaxis(test_temp, 0, 3)
test_data = test_elem[np.newaxis, ...]
start_time = time.time()
output_data = sess.run([test_output], feed_dict={test_input: test_data})
duration = time.time() - start_time
dcA = output_data[0][0, :, :, 0]
dcH = output_data[0][0, :, :, 1]
dcV = output_data[0][0, :, :, 2]
dcD = output_data[0][0, :, :, 3]
srcoeffs = (dcA * 255 + tcA * 255,
(dcH * 255 + tcH * 255,
dcV * 255 + tcV * 255,
dcD * 255 + tcD * 255))
sr_img = pw.idwt2(srcoeffs, WV)
# cv2.namedWindow('SR', cv2.WINDOW_NORMAL)
# cv2.imshow("SR", sr_img) # Show image
# cv2.waitKey(0)
cv2.imwrite(str(TEST_SAVE_PATH +
str(path_leaf(testImgName))), sr_img)
print ('Image%s, processing time%s' % (path_leaf(testImgName),
str(duration)))
else:
print('>>>>>>>>>Wrong script for New Experiment!')
sess.close()
print ('Testing finished, the SR luminance information is stored at ./Resultx2Lum\n'
'Please Run FinalColorSRx2.m in Matlab to generate color SR results.')