diff --git a/src/aequitas/flow/datasets/folktables.py b/src/aequitas/flow/datasets/folktables.py index bd7bf082..ae0a55ff 100644 --- a/src/aequitas/flow/datasets/folktables.py +++ b/src/aequitas/flow/datasets/folktables.py @@ -231,9 +231,16 @@ def load_data(self): self.data = pd.read_parquet(path) else: if self.split_type == "predefined": - datasets = [pd.read_csv(p) for p in path] - self._indexes = [d.index for d in datasets] - self.data = pd.concat(datasets) + train = pd.read_csv(path[0]) + train_index = train.index[-1] + validation = pd.read_csv(path[1]) + validation.set_index(validation.index + train_index + 1, inplace=True) + validation_index = validation.index[-1] + test = pd.read_csv(path[2]) + test.set_index(test.index + validation_index + 1, inplace=True) + self._indexes = [train.index, validation.index, test.index] + + self.data = pd.concat([train, validation, test]) else: self.data = pd.read_csv(path) diff --git a/src/aequitas/flow/methods/preprocessing/label_flipping.py b/src/aequitas/flow/methods/preprocessing/label_flipping.py index ae3bc96d..52f741c6 100644 --- a/src/aequitas/flow/methods/preprocessing/label_flipping.py +++ b/src/aequitas/flow/methods/preprocessing/label_flipping.py @@ -140,7 +140,7 @@ def _feature_suppression(self, X: pd.DataFrame, s: pd.Series) -> pd.DataFrame: else self.unawareness_features ) X_transformed = X_transformed.drop(columns=unawareness_features_list) - + return X_transformed def fit(self, X: pd.DataFrame, y: pd.Series, s: Optional[pd.Series]) -> None: