-
Notifications
You must be signed in to change notification settings - Fork 2
/
utils.py
120 lines (102 loc) · 3.95 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import math
import numpy as np
from matplotlib import pyplot as plt
import torch.distributed as dist
import pdb
def matrix_log_density_gaussian(x, mu, logvar):
"""Calculates log density of a Gaussian for all combination of bacth pairs of
`x` and `mu`. I.e. return tensor of shape `(batch_size, batch_size, dim)`
instead of (batch_size, dim) in the usual log density.
Parameters
----------
x: torch.Tensor
Value at which to compute the density. Shape: (batch_size, dim).
mu: torch.Tensor
Mean. Shape: (batch_size, dim).
logvar: torch.Tensor
Log variance. Shape: (batch_size, dim).
batch_size: int
number of training images in the batch
"""
batch_size, dim = x.shape
x = x.view(batch_size, 1, dim)
mu = mu.view(1, batch_size, dim)
logvar = logvar.view(1, batch_size, dim)
return log_density_gaussian(x, mu, logvar)
def log_density_gaussian(x, mu, logvar):
"""Calculates log density of a Gaussian.
Parameters
----------
x: torch.Tensor or np.ndarray or float
Value at which to compute the density.
mu: torch.Tensor or np.ndarray or float
Mean.
logvar: torch.Tensor or np.ndarray or float
Log variance.
"""
normalization = - 0.5 * (math.log(2 * math.pi) + logvar)
inv_var = torch.exp(-logvar)
log_density = normalization - 0.5 * ((x - mu)**2 * inv_var)
return log_density
def log_importance_weight_matrix(batch_size, dataset_size):
"""
Calculates a log importance weight matrix
Parameters
----------
batch_size: int
number of training images in the batch
dataset_size: int
number of training images in the dataset
"""
N = dataset_size
M = batch_size - 1
strat_weight = (N - M) / (N * M)
W = torch.Tensor(batch_size, batch_size).fill_(1 / M)
W.view(-1)[::M + 1] = 1 / N
W.view(-1)[1::M + 1] = strat_weight
W[M - 1, 0] = strat_weight
return W.log()
def _get_log_pz_qz_prodzi_qzCx(latent_sample, latent_dist, n_data, is_mss=True):
batch_size, hidden_dim = latent_sample.shape
# calculate log q(z|x)
log_q_zCx = log_density_gaussian(latent_sample, *latent_dist).sum(dim=1)
# calculate log p(z)
# mean and log var is 0
zeros = torch.zeros_like(latent_sample)
log_pz = log_density_gaussian(latent_sample, zeros, zeros).sum(1)
mat_log_qz = matrix_log_density_gaussian(latent_sample, *latent_dist)
if is_mss:
# use stratification
log_iw_mat = log_importance_weight_matrix(batch_size, n_data).to(latent_sample.device)
mat_log_qz = mat_log_qz + log_iw_mat.view(batch_size, batch_size, 1)
log_qz = torch.logsumexp(mat_log_qz.sum(2), dim=1, keepdim=False)
log_prod_qzi = torch.logsumexp(mat_log_qz, dim=1, keepdim=False).sum(1)
return log_pz, log_qz, log_prod_qzi, log_q_zCx
def l2_norm(input,axis=1):
norm = torch.norm(input,2,axis,True)
output = torch.div(input, norm+0.00001)
return norm, output
def one_hot(y, num_class):
return torch.zeros((len(y), num_class)).scatter_(1, y.unsqueeze(1), 1)
def DBindex(cl_data_file):
class_list = cl_data_file.keys()
cl_num= len(class_list)
cl_means = []
stds = []
DBs = []
for cl in class_list:
cl_means.append( np.mean(cl_data_file[cl], axis = 0) )
stds.append( np.sqrt(np.mean( np.sum(np.square( cl_data_file[cl] - cl_means[-1]), axis = 1))))
mu_i = np.tile( np.expand_dims( np.array(cl_means), axis = 0), (len(class_list),1,1) )
mu_j = np.transpose(mu_i,(1,0,2))
mdists = np.sqrt(np.sum(np.square(mu_i - mu_j), axis = 2))
for i in range(cl_num):
DBs.append( np.max([ (stds[i]+ stds[j])/mdists[i,j] for j in range(cl_num) if j != i ]) )
return np.mean(DBs)
def sparsity(cl_data_file):
class_list = cl_data_file.keys()
cl_sparsity = []
for cl in class_list:
cl_sparsity.append(np.mean([np.sum(x!=0) for x in cl_data_file[cl] ]) )
return np.mean(cl_sparsity)