-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutils.py
40 lines (37 loc) · 1.46 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
import torch.nn.functional as F
import numpy as np
import os
def label2onehot( labels, dim):
"""Convert label indices to one-hot vectors."""
batch_size = labels.size(0)
out = torch.zeros(batch_size, dim)
out[np.arange(batch_size), labels.long()] = 1
return out
def classification_loss( logit, target):
return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0)
def gradient_penalty( y, x):
"""Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
weight = torch.ones(y.size()).cuda()
dydx = torch.autograd.grad(outputs=y,
inputs=x,
grad_outputs=weight,
retain_graph=True,
create_graph=True,
only_inputs=True)[0]
dydx = dydx.view(dydx.size(0), -1)
dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
return torch.mean((dydx_l2norm-1)**2)
def seed_torch(seed=10):
# random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def update_lr(optim, lr):
"""Decay learning rates of the generator and discriminator."""
for param_group in optim.param_groups:
param_group['lr'] = lr