Skip to content

Commit

Permalink
Fixed folktables index when reading from csv (#179)
Browse files Browse the repository at this point in the history
* Fixed folktables index when reading from csv
  • Loading branch information
reluzita authored Feb 27, 2024
1 parent ad98c0c commit 592d76a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
13 changes: 10 additions & 3 deletions src/aequitas/flow/datasets/folktables.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,16 @@ def load_data(self):
self.data = pd.read_parquet(path)
else:
if self.split_type == "predefined":
datasets = [pd.read_csv(p) for p in path]
self._indexes = [d.index for d in datasets]
self.data = pd.concat(datasets)
train = pd.read_csv(path[0])
train_index = train.index[-1]
validation = pd.read_csv(path[1])
validation.set_index(validation.index + train_index + 1, inplace=True)
validation_index = validation.index[-1]
test = pd.read_csv(path[2])
test.set_index(test.index + validation_index + 1, inplace=True)
self._indexes = [train.index, validation.index, test.index]

self.data = pd.concat([train, validation, test])
else:
self.data = pd.read_csv(path)

Expand Down
2 changes: 1 addition & 1 deletion src/aequitas/flow/methods/preprocessing/label_flipping.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _feature_suppression(self, X: pd.DataFrame, s: pd.Series) -> pd.DataFrame:
else self.unawareness_features
)
X_transformed = X_transformed.drop(columns=unawareness_features_list)

return X_transformed

def fit(self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series]) -> None:
Expand Down

0 comments on commit 592d76a

Please sign in to comment.