-
Notifications
You must be signed in to change notification settings - Fork 4
/
loss.py
42 lines (35 loc) · 1.27 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn.functional as F
import math
from utils.utils import gamma2as, gamma2logas
def diffusion_elbo(gamma_0, gamma_1, d_gamma_t,
x, noise, noise_hat):
log_alpha_0, log_var_0 = gamma2logas(gamma_0)
log_alpha_1, log_var_1 = gamma2logas(gamma_1)
# prior loss KL(q(z_1|x) || p(z_1)))
# mu = alpha_1 * x
x_flat = x.view(-1)
x_dot = x_flat @ x_flat / x_flat.numel()
prior_loss = 0.5 * (log_var_1.exp() + x_dot *
torch.exp(log_alpha_1 * 2) - 1 - log_var_1)
#torch.mean(var_1 + mu * mu - 1 - var_1.log())
# recon loss E[-log p(x | z_0)]
# diff = (1 - alpha_0) * x
l2 = x_dot * torch.expm1(log_alpha_0) ** 2
ll = -0.5 * (log_var_0 + l2 / log_var_0.exp() + math.log(2 * math.pi))
recon_loss = -ll
extra_dict = {
'kld': prior_loss.item(),
'll': ll.item()
}
# diffusion loss
diff = noise - noise_hat
loss_T_raw = 0.5 * (d_gamma_t * (diff * diff).mean(1)
) / d_gamma_t.shape[0]
loss_T = loss_T_raw.sum()
extra_dict['loss_T_raw'] = loss_T_raw.detach()
extra_dict['loss_T'] = loss_T.item()
loss = prior_loss + recon_loss + loss_T
elbo = -loss
extra_dict['elbo'] = elbo.item()
return loss, extra_dict