-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfair_fn.py
61 lines (37 loc) · 1.92 KB
/
fair_fn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# @Author : Peizhao Li
# @Contact : peizhaoli05@gmail.com
import numpy as np
from model import IFBaseClass
def grad_ferm(grad_fn: IFBaseClass.grad, x: np.ndarray, y: np.ndarray, s: np.ndarray) -> np.ndarray:
"""
Fair empirical risk minimization for binary sensitive attribute
Exp(L|grp_0) - Exp(L|grp_1)
"""
N = x.shape[0]
idx_grp_0_y_1 = [i for i in range(N) if s[i] == 0 and y[i] == 1]
idx_grp_1_y_1 = [i for i in range(N) if s[i] == 1 and y[i] == 1]
grad_grp_0_y_1, _ = grad_fn(x=x[idx_grp_0_y_1], y=y[idx_grp_0_y_1])
grad_grp_1_y_1, _ = grad_fn(x=x[idx_grp_1_y_1], y=y[idx_grp_1_y_1])
return (grad_grp_0_y_1 / len(idx_grp_0_y_1)) - (grad_grp_1_y_1 / len(idx_grp_1_y_1))
def loss_ferm(loss_fn: IFBaseClass.log_loss, x: np.ndarray, y: np.ndarray, s: np.ndarray) -> float:
N = x.shape[0]
idx_grp_0_y_1 = [i for i in range(N) if s[i] == 0 and y[i] == 1]
idx_grp_1_y_1 = [i for i in range(N) if s[i] == 1 and y[i] == 1]
loss_grp_0_y_1 = loss_fn(x[idx_grp_0_y_1], y[idx_grp_0_y_1])
loss_grp_1_y_1 = loss_fn(x[idx_grp_1_y_1], y[idx_grp_1_y_1])
return (loss_grp_0_y_1 / len(idx_grp_0_y_1)) - (loss_grp_1_y_1 / len(idx_grp_1_y_1))
def grad_dp(grad_fn: IFBaseClass.grad_pred, x: np.ndarray, s: np.ndarray) -> np.ndarray:
""" Demographic parity """
N = x.shape[0]
idx_grp_0 = [i for i in range(N) if s[i] == 0]
idx_grp_1 = [i for i in range(N) if s[i] == 1]
grad_grp_0, _ = grad_fn(x=x[idx_grp_0])
grad_grp_1, _ = grad_fn(x=x[idx_grp_1])
return (grad_grp_1 / len(idx_grp_1)) - (grad_grp_0 / len(idx_grp_0))
def loss_dp(x: np.ndarray, s: np.ndarray, pred: np.ndarray) -> float:
N = x.shape[0]
idx_grp_0 = [i for i in range(N) if s[i] == 0]
idx_grp_1 = [i for i in range(N) if s[i] == 1]
pred_grp_0 = np.sum(pred[idx_grp_0])
pred_grp_1 = np.sum(pred[idx_grp_1])
return (pred_grp_1 / len(idx_grp_1)) - (pred_grp_0 / len(idx_grp_0))