forked from tkarras/progressive_growing_of_gans
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sliced_wasserstein.py
executable file
·96 lines (80 loc) · 4.15 KB
/
sliced_wasserstein.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
import numpy as np
import scipy.ndimage
#----------------------------------------------------------------------------
def get_descriptors_for_minibatch(minibatch, nhood_size, nhoods_per_image):
S = minibatch.shape # (minibatch, channel, height, width)
assert len(S) == 4 and S[1] == 3
N = nhoods_per_image * S[0]
H = nhood_size / 2
nhood, chan, x, y = np.ogrid[0:N, 0:3, -H:H+1, -H:H+1]
img = nhood / nhoods_per_image
x = x + np.random.randint(H, S[3] - H, size=(N, 1, 1, 1))
y = y + np.random.randint(H, S[2] - H, size=(N, 1, 1, 1))
idx = ((img * S[1] + chan) * S[2] + y) * S[3] + x
return minibatch.flat[idx]
#----------------------------------------------------------------------------
def finalize_descriptors(desc):
if isinstance(desc, list):
desc = np.concatenate(desc, axis=0)
assert desc.ndim == 4 # (neighborhood, channel, height, width)
desc -= np.mean(desc, axis=(0, 2, 3), keepdims=True)
desc /= np.std(desc, axis=(0, 2, 3), keepdims=True)
desc = desc.reshape(desc.shape[0], -1)
return desc
#----------------------------------------------------------------------------
def sliced_wasserstein(A, B, dir_repeats, dirs_per_repeat):
assert A.ndim == 2 and A.shape == B.shape # (neighborhood, descriptor_component)
results = []
for repeat in xrange(dir_repeats):
dirs = np.random.randn(A.shape[1], dirs_per_repeat) # (descriptor_component, direction)
dirs /= np.sqrt(np.sum(np.square(dirs), axis=0, keepdims=True)) # normalize descriptor components for each direction
dirs = dirs.astype(np.float32)
projA = np.matmul(A, dirs) # (neighborhood, direction)
projB = np.matmul(B, dirs)
projA = np.sort(projA, axis=0) # sort neighborhood projections for each direction
projB = np.sort(projB, axis=0)
dists = np.abs(projA - projB) # pointwise wasserstein distances
results.append(np.mean(dists)) # average over neighborhoods and directions
return np.mean(results) # average over repeats
#----------------------------------------------------------------------------
def downscale_minibatch(minibatch, lod):
if lod == 0:
return minibatch
t = minibatch.astype(np.float32)
for i in xrange(lod):
t = (t[:, :, 0::2, 0::2] + t[:, :, 0::2, 1::2] + t[:, :, 1::2, 0::2] + t[:, :, 1::2, 1::2]) * 0.25
return np.round(t).clip(0, 255).astype(np.uint8)
#----------------------------------------------------------------------------
gaussian_filter = np.float32([
[1, 4, 6, 4, 1],
[4, 16, 24, 16, 4],
[6, 24, 36, 24, 6],
[4, 16, 24, 16, 4],
[1, 4, 6, 4, 1]]) / 256.0
def pyr_down(minibatch): # matches cv2.pyrDown()
assert minibatch.ndim == 4
return scipy.ndimage.convolve(minibatch, gaussian_filter[np.newaxis, np.newaxis, :, :], mode='mirror')[:, :, ::2, ::2]
def pyr_up(minibatch): # matches cv2.pyrUp()
assert minibatch.ndim == 4
S = minibatch.shape
res = np.zeros((S[0], S[1], S[2] * 2, S[3] * 2), minibatch.dtype)
res[:, :, ::2, ::2] = minibatch
return scipy.ndimage.convolve(res, gaussian_filter[np.newaxis, np.newaxis, :, :] * 4.0, mode='mirror')
def generate_laplacian_pyramid(minibatch, num_levels):
pyramid = [np.float32(minibatch)]
for i in xrange(1, num_levels):
pyramid.append(pyr_down(pyramid[-1]))
pyramid[-2] -= pyr_up(pyramid[-1])
return pyramid
def reconstruct_laplacian_pyramid(pyramid):
minibatch = pyramid[-1]
for level in pyramid[-2::-1]:
minibatch = pyr_up(minibatch) + level
return minibatch
#----------------------------------------------------------------------------