Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More sklearn-compatible algorithms #318

Merged
merged 88 commits into from
Jul 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
8cfa9de
Initial sklearn-compatible datasets and metrics
May 16, 2019
1f4ae57
added initial dataset tests
May 16, 2019
2aef3fc
fixed to_list for older pandas versions
May 17, 2019
2b1799a
added metrics tests
hoffmansc May 17, 2019
9da5abd
added README and docs
hoffmansc May 21, 2019
025ecc1
simpler dataset loading and 'groups' for metrics
hoffmansc May 23, 2019
8e96177
fixes to categoricals
hoffmansc Jun 5, 2019
8abb897
fixes for tests, updated README
hoffmansc Jun 5, 2019
15a8eb2
added travis badge to README
hoffmansc Jun 6, 2019
3f594a4
updated todo with external blockers
hoffmansc Jun 13, 2019
7754b32
added reweighing workaround to example
hoffmansc Jun 13, 2019
17b0c95
added Reweighing algorithm
hoffmansc Jun 18, 2019
cc9246f
clean up comments
hoffmansc Jun 18, 2019
8c58f65
fixed package version in docs
hoffmansc Jun 18, 2019
1e7899c
adding hyperlinks to SLEPs
animeshsingh Jun 20, 2019
c1c1e40
added binary_age opt to german; fixed NAs in bank
hoffmansc Jun 24, 2019
93a7cdf
modified onehot_transformer to return DataFrame
hoffmansc Jun 24, 2019
8e52268
tweaks to reweighing to conform with sklearn
hoffmansc Jun 24, 2019
0183449
updated README
hoffmansc Jun 24, 2019
89b4a79
fixed docstring formatting
hoffmansc Jun 24, 2019
d57b6df
changed metrics to use prot_attr
hoffmansc Jun 24, 2019
d8958bb
added __all__ to __init__s
hoffmansc Jun 24, 2019
0bd3837
updated notebook with reweighing example
hoffmansc Jun 27, 2019
4107dd7
initial adversarial debiasing port
hoffmansc Jul 11, 2019
df85e42
multiclass/multigroup support for adv debiasing
hoffmansc Jul 16, 2019
d2d0ddc
fix build errors
hoffmansc Jul 30, 2019
7a2414a
Add ensure_binary option to check_groups
hoffmansc Aug 12, 2019
aac9954
`numeric_only` converts index and label as well
hoffmansc Oct 29, 2019
dc317cf
changed Reweighing to return X, sample_weight
hoffmansc Oct 29, 2019
0f184c3
made sample_weight optional in check_inputs
hoffmansc Oct 29, 2019
ec4a1de
matched tests to new numeric dataset format
hoffmansc Oct 29, 2019
f8c4fc5
added generalized_fnr/fpr metrics
hoffmansc Oct 29, 2019
7ce2f42
fixed dataset_processing
hoffmansc Oct 29, 2019
973a774
initial calibrated equalized odds port
hoffmansc Oct 29, 2019
40cad96
fixed adversarial debiasing reproducibility
hoffmansc Oct 30, 2019
dc410a2
updated Getting Started notebook
hoffmansc Oct 30, 2019
e0856e3
updated readme
hoffmansc Oct 31, 2019
8f8cd76
fixed tests and added additional tests
hoffmansc Oct 31, 2019
e01f23f
added COMPAS and other dataset fixes* fixed german dataset to match p…
hoffmansc Nov 11, 2019
e92f846
fix more edge cases in metrics
hoffmansc Nov 12, 2019
27aa55c
removed unused import
hoffmansc Nov 12, 2019
831775c
make cache dir if necessary
hoffmansc Dec 9, 2019
a0e56b0
docstring, formatting, and typo fixes
hoffmansc Dec 13, 2019
0e48ead
more gitignores
hoffmansc Dec 13, 2019
0cbc3f4
docstrings and add alpha=sqrt(global_step) option
hoffmansc Dec 13, 2019
8be6449
docstrings and input is now predict_proba output
hoffmansc Dec 13, 2019
994bdf0
moved tests to main test folder
hoffmansc Dec 18, 2019
372e111
more docs and formatting changes
hoffmansc Dec 19, 2019
8d10893
postprocessor takes DataFrame if use_proba
hoffmansc Dec 19, 2019
e0ff2b6
readme changes overwritten in the merge
hoffmansc Dec 19, 2019
a2cd77e
train, test were swapped for adult
hoffmansc Dec 19, 2019
ee7f23c
remove branch mentions
hoffmansc Dec 19, 2019
c8154ec
remove "attributes" line if none present
hoffmansc Dec 20, 2019
7ef94e7
moved example to main folder
hoffmansc Dec 28, 2019
c5af647
use_proba -> needs_proba
hoffmansc Jan 31, 2020
042bb12
fixed/renamed/reordered/added some attributes
hoffmansc Jan 31, 2020
ff9e70c
fixed sample_weight=None bug and classes_ typo
hoffmansc Feb 5, 2020
57b2ab5
improved specificity_score and added fpr/fnr error
hoffmansc Feb 6, 2020
8fdd6dc
made foreign_worker and education (bank) ordered
hoffmansc Feb 6, 2020
2cf455f
various fixes to address PR comments
hoffmansc Feb 19, 2020
789e96b
added comments to tests
hoffmansc Feb 19, 2020
9867938
initial ROC and LFR + other WIP
hoffmansc Apr 22, 2020
478e9a9
Merge branch 'master' into sklearn-compat
Jun 3, 2020
e75f1a0
Merge branch 'sklearn-compat' of https://github.com/IBM/AIF360 into s…
hoffmansc Jun 18, 2020
35d49e5
Merge branch 'master' into sklearn-compat
hoffmansc Jun 18, 2020
f09b941
allow prot_attr/target input to be Series
hoffmansc Jul 27, 2020
2d8cdf1
default to predict if no predict_proba
hoffmansc Jul 27, 2020
cab9bb9
Merge branch 'master' into sklearn-compat
hoffmansc Jul 27, 2020
351259b
matches old lfr
hoffmansc Jul 31, 2020
f08cf7d
use pytorch to calculate grad
hoffmansc Aug 5, 2020
1fdb02e
standardize defaults in metaestimators
hoffmansc Jan 15, 2021
debae37
clean up tests
hoffmansc Jun 21, 2022
c46affb
ROC improvements
hoffmansc Jun 21, 2022
3018660
minor tweaks
hoffmansc Jun 21, 2022
532d83d
Merge branch 'master' into sklearn-compat
hoffmansc Jun 21, 2022
c61728f
infer proba behavior from estimator tag
hoffmansc Jun 27, 2022
ad69dc4
rename LFR
hoffmansc Jun 27, 2022
0c7fe69
use scores instead of costs (higher is better)
hoffmansc Jun 27, 2022
16d800e
fix test imports
hoffmansc Jun 27, 2022
78ece3e
add new algorithms to docs
hoffmansc Jun 27, 2022
2918d36
suppress zero division warnings
hoffmansc Jun 27, 2022
6a710d0
new notebooks
hoffmansc Jun 27, 2022
01090f9
adjust tests
hoffmansc Jun 27, 2022
35c3e2d
small fixes
hoffmansc Jun 28, 2022
f978b5f
fix tests
hoffmansc Jun 28, 2022
304586a
additional checks and errors
hoffmansc Jul 1, 2022
2effc0b
add note to notebook
hoffmansc Jul 1, 2022
35f87a7
propagate classes_ in metaestimators
hoffmansc Jul 1, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions aif360/algorithms/postprocessing/reject_option_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,24 +216,24 @@ def fit_predict(self, dataset_true, dataset_pred):
return self.fit(dataset_true, dataset_pred).predict(dataset_pred)

# Function to obtain the pareto frontier
def _get_pareto_frontier(costs, return_mask = True): # <- Fastest for many points
def _get_pareto_frontier(scores, return_mask = True): # <- Fastest for many points
"""
:param costs: An (n_points, n_costs) array
:param scores: An (n_points, n_scores) array
:param return_mask: True to return a mask, False to return integer indices of efficient points.
:return: An array of indices of pareto-efficient points.
If return_mask is True, this will be an (n_points, ) boolean array
Otherwise it will be a (n_efficient_points, ) integer array of indices.

adapted from: https://stackoverflow.com/questions/32791911/fast-calculation-of-pareto-front-in-python
"""
is_efficient = np.arange(costs.shape[0])
n_points = costs.shape[0]
is_efficient = np.arange(scores.shape[0])
n_points = scores.shape[0]
next_point_index = 0 # Next index in the is_efficient array to search for

while next_point_index<len(costs):
nondominated_point_mask = np.any(costs<=costs[next_point_index], axis=1)
while next_point_index<len(scores):
nondominated_point_mask = np.any(scores>=scores[next_point_index], axis=1)
is_efficient = is_efficient[nondominated_point_mask] # Remove dominated points
costs = costs[nondominated_point_mask]
scores = scores[nondominated_point_mask]
next_point_index = np.sum(nondominated_point_mask[:next_point_index])+1

if return_mask:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,6 @@ def fit_transform(self, dataset):
repaired_features = repairer.repair(features)
repaired.features = np.array(repaired_features, dtype=np.float64)
# protected attribute shouldn't change
repaired.features[:, index] = repaired.protected_attributes[:, 0]
repaired.features[:, index] = repaired.protected_attributes[:, repaired.protected_attribute_names.index(self.sensitive_attribute)]

return repaired
2 changes: 1 addition & 1 deletion aif360/algorithms/preprocessing/lfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def transform(self, dataset, threshold=0.5):
dataset_new.labels = transformed_bin_labels
dataset_new.scores = np.array(transformed_labels)

return dataset_new
return dataset_new

def fit_transform(self, dataset, maxiter=5000, maxfun=5000, threshold=0.5):
"""Fit and transform methods sequentially.
Expand Down
47 changes: 31 additions & 16 deletions aif360/sklearn/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def check_already_dropped(labels, dropped_cols, name, dropped_by='numeric_only',
haven't.

Args:
labels (single label or list-like): Column labels to check.
labels (label, pandas.Series, or list-like of labels/Series): Column
labels to check.
dropped_cols (set or pandas.Index): Columns that were already dropped.
name (str): Original arg that triggered the check (e.g. dropcols).
dropped_by (str, optional): Original arg that caused dropped_cols``
Expand All @@ -27,28 +28,38 @@ def check_already_dropped(labels, dropped_cols, name, dropped_by='numeric_only',
Returns:
list: Columns in labels which are not in dropped_cols.
"""
if not is_list_like(labels):
if isinstance(labels, pd.Series) or not is_list_like(labels):
labels = [labels]
str_labels = [c for c in labels if isinstance(c, str)]
already_dropped = dropped_cols.intersection(str_labels)
str_labels = [c for c in labels if not isinstance(c, pd.Series)]
try:
already_dropped = dropped_cols.intersection(str_labels)
if isinstance(already_dropped, pd.MultiIndex):
raise TypeError # list of lists results in MultiIndex
except TypeError as e:
raise TypeError("Only labels or Series are allowed for {}. Got types:\n"
"{}".format(name, [type(c) for c in labels]))
if warn and any(already_dropped):
warnings.warn("Some column labels from `{}` were already dropped by "
"`{}`:\n{}".format(name, dropped_by, already_dropped.tolist()),
ColumnAlreadyDroppedWarning, stacklevel=2)
return [c for c in labels if not isinstance(c, str) or c not in already_dropped]
return [c for c in labels if isinstance(c, pd.Series)
or c not in already_dropped]

def standardize_dataset(df, prot_attr, target, sample_weight=None, usecols=[],
dropcols=[], numeric_only=False, dropna=True):
def standardize_dataset(df, *, prot_attr, target, sample_weight=None,
usecols=[], dropcols=[], numeric_only=False,
dropna=True):
"""Separate data, targets, and possibly sample weights and populate
protected attributes as sample properties.

Args:
df (pandas.DataFrame): DataFrame with features and target together.
prot_attr (single label or list-like): Label or list of labels
corresponding to protected attribute columns. Even if these are
dropped from the features, they remain in the index.
target (single label or list-like): Column label of the target (outcome)
variable.
prot_attr (label, pandas.Series, or list-like of labels/Series): Single
label, Series, or list-like of labels/Series corresponding to
protected attribute columns. Even if these are dropped from the
features, they remain in the index. If a Series is provided, it will
be added to the index but not show up in the features.
target (label, pandas.Series, or list-like of labels/Series): Column
label(s) or values of the target (outcome) variable.
sample_weight (single label, optional): Name of the column containing
sample weights.
usecols (single label or list-like, optional): Column(s) to keep. All
Expand Down Expand Up @@ -77,9 +88,11 @@ def standardize_dataset(df, prot_attr, target, sample_weight=None, usecols=[],
>>> import pandas as pd
>>> from sklearn.linear_model import LinearRegression

>>> df = pd.DataFrame([[1, 2, 3], [4, 5, 6]], columns=['X', 'y', 'Z'])
>>> train = standardize_dataset(df, prot_attr='Z', target='y')
>>> reg = LinearRegression().fit(*train)
>>> df = pd.DataFrame([[0.5, 1, 1, 0.75], [-0.5, 0, 0, 0.25]],
... columns=['X', 'y', 'Z', 'w'])
>>> train = standardize_dataset(df, prot_attr='Z', target='y',
... sample_weight='w')
>>> reg = LinearRegression().fit(**train._asdict())

>>> import numpy as np
>>> from sklearn.datasets import make_classification
Expand All @@ -105,7 +118,9 @@ def standardize_dataset(df, prot_attr, target, sample_weight=None, usecols=[],
target = check_already_dropped(target, nonnumeric, 'target')
if len(target) == 0:
raise ValueError("At least one target must be present.")
y = pd.concat([df.pop(t) for t in target], axis=1).squeeze() # maybe Series
y = pd.concat([df.pop(t) if not isinstance(t, pd.Series) else
t.set_axis(df.index, inplace=False) for t in target], axis=1)
y = y.squeeze() # maybe Series

# Column-wise drops
orig_cols = df.columns
Expand Down
4 changes: 2 additions & 2 deletions aif360/sklearn/inprocessing/adversarial_debiasing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, prot_attr=None, scope_name='classifier',
entire model (classifier and adversary).
adversary_loss_weight (float or ``None``, optional): If ``None``,
this will use the suggestion from the paper:
:math:`\alpha = \sqrt(global_step)` with inverse time decay on
:math:`\alpha = \sqrt{global\_step}` with inverse time decay on
the learning rate. Otherwise, it uses the provided coefficient
with exponential learning rate decay.
num_epochs (int, optional): Number of epochs for which to train.
Expand Down Expand Up @@ -340,7 +340,7 @@ def predict(self, X):
"""
scores = self.decision_function(X)
if scores.ndim == 1:
indices = (scores > 0).astype(np.int)
indices = (scores > 0).astype(int)
else:
indices = scores.argmax(axis=1)
return self.classes_[indices]
116 changes: 72 additions & 44 deletions aif360/sklearn/metrics/metrics.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import warnings

import numpy as np
import pandas as pd
from sklearn.metrics import make_scorer as _make_scorer, recall_score
from sklearn.metrics import multilabel_confusion_matrix
from sklearn.metrics._classification import _prf_divide, _check_zero_division
from sklearn.neighbors import NearestNeighbors
from sklearn.utils import check_X_y
from sklearn.exceptions import UndefinedMetricWarning, deprecated
from sklearn.utils.validation import column_or_1d
from sklearn.exceptions import deprecated

from aif360.sklearn.utils import check_groups
from aif360.detectors.mdss.ScoringFunctions import BerkJones, Bernoulli
Expand Down Expand Up @@ -77,7 +77,7 @@ def difference(func, y, *args, prot_attr=None, priv_group=1, sample_weight=None,
return func(*unpriv, **kwargs) - func(*priv, **kwargs)

def ratio(func, y, *args, prot_attr=None, priv_group=1, sample_weight=None,
**kwargs):
zero_division='warn', **kwargs):
"""Compute the ratio between unprivileged and privileged subsets for an
arbitrary metric.

Expand All @@ -96,11 +96,15 @@ def ratio(func, y, *args, prot_attr=None, priv_group=1, sample_weight=None,
priv_group (scalar, optional): The label of the privileged group.
sample_weight (array-like, optional): Sample weights passed through to
func.
zero_division ('warn', 0 or 1): Sets the value to return when there is a
zero division. If set to “warn”, this acts as 0, but warnings are
also raised.
**kwargs: Additional keyword args to be passed through to func.

Returns:
scalar: Ratio of metric values for unprivileged and privileged groups.
"""
_check_zero_division(zero_division)
groups, _ = check_groups(y, prot_attr)
idx = (groups == priv_group)
unpriv = map(lambda a: a[~idx], (y,) + args)
Expand All @@ -112,13 +116,14 @@ def ratio(func, y, *args, prot_attr=None, priv_group=1, sample_weight=None,
numerator = func(*unpriv, **kwargs)
denominator = func(*priv, **kwargs)

if denominator == 0:
warnings.warn("The ratio is ill-defined and being set to 0.0 because "
"'{}' for privileged samples is 0.".format(func.__name__),
UndefinedMetricWarning)
return 0.

return numerator / denominator
if func == base_rate:
modifier = 'positive privileged'
elif func == selection_rate:
modifier = 'predicted privileged'
else:
modifier = f'value for {func.__name__} on privileged'
return _prf_divide(np.array([numerator]), np.array([denominator]), 'ratio',
modifier, None, ('ratio',), zero_division).item()


# =========================== SCORER FACTORY =================================
Expand Down Expand Up @@ -151,24 +156,26 @@ def score(y, y_pred, **kwargs):
return scorer

# ================================ HELPERS =====================================
def specificity_score(y_true, y_pred, pos_label=1, sample_weight=None):
def specificity_score(y_true, y_pred, pos_label=1, sample_weight=None,
zero_division='warn'):
"""Compute the specificity or true negative rate.

Args:
y_true (array-like): Ground truth (correct) target values.
y_pred (array-like): Estimated targets as returned by a classifier.
pos_label (scalar, optional): The label of the positive class.
sample_weight (array-like, optional): Sample weights.
zero_division ('warn', 0 or 1): Sets the value to return when there is a
zero division. If set to “warn”, this acts as 0, but warnings are
also raised.
"""
_check_zero_division(zero_division)
MCM = multilabel_confusion_matrix(y_true, y_pred, labels=[pos_label],
sample_weight=sample_weight)
tn, fp, fn, tp = MCM.ravel()
tn, fp = MCM[:, 0, 0], MCM[:, 0, 1]
negs = tn + fp
if negs == 0:
warnings.warn('specificity_score is ill-defined and being set to 0.0 '
'due to no negative samples.', UndefinedMetricWarning)
return 0.
return tn / negs
return _prf_divide(tn, negs, 'specificity', 'negative', None,
('specificity',), zero_division).item()

def base_rate(y_true, y_pred=None, pos_label=1, sample_weight=None):
r"""Compute the base rate, :math:`Pr(Y = \text{pos_label}) = \frac{P}{P+N}`.
Expand Down Expand Up @@ -200,7 +207,8 @@ def selection_rate(y_true, y_pred, pos_label=1, sample_weight=None):
"""
return base_rate(y_pred, pos_label=pos_label, sample_weight=sample_weight)

def generalized_fpr(y_true, probas_pred, pos_label=1, sample_weight=None):
def generalized_fpr(y_true, probas_pred, pos_label=1, sample_weight=None,
zero_division='warn'):
r"""Return the ratio of generalized false positives to negative examples in
the dataset, :math:`GFPR = \tfrac{GFP}{N}`.

Expand All @@ -212,22 +220,29 @@ def generalized_fpr(y_true, probas_pred, pos_label=1, sample_weight=None):
probas_pred (array-like): Probability estimates of the positive class.
pos_label (scalar, optional): The label of the positive class.
sample_weight (array-like, optional): Sample weights.
zero_division ('warn', 0 or 1): Sets the value to return when there is a
zero division. If set to “warn”, this acts as 0, but warnings are
also raised.

Returns:
float: Generalized false positive rate. If there are no negative samples
in y_true, this will raise an
:class:`~sklearn.exceptions.UndefinedMetricWarning` and return 0.
float: Generalized false positive rate.
"""
_check_zero_division(zero_division)
y_true, probas_pred = column_or_1d(y_true), column_or_1d(probas_pred)

idx = (y_true != pos_label)
if not np.any(idx):
warnings.warn("generalized_fpr is ill-defined because there are no "
"negative samples in y_true.", UndefinedMetricWarning)
return 0.
gfps = probas_pred[idx]
if sample_weight is None:
return probas_pred[idx].mean()
return np.average(probas_pred[idx], weights=sample_weight[idx])
gfp = np.array([gfps.sum()])
neg = np.array([len(gfps)])
else:
gfp = np.array([np.dot(gfps, sample_weight[idx])])
neg = np.array([sample_weight[idx].sum()])
return _prf_divide(gfp, neg, 'generalized FPR', 'negative', None,
('generalized FPR',), zero_division).item()

def generalized_fnr(y_true, probas_pred, pos_label=1, sample_weight=None):
def generalized_fnr(y_true, probas_pred, pos_label=1, sample_weight=None,
zero_division='warn'):
r"""Return the ratio of generalized false negatives to positive examples in
the dataset, :math:`GFNR = \tfrac{GFN}{P}`.

Expand All @@ -239,20 +254,26 @@ def generalized_fnr(y_true, probas_pred, pos_label=1, sample_weight=None):
probas_pred (array-like): Probability estimates of the positive class.
pos_label (scalar, optional): The label of the positive class.
sample_weight (array-like, optional): Sample weights.
zero_division ('warn', 0 or 1): Sets the value to return when there is a
zero division. If set to “warn”, this acts as 0, but warnings are
also raised.

Returns:
float: Generalized false negative rate. If there are no positive samples
in y_true, this will raise an
:class:`~sklearn.exceptions.UndefinedMetricWarning` and return 0.
float: Generalized false negative rate.
"""
_check_zero_division(zero_division)
y_true, probas_pred = column_or_1d(y_true), column_or_1d(probas_pred)

idx = (y_true == pos_label)
if not np.any(idx):
warnings.warn("generalized_fnr is ill-defined because there are no "
"positive samples in y_true.", UndefinedMetricWarning)
return 0.
gfns = 1 - probas_pred[idx]
if sample_weight is None:
return 1 - probas_pred[idx].mean()
return 1 - np.average(probas_pred[idx], weights=sample_weight[idx])
gfn = np.array([gfns.sum()])
pos = np.array([len(gfns)])
else:
gfn = np.array([np.dot(gfns, sample_weight[idx])])
pos = np.array([sample_weight[idx].sum()])
return _prf_divide(gfn, pos, 'generalized FNR', 'positive', None,
('generalized FNR',), zero_division).item()


# ============================ GROUP FAIRNESS ==================================
Expand Down Expand Up @@ -291,7 +312,7 @@ def statistical_parity_difference(*y, prot_attr=None, priv_group=1, pos_label=1,
pos_label=pos_label, sample_weight=sample_weight)

def disparate_impact_ratio(*y, prot_attr=None, priv_group=1, pos_label=1,
sample_weight=None):
sample_weight=None, zero_division='warn'):
r"""Ratio of selection rates.

.. math::
Expand All @@ -313,6 +334,9 @@ def disparate_impact_ratio(*y, prot_attr=None, priv_group=1, pos_label=1,
priv_group (scalar, optional): The label of the privileged group.
pos_label (scalar, optional): The label of the positive class.
sample_weight (array-like, optional): Sample weights.
zero_division ('warn', 0 or 1): Sets the value to return when there is a
zero division. If set to “warn”, this acts as 0, but warnings are
also raised.

Returns:
float: Disparate impact.
Expand All @@ -322,7 +346,8 @@ def disparate_impact_ratio(*y, prot_attr=None, priv_group=1, pos_label=1,
"""
rate = base_rate if len(y) == 1 or y[1] is None else selection_rate
return ratio(rate, *y, prot_attr=prot_attr, priv_group=priv_group,
pos_label=pos_label, sample_weight=sample_weight)
pos_label=pos_label, sample_weight=sample_weight,
zero_division=zero_division)

def equal_opportunity_difference(y_true, y_pred, prot_attr=None, priv_group=1,
pos_label=1, sample_weight=None):
Expand Down Expand Up @@ -384,8 +409,8 @@ def average_odds_difference(y_true, y_pred, prot_attr=None, priv_group=1,
sample_weight=sample_weight)
return (tpr_diff + fpr_diff) / 2

def average_odds_error(y_true, y_pred, prot_attr=None, pos_label=1,
sample_weight=None):
def average_odds_error(y_true, y_pred, prot_attr=None, priv_group=None,
pos_label=1, sample_weight=None):
r"""A relaxed version of equality of odds.

Returns the average of the absolute difference in FPR and TPR for the
Expand All @@ -403,14 +428,17 @@ def average_odds_error(y_true, y_pred, prot_attr=None, pos_label=1,
y_pred (array-like): Estimated targets as returned by a classifier.
prot_attr (array-like, keyword-only): Protected attribute(s). If
``None``, all protected attributes in y_true are used.
priv_group (scalar, optional): The label of the privileged group.
priv_group (scalar, optional): The label of the privileged group. If
prot_attr is binary, this may be ``None``.
pos_label (scalar, optional): The label of the positive class.
sample_weight (array-like, optional): Sample weights.

Returns:
float: Average odds error.
"""
priv_group = check_groups(y_true, prot_attr=prot_attr)[0][0]
if priv_group is None:
priv_group = check_groups(y_true, prot_attr=prot_attr,
ensure_binary=True)[0][0]
fpr_diff = -difference(specificity_score, y_true, y_pred,
prot_attr=prot_attr, priv_group=priv_group,
pos_label=pos_label, sample_weight=sample_weight)
Expand Down
Loading