Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
bubbliiiing authored Dec 14, 2021
1 parent c29a0a3 commit 7be89de
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 14 deletions.
12 changes: 7 additions & 5 deletions get_miou.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tqdm import tqdm

from pspnet import Pspnet
from utils.utils_metrics import compute_mIoU
from utils.utils_metrics import compute_mIoU, show_results

gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
Expand Down Expand Up @@ -39,9 +39,10 @@
#-------------------------------------------------------#
VOCdevkit_path = 'VOCdevkit'

image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Segmentation/val.txt"),'r').read().splitlines()
gt_dir = os.path.join(VOCdevkit_path, "VOC2007/SegmentationClass/")
pred_dir = "miou_out"
image_ids = open(os.path.join(VOCdevkit_path, "VOC2007/ImageSets/Segmentation/val.txt"),'r').read().splitlines()
gt_dir = os.path.join(VOCdevkit_path, "VOC2007/SegmentationClass/")
miou_out_path = "miou_out"
pred_dir = os.path.join(miou_out_path, 'detection-results')

if miou_mode == 0 or miou_mode == 1:
if not os.path.exists(pred_dir):
Expand All @@ -61,5 +62,6 @@

if miou_mode == 0 or miou_mode == 2:
print("Get miou.")
compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes) # 执行计算mIoU的函数
hist, IoUs, PA_Recall, Precision = compute_mIoU(gt_dir, pred_dir, image_ids, num_classes, name_classes) # 执行计算mIoU的函数
print("Get miou done.")
show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes)
92 changes: 83 additions & 9 deletions utils/utils_metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import csv
import os
from os.path import join

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from tensorflow.keras import backend
Expand Down Expand Up @@ -46,9 +49,15 @@ def fast_hist(a, b, n):
def per_class_iu(hist):
return np.diag(hist) / np.maximum((hist.sum(1) + hist.sum(0) - np.diag(hist)), 1)

def per_class_PA(hist):
def per_class_PA_Recall(hist):
return np.diag(hist) / np.maximum(hist.sum(1), 1)

def per_class_Precision(hist):
return np.diag(hist) / np.maximum(hist.sum(0), 1)

def per_Accuracy(hist):
return np.sum(np.diag(hist)) / np.maximum(np.sum(hist), 1)

def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes):
print('Num classes', num_classes)
#-----------------------------------------#
Expand Down Expand Up @@ -90,22 +99,87 @@ def compute_mIoU(gt_dir, pred_dir, png_name_list, num_classes, name_classes):
hist += fast_hist(label.flatten(), pred.flatten(),num_classes)
# 每计算10张就输出一下目前已计算的图片中所有类别平均的mIoU值
if ind > 0 and ind % 10 == 0:
print('{:d} / {:d}: mIou-{:0.2f}; mPA-{:0.2f}'.format(ind, len(gt_imgs),
100 * np.nanmean(per_class_iu(hist)),
100 * np.nanmean(per_class_PA(hist))))
print('{:d} / {:d}: mIou-{:0.2f}%; mPA-{:0.2f}%; Accuracy-{:0.2f}%'.format(
ind,
len(gt_imgs),
100 * np.nanmean(per_class_iu(hist)),
100 * np.nanmean(per_class_PA_Recall(hist)),
100 * per_Accuracy(hist)
)
)
#------------------------------------------------#
# 计算所有验证集图片的逐类别mIoU值
#------------------------------------------------#
mIoUs = per_class_iu(hist)
mPA = per_class_PA(hist)
IoUs = per_class_iu(hist)
PA_Recall = per_class_PA_Recall(hist)
Precision = per_class_Precision(hist)
#------------------------------------------------#
# 逐类别输出一下mIoU值
#------------------------------------------------#
for ind_class in range(num_classes):
print('===>' + name_classes[ind_class] + ':\tmIou-' + str(round(mIoUs[ind_class] * 100, 2)) + '; mPA-' + str(round(mPA[ind_class] * 100, 2)))
print('===>' + name_classes[ind_class] + ':\tIou-' + str(round(IoUs[ind_class] * 100, 2)) \
+ '; Recall (equal to the PA)-' + str(round(PA_Recall[ind_class] * 100, 2))+ '; Precision-' + str(round(Precision[ind_class] * 100, 2)))

#-----------------------------------------------------------------#
# 在所有验证集图像上求所有类别平均的mIoU值,计算时忽略NaN值
#-----------------------------------------------------------------#
print('===> mIoU: ' + str(round(np.nanmean(mIoUs) * 100, 2)) + '; mPA: ' + str(round(np.nanmean(mPA) * 100, 2)))
return mIoUs
print('===> mIoU: ' + str(round(np.nanmean(IoUs) * 100, 2)) + '; mPA: ' + str(round(np.nanmean(PA_Recall) * 100, 2)) + '; Accuracy: ' + str(round(per_Accuracy(hist) * 100, 2)))
return np.array(hist, np.int), IoUs, PA_Recall, Precision

def adjust_axes(r, t, fig, axes):
bb = t.get_window_extent(renderer=r)
text_width_inches = bb.width / fig.dpi
current_fig_width = fig.get_figwidth()
new_fig_width = current_fig_width + text_width_inches
propotion = new_fig_width / current_fig_width
x_lim = axes.get_xlim()
axes.set_xlim([x_lim[0], x_lim[1] * propotion])

def draw_plot_func(values, name_classes, plot_title, x_label, output_path, tick_font_size = 12, plt_show = True):
fig = plt.gcf()
axes = plt.gca()
plt.barh(range(len(values)), values, color='royalblue')
plt.title(plot_title, fontsize=tick_font_size + 2)
plt.xlabel(x_label, fontsize=tick_font_size)
plt.yticks(range(len(values)), name_classes, fontsize=tick_font_size)
r = fig.canvas.get_renderer()
for i, val in enumerate(values):
str_val = " " + str(val)
if val < 1.0:
str_val = " {0:.2f}".format(val)
t = plt.text(val, i, str_val, color='royalblue', va='center', fontweight='bold')
if i == (len(values)-1):
adjust_axes(r, t, fig, axes)

fig.tight_layout()
fig.savefig(output_path)
if plt_show:
plt.show()
plt.close()

def show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes, tick_font_size = 12):
draw_plot_func(IoUs, name_classes, "mIoU = {0:.2f}%".format(np.nanmean(IoUs)*100), "Intersection over Union", \
os.path.join(miou_out_path, "mIoU.png"), tick_font_size = tick_font_size, plt_show = True)
print("Save mIoU out to " + os.path.join(miou_out_path, "mIoU.png"))

draw_plot_func(PA_Recall, name_classes, "mPA = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Pixel Accuracy", \
os.path.join(miou_out_path, "mPA.png"), tick_font_size = tick_font_size, plt_show = False)
print("Save mPA out to " + os.path.join(miou_out_path, "mPA.png"))

draw_plot_func(PA_Recall, name_classes, "mRecall = {0:.2f}%".format(np.nanmean(PA_Recall)*100), "Recall", \
os.path.join(miou_out_path, "Recall.png"), tick_font_size = tick_font_size, plt_show = False)
print("Save Recall out to " + os.path.join(miou_out_path, "Recall.png"))

draw_plot_func(Precision, name_classes, "mPrecision = {0:.2f}%".format(np.nanmean(Precision)*100), "Precision", \
os.path.join(miou_out_path, "Precision.png"), tick_font_size = tick_font_size, plt_show = False)
print("Save Precision out to " + os.path.join(miou_out_path, "Precision.png"))

with open(os.path.join(miou_out_path, "confusion_matrix.csv"), 'w', newline='') as f:
writer = csv.writer(f)
writer_list = []
writer_list.append([' '] + [str(c) for c in name_classes])
for i in range(len(hist)):
writer_list.append([name_classes[i]] + [str(x) for x in hist[i]])
writer.writerows(writer_list)
print("Save confusion_matrix out to " + os.path.join(miou_out_path, "confusion_matrix.csv"))

0 comments on commit 7be89de

Please sign in to comment.