Skip to content

Commit

Permalink
fea: add embedding norfair (#6)
Browse files Browse the repository at this point in the history
Add embedding in norfair notebook
  • Loading branch information
TomDarmon authored Nov 9, 2023
1 parent 5beb916 commit 2780905
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 16 deletions.
29 changes: 27 additions & 2 deletions lib/norfair_helper/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import List

import cv2
import numpy as np
from norfair import Detection
from norfair import Detection, get_cutout

from lib.bbox.utils import rescale_bbox, xy_center_to_xyxy


def yolo_to_norfair_detection(
yolo_detections: np.array, original_img_size: tuple
yolo_detections: np.array,
original_img_size: tuple,
) -> List[Detection]:
"""convert detections_as_xywh to norfair detections"""
norfair_detections: List[Detection] = []
Expand All @@ -23,3 +25,26 @@ def yolo_to_norfair_detection(
scores = np.array([detection_output[5].item(), detection_output[5].item()])
norfair_detections.append(Detection(points=bbox, scores=scores, label=detection_output[0]))
return norfair_detections


def compute_embeddings(norfair_detections: List[Detection], image: np.array):
"""
Add embedding attribute to all Detection objects in norfair_detections.
"""
for detection in norfair_detections:
object = get_cutout(detection.points, image)
if object.shape[0] > 0 and object.shape[1] > 0:
detection.embedding = get_hist(object)
return norfair_detections


def get_hist(image: np.array):
"""Compute an embedding with histograms"""
hist = cv2.calcHist(
[cv2.cvtColor(image, cv2.COLOR_BGR2Lab)],
[0, 1],
None,
[128, 128],
[0, 256, 0, 256],
)
return cv2.normalize(hist, hist).flatten()
16 changes: 11 additions & 5 deletions lib/norfair_helper/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
import numpy as np
from norfair import Tracker, draw_boxes

from lib.norfair_helper.utils import yolo_to_norfair_detection
from lib.norfair_helper.utils import compute_embeddings, yolo_to_norfair_detection
from lib.sequence import Sequence


def generate_tracking_video(
sequence: Sequence, tracker: Tracker, frame_size: tuple, output_path: str
sequence: Sequence,
tracker: Tracker,
frame_size: tuple,
output_path: str,
add_embedding: bool = False,
) -> str:
"""
Generate a video with the tracking results.
Expand All @@ -17,6 +21,7 @@ def generate_tracking_video(
tracker: The tracker to use.
frame_size: The size of the frames.
output_path: The path to save the video to.
add_embedding: Whether to add the embedding to the video.
Returns:
The path to the video.
Expand All @@ -26,11 +31,12 @@ def generate_tracking_video(
out = cv2.VideoWriter(output_path, fourcc, 20.0, frame_size) # Changed file extension to .mp4

for frame, detection in sequence:
frame = np.array(frame)
detections_list = yolo_to_norfair_detection(detection, frame_size)
if add_embedding:
detections_list = compute_embeddings(detections_list, frame)
tracked_objects = tracker.update(detections=detections_list)
frame_detected = draw_boxes(
np.array(frame), tracked_objects, draw_ids=True, color="by_label"
)
frame_detected = draw_boxes(frame, tracked_objects, draw_ids=True, color="by_label")
frame_detected = cv2.cvtColor(frame_detected, cv2.COLOR_BGR2RGB)
out.write(frame_detected)
out.release()
Expand Down
43 changes: 34 additions & 9 deletions notebooks/norfair_starter_kit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
"import sys; sys.path.append('..')\n",
"import os\n",
"\n",
"from norfair import Tracker\n",
"import cv2\n",
"from norfair import Tracker, OptimizedKalmanFilterFactory\n",
"\n",
"from lib.sequence import Sequence\n",
"from lib.norfair_helper.video import generate_tracking_video\n"
Expand Down Expand Up @@ -96,7 +97,7 @@
" detections.sort()\n",
" return detections\n",
"\n",
"frame_path = get_sequence_frames(SEQUENCES[1])\n",
"frame_path = get_sequence_frames(SEQUENCES[3])\n",
"test_sequence = Sequence(frame_path)\n",
"test_sequence"
]
Expand Down Expand Up @@ -158,6 +159,7 @@
" tracker=basic_tracker,\n",
" frame_size=(2560, 1440),\n",
" output_path=os.path.join(VIDEO_OUTPUT_PATH, \"basic_tracking.mp4\"),\n",
" add_embedding=False,\n",
")\n",
"video_path"
]
Expand All @@ -175,8 +177,29 @@
"metadata": {},
"outputs": [],
"source": [
"def reid_distance_advanced(new_object, unmatched_object):\n",
" return 0 # ALWAYS MATCH"
"def always_match(new_object, unmatched_object):\n",
" return 0 # ALWAYS MATCH\n",
"\n",
"\n",
"def embedding_distance(matched_not_init_trackers, unmatched_trackers):\n",
" snd_embedding = unmatched_trackers.last_detection.embedding\n",
"\n",
" # Find last non-empty embedding if current is None\n",
" if snd_embedding is None:\n",
" snd_embedding = next((detection.embedding for detection in reversed(unmatched_trackers.past_detections) if detection.embedding is not None), None)\n",
"\n",
" if snd_embedding is None:\n",
" return 1 # No match if no embedding is found\n",
"\n",
" # Iterate over past detections and calculate distance\n",
" for detection_fst in matched_not_init_trackers.past_detections:\n",
" if detection_fst.embedding is not None:\n",
" distance = 1 - cv2.compareHist(snd_embedding, detection_fst.embedding, cv2.HISTCMP_CORREL)\n",
" # If similar a tiny bit similar, we return the distance to the tracker\n",
" if distance < 0.9:\n",
" return distance\n",
"\n",
" return 1 # No match if no matching embedding is found between the 2"
]
},
{
Expand All @@ -187,12 +210,13 @@
"source": [
"advanced_tracker = Tracker(\n",
" distance_function=\"sqeuclidean\",\n",
" filter_factory = OptimizedKalmanFilterFactory(R=5, Q=0.05),\n",
" distance_threshold=350, # Higher value means objects further away will be matched\n",
" initialization_delay=10, # Wait 15 frames before an object is starts to be tracked\n",
" hit_counter_max=20, # Inertia, higher values means an object will take time to enter in reid phase\n",
" reid_distance_function=reid_distance_advanced, # function to decide on which metric to reid\n",
" reid_distance_threshold=0.5, # If the distance is below 0.5 the object is matched\n",
" reid_hit_counter_max=200, # inertia, higher values means an object will enter reid phase longer\n",
" initialization_delay=12, # Wait 15 frames before an object is starts to be tracked\n",
" hit_counter_max=15, # Inertia, higher values means an object will take time to enter in reid phase\n",
" reid_distance_function=embedding_distance, # function to decide on which metric to reid\n",
" reid_distance_threshold=0.9, # If the distance is below the object is matched\n",
" reid_hit_counter_max=200, #higher values means an object will stay reid phase longer\n",
" )"
]
},
Expand All @@ -207,6 +231,7 @@
" tracker=advanced_tracker,\n",
" frame_size=(2560, 1440),\n",
" output_path=os.path.join(VIDEO_OUTPUT_PATH, \"advance_tracking.mp4\"),\n",
" add_embedding=True,\n",
")\n",
"video_path"
]
Expand Down

0 comments on commit 2780905

Please sign in to comment.