-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgpu_cage.py
95 lines (73 loc) · 3.46 KB
/
gpu_cage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
from torch.distributions.beta import Beta
def probability_y(pi_y):
pi = torch.exp(pi_y)
return pi / pi.sum()
def phi(theta, l, device):
value = theta * torch.abs(l).double().to(device=device)
return value
def calculate_normalizer(theta, k, n_classes, device):
z = 0
for y in range(n_classes):
m_y = torch.exp(phi(theta[y], torch.ones(k.shape), device=device))
z += (1 + m_y).prod()
return z
def probability_l_y(theta, l, k, n_classes, device):
probability = torch.zeros((l.shape[0], n_classes), device=device)
z = calculate_normalizer(theta, k, n_classes, device=device)
for y in range(n_classes):
# print('l.shape ', l.shape)
yo = phi(theta[y], l , device)
# print('yo.shape', yo.shape)
# print(yo.shape[0])
# try:
#yo = yo.view(-1, l.shape[0])
yoo = torch.exp(yo.sum(1))
# except:
# print('inside except cage #32')
# yoo = torch.exp(yo.sum())
probability[:, y] = yoo/ z
return probability.double()
def probability_s_given_y_l(pi, s, y, l, k, continuous_mask, ratio_agreement=0.85, model=1, theta_process=2):
eq = torch.eq(k.view(-1, 1), y).double().t()
r = ratio_agreement * eq.squeeze() + (1 - ratio_agreement) * (1 - eq.squeeze())
params = torch.exp(pi)
probability = 1
for i in range(k.shape[0]):
m = Beta(r[i] * params[i], params[i] * (1 - r[i]))
probability *= (torch.exp(m.log_prob(s[:, i].double())) * l[:, i].double() + (1 - l[:, i]).double()) * continuous_mask[i] + (1 - continuous_mask[i])
return probability
def probability(theta, pi_y, pi, l, s, k, n_classes, continuous_mask, device):
p_l_y = probability_l_y(theta, l, k, n_classes, device=device)
p_s = torch.ones(s.shape[0], n_classes, device=device).double()
for y in range(n_classes):
p_s[:, y] = probability_s_given_y_l(pi[y], s, y, l, k, continuous_mask)
return p_l_y * p_s
# print((prob.T/prob.sum(1)).T)
# input()
# return prob
# return (prob.T/prob.sum(1)).T
def log_likelihood_loss(theta, pi_y, pi, l, s, k, n_classes, continuous_mask, device):
eps = 1e-8
return - torch.log(probability(theta, pi_y, pi, l, s, k, n_classes, continuous_mask, device).sum(1)).sum() / s.shape[0]
def log_likelihood_loss_supervised(theta, pi_y, pi, y, l, s, k, n_classes, continuous_mask, device):
eps = 1e-8
prob = probability(theta, pi_y, pi, l, s, k, n_classes, continuous_mask, device)
prob = (prob.t() / prob.sum(1)).t()
return torch.nn.NLLLoss()(torch.log(prob), y)
# return - torch.log(probability(theta, pi_y, pi, l, s, k, n_classes, continuous_mask)[:, y] + eps).mean()# / s.shape[0]
def precision_loss(theta, k, n_classes, a, device):
n_lfs = k.shape[0]
prob = torch.ones(n_lfs, n_classes).double()
z_per_lf = 0
for y in range(n_classes):
m_y = torch.exp(phi(theta[y], torch.ones(n_lfs, device=device), device=device))
per_lf_matrix = torch.tensordot((1 + m_y).view(-1, 1), torch.ones(m_y.shape, device= device).double().view(1, -1), 1) - torch.eye(n_lfs, device=device).double()
prob[:, y] = per_lf_matrix.prod(0).double()
z_per_lf += prob[:, y].double()
prob /= z_per_lf.view(-1, 1)
correct_prob = torch.zeros(n_lfs, device=device)
for i in range(n_lfs):
correct_prob[i] = prob[i, k[i]]
loss = a * torch.log(correct_prob).double() + (1 - a) * torch.log(1 - correct_prob).double()
return -loss.sum()