-
Notifications
You must be signed in to change notification settings - Fork 266
/
Copy pathinference.py
90 lines (75 loc) · 3.46 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import numpy as np
import matplotlib.pyplot as plt
from pylab import *
import os
import sys
import cv2
from PIL import Image
from keras.preprocessing.image import *
from keras.models import load_model
import keras.backend as K
from keras.applications.imagenet_utils import preprocess_input
from models import *
def inference(model_name, weight_file, image_size, image_list, data_dir, label_dir, return_results=True, save_dir=None,
label_suffix='.png',
data_suffix='.jpg'):
current_dir = os.path.dirname(os.path.realpath(__file__))
# mean_value = np.array([104.00699, 116.66877, 122.67892])
batch_shape = (1, ) + image_size + (3, )
save_path = os.path.join(current_dir, 'Models/'+model_name)
model_path = os.path.join(save_path, "model.json")
checkpoint_path = os.path.join(save_path, weight_file)
# model_path = os.path.join(current_dir, 'model_weights/fcn_atrous/model_change.hdf5')
# model = FCN_Resnet50_32s((480,480,3))
config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
session = tf.Session(config=config)
K.set_session(session)
model = globals()[model_name](batch_shape=batch_shape, input_shape=(512, 512, 3))
model.load_weights(checkpoint_path, by_name=True)
model.summary()
results = []
total = 0
for img_num in image_list:
img_num = img_num.strip('\n')
total += 1
print('#%d: %s' % (total,img_num))
image = Image.open('%s/%s%s' % (data_dir, img_num, data_suffix))
image = img_to_array(image) # , data_format='default')
label = Image.open('%s/%s%s' % (label_dir, img_num, label_suffix))
label_size = label.size
img_h, img_w = image.shape[0:2]
# long_side = max(img_h, img_w, image_size[0], image_size[1])
pad_w = max(image_size[1] - img_w, 0)
pad_h = max(image_size[0] - img_h, 0)
image = np.lib.pad(image, ((pad_h/2, pad_h - pad_h/2), (pad_w/2, pad_w - pad_w/2), (0, 0)), 'constant', constant_values=0.)
# image -= mean_value
'''img = array_to_img(image, 'channels_last', scale=False)
img.show()
exit()'''
# image = cv2.resize(image, image_size)
image = np.expand_dims(image, axis=0)
image = preprocess_input(image)
result = model.predict(image, batch_size=1)
result = np.argmax(np.squeeze(result), axis=-1).astype(np.uint8)
result_img = Image.fromarray(result, mode='P')
result_img.palette = label.palette
# result_img = result_img.resize(label_size, resample=Image.BILINEAR)
result_img = result_img.crop((pad_w/2, pad_h/2, pad_w/2+img_w, pad_h/2+img_h))
# result_img.show(title='result')
if return_results:
results.append(result_img)
if save_dir:
result_img.save(os.path.join(save_dir, img_num + '.png'))
return results
if __name__ == '__main__':
# model_name = 'AtrousFCN_Resnet50_16s'
# model_name = 'Atrous_DenseNet'
model_name = 'DenseNet_FCN'
weight_file = 'checkpoint_weights.hdf5'
image_size = (512, 512)
data_dir = os.path.expanduser('~/.keras/datasets/VOC2012/VOCdevkit/VOC2012/JPEGImages')
label_dir = os.path.expanduser('~/.keras/datasets/VOC2012/VOCdevkit/VOC2012/SegmentationClass')
image_list = sys.argv[1:]#'2007_000491'
results = inference(model_name, weight_file, image_size, image_list, data_dir, label_dir)
for result in results:
result.show(title='result', command=None)