-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetric_caller.py
56 lines (46 loc) · 1.82 KB
/
metric_caller.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# -*- coding: utf-8 -*-
# @Time : 2020/12/19
# @Author : Lart Pang
# @GitHub : https://github.com/lartpang
import numpy as np
from py_sod_metrics.sod_metrics import Emeasure, Fmeasure, MAE, Smeasure, WeightedFmeasure
class CalTotalMetric(object):
__slots__ = ["cal_mae", "cal_fm", "cal_sm", "cal_em", "cal_wfm"]
def __init__(self):
self.cal_mae = MAE()
self.cal_fm = Fmeasure()
self.cal_sm = Smeasure()
self.cal_em = Emeasure()
self.cal_wfm = WeightedFmeasure()
def step(self, pred: np.ndarray, gt: np.ndarray, gt_path: str):
assert pred.ndim == gt.ndim and pred.shape == gt.shape, (pred.shape, gt.shape, gt_path)
assert pred.dtype == np.uint8, pred.dtype
assert gt.dtype == np.uint8, gt.dtype
self.cal_mae.step(pred, gt)
self.cal_fm.step(pred, gt)
self.cal_sm.step(pred, gt)
self.cal_em.step(pred, gt)
self.cal_wfm.step(pred, gt)
def get_results(self, bit_width: int = 3) -> dict:
fm = self.cal_fm.get_results()["fm"]
wfm = self.cal_wfm.get_results()["wfm"]
sm = self.cal_sm.get_results()["sm"]
em = self.cal_em.get_results()["em"]
mae = self.cal_mae.get_results()["mae"]
results = {
"Smeasure": sm,
"wFmeasure": wfm,
"MAE": mae,
"adpEm": em["adp"],
"meanEm": em["curve"].mean(),
"maxEm": em["curve"].max(),
"adpFm": fm["adp"],
"meanFm": fm["curve"].mean(),
"maxFm": fm["curve"].max(),
}
def _round_w_zero_padding(_x):
_x = str(_x.round(bit_width))
_x += "0" * (bit_width - len(_x.split(".")[-1]))
return _x
results = {name: _round_w_zero_padding(metric) for name, metric in results.items()}
return results