-
Notifications
You must be signed in to change notification settings - Fork 0
/
video_pred.py
114 lines (94 loc) · 3.64 KB
/
video_pred.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# -*- coding:utf-8 -*-
# author:平手友梨奈ii
# e-mail:1353593259@qq.com
# datetime:1993/12/01
# filename:video_pred.py
# software: PyCharm
from detector import Detector
from tracker import matching_cascade
import numpy as np
import cv2
from PIL import Image
from utils import letterbox_image, create_tracker, box2xyah
import tensorflow as tf
from kalman_filter import KalmanFilter
# from official_code.kalman_filter import KalmanFilter
from visualize import visualize_results
# the input size
INPUT_SIZE = [416, 416]
def main(video_path,
model_path,
track_target=0,
visualize=True):
"""run video prediction
Args:
video_path: video path
model_path: model path
track_target: 0-person; 1-bicycle; 2-car; 7-truck
visualize: whether visualize tracking list
"""
detector = Detector(model_path=model_path)
kalman_filter = KalmanFilter()
capture = cv2.VideoCapture(video_path)
height = capture.get(cv2.CAP_PROP_FRAME_HEIGHT)
width = capture.get(cv2.CAP_PROP_FRAME_WIDTH)
# tracking list
tracking_list = []
label_count = 0
is_first_frame = True
while True:
success, frame = capture.read()
if not success:
capture.release()
break
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# convert to Image object
frame_pil = Image.fromarray(np.uint8(frame))
new_frame = letterbox_image(frame_pil, INPUT_SIZE)
image_array = np.expand_dims(np.array(new_frame, dtype='float32') / 255.0, axis=0)
image_shape = np.expand_dims(np.array([height, width], dtype='float32'), axis=0)
image_constant = tf.constant(image_array, dtype=tf.float32)
image_shape = tf.constant(image_shape, dtype=tf.float32)
# detect image
results = detector.detect(image_constant, image_shape)
pred_results = []
for key, value in results.items():
pred_results.append(value)
boxes = pred_results[0].numpy()
# scores = scores.numpy
classes = pred_results[2].numpy()
# find tracking targets
track_id = np.where(classes == track_target)[0]
track_boxes = boxes[track_id]
num_tracks = len(track_boxes)
if num_tracks > 0:
track_boxes = box2xyah(track_boxes)
track_boxes = [track_box for track_box in track_boxes]
if not is_first_frame:
# start tracking
tracking_list, label_count = matching_cascade(tracking_list, track_boxes,
kalman_filter, label_count)
if is_first_frame and (num_tracks > 0):
is_first_frame = False
for i in range(num_tracks):
# initialize first frame
mean_init, cov_init = kalman_filter.initiate(measurement=track_boxes[i])
# create tracker
new_tracker = create_tracker(mean=mean_init,
cov=cov_init,
detection=track_boxes[i])
tracking_list.append(new_tracker)
if visualize:
# visulize results
img = visualize_results(tracking_list, height, frame)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imshow('avoid invasion', img)
key = cv2.waitKey(30) & 0xff
if key == 27:
capture.release()
break
if __name__ == '__main__':
# test my algorithm
test_video_path = './video/person.avi'
model_path = './saved_model_coco'
main(test_video_path, model_path, visualize=True)