-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
116 lines (86 loc) · 4.5 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
refer to https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/utils/metrics.py
"""
import numpy
import numpy as np
import os
import cv2
__all__ = ['SegmentationMetric']
class SegmentationMetric(object):
def __init__(self, numClass):
self.numClass = numClass
self.confusionMatrix = np.zeros((self.numClass,) * 2)
def pixelAccuracy(self):
# return all class overall pixel accuracy
# PA = acc = (TP + TN) / (TP + TN + FP + TN)
acc = np.diag(self.confusionMatrix).sum() / self.confusionMatrix.sum()
return acc
def classPixelAccuracy(self):
# return each category pixel accuracy(A more accurate way to call it precision)
# acc = (TP) / TP + FP
classAcc = np.diag(np.transpose(self.confusionMatrix)) / np.transpose(self.confusionMatrix).sum(axis=1)
return classAcc # 返回的是一个列表值,如:[0.90, 0.80, 0.96],表示类别1 2 3各类别的预测准确率
def meanPixelAccuracy(self):
classAcc = self.classPixelAccuracy()
meanAcc = np.nanmean(classAcc) # np.nanmean 求平均值,nan表示遇到Nan类型,其值取为0
return meanAcc # 返回单个值,如:np.nanmean([0.90, 0.80, 0.96, nan, nan]) = (0.90 + 0.80 + 0.96) / 3 = 0.89
def meanIntersectionOverUnion(self):
# Intersection = TP Union = TP + FP + FN
# IoU = TP / (TP + FP + FN)
intersection = np.diag(self.confusionMatrix) # 取对角元素的值,返回列表
union = np.sum(self.confusionMatrix, axis=1) + np.sum(self.confusionMatrix, axis=0) - np.diag(
self.confusionMatrix) # axis = 1表示混淆矩阵行的值,返回列表; axis = 0表示取混淆矩阵列的值,返回列表
IoU = intersection / union # 返回列表,其值为各个类别的IoU
mIoU = np.nanmean(IoU) # 求各类别IoU的平均
return mIoU
def genConfusionMatrix(self, imgPredict, imgLabel): # 同FCN中score.py的fast_hist()函数
# remove classes from unlabeled pixels in gt image and predict
mask = (imgLabel >= 0) & (imgLabel < self.numClass)
label = self.numClass * imgLabel[mask] + imgPredict[mask]
# num_class * gt + pred
# [ 0 4 10 0 5 11 10 5 15]
count = np.bincount(label, minlength=self.numClass ** 2)
confusionMatrix = count.reshape(self.numClass, self.numClass)
return confusionMatrix
def Frequency_Weighted_Intersection_over_Union(self):
# FWIOU = [(TP+FN)/(TP+FP+TN+FN)] *[TP / (TP + FP + FN)]
freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix)
iu = np.diag(self.confusion_matrix) / (
np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) -
np.diag(self.confusion_matrix))
FWIoU = (freq[freq > 0] * iu[freq > 0]).sum()
return FWIoU
def addBatch(self, imgPredict, imgLabel):
assert imgPredict.shape == imgLabel.shape
self.confusionMatrix += self.genConfusionMatrix(imgPredict, imgLabel)
def reset(self):
self.confusionMatrix = np.zeros((self.numClass, self.numClass))
if __name__ == '__main__':
metric = SegmentationMetric(2) # num of class
imgPredict = cv2.imread("./1.png")
imgPredict = np.transpose(imgPredict, [2, 0, 1]) # uint8 (3, 340, 340)
imgPredict = imgPredict[:][:][0]
# print(pred_data.dtype)
imgLabel = cv2.imread("./2.png")
imgLabel = np.transpose(imgLabel, [2, 0, 1]) # uint8 (3, 340, 340)
imgLabel = imgLabel[:][:][0]
imgPredict = np.where(imgPredict > 2, 1, 0)
imgLabel = np.where(imgLabel > 2, 1, 0)
metric.addBatch(imgPredict, imgLabel)
print('ConfusionMatrix :')
print(metric.confusionMatrix) # numpy.transpose()
print('Add:')
print(numpy.sum(metric.confusionMatrix, axis=0))
print('%:')
print(metric.confusionMatrix / numpy.sum(metric.confusionMatrix, axis=0))
# print('ConfusionMatrix :')
# print(numpy.transpose(metric.confusionMatrix))
pa = metric.pixelAccuracy()
cpa = metric.classPixelAccuracy()
mpa = metric.meanPixelAccuracy()
mIoU = metric.meanIntersectionOverUnion()
print('pa is : %f' % pa)
print('cpa is :')
print(cpa)
print('mpa is : %f' % mpa)
print('mIoU is : %f' % mIoU)