From f4aa65444ff40cc7c385f94d28b90800dc8fb413 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9rgio=20Jesus?= <36162088+sgpjesus@users.noreply.github.com> Date: Wed, 10 Apr 2024 15:03:58 +0100 Subject: [PATCH] Add an alpha to `GroupThreshold` to control balance between original score and fairness (#191) * Add an alpha to threshold to control balance between score and fairness * Add alpha to docstring --- src/aequitas/audit.py | 11 ++++++++++- .../postprocessing/balanced_group_threshold.py | 9 +++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/aequitas/audit.py b/src/aequitas/audit.py index 37bad0b4..a22b8e1b 100644 --- a/src/aequitas/audit.py +++ b/src/aequitas/audit.py @@ -6,6 +6,7 @@ from .bias import Bias from .group import Group from .plot import summary, disparity, absolute +from .flow.methods.postprocessing import Threshold class Audit: @@ -49,7 +50,7 @@ class is a wrapper around the Group and Bias classes. the keys are the sensitive attribute columns and the values are the reference groups. By default, 'maj'. """ - self.df = df + self.df = df.copy(deep=True) self.score_column = score_column self.threshold = threshold self.label_column = label_column @@ -253,6 +254,14 @@ def _validate_score_column(self): # If not binarized and a threshold is not passed, raise an error if not self.binarized and self.threshold is None: raise ValueError("Scores are not binarized. Please pass a threshold.") + if not self.binarized: + self.threshold_object = Threshold(**self.threshold) + self.threshold_object.fit( + None, self.df[self.score_column], self.df[self.label_column] + ) + self.df[self.score_column] = self.threshold_object.transform( + None, self.df[self.score_column] + ) def _validate_label_column(self): # Check if column exists diff --git a/src/aequitas/flow/methods/postprocessing/balanced_group_threshold.py b/src/aequitas/flow/methods/postprocessing/balanced_group_threshold.py index c5905714..7a5cb198 100644 --- a/src/aequitas/flow/methods/postprocessing/balanced_group_threshold.py +++ b/src/aequitas/flow/methods/postprocessing/balanced_group_threshold.py @@ -14,6 +14,7 @@ def __init__( threshold_type: str, threshold_value: Union[float, int], fairness_metric: str, + alpha: float = 1, ): """Initialize a new instance of the BalancedGroupThreshold class. @@ -35,11 +36,14 @@ def __init__( - tpr: true positive rate - fpr: false positive rate - pprev: predicted prevalence + alpha : float, optional + The alpha value to use for the model score correction. The default is 1. """ self.logger = create_logger("methods.postprocessing.BalancedGroupThreshold") self.threshold_type = threshold_type self.threshold_value = threshold_value self.fairness_metric = fairness_metric + self.alpha = alpha self.thresholds = {} @@ -100,6 +104,11 @@ def process_group(group_df): # Forward fill the 'value' column group_df["value"].fillna(method="ffill", inplace=True) group_df["value"].fillna(0, inplace=True) + + # Apply model score correction + group_df["value"] = group_df["value"] * self.alpha + ( + 1 - group_df["y_hat"] + ) * (1 - self.alpha) return group_df # Create a single DataFrame