From 0185b8ef137ac36945608fc898fc790e8a9a79fe Mon Sep 17 00:00:00 2001 From: minhtrung23 Date: Sun, 8 Sep 2024 12:22:22 +0000 Subject: [PATCH] Fix convention for src/melt/tools/metrics/calibration_metric.py --- src/melt/tools/metrics/calibration_metric.py | 79 ++++++++++++-------- 1 file changed, 48 insertions(+), 31 deletions(-) diff --git a/src/melt/tools/metrics/calibration_metric.py b/src/melt/tools/metrics/calibration_metric.py index b011dc0..d242570 100644 --- a/src/melt/tools/metrics/calibration_metric.py +++ b/src/melt/tools/metrics/calibration_metric.py @@ -1,52 +1,60 @@ -from typing import Dict -import calibration as cal +"""Module for evaluating the calibration of probabilistic models.""" + + +from typing import Dict, List import numpy as np +try: + from melt.calibration import get_ece_em, get_ece, get_selective_stats, get_platt_scaler + print("Import successful") +except ImportError as e: + print(f"Import error: {e}") from .utils import normalize_text from .base import BaseMetric from .post_process import softmax_options_prob -from typing import List class CalibrationMetric(BaseMetric): - """Evaluate the calibration of probabilistic models""" + """Evaluate the calibration of probabilistic models.""" - # def __init__(self) -> None: - # pass - def get_cal_score(self, max_probs: List[float], correct: List[int]): + def get_cal_score(self, max_probs: List[float], correct: List[int]) -> Dict[str, float]: """Calculates various calibration scores based on the predicted probabilities (max_probs) and the ground truth labels (correct). + Args: max_probs (List[float]): A list of the maximum probabilities predicted by the model for each instance. + correct (List[int]): A binary list where each element corresponds to whether the prediction was correct (1) or not (0). + Returns: - A dictionary containing ECE scores for 10 bins and 1 bin, + Dict[str, float]: A dictionary containing ECE scores for 10 bins and 1 bin, coverage accuracy area, accuracy in the top 10 percentile, and Platt ECE scores for 10 bins and 1 bin. """ - ece_10_bin = cal.get_ece_em(max_probs, correct, num_bins=10) - ece_1_bin = cal.get_ece(max_probs, correct, num_bins=1) - coverage_acc_area, acc_top_10_percentile = cal.get_selective_stats( - max_probs, correct + max_probs_array = np.array(max_probs) + correct_array = np.array(correct) + + + ece_10_bin = get_ece_em(max_probs_array, correct_array, num_bins=10) + ece_1_bin = get_ece(max_probs_array, correct_array, num_bins=1) + coverage_acc_area, acc_top_10_percentile = get_selective_stats( + max_probs_array, correct_array ) - if np.sum(correct) == 0 or np.sum(correct) == len(correct): + if np.sum(correct_array) == 0 or np.sum(correct_array) == len(correct_array): platt_ece_10_bin = 0.0 platt_ece_1_bin = 0.0 else: - platt_scaler, clf = cal.get_platt_scaler( - np.array(max_probs), np.array(correct), get_clf=True - ) - cal_max_probs = platt_scaler(np.array(max_probs)) - platt_ece_10_bin = cal.get_ece_em( - cal_max_probs, correct, num_bins=10 - ) - platt_ece_1_bin = cal.get_ece(cal_max_probs, correct, num_bins=1) + platt_scaler, _ = get_platt_scaler(max_probs_array, correct_array, get_clf=False) + cal_max_probs = platt_scaler(max_probs_array) + platt_ece_10_bin = get_ece_em(cal_max_probs, correct_array, num_bins=10) + platt_ece_1_bin = get_ece(cal_max_probs, correct_array, num_bins=1) + return { "ece_10_bin": ece_10_bin, @@ -57,17 +65,20 @@ def get_cal_score(self, max_probs: List[float], correct: List[int]): "platt_ece_1_bin": platt_ece_1_bin, } - def evaluate(self, data: Dict, args, **kwargs) -> (Dict, Dict): + + def evaluate(self, data: Dict, args) -> (Dict, Dict): """Evaluates the given predictions against the references in the dictionary. + Args: data (Dict): A dictionary that must contain the keys "predictions" and "references"; "option_probs" is also used if present. + Returns: - Returns a tuple of two dictionaries: + Tuple[Dict, Dict]: Returns a tuple of two dictionaries: - The first dictionary is the updated data with additional key "max_probs". - The second dictionary result contains the mean of @@ -81,31 +92,37 @@ def evaluate(self, data: Dict, args, **kwargs) -> (Dict, Dict): ] references = data["references"] + accuracy = [ int(normalize_text(str(pred)) == normalize_text(str(ref))) for pred, ref in zip(predictions, references) ] - sum_option_probs = [] - for i in range(len(data["option_probs"])): - sum_option_probs.append( - [np.array(x).sum() for x in data["option_probs"][i]] - ) + option_probs = data.get("option_probs", []) + if option_probs: + sum_option_probs = [ + [np.array(x).sum() for x in option_probs[i]] + for i in range(len(option_probs)) + ] + else: + sum_option_probs = [] + if "gpt" in args.filepath: probs = softmax_options_prob(sum_option_probs) probs = np.zeros_like(probs) - labels = np.array( - [args.class_names.index(str(ref)) for ref in references] - ) + labels = np.array([args.class_names.index(str(ref)) for ref in references]) + for i, label in enumerate(labels): probs[i][label] = 1 else: probs = softmax_options_prob(sum_option_probs) + max_probs = np.max(probs, axis=1) data["max_probs"] = list(max_probs) result["max_probs"] = max_probs.mean() result.update(self.get_cal_score(max_probs, accuracy)) + return data, result