-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_cat_edge.py
74 lines (56 loc) · 2.36 KB
/
generate_cat_edge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import numpy as np
import cv2
import time
LABELS = {'bg': 0, 'skin': 1, 'nose': 2, 'eye_g': 3, 'l_eye': 4, 'r_eye': 5,
'l_brow': 6, 'r_brow': 7, 'l_ear': 8, 'r_ear': 9, 'mouth': 10, 'u_lip': 11,
'l_lip': 12, 'hair': 13, 'hat': 14, 'ear_r': 15, 'neck_l': 16, 'neck': 17, 'cloth': 18}
def generate_cat_edge(seg, seg_name, edge_width=3):
h, w = seg.shape
edge = np.zeros(seg.shape)
# left
for i in range(h):
for j in range(1, w):
if seg[i][j] != 0 and seg[i][j - 1] == 0:
edge[i][j] = LABELS[seg_name]
# right
for i in range(h):
for j in range(w - 1):
if seg[i][j] != 0 and seg[i][j + 1] == 0:
edge[i][j] = LABELS[seg_name]
# up
for i in range(1, h):
for j in range(w):
if seg[i][j] != 0 and seg[i - 1][j] == 0:
edge[i][j] = LABELS[seg_name]
# down
for i in range(h - 1):
for j in range(w):
if seg[i][j] != 0 and seg[i + 1][j] == 0:
edge[i][j] = LABELS[seg_name]
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (edge_width, edge_width))
edge = cv2.dilate(edge, kernel)
return edge
im_path = os.path.join('D:/Dataset/CelebAMask-HQ/CelebAMask-HQ/CelebA-HQ-img/')
image_list = os.listdir(im_path)
#im = cv2.imread(im_path, cv2.IMREAD_COLOR)
# image_list = [x for x in image_list if int(x[:-4]) < 2500]
parsing_anno_path = os.path.join('D:/Dataset/CelebAMask-HQ/CelebAMask-HQ/CelebAMask-HQ-mask-anno_acc/')
annotation_list = os.listdir(parsing_anno_path)
save_dir = "D:/Dataset/CelebAMask-HQ/CelebAMask-HQ/CelebAMask-HQ-mask-anno_cat_edges/"
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for im_list in image_list:
start = time.time()
parent_img_name = im_list[:-4].zfill(5)
print(parent_img_name)
edge = np.zeros((512, 512))
for idx, ann_list in enumerate(annotation_list):
if parent_img_name in ann_list:
annotation_path = parsing_anno_path + ann_list
parsing_anno = cv2.imread(annotation_path, cv2.IMREAD_GRAYSCALE)
#print(ann_list[6:-4])
edge = generate_cat_edge(parsing_anno, ann_list[6:-4])
edge = cv2.resize(edge, (473, 473), cv2.INTER_NEAREST)
cv2.imwrite(save_dir + parent_img_name + "_" + ann_list[6:-4] + ".png", edge)
print("time :", time.time() - start)