Skip to content

Commit

Permalink
Bug fixes in pre-processing methods
Browse files Browse the repository at this point in the history
  • Loading branch information
reluzita authored Feb 26, 2024
1 parent cb8c6b6 commit ad98c0c
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/aequitas/flow/datasets/folktables.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

BOOL_COLUMNS = {
"ACSIncome": ["SEX"],
"ACSEmployment": ["SEX", "DIS", "NATIVTY", "DEAR", "DEYE", "DREM"],
"ACSEmployment": ["SEX", "DIS", "NATIVITY", "DEAR", "DEYE", "DREM"],
"ACSMobility": [
"SEX",
"DIS",
Expand Down
9 changes: 7 additions & 2 deletions src/aequitas/flow/methods/preprocessing/data_repairer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
self.repair_level = repair_level
self.columns = columns
self.definition = definition
self.used_in_inference = True

def fit(self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series] = None) -> None:
"""
Expand All @@ -72,7 +73,11 @@ def fit(self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series] = None) -> N
super().fit(X, y, s)

if self.columns is None:
self.columns = X.columns.tolist()
self.columns = [
column
for column in X.columns
if (X[column].dtype != "category" and X[column].dtype != "bool")
]
if s is None:
raise ValueError("s must be passed.")
self._quantile_points = np.linspace(0, 1, self.definition)
Expand Down Expand Up @@ -141,7 +146,7 @@ def transform(
Transformed features, labels, and sensitive attribute.
"""
super().transform(X, y, s)

if s is None:
raise ValueError("s must be passed.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
feature_importance_threshold: Optional[float] = 0.1,
n_estimators: Optional[int] = 10,
seed: int = 0,
n_jobs: int = 1,
):
"""Iterively removes the most important features with respect to the sensitive
attribute.
Expand All @@ -32,6 +33,8 @@ def __init__(
The number of trees in the random forest. Defaults to 10.
seed : int, optional
The seed for the random forest. Defaults to 0.
n_jobs : int, optional
The number of jobs to run in parallel. Defaults to 1.
"""
self.logger = create_logger(
"methods.preprocessing.FeatureImportanceSuppression"
Expand All @@ -45,6 +48,7 @@ def __init__(
self.feature_importance_threshold = feature_importance_threshold
self.n_estimators = n_estimators
self.seed = seed
self.n_jobs = n_jobs

def fit(self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series]) -> None:
"""Iteratively removes the most important features to predict the sensitive
Expand All @@ -64,7 +68,7 @@ def fit(self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series]) -> None:
self.logger.info("Identifying features to remove.")

rf = RandomForestClassifier(
n_estimators=self.n_estimators, random_state=self.seed
n_estimators=self.n_estimators, random_state=self.seed, n_jobs=self.n_jobs
)

features = pd.concat([X, y], axis=1)
Expand Down
1 change: 1 addition & 0 deletions src/aequitas/flow/methods/preprocessing/massaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(

self.classifier = instantiate_object(classifier, **classifier_args)
self.logger.info(f"Created base estimator {self.classifier}")
self.used_in_inference = False

def _rank(
self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series]
Expand Down

0 comments on commit ad98c0c

Please sign in to comment.