-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathnon_param_dro.py
174 lines (135 loc) · 5.37 KB
/
non_param_dro.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#!/usr/bin/env python3
"""
Main program
"""
import numpy as np
import torch as th
import scipy.optimize
import abc
MIN_REL_DIFFERENCE = 1e-5
def bisection_search(objective, min_val, max_val, xtol=1e-5, maxiter=100):
# Check boundary conditions
if objective(min_val) * objective(max_val) >= 0:
# In this case the result lies outside the interval
if np.abs(objective(min_val)) < np.abs(objective(max_val)):
# To the left
root = min_val
else:
# Or it lies to the right
root = max_val
else:
# the result lies inside the interval, use the bisection method
# (=binary search) to find it
root, results = scipy.optimize.bisect(
objective, min_val, max_val, xtol=xtol, maxiter=maxiter,
full_output=True, disp=False,
)
if not results.converged:
print("Bisect didn't converge")
return root
class NonParametricAdversary(abc.ABC):
def is_valid_response(self, q: np.ndarray):
raise NotImplementedError()
def best_response(self, losses: th.Tensor):
raise NotImplementedError()
class KLConstrainedAdversary(NonParametricAdversary):
"""KL constrained adversary
Args:
kappa (float): KL bound
log_tau_min (float, optional): Minimum value to check for the log
temperature. Defaults to -10.
log_tau_max (float, optional): Maximum value to check for the log
temperature. Defaults to 10.
"""
def __init__(
self, kappa: float, log_tau_min: float = -10, log_tau_max: float = 10
) -> None:
super().__init__()
self.kappa = kappa
self.log_tau_min = log_tau_min
self.log_tau_max = log_tau_max
def find_optimal_tau(self, losses: np.ndarray):
"""Find \\tau^* such that KL(q^*_\\tau^* || p) = \\kappa
Heuristically we've found that values of \\tau can be very small
(<10^2) or sometimes big (10^2). Therefore, searching for \\log_10
\\tau^* is marginally faster since the values taken by \\tau^* are more
evenly spread out on the log scale
Args:
losses (np.ndarray): sample losses
Returns:
float: The optimal \\tau
"""
def kl_delta(log_tau):
tau = 10 ** log_tau
log_q_star_ = losses / tau
log_q_star_ -= log_q_star_.max()
log_q_star = log_q_star_ - np.log(np.mean(np.exp(log_q_star_)))
return (np.exp(log_q_star) * log_q_star).mean() - self.kappa
log_tau_star = bisection_search(
kl_delta, self.log_tau_min, self.log_tau_max, xtol=1e-5,
maxiter=100
)
return 10 ** (log_tau_star)
def is_valid_response(self, q: np.ndarray):
return np.where(q > 0, q*np.log(q), 0).sum() <= self.kappa
def best_response(self, losses: th.Tensor):
tau_star = self.find_optimal_tau(losses.detach().cpu().numpy())
return th.softmax(losses / tau_star, dim=0)
class Chi2ConstrainedAdversary(NonParametricAdversary):
def __init__(
self, bound: float, eta_min: float = -10, eta_max: float = 10
) -> None:
super().__init__()
self.bound = bound
self.eta_min = eta_min
self.eta_max = eta_max
def is_valid_response(self, q: np.ndarray):
return 0.5*((q*len(q)-1)**2).mean() <= self.bound
def find_optimal_eta(self, losses: np.ndarray):
"""Find \\eta^* such that KL(q^*_\\eta^* || p) = \\kappa
Args:
losses (np.ndarray): sample losses
Returns:
float: The optimal \\eta
"""
def chi2_delta(eta):
q_star_ = np.maximum(1e-12, losses - eta)
q_star = q_star_ / q_star_.sum()
m = len(losses)
return 0.5*((m*q_star - 1)**2).mean() - self.bound
eta_min = -(1.0 / (np.sqrt(2 * self.bound + 1) - 1)) * losses.max()
eta_max = losses.max()
eta_star = bisection_search(
chi2_delta, eta_min, eta_max, xtol=1e-3, maxiter=100
)
return eta_star
def best_response(self, losses: th.Tensor):
# If the losses are too close, just return uniform weights
if (losses.max() - losses.min()) / losses.max() <= MIN_REL_DIFFERENCE:
return th.ones_like(losses) / len(losses)
# failsafe for batch sizes small compared to uncertainty set size
if len(losses) <= 1 + 2 * self.bound:
out = (losses == losses.max()).float()
out /= out.sum()
return out
# Otherwise find optimal eta
eta_star = self.find_optimal_eta(losses.detach().cpu().numpy())
q_star_ = th.relu(losses - eta_star)
return q_star_ / q_star_.sum()
class CVaRConstrainedAdversary(NonParametricAdversary):
def __init__(self, alpha: float) -> None:
super().__init__()
self.alpha = alpha
def is_valid_response(self, q: np.ndarray):
return q.max() <= 1/(len(q)*self.alpha)
def best_response(self, losses: th.Tensor):
m = len(losses)
# We assign the maximum weight (1 / alpha)
cutoff = int(self.alpha * m)
surplus = 1.0 - cutoff / (self.alpha * m)
p = th.zeros_like(losses)
idx = th.argsort(losses, descending=True)
p[idx[:cutoff]] = 1.0 / (self.alpha * m)
if cutoff < m:
p[idx[cutoff]] = surplus
return p