From a89641e6cb7bc6c5be024e98690810e7048bdbad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9rgio=20Jesus?= <36162088+sgpjesus@users.noreply.github.com> Date: Tue, 27 Aug 2024 10:43:29 +0100 Subject: [PATCH] Add unit tests to datasets (#202) * Add tests to folktables datasets * Update the location of Generic dataset tests * Update BAF tests --- src/aequitas/flow/datasets/baf.py | 17 +++-- src/aequitas/flow/datasets/folktables.py | 47 +++++------- tests/flow/datasets/test_baf.py | 91 +++++++++++++++++++++++ tests/flow/datasets/test_folktables.py | 90 ++++++++++++++++++++++ tests/{ => flow/datasets}/test_generic.py | 51 ++++++------- 5 files changed, 236 insertions(+), 60 deletions(-) create mode 100644 tests/flow/datasets/test_baf.py create mode 100644 tests/flow/datasets/test_folktables.py rename tests/{ => flow/datasets}/test_generic.py (83%) diff --git a/src/aequitas/flow/datasets/baf.py b/src/aequitas/flow/datasets/baf.py index 9175e985..f0177b31 100644 --- a/src/aequitas/flow/datasets/baf.py +++ b/src/aequitas/flow/datasets/baf.py @@ -83,9 +83,7 @@ def __init__( ): super().__init__() - self.label_column = ( - LABEL_COLUMN if label_column is None else label_column - ) + self.label_column = LABEL_COLUMN if label_column is None else label_column if sensitive_column == "customer_age" or sensitive_column is None: self.sensitive_column = SENSITIVE_COLUMN @@ -107,11 +105,12 @@ def __init__( else: self.variant = variant self.logger.debug(f"Variant: {self.variant}") - if url(path) or path.exists(): - self.path = path + + self.extension = extension + self.path = path + if url(path) or self._check_paths(): self._download = False else: - self.path = path self._download = True if split_type not in SPLIT_TYPES: raise ValueError(f"Invalid split_type value. Try one of: {SPLIT_TYPES}") @@ -120,7 +119,6 @@ def __init__( self.splits = splits self._validate_splits() self.logger.debug("Splits successfully validated.") - self.extension = extension self.seed = seed self.data: pd.DataFrame = None self.include_month = include_month @@ -203,6 +201,11 @@ def create_splits(self) -> None: ), ) + def _check_paths(self) -> bool: + """Check if the data is already downloaded.""" + check_path = Path(self.path) / f"{self.variant}.{self.extension}" + return check_path.exists() + def _download_data(self) -> None: """Obtains the data of the sample dataset from Aequitas repository.""" self.logger.info("Downloading sample data from repository.") diff --git a/src/aequitas/flow/datasets/folktables.py b/src/aequitas/flow/datasets/folktables.py index ae0a55ff..a29bcb5e 100644 --- a/src/aequitas/flow/datasets/folktables.py +++ b/src/aequitas/flow/datasets/folktables.py @@ -212,37 +212,28 @@ def load_data(self): if self._download: self._download_data() - if self.split_type == "predefined": - path = [] - for split in ["train", "validation", "test"]: - if isinstance(self.path, str): - path.append(self.path + f"/{self.variant}.{split}.{self.extension}") - else: - path.append(self.path / f"{self.variant}.{split}.{self.extension}") - else: - path = self.path / f"{self.variant}.{self.extension}" + path = [] + for split in ["train", "validation", "test"]: + if isinstance(self.path, str): + path.append(self.path + f"/{self.variant}.{split}.{self.extension}") + else: + path.append(self.path / f"{self.variant}.{split}.{self.extension}") if self.extension == "parquet": - if self.split_type == "predefined": - datasets = [pd.read_parquet(p) for p in path] - self._indexes = [d.index for d in datasets] - self.data = pd.concat(datasets) - else: - self.data = pd.read_parquet(path) + datasets = [pd.read_parquet(p) for p in path] + self._indexes = [d.index for d in datasets] + self.data = pd.concat(datasets) else: - if self.split_type == "predefined": - train = pd.read_csv(path[0]) - train_index = train.index[-1] - validation = pd.read_csv(path[1]) - validation.set_index(validation.index + train_index + 1, inplace=True) - validation_index = validation.index[-1] - test = pd.read_csv(path[2]) - test.set_index(test.index + validation_index + 1, inplace=True) - self._indexes = [train.index, validation.index, test.index] - - self.data = pd.concat([train, validation, test]) - else: - self.data = pd.read_csv(path) + train = pd.read_csv(path[0]) + train_index = train.index[-1] + validation = pd.read_csv(path[1]) + validation.set_index(validation.index + train_index + 1, inplace=True) + validation_index = validation.index[-1] + test = pd.read_csv(path[2]) + test.set_index(test.index + validation_index + 1, inplace=True) + self._indexes = [train.index, validation.index, test.index] + + self.data = pd.concat([train, validation, test]) for col in CATEGORICAL_COLUMNS[self.variant]: self.data[col] = self.data[col].astype("category") diff --git a/tests/flow/datasets/test_baf.py b/tests/flow/datasets/test_baf.py new file mode 100644 index 00000000..3442fa6c --- /dev/null +++ b/tests/flow/datasets/test_baf.py @@ -0,0 +1,91 @@ +import unittest +from aequitas.flow.datasets.baf import BankAccountFraud, VARIANTS, DEFAULT_PATH + + +# TODO: These tests can be merged with the ones in test_folktables.py + +class TestBankAccountFraudDataset(unittest.TestCase): + # Test loading related functionalities. + def test_load_variants(self): + for variant in VARIANTS: + dataset = BankAccountFraud(variant) + dataset.load_data() + self.assertTrue(len(dataset.data) > 0) + self.assertTrue("customer_age_bin" in dataset.data.columns) + self.assertTrue("fraud_bool" in dataset.data.columns) + + def test_load_invalid_variant(self): + with self.assertRaises(ValueError): + BankAccountFraud("invalid_variant") + + def test_download(self): + # Remove default folder of datasets even if not empty + if DEFAULT_PATH.exists(): + for file in DEFAULT_PATH.iterdir(): + file.unlink() + DEFAULT_PATH.rmdir() + for variant in VARIANTS: + dataset = BankAccountFraud(variant) + dataset.load_data() + self.assertTrue(dataset.path.exists()) + + # Test split related functionalities. + def test_invalid_split_type(self): + with self.assertRaises(ValueError): + BankAccountFraud(VARIANTS[0], split_type="invalid_split_type") + + def test_default_split(self): + dataset = BankAccountFraud(VARIANTS[0]) + dataset.load_data() + dataset.create_splits() + self.assertTrue(len(dataset.train) > 0) + self.assertTrue(len(dataset.test) > 0) + self.assertTrue(len(dataset.validation) > 0) + + def test_random_split(self): + dataset = BankAccountFraud( + VARIANTS[0], + split_type="random", + splits={"train": 0.6, "validation": 0.2, "test": 0.2}, + ) + dataset.load_data() + dataset.create_splits() + self.assertTrue(len(dataset.train) > 0) + self.assertTrue(len(dataset.test) > 0) + self.assertTrue(len(dataset.validation) > 0) + + def test_invalid_random_split_missing_key(self): + with self.assertRaises(ValueError): + BankAccountFraud( + VARIANTS[0], + split_type="random", + splits={"train": 0.6, "validation": 0.2}, + ) + + def test_invalid_random_split_more_than_1(self): + with self.assertRaises(ValueError): + BankAccountFraud( + VARIANTS[0], + split_type="random", + splits={"train": 0.6, "validation": 0.2, "test": 0.3}, + ) + + # Test sensitive column related issues. + def test_housing_sensitive_column(self): + dataset = BankAccountFraud(VARIANTS[0], sensitive_column="housing_status") + dataset.load_data() + self.assertTrue("housing_status" in dataset.data.columns) + self.assertTrue(dataset.data.s.name == "housing_status") + + def test_invalid_sensitive_column(self): + with self.assertRaises(ValueError): + BankAccountFraud(VARIANTS[0], sensitive_column="invalid_column") + + def test_invalid_sensitive_column_type(self): + with self.assertRaises(ValueError): + BankAccountFraud(VARIANTS[0], sensitive_column="name_email_similarity") + # Numerical column + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/flow/datasets/test_folktables.py b/tests/flow/datasets/test_folktables.py new file mode 100644 index 00000000..c0ac67a4 --- /dev/null +++ b/tests/flow/datasets/test_folktables.py @@ -0,0 +1,90 @@ +import unittest +from aequitas.flow.datasets.folktables import FolkTables, VARIANTS, DEFAULT_PATH + + +# TODO: Test CSV related functionalities. + +class TestFolktablesDataset(unittest.TestCase): + # Test loading related functionalities. + def test_load_variants(self): + for variant in VARIANTS: + dataset = FolkTables(variant) + dataset.load_data() + self.assertTrue(len(dataset.data) > 0) + self.assertTrue("RAC1P" in dataset.data.columns) + self.assertTrue("AGEP" in dataset.data.columns) + + def test_load_invalid_variant(self): + with self.assertRaises(ValueError): + FolkTables("invalid_variant") + + def test_download(self): + # Remove default folder of datasets even if not empty + if DEFAULT_PATH.exists(): + for file in DEFAULT_PATH.iterdir(): + file.unlink() + DEFAULT_PATH.rmdir() + for variant in VARIANTS: + dataset = FolkTables(variant) + dataset.load_data() + self.assertTrue(dataset.path.exists()) + + # Test split related functionalities. + def test_invalid_split_type(self): + with self.assertRaises(ValueError): + FolkTables(VARIANTS[0], split_type="invalid_split_type") + + def test_default_split(self): + dataset = FolkTables(VARIANTS[0]) + dataset.load_data() + dataset.create_splits() + self.assertTrue(len(dataset.train) > 0) + self.assertTrue(len(dataset.test) > 0) + self.assertTrue(len(dataset.validation) > 0) + + def test_random_split(self): + dataset = FolkTables( + VARIANTS[0], + split_type="random", + splits={"train": 0.6, "validation": 0.2, "test": 0.2}, + ) + dataset.load_data() + dataset.create_splits() + self.assertTrue(len(dataset.train) > 0) + self.assertTrue(len(dataset.test) > 0) + self.assertTrue(len(dataset.validation) > 0) + + def test_invalid_random_split_missing_key(self): + with self.assertRaises(ValueError): + FolkTables( + VARIANTS[0], + split_type="random", + splits={"train": 0.6, "validation": 0.2}, + ) + + def test_invalid_random_split_more_than_1(self): + with self.assertRaises(ValueError): + FolkTables( + VARIANTS[0], + split_type="random", + splits={"train": 0.6, "validation": 0.2, "test": 0.3}, + ) + + # Test sensitive column related issues. + def test_age_sensitive_column(self): + dataset = FolkTables(VARIANTS[0], sensitive_column="AGEP") + dataset.load_data() + self.assertTrue("AGEP" in dataset.data.columns) + self.assertTrue("AGEP_bin" in dataset.data.columns) + + def test_invalid_sensitive_column(self): + with self.assertRaises(ValueError): + FolkTables(VARIANTS[0], sensitive_column="invalid_column") + + def test_invalid_sensitive_column_type(self): + with self.assertRaises(ValueError): + FolkTables(VARIANTS[0], sensitive_column="SCHL") # Numerical column + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_generic.py b/tests/flow/datasets/test_generic.py similarity index 83% rename from tests/test_generic.py rename to tests/flow/datasets/test_generic.py index 8653ba2f..65f6d85a 100644 --- a/tests/test_generic.py +++ b/tests/flow/datasets/test_generic.py @@ -1,9 +1,12 @@ import unittest import os import pandas as pd + +from pathlib import Path from aequitas.flow.datasets.generic import GenericDataset -BASE_DIR = os.path.dirname(__file__) +BASE_DIR = Path(__file__) +TEST_DIR = BASE_DIR.parents[2] class TestGenericDataset(unittest.TestCase): @@ -33,7 +36,7 @@ def test_load_data_from_path_parquet(self): sensitive_column="sensitive", extension="parquet", dataset_path=os.path.join( - BASE_DIR, "test_artifacts/test_generic/data.parquet" + TEST_DIR, "test_artifacts/test_generic/data.parquet" ), ) dataset.load_data() @@ -46,7 +49,7 @@ def test_load_data_from_path_csv(self): label_column="label", sensitive_column="sensitive", extension="csv", - dataset_path=os.path.join(BASE_DIR, "test_artifacts/test_generic/data.csv"), + dataset_path=os.path.join(TEST_DIR, "test_artifacts/test_generic/data.csv"), ) dataset.load_data() self.assertEqual(len(dataset.data), 10) @@ -59,13 +62,13 @@ def test_load_data_from_multiple_paths(self): sensitive_column="sensitive", extension="csv", train_path=os.path.join( - BASE_DIR, "test_artifacts/test_generic/data_train.csv" + TEST_DIR, "test_artifacts/test_generic/data_train.csv" ), validation_path=os.path.join( - BASE_DIR, "test_artifacts/test_generic/data_validation.csv" + TEST_DIR, "test_artifacts/test_generic/data_validation.csv" ), test_path=os.path.join( - BASE_DIR, "test_artifacts/test_generic/data_test.csv" + TEST_DIR, "test_artifacts/test_generic/data_test.csv" ), ) dataset.load_data() @@ -104,7 +107,7 @@ def test_create_splits_column_from_path(self): dataset = GenericDataset( label_column="label", sensitive_column="sensitive", - dataset_path=os.path.join(BASE_DIR, "test_artifacts/test_generic/data.csv"), + dataset_path=os.path.join(TEST_DIR, "test_artifacts/test_generic/data.csv"), split_type="column", extension="csv", split_column="sensitive", @@ -119,20 +122,20 @@ def test_create_splits_column_from_path(self): def test_all_paths_provided(self): self.assertRaisesRegex( ValueError, - "If single dataset path is passed, the other paths must be None.", + "If single dataset path is passed, the other paths must be None.", GenericDataset, label_column="label", sensitive_column="sensitive", train_path=os.path.join( - BASE_DIR, "test_artifacts/test_generic/data_train.csv" + TEST_DIR, "test_artifacts/test_generic/data_train.csv" ), validation_path=os.path.join( - BASE_DIR, "test_artifacts/test_generic/data_validation.csv" + TEST_DIR, "test_artifacts/test_generic/data_validation.csv" ), test_path=os.path.join( - BASE_DIR, "test_artifacts/test_generic/data_test.csv" + TEST_DIR, "test_artifacts/test_generic/data_test.csv" ), - dataset_path=os.path.join(BASE_DIR, "test_artifacts/test_generic/data.csv"), + dataset_path=os.path.join(TEST_DIR, "test_artifacts/test_generic/data.csv"), ) def test_missing_paths(self): @@ -143,10 +146,10 @@ def test_missing_paths(self): label_column="label", sensitive_column="sensitive", train_path=os.path.join( - BASE_DIR, "test_artifacts/test_generic/data_train.csv" + TEST_DIR, "test_artifacts/test_generic/data_train.csv" ), validation_path=os.path.join( - BASE_DIR, "test_artifacts/test_generic/data_validation.csv" + TEST_DIR, "test_artifacts/test_generic/data_validation.csv" ), ) @@ -157,12 +160,12 @@ def test_invalid_path(self): GenericDataset, label_column="label", sensitive_column="sensitive", - train_path=os.path.join(BASE_DIR, "test_artifacts/data_train.csv"), + train_path=os.path.join(TEST_DIR, "test_artifacts/data_train.csv"), validation_path=os.path.join( - BASE_DIR, "test_artifacts/test_generic/data_validation.csv" + TEST_DIR, "test_artifacts/test_generic/data_validation.csv" ), test_path=os.path.join( - BASE_DIR, "test_artifacts/test_generic/data_test.csv" + TEST_DIR, "test_artifacts/test_generic/data_test.csv" ), ) @@ -173,7 +176,7 @@ def test_missing_split_key(self): GenericDataset, label_column="label", sensitive_column="sensitive", - dataset_path=os.path.join(BASE_DIR, "test_artifacts/test_generic/data.csv"), + dataset_path=os.path.join(TEST_DIR, "test_artifacts/test_generic/data.csv"), split_values={"train": 0.63, "validation": 0.37}, ) @@ -185,25 +188,23 @@ def test_invalid_splits(self): GenericDataset, label_column="label", sensitive_column="sensitive", - dataset_path=os.path.join(BASE_DIR, "test_artifacts/test_generic/data.csv"), + dataset_path=os.path.join(TEST_DIR, "test_artifacts/test_generic/data.csv"), split_values={"train": 0.63, "validation": 0.37, "test": 0.2}, ) def test_invalid_splits_warn(self): with self.assertLogs("datasets.GenericDataset", level="WARN") as cm: - dataset = GenericDataset( + GenericDataset( label_column="label", sensitive_column="sensitive", dataset_path=os.path.join( - BASE_DIR, "test_artifacts/test_generic/data.csv" + TEST_DIR, "test_artifacts/test_generic/data.csv" ), - split_values={"train": 0.3, "validation": 0.1, "test": 0.2}, + split_values={"train": 0.5, "validation": 0.1, "test": 0.2}, ) self.assertEqual( cm.output, - [ - "WARNING:datasets.GenericDataset:Using only 0.6000000000000001 of the dataset." - ], + ["WARNING:datasets.GenericDataset:Using only 0.8 of the " "dataset."], ) def test_missing_splits_column(self):