-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsift.py
58 lines (40 loc) · 1.57 KB
/
sift.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn as nn
import cv2
import numpy as np
class SIFT(nn.Module):
def __init__(self):
super().__init__()
try:
self.sift = cv2.SIFT_create()
except:
self.sift = cv2.xfeatures2d.SIFT_create()
def forward(self, img1, img2, device='cpu'):
kp1, des1 = self.sift.detectAndCompute(img1,None)
kp2, des2 = self.sift.detectAndCompute(img2,None)
desc1 = torch.from_numpy(des1).float().to(device)
desc2 = torch.from_numpy(des2).float().to(device)
match_ids, _ = mnn_matching(desc1, desc2, threshold=0)
match_ids = match_ids.cpu().numpy()
kp1 = np.array([kp.pt for kp in kp1])
kp2 = np.array([kp.pt for kp in kp2])
p1 = kp1[match_ids[:, 0]]
p2 = kp2[match_ids[:, 1]]
matches = np.hstack((p1, p2))
matches = torch.tensor(matches).float().to(device)
return matches
def mnn_matching(desc1, desc2, threshold=None):
desc1 = desc1 / desc1.norm(dim=1, keepdim=True)
desc2 = desc2 / desc2.norm(dim=1, keepdim=True)
martix = desc1 @ desc2.t()
nn12 = martix.max(dim=1)[1]
nn21 = martix.max(dim=0)[1]
ids1 = torch.arange(0, martix.shape[0], device=desc1.device)
mask = (ids1 == nn21[nn12])
matches = torch.stack([ids1[mask], nn12[mask]]).t()
scores = martix.max(dim=1)[0][mask]
if threshold is not None:
mask = scores > threshold
matches = matches[mask]
scores = scores[mask]
return matches, scores