-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathutils.py
59 lines (46 loc) · 2.16 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.autograd import grad as torch_grad
def is_image_file(filename):
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
def load_img(filepath):
img = Image.open(filepath).convert('RGB')
img = img.resize((256, 256), Image.BICUBIC)
return img
def save_img(image_tensor, filename):
image_numpy = image_tensor.float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
image_numpy = image_numpy.clip(0, 255)
image_numpy = image_numpy.astype(np.uint8)
image_pil = Image.fromarray(image_numpy)
image_pil.save(filename)
print("Image saved as {}".format(filename))
def _gradient_penalty(D, real_data, generated_data, use_cuda):
batch_size = real_data.size()[0]
# Calculate interpolation
alpha = torch.rand(batch_size, 1, 1, 1)
alpha = alpha.expand_as(real_data)
if use_cuda:
alpha = alpha.cuda()
interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data
interpolated = Variable(interpolated, requires_grad=True)
if use_cuda:
interpolated = interpolated.cuda()
# Calculate probability of interpolated examples
prob_interpolated = D(interpolated)
# Calculate gradients of probabilities with respect to examples
gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
grad_outputs=torch.ones(prob_interpolated.size()).cuda() if use_cuda else torch.ones(
prob_interpolated.size()),
create_graph=False, retain_graph=False)[0]
# Gradients have shape (batch_size, num_channels, img_width, img_height),
# so flatten to easily take norm per example in batch
gradients = gradients.view(batch_size, -1)
# Derivatives of the gradient close to 0 can cause problems because of
# the square root, so manually calculate norm and add epsilon
gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)
# Return gradient penalty
return 10 * ((gradients_norm - 1) ** 2).mean()