Skip to content

Commit

Permalink
Tp/add reid lib (#4)
Browse files Browse the repository at this point in the history
* add matcher

* add reid processor

* add tracked object filter

* add metadata object

* add object filterd modif

* add tracked object class

* add utils

* add args & constants

* fix drop duplicate

* add notebook

* modify pyproject
  • Loading branch information
tristanpepinartefact authored Nov 7, 2023
1 parent 1b9d3f1 commit 308be7b
Show file tree
Hide file tree
Showing 10 changed files with 807 additions and 1 deletion.
157 changes: 157 additions & 0 deletions notebooks/starter_kit_reid.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload \n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys \n",
"sys.path.append(\"..\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from track_reid.reid_processor import ReidProcessor"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def bounding_box_distance(obj1, obj2):\n",
" # Get the bounding boxes from the Metadata of each TrackedObject\n",
" bbox1 = obj1.metadata.bbox\n",
" bbox2 = obj2.metadata.bbox\n",
"\n",
" # Calculate the Euclidean distance between the centers of the bounding boxes\n",
" center1 = ((bbox1[0] + bbox1[2]) / 2, (bbox1[1] + bbox1[3]) / 2)\n",
" center2 = ((bbox2[0] + bbox2[2]) / 2, (bbox2[1] + bbox2[3]) / 2)\n",
" distance = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)\n",
"\n",
" return distance\n",
"\n",
"def select_by_category(obj1, obj2):\n",
" # Compare the categories of the two objects\n",
" return 1 if obj1.category == obj2.category else 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Example usage:\n",
"data = np.array([\n",
" [1, 1, \"car\", 100, 200, 300, 400, 0.9],\n",
" [1, 2, \"person\", 50, 150, 200, 400, 0.8],\n",
" [2, 1, \"truck\", 120, 220, 320, 420, 0.95],\n",
" [2, 2, \"person\", 60, 160, 220, 420, 0.85],\n",
" [3, 1, \"car\", 110, 210, 310, 410, 0.91],\n",
" [3, 3, \"person\", 61, 170, 220, 420, 0.91],\n",
" [3, 4, \"car\", 60, 160, 220, 420, 0.91],\n",
" [3, 6, \"person\", 60, 160, 220, 420, 0.91],\n",
" [4, 1, \"truck\", 130, 230, 330, 430, 0.92],\n",
" [4, 2, \"person\", 65, 165, 225, 425, 0.83],\n",
" [5, 1, \"car\", 115, 215, 315, 415, 0.93],\n",
" [5, 2, \"person\", 57, 157, 207, 407, 0.84],\n",
" [5, 4, \"car\", 60, 160, 220, 420, 0.91],\n",
" [5, 8, \"person\", 60, 160, 220, 420, 0.91],\n",
"])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"processor = ReidProcessor(filter_confidence_threshold=0.4, \n",
" filter_time_threshold=0,\n",
" cost_function=bounding_box_distance,\n",
" selection_function=select_by_category,\n",
" max_attempt_to_rematch=0,\n",
" max_frames_to_rematch=100)\n",
"\n",
"\n",
"columns = ['frame_id', 'object_id', 'category', 'x1', 'y1', 'x2', 'y2', 'confidence']\n",
"df = pd.DataFrame(data, columns=columns)\n",
"# Convert numerical columns to the appropriate data type\n",
"df[['frame_id', 'object_id', 'x1', 'y1', 'x2', 'y2']] = df[['frame_id', 'object_id', 'x1', 'y1', 'x2', 'y2']].astype(int)\n",
"df['confidence'] = df['confidence'].astype(float)\n",
"\n",
"\n",
"for frame_id, frame_data in df.groupby(\"frame_id\"):\n",
"\n",
" bytetrack_output = frame_data.values\n",
" if bytetrack_output.ndim == 1 : \n",
" bytetrack_output = np.expand_dims(bytetrack_output, 0)\n",
"\n",
" results = processor.update(bytetrack_output)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(processor.all_tracked_objects[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(processor.all_tracked_objects[0].metadata)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "track-reid",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
},
"vscode": {
"interpreter": {
"hash": "a7fd834062a85a1fb9d4482d7456bec56e0ff99e4dd054f5e10ff6e3cdc923c6"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ select = [
"I",
"N",
"Q",
"RET",
"ARG",
"PTH",
"PD",
Expand Down
1 change: 1 addition & 0 deletions track_reid/args/reid_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
POSSIBLE_CLASSES = ["car", "person", "truck", "animal"]
22 changes: 22 additions & 0 deletions track_reid/constants/reid_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import ClassVar

from pydantic import BaseModel


class ReidConstants(BaseModel):
BYETRACK_OUTPUT: int = -2
FILTERED_OUTPUT: int = -1
STABLE: int = 0
SWITCHER: int = 1
CANDIDATE: int = 2

DESCRIPTION: ClassVar[dict] = {
BYETRACK_OUTPUT: "bytetrack output not in reid process",
FILTERED_OUTPUT: "bytetrack output entering reid process",
STABLE: "stable object",
SWITCHER: "lost object to be re-matched",
CANDIDATE: "new object to be matched",
}


reid_constants = ReidConstants()
91 changes: 91 additions & 0 deletions track_reid/matcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Callable, Dict, List

import numpy as np
from scipy.optimize import linear_sum_assignment

from track_reid.tracked_object import TrackedObject


class Matcher:
def __init__(self, cost_function: Callable, selection_function: Callable) -> None:
self.cost_function = cost_function
self.selection_function = selection_function

def compute_cost_matrix(
self, objects1: List[TrackedObject], objects2: List[TrackedObject]
) -> np.ndarray:
"""Computes a cost matrix of size [M, N] between a list of M TrackedObjects objects1,
and a list of N TrackedObjects objects2.
Args:
objects1 (List[TrackedObject]): list of objects to be matched.
objects2 (List[TrackedObject]): list of candidates for matches.
Returns:
np.ndarray: cost to match each pair of objects.
"""
if not objects1 or not objects2:
return np.array([]) # Return an empty array if either list is empty

# Create matrices with all combinations of objects1 and objects2
objects1_matrix, objects2_matrix = np.meshgrid(objects1, objects2)

# Use np.vectorize to apply the scoring function to all combinations
cost_matrix = np.vectorize(self.cost_function)(objects1_matrix, objects2_matrix)

return cost_matrix

def compute_selection_matrix(
self, objects1: List[TrackedObject], objects2: List[TrackedObject]
) -> np.ndarray:
"""Computes a selection matrix of size [M, N] between a list of M TrackedObjects objects1,
and a list of N TrackedObjects objects2.
Args:
objects1 (List[TrackedObject]): list of objects to be matched.
objects2 (List[TrackedObject]): list of candidates for matches.
Returns:
np.ndarray: cost each pair of objects be matched or not ?
"""
if not objects1 or not objects2:
return np.array([]) # Return an empty array if either list is empty

# Create matrices with all combinations of objects1 and objects2
objects1_matrix, objects2_matrix = np.meshgrid(objects1, objects2)

# Use np.vectorize to apply the scoring function to all combinations
selection_matrix = np.vectorize(self.selection_function)(objects1_matrix, objects2_matrix)

return selection_matrix

def match(
self, objects1: List[TrackedObject], objects2: List[TrackedObject]
) -> List[Dict[TrackedObject, TrackedObject]]:
"""Computes a dict of matching between objects in list objects1 and objects in objects2.
Args:
objects1 (List[TrackedObject]): list of objects to be matched.
objects2 (List[TrackedObject]): list of candidates for matches.
Returns:
List[Dict[TrackedObject, TrackedObject]]: list of pairs of TrackedObjects
if there is a match.
"""
if not objects1 or not objects2:
return [] # Return an empty array if either list is empty

cost_matrix = self.compute_cost_matrix(objects1, objects2)
selection_matrix = self.compute_selection_matrix(objects1, objects2)

# Set a large cost value for elements to be discarded
cost_matrix[selection_matrix == 0] = 1e3

# Find the best matches using the linear sum assignment
row_indices, col_indices = linear_sum_assignment(cost_matrix, maximize=False)

matches = []
for row, col in zip(row_indices, col_indices):
matches.append({objects1[col]: objects2[row]})

return matches
Loading

0 comments on commit 308be7b

Please sign in to comment.