From 5beb916cc3d768b7e2662155572170c3226f0041 Mon Sep 17 00:00:00 2001 From: TomDarmon <36815861+TomDarmon@users.noreply.github.com> Date: Thu, 9 Nov 2023 15:37:38 +0100 Subject: [PATCH 01/13] fix: remove useless dependencies (#11) Co-authored-by: TomDarmon --- pyproject.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1557e0a..255f1fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,8 +12,6 @@ pandas = "1.5.3" numpy = "1.24.2" llist = "0.7.1" pydantic = "2.4.2" -torch = ">=1.13.1" -Cython = ">=0.29.23" bytetracker = { git = "https://github.com/TomDarmon/bytetrack-pip.git", branch = "main" } [tool.poetry.group.dev.dependencies] From 27809051beb96e5eea4f4c1733e71bc4c0e0cccc Mon Sep 17 00:00:00 2001 From: TomDarmon <36815861+TomDarmon@users.noreply.github.com> Date: Thu, 9 Nov 2023 18:26:20 +0100 Subject: [PATCH 02/13] fea: add embedding norfair (#6) Add embedding in norfair notebook --- lib/norfair_helper/utils.py | 29 +++++++++++++++++-- lib/norfair_helper/video.py | 16 +++++++---- notebooks/norfair_starter_kit.ipynb | 43 +++++++++++++++++++++++------ 3 files changed, 72 insertions(+), 16 deletions(-) diff --git a/lib/norfair_helper/utils.py b/lib/norfair_helper/utils.py index 11bf4af..aff22c8 100644 --- a/lib/norfair_helper/utils.py +++ b/lib/norfair_helper/utils.py @@ -1,13 +1,15 @@ from typing import List +import cv2 import numpy as np -from norfair import Detection +from norfair import Detection, get_cutout from lib.bbox.utils import rescale_bbox, xy_center_to_xyxy def yolo_to_norfair_detection( - yolo_detections: np.array, original_img_size: tuple + yolo_detections: np.array, + original_img_size: tuple, ) -> List[Detection]: """convert detections_as_xywh to norfair detections""" norfair_detections: List[Detection] = [] @@ -23,3 +25,26 @@ def yolo_to_norfair_detection( scores = np.array([detection_output[5].item(), detection_output[5].item()]) norfair_detections.append(Detection(points=bbox, scores=scores, label=detection_output[0])) return norfair_detections + + +def compute_embeddings(norfair_detections: List[Detection], image: np.array): + """ + Add embedding attribute to all Detection objects in norfair_detections. + """ + for detection in norfair_detections: + object = get_cutout(detection.points, image) + if object.shape[0] > 0 and object.shape[1] > 0: + detection.embedding = get_hist(object) + return norfair_detections + + +def get_hist(image: np.array): + """Compute an embedding with histograms""" + hist = cv2.calcHist( + [cv2.cvtColor(image, cv2.COLOR_BGR2Lab)], + [0, 1], + None, + [128, 128], + [0, 256, 0, 256], + ) + return cv2.normalize(hist, hist).flatten() diff --git a/lib/norfair_helper/video.py b/lib/norfair_helper/video.py index 3ea8530..d84e09d 100644 --- a/lib/norfair_helper/video.py +++ b/lib/norfair_helper/video.py @@ -2,12 +2,16 @@ import numpy as np from norfair import Tracker, draw_boxes -from lib.norfair_helper.utils import yolo_to_norfair_detection +from lib.norfair_helper.utils import compute_embeddings, yolo_to_norfair_detection from lib.sequence import Sequence def generate_tracking_video( - sequence: Sequence, tracker: Tracker, frame_size: tuple, output_path: str + sequence: Sequence, + tracker: Tracker, + frame_size: tuple, + output_path: str, + add_embedding: bool = False, ) -> str: """ Generate a video with the tracking results. @@ -17,6 +21,7 @@ def generate_tracking_video( tracker: The tracker to use. frame_size: The size of the frames. output_path: The path to save the video to. + add_embedding: Whether to add the embedding to the video. Returns: The path to the video. @@ -26,11 +31,12 @@ def generate_tracking_video( out = cv2.VideoWriter(output_path, fourcc, 20.0, frame_size) # Changed file extension to .mp4 for frame, detection in sequence: + frame = np.array(frame) detections_list = yolo_to_norfair_detection(detection, frame_size) + if add_embedding: + detections_list = compute_embeddings(detections_list, frame) tracked_objects = tracker.update(detections=detections_list) - frame_detected = draw_boxes( - np.array(frame), tracked_objects, draw_ids=True, color="by_label" - ) + frame_detected = draw_boxes(frame, tracked_objects, draw_ids=True, color="by_label") frame_detected = cv2.cvtColor(frame_detected, cv2.COLOR_BGR2RGB) out.write(frame_detected) out.release() diff --git a/notebooks/norfair_starter_kit.ipynb b/notebooks/norfair_starter_kit.ipynb index ff9ff59..c911c52 100644 --- a/notebooks/norfair_starter_kit.ipynb +++ b/notebooks/norfair_starter_kit.ipynb @@ -33,7 +33,8 @@ "import sys; sys.path.append('..')\n", "import os\n", "\n", - "from norfair import Tracker\n", + "import cv2\n", + "from norfair import Tracker, OptimizedKalmanFilterFactory\n", "\n", "from lib.sequence import Sequence\n", "from lib.norfair_helper.video import generate_tracking_video\n" @@ -96,7 +97,7 @@ " detections.sort()\n", " return detections\n", "\n", - "frame_path = get_sequence_frames(SEQUENCES[1])\n", + "frame_path = get_sequence_frames(SEQUENCES[3])\n", "test_sequence = Sequence(frame_path)\n", "test_sequence" ] @@ -158,6 +159,7 @@ " tracker=basic_tracker,\n", " frame_size=(2560, 1440),\n", " output_path=os.path.join(VIDEO_OUTPUT_PATH, \"basic_tracking.mp4\"),\n", + " add_embedding=False,\n", ")\n", "video_path" ] @@ -175,8 +177,29 @@ "metadata": {}, "outputs": [], "source": [ - "def reid_distance_advanced(new_object, unmatched_object):\n", - " return 0 # ALWAYS MATCH" + "def always_match(new_object, unmatched_object):\n", + " return 0 # ALWAYS MATCH\n", + "\n", + "\n", + "def embedding_distance(matched_not_init_trackers, unmatched_trackers):\n", + " snd_embedding = unmatched_trackers.last_detection.embedding\n", + "\n", + " # Find last non-empty embedding if current is None\n", + " if snd_embedding is None:\n", + " snd_embedding = next((detection.embedding for detection in reversed(unmatched_trackers.past_detections) if detection.embedding is not None), None)\n", + "\n", + " if snd_embedding is None:\n", + " return 1 # No match if no embedding is found\n", + "\n", + " # Iterate over past detections and calculate distance\n", + " for detection_fst in matched_not_init_trackers.past_detections:\n", + " if detection_fst.embedding is not None:\n", + " distance = 1 - cv2.compareHist(snd_embedding, detection_fst.embedding, cv2.HISTCMP_CORREL)\n", + " # If similar a tiny bit similar, we return the distance to the tracker\n", + " if distance < 0.9:\n", + " return distance\n", + "\n", + " return 1 # No match if no matching embedding is found between the 2" ] }, { @@ -187,12 +210,13 @@ "source": [ "advanced_tracker = Tracker(\n", " distance_function=\"sqeuclidean\",\n", + " filter_factory = OptimizedKalmanFilterFactory(R=5, Q=0.05),\n", " distance_threshold=350, # Higher value means objects further away will be matched\n", - " initialization_delay=10, # Wait 15 frames before an object is starts to be tracked\n", - " hit_counter_max=20, # Inertia, higher values means an object will take time to enter in reid phase\n", - " reid_distance_function=reid_distance_advanced, # function to decide on which metric to reid\n", - " reid_distance_threshold=0.5, # If the distance is below 0.5 the object is matched\n", - " reid_hit_counter_max=200, # inertia, higher values means an object will enter reid phase longer\n", + " initialization_delay=12, # Wait 15 frames before an object is starts to be tracked\n", + " hit_counter_max=15, # Inertia, higher values means an object will take time to enter in reid phase\n", + " reid_distance_function=embedding_distance, # function to decide on which metric to reid\n", + " reid_distance_threshold=0.9, # If the distance is below the object is matched\n", + " reid_hit_counter_max=200, #higher values means an object will stay reid phase longer\n", " )" ] }, @@ -207,6 +231,7 @@ " tracker=advanced_tracker,\n", " frame_size=(2560, 1440),\n", " output_path=os.path.join(VIDEO_OUTPUT_PATH, \"advance_tracking.mp4\"),\n", + " add_embedding=True,\n", ")\n", "video_path" ] From 457e9795121e590861f791454c3da9a7ebf2ae19 Mon Sep 17 00:00:00 2001 From: Tristan Pepin <122389133+tristanpepinartefact@users.noreply.github.com> Date: Fri, 10 Nov 2023 16:29:43 +0100 Subject: [PATCH 03/13] Tp/update reid to real life date (#13) Co-authored-by: TomDarmon <36815861+TomDarmon@users.noreply.github.com> Co-authored-by: TomDarmon Co-authored-by: github-actions --- .github/workflows/release.yaml | 1 + .gitignore | 1 + CHANGELOG.md | 166 +++++++++++++++++++ bin/download_sample_sequences.sh | 2 + notebooks/starter_kit_reid.ipynb | 210 ++++++++++++++++++------ pyproject.toml | 7 +- trackreid/__init__.py | 2 +- trackreid/args/reid_args.py | 19 ++- trackreid/constants/reid_constants.py | 6 +- trackreid/matcher.py | 3 +- trackreid/reid_processor.py | 228 +++++++++++++++++--------- trackreid/tracked_object.py | 41 +++-- trackreid/tracked_object_filter.py | 6 +- trackreid/tracked_object_metadata.py | 44 +++-- trackreid/utils.py | 27 ++- 15 files changed, 591 insertions(+), 172 deletions(-) create mode 100644 CHANGELOG.md diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 595481b..6a859ed 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -29,6 +29,7 @@ jobs: - uses: actions/checkout@v4 with: fetch-depth: 0 + token: ${{ secrets.GH_TOKEN }} - name: Python Semantic Release uses: python-semantic-release/python-semantic-release@master diff --git a/.gitignore b/.gitignore index 8234a9f..92b4f04 100644 --- a/.gitignore +++ b/.gitignore @@ -141,6 +141,7 @@ secrets/* # Data ignore everythin data/detections and data/frames data/detections/* data/frames/* +*.mp4 # poetry poetry.lock diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..d92e098 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,166 @@ +# CHANGELOG + + + +## 0.0.2 (2023-11-09) + +### Fix + +* fix: remove useless dependencies (#11) (#12) + +Co-authored-by: TomDarmon <tom.darmon@artefact.com> ([`c082c1c`](https://github.com/artefactory-fr/track-reid/commit/c082c1c8766840512a9c1ad58b6848ff560e4ec7)) + +* fix: env in pyproject (#10) + +Co-authored-by: TomDarmon <tom.darmon@artefact.com> ([`5e88ea0`](https://github.com/artefactory-fr/track-reid/commit/5e88ea04c00443b7202a2ff4dbaac09f2db01034)) + +### Unknown + +* Update release.yaml ([`81695d5`](https://github.com/artefactory-fr/track-reid/commit/81695d5ac249df3b7dbbf2f2bdbe7075e42bea62)) + + +## 0.0.1 (2023-11-08) + +### Fix + +* fix: version path (#9) + +Co-authored-by: TomDarmon <tom.darmon@artefact.com> ([`677204b`](https://github.com/artefactory-fr/track-reid/commit/677204bf2e755d851b5b26d9acf808b957bdf55f)) + +* fix: ci rules (#8) + +* fix: ci rules + +* fix: add version + +--------- + +Co-authored-by: TomDarmon <tom.darmon@artefact.com> ([`3b7c12d`](https://github.com/artefactory-fr/track-reid/commit/3b7c12d7b21b511bdf76b6fe0856b4d489eefcaa)) + +* fix: ci ([`db9b31a`](https://github.com/artefactory-fr/track-reid/commit/db9b31a244c97c63bb3e7073bbf780d46416a0dc)) + +### Unknown + +* Fea/release semantic versionning (#7) + +* refacto: rename package name + +* fix: poetry lock + +* refacto: adapt makefile+pyproject order+template + +* fea: release workflow + +* fix: custom token + +* fix: poetry in CI + +* fix: author + +* fix :poetry core + +* fix: install cython + +* fix: require cython + +* fix: GH token name + +* fix: doc + +* fix: install poetry pip + +* fix: syntax + +* fix: remove path + +* fix: add lapx as dependency + +* fix: official script + +* fix: remove lock + +* fix: poetry lock ignore + +* fix: remove tests + +* fix: eol + +--------- + +Co-authored-by: TomDarmon <tom.darmon@artefact.com> ([`870f546`](https://github.com/artefactory-fr/track-reid/commit/870f5466d22019c76b429a4cb4a4633e7a181fdd)) + +* Tp/add reid lib (#4) + +* add matcher + +* add reid processor + +* add tracked object filter + +* add metadata object + +* add object filterd modif + +* add tracked object class + +* add utils + +* add args & constants + +* fix drop duplicate + +* add notebook + +* modify pyproject ([`308be7b`](https://github.com/artefactory-fr/track-reid/commit/308be7b1a0febd48d90040f42aff97b86137720f)) + +* fix requirements, use forked version of bytetrack (#5) + +* fix requirements, use forked version of bytetrack + +* fix ci + +* fix makefile ([`1b9d3f1`](https://github.com/artefactory-fr/track-reid/commit/1b9d3f1adf88e785adbfb29f328de63c2047e092)) + +* fea: norfair demo nb (#3) + +* fea: norfair demo nb + +* fix: data folder + script + +* fix: improve video generation + +* fix: script + +* fix: use .txt instead of pip tools + +* fix: pre-commit + +--------- + +Co-authored-by: TomDarmon <tom.darmon@artefact.com> ([`6fbc3aa`](https://github.com/artefactory-fr/track-reid/commit/6fbc3aaf4207d8b96c2f6568accdcc5dcadf67de)) + +* Merge pull request #2 from artefactory-fr/tp/add_requirements + +Tp/add requirements ([`555ed40`](https://github.com/artefactory-fr/track-reid/commit/555ed4060bf327aff0c2009c68079cb8fb0aa68b)) + +* delete requirements developer & unify ([`14eaf69`](https://github.com/artefactory-fr/track-reid/commit/14eaf693659df82ec7a5cc50c502cd04a713c686)) + +* use make install requirements in ci ([`754195d`](https://github.com/artefactory-fr/track-reid/commit/754195d42dcc80445919a245132a9d58ce78730b)) + +* fix ci ([`56aefb3`](https://github.com/artefactory-fr/track-reid/commit/56aefb3da4485c2fec579dfc6303100bc4185cbc)) + +* fix ci ([`0334a92`](https://github.com/artefactory-fr/track-reid/commit/0334a927499391ae26beb92c4eceff1ca97789b4)) + +* fix lint ([`1b71012`](https://github.com/artefactory-fr/track-reid/commit/1b71012b5d9beca0ba4f816e1a2dac7247af6d14)) + +* Merge branch 'main' into tp/add_requirements ([`af92203`](https://github.com/artefactory-fr/track-reid/commit/af92203acdd6c472c9bafc20e94cb8c140338d93)) + +* add requirements.in ([`f32c712`](https://github.com/artefactory-fr/track-reid/commit/f32c71233701ed064dd34300db93df28e4a48cae)) + +* add requirements ([`4620d17`](https://github.com/artefactory-fr/track-reid/commit/4620d1768bd6933c4c53f9ec392677e2f1925bc4)) + +* Merge pull request #1 from artefactory-fr/fix/CI + +fix: ci ([`f9e1c83`](https://github.com/artefactory-fr/track-reid/commit/f9e1c83c8cbcda7f19515690aa9c493d53c1db24)) + +* initial commit ([`30eb0e5`](https://github.com/artefactory-fr/track-reid/commit/30eb0e53ff06540c225a0df60a0d71bea2f69471)) diff --git a/bin/download_sample_sequences.sh b/bin/download_sample_sequences.sh index d357733..b7e5b14 100644 --- a/bin/download_sample_sequences.sh +++ b/bin/download_sample_sequences.sh @@ -9,6 +9,8 @@ sequences_frames=$(gsutil ls gs://data-track-reid/frames | head -$N_SEQUENCES) sequences_detections=$(echo "$sequences_detections" | tail -n +2) sequences_frames=$(echo "$sequences_frames" | tail -n +2) +mkdir -p data/detections +mkdir -p data/frames # download the sequences to data/detections and data/frames for sequence in $sequences_detections; do diff --git a/notebooks/starter_kit_reid.ipynb b/notebooks/starter_kit_reid.ipynb index 58605c0..d8d27aa 100644 --- a/notebooks/starter_kit_reid.ipynb +++ b/notebooks/starter_kit_reid.ipynb @@ -28,7 +28,17 @@ "source": [ "import numpy as np\n", "import pandas as pd\n", - "from track_reid.reid_processor import ReidProcessor" + "from trackreid.reid_processor import ReidProcessor\n", + "from trackreid.args.reid_args import OUTPUT_POSITIONS\n", + "import cv2\n", + "from tqdm import tqdm \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Real life data" ] }, { @@ -37,21 +47,47 @@ "metadata": {}, "outputs": [], "source": [ - "def bounding_box_distance(obj1, obj2):\n", - " # Get the bounding boxes from the Metadata of each TrackedObject\n", - " bbox1 = obj1.metadata.bbox\n", - " bbox2 = obj2.metadata.bbox\n", + "import os\n", + "from lib.sequence import Sequence\n", + "from bytetracker import BYTETracker\n", + "from lib.bbox.utils import xy_center_to_xyxy, rescale_bbox" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "DATA_PATH = \"../data\"\n", + "DETECTION_PATH = f\"{DATA_PATH}/detections\"\n", + "FRAME_PATH = f\"{DATA_PATH}/frames\"\n", + "VIDEO_OUTPUT_PATH = \"private\"\n", "\n", - " # Calculate the Euclidean distance between the centers of the bounding boxes\n", - " center1 = ((bbox1[0] + bbox1[2]) / 2, (bbox1[1] + bbox1[3]) / 2)\n", - " center2 = ((bbox2[0] + bbox2[2]) / 2, (bbox2[1] + bbox2[3]) / 2)\n", - " distance = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)\n", + "SEQUENCES = os.listdir(FRAME_PATH)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_sequence_frames(sequence):\n", + " frames = os.listdir(f\"{FRAME_PATH}/{sequence}\")\n", + " frames = [os.path.join(f\"{FRAME_PATH}/{sequence}\", frame) for frame in frames]\n", + " frames.sort()\n", + " return frames\n", "\n", - " return distance\n", + "def get_sequence_detections(sequence):\n", + " detections = os.listdir(f\"{DETECTION_PATH}/{sequence}\")\n", + " detections = [os.path.join(f\"{DETECTION_PATH}/{sequence}\", detection) for detection in detections]\n", + " detections.sort()\n", + " return detections\n", "\n", - "def select_by_category(obj1, obj2):\n", - " # Compare the categories of the two objects\n", - " return 1 if obj1.category == obj2.category else 0" + "frame_path = get_sequence_frames(SEQUENCES[2])\n", + "test_sequence = Sequence(frame_path)\n", + "test_sequence" ] }, { @@ -60,23 +96,31 @@ "metadata": {}, "outputs": [], "source": [ - "# Example usage:\n", - "data = np.array([\n", - " [1, 1, \"car\", 100, 200, 300, 400, 0.9],\n", - " [1, 2, \"person\", 50, 150, 200, 400, 0.8],\n", - " [2, 1, \"truck\", 120, 220, 320, 420, 0.95],\n", - " [2, 2, \"person\", 60, 160, 220, 420, 0.85],\n", - " [3, 1, \"car\", 110, 210, 310, 410, 0.91],\n", - " [3, 3, \"person\", 61, 170, 220, 420, 0.91],\n", - " [3, 4, \"car\", 60, 160, 220, 420, 0.91],\n", - " [3, 6, \"person\", 60, 160, 220, 420, 0.91],\n", - " [4, 1, \"truck\", 130, 230, 330, 430, 0.92],\n", - " [4, 2, \"person\", 65, 165, 225, 425, 0.83],\n", - " [5, 1, \"car\", 115, 215, 315, 415, 0.93],\n", - " [5, 2, \"person\", 57, 157, 207, 407, 0.84],\n", - " [5, 4, \"car\", 60, 160, 220, 420, 0.91],\n", - " [5, 8, \"person\", 60, 160, 220, 420, 0.91],\n", - "])\n" + "class DetectionHandler():\n", + " def __init__(self, image_shape) -> None:\n", + " self.image_shape = image_shape\n", + "\n", + " def process(self, detection_output):\n", + " if detection_output.size:\n", + " if detection_output.ndim == 1:\n", + " detection_output = np.expand_dims(detection_output, 0)\n", + "\n", + " processed_detection = np.zeros(detection_output.shape)\n", + "\n", + " for idx, detection in enumerate(detection_output):\n", + " clss = detection[0]\n", + " conf = detection[5]\n", + " bbox = detection[1:5]\n", + " xyxy_bbox = xy_center_to_xyxy(bbox)\n", + " rescaled_bbox = rescale_bbox(xyxy_bbox,self.image_shape)\n", + " processed_detection[idx,:4] = rescaled_bbox\n", + " processed_detection[idx,4] = conf\n", + " processed_detection[idx,5] = clss\n", + " \n", + " return processed_detection\n", + " else:\n", + " return detection_output\n", + " " ] }, { @@ -85,28 +129,30 @@ "metadata": {}, "outputs": [], "source": [ - "processor = ReidProcessor(filter_confidence_threshold=0.4, \n", - " filter_time_threshold=0,\n", - " cost_function=bounding_box_distance,\n", - " selection_function=select_by_category,\n", - " max_attempt_to_rematch=0,\n", - " max_frames_to_rematch=100)\n", + "class TrackingHandler():\n", + " def __init__(self, tracker) -> None:\n", + " self.tracker = tracker\n", "\n", + " def update(self, detection_outputs, frame_id):\n", "\n", - "columns = ['frame_id', 'object_id', 'category', 'x1', 'y1', 'x2', 'y2', 'confidence']\n", - "df = pd.DataFrame(data, columns=columns)\n", - "# Convert numerical columns to the appropriate data type\n", - "df[['frame_id', 'object_id', 'x1', 'y1', 'x2', 'y2']] = df[['frame_id', 'object_id', 'x1', 'y1', 'x2', 'y2']].astype(int)\n", - "df['confidence'] = df['confidence'].astype(float)\n", + " if not detection_outputs.size :\n", + " return detection_outputs\n", + " \n", + " processed_detections = self._pre_process(detection_outputs)\n", + " tracked_objects = self.tracker.update(processed_detections, _ = None)\n", + " processed_tracked = self._post_process(tracked_objects, frame_id)\n", + " return processed_tracked\n", "\n", + " def _pre_process(self,detection_outputs : np.ndarray):\n", + " return detection_outputs\n", "\n", - "for frame_id, frame_data in df.groupby(\"frame_id\"):\n", + " def _post_process(self, tracked_objects : np.ndarray, frame_id):\n", "\n", - " bytetrack_output = frame_data.values\n", - " if bytetrack_output.ndim == 1 : \n", - " bytetrack_output = np.expand_dims(bytetrack_output, 0)\n", + " if tracked_objects.size :\n", + " if tracked_objects.ndim == 1:\n", + " tracked_objects = np.expand_dims(tracked_objects, 0)\n", "\n", - " results = processor.update(bytetrack_output)\n" + " return tracked_objects" ] }, { @@ -115,7 +161,75 @@ "metadata": {}, "outputs": [], "source": [ - "print(processor.all_tracked_objects[0])" + "def bounding_box_distance(obj1, obj2):\n", + " # Get the bounding boxes from the Metadata of each TrackedObject\n", + " bbox1 = obj1.metadata.bbox\n", + " bbox2 = obj2.metadata.bbox\n", + "\n", + " # Calculate the Euclidean distance between the centers of the bounding boxes\n", + " center1 = ((bbox1[0] + bbox1[2]) / 2, (bbox1[1] + bbox1[3]) / 2)\n", + " center2 = ((bbox2[0] + bbox2[2]) / 2, (bbox2[1] + bbox2[3]) / 2)\n", + " distance = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)\n", + "\n", + " return distance\n", + "\n", + "# TODO : discard by zone\n", + "def select_by_category(obj1, obj2):\n", + " # Compare the categories of the two objects\n", + " return 1 if obj1.category == obj2.category else 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from bytetracker.basetrack import BaseTrack\n", + "BaseTrack._count = 0\n", + "\n", + "# Define the codec using VideoWriter_fourcc() and create a VideoWriter object\n", + "fourcc = cv2.VideoWriter_fourcc(*'avc1') # or use 'x264'\n", + "out = cv2.VideoWriter('output_corrected.mp4', fourcc, 20.0, (2560, 1440)) # adjust the frame size (640, 480) as per your needs\n", + "\n", + "\n", + "detection_handler = DetectionHandler(image_shape=[2560, 1440])\n", + "tracking_handler = TrackingHandler(tracker=BYTETracker(track_thresh= 0.3, track_buffer = 5, match_thresh = 0.85, frame_rate= 30))\n", + "reid_processor = ReidProcessor(filter_confidence_threshold=0.1, \n", + " filter_time_threshold=0,\n", + " cost_function=bounding_box_distance,\n", + " selection_function=select_by_category,\n", + " max_attempt_to_rematch=1,\n", + " max_frames_to_rematch=100)\n", + "\n", + "frame_id = 0\n", + "\n", + "for frame, detection in tqdm(test_sequence):\n", + " frame = np.array(frame)\n", + "\n", + " frame_id += 1\n", + "\n", + " processed_detections = detection_handler.process(detection)\n", + " processed_tracked = tracking_handler.update(processed_detections, frame_id)\n", + " reid_results = reid_processor.process(processed_tracked, frame_id)\n", + "\n", + " if reid_results.size:\n", + " for res in reid_results:\n", + " object_id =res[OUTPUT_POSITIONS[\"object_id\"]]\n", + " x1, y1, x2, y2 = res[OUTPUT_POSITIONS[\"bbox\"]]\n", + " class_id = res[OUTPUT_POSITIONS[\"category\"]]\n", + " confidence_score = res[OUTPUT_POSITIONS[\"confidence\"]]\n", + "\n", + " frame_id, object_id, class_id, x1, y1, x2, y2 = int(frame_id), int(object_id), int(class_id), int(x1), int(y1), int(x2), int(y2)\n", + " color = (0, 255, 0) if class_id == 0.0 else (0, 0, 255) # green for class 0, red for class 1\n", + " cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)\n", + " cv2.putText(frame, str(object_id), (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)\n", + "\n", + "\n", + " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", + " # Write the frame to the video file\n", + " out.write(frame)\n", + "out.release()" ] }, { @@ -124,7 +238,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(processor.all_tracked_objects[0].metadata)" + "reid_processor.all_tracked_objects" ] } ], diff --git a/pyproject.toml b/pyproject.toml index 255f1fb..a7209d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "trackreid" authors = ["tristanpepinartefact "] description = "This Git repository is dedicated to the development of a Python library aimed at correcting the results of tracking algorithms. The primary goal of this library is to reconcile and reassign lost or misidentified IDs, ensuring a consistent and accurate tracking of objects over time. " -version = "0.0.1" +version = "0.0.2" readme = "README.md" @@ -102,7 +102,7 @@ profile = "black" [tool.semantic_release] -version_variables = ["deployer/__init__.py:__version__"] +version_variables = ["trackreid/__init__.py:__version__"] version_toml = ["pyproject.toml:tool.poetry.version"] branch = "main" upload_to_pypi = false @@ -113,3 +113,6 @@ tag_format = "{version}" [tool.semantic_release.changelog] exclude_commit_patterns = ['''^chore\(release\).*'''] + +[tool.semantic_release.remote] +token = { env = "GH_TOKEN" } diff --git a/trackreid/__init__.py b/trackreid/__init__.py index 6c8e6b9..3b93d0b 100644 --- a/trackreid/__init__.py +++ b/trackreid/__init__.py @@ -1 +1 @@ -__version__ = "0.0.0" +__version__ = "0.0.2" diff --git a/trackreid/args/reid_args.py b/trackreid/args/reid_args.py index bab95c1..3d61812 100644 --- a/trackreid/args/reid_args.py +++ b/trackreid/args/reid_args.py @@ -1 +1,18 @@ -POSSIBLE_CLASSES = ["car", "person", "truck", "animal"] +POSSIBLE_CLASSES = [0.0, 1.0] +MAPPING_CLASSES = {0.0: "shop_item", 1.0: "personal_item"} + +INPUT_POSITIONS = { + "object_id": 4, + "category": 5, + "bbox": [0, 1, 2, 3], + "confidence": 6, +} + +OUTPUT_POSITIONS = { + "frame_id": 0, + "object_id": 1, + "category": 2, + "bbox": [3, 4, 5, 6], + "confidence": 7, + "mean_confidence": 8, +} diff --git a/trackreid/constants/reid_constants.py b/trackreid/constants/reid_constants.py index 950bf8b..65c5a95 100644 --- a/trackreid/constants/reid_constants.py +++ b/trackreid/constants/reid_constants.py @@ -4,14 +4,16 @@ class ReidConstants(BaseModel): - BYETRACK_OUTPUT: int = -2 + LOST_FOREVER: int = -3 + TRACKER_OUTPUT: int = -2 FILTERED_OUTPUT: int = -1 STABLE: int = 0 SWITCHER: int = 1 CANDIDATE: int = 2 DESCRIPTION: ClassVar[dict] = { - BYETRACK_OUTPUT: "bytetrack output not in reid process", + LOST_FOREVER: "switcher never rematched", + TRACKER_OUTPUT: "bytetrack output not in reid process", FILTERED_OUTPUT: "bytetrack output entering reid process", STABLE: "stable object", SWITCHER: "lost object to be re-matched", diff --git a/trackreid/matcher.py b/trackreid/matcher.py index 18702a4..b3bcc26 100644 --- a/trackreid/matcher.py +++ b/trackreid/matcher.py @@ -2,7 +2,8 @@ import numpy as np from scipy.optimize import linear_sum_assignment -from track_reid.tracked_object import TrackedObject + +from trackreid.tracked_object import TrackedObject class Matcher: diff --git a/trackreid/reid_processor.py b/trackreid/reid_processor.py index 57d18a3..fe5e919 100644 --- a/trackreid/reid_processor.py +++ b/trackreid/reid_processor.py @@ -1,13 +1,20 @@ from __future__ import annotations -from typing import Dict, List, Set +from typing import Dict, List, Set, Union import numpy as np -from track_reid.constants.reid_constants import reid_constants -from track_reid.matcher import Matcher -from track_reid.tracked_object import TrackedObject -from track_reid.tracked_object_filter import TrackedObjectFilter -from track_reid.utils import filter_objects_by_state, get_top_list_correction + +from trackreid.args.reid_args import INPUT_POSITIONS, OUTPUT_POSITIONS +from trackreid.constants.reid_constants import reid_constants +from trackreid.matcher import Matcher +from trackreid.tracked_object import TrackedObject +from trackreid.tracked_object_filter import TrackedObjectFilter +from trackreid.utils import ( + filter_objects_by_state, + get_nb_output_cols, + get_top_list_correction, + reshape_tracker_result, +) class ReidProcessor: @@ -28,47 +35,66 @@ def __init__( ) self.all_tracked_objects: List[TrackedObject] = [] + self.last_frame_tracked_objects: Set[TrackedObject] = set() + self.switchers: List[TrackedObject] = [] self.candidates: List[TrackedObject] = [] - self.last_tracker_ids: Set[int] = set() - self.max_frames_to_rematch = max_frames_to_rematch self.max_attempt_to_rematch = max_attempt_to_rematch self.frame_id = 0 + self.nb_output_cols = get_nb_output_cols(output_positions=OUTPUT_POSITIONS) - def update(self, tracker_output: np.ndarray): - reshaped_tracker_output = self._reshape_input(tracker_output) - self._preprocess(tracker_output=reshaped_tracker_output) - self._perform_reid_process(tracker_output=reshaped_tracker_output) - reid_output = self._postprocess(tracker_output=tracker_output) - return reid_output - - def _preprocess(self, tracker_output: np.ndarray): - self.all_tracked_objects = self._update_tracked_objects(tracker_output=tracker_output) + def process(self, tracker_output: np.ndarray, frame_id: int): + if tracker_output.size: # empty tracking + reshaped_tracker_output = reshape_tracker_result(tracker_output=tracker_output) + self.all_tracked_objects = self._preprocess( + tracker_output=reshaped_tracker_output, frame_id=frame_id + ) + self._perform_reid_process(tracker_output=reshaped_tracker_output) + reid_output = self._postprocess(tracker_output=tracker_output) + return reid_output + else: + return tracker_output + + def _preprocess(self, tracker_output: np.ndarray, frame_id: int) -> List["TrackedObject"]: + self.all_tracked_objects = self._update_tracked_objects( + tracker_output=tracker_output, frame_id=frame_id + ) self.all_tracked_objects = self._apply_filtering() + return self.all_tracked_objects - def _update_tracked_objects(self, tracker_output: np.ndarray): - self.frame_id = tracker_output[0, 0] - for object_id, data_line in zip(tracker_output[:, 1], tracker_output): + def _update_tracked_objects(self, tracker_output: np.ndarray, frame_id: int): + self.frame_id = frame_id + for object_id, data_line in zip( + tracker_output[:, INPUT_POSITIONS["object_id"]], tracker_output + ): if object_id not in self.all_tracked_objects: new_tracked_object = TrackedObject( - object_ids=object_id, state=reid_constants.BYETRACK_OUTPUT, metadata=data_line + object_ids=object_id, + state=reid_constants.TRACKER_OUTPUT, + frame_id=frame_id, + metadata=data_line, ) self.all_tracked_objects.append(new_tracked_object) else: self.all_tracked_objects[self.all_tracked_objects.index(object_id)].update_metadata( - data_line + data_line, frame_id=frame_id ) return self.all_tracked_objects - @staticmethod - def _reshape_input(bytetrack_output: np.ndarray): - if bytetrack_output.ndim == 1: - bytetrack_output = np.expand_dims(bytetrack_output, 0) - return bytetrack_output + def _get_current_tracked_objects(self, current_tracker_ids: Set[Union[int, float]]): + tracked_objects = filter_objects_by_state( + self.all_tracked_objects, states=reid_constants.TRACKER_OUTPUT, exclusion=True + ) + + current_tracked_objects = set( + [tracked_id for tracked_id in tracked_objects if tracked_id in current_tracker_ids] + ) + + return tracked_objects, current_tracked_objects def _apply_filtering(self): for tracked_object in self.all_tracked_objects: @@ -77,59 +103,75 @@ def _apply_filtering(self): return self.all_tracked_objects def _perform_reid_process(self, tracker_output: np.ndarray): - tracked_ids = filter_objects_by_state( - self.all_tracked_objects, states=reid_constants.BYETRACK_OUTPUT, exclusion=True + current_tracker_ids: List[Union[int, float]] = list( + tracker_output[:, INPUT_POSITIONS["object_id"]] ) - current_tracker_ids = set(tracker_output[:, 1]).intersection(set(tracked_ids)) + # TODO: we can get rid of self.switchers and self.candidates by + # applying: + # candidates = filter_objects_by_state( + # self.all_tracked_objects, states=reid_constants.CANDIDATE, exclusion=False + # ) + # switchers = filter_objects_by_state( + # self.all_tracked_objects, states=reid_constants.SWITCHER, exclusion=False + # ) + + self.all_tracked_objects, self.switchers = self.correct_reid_chains( + all_tracked_objects=self.all_tracked_objects, + current_tracker_ids=current_tracker_ids, + switchers=self.switchers, + ) - self.compute_stable_objects( - current_tracker_ids=current_tracker_ids, tracked_ids=self.all_tracked_objects + tracked_objects, current_tracked_objects = self._get_current_tracked_objects( + current_tracker_ids=current_tracker_ids ) self.switchers = self.drop_switchers( - self.switchers, - current_tracker_ids, + switchers=self.switchers, + current_tracked_objects=current_tracked_objects, max_frames_to_rematch=self.max_frames_to_rematch, frame_id=self.frame_id, ) - self.candidates.extend(self.identify_candidates(tracked_ids=tracked_ids)) + self.candidates = self.drop_candidates( + self.candidates, self.max_attempt_to_rematch, self.frame_id + ) + + self.candidates.extend(self.identify_candidates(tracked_objects=tracked_objects)) self.switchers.extend( self.identify_switchers( - current_tracker_ids=current_tracker_ids, - last_bytetrack_ids=self.last_tracker_ids, - tracked_ids=tracked_ids, + current_tracked_objects=current_tracked_objects, + last_frame_tracked_objects=self.last_frame_tracked_objects, + all_tracked_objects=self.all_tracked_objects, ) ) matches = self.matcher.match(self.candidates, self.switchers) - self.process_matches( + self.all_tracked_objects, self.switchers, self.candidates = self.process_matches( all_tracked_objects=self.all_tracked_objects, matches=matches, candidates=self.candidates, switchers=self.switchers, - current_tracker_ids=current_tracker_ids, ) - self.candidates = self.drop_candidates( - self.candidates, + _, current_tracked_objects = self._get_current_tracked_objects( + current_tracker_ids=current_tracker_ids ) - self.last_tracker_ids = current_tracker_ids.copy() + self.last_frame_tracked_objects = current_tracked_objects.copy() @staticmethod def identify_switchers( - tracked_ids: List["TrackedObject"], - current_tracker_ids: Set[int], - last_bytetrack_ids: Set[int], + all_tracked_objects: List["TrackedObject"], + current_tracked_objects: Set["TrackedObject"], + last_frame_tracked_objects: Set["TrackedObject"], ): switchers = [] - lost_ids = last_bytetrack_ids - current_tracker_ids + lost_ids = last_frame_tracked_objects - current_tracked_objects - for tracked_id in tracked_ids: + for tracked_id in all_tracked_objects: if tracked_id in lost_ids: switchers.append(tracked_id) tracked_id.state = reid_constants.SWITCHER @@ -137,29 +179,46 @@ def identify_switchers( return switchers @staticmethod - def identify_candidates(tracked_ids: List["TrackedObject"]): + def identify_candidates(tracked_objects: List["TrackedObject"]): candidates = [] - for current_object in tracked_ids: + for current_object in tracked_objects: if current_object.state == reid_constants.FILTERED_OUTPUT: current_object.state = reid_constants.CANDIDATE candidates.append(current_object) return candidates @staticmethod - def compute_stable_objects(tracked_ids: list, current_tracker_ids: Set[int]): - top_list_correction = get_top_list_correction(tracked_ids) + def correct_reid_chains( + all_tracked_objects: List["TrackedObject"], + current_tracker_ids: List[Union[int, float]], + switchers: List["TrackedObject"], + ): + top_list_correction = get_top_list_correction(all_tracked_objects) for current_object in current_tracker_ids: - tracked_id = tracked_ids[tracked_ids.index(current_object)] + tracked_id = all_tracked_objects[all_tracked_objects.index(current_object)] + object_state = tracked_id.state if current_object not in top_list_correction: - tracked_ids.remove(tracked_id) + all_tracked_objects.remove(tracked_id) + if object_state == reid_constants.SWITCHER: + switchers.remove(tracked_id) + new_object, tracked_id = tracked_id.cut(current_object) - new_object.state = reid_constants.STABLE tracked_id.state = reid_constants.STABLE + all_tracked_objects.append(tracked_id) + + # 2 cases to take : + if new_object in current_tracker_ids: + new_object.state = reid_constants.STABLE + all_tracked_objects.append(new_object) + + elif new_object.nb_corrections > 1: + new_object.state = reid_constants.SWITCHER + switchers.append(new_object) + all_tracked_objects.append(new_object) - tracked_ids.append(new_object) - tracked_ids.append(tracked_id) + return all_tracked_objects, switchers @staticmethod def process_matches( @@ -167,27 +226,26 @@ def process_matches( matches: Dict["TrackedObject", "TrackedObject"], switchers: List["TrackedObject"], candidates: List["TrackedObject"], - current_tracker_ids: Set[int], ): for match in matches: candidate_match, switcher_match = match.popitem() switcher_match.merge(candidate_match) + switcher_match.state = reid_constants.STABLE all_tracked_objects.remove(candidate_match) switchers.remove(switcher_match) candidates.remove(candidate_match) - current_tracker_ids.discard(candidate_match.id) - current_tracker_ids.add(switcher_match.id) + return all_tracked_objects, switchers, candidates @staticmethod def drop_switchers( switchers: List["TrackedObject"], - current_tracker_ids: Set[int], + current_tracked_objects: Set["TrackedObject"], max_frames_to_rematch: int, frame_id: int, ): - switchers_to_drop = set(switchers).intersection(current_tracker_ids) + switchers_to_drop = set(switchers).intersection(current_tracked_objects) filtered_switchers = switchers.copy() for switcher in switchers: @@ -195,38 +253,46 @@ def drop_switchers( switcher.state = reid_constants.STABLE filtered_switchers.remove(switcher) elif switcher.get_nb_frames_since_last_appearance(frame_id) > max_frames_to_rematch: + switcher.state = reid_constants.LOST_FOREVER filtered_switchers.remove(switcher) return filtered_switchers @staticmethod - def drop_candidates(candidates: List["TrackedObject"]): + def drop_candidates( + candidates: List["TrackedObject"], max_attempt_to_rematch: int, frame_id: int + ): + filtered_candidates = candidates.copy() # for now drop candidates if there was no match - for candidate in candidates: - candidate.state = reid_constants.STABLE - return [] + for candidate in filtered_candidates: + if candidate.get_age(frame_id) >= max_attempt_to_rematch: + candidate.state = reid_constants.STABLE + candidates.remove(candidate) + return candidates def _postprocess(self, tracker_output: np.ndarray): filtered_objects = list( filter( lambda obj: obj.get_state() == reid_constants.STABLE - and obj in tracker_output[:, 1], + and obj in tracker_output[:, INPUT_POSITIONS["object_id"]], self.all_tracked_objects, ) ) - reid_output = [] - for object in filtered_objects: - reid_output.append( - [ - self.frame_id, - object.id, - object.category, - object.bbox[0], - object.bbox[1], - object.bbox[2], - object.bbox[3], - object.confidence, - ] - ) + + reid_output = np.zeros((len(filtered_objects), self.nb_output_cols)) + + for idx, object in enumerate(filtered_objects): + for required_variable in OUTPUT_POSITIONS: + if required_variable == "frame_id": + output = self.frame_id + else: + try: + output = getattr(object, required_variable) + except: # noqa: E722 + raise NameError( + f"Attribute {required_variable} not in TrackedObject.Check your required output names." + ) + + reid_output[idx, OUTPUT_POSITIONS[required_variable]] = output return reid_output diff --git a/trackreid/tracked_object.py b/trackreid/tracked_object.py index b2dd1dc..901285b 100644 --- a/trackreid/tracked_object.py +++ b/trackreid/tracked_object.py @@ -1,31 +1,36 @@ from __future__ import annotations -from typing import Union +from typing import Optional, Union import numpy as np from llist import sllist -from track_reid.constants.reid_constants import reid_constants -from track_reid.tracked_object_metadata import TrackedObjectMetaData -from track_reid.utils import split_list_around_value + +from trackreid.constants.reid_constants import reid_constants +from trackreid.tracked_object_metadata import TrackedObjectMetaData +from trackreid.utils import split_list_around_value class TrackedObject: def __init__( self, - object_ids: Union[int, sllist], + object_ids: Union[Union[float, int], sllist], state: int, metadata: Union[np.ndarray, TrackedObjectMetaData], + frame_id: Optional[int] = None, ): self.state = state - if isinstance(object_ids, int): + if isinstance(object_ids, Union[float, int]): self.re_id_chain = sllist([object_ids]) elif isinstance(object_ids, sllist): self.re_id_chain = object_ids else: raise NameError("unrocognized type for object_ids.") if isinstance(metadata, np.ndarray): - self.metadata = TrackedObjectMetaData(metadata) + assert ( + frame_id is not None + ), "Please provide a frame_id for TrackedObject initialization" + self.metadata = TrackedObjectMetaData(metadata, frame_id) elif isinstance(metadata, TrackedObjectMetaData): self.metadata = metadata.copy() else: @@ -37,16 +42,14 @@ def merge(self, other_object): # Merge the re_id_chains self.re_id_chain.extend(other_object.re_id_chain) - - # Merge the metadata (you should implement a proper merge logic in TrackedObjectMetaData) self.metadata.merge(other_object.metadata) - self.state = reid_constants.STABLE + self.state = other_object.state # Return the merged object return self @property - def id(self): + def object_id(self): return self.re_id_chain.first.value @property @@ -65,6 +68,10 @@ def mean_confidence(self): def bbox(self): return self.metadata.bbox + @property + def nb_corrections(self): + return len(self.re_id_chain) + def get_age(self, frame_id): return frame_id - self.metadata.first_frame_id @@ -75,22 +82,22 @@ def get_state(self): return self.state def __hash__(self): - return hash(self.id) + return hash(self.object_id) def __repr__(self): return ( - f"TrackedObject(current_id={self.id}, re_id_chain={list(self.re_id_chain)}" + f"TrackedObject(current_id={self.object_id}, re_id_chain={list(self.re_id_chain)}" + f", state={self.state}: {reid_constants.DESCRIPTION[self.state]})" ) def __str__(self): return f"{self.__repr__()}, metadata : {self.metadata}" - def update_metadata(self, data_line: np.ndarray): - self.metadata.update(data_line) + def update_metadata(self, data_line: np.ndarray, frame_id: int): + self.metadata.update(data_line=data_line, frame_id=frame_id) def __eq__(self, other): - if isinstance(other, int): + if isinstance(other, Union[float, int]): return other in self.re_id_chain elif isinstance(other, TrackedObject): return self.re_id_chain == other.re_id_chain @@ -112,7 +119,7 @@ def cut(self, object_id: int): def format_data(self): return [ - self.id, + self.object_id, self.category, self.bbox[0], self.bbox[1], diff --git a/trackreid/tracked_object_filter.py b/trackreid/tracked_object_filter.py index bc936b0..fdcc44f 100644 --- a/trackreid/tracked_object_filter.py +++ b/trackreid/tracked_object_filter.py @@ -1,4 +1,4 @@ -from track_reid.constants.reid_constants import reid_constants +from trackreid.constants.reid_constants import reid_constants class TrackedObjectFilter: @@ -7,7 +7,7 @@ def __init__(self, confidence_threshold, frames_seen_threshold): self.frames_seen_threshold = frames_seen_threshold def update(self, tracked_object): - if tracked_object.get_state() == reid_constants.BYETRACK_OUTPUT: + if tracked_object.get_state() == reid_constants.TRACKER_OUTPUT: if ( tracked_object.metadata.mean_confidence() > self.confidence_threshold and tracked_object.metadata.observations >= self.frames_seen_threshold @@ -15,4 +15,4 @@ def update(self, tracked_object): tracked_object.state = reid_constants.FILTERED_OUTPUT elif tracked_object.metadata.mean_confidence() < self.confidence_threshold: - tracked_object.state = reid_constants.BYETRACK_OUTPUT + tracked_object.state = reid_constants.TRACKER_OUTPUT diff --git a/trackreid/tracked_object_metadata.py b/trackreid/tracked_object_metadata.py index dc255f1..9ba974b 100644 --- a/trackreid/tracked_object_metadata.py +++ b/trackreid/tracked_object_metadata.py @@ -1,30 +1,32 @@ import json from pathlib import Path -from track_reid.args.reid_args import POSSIBLE_CLASSES +import numpy as np + +from trackreid.args.reid_args import INPUT_POSITIONS, POSSIBLE_CLASSES class TrackedObjectMetaData: - def __init__(self, data_line): - self.first_frame_id = int(data_line[0]) + def __init__(self, data_line, frame_id): + self.first_frame_id = frame_id self.class_counts = {class_name: 0 for class_name in POSSIBLE_CLASSES} self.observations = 0 self.confidence_sum = 0 self.confidence = 0 - self.update(data_line) + self.update(data_line, frame_id) - def update(self, data_line): - self.last_frame_id = int(data_line[0]) - class_name = data_line[2] + def update(self, data_line, frame_id): + self.last_frame_id = frame_id + class_name = data_line[INPUT_POSITIONS["category"]] self.class_counts[class_name] = self.class_counts.get(class_name, 0) + 1 - self.bbox = list(map(int, data_line[3:7])) - confidence = float(data_line[7]) + self.bbox = list(data_line[INPUT_POSITIONS["bbox"]].astype(int)) + confidence = float(data_line[INPUT_POSITIONS["confidence"]]) self.confidence = confidence self.confidence_sum += confidence self.observations += 1 def merge(self, other_object): - if not isinstance(other_object, TrackedObjectMetaData): + if not isinstance(other_object, type(self)): raise TypeError("Can only merge with another TrackedObjectMetaData.") self.observations += other_object.observations @@ -38,15 +40,29 @@ def merge(self, other_object): ) + other_object.class_counts.get(class_name, 0) def copy(self): - # Create a new instance of TrackedObjectMetaData with the same data + # Create a new instance of TrackedObjectMetaData + # initialize with fake data + + # TODO: make something better here, input order might change copy_obj = TrackedObjectMetaData( - [self.first_frame_id, 0, list(self.class_counts.keys())[0], *self.bbox, self.confidence] + data_line=np.array( + [ + 0, + list(self.class_counts.keys())[0], + *self.bbox, + self.confidence, + ] + ), + frame_id=self.first_frame_id, ) # Update the copied instance with the actual class counts and observations copy_obj.class_counts = self.class_counts.copy() copy_obj.observations = self.observations copy_obj.confidence_sum = self.confidence_sum copy_obj.confidence = self.confidence + copy_obj.bbox = self.bbox + copy_obj.first_frame_id = self.first_frame_id + copy_obj.last_frame_id = self.last_frame_id return copy_obj @@ -105,6 +121,6 @@ def __repr__(self) -> str: def __str__(self): return ( f"First frame seen: {self.first_frame_id}, nb observations: {self.observations}, " - + "class Proportions: {self.class_proportions()}, Bounding Box: {self.bbox}, " - + "Mean Confidence: {self.mean_confidence()}" + + f"class Proportions: {self.class_proportions()}, Bounding Box: {self.bbox}, " + + f"Mean Confidence: {self.mean_confidence()}" ) diff --git a/trackreid/utils.py b/trackreid/utils.py index 7c559f3..1e6f595 100644 --- a/trackreid/utils.py +++ b/trackreid/utils.py @@ -1,5 +1,6 @@ from typing import List, Union +import numpy as np from llist import sllist @@ -9,9 +10,12 @@ def get_top_list_correction(tracked_ids: list): return top_list_correction -def split_list_around_value(my_list: sllist, value_to_split: int): +def split_list_around_value(my_list: sllist, value_to_split: float): if value_to_split == my_list.last.value: raise NameError("split on the last") + if value_to_split not in my_list: + raise NameError(f"{value_to_split} is not in the list") + before = sllist() after = sllist() @@ -21,9 +25,9 @@ def split_list_around_value(my_list: sllist, value_to_split: int): before.append(current.value) if current.value == value_to_split: break + current = current.next current = current.next - while current: after.append(current.value) current = current.next @@ -39,3 +43,22 @@ def filter_objects_by_state(tracked_objects: List, states: Union[int, list], exc else: filtered_objects = [obj for obj in tracked_objects if obj.state in states] return filtered_objects + + +def reshape_tracker_result(tracker_output: np.ndarray): + if tracker_output.ndim == 1: + tracker_output = np.expand_dims(tracker_output, 0) + return tracker_output + + +def get_nb_output_cols(output_positions: dict): + nb_cols = 0 + for feature in output_positions.values(): + if type(feature) is int: + nb_cols += 1 + elif type(feature) is list: + nb_cols += len(feature) + else: + raise TypeError("Unkown type in required output positions.") + + return nb_cols From befb92dd81ae225562b19917e2a4e07c3ccd5c72 Mon Sep 17 00:00:00 2001 From: TomDarmon <36815861+TomDarmon@users.noreply.github.com> Date: Mon, 13 Nov 2023 16:14:41 +0100 Subject: [PATCH 04/13] fix: bytetrack internal (#15) Co-authored-by: TomDarmon --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a7209d1..82409cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,8 @@ pandas = "1.5.3" numpy = "1.24.2" llist = "0.7.1" pydantic = "2.4.2" -bytetracker = { git = "https://github.com/TomDarmon/bytetrack-pip.git", branch = "main" } +bytetracker = {git = "git@github.com:artefactory-fr/bytetrack.git", branch = "main"} + [tool.poetry.group.dev.dependencies] black = "22.10.0" From 7bc79a68d328d0ab5351ce017a59b31cddaa289d Mon Sep 17 00:00:00 2001 From: Tristan Pepin <122389133+tristanpepinartefact@users.noreply.github.com> Date: Tue, 14 Nov 2023 11:22:47 +0100 Subject: [PATCH 05/13] Tp/fix_issues (#14) --- .github/workflows/ci.yaml | 5 +- .github/workflows/deploy_docs.yaml | 6 +- .pre-commit-config.yaml | 8 - Makefile | 21 +- README.md | 12 +- bin/install_with_conda.sh | 18 - bin/install_with_venv.sh | 20 - lib/sequence.py | 1 + notebooks/starter_kit_reid.ipynb | 167 +++--- pyproject.toml | 10 +- .../unit_tests/tracked_objects/object_1.json | 28 + .../unit_tests/tracked_objects/object_24.json | 24 + .../unit_tests/tracked_objects/object_4.json | 25 + tests/unit_tests/test_matcher.py | 84 +++ tests/unit_tests/test_metadata.py | 52 ++ tests/unit_tests/test_placeholder.py | 6 - tests/unit_tests/test_tracked_objects.py | 96 ++++ tests/unit_tests/test_utils.py | 91 ++++ trackreid/args/reid_args.py | 1 + trackreid/constants/reid_constants.py | 11 +- trackreid/matcher.py | 89 ++-- trackreid/reid_processor.py | 485 ++++++++++++------ trackreid/tracked_object.py | 45 +- trackreid/tracked_object_filter.py | 6 +- trackreid/tracked_object_metadata.py | 79 ++- trackreid/utils.py | 12 + 26 files changed, 1000 insertions(+), 402 deletions(-) delete mode 100644 bin/install_with_conda.sh delete mode 100644 bin/install_with_venv.sh create mode 100644 tests/data/unit_tests/tracked_objects/object_1.json create mode 100644 tests/data/unit_tests/tracked_objects/object_24.json create mode 100644 tests/data/unit_tests/tracked_objects/object_4.json create mode 100644 tests/unit_tests/test_matcher.py create mode 100644 tests/unit_tests/test_metadata.py delete mode 100644 tests/unit_tests/test_placeholder.py create mode 100644 tests/unit_tests/test_tracked_objects.py create mode 100644 tests/unit_tests/test_utils.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 5e4111b..8d94afd 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -29,7 +29,10 @@ jobs: - name: Install requirements run: | - poetry install + make install - name: Run Pre commit hooks run: make format-code + + - name: Test with pytest + run: make run-tests diff --git a/.github/workflows/deploy_docs.yaml b/.github/workflows/deploy_docs.yaml index 3ac59df..8d4a6b6 100644 --- a/.github/workflows/deploy_docs.yaml +++ b/.github/workflows/deploy_docs.yaml @@ -15,10 +15,12 @@ jobs: uses: actions/setup-python@v2 with: python-version: "3.10" - + - name: Install poetry + run: | + make download-poetry - name: Install requirements run: | - make install_project_requirements + make install - name: Deploying MkDocs documentation run: | mkdocs build diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6643f2c..2572435 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,12 +31,4 @@ repos: types: [file] files: (.ipynb)$ language: system - - id: pytest-check - name: Tests (pytest) - stages: [push] - entry: pytest tests/ - types: [python] - language: system - pass_filenames: false - always_run: true exclude: ^(.svn|CVS|.bzr|.hg|.git|__pycache__|.tox|.ipynb_checkpoints|assets|tests/assets/|venv/|.venv/) diff --git a/Makefile b/Makefile index 1385ca7..4de9643 100644 --- a/Makefile +++ b/Makefile @@ -1,10 +1,4 @@ -PYTHON_VERSION = 3.10 -USE_CONDA ?= 1 -INSTALL_SCRIPT = install_with_conda.sh -ifeq (false,$(USE_CONDA)) - INSTALL_SCRIPT = install_with_venv.sh -endif - +PYTHON_VERSION = 3.10.13 # help: help - Display this makefile's help information .PHONY: help help: @@ -17,6 +11,7 @@ download-poetry: # help: install - Install python dependencies using poetry .PHONY: install install: + @poetry config virtualenvs.create true @poetry env use $(PYTHON_VERSION) @poetry lock -n @poetry install -n @@ -27,12 +22,6 @@ install: install-requirements: @poetry install -n - -.PHONY: install-dev-requirements -# help : install-dev-requirements - Install Python Dependencies for development -install-dev-requirements: - @poetry install -n --with dev - .PHONY: update-requirements #help: update-requirements - Update Python Dependencies (requirements.txt and requirements-dev.txt) update-requirements: @@ -43,6 +32,12 @@ update-requirements: format-code: @poetry run pre-commit run -a +.PHONY: run-tests +#help: run-tests - Run all tests with pytest +run-tests: + @export PYTHONPATH=. + @poetry run pytest + # help: deploy_docs - Deploy documentation to GitHub Pages .PHONY: deploy_docs deploy_docs: diff --git a/README.md b/README.md index 276c58b..f7015f3 100644 --- a/README.md +++ b/README.md @@ -27,13 +27,23 @@ This Git repository is dedicated to the development of a Python library aimed at ## Installation +First, install poetry: + +```bash +make download-poetry +``` + To install the required packages in a virtual environment, run the following command: ```bash make install ``` -TODO: Choose between conda and venv if necessary or let the Makefile as is and copy/paste the [MORE INFO installation section](MORE_INFO.md#eased-installation) to explain how to choose between conda and venv. +You can then activate the env with the following command: + +```bash +poetry shell +``` A complete list of available commands can be found using the following command: diff --git a/bin/install_with_conda.sh b/bin/install_with_conda.sh deleted file mode 100644 index 58b6a47..0000000 --- a/bin/install_with_conda.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash -e - -read -p "Want to install conda env named 'track-reid'? (y/n)" answer -if [ "$answer" = "y" ]; then - echo "Installing conda env..." - conda create -n track-reid python=3.10 -y - source $(conda info --base)/etc/profile.d/conda.sh - conda activate track-reid - echo "Installing requirements..." - make install_project_requirements - python3 -m ipykernel install --user --name=track-reid - conda install -c conda-forge --name track-reid notebook -y - echo "Installing pre-commit..." - make install_precommit - echo "Installation complete!"; -else - echo "Installation of conda env aborted!"; -fi diff --git a/bin/install_with_venv.sh b/bin/install_with_venv.sh deleted file mode 100644 index b3389db..0000000 --- a/bin/install_with_venv.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -e - -read -p "Want to install virtual env named 'venv' in this project ? (y/n)" answer -if [ "$answer" = "y" ]; then - echo "Installing virtual env..." - declare VENV_DIR=$(pwd)/venv - if ! [ -d "$VENV_DIR" ]; then - python3 -m venv $VENV_DIR - fi - - source $VENV_DIR/bin/activate - echo "Installing requirements..." - make install_project_requirements - python3 -m ipykernel install --user --name=venv - echo "Installing pre-commit..." - make install_precommit - echo "Installation complete!"; -else - echo "Installation of virtual env aborted!"; -fi diff --git a/lib/sequence.py b/lib/sequence.py index 97b765a..b458f0b 100644 --- a/lib/sequence.py +++ b/lib/sequence.py @@ -30,6 +30,7 @@ def __next__(self): raise StopIteration frame = Image.open(self.frame_paths[self.index]) + try: detection = np.loadtxt(self.detection_paths[self.index], dtype="float") except OSError: # file doesn't exist not detection return empty file diff --git a/notebooks/starter_kit_reid.ipynb b/notebooks/starter_kit_reid.ipynb index d8d27aa..ff99d17 100644 --- a/notebooks/starter_kit_reid.ipynb +++ b/notebooks/starter_kit_reid.ipynb @@ -6,7 +6,7 @@ "metadata": {}, "outputs": [], "source": [ - "%load_ext autoreload \n", + "%load_ext autoreload\n", "%autoreload 2" ] }, @@ -16,22 +16,21 @@ "metadata": {}, "outputs": [], "source": [ - "import sys \n", - "sys.path.append(\"..\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + "import os\n", + "import sys\n", + "\n", + "import cv2\n", "import numpy as np\n", - "import pandas as pd\n", - "from trackreid.reid_processor import ReidProcessor\n", + "from bytetracker import BYTETracker\n", + "from bytetracker.basetrack import BaseTrack\n", + "from tqdm import tqdm\n", + "\n", + "from lib.bbox.utils import rescale_bbox, xy_center_to_xyxy\n", + "from lib.sequence import Sequence\n", "from trackreid.args.reid_args import OUTPUT_POSITIONS\n", - "import cv2\n", - "from tqdm import tqdm \n" + "from trackreid.reid_processor import ReidProcessor\n", + "\n", + "sys.path.append(\"..\")\n" ] }, { @@ -41,18 +40,6 @@ "# Real life data" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "from lib.sequence import Sequence\n", - "from bytetracker import BYTETracker\n", - "from lib.bbox.utils import xy_center_to_xyxy, rescale_bbox" - ] - }, { "cell_type": "code", "execution_count": null, @@ -64,7 +51,7 @@ "FRAME_PATH = f\"{DATA_PATH}/frames\"\n", "VIDEO_OUTPUT_PATH = \"private\"\n", "\n", - "SEQUENCES = os.listdir(FRAME_PATH)\n" + "SEQUENCES = os.listdir(DETECTION_PATH)\n" ] }, { @@ -83,11 +70,7 @@ " detections = os.listdir(f\"{DETECTION_PATH}/{sequence}\")\n", " detections = [os.path.join(f\"{DETECTION_PATH}/{sequence}\", detection) for detection in detections]\n", " detections.sort()\n", - " return detections\n", - "\n", - "frame_path = get_sequence_frames(SEQUENCES[2])\n", - "test_sequence = Sequence(frame_path)\n", - "test_sequence" + " return detections" ] }, { @@ -116,11 +99,10 @@ " processed_detection[idx,:4] = rescaled_bbox\n", " processed_detection[idx,4] = conf\n", " processed_detection[idx,5] = clss\n", - " \n", + "\n", " return processed_detection\n", " else:\n", - " return detection_output\n", - " " + " return detection_output\n" ] }, { @@ -137,16 +119,16 @@ "\n", " if not detection_outputs.size :\n", " return detection_outputs\n", - " \n", + "\n", " processed_detections = self._pre_process(detection_outputs)\n", - " tracked_objects = self.tracker.update(processed_detections, _ = None)\n", - " processed_tracked = self._post_process(tracked_objects, frame_id)\n", + " tracked_objects = self.tracker.update(processed_detections, frame_id = frame_id)\n", + " processed_tracked = self._post_process(tracked_objects)\n", " return processed_tracked\n", "\n", " def _pre_process(self,detection_outputs : np.ndarray):\n", " return detection_outputs\n", "\n", - " def _post_process(self, tracked_objects : np.ndarray, frame_id):\n", + " def _post_process(self, tracked_objects : np.ndarray):\n", "\n", " if tracked_objects.size :\n", " if tracked_objects.ndim == 1:\n", @@ -185,60 +167,57 @@ "metadata": {}, "outputs": [], "source": [ - "from bytetracker.basetrack import BaseTrack\n", - "BaseTrack._count = 0\n", - "\n", - "# Define the codec using VideoWriter_fourcc() and create a VideoWriter object\n", - "fourcc = cv2.VideoWriter_fourcc(*'avc1') # or use 'x264'\n", - "out = cv2.VideoWriter('output_corrected.mp4', fourcc, 20.0, (2560, 1440)) # adjust the frame size (640, 480) as per your needs\n", - "\n", - "\n", - "detection_handler = DetectionHandler(image_shape=[2560, 1440])\n", - "tracking_handler = TrackingHandler(tracker=BYTETracker(track_thresh= 0.3, track_buffer = 5, match_thresh = 0.85, frame_rate= 30))\n", - "reid_processor = ReidProcessor(filter_confidence_threshold=0.1, \n", - " filter_time_threshold=0,\n", - " cost_function=bounding_box_distance,\n", - " selection_function=select_by_category,\n", - " max_attempt_to_rematch=1,\n", - " max_frames_to_rematch=100)\n", - "\n", - "frame_id = 0\n", - "\n", - "for frame, detection in tqdm(test_sequence):\n", - " frame = np.array(frame)\n", - "\n", - " frame_id += 1\n", - "\n", - " processed_detections = detection_handler.process(detection)\n", - " processed_tracked = tracking_handler.update(processed_detections, frame_id)\n", - " reid_results = reid_processor.process(processed_tracked, frame_id)\n", - "\n", - " if reid_results.size:\n", - " for res in reid_results:\n", - " object_id =res[OUTPUT_POSITIONS[\"object_id\"]]\n", - " x1, y1, x2, y2 = res[OUTPUT_POSITIONS[\"bbox\"]]\n", - " class_id = res[OUTPUT_POSITIONS[\"category\"]]\n", - " confidence_score = res[OUTPUT_POSITIONS[\"confidence\"]]\n", - "\n", - " frame_id, object_id, class_id, x1, y1, x2, y2 = int(frame_id), int(object_id), int(class_id), int(x1), int(y1), int(x2), int(y2)\n", - " color = (0, 255, 0) if class_id == 0.0 else (0, 0, 255) # green for class 0, red for class 1\n", - " cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)\n", - " cv2.putText(frame, str(object_id), (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)\n", - "\n", - "\n", - " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", - " # Write the frame to the video file\n", - " out.write(frame)\n", - "out.release()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "reid_processor.all_tracked_objects" + "for sequence in SEQUENCES :\n", + " frame_path = get_sequence_frames(sequence)\n", + " test_sequence = Sequence(frame_path)\n", + " test_sequence\n", + " frame_id = 0\n", + " BaseTrack._count = 0\n", + "\n", + " # Define the codec using VideoWriter_fourcc() and create a VideoWriter object\n", + " fourcc = cv2.VideoWriter_fourcc(*'avc1') # or use 'x264'\n", + " out = cv2.VideoWriter(f'{sequence}.mp4', fourcc, 20.0, (2560, 1440)) # adjust the frame size (640, 480) as per your needs\n", + "\n", + "\n", + " detection_handler = DetectionHandler(image_shape=[2560, 1440])\n", + " tracking_handler = TrackingHandler(tracker=BYTETracker(track_thresh= 0.3, track_buffer = 5, match_thresh = 0.85, frame_rate= 30))\n", + " reid_processor = ReidProcessor(filter_confidence_threshold=0.1,\n", + " filter_time_threshold=5,\n", + " cost_function=bounding_box_distance,\n", + " #cost_function_threshold=500, # max cost to rematch 2 objects\n", + " selection_function=select_by_category,\n", + " max_attempt_to_match=5,\n", + " max_frames_to_rematch=500)\n", + "\n", + " for frame, detection in tqdm(test_sequence):\n", + " frame = np.array(frame)\n", + "\n", + " frame_id += 1\n", + "\n", + " processed_detections = detection_handler.process(detection)\n", + " processed_tracked = tracking_handler.update(processed_detections, frame_id)\n", + " reid_results = reid_processor.update(processed_tracked, frame_id)\n", + "\n", + " if len(reid_results) > 0:\n", + " for res in reid_results:\n", + " object_id = int(res[OUTPUT_POSITIONS[\"object_id\"]])\n", + " bbox = list(map(int, res[OUTPUT_POSITIONS[\"bbox\"]]))\n", + " class_id = int(res[OUTPUT_POSITIONS[\"category\"]])\n", + " tracker_id = int(res[OUTPUT_POSITIONS[\"tracker_id\"]])\n", + " mean_confidence = float(res[OUTPUT_POSITIONS[\"mean_confidence\"]])\n", + " #mean_confidence_per_object[object_id].append((frame_id, mean_confidence))\n", + " x1, y1, x2, y2 = bbox\n", + " color = (0, 0, 255) if class_id else (0, 255, 0) # green for class 0, red for class 1\n", + " cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)\n", + " cv2.putText(frame, f\"{object_id} ({tracker_id}) : {round(mean_confidence,2)}\", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)\n", + " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", + "\n", + " # Write the frame to the video file\n", + " out.write(frame)\n", + " out.release()\n", + "\n", + " print(sequence, len(reid_processor.seen_objects),reid_processor.nb_corrections)\n", + " print(reid_processor.seen_objects)\n" ] } ], diff --git a/pyproject.toml b/pyproject.toml index 82409cc..df722d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,11 +7,14 @@ readme = "README.md" [tool.poetry.dependencies] -python = ">=3.8, <3.11.0" -pandas = "1.5.3" +python = "3.10.13" numpy = "1.24.2" llist = "0.7.1" pydantic = "2.4.2" +lapx = "^0.5.5" +opencv-python = "^4.8.1.78" +tqdm = "^4.66.1" +pillow = "^10.1.0" bytetracker = {git = "git@github.com:artefactory-fr/bytetrack.git", branch = "main"} @@ -21,12 +24,13 @@ ruff = "0.0.272" isort = "5.12.0" pre-commit = "3.3.3" pytest = "7.3.2" +ipykernel = "6.24.0" mkdocs = "1.4.3" mkdocs-material = "9.1.15" mkdocstrings-python = "1.1.2" bandit = "1.7.5" nbstripout = "0.6.1" -ipykernel = "6.24.0" + [build-system] diff --git a/tests/data/unit_tests/tracked_objects/object_1.json b/tests/data/unit_tests/tracked_objects/object_1.json new file mode 100644 index 0000000..13a9864 --- /dev/null +++ b/tests/data/unit_tests/tracked_objects/object_1.json @@ -0,0 +1,28 @@ +{ + "object_id": 1.0, + "state": 0, + "re_id_chain": [ + 1.0, + 2.0, + 14.0, + 18.0, + 21.0 + ], + "metadata": { + "first_frame_id": 15, + "last_frame_id": 251, + "class_counts": { + "shop_item": 175, + "personal_item": 0 + }, + "bbox": [ + 598, + 208, + 814, + 447 + ], + "confidence": 0.610211, + "confidence_sum": 111.30582399999996, + "observations": 175 + } +} diff --git a/tests/data/unit_tests/tracked_objects/object_24.json b/tests/data/unit_tests/tracked_objects/object_24.json new file mode 100644 index 0000000..29ae090 --- /dev/null +++ b/tests/data/unit_tests/tracked_objects/object_24.json @@ -0,0 +1,24 @@ +{ + "object_id": 24.0, + "state": -2, + "re_id_chain": [ + 24.0 + ], + "metadata": { + "first_frame_id": 154, + "last_frame_id": 251, + "class_counts": { + "shop_item": 2, + "personal_item": 0 + }, + "bbox": [ + 1430, + 664, + 1531, + 830 + ], + "confidence": 0.48447, + "confidence_sum": 1.108755, + "observations": 2 + } +} diff --git a/tests/data/unit_tests/tracked_objects/object_4.json b/tests/data/unit_tests/tracked_objects/object_4.json new file mode 100644 index 0000000..facc938 --- /dev/null +++ b/tests/data/unit_tests/tracked_objects/object_4.json @@ -0,0 +1,25 @@ +{ + "object_id": 4.0, + "state": 0, + "re_id_chain": [ + 4.0, + 13.0 + ], + "metadata": { + "first_frame_id": 38, + "last_frame_id": 251, + "class_counts": { + "shop_item": 0, + "personal_item": 216 + }, + "bbox": [ + 548, + 455, + 846, + 645 + ], + "confidence": 0.700626, + "confidence_sum": 149.68236100000004, + "observations": 216 + } +} diff --git a/tests/unit_tests/test_matcher.py b/tests/unit_tests/test_matcher.py new file mode 100644 index 0000000..e75aa05 --- /dev/null +++ b/tests/unit_tests/test_matcher.py @@ -0,0 +1,84 @@ +import json +from pathlib import Path + +from trackreid.matcher import Matcher +from trackreid.tracked_object import TrackedObject + +INPUT_FOLDER = Path("tests/data/unit_tests/tracked_objects") +LIST_TRACKED_OBJECTS = ["object_1.json", "object_4.json", "object_24.json"] + +ALL_TRACKED_OBJECTS = [] +for tracked_object in LIST_TRACKED_OBJECTS: + with Path.open(INPUT_FOLDER / tracked_object) as file: + ALL_TRACKED_OBJECTS.append(TrackedObject.from_dict(json.load(file))) + + +def test_matcher_no_match(): + def dummy_cost_function(candidate, switcher): + return abs(candidate.object_id - switcher.object_id) + + def dummy_selection_function(candidate, switcher): # noqa: ARG001 + return 0 + + matcher = Matcher(dummy_cost_function, dummy_selection_function) + + candidates = [] + switchers = [] + for obj in ALL_TRACKED_OBJECTS: + candidates.append(obj) + switchers.append(obj) + + matches = matcher.match(candidates, switchers) + + assert len(matches) == 0 + + +def test_matcher_all_match(): + def dummy_cost_function(candidate, switcher): + return abs(candidate.object_id - switcher.object_id) + + def dummy_selection_function(candidate, switcher): # noqa: ARG001 + return 1 + + matcher = Matcher(dummy_cost_function, dummy_selection_function) + + candidates = [] + switchers = [] + for obj in ALL_TRACKED_OBJECTS: + candidates.append(obj) + switchers.append(obj) + + print(candidates) + print(switchers) + + matches = matcher.match(candidates, switchers) + + assert len(matches) == 3 + for i in range(3): + assert matches[i][candidates[i]] == switchers[i] + + +def test_matcher_middle_case(): + def dummy_cost_function(candidate, switcher): + return abs(candidate.object_id - switcher.object_id) + + def dummy_selection_function(candidate, switcher): + return (candidate.object_id % 2 == switcher.object_id % 2) and ( + candidate.object_id != switcher.object_id + ) + + matcher = Matcher(dummy_cost_function, dummy_selection_function) + + candidates = [] + switchers = [] + for obj in ALL_TRACKED_OBJECTS: + candidates.append(obj) + switchers.append(obj) + + matches = matcher.match(candidates, switchers) + print(matches) + + assert len(matches) == 2 + for match in matches: + for candidate, switcher in match.items(): + assert candidate.object_id % 2 == switcher.object_id % 2 diff --git a/tests/unit_tests/test_metadata.py b/tests/unit_tests/test_metadata.py new file mode 100644 index 0000000..df21166 --- /dev/null +++ b/tests/unit_tests/test_metadata.py @@ -0,0 +1,52 @@ +import json +from pathlib import Path + +from trackreid.tracked_object_metadata import TrackedObjectMetaData + +INPUT_FOLDER = Path("tests/data/unit_tests/tracked_objects") +LIST_TRACKED_OBJECTS = ["object_1.json", "object_4.json", "object_24.json"] + +ALL_TRACKED_METADATA = [] +for tracked_object in LIST_TRACKED_OBJECTS: + with Path.open(INPUT_FOLDER / tracked_object) as file: + ALL_TRACKED_METADATA.append(TrackedObjectMetaData.from_dict(json.load(file)["metadata"])) + + +def test_tracked_metadata_copy(): + tracked_metadata = ALL_TRACKED_METADATA[0].copy() + copied_metadata = tracked_metadata.copy() + assert copied_metadata.first_frame_id == 15 + assert copied_metadata.last_frame_id == 251 + assert copied_metadata.class_counts == {"shop_item": 175, "personal_item": 0} + assert copied_metadata.bbox == [598, 208, 814, 447] + assert copied_metadata.confidence == 0.610211 + assert copied_metadata.confidence_sum == 111.30582399999996 + assert copied_metadata.observations == 175 + + assert round(copied_metadata.percentage_of_time_seen(251), 2) == 73.84 + class_proportions = copied_metadata.class_proportions() + assert round(class_proportions["shop_item"], 2) == 1.0 + assert round(class_proportions["personal_item"], 2) == 0.0 + + tracked_metadata_2 = ALL_TRACKED_METADATA[1].copy() + tracked_metadata.merge(tracked_metadata_2) + # test impact of merge inplace in a copy, should be none + + assert copied_metadata.class_counts == {"shop_item": 175, "personal_item": 0} + assert copied_metadata.bbox == [598, 208, 814, 447] + assert copied_metadata.confidence == 0.610211 + assert copied_metadata.confidence_sum == 111.30582399999996 + assert copied_metadata.observations == 175 + + +def test_tracked_metadata_merge(): + tracked_metadata_1 = ALL_TRACKED_METADATA[0].copy() + tracked_metadata_2 = ALL_TRACKED_METADATA[1].copy() + tracked_metadata_1.merge(tracked_metadata_2) + assert tracked_metadata_1.last_frame_id == 251 + assert tracked_metadata_1.class_counts["shop_item"] == 175 + assert tracked_metadata_1.class_counts["personal_item"] == 216 + assert tracked_metadata_1.bbox == [548, 455, 846, 645] + assert tracked_metadata_1.confidence == 0.700626 + assert tracked_metadata_1.confidence_sum == 260.988185 + assert tracked_metadata_1.observations == 391 diff --git a/tests/unit_tests/test_placeholder.py b/tests/unit_tests/test_placeholder.py deleted file mode 100644 index 338a8e0..0000000 --- a/tests/unit_tests/test_placeholder.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Placeholder test file for unit tests. To be replaced with actual tests.""" - - -def test_placeholder() -> None: - """To be replaced with actual tests.""" - pass diff --git a/tests/unit_tests/test_tracked_objects.py b/tests/unit_tests/test_tracked_objects.py new file mode 100644 index 0000000..76c2575 --- /dev/null +++ b/tests/unit_tests/test_tracked_objects.py @@ -0,0 +1,96 @@ +import json +from pathlib import Path + +from trackreid.tracked_object import TrackedObject + +INPUT_FOLDER = Path("tests/data/unit_tests/tracked_objects") +LIST_TRACKED_OBJECTS = ["object_1.json", "object_4.json", "object_24.json"] + +ALL_TRACKED_OBJECTS = [] +for tracked_object in LIST_TRACKED_OBJECTS: + with Path.open(INPUT_FOLDER / tracked_object) as file: + ALL_TRACKED_OBJECTS.append(TrackedObject.from_dict(json.load(file))) + + +def test_tracked_object_copy(): + tracked_object = ALL_TRACKED_OBJECTS[0].copy() + copied_object = tracked_object.copy() + assert copied_object.object_id == tracked_object.object_id + assert copied_object.state == tracked_object.state + assert copied_object.category == tracked_object.category + assert round(copied_object.confidence, 2) == round(tracked_object.confidence, 2) + assert round(copied_object.mean_confidence, 2) == round(tracked_object.mean_confidence, 2) + assert copied_object.bbox == tracked_object.bbox + assert copied_object.nb_ids == tracked_object.nb_ids + assert copied_object.nb_corrections == tracked_object.nb_corrections + + tracked_object_2 = ALL_TRACKED_OBJECTS[1].copy() + tracked_object.merge(tracked_object_2) + + assert round(copied_object.confidence, 2) != round(tracked_object.confidence, 2) + assert round(copied_object.mean_confidence, 2) != round(tracked_object.mean_confidence, 2) + assert copied_object.bbox != tracked_object.bbox + assert copied_object.nb_ids != tracked_object.nb_ids + assert copied_object.nb_corrections != tracked_object.nb_corrections + + +def test_tracked_object_properties(): + tracked_object = ALL_TRACKED_OBJECTS[0].copy() + assert tracked_object.object_id == 1.0 + assert tracked_object.state == 0 + assert tracked_object.category == "shop_item" + assert round(tracked_object.confidence, 2) == 0.61 + assert round(tracked_object.mean_confidence, 2) == 0.64 + assert tracked_object.bbox == [598, 208, 814, 447] + assert tracked_object.nb_ids == 5 + assert tracked_object.nb_corrections == 4 + + +def test_tracked_object_merge(): + tracked_object_1 = ALL_TRACKED_OBJECTS[0].copy() + tracked_object_2 = ALL_TRACKED_OBJECTS[1].copy() + tracked_object_1.merge(tracked_object_2) + assert tracked_object_1.object_id == 1.0 + assert tracked_object_1.state == 0 + assert tracked_object_1.category == "personal_item" + assert round(tracked_object_1.confidence, 2) == 0.70 + assert round(tracked_object_1.mean_confidence, 2) == 0.67 + assert tracked_object_1.bbox == [548, 455, 846, 645] + assert tracked_object_1.nb_ids == 7 + assert tracked_object_1.nb_corrections == 6 + + +def test_tracked_object_cut(): + tracked_object = ALL_TRACKED_OBJECTS[0].copy() + new_object, cut_object = tracked_object.cut(2.0) + assert new_object.object_id == 14.0 + assert new_object.state == 0 + assert new_object.category == "shop_item" + assert round(new_object.confidence, 2) == 0.61 + assert round(new_object.mean_confidence, 2) == 0.64 + assert new_object.bbox == [598, 208, 814, 447] + assert new_object.nb_ids == 3 + assert new_object.nb_corrections == 2 + assert cut_object.object_id == 1.0 + assert cut_object.state == 0 + assert cut_object.category == "shop_item" + assert round(cut_object.confidence, 2) == 0.61 + assert round(cut_object.mean_confidence, 2) == 0.64 + assert cut_object.bbox == [598, 208, 814, 447] + assert cut_object.nb_ids == 2 + assert cut_object.nb_corrections == 1 + + +def test_get_age(): + tracked_object = ALL_TRACKED_OBJECTS[0].copy() + assert tracked_object.get_age(100) == 85 + + +def test_get_nb_frames_since_last_appearance(): + tracked_object = ALL_TRACKED_OBJECTS[0].copy() + assert tracked_object.get_nb_frames_since_last_appearance(300) == 49 + + +def test_get_state(): + tracked_object = ALL_TRACKED_OBJECTS[0].copy() + assert tracked_object.get_state() == 0 diff --git a/tests/unit_tests/test_utils.py b/tests/unit_tests/test_utils.py new file mode 100644 index 0000000..2a15f30 --- /dev/null +++ b/tests/unit_tests/test_utils.py @@ -0,0 +1,91 @@ +import json +from pathlib import Path + +import numpy as np +from llist import sllist + +from trackreid import utils +from trackreid.tracked_object import TrackedObject + +# Load tracked object data +INPUT_FOLDER = Path("tests/data/unit_tests/tracked_objects") +LIST_TRACKED_OBJECTS = ["object_1.json", "object_4.json", "object_24.json"] + +ALL_TRACKED_OBJECTS = [] +for tracked_object in LIST_TRACKED_OBJECTS: + with Path.open(INPUT_FOLDER / tracked_object) as file: + ALL_TRACKED_OBJECTS.append(TrackedObject.from_dict(json.load(file))) + + +# Define tests +def test_get_top_list_correction(): + top_list_correction = utils.get_top_list_correction(ALL_TRACKED_OBJECTS) + assert top_list_correction == [21.0, 13.0, 24.0] + + +def test_split_list_around_value_1(): + my_list = sllist([1, 2, 3, 4, 5]) + value_to_split = 3 + before, after = utils.split_list_around_value(my_list, value_to_split) + assert list(before) == [1, 2, 3] + assert list(after) == [4, 5] + + +def test_split_list_around_value_2(): + my_list = sllist([1, 2, 3, 4, 5]) + value_to_split = 1 + before, after = utils.split_list_around_value(my_list, value_to_split) + assert list(before) == [1] + assert list(after) == [2, 3, 4, 5] + + +def test_split_list_around_value_3(): + my_list = sllist([1, 2, 3, 4, 5]) + value_to_split = 4 + before, after = utils.split_list_around_value(my_list, value_to_split) + assert list(before) == [1, 2, 3, 4] + assert list(after) == [5] + + +def test_filter_objects_by_state(): + states = 0 + assert utils.filter_objects_by_state(ALL_TRACKED_OBJECTS, states, exclusion=False) == [ + ALL_TRACKED_OBJECTS[0], + ALL_TRACKED_OBJECTS[1], + ] + + +def test_filter_objects_by_state_2(): + states = -2 + assert utils.filter_objects_by_state(ALL_TRACKED_OBJECTS, states, exclusion=True) == [ + ALL_TRACKED_OBJECTS[0], + ALL_TRACKED_OBJECTS[1], + ] + + +def test_filter_objects_by_category(): + category = "shop_item" + assert utils.filter_objects_by_category(ALL_TRACKED_OBJECTS, category, exclusion=False) == [ + ALL_TRACKED_OBJECTS[0], + ALL_TRACKED_OBJECTS[2], + ] + + +def test_filter_objects_by_category_2(): + category = "personal_item" + assert utils.filter_objects_by_category(ALL_TRACKED_OBJECTS, category, exclusion=True) == [ + ALL_TRACKED_OBJECTS[0], + ALL_TRACKED_OBJECTS[2], + ] + + +def test_reshape_tracker_result(): + tracker_output = np.array([1, 1, 3, 4, 5, 6, 7]) + assert np.array_equal( + utils.reshape_tracker_result(tracker_output), np.array([[1, 1, 3, 4, 5, 6, 7]]) + ) + + +def test_get_nb_output_cols(): + output_positions = {"feature1": 1, "feature2": [1, 2, 3]} + assert utils.get_nb_output_cols(output_positions) == 4 diff --git a/trackreid/args/reid_args.py b/trackreid/args/reid_args.py index 3d61812..2e81095 100644 --- a/trackreid/args/reid_args.py +++ b/trackreid/args/reid_args.py @@ -15,4 +15,5 @@ "bbox": [3, 4, 5, 6], "confidence": 7, "mean_confidence": 8, + "tracker_id": 9, } diff --git a/trackreid/constants/reid_constants.py b/trackreid/constants/reid_constants.py index 65c5a95..299229a 100644 --- a/trackreid/constants/reid_constants.py +++ b/trackreid/constants/reid_constants.py @@ -3,7 +3,7 @@ from pydantic import BaseModel -class ReidConstants(BaseModel): +class States(BaseModel): LOST_FOREVER: int = -3 TRACKER_OUTPUT: int = -2 FILTERED_OUTPUT: int = -1 @@ -21,4 +21,13 @@ class ReidConstants(BaseModel): } +class Matches(BaseModel): + DISALLOWED_MATCH: int = 1e6 + + +class ReidConstants(BaseModel): + STATES: States = States() + MATCHES: Matches = Matches() + + reid_constants = ReidConstants() diff --git a/trackreid/matcher.py b/trackreid/matcher.py index b3bcc26..7f58723 100644 --- a/trackreid/matcher.py +++ b/trackreid/matcher.py @@ -1,91 +1,112 @@ -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Optional, Union +import lap import numpy as np -from scipy.optimize import linear_sum_assignment +from trackreid.constants.reid_constants import reid_constants from trackreid.tracked_object import TrackedObject class Matcher: - def __init__(self, cost_function: Callable, selection_function: Callable) -> None: + def __init__( + self, + cost_function: Callable, + selection_function: Callable, + cost_function_threshold: Optional[Union[int, float]] = None, + ) -> None: self.cost_function = cost_function self.selection_function = selection_function + self.cost_function_threshold = cost_function_threshold def compute_cost_matrix( - self, objects1: List[TrackedObject], objects2: List[TrackedObject] + self, candidates: List[TrackedObject], switchers: List[TrackedObject] ) -> np.ndarray: - """Computes a cost matrix of size [M, N] between a list of M TrackedObjects objects1, - and a list of N TrackedObjects objects2. + """Computes a cost matrix of size [M, N] between a list of M TrackedObjects candidates, + and a list of N TrackedObjects switchers. Args: - objects1 (List[TrackedObject]): list of objects to be matched. - objects2 (List[TrackedObject]): list of candidates for matches. + candidates (List[TrackedObject]): list of candidates for matches. + switchers (List[TrackedObject]): list of objects to be matched. Returns: np.ndarray: cost to match each pair of objects. """ - if not objects1 or not objects2: + if not candidates or not switchers: return np.array([]) # Return an empty array if either list is empty - # Create matrices with all combinations of objects1 and objects2 - objects1_matrix, objects2_matrix = np.meshgrid(objects1, objects2) + # Create matrices with all combinations of candidates and switchers + candidates_matrix, switchers_matrix = np.meshgrid(candidates, switchers) # Use np.vectorize to apply the scoring function to all combinations - cost_matrix = np.vectorize(self.cost_function)(objects1_matrix, objects2_matrix) + cost_matrix = np.vectorize(self.cost_function)(candidates_matrix, switchers_matrix) return cost_matrix def compute_selection_matrix( - self, objects1: List[TrackedObject], objects2: List[TrackedObject] + self, candidates: List[TrackedObject], switchers: List[TrackedObject] ) -> np.ndarray: - """Computes a selection matrix of size [M, N] between a list of M TrackedObjects objects1, - and a list of N TrackedObjects objects2. + """Computes a selection matrix of size [M, N] between a list of M TrackedObjects candidates, + and a list of N TrackedObjects switchers. Args: - objects1 (List[TrackedObject]): list of objects to be matched. - objects2 (List[TrackedObject]): list of candidates for matches. + candidates (List[TrackedObject]): list of candidates for matches. + switchers (List[TrackedObject]): list of objects to be rematched. Returns: np.ndarray: cost each pair of objects be matched or not ? """ - if not objects1 or not objects2: + if not candidates or not switchers: return np.array([]) # Return an empty array if either list is empty - # Create matrices with all combinations of objects1 and objects2 - objects1_matrix, objects2_matrix = np.meshgrid(objects1, objects2) + # Create matrices with all combinations of candidates and switchers + candidates_matrix, switchers_matrix = np.meshgrid(candidates, switchers) # Use np.vectorize to apply the scoring function to all combinations - selection_matrix = np.vectorize(self.selection_function)(objects1_matrix, objects2_matrix) + selection_matrix = np.vectorize(self.selection_function)( + candidates_matrix, switchers_matrix + ) return selection_matrix def match( - self, objects1: List[TrackedObject], objects2: List[TrackedObject] + self, candidates: List[TrackedObject], switchers: List[TrackedObject] ) -> List[Dict[TrackedObject, TrackedObject]]: - """Computes a dict of matching between objects in list objects1 and objects in objects2. + """Computes a dict of matching between objects in list candidates and objects in switchers. Args: - objects1 (List[TrackedObject]): list of objects to be matched. - objects2 (List[TrackedObject]): list of candidates for matches. + candidates (List[TrackedObject]): list of candidates for matches. + switchers (List[TrackedObject]): list of objects to be matched. Returns: List[Dict[TrackedObject, TrackedObject]]: list of pairs of TrackedObjects if there is a match. """ - if not objects1 or not objects2: + if not candidates or not switchers: return [] # Return an empty array if either list is empty - cost_matrix = self.compute_cost_matrix(objects1, objects2) - selection_matrix = self.compute_selection_matrix(objects1, objects2) + cost_matrix = self.compute_cost_matrix(candidates, switchers) + selection_matrix = self.compute_selection_matrix(candidates, switchers) - # Set a large cost value for elements to be discarded - cost_matrix[selection_matrix == 0] = 1e3 + # Set a elements values to be discard at DISALLOWED_MATCH value, large cost + cost_matrix[selection_matrix == 0] = reid_constants.MATCHES.DISALLOWED_MATCH + if self.cost_function_threshold is not None: + cost_matrix[ + cost_matrix > self.cost_function_threshold + ] = reid_constants.MATCHES.DISALLOWED_MATCH - # Find the best matches using the linear sum assignment - row_indices, col_indices = linear_sum_assignment(cost_matrix, maximize=False) + matches = self.linear_assigment(cost_matrix, candidates=candidates, switchers=switchers) + + return matches + + @staticmethod + def linear_assigment(cost_matrix, candidates, switchers): + _, _, row_cols = lap.lapjv( + cost_matrix, extend_cost=True, cost_limit=reid_constants.MATCHES.DISALLOWED_MATCH - 0.1 + ) matches = [] - for row, col in zip(row_indices, col_indices): - matches.append({objects1[col]: objects2[row]}) + for candidate_idx, switcher_idx in enumerate(row_cols): + if switcher_idx >= 0: + matches.append({candidates[candidate_idx]: switchers[switcher_idx]}) return matches diff --git a/trackreid/reid_processor.py b/trackreid/reid_processor.py index fe5e919..13731da 100644 --- a/trackreid/reid_processor.py +++ b/trackreid/reid_processor.py @@ -1,10 +1,10 @@ from __future__ import annotations -from typing import Dict, List, Set, Union +from typing import Callable, Dict, List, Optional, Set, Union import numpy as np -from trackreid.args.reid_args import INPUT_POSITIONS, OUTPUT_POSITIONS +from trackreid.args.reid_args import INPUT_POSITIONS, MAPPING_CLASSES, OUTPUT_POSITIONS from trackreid.constants.reid_constants import reid_constants from trackreid.matcher import Matcher from trackreid.tracked_object import TrackedObject @@ -20,14 +20,32 @@ class ReidProcessor: def __init__( self, - filter_confidence_threshold, - filter_time_threshold, - cost_function, - selection_function, - max_frames_to_rematch: int = 100, - max_attempt_to_rematch: int = 1, + filter_confidence_threshold: float, + filter_time_threshold: int, + cost_function: Callable, + selection_function: Callable, + max_frames_to_rematch: int, + max_attempt_to_match: int, + cost_function_threshold: Optional[Union[int, float]] = None, ) -> None: - self.matcher = Matcher(cost_function=cost_function, selection_function=selection_function) + """ + Initializes the ReidProcessor class. + + Args: + filter_confidence_threshold: Confidence threshold for the filter. + filter_time_threshold: Time threshold for the filter. + cost_function: Cost function to be used. + selection_function: Selection function to be used. + max_frames_to_rematch (int): Maximum number of frames to rematch. + max_attempt_to_match (int): Maximum attempts to match. + cost_function_threshold (int, float): Maximum cost to rematch 2 objects. + """ + + self.matcher = Matcher( + cost_function=cost_function, + selection_function=selection_function, + cost_function_threshold=cost_function_threshold, + ) self.tracked_filter = TrackedObjectFilter( confidence_threshold=filter_confidence_threshold, @@ -37,35 +55,98 @@ def __init__( self.all_tracked_objects: List[TrackedObject] = [] self.last_frame_tracked_objects: Set[TrackedObject] = set() - self.switchers: List[TrackedObject] = [] - self.candidates: List[TrackedObject] = [] - self.max_frames_to_rematch = max_frames_to_rematch - self.max_attempt_to_rematch = max_attempt_to_rematch + self.max_attempt_to_match = max_attempt_to_match self.frame_id = 0 self.nb_output_cols = get_nb_output_cols(output_positions=OUTPUT_POSITIONS) - def process(self, tracker_output: np.ndarray, frame_id: int): + @property + def nb_corrections(self) -> int: + nb_corrections = 0 + for obj in self.all_tracked_objects: + nb_corrections += obj.nb_corrections + return nb_corrections + + @property + def nb_tracker_ids(self) -> int: + tracker_ids = 0 + for obj in self.all_tracked_objects: + tracker_ids += obj.nb_ids + return tracker_ids + + @property + def corrected_objects(self) -> List["TrackedObject"]: + return [obj for obj in self.all_tracked_objects if obj.nb_corrections] + + @property + def seen_objects(self) -> List["TrackedObject"]: + return filter_objects_by_state( + tracked_objects=self.all_tracked_objects, + states=[reid_constants.STATES.TRACKER_OUTPUT, reid_constants.STATES.FILTERED_OUTPUT], + exclusion=True, + ) + + @property + def mean_nb_corrections(self) -> float: + return self.nb_corrections / len(self.all_tracked_objects) + + def update( + self, tracker_output: np.ndarray, frame_id: int + ) -> Union[np.ndarray, List[TrackedObject]]: + """ + Processes the tracker output. + + Args: + tracker_output (np.ndarray): The tracker output. + frame_id (int): The frame id. + + Returns: + Union[np.ndarray, List[TrackedObject]]: The processed output. + """ if tracker_output.size: # empty tracking - reshaped_tracker_output = reshape_tracker_result(tracker_output=tracker_output) - self.all_tracked_objects = self._preprocess( - tracker_output=reshaped_tracker_output, frame_id=frame_id + self.all_tracked_objects, current_tracker_ids = self._preprocess( + tracker_output=tracker_output, frame_id=frame_id ) - self._perform_reid_process(tracker_output=reshaped_tracker_output) - reid_output = self._postprocess(tracker_output=tracker_output) + self._perform_reid_process(current_tracker_ids=current_tracker_ids) + reid_output = self._postprocess(current_tracker_ids=current_tracker_ids) return reid_output else: return tracker_output def _preprocess(self, tracker_output: np.ndarray, frame_id: int) -> List["TrackedObject"]: + """ + Preprocesses the tracker output. + + Args: + tracker_output (np.ndarray): The tracker output. + frame_id (int): The frame id. + + Returns: + List["TrackedObject"]: The preprocessed output. + """ + reshaped_tracker_output = reshape_tracker_result(tracker_output=tracker_output) + current_tracker_ids = list(reshaped_tracker_output[:, INPUT_POSITIONS["object_id"]]) + self.all_tracked_objects = self._update_tracked_objects( - tracker_output=tracker_output, frame_id=frame_id + tracker_output=reshaped_tracker_output, frame_id=frame_id ) self.all_tracked_objects = self._apply_filtering() - return self.all_tracked_objects + return self.all_tracked_objects, current_tracker_ids + + def _update_tracked_objects( + self, tracker_output: np.ndarray, frame_id: int + ) -> List[TrackedObject]: + """ + Updates the tracked objects. - def _update_tracked_objects(self, tracker_output: np.ndarray, frame_id: int): + Args: + tracker_output (np.ndarray): The tracker output. + frame_id (int): The frame id. + + Returns: + List[TrackedObject]: The updated tracked objects. + """ self.frame_id = frame_id for object_id, data_line in zip( tracker_output[:, INPUT_POSITIONS["object_id"]], tracker_output @@ -73,7 +154,7 @@ def _update_tracked_objects(self, tracker_output: np.ndarray, frame_id: int): if object_id not in self.all_tracked_objects: new_tracked_object = TrackedObject( object_ids=object_id, - state=reid_constants.TRACKER_OUTPUT, + state=reid_constants.STATES.TRACKER_OUTPUT, frame_id=frame_id, metadata=data_line, ) @@ -85,214 +166,320 @@ def _update_tracked_objects(self, tracker_output: np.ndarray, frame_id: int): return self.all_tracked_objects - def _get_current_tracked_objects(self, current_tracker_ids: Set[Union[int, float]]): + def _get_current_frame_tracked_objects( + self, current_tracker_ids: Set[Union[int, float]] + ) -> Set[Union[int, float]]: + """ + Retrieves the tracked objects for the current frame. + + Args: + current_tracker_ids (Set[Union[int, float]]): The set of current tracker IDs. + + Returns: + Set[Union[int, float]]: The set of tracked objects for the current frame. + """ tracked_objects = filter_objects_by_state( - self.all_tracked_objects, states=reid_constants.TRACKER_OUTPUT, exclusion=True + self.all_tracked_objects, states=reid_constants.STATES.TRACKER_OUTPUT, exclusion=True ) - current_tracked_objects = set( + current_frame_tracked_objects = set( [tracked_id for tracked_id in tracked_objects if tracked_id in current_tracker_ids] ) - return tracked_objects, current_tracked_objects + return current_frame_tracked_objects - def _apply_filtering(self): + def _apply_filtering(self) -> List[TrackedObject]: + """ + Applies filtering to the tracked objects. + + Returns: + List[TrackedObject]: The filtered tracked objects. + """ for tracked_object in self.all_tracked_objects: self.tracked_filter.update(tracked_object) return self.all_tracked_objects - def _perform_reid_process(self, tracker_output: np.ndarray): - current_tracker_ids: List[Union[int, float]] = list( - tracker_output[:, INPUT_POSITIONS["object_id"]] - ) - - # TODO: we can get rid of self.switchers and self.candidates by - # applying: - # candidates = filter_objects_by_state( - # self.all_tracked_objects, states=reid_constants.CANDIDATE, exclusion=False - # ) - # switchers = filter_objects_by_state( - # self.all_tracked_objects, states=reid_constants.SWITCHER, exclusion=False - # ) + def _perform_reid_process(self, current_tracker_ids: List[Union[int, float]]) -> None: + """ + Performs the reid process. - self.all_tracked_objects, self.switchers = self.correct_reid_chains( - all_tracked_objects=self.all_tracked_objects, - current_tracker_ids=current_tracker_ids, - switchers=self.switchers, + Args: + current_tracker_ids (List[Union[int, float]]): The current tracker IDs. + """ + self.all_tracked_objects = self.correct_reid_chains( + all_tracked_objects=self.all_tracked_objects, current_tracker_ids=current_tracker_ids ) - tracked_objects, current_tracked_objects = self._get_current_tracked_objects( + current_frame_tracked_objects = self._get_current_frame_tracked_objects( current_tracker_ids=current_tracker_ids ) - self.switchers = self.drop_switchers( - switchers=self.switchers, - current_tracked_objects=current_tracked_objects, + self.all_tracked_objects = self.update_switchers_states( + all_tracked_objects=self.all_tracked_objects, + current_frame_tracked_objects=current_frame_tracked_objects, max_frames_to_rematch=self.max_frames_to_rematch, frame_id=self.frame_id, ) - self.candidates = self.drop_candidates( - self.candidates, self.max_attempt_to_rematch, self.frame_id + self.all_tracked_objects = self.update_candidates_states( + all_tracked_objects=self.all_tracked_objects, + max_attempt_to_match=self.max_attempt_to_match, + frame_id=self.frame_id, ) - self.candidates.extend(self.identify_candidates(tracked_objects=tracked_objects)) + self.all_tracked_objects = self.identify_switchers( + current_frame_tracked_objects=current_frame_tracked_objects, + last_frame_tracked_objects=self.last_frame_tracked_objects, + all_tracked_objects=self.all_tracked_objects, + ) - self.switchers.extend( - self.identify_switchers( - current_tracked_objects=current_tracked_objects, - last_frame_tracked_objects=self.last_frame_tracked_objects, - all_tracked_objects=self.all_tracked_objects, - ) + self.all_tracked_objects = self.identify_candidates( + all_tracked_objects=self.all_tracked_objects + ) + + candidates = filter_objects_by_state( + self.all_tracked_objects, states=reid_constants.STATES.CANDIDATE, exclusion=False + ) + switchers = filter_objects_by_state( + self.all_tracked_objects, states=reid_constants.STATES.SWITCHER, exclusion=False ) - matches = self.matcher.match(self.candidates, self.switchers) + matches = self.matcher.match(candidates, switchers) - self.all_tracked_objects, self.switchers, self.candidates = self.process_matches( + self.all_tracked_objects = self.process_matches( all_tracked_objects=self.all_tracked_objects, matches=matches, - candidates=self.candidates, - switchers=self.switchers, ) - _, current_tracked_objects = self._get_current_tracked_objects( + current_frame_tracked_objects = self._get_current_frame_tracked_objects( current_tracker_ids=current_tracker_ids ) - self.last_frame_tracked_objects = current_tracked_objects.copy() + self.last_frame_tracked_objects = current_frame_tracked_objects.copy() @staticmethod def identify_switchers( all_tracked_objects: List["TrackedObject"], - current_tracked_objects: Set["TrackedObject"], + current_frame_tracked_objects: Set["TrackedObject"], last_frame_tracked_objects: Set["TrackedObject"], - ): - switchers = [] - lost_ids = last_frame_tracked_objects - current_tracked_objects + ) -> List["TrackedObject"]: + """ + Identifies switchers in the list of all tracked objects, and + update their states. A switcher is an object that is lost, and probably + needs to be rematched. + + Args: + all_tracked_objects (List["TrackedObject"]): List of all objects being tracked. + current_frame_tracked_objects (Set["TrackedObject"]): Set of currently tracked objects. + last_frame_tracked_objects Set["TrackedObject"]: Set of last timestep tracked objects. + + Returns: + List["TrackedObject"]: Updated list of all tracked objects after state changes. + """ + lost_objects = last_frame_tracked_objects - current_frame_tracked_objects - for tracked_id in all_tracked_objects: - if tracked_id in lost_ids: - switchers.append(tracked_id) - tracked_id.state = reid_constants.SWITCHER + for tracked_object in all_tracked_objects: + if tracked_object in lost_objects: + tracked_object.state = reid_constants.STATES.SWITCHER - return switchers + return all_tracked_objects @staticmethod - def identify_candidates(tracked_objects: List["TrackedObject"]): - candidates = [] + def identify_candidates(all_tracked_objects: List["TrackedObject"]) -> List["TrackedObject"]: + """ + Identifies candidates in the list of all tracked objects, and + update their states. A candidate is an object that was never seen before and + that probably needs to be rematched. + + Args: + all_tracked_objects (List["TrackedObject"]): List of all objects being tracked. + + Returns: + List["TrackedObject"]: Updated list of all tracked objects after state changes. + """ + tracked_objects = filter_objects_by_state( + all_tracked_objects, states=reid_constants.STATES.TRACKER_OUTPUT, exclusion=True + ) for current_object in tracked_objects: - if current_object.state == reid_constants.FILTERED_OUTPUT: - current_object.state = reid_constants.CANDIDATE - candidates.append(current_object) - return candidates + if current_object.state == reid_constants.STATES.FILTERED_OUTPUT: + current_object.state = reid_constants.STATES.CANDIDATE + return all_tracked_objects @staticmethod def correct_reid_chains( all_tracked_objects: List["TrackedObject"], current_tracker_ids: List[Union[int, float]], - switchers: List["TrackedObject"], - ): + ) -> List["TrackedObject"]: + """ + Corrects the reid chains to prevent duplicates when an object reappears with a corrected id. + For instance, if an object has a reid chain [1, 3, 6, 7], only the id 7 should be in the tracker's output. + If another id from the chain (e.g., 3) is in the tracker's output, the reid chain is split into two: + [1, 3] and [6, 7]. The first object's state is set to stable as 3 is in the current tracker output, + and a new object with reid chain [6, 7] is created. + The new object's state can be: + - stable, if the tracker output is in the new reid chain + - switcher, if not + - nothing, if this is a singleton object, in which case the reid process is performed automatically. + + Args: + all_tracked_objects (List["TrackedObject"]): List of all objects being tracked. + current_tracker_ids (List[Union[int, float]]): The current tracker IDs. + + Returns: + List["TrackedObject"]: The corrected tracked objects. + """ top_list_correction = get_top_list_correction(all_tracked_objects) + to_correct = set(current_tracker_ids) - set(top_list_correction) - for current_object in current_tracker_ids: + for current_object in to_correct: tracked_id = all_tracked_objects[all_tracked_objects.index(current_object)] - object_state = tracked_id.state - if current_object not in top_list_correction: - all_tracked_objects.remove(tracked_id) - if object_state == reid_constants.SWITCHER: - switchers.remove(tracked_id) - - new_object, tracked_id = tracked_id.cut(current_object) + all_tracked_objects.remove(tracked_id) + new_object, tracked_id = tracked_id.cut(current_object) - tracked_id.state = reid_constants.STABLE - all_tracked_objects.append(tracked_id) + tracked_id.state = reid_constants.STATES.STABLE + all_tracked_objects.append(tracked_id) - # 2 cases to take : - if new_object in current_tracker_ids: - new_object.state = reid_constants.STABLE - all_tracked_objects.append(new_object) + if new_object in current_tracker_ids: + new_object.state = reid_constants.STATES.STABLE + all_tracked_objects.append(new_object) - elif new_object.nb_corrections > 1: - new_object.state = reid_constants.SWITCHER - switchers.append(new_object) - all_tracked_objects.append(new_object) + elif new_object.nb_corrections > 1: + new_object.state = reid_constants.STATES.SWITCHER + all_tracked_objects.append(new_object) - return all_tracked_objects, switchers + return all_tracked_objects @staticmethod def process_matches( all_tracked_objects: List["TrackedObject"], matches: Dict["TrackedObject", "TrackedObject"], - switchers: List["TrackedObject"], - candidates: List["TrackedObject"], - ): + ) -> List["TrackedObject"]: + """ + Processes the matches. + + Args: + all_tracked_objects (List["TrackedObject"]): List of all objects being tracked. + matches (Dict["TrackedObject", "TrackedObject"]): The matches. + + Returns: + List["TrackedObject"]: The processed tracked objects. + """ for match in matches: candidate_match, switcher_match = match.popitem() - switcher_match.merge(candidate_match) - switcher_match.state = reid_constants.STABLE + switcher_match.state = reid_constants.STATES.STABLE all_tracked_objects.remove(candidate_match) - switchers.remove(switcher_match) - candidates.remove(candidate_match) - return all_tracked_objects, switchers, candidates + return all_tracked_objects @staticmethod - def drop_switchers( - switchers: List["TrackedObject"], - current_tracked_objects: Set["TrackedObject"], + def update_switchers_states( + all_tracked_objects: List["TrackedObject"], + current_frame_tracked_objects: Set["TrackedObject"], max_frames_to_rematch: int, frame_id: int, - ): - switchers_to_drop = set(switchers).intersection(current_tracked_objects) - filtered_switchers = switchers.copy() + ) -> List["TrackedObject"]: + """ + Updates the state of switchers in the list of all tracked objects: + - If a switcher is lost for too long, it will be flaged as lost forever + - If a switcher reapears in the tracking output, it will be flaged as + a stable object. + + Args: + all_tracked_objects (List["TrackedObject"]): List of all objects being tracked. + current_frame_tracked_objects (Set["TrackedObject"]): Set of currently tracked objects. + max_frames_to_rematch (int): Maximum number of frames to rematch. + frame_id (int): Current frame id. + + Returns: + List["TrackedObject"]: Updated list of all tracked objects after state changes. + """ + switchers = filter_objects_by_state( + all_tracked_objects, reid_constants.STATES.SWITCHER, exclusion=False + ) + switchers_to_drop = set(switchers).intersection(current_frame_tracked_objects) for switcher in switchers: if switcher in switchers_to_drop: - switcher.state = reid_constants.STABLE - filtered_switchers.remove(switcher) + switcher.state = reid_constants.STATES.STABLE elif switcher.get_nb_frames_since_last_appearance(frame_id) > max_frames_to_rematch: - switcher.state = reid_constants.LOST_FOREVER - filtered_switchers.remove(switcher) + switcher.state = reid_constants.STATES.LOST_FOREVER - return filtered_switchers + return all_tracked_objects @staticmethod - def drop_candidates( - candidates: List["TrackedObject"], max_attempt_to_rematch: int, frame_id: int - ): - filtered_candidates = candidates.copy() - # for now drop candidates if there was no match - for candidate in filtered_candidates: - if candidate.get_age(frame_id) >= max_attempt_to_rematch: - candidate.state = reid_constants.STABLE - candidates.remove(candidate) - return candidates - - def _postprocess(self, tracker_output: np.ndarray): - filtered_objects = list( - filter( - lambda obj: obj.get_state() == reid_constants.STABLE - and obj in tracker_output[:, INPUT_POSITIONS["object_id"]], - self.all_tracked_objects, - ) + def update_candidates_states( + all_tracked_objects: List["TrackedObject"], max_attempt_to_match: int, frame_id: int + ) -> List["TrackedObject"]: + """ + Updates the state of candidates in the list of all tracked objects. + If a candidate has not been rematched despite max_attempt_to_match attempts, + if will be flaged as a stable object. + + Args: + all_tracked_objects (List["TrackedObject"]): List of all objects being tracked. + max_attempt_to_match (int): Maximum attempt to match a candidate. + frame_id (int): Current frame id. + + Returns: + List["TrackedObject"]: Updated list of all tracked objects after state changes. + """ + candidates = filter_objects_by_state( + tracked_objects=all_tracked_objects, + states=reid_constants.STATES.CANDIDATE, + exclusion=False, ) - reid_output = np.zeros((len(filtered_objects), self.nb_output_cols)) + for candidate in candidates: + if candidate.get_age(frame_id) >= max_attempt_to_match: + candidate.state = reid_constants.STATES.STABLE + return all_tracked_objects - for idx, object in enumerate(filtered_objects): - for required_variable in OUTPUT_POSITIONS: - if required_variable == "frame_id": - output = self.frame_id - else: - try: - output = getattr(object, required_variable) - except: # noqa: E722 - raise NameError( - f"Attribute {required_variable} not in TrackedObject.Check your required output names." - ) + def _postprocess(self, current_tracker_ids: List[Union[int, float]]) -> np.ndarray: + """ + Postprocesses the current tracker IDs. + + Args: + current_tracker_ids (List[Union[int, float]]): The current tracker IDs. + + Returns: + np.ndarray: The postprocessed output. + """ + stable_objects = [ + obj + for obj in self.all_tracked_objects + if obj.get_state() == reid_constants.STATES.STABLE and obj in current_tracker_ids + ] + reid_output = np.zeros((len(stable_objects), self.nb_output_cols)) + + for idx, stable_object in enumerate(stable_objects): + for required_variable in OUTPUT_POSITIONS: + output = ( + self.frame_id + if required_variable == "frame_id" + else getattr(stable_object, required_variable, None) + ) + if output is None: + raise NameError( + f"Attribute {required_variable} not in TrackedObject. Check your required output names." + ) + if required_variable == "category": + inverted_dict = {v: k for k, v in MAPPING_CLASSES.items()} + output = inverted_dict[output] reid_output[idx, OUTPUT_POSITIONS[required_variable]] = output return reid_output + + def to_dict(self) -> Dict: + """ + Converts the tracked objects to a dictionary. + + Returns: + Dict: The dictionary representation of the tracked objects. + """ + data = dict() + for tracked_object in self.all_tracked_objects: + data[tracked_object.object_id] = tracked_object.to_dict() + return data diff --git a/trackreid/tracked_object.py b/trackreid/tracked_object.py index 901285b..8806f5b 100644 --- a/trackreid/tracked_object.py +++ b/trackreid/tracked_object.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from typing import Optional, Union import numpy as np @@ -23,7 +24,7 @@ def __init__( if isinstance(object_ids, Union[float, int]): self.re_id_chain = sllist([object_ids]) elif isinstance(object_ids, sllist): - self.re_id_chain = object_ids + self.re_id_chain = sllist(object_ids) else: raise NameError("unrocognized type for object_ids.") if isinstance(metadata, np.ndarray): @@ -36,6 +37,9 @@ def __init__( else: raise NameError("unrocognized type for metadata.") + def copy(self): + return TrackedObject(object_ids=self.re_id_chain, state=self.state, metadata=self.metadata) + def merge(self, other_object): if not isinstance(other_object, TrackedObject): raise TypeError("Can only merge with another TrackedObject.") @@ -52,6 +56,10 @@ def merge(self, other_object): def object_id(self): return self.re_id_chain.first.value + @property + def tracker_id(self): + return self.re_id_chain.last.value + @property def category(self): return max(self.metadata.class_counts, key=self.metadata.class_counts.get) @@ -69,9 +77,13 @@ def bbox(self): return self.metadata.bbox @property - def nb_corrections(self): + def nb_ids(self): return len(self.re_id_chain) + @property + def nb_corrections(self): + return self.nb_ids - 1 + def get_age(self, frame_id): return frame_id - self.metadata.first_frame_id @@ -87,7 +99,7 @@ def __hash__(self): def __repr__(self): return ( f"TrackedObject(current_id={self.object_id}, re_id_chain={list(self.re_id_chain)}" - + f", state={self.state}: {reid_constants.DESCRIPTION[self.state]})" + + f", state={self.state}: {reid_constants.STATES.DESCRIPTION[self.state]})" ) def __str__(self): @@ -113,7 +125,7 @@ def cut(self, object_id: int): self.re_id_chain = before new_object = TrackedObject( - state=reid_constants.STABLE, object_ids=after, metadata=self.metadata + state=reid_constants.STATES.STABLE, object_ids=after, metadata=self.metadata ) return new_object, self @@ -127,3 +139,28 @@ def format_data(self): self.bbox[3], self.confidence, ] + + def to_dict(self): + data = { + "object_id": float(self.object_id), + "state": int(self.state), + "re_id_chain": list(self.re_id_chain), + "metadata": self.metadata.to_dict(), + } + return data + + def to_json(self): + return json.dumps(self.to_dict(), indent=4) + + @classmethod + def from_dict(cls, data: dict): + obj = cls.__new__(cls) + obj.state = data["state"] + obj.re_id_chain = sllist(data["re_id_chain"]) + obj.metadata = TrackedObjectMetaData.from_dict(data["metadata"]) + return obj + + @classmethod + def from_json(cls, json_str: str): + data = json.loads(json_str) + return cls.from_dict(data) diff --git a/trackreid/tracked_object_filter.py b/trackreid/tracked_object_filter.py index fdcc44f..e7ec627 100644 --- a/trackreid/tracked_object_filter.py +++ b/trackreid/tracked_object_filter.py @@ -7,12 +7,12 @@ def __init__(self, confidence_threshold, frames_seen_threshold): self.frames_seen_threshold = frames_seen_threshold def update(self, tracked_object): - if tracked_object.get_state() == reid_constants.TRACKER_OUTPUT: + if tracked_object.get_state() == reid_constants.STATES.TRACKER_OUTPUT: if ( tracked_object.metadata.mean_confidence() > self.confidence_threshold and tracked_object.metadata.observations >= self.frames_seen_threshold ): - tracked_object.state = reid_constants.FILTERED_OUTPUT + tracked_object.state = reid_constants.STATES.FILTERED_OUTPUT elif tracked_object.metadata.mean_confidence() < self.confidence_threshold: - tracked_object.state = reid_constants.TRACKER_OUTPUT + tracked_object.state = reid_constants.STATES.TRACKER_OUTPUT diff --git a/trackreid/tracked_object_metadata.py b/trackreid/tracked_object_metadata.py index 9ba974b..5b6a22f 100644 --- a/trackreid/tracked_object_metadata.py +++ b/trackreid/tracked_object_metadata.py @@ -1,15 +1,12 @@ import json -from pathlib import Path -import numpy as np - -from trackreid.args.reid_args import INPUT_POSITIONS, POSSIBLE_CLASSES +from trackreid.args.reid_args import INPUT_POSITIONS, MAPPING_CLASSES, POSSIBLE_CLASSES class TrackedObjectMetaData: def __init__(self, data_line, frame_id): self.first_frame_id = frame_id - self.class_counts = {class_name: 0 for class_name in POSSIBLE_CLASSES} + self.class_counts = {class_name: 0 for class_name in MAPPING_CLASSES.values()} self.observations = 0 self.confidence_sum = 0 self.confidence = 0 @@ -17,7 +14,7 @@ def __init__(self, data_line, frame_id): def update(self, data_line, frame_id): self.last_frame_id = frame_id - class_name = data_line[INPUT_POSITIONS["category"]] + class_name = MAPPING_CLASSES.get(data_line[INPUT_POSITIONS["category"]]) self.class_counts[class_name] = self.class_counts.get(class_name, 0) + 1 self.bbox = list(data_line[INPUT_POSITIONS["bbox"]].astype(int)) confidence = float(data_line[INPUT_POSITIONS["confidence"]]) @@ -34,63 +31,55 @@ def merge(self, other_object): self.confidence = other_object.confidence self.bbox = other_object.bbox self.last_frame_id = other_object.last_frame_id - for class_name in POSSIBLE_CLASSES: + for class_name in MAPPING_CLASSES.values(): self.class_counts[class_name] = self.class_counts.get( class_name, 0 ) + other_object.class_counts.get(class_name, 0) def copy(self): - # Create a new instance of TrackedObjectMetaData - # initialize with fake data - - # TODO: make something better here, input order might change - copy_obj = TrackedObjectMetaData( - data_line=np.array( - [ - 0, - list(self.class_counts.keys())[0], - *self.bbox, - self.confidence, - ] - ), - frame_id=self.first_frame_id, - ) + copy_obj = TrackedObjectMetaData.__new__(TrackedObjectMetaData) # Update the copied instance with the actual class counts and observations + copy_obj.bbox = self.bbox.copy() copy_obj.class_counts = self.class_counts.copy() copy_obj.observations = self.observations copy_obj.confidence_sum = self.confidence_sum copy_obj.confidence = self.confidence - copy_obj.bbox = self.bbox copy_obj.first_frame_id = self.first_frame_id copy_obj.last_frame_id = self.last_frame_id return copy_obj - def save_to_json(self, filename): + def to_dict(self): data = { - "first_frame_id": self.first_frame_id, + "first_frame_id": int(self.first_frame_id), + "last_frame_id": int(self.last_frame_id), "class_counts": self.class_counts, - "bbox": self.bbox, - "confidence": self.confidence, - "confidence_sum": self.confidence_sum, - "observations": self.observations, + "bbox": [int(i) for i in self.bbox], + "confidence": float(self.confidence), + "confidence_sum": float(self.confidence_sum), + "observations": int(self.observations), } + return data - with Path.open(filename, "w") as file: - json.dump(data, file) + def to_json(self): + return json.dumps(self.to_dict(), indent=4) + + @classmethod + def from_dict(cls, data: dict): + obj = cls.__new__(cls) + obj.first_frame_id = data["first_frame_id"] + obj.last_frame_id = data["last_frame_id"] + obj.class_counts = data["class_counts"] + obj.bbox = data["bbox"] + obj.confidence = data["confidence"] + obj.confidence_sum = data["confidence_sum"] + obj.observations = data["observations"] + return obj @classmethod - def load_from_json(cls, filename): - with Path.open(filename, "r") as file: - data = json.load(file) - obj = cls.__new__(cls) - obj.first_frame_id = data["first_frame_id"] - obj.class_counts = data["class_counts"] - obj.bbox = data["bbox"] - obj.confidence = data["confidence"] - obj.confidence_sum = data["confidence_sum"] - obj.observations = data["observations"] - return obj + def from_json(cls, json_str: str): + data = json.loads(json_str) + return cls.from_dict(data) def class_proportions(self): if self.observations > 0: @@ -99,7 +88,7 @@ def class_proportions(self): for class_name, count in self.class_counts.items() } else: - proportions = {class_name: 0.0 for class_name in POSSIBLE_CLASSES} + proportions = {MAPPING_CLASSES[class_name]: 0.0 for class_name in POSSIBLE_CLASSES} return proportions def percentage_of_time_seen(self, frame_id): @@ -121,6 +110,6 @@ def __repr__(self) -> str: def __str__(self): return ( f"First frame seen: {self.first_frame_id}, nb observations: {self.observations}, " - + f"class Proportions: {self.class_proportions()}, Bounding Box: {self.bbox}, " - + f"Mean Confidence: {self.mean_confidence()}" + + f"class proportions: {self.class_proportions()}, bbox: {self.bbox}, " + + f"mean confidence: {self.mean_confidence()}" ) diff --git a/trackreid/utils.py b/trackreid/utils.py index 1e6f595..3423a8f 100644 --- a/trackreid/utils.py +++ b/trackreid/utils.py @@ -45,6 +45,18 @@ def filter_objects_by_state(tracked_objects: List, states: Union[int, list], exc return filtered_objects +def filter_objects_by_category( + tracked_objects: List, category: Union[Union[float, int], list], exclusion=False +): + if isinstance(category, Union[float, int]): + category = [category] + if exclusion: + filtered_objects = [obj for obj in tracked_objects if obj.category not in category] + else: + filtered_objects = [obj for obj in tracked_objects if obj.category in category] + return filtered_objects + + def reshape_tracker_result(tracker_output: np.ndarray): if tracker_output.ndim == 1: tracker_output = np.expand_dims(tracker_output, 0) From c661fabbd7607b3afd97089419b01d2faf75aee3 Mon Sep 17 00:00:00 2001 From: Tristan Pepin <122389133+tristanpepinartefact@users.noreply.github.com> Date: Thu, 16 Nov 2023 18:12:43 +0100 Subject: [PATCH 06/13] Tp/add to txt saving (#16) --- .gitignore | 2 +- notebooks/starter_kit_reid.ipynb | 128 +++++++++++++++++++---- pyproject.toml | 1 - tests/unit_tests/test_matcher.py | 4 - tests/unit_tests/test_metadata.py | 12 +-- tests/unit_tests/test_tracked_objects.py | 8 +- tests/unit_tests/test_utils.py | 4 +- trackreid/args/reid_args.py | 4 +- trackreid/reid_processor.py | 51 +++++++-- trackreid/tracked_object.py | 2 + trackreid/tracked_object_metadata.py | 23 ++-- trackreid/utils.py | 7 ++ 12 files changed, 193 insertions(+), 53 deletions(-) diff --git a/.gitignore b/.gitignore index 92b4f04..fe7c188 100644 --- a/.gitignore +++ b/.gitignore @@ -142,6 +142,6 @@ secrets/* data/detections/* data/frames/* *.mp4 - +*.txt # poetry poetry.lock diff --git a/notebooks/starter_kit_reid.ipynb b/notebooks/starter_kit_reid.ipynb index ff99d17..09faf18 100644 --- a/notebooks/starter_kit_reid.ipynb +++ b/notebooks/starter_kit_reid.ipynb @@ -18,6 +18,7 @@ "source": [ "import os\n", "import sys\n", + "from datetime import datetime\n", "\n", "import cv2\n", "import numpy as np\n", @@ -33,6 +34,28 @@ "sys.path.append(\"..\")\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For this demo, you have to install bytetrack. You can do so by typing the following command : \n", + "```bash\n", + "pip install git+https://github.com/artefactory-fr/bytetrack.git@main\n", + "````\n", + "\n", + "Baseline data can be found in `gs://data-track-reid/predictions/baseline`. You can copy them in `../data/predictions/` using the following commands (in a terminal at the root of the project):\n", + "\n", + "```bash\n", + "mkdir -p ./data/predictions/\n", + "gsutil -m cp -r gs://data-track-reid/predictions/baseline ./data/predictions/\n", + "```\n", + "\n", + "Then you can reoganize the data using the following : \n", + "```bash \n", + "find ./data/predictions/baseline -mindepth 2 -type f -name \"*.txt\" -exec sh -c 'mv \"$0\" \"${0%/*/*}/$(basename \"${0%/*}\").txt\"' {} \\; && find ./data/predictions/baseline -mindepth 1 -type d -empty -delete\n", + "```" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -49,9 +72,11 @@ "DATA_PATH = \"../data\"\n", "DETECTION_PATH = f\"{DATA_PATH}/detections\"\n", "FRAME_PATH = f\"{DATA_PATH}/frames\"\n", + "PREDICTIONS_PATH = f\"{DATA_PATH}/predictions\"\n", "VIDEO_OUTPUT_PATH = \"private\"\n", "\n", - "SEQUENCES = os.listdir(DETECTION_PATH)\n" + "SEQUENCES = os.listdir(DETECTION_PATH)\n", + "GENERATE_VIDEOS = False\n" ] }, { @@ -167,30 +192,38 @@ "metadata": {}, "outputs": [], "source": [ - "for sequence in SEQUENCES :\n", + "timestamp = datetime.now().strftime('%Y-%m-%d-%H-%M')\n", + "print(timestamp)\n", + "folder_save = os.path.join(PREDICTIONS_PATH,timestamp)\n", + "os.makedirs(folder_save, exist_ok=True)\n", + "GENERATE_VIDEOS = False\n", + "for sequence in tqdm(SEQUENCES) :\n", " frame_path = get_sequence_frames(sequence)\n", " test_sequence = Sequence(frame_path)\n", - " test_sequence\n", " frame_id = 0\n", " BaseTrack._count = 0\n", + " from datetime import datetime\n", "\n", - " # Define the codec using VideoWriter_fourcc() and create a VideoWriter object\n", - " fourcc = cv2.VideoWriter_fourcc(*'avc1') # or use 'x264'\n", - " out = cv2.VideoWriter(f'{sequence}.mp4', fourcc, 20.0, (2560, 1440)) # adjust the frame size (640, 480) as per your needs\n", + " file_path = os.path.join(folder_save,sequence) + '.txt'\n", + " video_path = os.path.join(folder_save,sequence) + '.mp4'\n", "\n", + " if GENERATE_VIDEOS:\n", + " fourcc = cv2.VideoWriter_fourcc(*'avc1') # or use 'x264'\n", + " out = cv2.VideoWriter(video_path, fourcc, 20.0, (2560, 1440)) # adjust the frame size (640, 480) as per your needs\n", "\n", " detection_handler = DetectionHandler(image_shape=[2560, 1440])\n", " tracking_handler = TrackingHandler(tracker=BYTETracker(track_thresh= 0.3, track_buffer = 5, match_thresh = 0.85, frame_rate= 30))\n", " reid_processor = ReidProcessor(filter_confidence_threshold=0.1,\n", " filter_time_threshold=5,\n", " cost_function=bounding_box_distance,\n", - " #cost_function_threshold=500, # max cost to rematch 2 objects\n", + " cost_function_threshold=5000, # max cost to rematch 2 objects\n", " selection_function=select_by_category,\n", " max_attempt_to_match=5,\n", - " max_frames_to_rematch=500)\n", + " max_frames_to_rematch=500,\n", + " save_to_txt=True,\n", + " file_path=file_path)\n", "\n", - " for frame, detection in tqdm(test_sequence):\n", - " frame = np.array(frame)\n", + " for frame, detection in test_sequence:\n", "\n", " frame_id += 1\n", "\n", @@ -198,26 +231,85 @@ " processed_tracked = tracking_handler.update(processed_detections, frame_id)\n", " reid_results = reid_processor.update(processed_tracked, frame_id)\n", "\n", - " if len(reid_results) > 0:\n", + " if GENERATE_VIDEOS and len(reid_results) > 0:\n", + " frame = np.array(frame)\n", " for res in reid_results:\n", " object_id = int(res[OUTPUT_POSITIONS[\"object_id\"]])\n", " bbox = list(map(int, res[OUTPUT_POSITIONS[\"bbox\"]]))\n", " class_id = int(res[OUTPUT_POSITIONS[\"category\"]])\n", " tracker_id = int(res[OUTPUT_POSITIONS[\"tracker_id\"]])\n", " mean_confidence = float(res[OUTPUT_POSITIONS[\"mean_confidence\"]])\n", - " #mean_confidence_per_object[object_id].append((frame_id, mean_confidence))\n", " x1, y1, x2, y2 = bbox\n", " color = (0, 0, 255) if class_id else (0, 255, 0) # green for class 0, red for class 1\n", " cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)\n", " cv2.putText(frame, f\"{object_id} ({tracker_id}) : {round(mean_confidence,2)}\", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)\n", - " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", "\n", - " # Write the frame to the video file\n", - " out.write(frame)\n", - " out.release()\n", + " if GENERATE_VIDEOS:\n", + " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", + " out.write(frame)\n", + "\n", + " if GENERATE_VIDEOS :\n", + " out.release()\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from collections import defaultdict\n", + "\n", + "def count_occurrences(file_path, case):\n", + " object_counts = defaultdict(int)\n", + " class_counts = defaultdict(lambda: defaultdict(int))\n", + "\n", + " with open(file_path, 'r') as f:\n", + " for line in f:\n", + " data = line.split()\n", + "\n", + " if case != 'baseline':\n", + " object_id = int(data[1])\n", + " category = int(data[2])\n", + " else:\n", + " object_id = int(data[1])\n", + " category = int(data[-1])\n", + "\n", + " object_counts[object_id] += 1\n", + " class_counts[object_id][category] += 1\n", + "\n", + " return object_counts, class_counts\n", + "\n", + "def filter_counts(object_counts, class_counts, min_occurrences=10):\n", + " filtered_objects = {}\n", + "\n", + " for object_id, count in object_counts.items():\n", + " if count > min_occurrences and class_counts[object_id][0] > class_counts[object_id][1]:\n", + " filtered_objects[object_id] = count\n", + "\n", + " return filtered_objects\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "PATH_PREDICTIONS = f\"../data/predictions/{timestamp}\"\n", + "\n", + "for sequence in SEQUENCES:\n", + " print(\"-\"*50)\n", + " print(sequence)\n", + "\n", + " for case in [\"baseline\", timestamp]:\n", + " object_counts, class_counts = count_occurrences(f'../data/predictions/{case}/{sequence}.txt', case=case)\n", + " filtered_objects = filter_counts(object_counts, class_counts)\n", "\n", - " print(sequence, len(reid_processor.seen_objects),reid_processor.nb_corrections)\n", - " print(reid_processor.seen_objects)\n" + " print(case)\n", + " print(filtered_objects)" ] } ], diff --git a/pyproject.toml b/pyproject.toml index df722d0..f0459b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,6 @@ lapx = "^0.5.5" opencv-python = "^4.8.1.78" tqdm = "^4.66.1" pillow = "^10.1.0" -bytetracker = {git = "git@github.com:artefactory-fr/bytetrack.git", branch = "main"} [tool.poetry.group.dev.dependencies] diff --git a/tests/unit_tests/test_matcher.py b/tests/unit_tests/test_matcher.py index e75aa05..1e21867 100644 --- a/tests/unit_tests/test_matcher.py +++ b/tests/unit_tests/test_matcher.py @@ -48,9 +48,6 @@ def dummy_selection_function(candidate, switcher): # noqa: ARG001 candidates.append(obj) switchers.append(obj) - print(candidates) - print(switchers) - matches = matcher.match(candidates, switchers) assert len(matches) == 3 @@ -76,7 +73,6 @@ def dummy_selection_function(candidate, switcher): switchers.append(obj) matches = matcher.match(candidates, switchers) - print(matches) assert len(matches) == 2 for match in matches: diff --git a/tests/unit_tests/test_metadata.py b/tests/unit_tests/test_metadata.py index df21166..7b195ac 100644 --- a/tests/unit_tests/test_metadata.py +++ b/tests/unit_tests/test_metadata.py @@ -17,7 +17,7 @@ def test_tracked_metadata_copy(): copied_metadata = tracked_metadata.copy() assert copied_metadata.first_frame_id == 15 assert copied_metadata.last_frame_id == 251 - assert copied_metadata.class_counts == {"shop_item": 175, "personal_item": 0} + assert copied_metadata.class_counts == {0: 175, 1: 0} assert copied_metadata.bbox == [598, 208, 814, 447] assert copied_metadata.confidence == 0.610211 assert copied_metadata.confidence_sum == 111.30582399999996 @@ -25,14 +25,14 @@ def test_tracked_metadata_copy(): assert round(copied_metadata.percentage_of_time_seen(251), 2) == 73.84 class_proportions = copied_metadata.class_proportions() - assert round(class_proportions["shop_item"], 2) == 1.0 - assert round(class_proportions["personal_item"], 2) == 0.0 + assert round(class_proportions.get(0), 2) == 1.0 + assert round(class_proportions.get(1), 2) == 0.0 tracked_metadata_2 = ALL_TRACKED_METADATA[1].copy() tracked_metadata.merge(tracked_metadata_2) # test impact of merge inplace in a copy, should be none - assert copied_metadata.class_counts == {"shop_item": 175, "personal_item": 0} + assert copied_metadata.class_counts == {0: 175, 1: 0} assert copied_metadata.bbox == [598, 208, 814, 447] assert copied_metadata.confidence == 0.610211 assert copied_metadata.confidence_sum == 111.30582399999996 @@ -44,8 +44,8 @@ def test_tracked_metadata_merge(): tracked_metadata_2 = ALL_TRACKED_METADATA[1].copy() tracked_metadata_1.merge(tracked_metadata_2) assert tracked_metadata_1.last_frame_id == 251 - assert tracked_metadata_1.class_counts["shop_item"] == 175 - assert tracked_metadata_1.class_counts["personal_item"] == 216 + assert tracked_metadata_1.class_counts.get(0) == 175 + assert tracked_metadata_1.class_counts.get(1) == 216 assert tracked_metadata_1.bbox == [548, 455, 846, 645] assert tracked_metadata_1.confidence == 0.700626 assert tracked_metadata_1.confidence_sum == 260.988185 diff --git a/tests/unit_tests/test_tracked_objects.py b/tests/unit_tests/test_tracked_objects.py index 76c2575..7c70952 100644 --- a/tests/unit_tests/test_tracked_objects.py +++ b/tests/unit_tests/test_tracked_objects.py @@ -38,7 +38,7 @@ def test_tracked_object_properties(): tracked_object = ALL_TRACKED_OBJECTS[0].copy() assert tracked_object.object_id == 1.0 assert tracked_object.state == 0 - assert tracked_object.category == "shop_item" + assert tracked_object.category == 0 assert round(tracked_object.confidence, 2) == 0.61 assert round(tracked_object.mean_confidence, 2) == 0.64 assert tracked_object.bbox == [598, 208, 814, 447] @@ -52,7 +52,7 @@ def test_tracked_object_merge(): tracked_object_1.merge(tracked_object_2) assert tracked_object_1.object_id == 1.0 assert tracked_object_1.state == 0 - assert tracked_object_1.category == "personal_item" + assert tracked_object_1.category == 1 assert round(tracked_object_1.confidence, 2) == 0.70 assert round(tracked_object_1.mean_confidence, 2) == 0.67 assert tracked_object_1.bbox == [548, 455, 846, 645] @@ -65,7 +65,7 @@ def test_tracked_object_cut(): new_object, cut_object = tracked_object.cut(2.0) assert new_object.object_id == 14.0 assert new_object.state == 0 - assert new_object.category == "shop_item" + assert new_object.category == 0 assert round(new_object.confidence, 2) == 0.61 assert round(new_object.mean_confidence, 2) == 0.64 assert new_object.bbox == [598, 208, 814, 447] @@ -73,7 +73,7 @@ def test_tracked_object_cut(): assert new_object.nb_corrections == 2 assert cut_object.object_id == 1.0 assert cut_object.state == 0 - assert cut_object.category == "shop_item" + assert cut_object.category == 0 assert round(cut_object.confidence, 2) == 0.61 assert round(cut_object.mean_confidence, 2) == 0.64 assert cut_object.bbox == [598, 208, 814, 447] diff --git a/tests/unit_tests/test_utils.py b/tests/unit_tests/test_utils.py index 2a15f30..e280c20 100644 --- a/tests/unit_tests/test_utils.py +++ b/tests/unit_tests/test_utils.py @@ -64,7 +64,7 @@ def test_filter_objects_by_state_2(): def test_filter_objects_by_category(): - category = "shop_item" + category = 0 assert utils.filter_objects_by_category(ALL_TRACKED_OBJECTS, category, exclusion=False) == [ ALL_TRACKED_OBJECTS[0], ALL_TRACKED_OBJECTS[2], @@ -72,7 +72,7 @@ def test_filter_objects_by_category(): def test_filter_objects_by_category_2(): - category = "personal_item" + category = 1 assert utils.filter_objects_by_category(ALL_TRACKED_OBJECTS, category, exclusion=True) == [ ALL_TRACKED_OBJECTS[0], ALL_TRACKED_OBJECTS[2], diff --git a/trackreid/args/reid_args.py b/trackreid/args/reid_args.py index 2e81095..cdcd456 100644 --- a/trackreid/args/reid_args.py +++ b/trackreid/args/reid_args.py @@ -1,5 +1,5 @@ -POSSIBLE_CLASSES = [0.0, 1.0] -MAPPING_CLASSES = {0.0: "shop_item", 1.0: "personal_item"} +POSSIBLE_CLASSES = [0, 1] +MAPPING_CLASSES = {0: "shop_item", 1: "personal_item"} INPUT_POSITIONS = { "object_id": 4, diff --git a/trackreid/reid_processor.py b/trackreid/reid_processor.py index 13731da..70f2c69 100644 --- a/trackreid/reid_processor.py +++ b/trackreid/reid_processor.py @@ -4,7 +4,7 @@ import numpy as np -from trackreid.args.reid_args import INPUT_POSITIONS, MAPPING_CLASSES, OUTPUT_POSITIONS +from trackreid.args.reid_args import INPUT_POSITIONS, OUTPUT_POSITIONS from trackreid.constants.reid_constants import reid_constants from trackreid.matcher import Matcher from trackreid.tracked_object import TrackedObject @@ -27,6 +27,8 @@ def __init__( max_frames_to_rematch: int, max_attempt_to_match: int, cost_function_threshold: Optional[Union[int, float]] = None, + save_to_txt: bool = True, + file_path: str = "tracks.txt", ) -> None: """ Initializes the ReidProcessor class. @@ -61,6 +63,18 @@ def __init__( self.frame_id = 0 self.nb_output_cols = get_nb_output_cols(output_positions=OUTPUT_POSITIONS) + self.save_to_txt = save_to_txt + self.file_path = file_path + + def set_file_path(self, new_file_path: str) -> None: + """ + Sets a new file path. + + Args: + new_file_path (str): The new file path. + """ + self.file_path = new_file_path + @property def nb_corrections(self) -> int: nb_corrections = 0 @@ -110,9 +124,14 @@ def update( ) self._perform_reid_process(current_tracker_ids=current_tracker_ids) reid_output = self._postprocess(current_tracker_ids=current_tracker_ids) - return reid_output + else: - return tracker_output + reid_output = tracker_output + + if self.save_to_txt: + self.save_results_to_txt(file_path=self.file_path, reid_output=reid_output) + + return reid_output def _preprocess(self, tracker_output: np.ndarray, frame_id: int) -> List["TrackedObject"]: """ @@ -207,6 +226,7 @@ def _perform_reid_process(self, current_tracker_ids: List[Union[int, float]]) -> Args: current_tracker_ids (List[Union[int, float]]): The current tracker IDs. """ + self.all_tracked_objects = self.correct_reid_chains( all_tracked_objects=self.all_tracked_objects, current_tracker_ids=current_tracker_ids ) @@ -341,7 +361,7 @@ def correct_reid_chains( all_tracked_objects.append(tracked_id) if new_object in current_tracker_ids: - new_object.state = reid_constants.STATES.STABLE + new_object.state = reid_constants.STATES.CANDIDATE all_tracked_objects.append(new_object) elif new_object.nb_corrections > 1: @@ -436,7 +456,10 @@ def update_candidates_states( candidate.state = reid_constants.STATES.STABLE return all_tracked_objects - def _postprocess(self, current_tracker_ids: List[Union[int, float]]) -> np.ndarray: + def _postprocess( + self, + current_tracker_ids: List[Union[int, float]], + ) -> np.ndarray: """ Postprocesses the current tracker IDs. @@ -465,13 +488,25 @@ def _postprocess(self, current_tracker_ids: List[Union[int, float]]) -> np.ndarr raise NameError( f"Attribute {required_variable} not in TrackedObject. Check your required output names." ) - if required_variable == "category": - inverted_dict = {v: k for k, v in MAPPING_CLASSES.items()} - output = inverted_dict[output] reid_output[idx, OUTPUT_POSITIONS[required_variable]] = output return reid_output + def save_results_to_txt(self, file_path: str, reid_output: np.ndarray) -> None: + """ + Saves the reid_output to a txt file. + + Args: + file_path (str): The path to the txt file. + reid_output (np.ndarray): The output of _post_process. + """ + with open(file_path, "a") as f: # noqa: PTH123 + for row in reid_output: + line = " ".join( + str(int(val)) if val.is_integer() else "{:.6f}".format(val) for val in row + ) + f.write(line + "\n") + def to_dict(self) -> Dict: """ Converts the tracked objects to a dictionary. diff --git a/trackreid/tracked_object.py b/trackreid/tracked_object.py index 8806f5b..96c5181 100644 --- a/trackreid/tracked_object.py +++ b/trackreid/tracked_object.py @@ -127,6 +127,8 @@ def cut(self, object_id: int): new_object = TrackedObject( state=reid_constants.STATES.STABLE, object_ids=after, metadata=self.metadata ) + # set potential age 0 for new object + new_object.metadata.first_frame_id = new_object.metadata.last_frame_id return new_object, self def format_data(self): diff --git a/trackreid/tracked_object_metadata.py b/trackreid/tracked_object_metadata.py index 5b6a22f..bee82ff 100644 --- a/trackreid/tracked_object_metadata.py +++ b/trackreid/tracked_object_metadata.py @@ -1,12 +1,13 @@ import json from trackreid.args.reid_args import INPUT_POSITIONS, MAPPING_CLASSES, POSSIBLE_CLASSES +from trackreid.utils import get_key_from_value class TrackedObjectMetaData: def __init__(self, data_line, frame_id): self.first_frame_id = frame_id - self.class_counts = {class_name: 0 for class_name in MAPPING_CLASSES.values()} + self.class_counts = {class_name: 0 for class_name in map(int, POSSIBLE_CLASSES)} self.observations = 0 self.confidence_sum = 0 self.confidence = 0 @@ -14,9 +15,9 @@ def __init__(self, data_line, frame_id): def update(self, data_line, frame_id): self.last_frame_id = frame_id - class_name = MAPPING_CLASSES.get(data_line[INPUT_POSITIONS["category"]]) + class_name = int(data_line[INPUT_POSITIONS["category"]]) self.class_counts[class_name] = self.class_counts.get(class_name, 0) + 1 - self.bbox = list(data_line[INPUT_POSITIONS["bbox"]].astype(int)) + self.bbox = list(map(int, data_line[INPUT_POSITIONS["bbox"]])) confidence = float(data_line[INPUT_POSITIONS["confidence"]]) self.confidence = confidence self.confidence_sum += confidence @@ -31,7 +32,7 @@ def merge(self, other_object): self.confidence = other_object.confidence self.bbox = other_object.bbox self.last_frame_id = other_object.last_frame_id - for class_name in MAPPING_CLASSES.values(): + for class_name in map(int, POSSIBLE_CLASSES): self.class_counts[class_name] = self.class_counts.get( class_name, 0 ) + other_object.class_counts.get(class_name, 0) @@ -50,10 +51,13 @@ def copy(self): return copy_obj def to_dict(self): + class_counts_str = { + MAPPING_CLASSES[class_name]: count for class_name, count in self.class_counts.items() + } data = { "first_frame_id": int(self.first_frame_id), "last_frame_id": int(self.last_frame_id), - "class_counts": self.class_counts, + "class_counts": class_counts_str, "bbox": [int(i) for i in self.bbox], "confidence": float(self.confidence), "confidence_sum": float(self.confidence_sum), @@ -66,10 +70,15 @@ def to_json(self): @classmethod def from_dict(cls, data: dict): + class_counts_str = data["class_counts"] + class_counts = { + get_key_from_value(MAPPING_CLASSES, class_name): count + for class_name, count in class_counts_str.items() + } obj = cls.__new__(cls) obj.first_frame_id = data["first_frame_id"] obj.last_frame_id = data["last_frame_id"] - obj.class_counts = data["class_counts"] + obj.class_counts = class_counts obj.bbox = data["bbox"] obj.confidence = data["confidence"] obj.confidence_sum = data["confidence_sum"] @@ -88,7 +97,7 @@ def class_proportions(self): for class_name, count in self.class_counts.items() } else: - proportions = {MAPPING_CLASSES[class_name]: 0.0 for class_name in POSSIBLE_CLASSES} + proportions = {class_name: 0.0 for class_name in map(int, POSSIBLE_CLASSES)} return proportions def percentage_of_time_seen(self, frame_id): diff --git a/trackreid/utils.py b/trackreid/utils.py index 3423a8f..bcc09b3 100644 --- a/trackreid/utils.py +++ b/trackreid/utils.py @@ -74,3 +74,10 @@ def get_nb_output_cols(output_positions: dict): raise TypeError("Unkown type in required output positions.") return nb_cols + + +def get_key_from_value(dictionary, target_value): + for key, value in dictionary.items(): + if value == target_value: + return key + return None From 81ecb2c77bd5558f8eeefb9febc6d74d485bd2c9 Mon Sep 17 00:00:00 2001 From: Tristan Pepin <122389133+tristanpepinartefact@users.noreply.github.com> Date: Mon, 20 Nov 2023 13:15:20 +0100 Subject: [PATCH 07/13] Tp/improve user interface (#17) --- notebooks/starter_kit_reid.ipynb | 100 +++++----- .../unit_tests/tracked_objects/object_1.json | 4 +- .../unit_tests/tracked_objects/object_24.json | 4 +- .../unit_tests/tracked_objects/object_4.json | 4 +- tests/unit_tests/test_utils.py | 5 +- trackreid/args/reid_args.py | 19 -- trackreid/configs/input_data_positions.py | 25 +++ trackreid/configs/output_data_positions.py | 35 ++++ .../{constants => configs}/reid_constants.py | 0 trackreid/cost_functions/__init__.py | 1 + .../cost_functions/bounding_box_distance.py | 28 +++ trackreid/matcher.py | 42 +++- trackreid/reid_processor.py | 179 +++++++++++++++--- trackreid/selection_functions/__init__.py | 1 + .../selection_functions/select_by_category.py | 18 ++ trackreid/tracked_object.py | 2 +- trackreid/tracked_object_filter.py | 27 ++- trackreid/tracked_object_metadata.py | 23 +-- trackreid/utils.py | 89 +++++++-- 19 files changed, 477 insertions(+), 129 deletions(-) delete mode 100644 trackreid/args/reid_args.py create mode 100644 trackreid/configs/input_data_positions.py create mode 100644 trackreid/configs/output_data_positions.py rename trackreid/{constants => configs}/reid_constants.py (100%) create mode 100644 trackreid/cost_functions/__init__.py create mode 100644 trackreid/cost_functions/bounding_box_distance.py create mode 100644 trackreid/selection_functions/__init__.py create mode 100644 trackreid/selection_functions/select_by_category.py diff --git a/notebooks/starter_kit_reid.ipynb b/notebooks/starter_kit_reid.ipynb index 09faf18..685d1a6 100644 --- a/notebooks/starter_kit_reid.ipynb +++ b/notebooks/starter_kit_reid.ipynb @@ -10,6 +10,28 @@ "%autoreload 2" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For this demo, you have to install bytetrack. You can do so by typing the following command : \n", + "```bash\n", + "pip install git+https://github.com/artefactory-fr/bytetrack.git@main\n", + "````\n", + "\n", + "Baseline data can be found in `gs://data-track-reid/predictions/baseline`. You can copy them in `../data/predictions/` using the following commands (in a terminal at the root of the project):\n", + "\n", + "```bash\n", + "mkdir -p ./data/predictions/\n", + "gsutil -m cp -r gs://data-track-reid/predictions/baseline ./data/predictions/\n", + "```\n", + "\n", + "Then you can reoganize the data using the following : \n", + "```bash \n", + "find ./data/predictions/baseline -mindepth 2 -type f -name \"*.txt\" -exec sh -c 'mv \"$0\" \"${0%/*/*}/$(basename \"${0%/*}\").txt\"' {} \\; && find ./data/predictions/baseline -mindepth 1 -type d -empty -delete\n", + "```" + ] + }, { "cell_type": "code", "execution_count": null, @@ -28,32 +50,40 @@ "\n", "from lib.bbox.utils import rescale_bbox, xy_center_to_xyxy\n", "from lib.sequence import Sequence\n", - "from trackreid.args.reid_args import OUTPUT_POSITIONS\n", - "from trackreid.reid_processor import ReidProcessor\n", + "from trackreid.configs.output_data_positions import output_data_positions\n", "\n", "sys.path.append(\"..\")\n" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "For this demo, you have to install bytetrack. You can do so by typing the following command : \n", - "```bash\n", - "pip install git+https://github.com/artefactory-fr/bytetrack.git@main\n", - "````\n", - "\n", - "Baseline data can be found in `gs://data-track-reid/predictions/baseline`. You can copy them in `../data/predictions/` using the following commands (in a terminal at the root of the project):\n", - "\n", - "```bash\n", - "mkdir -p ./data/predictions/\n", - "gsutil -m cp -r gs://data-track-reid/predictions/baseline ./data/predictions/\n", - "```\n", "\n", - "Then you can reoganize the data using the following : \n", - "```bash \n", - "find ./data/predictions/baseline -mindepth 2 -type f -name \"*.txt\" -exec sh -c 'mv \"$0\" \"${0%/*/*}/$(basename \"${0%/*}\").txt\"' {} \\; && find ./data/predictions/baseline -mindepth 1 -type d -empty -delete\n", - "```" + "from trackreid.reid_processor import ReidProcessor\n", + "from trackreid.cost_functions import bounding_box_distance\n", + "from trackreid.selection_functions import select_by_category\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ReidProcessor.print_input_data_format_requirements()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ReidProcessor.print_output_data_format_information()" ] }, { @@ -162,30 +192,6 @@ " return tracked_objects" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def bounding_box_distance(obj1, obj2):\n", - " # Get the bounding boxes from the Metadata of each TrackedObject\n", - " bbox1 = obj1.metadata.bbox\n", - " bbox2 = obj2.metadata.bbox\n", - "\n", - " # Calculate the Euclidean distance between the centers of the bounding boxes\n", - " center1 = ((bbox1[0] + bbox1[2]) / 2, (bbox1[1] + bbox1[3]) / 2)\n", - " center2 = ((bbox2[0] + bbox2[2]) / 2, (bbox2[1] + bbox2[3]) / 2)\n", - " distance = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)\n", - "\n", - " return distance\n", - "\n", - "# TODO : discard by zone\n", - "def select_by_category(obj1, obj2):\n", - " # Compare the categories of the two objects\n", - " return 1 if obj1.category == obj2.category else 0" - ] - }, { "cell_type": "code", "execution_count": null, @@ -234,11 +240,11 @@ " if GENERATE_VIDEOS and len(reid_results) > 0:\n", " frame = np.array(frame)\n", " for res in reid_results:\n", - " object_id = int(res[OUTPUT_POSITIONS[\"object_id\"]])\n", - " bbox = list(map(int, res[OUTPUT_POSITIONS[\"bbox\"]]))\n", - " class_id = int(res[OUTPUT_POSITIONS[\"category\"]])\n", - " tracker_id = int(res[OUTPUT_POSITIONS[\"tracker_id\"]])\n", - " mean_confidence = float(res[OUTPUT_POSITIONS[\"mean_confidence\"]])\n", + " object_id = int(res[output_data_positions.object_id])\n", + " bbox = list(map(int, res[output_data_positions.bbox]))\n", + " class_id = int(res[output_data_positions.category])\n", + " tracker_id = int(res[output_data_positions.tracker_id])\n", + " mean_confidence = float(res[output_data_positions.mean_confidence])\n", " x1, y1, x2, y2 = bbox\n", " color = (0, 0, 255) if class_id else (0, 255, 0) # green for class 0, red for class 1\n", " cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)\n", diff --git a/tests/data/unit_tests/tracked_objects/object_1.json b/tests/data/unit_tests/tracked_objects/object_1.json index 13a9864..9609a5b 100644 --- a/tests/data/unit_tests/tracked_objects/object_1.json +++ b/tests/data/unit_tests/tracked_objects/object_1.json @@ -12,8 +12,8 @@ "first_frame_id": 15, "last_frame_id": 251, "class_counts": { - "shop_item": 175, - "personal_item": 0 + "0": 175, + "1": 0 }, "bbox": [ 598, diff --git a/tests/data/unit_tests/tracked_objects/object_24.json b/tests/data/unit_tests/tracked_objects/object_24.json index 29ae090..aef62df 100644 --- a/tests/data/unit_tests/tracked_objects/object_24.json +++ b/tests/data/unit_tests/tracked_objects/object_24.json @@ -8,8 +8,8 @@ "first_frame_id": 154, "last_frame_id": 251, "class_counts": { - "shop_item": 2, - "personal_item": 0 + "0": 2, + "1": 0 }, "bbox": [ 1430, diff --git a/tests/data/unit_tests/tracked_objects/object_4.json b/tests/data/unit_tests/tracked_objects/object_4.json index facc938..ecb65dd 100644 --- a/tests/data/unit_tests/tracked_objects/object_4.json +++ b/tests/data/unit_tests/tracked_objects/object_4.json @@ -9,8 +9,8 @@ "first_frame_id": 38, "last_frame_id": 251, "class_counts": { - "shop_item": 0, - "personal_item": 216 + "0": 0, + "1": 216 }, "bbox": [ 548, diff --git a/tests/unit_tests/test_utils.py b/tests/unit_tests/test_utils.py index e280c20..c96bb56 100644 --- a/tests/unit_tests/test_utils.py +++ b/tests/unit_tests/test_utils.py @@ -5,6 +5,7 @@ from llist import sllist from trackreid import utils +from trackreid.configs.output_data_positions import OutputDataPositions from trackreid.tracked_object import TrackedObject # Load tracked object data @@ -87,5 +88,5 @@ def test_reshape_tracker_result(): def test_get_nb_output_cols(): - output_positions = {"feature1": 1, "feature2": [1, 2, 3]} - assert utils.get_nb_output_cols(output_positions) == 4 + output_positions = OutputDataPositions() + assert utils.get_nb_output_cols(output_positions) == 10 diff --git a/trackreid/args/reid_args.py b/trackreid/args/reid_args.py deleted file mode 100644 index cdcd456..0000000 --- a/trackreid/args/reid_args.py +++ /dev/null @@ -1,19 +0,0 @@ -POSSIBLE_CLASSES = [0, 1] -MAPPING_CLASSES = {0: "shop_item", 1: "personal_item"} - -INPUT_POSITIONS = { - "object_id": 4, - "category": 5, - "bbox": [0, 1, 2, 3], - "confidence": 6, -} - -OUTPUT_POSITIONS = { - "frame_id": 0, - "object_id": 1, - "category": 2, - "bbox": [3, 4, 5, 6], - "confidence": 7, - "mean_confidence": 8, - "tracker_id": 9, -} diff --git a/trackreid/configs/input_data_positions.py b/trackreid/configs/input_data_positions.py new file mode 100644 index 0000000..ec750ce --- /dev/null +++ b/trackreid/configs/input_data_positions.py @@ -0,0 +1,25 @@ +from pydantic import BaseModel, Field + + +class InputDataPositions(BaseModel): + bbox: list = Field( + [0, 1, 2, 3], + description="List of bounding box coordinate positions in the input (numpy array)." + + "Coordinates are in the format x,y,w,h by default.", + ) + object_id: int = Field( + 4, + description="Position of the ID assigned by the tracker to each item in the input (numpy array)", + ) + category: int = Field( + 5, + description="Position of the category assigned to each detected object in the input (numpy array)", + ) + confidence: int = Field( + 6, + description="Position of the confidence score (range [0, 1]) for each" + + "detected object in the input (numpy array)", + ) + + +input_data_positions = InputDataPositions() diff --git a/trackreid/configs/output_data_positions.py b/trackreid/configs/output_data_positions.py new file mode 100644 index 0000000..e010f5c --- /dev/null +++ b/trackreid/configs/output_data_positions.py @@ -0,0 +1,35 @@ +from pydantic import BaseModel, Field + + +class OutputDataPositions(BaseModel): + frame_id: int = Field(0, description="Position of the frame id in the output (numpy array)") + object_id: int = Field( + 1, + description="Position of the ID assigned by the reid processor to each item in the output (numpy array)", + ) + category: int = Field( + 2, + description="Position of the category assigned to each detected object in the output (numpy array)", + ) + bbox: list = Field( + [3, 4, 5, 6], + description="List of bounding box coordinate positions in the output (numpy array)." + + "Coordinates are in the format x,y,w,h by default.", + ) + confidence: int = Field( + 7, + description="Position of the confidence score (range [0, 1]) for each" + + " detected object in the output (numpy array)", + ) + mean_confidence: int = Field( + 8, + description="Position of the mean confidence score over object life time (range [0, 1]) for each" + + " tracked object in the output (numpy array)", + ) + tracker_id: int = Field( + 9, + description="Position of the id assigned to the tracker to each object (prior re-identification).", + ) + + +output_data_positions = OutputDataPositions() diff --git a/trackreid/constants/reid_constants.py b/trackreid/configs/reid_constants.py similarity index 100% rename from trackreid/constants/reid_constants.py rename to trackreid/configs/reid_constants.py diff --git a/trackreid/cost_functions/__init__.py b/trackreid/cost_functions/__init__.py new file mode 100644 index 0000000..8c0f3bd --- /dev/null +++ b/trackreid/cost_functions/__init__.py @@ -0,0 +1 @@ +from .bounding_box_distance import bounding_box_distance # noqa: F401 diff --git a/trackreid/cost_functions/bounding_box_distance.py b/trackreid/cost_functions/bounding_box_distance.py new file mode 100644 index 0000000..8deff92 --- /dev/null +++ b/trackreid/cost_functions/bounding_box_distance.py @@ -0,0 +1,28 @@ +import numpy as np + +from trackreid.tracked_object import TrackedObject + + +def bounding_box_distance(candidate: TrackedObject, switcher: TrackedObject) -> float: + """ + Calculates the Euclidean distance between the centers of the bounding boxes of two TrackedObjects. + This distance is used as a measure of dissimilarity between the two objects, with a smaller distance + indicating a higher likelihood of the objects being the same. + + Args: + candidate (TrackedObject): The first TrackedObject. + switcher (TrackedObject): The second TrackedObject. + + Returns: + float: The Euclidean distance between the centers of the bounding boxes of the two TrackedObjects. + """ + # Get the bounding boxes from the Metadata of each TrackedObject + bbox1 = candidate.metadata.bbox + bbox2 = switcher.metadata.bbox + + # Calculate the Euclidean distance between the centers of the bounding boxes + center1 = ((bbox1[0] + bbox1[2]) / 2, (bbox1[1] + bbox1[3]) / 2) + center2 = ((bbox2[0] + bbox2[2]) / 2, (bbox2[1] + bbox2[3]) / 2) + distance = np.sqrt((center1[0] - center2[0]) ** 2 + (center1[1] - center2[1]) ** 2) + + return distance diff --git a/trackreid/matcher.py b/trackreid/matcher.py index 7f58723..9148c2e 100644 --- a/trackreid/matcher.py +++ b/trackreid/matcher.py @@ -3,7 +3,7 @@ import lap import numpy as np -from trackreid.constants.reid_constants import reid_constants +from trackreid.configs.reid_constants import reid_constants from trackreid.tracked_object import TrackedObject @@ -14,6 +14,26 @@ def __init__( selection_function: Callable, cost_function_threshold: Optional[Union[int, float]] = None, ) -> None: + """ + Initializes the Matcher object with the provided cost function, selection function, and cost function threshold. + + Args: + cost_function (Callable): A function that calculates the cost of matching two objects. This function should + take two TrackedObject instances as input and return a numerical value representing the cost of matching + these two objects. A lower cost indicates a higher likelihood of a match. + + selection_function (Callable): A function that determines whether two objects should be considered for + matching. This function should take two TrackedObject instances as input and return a binary value (0 or 1). + A return value of 1 indicates that the pair should be considered for matching, while a return value of 0 + indicates that the pair should not be considered. + + cost_function_threshold (Optional[Union[int, float]]): An optional threshold value for the cost function. + If provided, any pair of objects with a matching cost greater than this threshold will not be considered for + matching. If not provided, all selected pairs will be considered regardless of their matching cost. + + Returns: + None + """ self.cost_function = cost_function self.selection_function = selection_function self.cost_function_threshold = cost_function_threshold @@ -99,7 +119,25 @@ def match( return matches @staticmethod - def linear_assigment(cost_matrix, candidates, switchers): + def linear_assigment( + cost_matrix: np.ndarray, candidates: List[TrackedObject], switchers: List[TrackedObject] + ) -> List[Dict[TrackedObject, TrackedObject]]: + """ + Performs linear assignment on the cost matrix to find the optimal match between candidates and switchers. + + The function uses the Jonker-Volgenant algorithm to solve the linear assignment problem. The algorithm finds the + optimal assignment (minimum total cost) for the given cost matrix. The cost matrix is a 2D numpy array where + each cell represents the cost of assigning a candidate to a switcher. + + Args: + cost_matrix (np.ndarray): A 2D array representing the cost of assigning each candidate to each switcher. + candidates (List[TrackedObject]): A list of candidate TrackedObjects for matching. + switchers (List[TrackedObject]): A list of switcher TrackedObjects to be matched. + + Returns: + List[Dict[TrackedObject, TrackedObject]]: A list of dictionaries where each dictionary represents a match. + The key is a candidate and the value is the corresponding switcher. + """ _, _, row_cols = lap.lapjv( cost_matrix, extend_cost=True, cost_limit=reid_constants.MATCHES.DISALLOWED_MATCH - 0.1 ) diff --git a/trackreid/reid_processor.py b/trackreid/reid_processor.py index 70f2c69..7665f1d 100644 --- a/trackreid/reid_processor.py +++ b/trackreid/reid_processor.py @@ -4,8 +4,9 @@ import numpy as np -from trackreid.args.reid_args import INPUT_POSITIONS, OUTPUT_POSITIONS -from trackreid.constants.reid_constants import reid_constants +from trackreid.configs.input_data_positions import input_data_positions +from trackreid.configs.output_data_positions import output_data_positions +from trackreid.configs.reid_constants import reid_constants from trackreid.matcher import Matcher from trackreid.tracked_object import TrackedObject from trackreid.tracked_object_filter import TrackedObjectFilter @@ -31,16 +32,43 @@ def __init__( file_path: str = "tracks.txt", ) -> None: """ - Initializes the ReidProcessor class. + This initializes the ReidProcessor class. + For information about the required input format and output details, use the following methods: + + ReidProcessor.print_input_data_format_requirements() + ReidProcessor.print_output_data_format_information() + Args: - filter_confidence_threshold: Confidence threshold for the filter. - filter_time_threshold: Time threshold for the filter. - cost_function: Cost function to be used. - selection_function: Selection function to be used. - max_frames_to_rematch (int): Maximum number of frames to rematch. - max_attempt_to_match (int): Maximum attempts to match. - cost_function_threshold (int, float): Maximum cost to rematch 2 objects. + filter_confidence_threshold (float): Confidence threshold for the filter. The filter will only consider + tracked objects that have a mean confidence score during the all transaction above this threshold. + + filter_time_threshold (int): Time threshold for the filter. The filter will only consider tracked objects + that have been seen for a number of frames above this threshold. + + cost_function (Callable): A function that calculates the cost of matching two objects. The cost function + should take two TrackedObject instances as input and return a numerical value representing the cost of + matching these two objects. A lower cost indicates a higher likelihood of a match. + + selection_function (Callable): A function that determines whether two objects should be considered for + matching. The selection function should take two TrackedObject instances as input and return a binary value + (0 or 1). A return value of 1 indicates that the pair should be considered for matching, while a return + value of 0 indicates that the pair should not be considered. + + max_frames_to_rematch (int): Maximum number of frames to rematch. If a switcher is lost for a number of + frames greater than this value, it will be flagged as lost forever. + + max_attempt_to_match (int): Maximum number of attempts to match a candidate. If a candidate has not been + rematched despite a number of attempts equal to this value, it will be flagged as a stable object. + + cost_function_threshold (Optional[Union[int, float]]): An maximal threshold value for the cost function. + If provided, any pair of objects with a matching cost greater than this threshold will not be considered + for matching. If not provided, all selected pairs will be considered regardless of their matching cost. + + save_to_txt (bool): A flag indicating whether to save the results to a text file. If set to True, the + results will be saved to a text file specified by the file_path parameter. + + file_path (str): The path to the text file where the results will be saved if save_to_txt is set to True. """ self.matcher = Matcher( @@ -61,14 +89,14 @@ def __init__( self.max_attempt_to_match = max_attempt_to_match self.frame_id = 0 - self.nb_output_cols = get_nb_output_cols(output_positions=OUTPUT_POSITIONS) + self.nb_output_cols = get_nb_output_cols(output_positions=output_data_positions) self.save_to_txt = save_to_txt self.file_path = file_path def set_file_path(self, new_file_path: str) -> None: """ - Sets a new file path. + Sets a new file path for saving txt data. Args: new_file_path (str): The new file path. @@ -77,6 +105,12 @@ def set_file_path(self, new_file_path: str) -> None: @property def nb_corrections(self) -> int: + """ + Calculates and returns the total number of corrections made across all tracked objects. + + Returns: + int: Total number of corrections. + """ nb_corrections = 0 for obj in self.all_tracked_objects: nb_corrections += obj.nb_corrections @@ -84,6 +118,12 @@ def nb_corrections(self) -> int: @property def nb_tracker_ids(self) -> int: + """ + Calculates and returns the total number of tracker IDs across all tracked objects. + + Returns: + int: Total number of tracker IDs. + """ tracker_ids = 0 for obj in self.all_tracked_objects: tracker_ids += obj.nb_ids @@ -91,10 +131,23 @@ def nb_tracker_ids(self) -> int: @property def corrected_objects(self) -> List["TrackedObject"]: + """ + Returns a list of tracked objects that have been corrected. + + Returns: + List[TrackedObject]: List of corrected tracked objects. + """ return [obj for obj in self.all_tracked_objects if obj.nb_corrections] @property def seen_objects(self) -> List["TrackedObject"]: + """ + Returns a list of tracked objects that have been seen, excluding those in the + states TRACKER_OUTPUT and FILTERED_OUTPUT. + + Returns: + List[TrackedObject]: List of seen tracked objects. + """ return filter_objects_by_state( tracked_objects=self.all_tracked_objects, states=[reid_constants.STATES.TRACKER_OUTPUT, reid_constants.STATES.FILTERED_OUTPUT], @@ -103,21 +156,49 @@ def seen_objects(self) -> List["TrackedObject"]: @property def mean_nb_corrections(self) -> float: + """ + Calculates and returns the mean number of corrections across all tracked objects. + + Returns: + float: Mean number of corrections. + """ return self.nb_corrections / len(self.all_tracked_objects) - def update( - self, tracker_output: np.ndarray, frame_id: int - ) -> Union[np.ndarray, List[TrackedObject]]: + def update(self, tracker_output: np.ndarray, frame_id: int) -> np.ndarray: """ - Processes the tracker output. + Processes the tracker output and updates internal states. + + All input data should be of numeric type, either integers or floats. + Here's an example of how the input data should look like based on the schema: + + | bbox (0-3) | object_id (4) | category (5) | confidence (6) | + |-----------------|---------------|--------------|----------------| + | 50, 60, 120, 80 | 1 | 1 | 0.91 | + | 50, 60, 120, 80 | 2 | 0 | 0.54 | + + Each row represents a detected object. The first four columns represent the bounding box coordinates + (x, y, width, height), the fifth column represents the object ID assigned by the tracker, + the sixth column represents the category of the detected object, and the seventh column represents + the confidence score of the detection. + + You can use ReidProcessor.print_input_data_requirements() for more insight. + + Here's an example of how the output data looks like based on the schema: + + | frame_id (0) | object_id (1) | category (2) | bbox (3-6) | confidence (7) | mean_confidence (8) | tracker_id (9) | + |--------------|---------------|--------------|-----------------|----------------|---------------------|----------------| + | 1 | 1 | 1 | 50, 60, 120, 80 | 0.91 | 0.85 | 1 | + | 2 | 2 | 0 | 50, 60, 120, 80 | 0.54 | 0.60 | 2 | + + You can use ReidProcessor.print_output_data_format_information() for more insight. Args: tracker_output (np.ndarray): The tracker output. frame_id (int): The frame id. Returns: - Union[np.ndarray, List[TrackedObject]]: The processed output. - """ + np.ndarray: The processed output. + """ # noqa: E501 if tracker_output.size: # empty tracking self.all_tracked_objects, current_tracker_ids = self._preprocess( tracker_output=tracker_output, frame_id=frame_id @@ -145,7 +226,7 @@ def _preprocess(self, tracker_output: np.ndarray, frame_id: int) -> List["Tracke List["TrackedObject"]: The preprocessed output. """ reshaped_tracker_output = reshape_tracker_result(tracker_output=tracker_output) - current_tracker_ids = list(reshaped_tracker_output[:, INPUT_POSITIONS["object_id"]]) + current_tracker_ids = list(reshaped_tracker_output[:, input_data_positions.object_id]) self.all_tracked_objects = self._update_tracked_objects( tracker_output=reshaped_tracker_output, frame_id=frame_id @@ -168,7 +249,7 @@ def _update_tracked_objects( """ self.frame_id = frame_id for object_id, data_line in zip( - tracker_output[:, INPUT_POSITIONS["object_id"]], tracker_output + tracker_output[:, input_data_positions.object_id], tracker_output ): if object_id not in self.all_tracked_objects: new_tracked_object = TrackedObject( @@ -462,6 +543,8 @@ def _postprocess( ) -> np.ndarray: """ Postprocesses the current tracker IDs. + It selects the stable TrackedObjects, and formats their datas in the output + to match requirements. Args: current_tracker_ids (List[Union[int, float]]): The current tracker IDs. @@ -478,7 +561,7 @@ def _postprocess( reid_output = np.zeros((len(stable_objects), self.nb_output_cols)) for idx, stable_object in enumerate(stable_objects): - for required_variable in OUTPUT_POSITIONS: + for required_variable in output_data_positions.model_json_schema()["properties"].keys(): output = ( self.frame_id if required_variable == "frame_id" @@ -488,7 +571,7 @@ def _postprocess( raise NameError( f"Attribute {required_variable} not in TrackedObject. Check your required output names." ) - reid_output[idx, OUTPUT_POSITIONS[required_variable]] = output + reid_output[idx, getattr(output_data_positions, required_variable)] = output return reid_output @@ -518,3 +601,55 @@ def to_dict(self) -> Dict: for tracked_object in self.all_tracked_objects: data[tracked_object.object_id] = tracked_object.to_dict() return data + + @staticmethod + def print_input_data_format_requirements(): + """ + + Prints the input data format requirements. + + All input data should be of numeric type, either integers or floats. + Here's an example of how the input data should look like based on the schema: + + | bbox (0-3) | object_id (4) | category (5) | confidence (6) | + |-----------------|---------------|--------------|----------------| + | 50, 60, 120, 80 | 1 | 1 | 0.91 | + | 50, 60, 120, 80 | 2 | 0 | 0.54 | + + Each row represents a detected object. The first four columns represent the bounding box coordinates + (x, y, width, height), the fifth column represents the object ID assigned by the tracker, + the sixth column represents the category of the detected object, and the seventh column represents + the confidence score of the detection. + """ + input_schema = input_data_positions.model_json_schema() + + print("Input Data Format Requirements:") + for name, properties in input_schema["properties"].items(): + print("-" * 50) + print(f"{name}: {properties['description']}") + print( + f"{name} (position of {name} in the input array must be): {properties['default']}" + ) + + @staticmethod + def print_output_data_format_information(): + """ + Prints the output data format information. + + Here's an example of how the output data looks like based on the schema: + + | frame_id (0) | object_id (1) | category (2) | bbox (3-6) | confidence (7) | mean_confidence (8) | tracker_id (9) | + |--------------|---------------|--------------|------------|----------------|-------------------|------------------| + | 1 | 1 | 1 | 50,60,120,80 | 0.91 | 0.85 | 1 | + | 2 | 2 | 0 | 50,60,120,80 | 0.54 | 0.60 | 2 | + + """ # noqa: E501 + output_schema = output_data_positions.model_json_schema() + + print("\nOutput Data Format:") + for name, properties in output_schema["properties"].items(): + print("-" * 50) + print(f"{name}: {properties['description']}") + print( + f"{name} (position of {name} in the output array will be): {properties['default']}" + ) diff --git a/trackreid/selection_functions/__init__.py b/trackreid/selection_functions/__init__.py new file mode 100644 index 0000000..b16c7b2 --- /dev/null +++ b/trackreid/selection_functions/__init__.py @@ -0,0 +1 @@ +from .select_by_category import select_by_category # noqa: F401 diff --git a/trackreid/selection_functions/select_by_category.py b/trackreid/selection_functions/select_by_category.py new file mode 100644 index 0000000..38aec83 --- /dev/null +++ b/trackreid/selection_functions/select_by_category.py @@ -0,0 +1,18 @@ +from trackreid.tracked_object import TrackedObject + + +def select_by_category(candidate: TrackedObject, switcher: TrackedObject) -> int: + """ + Compares the categories of two TrackedObject instances. + This selection function is used as a measure of similarity between the two objects, + matches are discard if this function returns 0. + + Args: + candidate (TrackedObject): The first TrackedObject instance. + switcher (TrackedObject): The second TrackedObject instance. + + Returns: + int: Returns 1 if the categories of the two objects are the same, otherwise returns 0. + """ + # Compare the categories of the two objects + return 1 if candidate.category == switcher.category else 0 diff --git a/trackreid/tracked_object.py b/trackreid/tracked_object.py index 96c5181..b28f7af 100644 --- a/trackreid/tracked_object.py +++ b/trackreid/tracked_object.py @@ -6,7 +6,7 @@ import numpy as np from llist import sllist -from trackreid.constants.reid_constants import reid_constants +from trackreid.configs.reid_constants import reid_constants from trackreid.tracked_object_metadata import TrackedObjectMetaData from trackreid.utils import split_list_around_value diff --git a/trackreid/tracked_object_filter.py b/trackreid/tracked_object_filter.py index e7ec627..35461f9 100644 --- a/trackreid/tracked_object_filter.py +++ b/trackreid/tracked_object_filter.py @@ -1,12 +1,37 @@ -from trackreid.constants.reid_constants import reid_constants +from trackreid.configs.reid_constants import reid_constants class TrackedObjectFilter: + """ + The TrackedObjectFilter class is used to filter tracked objects based on their + confidence and the number of frames they have been observed in. + + Args: + confidence_threshold (float): The minimum mean confidence level required for a tracked + object to be considered valid. + frames_seen_threshold (int): The minimum number of frames a tracked object + must be observed in to be considered valid. + """ + def __init__(self, confidence_threshold, frames_seen_threshold): self.confidence_threshold = confidence_threshold self.frames_seen_threshold = frames_seen_threshold def update(self, tracked_object): + """ + The update method is used to update the state of a tracked object based on its confidence + and the number of frames it has been observed in. + + If the tracked object's state is TRACKER_OUTPUT, and its mean confidence is greater than the + confidence_threshold, and it has been observed in more frames than the frames_seen_threshold, + its state is updated to FILTERED_OUTPUT. + + If the tracked object's mean confidence is less than the confidence_threshold, its state is + updated to TRACKER_OUTPUT. + + Args: + tracked_object (TrackedObject): The tracked object to update. + """ if tracked_object.get_state() == reid_constants.STATES.TRACKER_OUTPUT: if ( tracked_object.metadata.mean_confidence() > self.confidence_threshold diff --git a/trackreid/tracked_object_metadata.py b/trackreid/tracked_object_metadata.py index bee82ff..197e3f1 100644 --- a/trackreid/tracked_object_metadata.py +++ b/trackreid/tracked_object_metadata.py @@ -1,13 +1,12 @@ import json -from trackreid.args.reid_args import INPUT_POSITIONS, MAPPING_CLASSES, POSSIBLE_CLASSES -from trackreid.utils import get_key_from_value +from trackreid.configs.input_data_positions import input_data_positions class TrackedObjectMetaData: def __init__(self, data_line, frame_id): self.first_frame_id = frame_id - self.class_counts = {class_name: 0 for class_name in map(int, POSSIBLE_CLASSES)} + self.class_counts = {} self.observations = 0 self.confidence_sum = 0 self.confidence = 0 @@ -15,10 +14,11 @@ def __init__(self, data_line, frame_id): def update(self, data_line, frame_id): self.last_frame_id = frame_id - class_name = int(data_line[INPUT_POSITIONS["category"]]) + + class_name = int(data_line[input_data_positions.category]) self.class_counts[class_name] = self.class_counts.get(class_name, 0) + 1 - self.bbox = list(map(int, data_line[INPUT_POSITIONS["bbox"]])) - confidence = float(data_line[INPUT_POSITIONS["confidence"]]) + self.bbox = list(map(int, data_line[input_data_positions.bbox])) + confidence = float(data_line[input_data_positions.confidence]) self.confidence = confidence self.confidence_sum += confidence self.observations += 1 @@ -32,7 +32,7 @@ def merge(self, other_object): self.confidence = other_object.confidence self.bbox = other_object.bbox self.last_frame_id = other_object.last_frame_id - for class_name in map(int, POSSIBLE_CLASSES): + for class_name in other_object.class_counts.keys(): self.class_counts[class_name] = self.class_counts.get( class_name, 0 ) + other_object.class_counts.get(class_name, 0) @@ -52,7 +52,7 @@ def copy(self): def to_dict(self): class_counts_str = { - MAPPING_CLASSES[class_name]: count for class_name, count in self.class_counts.items() + str(class_name): count for class_name, count in self.class_counts.items() } data = { "first_frame_id": int(self.first_frame_id), @@ -71,10 +71,7 @@ def to_json(self): @classmethod def from_dict(cls, data: dict): class_counts_str = data["class_counts"] - class_counts = { - get_key_from_value(MAPPING_CLASSES, class_name): count - for class_name, count in class_counts_str.items() - } + class_counts = {int(class_name): count for class_name, count in class_counts_str.items()} obj = cls.__new__(cls) obj.first_frame_id = data["first_frame_id"] obj.last_frame_id = data["last_frame_id"] @@ -97,7 +94,7 @@ def class_proportions(self): for class_name, count in self.class_counts.items() } else: - proportions = {class_name: 0.0 for class_name in map(int, POSSIBLE_CLASSES)} + proportions = None return proportions def percentage_of_time_seen(self, frame_id): diff --git a/trackreid/utils.py b/trackreid/utils.py index bcc09b3..65f8bcb 100644 --- a/trackreid/utils.py +++ b/trackreid/utils.py @@ -3,14 +3,35 @@ import numpy as np from llist import sllist +from trackreid.configs.output_data_positions import OutputDataPositions -def get_top_list_correction(tracked_ids: list): + +def get_top_list_correction(tracked_ids: List): + """ + Function to get the last value of each re_id_chain in tracked_ids. + + Args: + tracked_ids (list): List of tracked ids. + + Returns: + list: List of last values of each re_id_chain in tracked_ids. + """ top_list_correction = [tracked_id.re_id_chain.last.value for tracked_id in tracked_ids] return top_list_correction def split_list_around_value(my_list: sllist, value_to_split: float): + """ + Function to split a list around a given value. + + Args: + my_list (sllist): The list to split. + value_to_split (float): The value to split the list around. + + Returns: + tuple: Two lists, before and after the split value. + """ if value_to_split == my_list.last.value: raise NameError("split on the last") if value_to_split not in my_list: @@ -35,7 +56,18 @@ def split_list_around_value(my_list: sllist, value_to_split: float): return before, after -def filter_objects_by_state(tracked_objects: List, states: Union[int, list], exclusion=False): +def filter_objects_by_state(tracked_objects: List, states: Union[int, List[int]], exclusion=False): + """ + Function to filter tracked objects by their state. + + Args: + tracked_objects (List): List of tracked objects. + states (Union[int, list]): State or list of states to filter by. + exclusion (bool, optional): If True, exclude objects with the given states. Defaults to False. + + Returns: + list: List of filtered tracked objects. + """ if isinstance(states, int): states = [states] if exclusion: @@ -46,8 +78,21 @@ def filter_objects_by_state(tracked_objects: List, states: Union[int, list], exc def filter_objects_by_category( - tracked_objects: List, category: Union[Union[float, int], list], exclusion=False + tracked_objects: List, + category: Union[Union[float, int], List[Union[float, int]]], + exclusion=False, ): + """ + Function to filter tracked objects by their category. + + Args: + tracked_objects (List): List of tracked objects. + category (Union[Union[float, int], list]): Category or list of categories to filter by. + exclusion (bool, optional): If True, exclude objects with the given categories. Defaults to False. + + Returns: + list: List of filtered tracked objects. + """ if isinstance(category, Union[float, int]): category = [category] if exclusion: @@ -58,26 +103,38 @@ def filter_objects_by_category( def reshape_tracker_result(tracker_output: np.ndarray): + """ + Function to reshape the tracker output if it has only one dimension. + + Args: + tracker_output (np.ndarray): The tracker output to reshape. + + Returns: + np.ndarray: The reshaped tracker output. + """ if tracker_output.ndim == 1: tracker_output = np.expand_dims(tracker_output, 0) return tracker_output -def get_nb_output_cols(output_positions: dict): +def get_nb_output_cols(output_positions: OutputDataPositions): + """ + Function to get the number of output columns based on the model json schema. + + Args: + output_positions (OutputDataPositions): The output data positions. + + Returns: + int: The number of output columns. + """ + schema = output_positions.model_json_schema() nb_cols = 0 - for feature in output_positions.values(): - if type(feature) is int: + for feature in schema["properties"]: + if schema["properties"][feature]["type"] == "integer": nb_cols += 1 - elif type(feature) is list: - nb_cols += len(feature) + elif schema["properties"][feature]["type"] == "array": + nb_cols += len(schema["properties"][feature]["default"]) else: - raise TypeError("Unkown type in required output positions.") + raise TypeError("Unknown type in required output positions.") return nb_cols - - -def get_key_from_value(dictionary, target_value): - for key, value in dictionary.items(): - if value == target_value: - return key - return None From e362e62bd086e4e1c1b2c588d36e4fb027fcc4ca Mon Sep 17 00:00:00 2001 From: Tristan Pepin <122389133+tristanpepinartefact@users.noreply.github.com> Date: Thu, 23 Nov 2023 10:37:34 +0100 Subject: [PATCH 08/13] Add docstring (#18) --- trackreid/configs/reid_constants.py | 4 +- trackreid/reid_processor.py | 126 +++++++++++++++------- trackreid/tracked_object.py | 153 ++++++++++++++++++++++++--- trackreid/tracked_object_filter.py | 5 +- trackreid/tracked_object_metadata.py | 144 ++++++++++++++++++++++++- 5 files changed, 369 insertions(+), 63 deletions(-) diff --git a/trackreid/configs/reid_constants.py b/trackreid/configs/reid_constants.py index 299229a..9606e06 100644 --- a/trackreid/configs/reid_constants.py +++ b/trackreid/configs/reid_constants.py @@ -13,8 +13,8 @@ class States(BaseModel): DESCRIPTION: ClassVar[dict] = { LOST_FOREVER: "switcher never rematched", - TRACKER_OUTPUT: "bytetrack output not in reid process", - FILTERED_OUTPUT: "bytetrack output entering reid process", + TRACKER_OUTPUT: "tracker output not in reid process", + FILTERED_OUTPUT: "tracker output entering reid process", STABLE: "stable object", SWITCHER: "lost object to be re-matched", CANDIDATE: "new object to be matched", diff --git a/trackreid/reid_processor.py b/trackreid/reid_processor.py index 7665f1d..8fde685 100644 --- a/trackreid/reid_processor.py +++ b/trackreid/reid_processor.py @@ -7,7 +7,9 @@ from trackreid.configs.input_data_positions import input_data_positions from trackreid.configs.output_data_positions import output_data_positions from trackreid.configs.reid_constants import reid_constants +from trackreid.cost_functions import bounding_box_distance from trackreid.matcher import Matcher +from trackreid.selection_functions import select_by_category from trackreid.tracked_object import TrackedObject from trackreid.tracked_object_filter import TrackedObjectFilter from trackreid.utils import ( @@ -19,58 +21,84 @@ class ReidProcessor: - def __init__( - self, - filter_confidence_threshold: float, - filter_time_threshold: int, - cost_function: Callable, - selection_function: Callable, - max_frames_to_rematch: int, - max_attempt_to_match: int, - cost_function_threshold: Optional[Union[int, float]] = None, - save_to_txt: bool = True, - file_path: str = "tracks.txt", - ) -> None: - """ - This initializes the ReidProcessor class. - For information about the required input format and output details, use the following methods: + """ + The ReidProcessor class is designed to correct the results of tracking algorithms by reconciling and reassigning + lost or misidentified IDs. This ensures a consistent and accurate tracking of objects over time. - ReidProcessor.print_input_data_format_requirements() - ReidProcessor.print_output_data_format_information() + All input data should be of numeric type, either integers or floats. + Here's an example of how the input data should look like based on the schema: + | bbox (0-3) | object_id (4) | category (5) | confidence (6) | + |-----------------|---------------|--------------|----------------| + | 50, 60, 120, 80 | 1 | 1 | 0.91 | + | 50, 60, 120, 80 | 2 | 0 | 0.54 | - Args: - filter_confidence_threshold (float): Confidence threshold for the filter. The filter will only consider - tracked objects that have a mean confidence score during the all transaction above this threshold. + Each row represents a detected object. The first four columns represent the bounding box coordinates + (x, y, width, height), the fifth column represents the object ID assigned by the tracker, + the sixth column represents the category of the detected object, and the seventh column represents + the confidence score of the detection. - filter_time_threshold (int): Time threshold for the filter. The filter will only consider tracked objects - that have been seen for a number of frames above this threshold. + You can use ReidProcessor.print_input_data_requirements() for more insight. - cost_function (Callable): A function that calculates the cost of matching two objects. The cost function - should take two TrackedObject instances as input and return a numerical value representing the cost of - matching these two objects. A lower cost indicates a higher likelihood of a match. + Here's an example of how the output data looks like based on the schema: - selection_function (Callable): A function that determines whether two objects should be considered for - matching. The selection function should take two TrackedObject instances as input and return a binary value - (0 or 1). A return value of 1 indicates that the pair should be considered for matching, while a return - value of 0 indicates that the pair should not be considered. + | frame_id (0) | object_id (1) | category (2) | bbox (3-6) | confidence (7) | mean_confidence (8) | tracker_id (9) | + |--------------|---------------|--------------|-----------------|----------------|---------------------|----------------| + | 1 | 1 | 1 | 50, 60, 120, 80 | 0.91 | 0.85 | 1 | + | 2 | 2 | 0 | 50, 60, 120, 80 | 0.54 | 0.60 | 2 | - max_frames_to_rematch (int): Maximum number of frames to rematch. If a switcher is lost for a number of - frames greater than this value, it will be flagged as lost forever. + You can use ReidProcessor.print_output_data_format_information() for more insight. - max_attempt_to_match (int): Maximum number of attempts to match a candidate. If a candidate has not been - rematched despite a number of attempts equal to this value, it will be flagged as a stable object. - cost_function_threshold (Optional[Union[int, float]]): An maximal threshold value for the cost function. - If provided, any pair of objects with a matching cost greater than this threshold will not be considered - for matching. If not provided, all selected pairs will be considered regardless of their matching cost. + Args: + filter_confidence_threshold (float): Confidence threshold for the filter. The filter will only consider + tracked objects that have a mean confidence score during the all transaction above this threshold. - save_to_txt (bool): A flag indicating whether to save the results to a text file. If set to True, the - results will be saved to a text file specified by the file_path parameter. + filter_time_threshold (int): Time threshold for the filter. The filter will only consider tracked objects + that have been seen for a number of frames above this threshold. - file_path (str): The path to the text file where the results will be saved if save_to_txt is set to True. - """ + max_frames_to_rematch (int): Maximum number of frames to rematch. If a switcher is lost for a number of + frames greater than this value, it will be flagged as lost forever. + + max_attempt_to_match (int): Maximum number of attempts to match a candidate. If a candidate has not been + rematched despite a number of attempts equal to this value, it will be flagged as a stable object. + + selection_function (Callable): A function that determines whether two objects should be considered for + matching. The selection function should take two TrackedObject instances as input and return a binary value + (0 or 1). A return value of 1 indicates that the pair should be considered for matching, while a return + value of 0 indicates that the pair should not be considered. + Defaults to select_by_category. + cost_function (Callable): A function that calculates the cost of matching two objects. The cost function + should take two TrackedObject instances as input and return a numerical value representing the cost of + matching these two objects. A lower cost indicates a higher likelihood of a match. + Defaults to bounding_box_distance. + + cost_function_threshold (Optional[Union[int, float]]): An maximal threshold value for the cost function. + If provided, any pair of objects with a matching cost greater than this threshold will not be considered + for matching. If not provided, all selected pairs will be considered regardless of their matching cost. + Defaults to None. + + save_to_txt (bool): A flag indicating whether to save the results to a text file. If set to True, the + results will be saved to a text file specified by the file_path parameter. + Default to False. + + file_path (str): The path to the text file where the results will be saved if save_to_txt is set to True. + Defaults to tracks.txt + """ # noqa: E501 + + def __init__( + self, + filter_confidence_threshold: float, + filter_time_threshold: int, + max_frames_to_rematch: int, + max_attempt_to_match: int, + selection_function: Callable = select_by_category, + cost_function: Callable = bounding_box_distance, + cost_function_threshold: Optional[Union[int, float]] = None, + save_to_txt: bool = False, + file_path: str = "tracks.txt", + ) -> None: self.matcher = Matcher( cost_function=cost_function, selection_function=selection_function, @@ -302,7 +330,23 @@ def _apply_filtering(self) -> List[TrackedObject]: def _perform_reid_process(self, current_tracker_ids: List[Union[int, float]]) -> None: """ - Performs the reid process. + Performs the re-identification process on tracked objects. + + This method is responsible for managing the state of tracked objects and identifying potential + candidates for re-identification. It follows these steps: + + 1. correct_reid_chains: Corrects the re-identification chains of all tracked objects + based on the current tracker IDs. This avoids potential duplicates. + 2. update_switchers_states: Updates the states of switchers (objects that have switched IDs) + based on the current frame's tracked objects, the maximum number of frames to rematch, and the current frame ID. + 3. update_candidates_states: Updates the states of candidate objects (potential matches for re-identification) + based on the maximum number of attempts to match and the current frame ID. + 4. identify_switchers: Identifies switchers based on the current and last frame's tracked objects and + updates the state of all tracked objects accordingly. + 5. identify_candidates: Identifies candidates for re-identification and updates the state of all + tracked objects accordingly. + 6. match: Matches candidates with switchers using Jonker-Volgenant algorithm. + 7. process_matches: Processes the matches and updates the state of all tracked objects accordingly. Args: current_tracker_ids (List[Union[int, float]]): The current tracker IDs. diff --git a/trackreid/tracked_object.py b/trackreid/tracked_object.py index b28f7af..bd42a19 100644 --- a/trackreid/tracked_object.py +++ b/trackreid/tracked_object.py @@ -12,6 +12,47 @@ class TrackedObject: + """ + The TrackedObject class represents an object that is being tracked in a video frame. + It contains information about the object's state, its unique identifiers, and metadata. + + The object's state is an integer that represents the current state of the object in the + reid process. The states can take the following values: + + - LOST_FOREVER (-3): "Switcher never rematched" + - TRACKER_OUTPUT (-2): "Tracker output not in reid process" + - FILTERED_OUTPUT (-1): "Tracker output entering reid process" + - STABLE (0): "Stable object" + - SWITCHER (1): "Lost object to be re-matched" + - CANDIDATE (2): "New object to be matched" + + The object's unique identifiers are stored in a singly linked list (sllist) called re_id_chain. The re_id_chain + is a crucial component in the codebase. It stores the history of the object's unique identifiers, allowing for + tracking of the object across different frames. The first value in the re_id_chain + is the original object ID, while the last value is the most recent tracker ID assigned to the object. + + The metadata is an instance of the TrackedObjectMetaData class, which contains additional information + about the object. + + The TrackedObject class provides several methods for manipulating and accessing the data it contains. + These include methods for merging two TrackedObject instances, updating the metadata, and converting the + TrackedObject instance to a dictionary or JSON string. + + The TrackedObject class also provides several properties for accessing specific pieces of data, such as the object's + unique identifier, its state, and its metadata. + + Args: + object_ids (Union[Union[float, int], sllist]): The unique identifiers for the object. + state (int): The current state of the object. + metadata (Union[np.ndarray, TrackedObjectMetaData]): The metadata for the object. It can be either a + TrackedObjectMetaData object, or a data line, i.e. output of detection model. If metadata is initialized + with a TrackedObjectMetaData object, a frame_id must be given. + frame_id (Optional[int], optional): The frame ID where the object was first seen. Defaults to None. + + Raises: + NameError: If the type of object_ids or metadata is unrecognized. + """ + def __init__( self, object_ids: Union[Union[float, int], sllist], @@ -40,7 +81,7 @@ def __init__( def copy(self): return TrackedObject(object_ids=self.re_id_chain, state=self.state, metadata=self.metadata) - def merge(self, other_object): + def merge(self, other_object: TrackedObject): if not isinstance(other_object, TrackedObject): raise TypeError("Can only merge with another TrackedObject.") @@ -54,43 +95,80 @@ def merge(self, other_object): @property def object_id(self): + """ + Returns the first value in the re_id_chain which represents the object id. + """ return self.re_id_chain.first.value @property def tracker_id(self): + """ + Returns the last value in the re_id_chain which represents the last tracker id. + """ return self.re_id_chain.last.value @property def category(self): + """ + Returns the category with the maximum count in the class_counts dictionary of the metadata. + """ return max(self.metadata.class_counts, key=self.metadata.class_counts.get) @property def confidence(self): + """ + Returns the confidence value from the metadata. + """ return self.metadata.confidence @property def mean_confidence(self): + """ + Returns the mean confidence value from the metadata. + """ return self.metadata.mean_confidence() @property def bbox(self): + """ + Returns the bounding box coordinates from the metadata. + """ return self.metadata.bbox @property def nb_ids(self): + """ + Returns the number of ids in the re_id_chain. + """ return len(self.re_id_chain) @property def nb_corrections(self): + """ + Returns the number of corrections which is the number of ids in the re_id_chain minus one. + """ return self.nb_ids - 1 - def get_age(self, frame_id): + def get_age(self, frame_id: int): + """ + Calculates and returns the age of the tracked object based on the given frame id. + Age is defined as the difference between the current frame id and the first frame id where + the object was detected. + """ return frame_id - self.metadata.first_frame_id - def get_nb_frames_since_last_appearance(self, frame_id): + def get_nb_frames_since_last_appearance(self, frame_id: int): + """ + Calculates and returns the number of frames since the last appearance of the tracked object. + This is computed as the difference between the current frame id and the last frame id where + the object was detected. + """ return frame_id - self.metadata.last_frame_id def get_state(self): + """ + Returns the current state of the tracked object. + """ return self.state def __hash__(self): @@ -106,6 +184,20 @@ def __str__(self): return f"{self.__repr__()}, metadata : {self.metadata}" def update_metadata(self, data_line: np.ndarray, frame_id: int): + """ + Updates the metadata of the tracked object based on new detection data. + + This method is used to update the metadata of a tracked object whenever new detection data is available. + It updates the metadata by calling the update method of the TrackedObjectMetaData instance associated with + the tracked object. + + Args: + data_line (np.ndarray): The detection data for a single frame. It contains information such as the + class name, bounding box coordinates, and confidence level of the detection. + + frame_id (int): The frame id where the object was detected. This is used to update the last frame id of + the tracked object. + """ self.metadata.update(data_line=data_line, frame_id=frame_id) def __eq__(self, other): @@ -116,6 +208,20 @@ def __eq__(self, other): return False def cut(self, object_id: int): + """ + Splits the re_id_chain of the tracked object at the specified object_id and creates a new TrackedObject + instance with the remaining part of the re_id_chain. The original TrackedObject instance retains the part + of the re_id_chain before the specified object_id. + + Args: + object_id (int): The object_id at which to split the re_id_chain. + + Raises: + NameError: If the specified object_id is not found in the re_id_chain of the tracked object. + + Returns: + tuple: A tuple containing the new TrackedObject instance and the original TrackedObject instance. + """ if object_id not in self.re_id_chain: raise NameError( f"Trying to cut object {self} with {object_id} that is not in the re-id chain." @@ -131,18 +237,13 @@ def cut(self, object_id: int): new_object.metadata.first_frame_id = new_object.metadata.last_frame_id return new_object, self - def format_data(self): - return [ - self.object_id, - self.category, - self.bbox[0], - self.bbox[1], - self.bbox[2], - self.bbox[3], - self.confidence, - ] - def to_dict(self): + """ + Converts the TrackedObject instance to a dictionary. + + Returns: + dict: A dictionary representation of the TrackedObject instance. + """ data = { "object_id": float(self.object_id), "state": int(self.state), @@ -152,10 +253,25 @@ def to_dict(self): return data def to_json(self): + """ + Converts the TrackedObject instance to a JSON string. + + Returns: + str: A JSON string representation of the TrackedObject instance. + """ return json.dumps(self.to_dict(), indent=4) @classmethod def from_dict(cls, data: dict): + """ + Creates a new TrackedObject instance from a dictionary. + + Args: + data (dict): A dictionary containing the data for the TrackedObject instance. + + Returns: + TrackedObject: A new TrackedObject instance created from the dictionary. + """ obj = cls.__new__(cls) obj.state = data["state"] obj.re_id_chain = sllist(data["re_id_chain"]) @@ -164,5 +280,14 @@ def from_dict(cls, data: dict): @classmethod def from_json(cls, json_str: str): + """ + Creates a new TrackedObject instance from a JSON string. + + Args: + json_str (str): A JSON string containing the data for the TrackedObject instance. + + Returns: + TrackedObject: A new TrackedObject instance created from the JSON string. + """ data = json.loads(json_str) return cls.from_dict(data) diff --git a/trackreid/tracked_object_filter.py b/trackreid/tracked_object_filter.py index 35461f9..c4a56fa 100644 --- a/trackreid/tracked_object_filter.py +++ b/trackreid/tracked_object_filter.py @@ -1,4 +1,5 @@ from trackreid.configs.reid_constants import reid_constants +from trackreid.tracked_object import TrackedObject class TrackedObjectFilter: @@ -13,11 +14,11 @@ class TrackedObjectFilter: must be observed in to be considered valid. """ - def __init__(self, confidence_threshold, frames_seen_threshold): + def __init__(self, confidence_threshold: float, frames_seen_threshold: int): self.confidence_threshold = confidence_threshold self.frames_seen_threshold = frames_seen_threshold - def update(self, tracked_object): + def update(self, tracked_object: TrackedObject): """ The update method is used to update the state of a tracked object based on its confidence and the number of frames it has been observed in. diff --git a/trackreid/tracked_object_metadata.py b/trackreid/tracked_object_metadata.py index 197e3f1..ea04bfc 100644 --- a/trackreid/tracked_object_metadata.py +++ b/trackreid/tracked_object_metadata.py @@ -1,10 +1,24 @@ import json +import numpy as np + from trackreid.configs.input_data_positions import input_data_positions class TrackedObjectMetaData: - def __init__(self, data_line, frame_id): + """ + The TrackedObjectMetaData class is used to store and manage metadata for tracked objects in a video frame. + This metadata includes information such as the frame ID where the object was first seen, the class counts + (how many times each class was detected), the bounding box coordinates, and the confidence level of the detection. + + This metadata is then use in selection and cost functions to compute likelihood of a match between two objects. + + Usage: + An instance of TrackedObjectMetaData is created by passing a data_line (which contains the detection data + for a single frame) and a frame_id (which identifies the frame where the object was detected). + """ + + def __init__(self, data_line: np.ndarray, frame_id: int): self.first_frame_id = frame_id self.class_counts = {} self.observations = 0 @@ -12,7 +26,33 @@ def __init__(self, data_line, frame_id): self.confidence = 0 self.update(data_line, frame_id) - def update(self, data_line, frame_id): + def update(self, data_line: np.ndarray, frame_id: int): + """ + Updates the metadata of a tracked object based on new detection data. + + This method is used to update the metadata of a tracked object whenever new detection data is available. + It updates the last frame id, class counts, bounding box, confidence, confidence sum, and observations. + + Args: + data_line (np.ndarra): The detection data for a single frame. It contains information such as the + class name, bounding box coordinates, and confidence level of the detection. + + frame_id (int): The frame id where the object was detected. This is used to update the last frame id of + the tracked object. + + Updates: + last_frame_id: The last frame id is updated to the frame id where the object was detected. + + class_counts: The class counts are updated by incrementing the count of the detected class by 1. + + bbox: The bounding box is updated to the bounding box coordinates from the detection data. + + confidence: The confidence is updated to the confidence level from the detection data. + + confidence_sum: The confidence sum is updated by adding the confidence level from the detection data. + + observations: The total number of observations is incremented by 1. + """ self.last_frame_id = frame_id class_name = int(data_line[input_data_positions.category]) @@ -24,6 +64,24 @@ def update(self, data_line, frame_id): self.observations += 1 def merge(self, other_object): + """ + Merges the metadata of another TrackedObjectMetaData instance into the current one. + + Args: + other_object (TrackedObjectMetaData): The other TrackedObjectMetaData instance whose metadata + is to be merged with the current instance. + + Raises: + TypeError: If the other_object is not an instance of TrackedObjectMetaData. + + Updates: + observations: The total number of observations is updated by adding the observations of the other object. + confidence_sum: The total confidence sum is updated by adding the confidence sum of the other object. + confidence: The confidence is updated to the confidence of the other object. + bbox: The bounding box is updated to the bounding box of the other object. + last_frame_id: The last frame id is updated to the last frame id of the other object. + class_counts: The class counts are updated by adding the class counts of the other object for each class. + """ if not isinstance(other_object, type(self)): raise TypeError("Can only merge with another TrackedObjectMetaData.") @@ -38,8 +96,14 @@ def merge(self, other_object): ) + other_object.class_counts.get(class_name, 0) def copy(self): + """ + Creates a copy of the current TrackedObjectMetaData instance. + + Returns: + TrackedObjectMetaData: A new instance of TrackedObjectMetaData with the same + properties as the current instance. + """ copy_obj = TrackedObjectMetaData.__new__(TrackedObjectMetaData) - # Update the copied instance with the actual class counts and observations copy_obj.bbox = self.bbox.copy() copy_obj.class_counts = self.class_counts.copy() copy_obj.observations = self.observations @@ -51,6 +115,17 @@ def copy(self): return copy_obj def to_dict(self): + """ + Converts the TrackedObjectMetaData instance to a dictionary. + + The class_counts dictionary is converted to a string-keyed dictionary. + The bounding box list is converted to a list of integers. + The first_frame_id, last_frame_id, confidence, confidence_sum, and observations are converted to their + respective types. + + Returns: + dict: A dictionary representation of the TrackedObjectMetaData instance. + """ class_counts_str = { str(class_name): count for class_name, count in self.class_counts.items() } @@ -66,10 +141,29 @@ def to_dict(self): return data def to_json(self): + """ + Converts the TrackedObjectMetaData instance to a JSON string. + + Returns: + str: A JSON string representation of the TrackedObjectMetaData instance. + """ return json.dumps(self.to_dict(), indent=4) @classmethod def from_dict(cls, data: dict): + """ + Creates a new instance of the class from a dictionary. + + The dictionary should contain the following keys: "first_frame_id", "last_frame_id", "class_counts", + "bbox", "confidence", "confidence_sum", and "observations". The "class_counts" key should map to a + dictionary where the keys are class names (as integers) and the values are counts. + + Args: + data (dict): A dictionary containing the data to populate the new instance. + + Returns: + TrackedObjectMetaData: A new instance of TrackedObjectMetaData populated with the data from the dictionary. + """ class_counts_str = data["class_counts"] class_counts = {int(class_name): count for class_name, count in class_counts_str.items()} obj = cls.__new__(cls) @@ -84,10 +178,25 @@ def from_dict(cls, data: dict): @classmethod def from_json(cls, json_str: str): + """ + Creates a new instance of the class from a JSON string. + + Args: + json_str (str): A JSON string representation of the TrackedObjectMetaData instance. + + Returns: + TrackedObjectMetaData: A new instance of TrackedObjectMetaData populated with the data from the JSON string. + """ data = json.loads(json_str) return cls.from_dict(data) def class_proportions(self): + """ + Calculates the proportions of each class in the tracked object. + + Returns: + dict: A dictionary where the keys are class names and the values are the proportions of each class. + """ if self.observations > 0: proportions = { class_name: count / self.observations @@ -97,7 +206,16 @@ def class_proportions(self): proportions = None return proportions - def percentage_of_time_seen(self, frame_id): + def percentage_of_time_seen(self, frame_id: int): + """ + Calculates the percentage of time the tracked object has been seen. + + Args: + frame_id (int): The current frame id. + + Returns: + float: The percentage of time the tracked object has been seen. + """ if self.observations > 0: percentage = (self.observations / (frame_id - self.first_frame_id + 1)) * 100 else: @@ -105,15 +223,33 @@ def percentage_of_time_seen(self, frame_id): return percentage def mean_confidence(self): + """ + Calculates the mean confidence of the tracked object. + + Returns: + float: The mean confidence of the tracked object. + """ if self.observations > 0: return self.confidence_sum / self.observations else: return 0.0 def __repr__(self) -> str: + """ + Returns a string representation of the TrackedObjectMetaData instance. + + Returns: + str: A string representation of the TrackedObjectMetaData instance. + """ return f"TrackedObjectMetaData(bbox={self.bbox})" def __str__(self): + """ + Returns a string representation of the TrackedObjectMetaData instance. + + Returns: + str: A string representation of the TrackedObjectMetaData instance. + """ return ( f"First frame seen: {self.first_frame_id}, nb observations: {self.observations}, " + f"class proportions: {self.class_proportions()}, bbox: {self.bbox}, " From 5708a9a37b2da99b3fa4ec0fc8db844fc8370a70 Mon Sep 17 00:00:00 2001 From: Tristan Pepin <122389133+tristanpepinartefact@users.noreply.github.com> Date: Thu, 23 Nov 2023 14:29:05 +0100 Subject: [PATCH 09/13] update readme (#20) --- README.md | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index f7015f3..f65be64 100644 --- a/README.md +++ b/README.md @@ -12,8 +12,6 @@ [![Pre-commit](https://img.shields.io/badge/pre--commit-enabled-informational?logo=pre-commit&logoColor=white)](https://github.com/artefactory-fr/track-reid/blob/main/.pre-commit-config.yaml) -TODO: if not done already, check out the [Skaff documentation](https://artefact.roadie.so/catalog/default/component/repo-builder-ds/docs/) for more information about the generated repository. - This Git repository is dedicated to the development of a Python library aimed at correcting the results of tracking algorithms. The primary goal of this library is to reconcile and reassign lost or misidentified IDs, ensuring a consistent and accurate tracking of objects over time. ## Table of Contents @@ -53,11 +51,22 @@ make help ## Usage -TODO: Add usage instructions here +For a quickstart, please refer to the documentation [here](https://artefactory-fr.github.io/track-reid/quickstart_user/). You also have at disposal a demo notebook in `notebooks/starer_kit_reid.ipynb`. -## Documentation +Lets say you have a `dataset` iterable object, composed for each iteartion of a frame id and its associated tracking results. You can call the `ReidProcessor` update class using the following: + +```python +for frame_id, tracker_output in dataset: + corrected_results = reid_processor.update(frame_id = frame_id, tracker_output=tracker_output) +``` -TODO: Github pages is not enabled by default, you need to enable it in the repository settings: Settings > Pages > Source: "Deploy from a branch" / Branch: "gh-pages" / Folder: "/(root)" +At the end of the for loop, information about the correction can be retrieved using the `ReidProcessor` properties. For instance, the list of tracked object can be accessed using: + +```python +reid_processor.seen_objects() +``` + +## Documentation A detailed documentation of this project is available [here](https://artefactory-fr.github.io/track-reid/) From 5b64484c7dc1af992bf201bda8ee05087c35a874 Mon Sep 17 00:00:00 2001 From: Tristan Pepin <122389133+tristanpepinartefact@users.noreply.github.com> Date: Thu, 23 Nov 2023 14:29:16 +0100 Subject: [PATCH 10/13] Tp/add mkdocs (#19) --- docs/code.md | 1 - docs/custom_cost_selection.md | 148 ++++++++++++++++++++++ docs/index.md | 2 +- docs/quickstart_dev.md | 57 +++++++++ docs/quickstart_user.md | 117 +++++++++++++++++ docs/reference/cost_functions.md | 15 +++ docs/reference/matcher.md | 3 + docs/reference/reid_processor.md | 3 + docs/reference/selection_functions.md | 17 +++ docs/reference/tracked_object.md | 3 + docs/reference/tracked_object_filter.md | 3 + docs/reference/tracked_object_metadata.md | 3 + mkdocs.yaml | 31 ----- mkdocs.yml | 46 +++++++ pyproject.toml | 2 +- trackreid/matcher.py | 17 +-- trackreid/reid_processor.py | 33 ++--- trackreid/tracked_object.py | 14 +- trackreid/tracked_object_filter.py | 8 +- trackreid/tracked_object_metadata.py | 49 +++---- 20 files changed, 458 insertions(+), 114 deletions(-) delete mode 100644 docs/code.md create mode 100644 docs/custom_cost_selection.md create mode 100644 docs/quickstart_dev.md create mode 100644 docs/quickstart_user.md create mode 100644 docs/reference/cost_functions.md create mode 100644 docs/reference/matcher.md create mode 100644 docs/reference/reid_processor.md create mode 100644 docs/reference/selection_functions.md create mode 100644 docs/reference/tracked_object.md create mode 100644 docs/reference/tracked_object_filter.md create mode 100644 docs/reference/tracked_object_metadata.md delete mode 100644 mkdocs.yaml create mode 100644 mkdocs.yml diff --git a/docs/code.md b/docs/code.md deleted file mode 100644 index aa45473..0000000 --- a/docs/code.md +++ /dev/null @@ -1 +0,0 @@ -# Code diff --git a/docs/custom_cost_selection.md b/docs/custom_cost_selection.md new file mode 100644 index 0000000..71bb9c1 --- /dev/null +++ b/docs/custom_cost_selection.md @@ -0,0 +1,148 @@ +# Designing custom cost and selection functions + +## Custom cost function + +In our codebase, a cost function is utilized to quantify the dissimilarity between two objects, specifically instances of [TrackedObjects](reference/tracked_object.md). The cost function plays a pivotal role in the matching process within the [Matcher class](reference/matcher.md), where it computes a cost matrix. Each element in this matrix represents the cost of assigning a candidate to a switcher. For a deeper understanding of cost functions, please refer to the [related documentation](reference/cost_functions.md). + +When initializing the [ReidProcessor](reference/reid_processor.md), you have the option to provide a custom cost function. The requirements for designing one are as follows: + +- The cost function must accept 2 [TrackedObjects](reference/tracked_object.md) instances: a candidate (a new object that appears and can potentially be matched), and a switcher (an object that has been lost and can potentially be re-matched). +- All the [metadata](reference/tracked_object_metadata.md) of each [TrackedObject](reference/tracked_object.md) can be utilized to compute a cost. +- If additional metadata is required, you should modify the [metadata](reference/tracked_object_metadata.md) class accordingly. Please refer to the [developer quickstart documentation](quickstart_dev.md) if needed. + +Here is an example of an Intersection over Union (IoU) distance function that you can use: + +```python +def bounding_box_iou_distance(candidate: TrackedObject, switcher: TrackedObject) -> float: + """ + Calculates the Intersection over Union (IoU) between the bounding boxes of two TrackedObjects. + This measure is used as a measure of similarity between the two objects, with a higher IoU + indicating a higher likelihood of the objects being the same. + + Args: + candidate (TrackedObject): The first TrackedObject. + switcher (TrackedObject): The second TrackedObject. + + Returns: + float: The IoU between the bounding boxes of the two TrackedObjects. + """ + # Get the bounding boxes from the Metadata of each TrackedObject + bbox1 = candidate.metadata.bbox + bbox2 = switcher.metadata.bbox + + # Calculate the intersection of the bounding boxes + x1 = max(bbox1[0], bbox2[0]) + y1 = max(bbox1[1], bbox2[1]) + x2 = min(bbox1[2], bbox2[2]) + y2 = min(bbox1[3], bbox2[3]) + + # If the bounding boxes do not overlap, return 0 + if x2 < x1 or y2 < y1: + return 0.0 + + # Calculate the area of the intersection + intersection_area = (x2 - x1) * (y2 - y1) + + # Calculate the area of each bounding box + bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) + bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) + + # Calculate the IoU + iou = intersection_area / float(bbox1_area + bbox2_area - intersection_area) + + return 1 - iou + +``` + +Next, pass this function during the initialization of your [ReidProcessor](reference/reid_processor.md): + +```python +reid_processor = ReidProcessor(cost_function_threshold=0.3, + cost_function = bounding_box_iou_distance, + filter_confidence_threshold=..., + filter_time_threshold=..., + max_attempt_to_match=..., + max_frames_to_rematch=..., + save_to_txt=True, + file_path="your_file.txt") +``` + +In this case, candidates and switchers with bounding boxes will be matched if their IoU is below 0.7. Among possible matches, the two bounding boxes with the lowest cost (i.e., larger IoU) will be matched. You can use all the available metadata. For instance, here is an example of a cost function based on the difference in confidence: + +```python +def confidence_difference(candidate: TrackedObject, switcher: TrackedObject) -> float: + """ + Calculates the absolute difference between the confidence values of two TrackedObjects. + This measure is used as a measure of dissimilarity between the two objects, with a smaller difference + indicating a higher likelihood of the objects being the same. + + Args: + candidate (TrackedObject): The first TrackedObject. + switcher (TrackedObject): The second TrackedObject. + + Returns: + float: The absolute difference between the confidence values of the two TrackedObjects. + """ + # Get the confidence values from the Metadata of each TrackedObject + confidence1 = candidate.metadata.confidence + confidence2 = switcher.metadata.confidence + + # Calculate the absolute difference between the confidence values + difference = abs(confidence1 - confidence2) + + return difference + +``` + +Then, pass this function during the initialization of your [ReidProcessor](reference/reid_processor.md): + +```python +reid_processor = ReidProcessor(cost_function_threshold=0.1, + cost_function = confidence_difference, + filter_confidence_threshold=..., + filter_time_threshold=..., + max_attempt_to_match=..., + max_frames_to_rematch=..., + save_to_txt=True, + file_path="your_file.txt") +``` + +In this case, candidates and switchers will be matched if their confidence is similar, with a threshold acceptance of 0.1. Among possible matches, the two objects with the lowest cost (i.e., lower confidence difference) will be matched. + +## Custom Selection function + +In the codebase, a selection function is used to determine whether two objects, specifically [TrackedObjects](reference/tracked_object.md) instances, should be considered for matching. The selection function is a key part of the matching process in the [Matcher class](reference/matcher.md). For a deeper understanding of selection functions, please refer to the [related documentation](reference/selection_functions.md). + +Here is an example of a selection function per zone that you can use: + +```python + +# Define the area of interest, [x_min, y_min, x_max, y_max] +AREA_OF_INTEREST = [0, 0, 500, 500] + +def select_by_area(candidate: TrackedObject, switcher: TrackedObject) -> int: + + # Check if both objects are inside the area of interest + if (candidate.bbox[0] > AREA_OF_INTEREST[0] and candidate.bbox[1] > AREA_OF_INTEREST[1] and + candidate.bbox[0] + candidate.bbox[2] < AREA_OF_INTEREST[2] and candidate.bbox[1] + candidate.bbox[3] < AREA_OF_INTEREST[3] and + switcher.bbox[0] > AREA_OF_INTEREST[0] and switcher.bbox[1] > AREA_OF_INTEREST[1] and + switcher.bbox[0] + switcher.bbox[2] < AREA_OF_INTEREST[2] and switcher.bbox[1] + switcher.bbox[3] < AREA_OF_INTEREST[3]): + return 1 + else: + return 0 + +``` + +Then, pass this function during the initialization of your [ReidProcessor](reference/reid_processor.md): + +```python +reid_processor = ReidProcessor(selection_function = select_by_area, + filter_confidence_threshold=..., + filter_time_threshold=..., + max_attempt_to_match=..., + max_frames_to_rematch=..., + save_to_txt=True, + file_path="your_file.txt") +``` + +In this case, candidates and switchers will be considerated for matching if they belong to the same zone. You can of course combine selection functions, for instance to selection only switchers and candidates that belong to the same area and belong to the same category. diff --git a/docs/index.md b/docs/index.md index 8013429..df25491 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,3 +1,3 @@ -# Welcome to the documentation! +# Welcome to the documentation For more information, make sure to check the [Material for MkDocs documentation](https://squidfunk.github.io/mkdocs-material/getting-started/) diff --git a/docs/quickstart_dev.md b/docs/quickstart_dev.md new file mode 100644 index 0000000..7a0f99e --- /dev/null +++ b/docs/quickstart_dev.md @@ -0,0 +1,57 @@ +# Quickstart developers + +## Installation + +First, clone the repository to your local machine: + +```bash +git clone https://github.com/artefactory-fr/track-reid.git +``` + +Then, navigate to the project directory: + +```bash +cd track-reid +``` + +To install the necessary dependencies, we use Poetry. If you don't have Poetry installed, you can download it using the following command: + +```bash +curl -sSL https://install.python-poetry.org | python3 - +``` + +Now, you can install the dependencies: + +```bash +make install +``` + +This will create a virtual environment and install the necessary dependencies. +To activate the virtual environment in your terminal, you can use the following command: + +```bash +poetry shell +``` + +You can also update the requirements using the following command: + +```bash +make update-requirements +``` + +Then, you are ready to go ! +For more detailed information, please refer to the `Makefile`. + +## Tests + +In this project, we have designed both integration tests and unit tests. These tests are located in the `tests` directory of the project. + +Integration tests are designed to test the interaction between different parts of the system, ensuring that they work together as expected. Those tests can be found in the `tests/integration_tests` directory of the project. + +Unit tests, on the other hand, are designed to test individual components of the system in isolation. We provided a bench of unit tests to test key functions of the project, those can be found in `tests/unit_tests`. + +To run all tests, you can use the following command: + +```bash +make run_tests +``` diff --git a/docs/quickstart_user.md b/docs/quickstart_user.md new file mode 100644 index 0000000..d5262cb --- /dev/null +++ b/docs/quickstart_user.md @@ -0,0 +1,117 @@ +# Using the ReidProcessor + +The `ReidProcessor` is the entry point of the `track-reid` library. It is used to process and reconcile tracking data, ensuring consistent and accurate tracking of objects over time. Here's a step-by-step guide on how to use it: + +## Step 1: Understand the Usage + +The reidentification process is applied to tracking results, which are derived from the application of a tracking algorithm on detection results for successive frames of a video. This reidentification process is applied iteratively on each tracking result, updating its internal states during the process. + +The `ReidProcessor` needs to be updated with the tracking results for each frame of your +sequence or video. This is done by calling the `update` method that takes 2 arguments: + +- `frame_id`: an integer specifying the current frame of the video +- `tracker_output`: a numpy array containing the tracking results for the current frame + +## Step 2: Understand the Data Format Requirements + +The `ReidProcessor` update function requires a numpy array of tracking results for the current frame as input. This data must meet specific criteria regarding data type and structure. + +All input data must be numeric, either integers or floats. +Here's an example of the expected input data format based on the schema: + +| bbox (0-3) | object_id (4) | category (5) | confidence (6) | +|-----------------|---------------|--------------|----------------| +| 50, 60, 120, 80 | 1 | 1 | 0.91 | +| 50, 60, 120, 80 | 2 | 0 | 0.54 | + +Each row corresponds to a tracked object. + +- The first four columns denote the **bounding box coordinates** in the format (x, y, width, height), +where x and y are the top left coordinates of the bounding box. These coordinates can be either normalized or in pixel units. +These values remain unchanged during the reidentification process. +- The fifth column is the **object ID** assigned by the tracker, which may be adjusted during the reidentification process. +- The sixth column indicates the **category** of the detected object, which may also be adjusted during the reidentification process. +- The seventh column is the confidence score of the detection, which is not modified by the reidentification process. + +For additional information, you can utilize `ReidProcessor.print_input_data_requirements()`. + +Here's a reformatted example of how the output data should appear, based on the schema: + +| frame_id (0) | object_id (1) | category (2) | bbox (3-6) | confidence (7) | mean_confidence (8) | tracker_id (9) | +|--------------|---------------|--------------|-----------------|----------------|---------------------|----------------| +| 1 | 1 | 1 | 50, 60, 120, 80 | 0.91 | 0.85 | 1 | +| 2 | 2 | 0 | 50, 60, 120, 80 | 0.54 | 0.60 | 2 | + +- The first column represents the **frame identifier**, indicating the frame for which the result is applicable. +- The second column is the **object ID** assigned by the reidentification process. +- The third column is the **category** of the detected object, which may be adjusted during the reidentification process. +- The next four columns represent the **bounding box coordinates**, which remain unchanged from the input data. +- The seventh column is the **confidence** of the object, which also remains unchanged from the input data. +- The eighth column indicates the **average confidence** of the detected object over its lifetime, from the beginning of the tracking to the current frame. +- The final column is the **object ID assigned by the tracking algorithm**, before the reidentification process. + +You can use `ReidProcessor.print_output_data_format_information()` for more insight. + +## Step 3: Understand Necessary Modules + +To make ReidProcessor work, several modules are necessary: + +- `TrackedObject`: This class represents a tracked object. It is used within the Matcher and ReidProcessor classes. +- `TrackedObjectMetadata`: This class is attached to a tracked object and represents informations and properties about the object. +- `TrackedObjectFilter`: This class is used to filter tracked objects based on certain criteria. It is used within the ReidProcessor class. +- `Matcher`: This class is used to match tracked objects based on a cost function and a selection function. It is initialized within the ReidProcessor class. + +The cost and selection functions are key components of the ReidProcessor, as they will drive the matching process between lost objects and new objects during the video. Those two functions are fully customizable and can be passed as arguments of the ReidProcessor at initialization. They both take 2 `TrackedObjects` as inputs, and perform computation based on their metadatas. + +- **cost function**: This function calculates the cost of matching two objects. It takes two TrackedObject instances as input and returns a numerical value representing the cost of matching these two objects. A lower cost indicates a higher likelihood of a match. The default cost function is `bounding_box_distance`. + +- **selection_function**: This function determines whether two objects should be considered for matching. It takes two TrackedObject instances as input and returns a binary value (0 or 1). A return value of 1 indicates that the pair should be considered for matching, while a return value of 0 indicates that the pair should not be considered. The default selection function is `select_by_category`. + +In summary, prior to the matching process, filtering on which objects should be considerated is applied thought the `TrackedObjectFilter`. All objects are represented by the `TrackedObject` class, with its attached metadata represented by `TrackedObjectMetadata`. The `ReidProcessor` then uses the `Matcher` class with a cost function and selection function to match objects. + +## Step 4: Initialize ReidProcessor + +If you do not want to provide custom cost and selection function, here is an example of ReidProcessor initialization: + +```python +reid_processor = ReidProcessor(filter_confidence_threshold=0.1, + filter_time_threshold=5, + cost_function_threshold=5000, + max_attempt_to_match=5, + max_frames_to_rematch=500, + save_to_txt=True, + file_path="your_file.txt") +``` + +Here is a brief explanation of each argument in the ReidProcessor function, and how you can monitor the `Matcher` and the `TrackedObjectFilter` behaviours: + +- `filter_confidence_threshold`: Float value that sets the **minimum average confidence level** for a tracked object to be considered valid. Tracked objects with average confidence levels below this threshold will be ignored. + +- `filter_time_threshold`: Integer that sets the **minimum number of frames** a tracked object must be seen with the same id to be considered valid. Tracked objects seen less frames that this threshold will be ignored. + +- `cost_function_threshold`: This is a float value that sets the **maximum cost for a match** between a detection and a track. If the cost of matching a detection to a track exceeds this threshold, the match will not be made. Set to None for no limitation. + +- `max_attempt_to_match`: This is an integer that sets the **maximum number of attempts to match a tracked object never seen before** to a lost tracked object. If this tracked object never seen before can't be matched within this number of attempts, it will be considered a new stable tracked object. + +- `max_frames_to_rematch`: This is an integer that sets the **maximum number of frames to try to rematch a tracked object that has been lost**. If a lost object can't be rematch within this number of frames, it will be considered as lost forever. + +- `save_to_txt`: This is a boolean value that determines whether the tracking results should be saved to a text file. If set to True, the results will be saved to a text file. + +- `file_path`: This is a string that specifies the path to the text file where the tracking results will be saved. This argument is only relevant if save_to_txt is set to True. + +For more information on how to design custom cost and selection functions, refer to [this guide](custom_cost_selection.md). + +## Step 5: Run reidentifiaction process + +Lets say you have a `dataset` iterable object, composed for each iteartion of a frame id and its associated tracking results. You can call the `ReidProcessor` update class using the following: + +```python +for frame_id, tracker_output in dataset: + corrected_results = reid_processor.update(frame_id = frame_id, tracker_output=tracker_output) +``` + +At the end of the for loop, information about the correction can be retrieved using the `ReidProcessor` properties. For instance, the list of tracked object can be accessed using: + +```python +reid_processor.seen_objects() +``` diff --git a/docs/reference/cost_functions.md b/docs/reference/cost_functions.md new file mode 100644 index 0000000..3a0a1c9 --- /dev/null +++ b/docs/reference/cost_functions.md @@ -0,0 +1,15 @@ +# Cost functions + +In the codebase, a cost function is used to measure the dissimilarity between two objects, specifically [TrackedObjects](tracked_object.md) instances. The cost function is a crucial part of the matching process in the [Matcher class](matcher.md). It calculates a cost matrix, where each element represents the cost of assigning a candidate to a switcher. + +The cost function affects the behavior of the matching process in the following ways: + +1. **Determining Matches**: The cost function is used to determine the best matches between candidates and switchers. The lower the cost, the higher the likelihood that two objects are the same. + +2. **Influencing Match Quality**: The choice of cost function can greatly influence the quality of the matches. For example, a cost function that calculates the Euclidean distance between the centers of bounding boxes might be more suitable for tracking objects in a video, while a cost function that calculates the absolute difference between confidence values might be more suitable for matching objects based on their detection confidence. + +3. **Setting Match Thresholds**: The cost function also plays a role in setting thresholds for matches. In the [Matcher class](matcher.md), if the cost exceeds a certain threshold, the match is discarded. + +You can provide a custom cost function to the reidentification process. For more information, please refer to [this documentation](../custom_cost_selection.md). + +:::trackreid.cost_functions diff --git a/docs/reference/matcher.md b/docs/reference/matcher.md new file mode 100644 index 0000000..396ffbf --- /dev/null +++ b/docs/reference/matcher.md @@ -0,0 +1,3 @@ +# Matcher + +:::trackreid.matcher diff --git a/docs/reference/reid_processor.md b/docs/reference/reid_processor.md new file mode 100644 index 0000000..c72dbdc --- /dev/null +++ b/docs/reference/reid_processor.md @@ -0,0 +1,3 @@ +# Reid processor + +:::trackreid.reid_processor diff --git a/docs/reference/selection_functions.md b/docs/reference/selection_functions.md new file mode 100644 index 0000000..6b4f998 --- /dev/null +++ b/docs/reference/selection_functions.md @@ -0,0 +1,17 @@ +# Selection Functions + +In the codebase, a selection function is used to determine whether two objects, specifically [TrackedObjects](tracked_object.md) instances, should be considered for matching. The selection function is a key part of the matching process in the [Matcher class](matcher.md). + +The selection function influences the behavior of the matching process in the following ways: + +1. **Filtering Candidates**: The selection function is used to filter out pairs of objects that should not be considered for matching. This can help reduce the computational complexity of the matching process by reducing the size of the cost matrix. + +2. **Customizing Matching Criteria**: The selection function allows you to customize the criteria for considering a pair of objects for matching. For example, you might want to only consider pairs of objects that belong to the same category, or pairs of objects that belong to the same area / zone. + +3. **Improving Match Quality**: By carefully choosing or designing a selection function, you can improve the quality of the matches. For example, a selection function that only considers pairs of objects with similar appearance features might lead to more accurate matches. + +The selection function should return a boolean value. A return value of `True` or `1` indicates that the pair of objects should be considered for matching, while a return value of `False` or `0` indicates that the pair should not be considered. + +You can provide a custom selection function to the reidentification process. For more information, please refer to [this documentation](../custom_cost_selection.md). + +:::trackreid.selection_functions diff --git a/docs/reference/tracked_object.md b/docs/reference/tracked_object.md new file mode 100644 index 0000000..d74024a --- /dev/null +++ b/docs/reference/tracked_object.md @@ -0,0 +1,3 @@ +# TrackedObject + +:::trackreid.tracked_object diff --git a/docs/reference/tracked_object_filter.md b/docs/reference/tracked_object_filter.md new file mode 100644 index 0000000..757ebe7 --- /dev/null +++ b/docs/reference/tracked_object_filter.md @@ -0,0 +1,3 @@ +# TrackedObjectFilter + +:::trackreid.tracked_object_filter diff --git a/docs/reference/tracked_object_metadata.md b/docs/reference/tracked_object_metadata.md new file mode 100644 index 0000000..b03870c --- /dev/null +++ b/docs/reference/tracked_object_metadata.md @@ -0,0 +1,3 @@ +# TrackedObjectMetadata + +:::trackreid.tracked_object_metadata diff --git a/mkdocs.yaml b/mkdocs.yaml deleted file mode 100644 index 449fd64..0000000 --- a/mkdocs.yaml +++ /dev/null @@ -1,31 +0,0 @@ -site_name: track-reid - -theme: - name: "material" - palette: - - media: "(prefers-color-scheme: dark)" - scheme: default - primary: teal - accent: amber - toggle: - icon: material/moon-waning-crescent - name: Switch to dark mode - - media: "(prefers-color-scheme: light)" - scheme: slate - primary: teal - accent: amber - toggle: - icon: material/white-balance-sunny - name: Switch to light mode - features: - - search.suggest - - search.highlight - - content.tabs.link - -plugins: - - mkdocstrings - - search - -nav: - - Home: index.md - - Source code: code.md diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..edc0dd2 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,46 @@ +site_name: track-reid + +theme: + name: "material" + palette: + - media: "(prefers-color-scheme: dark)" + scheme: default + primary: indigo + accent: pink + toggle: + icon: material/moon-waning-crescent + name: Switch to dark mode + - media: "(prefers-color-scheme: light)" + scheme: slate + primary: indigo + accent: pink + toggle: + icon: material/white-balance-sunny + name: Switch to light mode + features: + - search.suggest + - search.highlight + - content.tabs.link + +plugins: + - mkdocstrings + - search + +markdown_extensions: + - codehilite: + use_pygments: true + pygments_style: monokai + +nav: + - Home: index.md + - Quickstart users: quickstart_user.md + - Quickstart developers: quickstart_dev.md + - Custom cost and selection functions: custom_cost_selection.md + - Code Reference: + - ReidProcessor: reference/reid_processor.md + - TrackedObjectFilter: reference/tracked_object_filter.md + - Matcher: reference/matcher.md + - TrackedObjectMetadata: reference/tracked_object_metadata.md + - TrackedObject: reference/tracked_object.md + - Cost functions: reference/cost_functions.md + - Selection functions: reference/selection_functions.md diff --git a/pyproject.toml b/pyproject.toml index 122887e..1d64bc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ pytest = "7.3.2" ipykernel = "6.24.0" mkdocs = "1.4.3" mkdocs-material = "9.1.15" -mkdocstrings-python = "1.1.2" +mkdocstrings = {extras = ["python-legacy"], version = "^0.24.0"} bandit = "1.7.5" nbstripout = "0.6.1" diff --git a/trackreid/matcher.py b/trackreid/matcher.py index 9148c2e..c57a9d9 100644 --- a/trackreid/matcher.py +++ b/trackreid/matcher.py @@ -18,22 +18,13 @@ def __init__( Initializes the Matcher object with the provided cost function, selection function, and cost function threshold. Args: - cost_function (Callable): A function that calculates the cost of matching two objects. This function should - take two TrackedObject instances as input and return a numerical value representing the cost of matching - these two objects. A lower cost indicates a higher likelihood of a match. - - selection_function (Callable): A function that determines whether two objects should be considered for - matching. This function should take two TrackedObject instances as input and return a binary value (0 or 1). - A return value of 1 indicates that the pair should be considered for matching, while a return value of 0 - indicates that the pair should not be considered. - - cost_function_threshold (Optional[Union[int, float]]): An optional threshold value for the cost function. - If provided, any pair of objects with a matching cost greater than this threshold will not be considered for - matching. If not provided, all selected pairs will be considered regardless of their matching cost. + cost_function (Callable): A function that calculates the cost of matching two objects. This function should take two TrackedObject instances as input and return a numerical value representing the cost of matching these two objects. A lower cost indicates a higher likelihood of a match. + selection_function (Callable): A function that determines whether two objects should be considered for matching. This function should take two TrackedObject instances as input and return a binary value (0 or 1). A return value of 1 indicates that the pair should be considered for matching, while a return value of 0 indicates that the pair should not be considered. + cost_function_threshold (Optional[Union[int, float]]): An optional threshold value for the cost function. If provided, any pair of objects with a matching cost greater than this threshold will not be considered for matching. If not provided, all selected pairs will be considered regardless of their matching cost. Returns: None - """ + """ # noqa: E501 self.cost_function = cost_function self.selection_function = selection_function self.cost_function_threshold = cost_function_threshold diff --git a/trackreid/reid_processor.py b/trackreid/reid_processor.py index 8fde685..027f07e 100644 --- a/trackreid/reid_processor.py +++ b/trackreid/reid_processor.py @@ -51,40 +51,23 @@ class ReidProcessor: Args: - filter_confidence_threshold (float): Confidence threshold for the filter. The filter will only consider - tracked objects that have a mean confidence score during the all transaction above this threshold. + filter_confidence_threshold (float): Confidence threshold for the filter. The filter will only consider tracked objects that have a mean confidence score during the all transaction above this threshold. - filter_time_threshold (int): Time threshold for the filter. The filter will only consider tracked objects - that have been seen for a number of frames above this threshold. + filter_time_threshold (int): Time threshold for the filter. The filter will only consider tracked objects that have been seen for a number of frames above this threshold. - max_frames_to_rematch (int): Maximum number of frames to rematch. If a switcher is lost for a number of - frames greater than this value, it will be flagged as lost forever. + max_frames_to_rematch (int): Maximum number of frames to rematch. If a switcher is lost for a number of frames greater than this value, it will be flagged as lost forever. - max_attempt_to_match (int): Maximum number of attempts to match a candidate. If a candidate has not been - rematched despite a number of attempts equal to this value, it will be flagged as a stable object. + max_attempt_to_match (int): Maximum number of attempts to match a candidate. If a candidate has not been rematched despite a number of attempts equal to this value, it will be flagged as a stable object. - selection_function (Callable): A function that determines whether two objects should be considered for - matching. The selection function should take two TrackedObject instances as input and return a binary value - (0 or 1). A return value of 1 indicates that the pair should be considered for matching, while a return - value of 0 indicates that the pair should not be considered. - Defaults to select_by_category. + selection_function (Callable): A function that determines whether two objects should be considered for matching. The selection function should take two TrackedObject instances as input and return a binary value (0 or 1). A return value of 1 indicates that the pair should be considered for matching, while a return value of 0 indicates that the pair should not be considered. - cost_function (Callable): A function that calculates the cost of matching two objects. The cost function - should take two TrackedObject instances as input and return a numerical value representing the cost of - matching these two objects. A lower cost indicates a higher likelihood of a match. - Defaults to bounding_box_distance. + cost_function (Callable): A function that calculates the cost of matching two objects. The cost function should take two TrackedObject instances as input and return a numerical value representing the cost of matching these two objects. A lower cost indicates a higher likelihood of a match. - cost_function_threshold (Optional[Union[int, float]]): An maximal threshold value for the cost function. - If provided, any pair of objects with a matching cost greater than this threshold will not be considered - for matching. If not provided, all selected pairs will be considered regardless of their matching cost. - Defaults to None. + cost_function_threshold (Optional[Union[int, float]]): An maximal threshold value for the cost function. If provided, any pair of objects with a matching cost greater than this threshold will not be considered for matching. If not provided, all selected pairs will be considered regardless of their matching cost.\n - save_to_txt (bool): A flag indicating whether to save the results to a text file. If set to True, the - results will be saved to a text file specified by the file_path parameter. - Default to False. + save_to_txt (bool): A flag indicating whether to save the results to a text file. If set to True, the results will be saved to a text file specified by the file_path parameter. file_path (str): The path to the text file where the results will be saved if save_to_txt is set to True. - Defaults to tracks.txt """ # noqa: E501 def __init__( diff --git a/trackreid/tracked_object.py b/trackreid/tracked_object.py index bd42a19..1bc84b0 100644 --- a/trackreid/tracked_object.py +++ b/trackreid/tracked_object.py @@ -44,14 +44,12 @@ class TrackedObject: Args: object_ids (Union[Union[float, int], sllist]): The unique identifiers for the object. state (int): The current state of the object. - metadata (Union[np.ndarray, TrackedObjectMetaData]): The metadata for the object. It can be either a - TrackedObjectMetaData object, or a data line, i.e. output of detection model. If metadata is initialized - with a TrackedObjectMetaData object, a frame_id must be given. + metadata (Union[np.ndarray, TrackedObjectMetaData]): The metadata for the object. It can be either a TrackedObjectMetaData object, or a data line, i.e. output of detection model. If metadata is initialized with a TrackedObjectMetaData object, a frame_id must be given. frame_id (Optional[int], optional): The frame ID where the object was first seen. Defaults to None. Raises: NameError: If the type of object_ids or metadata is unrecognized. - """ + """ # noqa: E501 def __init__( self, @@ -192,12 +190,10 @@ def update_metadata(self, data_line: np.ndarray, frame_id: int): the tracked object. Args: - data_line (np.ndarray): The detection data for a single frame. It contains information such as the - class name, bounding box coordinates, and confidence level of the detection. + data_line (np.ndarray): The detection data for a single frame. It contains information such as the class name, bounding box coordinates, and confidence level of the detection. - frame_id (int): The frame id where the object was detected. This is used to update the last frame id of - the tracked object. - """ + frame_id (int): The frame id where the object was detected. This is used to update the last frame id of the tracked object. + """ # noqa: E501 self.metadata.update(data_line=data_line, frame_id=frame_id) def __eq__(self, other): diff --git a/trackreid/tracked_object_filter.py b/trackreid/tracked_object_filter.py index c4a56fa..cf9b016 100644 --- a/trackreid/tracked_object_filter.py +++ b/trackreid/tracked_object_filter.py @@ -8,11 +8,9 @@ class TrackedObjectFilter: confidence and the number of frames they have been observed in. Args: - confidence_threshold (float): The minimum mean confidence level required for a tracked - object to be considered valid. - frames_seen_threshold (int): The minimum number of frames a tracked object - must be observed in to be considered valid. - """ + confidence_threshold (float): The minimum mean confidence level required for a tracked object to be considered valid. + frames_seen_threshold (int): The minimum number of frames a tracked object must be observed in to be considered valid. + """ # noqa: E501 def __init__(self, confidence_threshold: float, frames_seen_threshold: int): self.confidence_threshold = confidence_threshold diff --git a/trackreid/tracked_object_metadata.py b/trackreid/tracked_object_metadata.py index ea04bfc..874ec4b 100644 --- a/trackreid/tracked_object_metadata.py +++ b/trackreid/tracked_object_metadata.py @@ -31,28 +31,20 @@ def update(self, data_line: np.ndarray, frame_id: int): Updates the metadata of a tracked object based on new detection data. This method is used to update the metadata of a tracked object whenever new detection data is available. - It updates the last frame id, class counts, bounding box, confidence, confidence sum, and observations. + It updates the last frame id, class counts, bounding box, confidence, confidence sum, and observations: + - last_frame_id: Updated to the frame id where the object was detected + - class_counts: Incremented by 1 for the detected class + - bbox: Updated to the bounding box coordinates from the detection data + - confidence: Updated to the confidence level from the detection data + - confidence_sum: Incremented by the confidence level from the detection data + - observations: Incremented by 1 Args: - data_line (np.ndarra): The detection data for a single frame. It contains information such as the - class name, bounding box coordinates, and confidence level of the detection. + data_line (np.ndarra): The detection data for a single frame. It contains information such as the class name, bounding box coordinates, and confidence level of the detection. - frame_id (int): The frame id where the object was detected. This is used to update the last frame id of - the tracked object. + frame_id (int): The frame id where the object was detected. This is used to update the last frame id of the tracked object. - Updates: - last_frame_id: The last frame id is updated to the frame id where the object was detected. - - class_counts: The class counts are updated by incrementing the count of the detected class by 1. - - bbox: The bounding box is updated to the bounding box coordinates from the detection data. - - confidence: The confidence is updated to the confidence level from the detection data. - - confidence_sum: The confidence sum is updated by adding the confidence level from the detection data. - - observations: The total number of observations is incremented by 1. - """ + """ # noqa: E501 self.last_frame_id = frame_id class_name = int(data_line[input_data_positions.category]) @@ -66,22 +58,23 @@ class name, bounding box coordinates, and confidence level of the detection. def merge(self, other_object): """ Merges the metadata of another TrackedObjectMetaData instance into the current one. + Updates the current instance with the data from the other TrackedObjectMetaData instance. + + The following properties are updated: + - observations: Incremented by the observations of the other object. + - confidence_sum: Incremented by the confidence sum of the other object. + - confidence: Set to the confidence of the other object. + - bbox: Set to the bounding box of the other object. + - last_frame_id: Set to the last frame id of the other object. + - class_counts: For each class, the count is incremented by the count of the other object. Args: - other_object (TrackedObjectMetaData): The other TrackedObjectMetaData instance whose metadata - is to be merged with the current instance. + other_object (TrackedObjectMetaData): The other TrackedObjectMetaData instance whose metadata is to be merged with the current instance. Raises: TypeError: If the other_object is not an instance of TrackedObjectMetaData. - Updates: - observations: The total number of observations is updated by adding the observations of the other object. - confidence_sum: The total confidence sum is updated by adding the confidence sum of the other object. - confidence: The confidence is updated to the confidence of the other object. - bbox: The bounding box is updated to the bounding box of the other object. - last_frame_id: The last frame id is updated to the last frame id of the other object. - class_counts: The class counts are updated by adding the class counts of the other object for each class. - """ + """ # noqa: E501 if not isinstance(other_object, type(self)): raise TypeError("Can only merge with another TrackedObjectMetaData.") From ea575474784d58c49f081d693724c58dbf80cafd Mon Sep 17 00:00:00 2001 From: Tristan Pepin <122389133+tristanpepinartefact@users.noreply.github.com> Date: Thu, 23 Nov 2023 15:11:01 +0100 Subject: [PATCH 11/13] fix deployment (#22) --- .github/workflows/deploy_docs.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/deploy_docs.yaml b/.github/workflows/deploy_docs.yaml index 8d4a6b6..adc0cff 100644 --- a/.github/workflows/deploy_docs.yaml +++ b/.github/workflows/deploy_docs.yaml @@ -23,5 +23,5 @@ jobs: make install - name: Deploying MkDocs documentation run: | - mkdocs build - mkdocs gh-deploy --force + poetry run mkdocs build + poetry run mkdocs gh-deploy --force From b3cff2dc6f3a0569a5972874d732b6bb8df73207 Mon Sep 17 00:00:00 2001 From: Tristan Pepin <122389133+tristanpepinartefact@users.noreply.github.com> Date: Thu, 23 Nov 2023 15:33:07 +0100 Subject: [PATCH 12/13] Tp/fix notebook (#23) --- notebooks/starter_kit_reid.ipynb | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/notebooks/starter_kit_reid.ipynb b/notebooks/starter_kit_reid.ipynb index 685d1a6..af33a21 100644 --- a/notebooks/starter_kit_reid.ipynb +++ b/notebooks/starter_kit_reid.ipynb @@ -176,7 +176,7 @@ " return detection_outputs\n", "\n", " processed_detections = self._pre_process(detection_outputs)\n", - " tracked_objects = self.tracker.update(processed_detections, frame_id = frame_id)\n", + " tracked_objects = self.tracker.update(processed_detections, _ = frame_id)\n", " processed_tracked = self._post_process(tracked_objects)\n", " return processed_tracked\n", "\n", @@ -260,6 +260,15 @@ "\n" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reid_processor.seen_objects" + ] + }, { "cell_type": "code", "execution_count": null, @@ -317,6 +326,13 @@ " print(case)\n", " print(filtered_objects)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { From c59c54e7cfd95e080b76d85e7d0f2556b09b2454 Mon Sep 17 00:00:00 2001 From: TomDarmon <36815861+TomDarmon@users.noreply.github.com> Date: Thu, 23 Nov 2023 16:38:02 +0100 Subject: [PATCH 13/13] Fix/commit history (#24) Co-authored-by: TomDarmon Co-authored-by: Tristan Pepin <122389133+tristanpepinartefact@users.noreply.github.com> Co-authored-by: github-actions --- notebooks/starter_kit_reid.ipynb | 190 +++++++++++++++++++++++++++++++ 1 file changed, 190 insertions(+) diff --git a/notebooks/starter_kit_reid.ipynb b/notebooks/starter_kit_reid.ipynb index af33a21..177b9f7 100644 --- a/notebooks/starter_kit_reid.ipynb +++ b/notebooks/starter_kit_reid.ipynb @@ -160,6 +160,196 @@ " return detection_output\n" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class TrackingHandler():\n", + " def __init__(self, tracker) -> None:\n", + " self.tracker = tracker\n", + "\n", + " def update(self, detection_outputs, frame_id):\n", + "\n", + " if not detection_outputs.size :\n", + " return detection_outputs\n", + "\n", + " processed_detections = self._pre_process(detection_outputs)\n", + " tracked_objects = self.tracker.update(processed_detections, frame_id = frame_id)\n", + " processed_tracked = self._post_process(tracked_objects)\n", + " return processed_tracked\n", + "\n", + " def _pre_process(self,detection_outputs : np.ndarray):\n", + " return detection_outputs\n", + "\n", + " def _post_process(self, tracked_objects : np.ndarray):\n", + "\n", + " if tracked_objects.size :\n", + " if tracked_objects.ndim == 1:\n", + " tracked_objects = np.expand_dims(tracked_objects, 0)\n", + "\n", + " return tracked_objects" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "timestamp = datetime.now().strftime('%Y-%m-%d-%H-%M')\n", + "print(timestamp)\n", + "folder_save = os.path.join(PREDICTIONS_PATH,timestamp)\n", + "os.makedirs(folder_save, exist_ok=True)\n", + "GENERATE_VIDEOS = False\n", + "for sequence in tqdm(SEQUENCES) :\n", + " frame_path = get_sequence_frames(sequence)\n", + " test_sequence = Sequence(frame_path)\n", + " frame_id = 0\n", + " BaseTrack._count = 0\n", + " from datetime import datetime\n", + "\n", + " file_path = os.path.join(folder_save,sequence) + '.txt'\n", + " video_path = os.path.join(folder_save,sequence) + '.mp4'\n", + "\n", + " if GENERATE_VIDEOS:\n", + " fourcc = cv2.VideoWriter_fourcc(*'avc1') # or use 'x264'\n", + " out = cv2.VideoWriter(video_path, fourcc, 20.0, (2560, 1440)) # adjust the frame size (640, 480) as per your needs\n", + "\n", + " detection_handler = DetectionHandler(image_shape=[2560, 1440])\n", + " tracking_handler = TrackingHandler(tracker=BYTETracker(track_thresh= 0.3, track_buffer = 5, match_thresh = 0.85, frame_rate= 30))\n", + " reid_processor = ReidProcessor(filter_confidence_threshold=0.1,\n", + " filter_time_threshold=5,\n", + " cost_function=bounding_box_distance,\n", + " cost_function_threshold=5000, # max cost to rematch 2 objects\n", + " selection_function=select_by_category,\n", + " max_attempt_to_match=5,\n", + " max_frames_to_rematch=500,\n", + " save_to_txt=True,\n", + " file_path=file_path)\n", + "\n", + " for frame, detection in test_sequence:\n", + "\n", + " frame_id += 1\n", + "\n", + " processed_detections = detection_handler.process(detection)\n", + " processed_tracked = tracking_handler.update(processed_detections, frame_id)\n", + " reid_results = reid_processor.update(processed_tracked, frame_id)\n", + "\n", + " if GENERATE_VIDEOS and len(reid_results) > 0:\n", + " frame = np.array(frame)\n", + " for res in reid_results:\n", + " object_id = int(res[output_data_positions.object_id])\n", + " bbox = list(map(int, res[output_data_positions.bbox]))\n", + " class_id = int(res[output_data_positions.category])\n", + " tracker_id = int(res[output_data_positions.tracker_id])\n", + " mean_confidence = float(res[output_data_positions.mean_confidence])\n", + " x1, y1, x2, y2 = bbox\n", + " color = (0, 0, 255) if class_id else (0, 255, 0) # green for class 0, red for class 1\n", + " cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)\n", + " cv2.putText(frame, f\"{object_id} ({tracker_id}) : {round(mean_confidence,2)}\", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)\n", + "\n", + " if GENERATE_VIDEOS:\n", + " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", + " out.write(frame)\n", + "\n", + " if GENERATE_VIDEOS :\n", + " out.release()\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from collections import defaultdict\n", + "\n", + "def count_occurrences(file_path, case):\n", + " object_counts = defaultdict(int)\n", + " class_counts = defaultdict(lambda: defaultdict(int))\n", + "\n", + " with open(file_path, 'r') as f:\n", + " for line in f:\n", + " data = line.split()\n", + "\n", + " if case != 'baseline':\n", + " object_id = int(data[1])\n", + " category = int(data[2])\n", + " else:\n", + " object_id = int(data[1])\n", + " category = int(data[-1])\n", + "\n", + " object_counts[object_id] += 1\n", + " class_counts[object_id][category] += 1\n", + "\n", + " return object_counts, class_counts\n", + "\n", + "def filter_counts(object_counts, class_counts, min_occurrences=10):\n", + " filtered_objects = {}\n", + "\n", + " for object_id, count in object_counts.items():\n", + " if count > min_occurrences and class_counts[object_id][0] > class_counts[object_id][1]:\n", + " filtered_objects[object_id] = count\n", + "\n", + " return filtered_objects\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "PATH_PREDICTIONS = f\"../data/predictions/{timestamp}\"\n", + "\n", + "for sequence in SEQUENCES:\n", + " print(\"-\"*50)\n", + " print(sequence)\n", + "\n", + " for case in [\"baseline\", timestamp]:\n", + " object_counts, class_counts = count_occurrences(f'../data/predictions/{case}/{sequence}.txt', case=case)\n", + " filtered_objects = filter_counts(object_counts, class_counts)\n", + "\n", + " print(case)\n", + " print(filtered_objects)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class DetectionHandler():\n", + " def __init__(self, image_shape) -> None:\n", + " self.image_shape = image_shape\n", + "\n", + " def process(self, detection_output):\n", + " if detection_output.size:\n", + " if detection_output.ndim == 1:\n", + " detection_output = np.expand_dims(detection_output, 0)\n", + "\n", + " processed_detection = np.zeros(detection_output.shape)\n", + "\n", + " for idx, detection in enumerate(detection_output):\n", + " clss = detection[0]\n", + " conf = detection[5]\n", + " bbox = detection[1:5]\n", + " xyxy_bbox = xy_center_to_xyxy(bbox)\n", + " rescaled_bbox = rescale_bbox(xyxy_bbox,self.image_shape)\n", + " processed_detection[idx,:4] = rescaled_bbox\n", + " processed_detection[idx,4] = conf\n", + " processed_detection[idx,5] = clss\n", + "\n", + " return processed_detection\n", + " else:\n", + " return detection_output\n" + ] + }, { "cell_type": "code", "execution_count": null,