forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
alignment.py
54 lines (49 loc) · 2.28 KB
/
alignment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common utilities for data pre-processing, e.g. matching moving object across frames."""
import numpy as np
def compute_overlap(mask1, mask2):
# Use IoU here.
return np.sum(mask1 & mask2)/np.sum(mask1 | mask2)
def align(seg_img1, seg_img2, seg_img3, threshold_same=0.3):
res_img1 = np.zeros_like(seg_img1)
res_img2 = np.zeros_like(seg_img2)
res_img3 = np.zeros_like(seg_img3)
remaining_objects2 = list(np.unique(seg_img2.flatten()))
remaining_objects3 = list(np.unique(seg_img3.flatten()))
for seg_id in np.unique(seg_img1):
# See if we can find correspondences to seg_id in seg_img2.
max_overlap2 = float('-inf')
max_segid2 = -1
for seg_id2 in remaining_objects2:
overlap = compute_overlap(seg_img1==seg_id, seg_img2==seg_id2)
if overlap>max_overlap2:
max_overlap2 = overlap
max_segid2 = seg_id2
if max_overlap2 > threshold_same:
max_overlap3 = float('-inf')
max_segid3 = -1
for seg_id3 in remaining_objects3:
overlap = compute_overlap(seg_img2==max_segid2, seg_img3==seg_id3)
if overlap>max_overlap3:
max_overlap3 = overlap
max_segid3 = seg_id3
if max_overlap3 > threshold_same:
res_img1[seg_img1==seg_id] = seg_id
res_img2[seg_img2==max_segid2] = seg_id
res_img3[seg_img3==max_segid3] = seg_id
remaining_objects2.remove(max_segid2)
remaining_objects3.remove(max_segid3)
return res_img1, res_img2, res_img3