Skip to content

Commit

Permalink
move to init
Browse files Browse the repository at this point in the history
  • Loading branch information
Artem Vazhentsev committed Sep 6, 2024
1 parent 5e13f8e commit 5ce3175
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/lm_polygraph/ue_metrics/pred_rej_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,19 @@ class PredictionRejectionArea(UEMetric):
Calculates area under Prediction-Rejection curve.
"""

def __init__(self, max_rejection: float = 1.0):
"""
Parameters:
max_rejection (float): a maximum proportion of instances that will be rejected.
1.0 indicates entire set, 0.5 - half of the set
"""
super().__init__()
self.max_rejection = max_rejection

def __str__(self):
return "prr"

def __call__(
self, estimator: List[float], target: List[float], max_rejection: float = 1.0
) -> float:
def __call__(self, estimator: List[float], target: List[float]) -> float:
"""
Measures the area under the Prediction-Rejection curve between `estimator` and `target`.
Expand All @@ -24,8 +31,6 @@ def __call__(
Higher values indicate more uncertainty.
target (List[int]): a batch of ground-truth uncertainty estimations.
Higher values indicate less uncertainty.
max_rejection (float): a maximum proportion of instances that will be rejected.
1.0 indicates entire set, 0.5 - half of the set
Returns:
float: area under the Prediction-Rejection curve.
Higher values indicate better uncertainty estimations.
Expand All @@ -34,7 +39,7 @@ def __call__(
# ue: greater is more uncertain
ue = np.array(estimator)
num_obs = len(ue)
num_rej = int(max_rejection * num_obs)
num_rej = int(self.max_rejection * num_obs)
# Sort in ascending order: the least uncertain come first
ue_argsort = np.argsort(ue)
# want sorted_metrics to be increasing => smaller scores is better
Expand Down

0 comments on commit 5ce3175

Please sign in to comment.