-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: adding notebook example of bytetrack with car detection and trac…
…king (#22)
- Loading branch information
1 parent
7ec1476
commit 0f2f63b
Showing
4 changed files
with
378 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -129,4 +129,13 @@ dmypy.json | |
.pyre/ | ||
|
||
# Poetry | ||
poetry.lock | ||
poetry.lock | ||
|
||
# model | ||
*.pt | ||
|
||
# images | ||
*.png | ||
|
||
# videos | ||
*.mp4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
git+https://github.com/artefactory-fr/bytetrack.git@main | ||
opencv-python==4.8.1.78 | ||
ultralytics==8.0.216 | ||
matplotlib==3.8.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,294 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Executive Summary: Car Detection with ByteTrack - An Introductory Guide\n", | ||
"\n", | ||
"This guide is designed to provide a beginner-friendly introduction to the application of ByteTrack for car detection in video footage. ByteTrack is an advanced algorithm that leverages the capabilities of the YOLO (You Only Look Once) model for object detection, specifically focusing on tracking objects across video frames.\n", | ||
"\n", | ||
"For more information on YOLO and ultralytics, visit [this link](https://github.com/ultralytics/ultralytics).\n", | ||
"\n", | ||
"For more information on ByteTrack, visit [this link](https://github.com/ifzhang/ByteTrack).\n", | ||
"\n", | ||
"1. **Frame Extraction**: \n", | ||
" This video is decomposed into frames, transforming continuous video into discrete snapshots for analysis.\n", | ||
"\n", | ||
"2. **Detection and tracking**: \n", | ||
" We initialize the ByteTracker object and load the pre-trained Yolo model, indicating its parameters. Going through all the frames of the video, the YOLO model enables object detection. Tracking is handled by the ByteTrack algorithm, using the bounding boxes and assigning each of it an ID that enables to track its movement.\n", | ||
"\n", | ||
"3. **Visualization of Tracking**: \n", | ||
" Recomposing the video from the frames with object detected, writing it in a MP4 format to same folder.\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2\n", | ||
"import glob\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import cv2\n", | ||
"import numpy as np\n", | ||
"import pandas as pd\n", | ||
"\n", | ||
"# YOLO and video packages \n", | ||
"from ultralytics import YOLO\n", | ||
"from bytetracker import BYTETracker\n", | ||
"from bytetracker.basetrack import BaseTrack\n", | ||
"from utils import draw_all_bbox_on_image, yolo_results_to_bytetrack_format, scale_bbox_as_xyxy\n", | ||
"from IPython.display import Video" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Download the video\n", | ||
"VIDEO_PATH = 'videos/traffic.mp4'\n", | ||
"!if [ ! -f $VIDEO_PATH ]; then mkdir -p videos && wget https://storage.googleapis.com/bytetrack-data-public/traffic.mp4 -O $VIDEO_PATH; fi" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Reading video" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"Video(VIDEO_PATH, width=800,embed=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### 1. Frame Extraction " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# You can run this only once:\n", | ||
"# Transform this VIDEO_PATH into a list of frames in this folder under frames/\n", | ||
"!mkdir -p frames && ffmpeg -i $VIDEO_PATH -vf fps=12 frames/%d.png -hide_banner -loglevel panic" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# - list and sort PNG frames in the 'frames' directory, ensuring they are ordered numerically for subsequent processing.\n", | ||
"# - usinglob to find all PNG files and sorts them based on the numeric part of their filenames, avoiding lexicographic order issues" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"available_frames = glob.glob(\"frames/*.png\")\n", | ||
"available_frames = sorted(available_frames, key=lambda x: int(x.split(\"/\")[-1].split(\".\")[0]))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%matplotlib inline\n", | ||
"\n", | ||
"MODEL_WEIGHTS = \"yolov8m.pt\"\n", | ||
"\n", | ||
"model = YOLO(MODEL_WEIGHTS)\n", | ||
"results = model(available_frames[0])[0]\n", | ||
"\n", | ||
"plt.imshow(cv2.cvtColor(results.plot(), cv2.COLOR_BGR2RGB))\n", | ||
"plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Classes for prediction, indicating which object to detect\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"### We will track only car \n", | ||
"CAR_CLASS_ID = 2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"\n", | ||
" #### BYTETracker Parameters\n", | ||
" - `track_thresh`: Threshold for considering a detection as a potential object to track.\n", | ||
" - `track_buffer`: Number of frames to keep tracking information for an object before discarding it.\n", | ||
" - `match_thresh`: Threshold for matching detections between consecutive frames.\n", | ||
" - `frame_rate`: Frame rate of the video or sequence being processed." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"tracker = BYTETracker(track_thresh= 0.15, track_buffer = 3, match_thresh = 0.85, frame_rate= 12)\n", | ||
"BaseTrack._count = 0" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model = YOLO(MODEL_WEIGHTS, task=\"detect\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### 2. Detection and tracking" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"all_tracked_objects = []\n", | ||
"for frame_id, image_filename in enumerate(available_frames):\n", | ||
" img = cv2.imread(image_filename)\n", | ||
" detections = model.predict(img, classes=[CAR_CLASS_ID], conf=0.15, verbose=False)[0]\n", | ||
" detections_bytetrack_format = yolo_results_to_bytetrack_format(detections)\n", | ||
" tracked_objects = tracker.update(detections_bytetrack_format, frame_id)\n", | ||
" if len(tracked_objects) > 0:\n", | ||
" tracked_objects = np.insert(tracked_objects, 0, frame_id, axis=1)\n", | ||
" all_tracked_objects.append(tracked_objects)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### Scaling the bounding boxes to match with original image size " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"df_tracked = pd.DataFrame(np.concatenate(all_tracked_objects), columns=[\"frame_id\", \"x1\", \"y1\", \"x2\", \"y2\", \"track_id\", \"class\", \"confidence\"])\n", | ||
"df_tracked[[\"x1\", \"y1\", \"x2\", \"y2\"]] = df_tracked[[\"x1\", \"y1\", \"x2\", \"y2\"]].apply(\n", | ||
" lambda x: scale_bbox_as_xyxy(x[0:4], detections.orig_shape), axis=1, result_type=\"expand\"\n", | ||
" )\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"#### 3. Visualization of Tracking" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"fourcc = cv2.VideoWriter_fourcc(*'H264')\n", | ||
"OUTPUT_WITH_BBOX = \"videos/traffic_tracked.mp4\"\n", | ||
"out = cv2.VideoWriter(OUTPUT_WITH_BBOX, fourcc, 12, (1280, 720))\n", | ||
"for frame_id, image_filename in enumerate(available_frames):\n", | ||
" image = cv2.imread(image_filename)\n", | ||
" if frame_id in df_tracked.frame_id.astype('int').values:\n", | ||
" df_current_frame = df_tracked[df_tracked.frame_id == frame_id][[\"x1\", \"y1\", \"x2\", \"y2\", \"track_id\", \"class\", \"confidence\"]].to_numpy()\n", | ||
" image = draw_all_bbox_on_image(image, df_current_frame)\n", | ||
" out.write(image)\n", | ||
"out.release()\n", | ||
"print(\"Video with bounding box is saved at:\", OUTPUT_WITH_BBOX)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"print(\"Number of detected objects: \", len(df_tracked.track_id.unique()))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"video_path = \"videos/traffic_tracked.mp4\"\n", | ||
"display(Video(video_path, embed=True, width=800))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "3.8", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.13" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import cv2 | ||
import numpy as np | ||
|
||
|
||
def draw_all_bbox_on_image(image, tracking_objects: np.ndarray): | ||
""" | ||
A list of of detections with track id, class id and confidence. | ||
[ | ||
[x, y, x, y, track_id, class_id, conf], | ||
[x, y, x, y, track_id, class_id, conf], | ||
... | ||
] | ||
Plot this on the image with the track id, class id and confidence. | ||
""" | ||
for detection in tracking_objects: | ||
x1, y1, x2, y2, track_id, _, conf = detection | ||
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | ||
cv2.rectangle(image, (x1, y1), (x2, y2), (255, 0, 0), 2) | ||
cv2.putText( | ||
image, | ||
f"{int(track_id)} ({conf:.2f})", | ||
(x1, y1 - 10), | ||
0, | ||
1, | ||
(0, 255, 0), | ||
2, | ||
) | ||
return image | ||
|
||
|
||
def yolo_results_to_bytetrack_format(detections): | ||
"""Transforms YOLO detections into the bytetrack format. | ||
Args: | ||
detections: A list of YOLO detections. | ||
Returns: | ||
A list of bytetrack detections. | ||
""" | ||
boxes = detections.numpy().boxes.xyxyn | ||
scores = detections.numpy().boxes.conf | ||
classes = detections.numpy().boxes.cls | ||
return np.stack( | ||
[ | ||
boxes[:, 0], | ||
boxes[:, 1], | ||
boxes[:, 2], | ||
boxes[:, 3], | ||
scores, | ||
classes, | ||
], | ||
axis=1, | ||
) | ||
|
||
|
||
def scale_bbox_as_xyxy(bbox: np.ndarray, target_img_size: tuple): | ||
"""Scales a bounding box to a target image size. | ||
Args: | ||
bbox: A bounding box in the format [x, y, x, y]. | ||
target_img_size: The target image size as a tuple (h, W). | ||
Returns: | ||
The scaled bounding box. | ||
""" | ||
x1, y1, x2, y2 = bbox | ||
h, w = target_img_size | ||
scaled_bbox = np.array([x1 * w, y1 * h, x2 * w, y2 * h]) | ||
return scaled_bbox |