diff --git a/src/lm_polygraph/ue_metrics/pred_rej_area.py b/src/lm_polygraph/ue_metrics/pred_rej_area.py index 6ce77555..dc20164a 100644 --- a/src/lm_polygraph/ue_metrics/pred_rej_area.py +++ b/src/lm_polygraph/ue_metrics/pred_rej_area.py @@ -10,6 +10,15 @@ class PredictionRejectionArea(UEMetric): Calculates area under Prediction-Rejection curve. """ + def __init__(self, max_rejection: float = 1.0): + """ + Parameters: + max_rejection (float): a maximum proportion of instances that will be rejected. + 1.0 indicates entire set, 0.5 - half of the set + """ + super().__init__() + self.max_rejection = max_rejection + def __str__(self): return "prr" @@ -30,12 +39,13 @@ def __call__(self, estimator: List[float], target: List[float]) -> float: # ue: greater is more uncertain ue = np.array(estimator) num_obs = len(ue) + num_rej = int(self.max_rejection * num_obs) # Sort in ascending order: the least uncertain come first ue_argsort = np.argsort(ue) # want sorted_metrics to be increasing => smaller scores is better sorted_metrics = np.array(target)[ue_argsort] # Since we want all plots to coincide when all the data is discarded - cumsum = np.cumsum(sorted_metrics) - scores = (cumsum / np.arange(1, num_obs + 1))[::-1] - prr_score = np.sum(scores) / num_obs + cumsum = np.cumsum(sorted_metrics)[-num_rej:] + scores = (cumsum / np.arange((num_obs - num_rej) + 1, num_obs + 1))[::-1] + prr_score = np.sum(scores) / num_rej return prr_score