Skip to content

Commit

Permalink
fix: adding notebook example of bytetrack with car detection and trac…
Browse files Browse the repository at this point in the history
…king (#22)
  • Loading branch information
nmathieufact authored Mar 14, 2024
1 parent 7ec1476 commit 0f2f63b
Show file tree
Hide file tree
Showing 4 changed files with 378 additions and 1 deletion.
11 changes: 10 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,13 @@ dmypy.json
.pyre/

# Poetry
poetry.lock
poetry.lock

# model
*.pt

# images
*.png

# videos
*.mp4
4 changes: 4 additions & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
git+https://github.com/artefactory-fr/bytetrack.git@main
opencv-python==4.8.1.78
ultralytics==8.0.216
matplotlib==3.8.2
294 changes: 294 additions & 0 deletions examples/test_bytetrack_car.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Executive Summary: Car Detection with ByteTrack - An Introductory Guide\n",
"\n",
"This guide is designed to provide a beginner-friendly introduction to the application of ByteTrack for car detection in video footage. ByteTrack is an advanced algorithm that leverages the capabilities of the YOLO (You Only Look Once) model for object detection, specifically focusing on tracking objects across video frames.\n",
"\n",
"For more information on YOLO and ultralytics, visit [this link](https://github.com/ultralytics/ultralytics).\n",
"\n",
"For more information on ByteTrack, visit [this link](https://github.com/ifzhang/ByteTrack).\n",
"\n",
"1. **Frame Extraction**: \n",
" This video is decomposed into frames, transforming continuous video into discrete snapshots for analysis.\n",
"\n",
"2. **Detection and tracking**: \n",
" We initialize the ByteTracker object and load the pre-trained Yolo model, indicating its parameters. Going through all the frames of the video, the YOLO model enables object detection. Tracking is handled by the ByteTrack algorithm, using the bounding boxes and assigning each of it an ID that enables to track its movement.\n",
"\n",
"3. **Visualization of Tracking**: \n",
" Recomposing the video from the frames with object detected, writing it in a MP4 format to same folder.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"import glob\n",
"import matplotlib.pyplot as plt\n",
"import cv2\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"# YOLO and video packages \n",
"from ultralytics import YOLO\n",
"from bytetracker import BYTETracker\n",
"from bytetracker.basetrack import BaseTrack\n",
"from utils import draw_all_bbox_on_image, yolo_results_to_bytetrack_format, scale_bbox_as_xyxy\n",
"from IPython.display import Video"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Download the video\n",
"VIDEO_PATH = 'videos/traffic.mp4'\n",
"!if [ ! -f $VIDEO_PATH ]; then mkdir -p videos && wget https://storage.googleapis.com/bytetrack-data-public/traffic.mp4 -O $VIDEO_PATH; fi"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Reading video"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Video(VIDEO_PATH, width=800,embed=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 1. Frame Extraction "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# You can run this only once:\n",
"# Transform this VIDEO_PATH into a list of frames in this folder under frames/\n",
"!mkdir -p frames && ffmpeg -i $VIDEO_PATH -vf fps=12 frames/%d.png -hide_banner -loglevel panic"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# - list and sort PNG frames in the 'frames' directory, ensuring they are ordered numerically for subsequent processing.\n",
"# - usinglob to find all PNG files and sorts them based on the numeric part of their filenames, avoiding lexicographic order issues"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"available_frames = glob.glob(\"frames/*.png\")\n",
"available_frames = sorted(available_frames, key=lambda x: int(x.split(\"/\")[-1].split(\".\")[0]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"\n",
"MODEL_WEIGHTS = \"yolov8m.pt\"\n",
"\n",
"model = YOLO(MODEL_WEIGHTS)\n",
"results = model(available_frames[0])[0]\n",
"\n",
"plt.imshow(cv2.cvtColor(results.plot(), cv2.COLOR_BGR2RGB))\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Classes for prediction, indicating which object to detect\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"### We will track only car \n",
"CAR_CLASS_ID = 2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
" #### BYTETracker Parameters\n",
" - `track_thresh`: Threshold for considering a detection as a potential object to track.\n",
" - `track_buffer`: Number of frames to keep tracking information for an object before discarding it.\n",
" - `match_thresh`: Threshold for matching detections between consecutive frames.\n",
" - `frame_rate`: Frame rate of the video or sequence being processed."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tracker = BYTETracker(track_thresh= 0.15, track_buffer = 3, match_thresh = 0.85, frame_rate= 12)\n",
"BaseTrack._count = 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = YOLO(MODEL_WEIGHTS, task=\"detect\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 2. Detection and tracking"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"all_tracked_objects = []\n",
"for frame_id, image_filename in enumerate(available_frames):\n",
" img = cv2.imread(image_filename)\n",
" detections = model.predict(img, classes=[CAR_CLASS_ID], conf=0.15, verbose=False)[0]\n",
" detections_bytetrack_format = yolo_results_to_bytetrack_format(detections)\n",
" tracked_objects = tracker.update(detections_bytetrack_format, frame_id)\n",
" if len(tracked_objects) > 0:\n",
" tracked_objects = np.insert(tracked_objects, 0, frame_id, axis=1)\n",
" all_tracked_objects.append(tracked_objects)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Scaling the bounding boxes to match with original image size "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df_tracked = pd.DataFrame(np.concatenate(all_tracked_objects), columns=[\"frame_id\", \"x1\", \"y1\", \"x2\", \"y2\", \"track_id\", \"class\", \"confidence\"])\n",
"df_tracked[[\"x1\", \"y1\", \"x2\", \"y2\"]] = df_tracked[[\"x1\", \"y1\", \"x2\", \"y2\"]].apply(\n",
" lambda x: scale_bbox_as_xyxy(x[0:4], detections.orig_shape), axis=1, result_type=\"expand\"\n",
" )\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3. Visualization of Tracking"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fourcc = cv2.VideoWriter_fourcc(*'H264')\n",
"OUTPUT_WITH_BBOX = \"videos/traffic_tracked.mp4\"\n",
"out = cv2.VideoWriter(OUTPUT_WITH_BBOX, fourcc, 12, (1280, 720))\n",
"for frame_id, image_filename in enumerate(available_frames):\n",
" image = cv2.imread(image_filename)\n",
" if frame_id in df_tracked.frame_id.astype('int').values:\n",
" df_current_frame = df_tracked[df_tracked.frame_id == frame_id][[\"x1\", \"y1\", \"x2\", \"y2\", \"track_id\", \"class\", \"confidence\"]].to_numpy()\n",
" image = draw_all_bbox_on_image(image, df_current_frame)\n",
" out.write(image)\n",
"out.release()\n",
"print(\"Video with bounding box is saved at:\", OUTPUT_WITH_BBOX)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"Number of detected objects: \", len(df_tracked.track_id.unique()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"video_path = \"videos/traffic_tracked.mp4\"\n",
"display(Video(video_path, embed=True, width=800))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "3.8",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
70 changes: 70 additions & 0 deletions examples/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import cv2
import numpy as np


def draw_all_bbox_on_image(image, tracking_objects: np.ndarray):
"""
A list of of detections with track id, class id and confidence.
[
[x, y, x, y, track_id, class_id, conf],
[x, y, x, y, track_id, class_id, conf],
...
]
Plot this on the image with the track id, class id and confidence.
"""
for detection in tracking_objects:
x1, y1, x2, y2, track_id, _, conf = detection
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
cv2.rectangle(image, (x1, y1), (x2, y2), (255, 0, 0), 2)
cv2.putText(
image,
f"{int(track_id)} ({conf:.2f})",
(x1, y1 - 10),
0,
1,
(0, 255, 0),
2,
)
return image


def yolo_results_to_bytetrack_format(detections):
"""Transforms YOLO detections into the bytetrack format.
Args:
detections: A list of YOLO detections.
Returns:
A list of bytetrack detections.
"""
boxes = detections.numpy().boxes.xyxyn
scores = detections.numpy().boxes.conf
classes = detections.numpy().boxes.cls
return np.stack(
[
boxes[:, 0],
boxes[:, 1],
boxes[:, 2],
boxes[:, 3],
scores,
classes,
],
axis=1,
)


def scale_bbox_as_xyxy(bbox: np.ndarray, target_img_size: tuple):
"""Scales a bounding box to a target image size.
Args:
bbox: A bounding box in the format [x, y, x, y].
target_img_size: The target image size as a tuple (h, W).
Returns:
The scaled bounding box.
"""
x1, y1, x2, y2 = bbox
h, w = target_img_size
scaled_bbox = np.array([x1 * w, y1 * h, x2 * w, y2 * h])
return scaled_bbox

0 comments on commit 0f2f63b

Please sign in to comment.