-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsplit_TrainValidTest.py
51 lines (44 loc) · 1.95 KB
/
split_TrainValidTest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import cv2
import os
import numpy as np
import random
def seperate():
image_path = "D:/Dataset/CelebAMask-HQ/CelebAMask-HQ/CelebA-HQ-img/" # CelebAMask-HQ 이미지 디렉터리
txt = "D:/Dataset/CelebAMask-HQ/CelebAMask-HQ/CelebA-HQ-to-CelebA-mapping_png.txt"
ca_train_path = "C:/Users/USER/Downloads/misf-main/misf-main/celebA/train/"
ca_valid_path = "C:/Users/USER/Downloads/misf-main/misf-main/celebA/valid/"
ca_test_path = "C:/Users/USER/Downloads/misf-main/misf-main/celebA/test/"
hq_train_path = "D:/Dataset/CelebAMask-HQ/train/"
if not os.path.exists(hq_train_path):
os.makedirs(hq_train_path)
hq_valid_path = "D:/Dataset/CelebAMask-HQ/valid/"
if not os.path.exists(hq_valid_path):
os.makedirs(hq_valid_path)
hq_test_path = "D:/Dataset/CelebAMask-HQ/test/"
if not os.path.exists(hq_test_path):
os.makedirs(hq_test_path)
#print("celebA", len(os.listdir(image_path)))
f = open(txt, "r")
lines = f.readlines()[1:]
for line in lines:
idx, orig_idx, orig_file = line.strip().split()
print(idx, orig_idx, orig_file)
image = cv2.imread(image_path + idx + ".jpg", cv2.IMREAD_COLOR)
if orig_file in os.listdir(ca_train_path):
#cv2.imwrite(hq_train_path + idx.zfill(5) + ".jpg", image)
continue
elif orig_file in os.listdir(ca_valid_path):
#cv2.imwrite(hq_valid_path + idx.zfill(5) + ".jpg", image)
continue
elif orig_file in os.listdir(ca_test_path):
cv2.imwrite(hq_test_path + idx.zfill(5) + ".jpg", image)
else:
print(idx.zfill(5) + ".jpg")
break
print("train", len(os.listdir(hq_train_path)))
print("valid", len(os.listdir(hq_valid_path)))
print("test", len(os.listdir(hq_test_path)))
print("sum", len(os.listdir(hq_train_path))+len(os.listdir(hq_valid_path))+len(os.listdir(hq_test_path)))
f.close()
if __name__ == "__main__":
seperate()