-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathweighted_cage.py
95 lines (78 loc) · 3.88 KB
/
weighted_cage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
from torch.distributions.beta import Beta
import math
def probability_y(pi_y):
pi = torch.exp(pi_y)
return pi / pi.sum()
def phi(theta, l):
return theta * torch.abs(l).double()
def calculate_normalizer(theta, k, n_classes):
z = 0
for y in range(n_classes):
m_y = torch.exp(phi(theta[y], torch.ones(k.shape)))
z += (1 + m_y).prod()
return z
def probability_l_y(theta, l, k, n_classes, weights):
probability = torch.zeros((l.shape[0], n_classes))
z = calculate_normalizer(weights.view(1, -1)*theta, k, n_classes)
for y in range(n_classes):
probability[:, y] = torch.exp(phi(weights.view(1, -1)*theta[y], l).sum(1)) / z
return probability.double()
def probability_s_given_y_l(pi, s, y, l, k, continuous_mask, weights, ratio_agreement=0.85, model=1, theta_process=2):
eq = torch.eq(k.view(-1, 1), y).double().t()
r = ratio_agreement * eq.squeeze() + (1 - ratio_agreement) * (1 - eq.squeeze())
params = torch.exp(pi)
probability = 1
for i in range(k.shape[0]):
m = Beta(weights[i]* (r[i] * params[i] - 1) +1, weights[i]*((params[i] * (1 - r[i]))-1) + 1)
probability *= (torch.exp(m.log_prob(s[:, i].double())) * l[:, i].double() + (1 - l[:, i]).double()) * continuous_mask[i] + (1 - continuous_mask[i])
return probability
def probability(theta, pi_y, pi, l, s, k, n_classes, continuous_mask, weights):
p_l_y = probability_l_y(theta, l, k, n_classes, weights)
p_s = torch.ones(s.shape[0], n_classes).double()
for y in range(n_classes):
p_s[:, y] = probability_s_given_y_l(pi[y], s, y, l, k, continuous_mask, weights)
return p_l_y * p_s
# print((prob.T/prob.sum(1)).T)
# input()
# return prob
# return (prob.T/prob.sum(1)).T
def log_likelihood_loss(theta, pi_y, pi, l, s, k, n_classes, continuous_mask, weights):
eps = 1e-8
return - torch.log(probability(theta, pi_y, pi, l, s, k, n_classes, continuous_mask, weights).sum(1)).sum() / s.shape[0]
def log_likelihood_loss_supervised(theta, pi_y, pi, y, l, s, k, n_classes, continuous_mask, weights):
eps = 1e-8
prob = probability(theta, pi_y, pi, l, s, k, n_classes, continuous_mask, weights)
prob = (prob.t() / prob.sum(1)).t()
# print('prob ka shape', prob.shape)
return torch.nn.NLLLoss()(torch.log(prob), y)
# return - torch.log(probability(theta, pi_y, pi, l, s, k, n_classes, continuous_mask)[:, y] + eps).mean()# / s.shape[0]
def precision_loss(theta, k, n_classes, a, weights):
n_lfs = k.shape[0]
prob = torch.ones(n_lfs, n_classes).double()
z_per_lf = 0
for y in range(n_classes):
m_y = torch.exp(phi(theta[y] * weights, torch.ones(n_lfs)))
per_lf_matrix = torch.tensordot((1 + m_y).view(-1, 1), torch.ones(m_y.shape).double().view(1, -1), 1) - torch.eye(n_lfs).double()
prob[:, y] = per_lf_matrix.prod(0).double()
z_per_lf += prob[:, y].double()
prob /= z_per_lf.view(-1, 1)
correct_prob = torch.zeros(n_lfs)
for i in range(n_lfs):
correct_prob[i] = prob[i, k[i]]
loss = (1/math.exp(1)) * (a * torch.exp(weights) * torch.log(correct_prob).double() + (1 - a) * torch.exp(weights) * torch.log(1 - correct_prob).double())
return -loss.sum()
# n_lfs = k.shape[0]
# prob = torch.ones(n_lfs, n_classes).double()
# z_per_lf = 0
# for y in range(n_classes):
# m_y = torch.exp(phi(theta[y] * weights, torch.ones(n_lfs)))
# per_lf_matrix = torch.tensordot((1 + m_y).view(-1, 1), torch.ones(m_y.shape).double().view(1, -1), 1) - torch.eye(n_lfs).double()
# prob[:, y] = per_lf_matrix.prod(0).double()
# z_per_lf += prob[:, y].double()
# prob /= z_per_lf.view(-1, 1)
# correct_prob = torch.zeros(n_lfs)
# for i in range(n_lfs):
# correct_prob[i] = prob[i, k[i]]
# loss = a * torch.log(correct_prob).double() + (1 - a) * torch.log(1 - correct_prob).double()
# return -loss.sum()