-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathsampler.py
111 lines (91 loc) · 4.75 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from __future__ import division
import torch
from numpy.random import gamma
from torch.optim import Optimizer
class H_SA_SGHMC(Optimizer):
""" Stochastic Gradient Hamiltonian Monte-Carlo Sampler that uses a burn-in
procedure to adapt its own hyperparameters during the initial stages
of sampling."""
def __init__(self, params, lr=1e-2, base_C=0.05, gauss_sig=0.1, alpha0=10, beta0=10):
""" Set up a SGHMC Optimizer.
Parameters
----------
params : iterable
Parameters serving as optimization variable.
lr: float, optional
Base learning rate for this optimizer.
Must be tuned to the specific function being minimized.
Default: `1e-2`.
base_C:float, optional
(Constant) momentum decay per time-step.
Default: `0.05`.
"""
self.eps = 1e-6
self.alpha0 = alpha0
self.beta0 = beta0
if gauss_sig == 0:
self.weight_decay = 0
else:
self.weight_decay = 1 / (gauss_sig ** 2)
if self.weight_decay <= 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if base_C < 0:
raise ValueError("Invalid friction term: {}".format(base_C))
defaults = dict(
lr=lr,
base_C=base_C,
)
super(H_SA_SGHMC, self).__init__(params, defaults)
def step(self, burn_in=False, resample_momentum=False, resample_prior=False):
"""Simulate discretized Hamiltonian dynamics for one step"""
loss = None
for group in self.param_groups: # iterate over blocks -> the ones defined in defaults. We dont use groups.
for p in group["params"]: # these are weight and bias matrices
if p.grad is None:
continue
state = self.state[p] # define dict for each individual param
if len(state) == 0:
state["iteration"] = 0
state["tau"] = torch.ones_like(p)
state["g"] = torch.ones_like(p)
state["V_hat"] = torch.ones_like(p)
state["v_momentum"] = torch.zeros_like(
p) # p.data.new(p.data.size()).normal_(mean=0, std=np.sqrt(group["lr"])) #
state['weight_decay'] = self.weight_decay
state["iteration"] += 1 # this is kind of useless now but lets keep it provisionally
if resample_prior:
alpha = self.alpha0 + p.data.nelement() / 2
beta = self.beta0 + (p.data ** 2).sum().item() / 2
gamma_sample = gamma(shape=alpha, scale=1 / (beta), size=None)
# print('std', 1/np.sqrt(gamma_sample))
state['weight_decay'] = gamma_sample
base_C, lr = group["base_C"], group["lr"]
weight_decay = state["weight_decay"]
tau, g, V_hat = state["tau"], state["g"], state["V_hat"]
d_p = p.grad.data
if weight_decay != 0:
d_p.add_(weight_decay, p.data)
# update parameters during burn-in
if burn_in: # We update g first as it makes most sense
tau.add_(-tau * (g ** 2) / (
V_hat + self.eps) + 1) # specifies the moving average window, see Eq 9 in [1] left
tau_inv = 1. / (tau + self.eps)
g.add_(-tau_inv * g + tau_inv * d_p) # average gradient see Eq 9 in [1] right
V_hat.add_(-tau_inv * V_hat + tau_inv * (d_p ** 2)) # gradient variance see Eq 8 in [1]
V_sqrt = torch.sqrt(V_hat)
V_inv_sqrt = 1. / (V_sqrt + self.eps) # preconditioner
if resample_momentum: # equivalent to var = M under momentum reparametrisation
state["v_momentum"] = torch.normal(mean=torch.zeros_like(d_p),
std=torch.sqrt((lr ** 2) * V_inv_sqrt))
v_momentum = state["v_momentum"]
noise_var = (2. * (lr ** 2) * V_inv_sqrt * base_C - (lr ** 4))
noise_std = torch.sqrt(torch.clamp(noise_var, min=1e-16))
# sample random epsilon
noise_sample = torch.normal(mean=torch.zeros_like(d_p), std=torch.ones_like(d_p) * noise_std)
# update momentum (Eq 10 right in [1])
v_momentum.add_(- (lr ** 2) * V_inv_sqrt * d_p - base_C * v_momentum + noise_sample)
# update theta (Eq 10 left in [1])
p.data.add_(v_momentum)
return loss