Skip to content

Commit

Permalink
Fix convention for src/melt/tools/metrics/calibration_metric.py
Browse files Browse the repository at this point in the history
  • Loading branch information
minhtrung23 committed Sep 8, 2024
1 parent 5860b17 commit 0185b8e
Showing 1 changed file with 48 additions and 31 deletions.
79 changes: 48 additions & 31 deletions src/melt/tools/metrics/calibration_metric.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,60 @@
from typing import Dict
import calibration as cal
"""Module for evaluating the calibration of probabilistic models."""


from typing import Dict, List
import numpy as np
try:
from melt.calibration import get_ece_em, get_ece, get_selective_stats, get_platt_scaler
print("Import successful")
except ImportError as e:
print(f"Import error: {e}")
from .utils import normalize_text
from .base import BaseMetric
from .post_process import softmax_options_prob
from typing import List


class CalibrationMetric(BaseMetric):
"""Evaluate the calibration of probabilistic models"""
"""Evaluate the calibration of probabilistic models."""

# def __init__(self) -> None:
# pass

def get_cal_score(self, max_probs: List[float], correct: List[int]):
def get_cal_score(self, max_probs: List[float], correct: List[int]) -> Dict[str, float]:
"""Calculates various calibration scores based on
the predicted probabilities (max_probs) and
the ground truth labels (correct).
Args:
max_probs (List[float]): A list of the maximum probabilities
predicted by the model for each instance.
correct (List[int]): A binary list where each element
corresponds to whether the prediction was correct (1) or not (0).
Returns:
A dictionary containing ECE scores for 10 bins and 1 bin,
Dict[str, float]: A dictionary containing ECE scores for 10 bins and 1 bin,
coverage accuracy area, accuracy in the top 10 percentile,
and Platt ECE scores for 10 bins and 1 bin.
"""
ece_10_bin = cal.get_ece_em(max_probs, correct, num_bins=10)
ece_1_bin = cal.get_ece(max_probs, correct, num_bins=1)
coverage_acc_area, acc_top_10_percentile = cal.get_selective_stats(
max_probs, correct
max_probs_array = np.array(max_probs)
correct_array = np.array(correct)


ece_10_bin = get_ece_em(max_probs_array, correct_array, num_bins=10)
ece_1_bin = get_ece(max_probs_array, correct_array, num_bins=1)
coverage_acc_area, acc_top_10_percentile = get_selective_stats(
max_probs_array, correct_array
)
if np.sum(correct) == 0 or np.sum(correct) == len(correct):
if np.sum(correct_array) == 0 or np.sum(correct_array) == len(correct_array):
platt_ece_10_bin = 0.0
platt_ece_1_bin = 0.0
else:
platt_scaler, clf = cal.get_platt_scaler(
np.array(max_probs), np.array(correct), get_clf=True
)
cal_max_probs = platt_scaler(np.array(max_probs))
platt_ece_10_bin = cal.get_ece_em(
cal_max_probs, correct, num_bins=10
)
platt_ece_1_bin = cal.get_ece(cal_max_probs, correct, num_bins=1)
platt_scaler, _ = get_platt_scaler(max_probs_array, correct_array, get_clf=False)
cal_max_probs = platt_scaler(max_probs_array)
platt_ece_10_bin = get_ece_em(cal_max_probs, correct_array, num_bins=10)
platt_ece_1_bin = get_ece(cal_max_probs, correct_array, num_bins=1)


return {
"ece_10_bin": ece_10_bin,
Expand All @@ -57,17 +65,20 @@ def get_cal_score(self, max_probs: List[float], correct: List[int]):
"platt_ece_1_bin": platt_ece_1_bin,
}

def evaluate(self, data: Dict, args, **kwargs) -> (Dict, Dict):

def evaluate(self, data: Dict, args) -> (Dict, Dict):
"""Evaluates the given predictions against the references
in the dictionary.
Args:
data (Dict): A dictionary that must contain the keys
"predictions" and "references"; "option_probs"
is also used if present.
Returns:
Returns a tuple of two dictionaries:
Tuple[Dict, Dict]: Returns a tuple of two dictionaries:
- The first dictionary is the updated data with
additional key "max_probs".
- The second dictionary result contains the mean of
Expand All @@ -81,31 +92,37 @@ def evaluate(self, data: Dict, args, **kwargs) -> (Dict, Dict):
]
references = data["references"]


accuracy = [
int(normalize_text(str(pred)) == normalize_text(str(ref)))
for pred, ref in zip(predictions, references)
]
sum_option_probs = []
for i in range(len(data["option_probs"])):
sum_option_probs.append(
[np.array(x).sum() for x in data["option_probs"][i]]
)
option_probs = data.get("option_probs", [])
if option_probs:
sum_option_probs = [
[np.array(x).sum() for x in option_probs[i]]
for i in range(len(option_probs))
]
else:
sum_option_probs = []


if "gpt" in args.filepath:
probs = softmax_options_prob(sum_option_probs)
probs = np.zeros_like(probs)
labels = np.array(
[args.class_names.index(str(ref)) for ref in references]
)
labels = np.array([args.class_names.index(str(ref)) for ref in references])


for i, label in enumerate(labels):
probs[i][label] = 1
else:
probs = softmax_options_prob(sum_option_probs)


max_probs = np.max(probs, axis=1)
data["max_probs"] = list(max_probs)
result["max_probs"] = max_probs.mean()
result.update(self.get_cal_score(max_probs, accuracy))


return data, result

0 comments on commit 0185b8e

Please sign in to comment.