-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtracking.py
223 lines (171 loc) · 7.66 KB
/
tracking.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import os
import sys
import json
import datetime
import numpy as np
import random
import cv2
import colorsys
import skimage.io
from time import sleep
#from google.colab.patches import cv2_imshow
from imutils.video import FPS
from tqdm import tqdm
import math
import time
#My own library with utils functions
from utility.utility import *
# Root directory of the project
ROOT_DIR = os.path.abspath("../../")
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# Import Mask RCNN
sys.path.append(ROOT_DIR) # To find local version of the library
#Tacking using OpenCV implementation of CSRT
def opencv_tracking(video_path, detection_path, resize=1, txt_path="det/det_track_maskrcnn.txt"):
start = time.time()
#BBOX file path
f = open(txt_path, "w")
#Open stat file
stat = open("stats/stat.txt", "a")
#Convert file detection to dictionary
gt_dict = get_dict(detection_path)
params = cv2.TrackerCSRT_Params()
params.psr_threshold = 0.08
#Initialize tracker
tracker = cv2.TrackerCSRT_create(params)
# Input video
video = cv2.VideoCapture(video_path)
length_input = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
print("Totale frame: {}".format(length_input))
# Output video
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter('output/tracking.mp4',fourcc, 30.0, (int(video.get(3)),int(video.get(4))))
if not video.isOpened():
print ("Could not open video")
sys.exit()
frame_id = 0
ret = True
success = False
initBB = None
best_score = 0
det_frame = 0
track_frame = 0
fps = None
prev_box = [0, 0]
frame_diff = 0
bbox_offset = 10
with tqdm(total=length_input, file=sys.stdout) as pbar:
while ret:
ret, frame = video.read()
if not ret:
continue
#Get bbox for single frame
boxes, scores, names = [], [], []
boxes,scores,names,complete = get_gt(frame_id,gt_dict)
(H, W) = frame.shape[:2]
#Draw the detections boxes
frame = draw_bbox(frame, [], complete, show_label=False, tracking=True)
#Just log for stats
if len(boxes) > 0:
det_frame+=1
#If no bbox initialized
if initBB is None or (prev_box[0] == 0 and prev_box[1] == 0):
best_score = 0
for i, bbox in enumerate(boxes):
coor = np.array(bbox[:4], dtype=np.int32)
initBB = (coor[0] - bbox_offset, coor[1] - bbox_offset, coor[2] + 2*bbox_offset, coor[3] + 2*bbox_offset)
cv2.putText(frame, "Tracking perso. Nuovo punto: {}{}".format(initBB[0], initBB[1]), (16, 50),
cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 0, 255), 2)
if scores[i] > best_score:
tracker = cv2.TrackerCSRT_create(params)
tracker.init(frame, initBB)
fps = FPS().start()
best_score = scores[i]
else:
min_distance = 99999
for i, bbox in enumerate(boxes):
coor = np.array(bbox[:4], dtype=np.int32)
# bbox initialized with an offset to better discriminate using dome background information
initBB = (coor[0] - bbox_offset, coor[1] - bbox_offset, coor[2] + 2*bbox_offset, coor[3] + 2*bbox_offset)
#Difference between new detections and last tracking by CSRT
eucl = math.sqrt((coor[0] - prev_box[0]) ** 2 + (coor[1] - prev_box[1]) ** 2)
#Check if bbox is close enough in about N frames
if (frame_diff > 8 or eucl < 200):
#Get the closest bbox
if eucl < min_distance:
min_distance = eucl
tracker = cv2.TrackerCSRT_create(params)
#tracker.read(fp.getFirstTopLevelNode())
tracker.init(frame, initBB)
fps = FPS().start()
frame_diff = 0
else:
frame_diff += 1
cv2.putText(frame, "Frame without valid det: {}".format(frame_diff), (16, 50),
cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 0, 255), 2)
# If there is a new bbox, update the tracker
if initBB is not None:
(success, tracked_box) = tracker.update(frame)
#(success, tracked_boxes) = trackers.update(frame)
if success:
'''for i, newbox in enumerate(tracked_boxes):
p1 = (int(newbox[0]), int(newbox[1]))
p2 = (int(newbox[0] + newbox[2]), int(newbox[1] + newbox[3]))
cv2.rectangle(frame, p1, p2, (255,0,100), 6, 3)'''
#Save tracking boxes (include also det)
f.write('{},-1,{},{},{},{},{},-1,-1,-1\n'.format(frame_id, tracked_box[0], tracked_box[1], tracked_box[2], tracked_box[3], 1))
track_frame+=1
p1 = (int(tracked_box[0]), int(tracked_box[1]))
p2 = (int(tracked_box[0] + tracked_box[2]), int(tracked_box[1] + tracked_box[3]))
cv2.rectangle(frame, p1, p2, (255,0,100), 6, 3)
# update the FPS counter
fps.update()
fps.stop()
prev_box = [tracked_box[0], tracked_box[1]]
# initialize the set of information we'll be displaying on
# the frame
info = [
("Tracker", "CSRT"),
("Success", "Yes" if success else "No"),
("FPS", "{:.2f}".format(fps.fps())),
]
# loop over the info tuples and draw them on our frame
for (i, (k, v)) in enumerate(info):
text = "{}: {}".format(k, v)
cv2.putText(frame, text, (16, H - ((i * 25) + 20)),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
else:
initBB = None
#Save video frame
out.write(frame)
frame_id+=1
#Fancy print
pbar.update(1)
sleep(0.1)
stat.write("\n---- Tracking ----\n")
stat.write("Numero totale frame: {}\n".format(frame_id))
stat.write("Numero totale frame con posizione individuata: {}\n".format(det_frame))
stat.write("Numero totale frame tracked: {}\n".format(track_frame))
f.close()
out.release()
stat.close()
#Timing information
end = time.time()
print("Detections time: ", end-start)
print("FPS: {}".format(length_input/(end-start)))
if __name__ == '__main__':
import argparse
# Parse command line arguments
parser = argparse.ArgumentParser(
description='Train Mask R-CNN to detect balloons.')
parser.add_argument('--det', required=False,
default="det/det_maskrcnn.txt",
metavar="/path/to/balloon/dataset/",
help='Path to detections file')
parser.add_argument('--video', required=True,
metavar="path or URL to video",
help='Video to apply the tracking on')
args = parser.parse_args()
print("Video: ", args.video)
print("Detections: ", args.det)
opencv_tracking(args.video, args.det)