-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
Copy pathtrack.py
174 lines (147 loc) · 6.79 KB
/
track.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
import argparse
import cv2
import numpy as np
from functools import partial
from pathlib import Path
import torch
from boxmot import TRACKERS
from boxmot.tracker_zoo import create_tracker
from boxmot.utils import ROOT, WEIGHTS, TRACKER_CONFIGS
from boxmot.utils.checks import RequirementsChecker
from tracking.detectors import get_yolo_inferer
checker = RequirementsChecker()
checker.check_packages(('ultralytics @ git+https://github.com/mikel-brostrom/ultralytics.git', )) # install
from ultralytics import YOLO
from ultralytics.utils.plotting import Annotator, colors
from ultralytics.data.utils import VID_FORMATS
from ultralytics.utils.plotting import save_one_box
def on_predict_start(predictor, persist=False):
"""
Initialize trackers for object tracking during prediction.
Args:
predictor (object): The predictor object to initialize trackers for.
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
"""
assert predictor.custom_args.tracking_method in TRACKERS, \
f"'{predictor.custom_args.tracking_method}' is not supported. Supported ones are {TRACKERS}"
tracking_config = TRACKER_CONFIGS / (predictor.custom_args.tracking_method + '.yaml')
trackers = []
for i in range(predictor.dataset.bs):
tracker = create_tracker(
predictor.custom_args.tracking_method,
tracking_config,
predictor.custom_args.reid_model,
predictor.device,
predictor.custom_args.half,
predictor.custom_args.per_class
)
# motion only modeles do not have
if hasattr(tracker, 'model'):
tracker.model.warmup()
trackers.append(tracker)
predictor.trackers = trackers
@torch.no_grad()
def run(args):
ul_models = ['yolov8', 'yolov9', 'yolov10', 'yolo11', 'rtdetr', 'sam']
yolo = YOLO(
args.yolo_model if any(yolo in str(args.yolo_model) for yolo in ul_models) else 'yolov8n.pt',
)
results = yolo.track(
source=args.source,
conf=args.conf,
iou=args.iou,
agnostic_nms=args.agnostic_nms,
show=False,
stream=True,
device=args.device,
show_conf=args.show_conf,
save_txt=args.save_txt,
show_labels=args.show_labels,
save=args.save,
verbose=args.verbose,
exist_ok=args.exist_ok,
project=args.project,
name=args.name,
classes=args.classes,
imgsz=args.imgsz,
vid_stride=args.vid_stride,
line_width=args.line_width
)
yolo.add_callback('on_predict_start', partial(on_predict_start, persist=True))
if not any(yolo in str(args.yolo_model) for yolo in ul_models):
# replace yolov8 model
m = get_yolo_inferer(args.yolo_model)
model = m(
model=args.yolo_model,
device=yolo.predictor.device,
args=yolo.predictor.args
)
yolo.predictor.model = model
# store custom args in predictor
yolo.predictor.custom_args = args
for r in results:
img = yolo.predictor.trackers[0].plot_results(r.orig_img, args.show_trajectories)
if args.show is True:
cv2.imshow('BoxMOT', img)
key = cv2.waitKey(1) & 0xFF
if key == ord(' ') or key == ord('q'):
break
def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--yolo-model', type=Path, default=WEIGHTS / 'yolov8n',
help='yolo model path')
parser.add_argument('--reid-model', type=Path, default=WEIGHTS / 'osnet_x0_25_msmt17.pt',
help='reid model path')
parser.add_argument('--tracking-method', type=str, default='deepocsort',
help='deepocsort, botsort, strongsort, ocsort, bytetrack, imprassoc')
parser.add_argument('--source', type=str, default='0',
help='file/dir/URL/glob, 0 for webcam')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640],
help='inference size h,w')
parser.add_argument('--conf', type=float, default=0.5,
help='confidence threshold')
parser.add_argument('--iou', type=float, default=0.7,
help='intersection over union (IoU) threshold for NMS')
parser.add_argument('--device', default='',
help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--show', action='store_true',
help='display tracking video results')
parser.add_argument('--save', action='store_true',
help='save video tracking results')
# class 0 is person, 1 is bycicle, 2 is car... 79 is oven
parser.add_argument('--classes', nargs='+', type=int,
help='filter by class: --classes 0, or --classes 0 2 3')
parser.add_argument('--project', default=ROOT / 'runs' / 'track',
help='save results to project/name')
parser.add_argument('--name', default='exp',
help='save results to project/name')
parser.add_argument('--exist-ok', action='store_true',
help='existing project/name ok, do not increment')
parser.add_argument('--half', action='store_true',
help='use FP16 half-precision inference')
parser.add_argument('--vid-stride', type=int, default=1,
help='video frame-rate stride')
parser.add_argument('--show-labels', action='store_false',
help='either show all or only bboxes')
parser.add_argument('--show-conf', action='store_false',
help='hide confidences when show')
parser.add_argument('--show-trajectories', action='store_true',
help='show confidences')
parser.add_argument('--save-txt', action='store_true',
help='save tracking results in a txt file')
parser.add_argument('--save-id-crops', action='store_true',
help='save each crop to its respective id folder')
parser.add_argument('--line-width', default=None, type=int,
help='The line width of the bounding boxes. If None, it is scaled to the image size.')
parser.add_argument('--per-class', default=False, action='store_true',
help='not mix up classes when tracking')
parser.add_argument('--verbose', default=True, action='store_true',
help='print results per frame')
parser.add_argument('--agnostic-nms', default=False, action='store_true',
help='class-agnostic NMS')
opt = parser.parse_args()
return opt
if __name__ == "__main__":
opt = parse_opt()
run(opt)