forked from bubbliiiing/segformer-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvoc_annotation.py
98 lines (84 loc) · 4.01 KB
/
voc_annotation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import random
import numpy as np
from PIL import Image
from tqdm import tqdm
#-------------------------------------------------------#
# 想要增加测试集修改trainval_percent
# 修改train_percent用于改变验证集的比例 9:1
#
# 当前该库将测试集当作验证集使用,不单独划分测试集
#-------------------------------------------------------#
trainval_percent = 1
train_percent = 0.9
#-------------------------------------------------------#
# 指向VOC数据集所在的文件夹
# 默认指向根目录下的VOC数据集
#-------------------------------------------------------#
VOCdevkit_path = 'VOCdevkit'
if __name__ == "__main__":
random.seed(0)
print("Generate txt in ImageSets.")
segfilepath = os.path.join(VOCdevkit_path, 'VOC2007/SegmentationClass')
saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Segmentation')
temp_seg = os.listdir(segfilepath)
total_seg = []
for seg in temp_seg:
if seg.endswith(".png"):
total_seg.append(seg)
num = len(total_seg)
list = range(num)
tv = int(num*trainval_percent)
tr = int(tv*train_percent)
trainval= random.sample(list,tv)
train = random.sample(trainval,tr)
print("train and val size",tv)
print("traub suze",tr)
ftrainval = open(os.path.join(saveBasePath,'trainval.txt'), 'w')
ftest = open(os.path.join(saveBasePath,'test.txt'), 'w')
ftrain = open(os.path.join(saveBasePath,'train.txt'), 'w')
fval = open(os.path.join(saveBasePath,'val.txt'), 'w')
for i in list:
name = total_seg[i][:-4]+'\n'
if i in trainval:
ftrainval.write(name)
if i in train:
ftrain.write(name)
else:
fval.write(name)
else:
ftest.write(name)
ftrainval.close()
ftrain.close()
fval.close()
ftest.close()
print("Generate txt in ImageSets done.")
print("Check datasets format, this may take a while.")
print("检查数据集格式是否符合要求,这可能需要一段时间。")
classes_nums = np.zeros([256], np.int)
for i in tqdm(list):
name = total_seg[i]
png_file_name = os.path.join(segfilepath, name)
if not os.path.exists(png_file_name):
raise ValueError("未检测到标签图片%s,请查看具体路径下文件是否存在以及后缀是否为png。"%(png_file_name))
png = np.array(Image.open(png_file_name), np.uint8)
if len(np.shape(png)) > 2:
print("标签图片%s的shape为%s,不属于灰度图或者八位彩图,请仔细检查数据集格式。"%(name, str(np.shape(png))))
print("标签图片需要为灰度图或者八位彩图,标签的每个像素点的值就是这个像素点所属的种类。"%(name, str(np.shape(png))))
classes_nums += np.bincount(np.reshape(png, [-1]), minlength=256)
print("打印像素点的值与数量。")
print('-' * 37)
print("| %15s | %15s |"%("Key", "Value"))
print('-' * 37)
for i in range(256):
if classes_nums[i] > 0:
print("| %15s | %15s |"%(str(i), str(classes_nums[i])))
print('-' * 37)
if classes_nums[255] > 0 and classes_nums[0] > 0 and np.sum(classes_nums[1:255]) == 0:
print("检测到标签中像素点的值仅包含0与255,数据格式有误。")
print("二分类问题需要将标签修改为背景的像素点值为0,目标的像素点值为1。")
elif classes_nums[0] > 0 and np.sum(classes_nums[1:]) == 0:
print("检测到标签中仅仅包含背景像素点,数据格式有误,请仔细检查数据集格式。")
print("JPEGImages中的图片应当为.jpg文件、SegmentationClass中的图片应当为.png文件。")
print("如果格式有误,参考:")
print("https://github.com/bubbliiiing/segmentation-format-fix")