-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathtopoloss_pytorch.py
225 lines (183 loc) · 10.8 KB
/
topoloss_pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# =============================================================================
# Created By : Xiaoling Hu
# Created Date: Tue June 22 9:00:00 PDT 2021
# =============================================================================
import time
import numpy
import gudhi as gd
from pylab import *
import torch
def compute_dgm_force(lh_dgm, gt_dgm, pers_thresh=0.03, pers_thresh_perfect=0.99, do_return_perfect=False):
"""
Compute the persistent diagram of the image
Args:
lh_dgm: likelihood persistent diagram.
gt_dgm: ground truth persistent diagram.
pers_thresh: Persistent threshold, which also called dynamic value, which measure the difference.
between the local maximum critical point value with its neighouboring minimum critical point value.
The value smaller than the persistent threshold should be filtered. Default: 0.03
pers_thresh_perfect: The distance difference between two critical points that can be considered as
correct match. Default: 0.99
do_return_perfect: Return the persistent point or not from the matching. Default: False
Returns:
force_list: The matching between the likelihood and ground truth persistent diagram
idx_holes_to_fix: The index of persistent points that requires to fix in the following training process
idx_holes_to_remove: The index of persistent points that require to remove for the following training
process
"""
lh_pers = abs(lh_dgm[:, 1] - lh_dgm[:, 0])
if (gt_dgm.shape[0] == 0):
gt_pers = None;
gt_n_holes = 0;
else:
gt_pers = gt_dgm[:, 1] - gt_dgm[:, 0]
gt_n_holes = gt_pers.size # number of holes in gt
if (gt_pers is None or gt_n_holes == 0):
idx_holes_to_fix = list();
idx_holes_to_remove = list(set(range(lh_pers.size)))
idx_holes_perfect = list();
else:
# check to ensure that all gt dots have persistence 1
tmp = gt_pers > pers_thresh_perfect
# get "perfect holes" - holes which do not need to be fixed, i.e., find top
# lh_n_holes_perfect indices
# check to ensure that at least one dot has persistence 1; it is the hole
# formed by the padded boundary
# if no hole is ~1 (ie >.999) then just take all holes with max values
tmp = lh_pers > pers_thresh_perfect # old: assert tmp.sum() >= 1
lh_pers_sorted_indices = np.argsort(lh_pers)[::-1]
if np.sum(tmp) >= 1:
lh_n_holes_perfect = tmp.sum()
idx_holes_perfect = lh_pers_sorted_indices[:lh_n_holes_perfect];
else:
idx_holes_perfect = list();
# find top gt_n_holes indices
idx_holes_to_fix_or_perfect = lh_pers_sorted_indices[:gt_n_holes];
# the difference is holes to be fixed to perfect
idx_holes_to_fix = list(
set(idx_holes_to_fix_or_perfect) - set(idx_holes_perfect))
# remaining holes are all to be removed
idx_holes_to_remove = lh_pers_sorted_indices[gt_n_holes:];
# only select the ones whose persistence is large enough
# set a threshold to remove meaningless persistence dots
pers_thd = pers_thresh
idx_valid = np.where(lh_pers > pers_thd)[0]
idx_holes_to_remove = list(
set(idx_holes_to_remove).intersection(set(idx_valid)))
force_list = np.zeros(lh_dgm.shape)
# push each hole-to-fix to (0,1)
force_list[idx_holes_to_fix, 0] = 0 - lh_dgm[idx_holes_to_fix, 0]
force_list[idx_holes_to_fix, 1] = 1 - lh_dgm[idx_holes_to_fix, 1]
# push each hole-to-remove to (0,1)
force_list[idx_holes_to_remove, 0] = lh_pers[idx_holes_to_remove] / \
math.sqrt(2.0)
force_list[idx_holes_to_remove, 1] = -lh_pers[idx_holes_to_remove] / \
math.sqrt(2.0)
if (do_return_perfect):
return force_list, idx_holes_to_fix, idx_holes_to_remove, idx_holes_perfect
return force_list, idx_holes_to_fix, idx_holes_to_remove
def getCriticalPoints(likelihood):
"""
Compute the critical points of the image (Value range from 0 -> 1)
Args:
likelihood: Likelihood image from the output of the neural networks
Returns:
pd_lh: persistence diagram.
bcp_lh: Birth critical points.
dcp_lh: Death critical points.
Bool: Skip the process if number of matching pairs is zero.
"""
lh = 1 - likelihood
lh_vector = np.asarray(lh).flatten()
lh_cubic = gd.CubicalComplex(
dimensions=[lh.shape[0], lh.shape[1]],
top_dimensional_cells=lh_vector
)
Diag_lh = lh_cubic.persistence(homology_coeff_field=2, min_persistence=0)
pairs_lh = lh_cubic.cofaces_of_persistence_pairs()
# If the paris is 0, return False to skip
if (len(pairs_lh[0])==0): return 0, 0, 0, False
# return persistence diagram, birth/death critical points
pd_lh = numpy.array([[lh_vector[pairs_lh[0][0][i][0]], lh_vector[pairs_lh[0][0][i][1]]] for i in range(len(pairs_lh[0][0]))])
bcp_lh = numpy.array([[pairs_lh[0][0][i][0]//lh.shape[1], pairs_lh[0][0][i][0]%lh.shape[1]] for i in range(len(pairs_lh[0][0]))])
dcp_lh = numpy.array([[pairs_lh[0][0][i][1]//lh.shape[1], pairs_lh[0][0][i][1]%lh.shape[1]] for i in range(len(pairs_lh[0][0]))])
return pd_lh, bcp_lh, dcp_lh, True
def getTopoLoss(likelihood_tensor, gt_tensor, topo_size=100):
"""
Calculate the topology loss of the predicted image and ground truth image
Warning: To make sure the topology loss is able to back-propagation, likelihood
tensor requires to clone before detach from GPUs. In the end, you can hook the
likelihood tensor to GPUs device.
Args:
likelihood_tensor: The likelihood pytorch tensor.
gt_tensor : The groundtruth of pytorch tensor.
topo_size : The size of the patch is used. Default: 100
Returns:
loss_topo : The topology loss value (tensor)
"""
likelihood = torch.sigmoid(likelihood_tensor).clone()
gt = gt_tensor.clone()
likelihood = torch.squeeze(likelihood).cpu().detach().numpy()
gt = torch.squeeze(gt).cpu().detach().numpy()
topo_cp_weight_map = np.zeros(likelihood.shape)
topo_cp_ref_map = np.zeros(likelihood.shape)
for y in range(0, likelihood.shape[0], topo_size):
for x in range(0, likelihood.shape[1], topo_size):
lh_patch = likelihood[y:min(y + topo_size, likelihood.shape[0]),
x:min(x + topo_size, likelihood.shape[1])]
gt_patch = gt[y:min(y + topo_size, gt.shape[0]),
x:min(x + topo_size, gt.shape[1])]
if(np.min(lh_patch) == 1 or np.max(lh_patch) == 0): continue
if(np.min(gt_patch) == 1 or np.max(gt_patch) == 0): continue
# Get the critical points of predictions and ground truth
pd_lh, bcp_lh, dcp_lh, pairs_lh_pa = getCriticalPoints(lh_patch)
pd_gt, bcp_gt, dcp_gt, pairs_lh_gt = getCriticalPoints(gt_patch)
# If the pairs not exist, continue for the next loop
if not(pairs_lh_pa): continue
if not(pairs_lh_gt): continue
force_list, idx_holes_to_fix, idx_holes_to_remove = compute_dgm_force(pd_lh, pd_gt, pers_thresh=0.03)
if (len(idx_holes_to_fix) > 0 or len(idx_holes_to_remove) > 0):
for hole_indx in idx_holes_to_fix:
if (int(bcp_lh[hole_indx][0]) >= 0 and int(bcp_lh[hole_indx][0]) < likelihood.shape[0] and int(
bcp_lh[hole_indx][1]) >= 0 and int(bcp_lh[hole_indx][1]) < likelihood.shape[1]):
topo_cp_weight_map[y + int(bcp_lh[hole_indx][0]), x + int(
bcp_lh[hole_indx][1])] = 1 # push birth to 0 i.e. min birth prob or likelihood
topo_cp_ref_map[y + int(bcp_lh[hole_indx][0]), x + int(bcp_lh[hole_indx][1])] = 0
if (int(dcp_lh[hole_indx][0]) >= 0 and int(dcp_lh[hole_indx][0]) < likelihood.shape[
0] and int(dcp_lh[hole_indx][1]) >= 0 and int(dcp_lh[hole_indx][1]) <
likelihood.shape[1]):
topo_cp_weight_map[y + int(dcp_lh[hole_indx][0]), x + int(
dcp_lh[hole_indx][1])] = 1 # push death to 1 i.e. max death prob or likelihood
topo_cp_ref_map[y + int(dcp_lh[hole_indx][0]), x + int(dcp_lh[hole_indx][1])] = 1
for hole_indx in idx_holes_to_remove:
if (int(bcp_lh[hole_indx][0]) >= 0 and int(bcp_lh[hole_indx][0]) < likelihood.shape[
0] and int(bcp_lh[hole_indx][1]) >= 0 and int(bcp_lh[hole_indx][1]) <
likelihood.shape[1]):
topo_cp_weight_map[y + int(bcp_lh[hole_indx][0]), x + int(
bcp_lh[hole_indx][1])] = 1 # push birth to death # push to diagonal
if (int(dcp_lh[hole_indx][0]) >= 0 and int(dcp_lh[hole_indx][0]) < likelihood.shape[
0] and int(dcp_lh[hole_indx][1]) >= 0 and int(dcp_lh[hole_indx][1]) <
likelihood.shape[1]):
topo_cp_ref_map[y + int(bcp_lh[hole_indx][0]), x + int(bcp_lh[hole_indx][1])] = \
lh_patch[int(dcp_lh[hole_indx][0]), int(dcp_lh[hole_indx][1])]
else:
topo_cp_ref_map[y + int(bcp_lh[hole_indx][0]), x + int(bcp_lh[hole_indx][1])] = 1
if (int(dcp_lh[hole_indx][0]) >= 0 and int(dcp_lh[hole_indx][0]) < likelihood.shape[
0] and int(dcp_lh[hole_indx][1]) >= 0 and int(dcp_lh[hole_indx][1]) <
likelihood.shape[1]):
topo_cp_weight_map[y + int(dcp_lh[hole_indx][0]), x + int(
dcp_lh[hole_indx][1])] = 1 # push death to birth # push to diagonal
if (int(bcp_lh[hole_indx][0]) >= 0 and int(bcp_lh[hole_indx][0]) < likelihood.shape[
0] and int(bcp_lh[hole_indx][1]) >= 0 and int(bcp_lh[hole_indx][1]) <
likelihood.shape[1]):
topo_cp_ref_map[y + int(dcp_lh[hole_indx][0]), x + int(dcp_lh[hole_indx][1])] = \
lh_patch[int(bcp_lh[hole_indx][0]), int(bcp_lh[hole_indx][1])]
else:
topo_cp_ref_map[y + int(dcp_lh[hole_indx][0]), x + int(dcp_lh[hole_indx][1])] = 0
topo_cp_weight_map = torch.tensor(topo_cp_weight_map, dtype=torch.float).cuda()
topo_cp_ref_map = torch.tensor(topo_cp_ref_map, dtype=torch.float).cuda()
# Measuring the MSE loss between predicted critical points and reference critical points
loss_topo = (((likelihood_tensor * topo_cp_weight_map) - topo_cp_ref_map) ** 2).sum()
return loss_topo