-
Notifications
You must be signed in to change notification settings - Fork 855
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Issue 265 privileged class bank dataset (#449)
* Updated readme for bank dataset * Added age >60 to unprivileged group in bank_dataset.py * Added tests for bank dataset * Fixed linting errors for all tests/test_standard_datasets.py * Added binary_age to fetch_bank * Download bank dataset in ci.yml
- Loading branch information
1 parent
502ff47
commit 6f9972e
Showing
6 changed files
with
86 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,74 @@ | ||
""" Tests for standard dataset classes """ | ||
|
||
from unittest.mock import patch | ||
import numpy as np | ||
import pandas as pd | ||
|
||
pd.set_option('display.max_rows', 50) | ||
pd.set_option('display.max_columns', 10) | ||
pd.set_option('display.width', 200) | ||
import os | ||
|
||
from aif360.datasets import AdultDataset | ||
from aif360.datasets import BankDataset | ||
from aif360.datasets import CompasDataset | ||
from aif360.datasets import GermanDataset | ||
from aif360.metrics import BinaryLabelDatasetMetric | ||
|
||
pd.set_option('display.max_rows', 50) | ||
pd.set_option('display.max_columns', 10) | ||
pd.set_option('display.width', 200) | ||
|
||
def test_compas(): | ||
''' Test default loading for compas ''' | ||
# just test that there are no errors for default loading... | ||
cd = CompasDataset() | ||
# print(cd) | ||
compas_dataset = CompasDataset() | ||
compas_dataset.validate_dataset() | ||
|
||
def test_german(): | ||
gd = GermanDataset() | ||
bldm = BinaryLabelDatasetMetric(gd) | ||
''' Test default loading for german ''' | ||
german_dataset = GermanDataset() | ||
bldm = BinaryLabelDatasetMetric(german_dataset) | ||
assert bldm.num_instances() == 1000 | ||
|
||
def test_adult_test_set(): | ||
ad = AdultDataset() | ||
# test, train = ad.split([16281]) | ||
test, train = ad.split([15060]) | ||
''' Test default loading for adult, test set ''' | ||
adult_dataset = AdultDataset() | ||
test, _ = adult_dataset.split([15060]) | ||
assert np.any(test.labels) | ||
|
||
def test_adult(): | ||
ad = AdultDataset() | ||
# print(ad.feature_names) | ||
assert np.isclose(ad.labels.mean(), 0.2478, atol=5e-5) | ||
|
||
bldm = BinaryLabelDatasetMetric(ad) | ||
''' Test default loading for adult, mean''' | ||
adult_dataset = AdultDataset() | ||
assert np.isclose(adult_dataset.labels.mean(), 0.2478, atol=5e-5) | ||
bldm = BinaryLabelDatasetMetric(adult_dataset) | ||
assert bldm.num_instances() == 45222 | ||
|
||
def test_adult_no_drop(): | ||
ad = AdultDataset(protected_attribute_names=['sex'], | ||
''' Test default loading for adult, number of instances ''' | ||
adult_dataset = AdultDataset(protected_attribute_names=['sex'], | ||
privileged_classes=[['Male']], categorical_features=[], | ||
features_to_keep=['age', 'education-num']) | ||
bldm = BinaryLabelDatasetMetric(ad) | ||
bldm = BinaryLabelDatasetMetric(adult_dataset) | ||
assert bldm.num_instances() == 48842 | ||
|
||
def test_bank(): | ||
''' Test for errors during default loading ''' | ||
bank_dataset = BankDataset() | ||
bank_dataset.validate_dataset() | ||
|
||
def test_bank_priviliged_attributes(): | ||
''' Test if protected attribute age is correctly processed ''' | ||
# Bank Data Set | ||
bank_dataset = BankDataset() | ||
num_priv = bank_dataset.protected_attributes.sum() | ||
|
||
# Raw data | ||
# TO DO: add file path. | ||
filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), | ||
'..', 'aif360', 'data', 'raw', 'bank', 'bank-additional-full.csv') | ||
|
||
bank_dataset_unpreproc = pd.read_csv(filepath, sep = ";", na_values = ["unknown"]) | ||
bank_dataset_unpreproc = bank_dataset_unpreproc.dropna() | ||
num_priv_raw = len(bank_dataset_unpreproc[(bank_dataset_unpreproc["age"] >= 25) & (bank_dataset_unpreproc["age"] < 60)]) | ||
assert num_priv == num_priv_raw | ||
|
||
|
||
|
||
|