-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathwgpot.py
139 lines (102 loc) · 4 KB
/
wgpot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
WGPOT
Wasserstein Distance and Optimal Transport Map
of Gaussian Processes
Jiacheng Zhu
jzhu4@andrew.cmu.edu
"""
import numpy as np
import scipy.io
import scipy.linalg
# from matplotlib import pyplot as plt
def GP_W_barycenter(gp_list, lbda=None, err=None):
# Notice: Initialization
m_gp = len(gp_list) # Number of GPs
d_gp = gp_list[0][0].shape[0] # Dimension of the Gaussians
means_array = np.zeros((d_gp, m_gp))
cov_mats = np.zeros((d_gp, d_gp, m_gp))
for i in range(m_gp):
means_array[:, i] = gp_list[i][0][:, 0]
cov_mats[:, :, i] = gp_list[i][1]
# Notice: Constant limiting the amount of iterations
uplimit = 10**2
# Notice:
# If error margin is not specified, it's set to 1e-6
if err is None:
err = 1e-6
# Notice if weights are not specified, uniform wrights are chosen
if lbda is None:
lbda = (1.0/m_gp) * np.ones((1, m_gp))
# lbda = (0.0141) * np.ones((1, m_gp))
# Notice: Iteration
# The barycenter is the fixed point of the map F.
K = cov_mats[:, :, 0]
K_next = F_map(K, cov_mats, lbda)
count = 0
wd = Wasserstein_GP((np.zeros((d_gp, 1)), K), (np.zeros((d_gp, 1)), K_next))
while wd > err and count < uplimit:
K = K_next
K_next = F_map(K, cov_mats, lbda)
count = count + 1
print('count =', count)
wd = Wasserstein_GP((np.zeros((d_gp, 1)), K), (np.zeros((d_gp, 1)), K_next))
print(' W-d in this iteration =', wd)
if count == uplimit:
print('Barycenter did not converge')
mu_mean = np.sum(np.multiply(np.tile(lbda, (d_gp, 1)), means_array), axis=1, keepdims=1)
return mu_mean, K_next
# Notice: Squared 2-Wasserstein distance of GPs
def Wasserstein_GP(gp_0, gp_1):
mu_0 = gp_0[0]
K_0 = gp_0[1]
mu_1 = gp_1[0]
K_1 = gp_1[1]
sqrtK_0 = scipy.linalg.sqrtm(K_0)
first_term = np.dot(sqrtK_0, K_1)
K_0_K_1_K_0 = np.dot(first_term, sqrtK_0)
cov_dist = np.trace(K_0) + np.trace(K_1) - 2 * np.trace(scipy.linalg.sqrtm(K_0_K_1_K_0))
l2norm = (np.sum(np.square(abs(mu_0 - mu_1))))
d = np.real(np.sqrt(l2norm + cov_dist))
return d
# Notice
# The covariance matrix of the barycenter is the fixed point of
# the following map F
def F_map(K, cov_mats, lbda):
sqrtK = np.real(scipy.linalg.sqrtm(K))
d_gp = cov_mats.shape[0]
m_gp = lbda.shape[1]
T = np.zeros((d_gp, d_gp))
for i in range(m_gp):
K_bar_K_i_K_bar = np.dot(np.dot(sqrtK, cov_mats[:, :, i]), sqrtK)
T = T + lbda[0, i] * np.real(scipy.linalg.sqrtm(K_bar_K_i_K_bar))
# Notice
# x = np.linalg.solve(B.conj().T, A.conj().T).conj().T
# https://stackoverflow.com/questions/1007442/mrdivide-function-in-matlab-what-is-it-doing-and-how-can-i-do-it-in-python
# x = np.linalg.lstsq(sqrtK.T, np.square(T).T)[0] #
# x = np.linalg.solve(sqrtK.T, np.square(T).T)
# x = np.dot(np.dot(T, T), np.linalg.inv(sqrtK))
scd_term = np.linalg.solve(sqrtK.conj().T, np.dot(T, T).conj().T).conj().T
T = np.linalg.solve(sqrtK, scd_term)
return T
def logmap(mu_gp1, K_gp1, mu_gp2, K_gp2):
# Notice: The logarithmic map from GD1 tp GD2 on the Riemannian manifold
# of GDs with the W metric, see "W geometry of Gaussian measure"
# The logmap.m from [Anton NIPS 2017]
v_mu = mu_gp1 - mu_gp2
d_gp = mu_gp1.shape[0]
# Notice: * Here apply the transport map of Gaussian Process!
# Proposition 2 of
# "Procrustes Metrics on Covariance Operators and
# Optimal Transportation of Gaussian Processes"
sqrtK2 = np.real(scipy.linalg.sqrtm(K_gp2))
sqrt_sK2_K1_sK2 = np.real(scipy.linalg.sqrtm(np.dot(np.dot(sqrtK2, K_gp1),sqrtK2)))
scd_part = np.linalg.solve(sqrt_sK2_K1_sK2, sqrtK2)
T = np.dot(sqrtK2, scd_part) - np.eye(d_gp)
return v_mu, T
def expmap(mu_gp1, K_gp1, v_mu_t, v_K_t):
n = mu_gp1.shape[0]
# print('n =', n)
q_mu = mu_gp1 + v_mu_t
v_eye = np.eye(n) + v_K_t
q_K = np.dot(v_eye, np.dot(K_gp1, v_eye))
return q_mu, q_K