From 308be7b1a0febd48d90040f42aff97b86137720f Mon Sep 17 00:00:00 2001 From: Tristan Pepin <122389133+tristanpepinartefact@users.noreply.github.com> Date: Tue, 7 Nov 2023 15:08:41 +0100 Subject: [PATCH] Tp/add reid lib (#4) * add matcher * add reid processor * add tracked object filter * add metadata object * add object filterd modif * add tracked object class * add utils * add args & constants * fix drop duplicate * add notebook * modify pyproject --- notebooks/starter_kit_reid.ipynb | 157 ++++++++++++++++ pyproject.toml | 1 - track_reid/args/reid_args.py | 1 + track_reid/constants/reid_constants.py | 22 +++ track_reid/matcher.py | 91 ++++++++++ track_reid/reid_processor.py | 240 +++++++++++++++++++++++++ track_reid/tracked_object.py | 125 +++++++++++++ track_reid/tracked_object_filter.py | 18 ++ track_reid/tracked_object_metadata.py | 110 ++++++++++++ track_reid/utils.py | 43 +++++ 10 files changed, 807 insertions(+), 1 deletion(-) create mode 100644 notebooks/starter_kit_reid.ipynb create mode 100644 track_reid/args/reid_args.py create mode 100644 track_reid/constants/reid_constants.py create mode 100644 track_reid/matcher.py create mode 100644 track_reid/reid_processor.py create mode 100644 track_reid/tracked_object.py create mode 100644 track_reid/tracked_object_filter.py create mode 100644 track_reid/tracked_object_metadata.py create mode 100644 track_reid/utils.py diff --git a/notebooks/starter_kit_reid.ipynb b/notebooks/starter_kit_reid.ipynb new file mode 100644 index 0000000..58605c0 --- /dev/null +++ b/notebooks/starter_kit_reid.ipynb @@ -0,0 +1,157 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload \n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys \n", + "sys.path.append(\"..\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "from track_reid.reid_processor import ReidProcessor" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def bounding_box_distance(obj1, obj2):\n", + " # Get the bounding boxes from the Metadata of each TrackedObject\n", + " bbox1 = obj1.metadata.bbox\n", + " bbox2 = obj2.metadata.bbox\n", + "\n", + " # Calculate the Euclidean distance between the centers of the bounding boxes\n", + " center1 = ((bbox1[0] + bbox1[2]) / 2, (bbox1[1] + bbox1[3]) / 2)\n", + " center2 = ((bbox2[0] + bbox2[2]) / 2, (bbox2[1] + bbox2[3]) / 2)\n", + " distance = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)\n", + "\n", + " return distance\n", + "\n", + "def select_by_category(obj1, obj2):\n", + " # Compare the categories of the two objects\n", + " return 1 if obj1.category == obj2.category else 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Example usage:\n", + "data = np.array([\n", + " [1, 1, \"car\", 100, 200, 300, 400, 0.9],\n", + " [1, 2, \"person\", 50, 150, 200, 400, 0.8],\n", + " [2, 1, \"truck\", 120, 220, 320, 420, 0.95],\n", + " [2, 2, \"person\", 60, 160, 220, 420, 0.85],\n", + " [3, 1, \"car\", 110, 210, 310, 410, 0.91],\n", + " [3, 3, \"person\", 61, 170, 220, 420, 0.91],\n", + " [3, 4, \"car\", 60, 160, 220, 420, 0.91],\n", + " [3, 6, \"person\", 60, 160, 220, 420, 0.91],\n", + " [4, 1, \"truck\", 130, 230, 330, 430, 0.92],\n", + " [4, 2, \"person\", 65, 165, 225, 425, 0.83],\n", + " [5, 1, \"car\", 115, 215, 315, 415, 0.93],\n", + " [5, 2, \"person\", 57, 157, 207, 407, 0.84],\n", + " [5, 4, \"car\", 60, 160, 220, 420, 0.91],\n", + " [5, 8, \"person\", 60, 160, 220, 420, 0.91],\n", + "])\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "processor = ReidProcessor(filter_confidence_threshold=0.4, \n", + " filter_time_threshold=0,\n", + " cost_function=bounding_box_distance,\n", + " selection_function=select_by_category,\n", + " max_attempt_to_rematch=0,\n", + " max_frames_to_rematch=100)\n", + "\n", + "\n", + "columns = ['frame_id', 'object_id', 'category', 'x1', 'y1', 'x2', 'y2', 'confidence']\n", + "df = pd.DataFrame(data, columns=columns)\n", + "# Convert numerical columns to the appropriate data type\n", + "df[['frame_id', 'object_id', 'x1', 'y1', 'x2', 'y2']] = df[['frame_id', 'object_id', 'x1', 'y1', 'x2', 'y2']].astype(int)\n", + "df['confidence'] = df['confidence'].astype(float)\n", + "\n", + "\n", + "for frame_id, frame_data in df.groupby(\"frame_id\"):\n", + "\n", + " bytetrack_output = frame_data.values\n", + " if bytetrack_output.ndim == 1 : \n", + " bytetrack_output = np.expand_dims(bytetrack_output, 0)\n", + "\n", + " results = processor.update(bytetrack_output)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(processor.all_tracked_objects[0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(processor.all_tracked_objects[0].metadata)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "track-reid", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "vscode": { + "interpreter": { + "hash": "a7fd834062a85a1fb9d4482d7456bec56e0ff99e4dd054f5e10ff6e3cdc923c6" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index 97e66fc..3209cb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ select = [ "I", "N", "Q", - "RET", "ARG", "PTH", "PD", diff --git a/track_reid/args/reid_args.py b/track_reid/args/reid_args.py new file mode 100644 index 0000000..bab95c1 --- /dev/null +++ b/track_reid/args/reid_args.py @@ -0,0 +1 @@ +POSSIBLE_CLASSES = ["car", "person", "truck", "animal"] diff --git a/track_reid/constants/reid_constants.py b/track_reid/constants/reid_constants.py new file mode 100644 index 0000000..950bf8b --- /dev/null +++ b/track_reid/constants/reid_constants.py @@ -0,0 +1,22 @@ +from typing import ClassVar + +from pydantic import BaseModel + + +class ReidConstants(BaseModel): + BYETRACK_OUTPUT: int = -2 + FILTERED_OUTPUT: int = -1 + STABLE: int = 0 + SWITCHER: int = 1 + CANDIDATE: int = 2 + + DESCRIPTION: ClassVar[dict] = { + BYETRACK_OUTPUT: "bytetrack output not in reid process", + FILTERED_OUTPUT: "bytetrack output entering reid process", + STABLE: "stable object", + SWITCHER: "lost object to be re-matched", + CANDIDATE: "new object to be matched", + } + + +reid_constants = ReidConstants() diff --git a/track_reid/matcher.py b/track_reid/matcher.py new file mode 100644 index 0000000..cb8b817 --- /dev/null +++ b/track_reid/matcher.py @@ -0,0 +1,91 @@ +from typing import Callable, Dict, List + +import numpy as np +from scipy.optimize import linear_sum_assignment + +from track_reid.tracked_object import TrackedObject + + +class Matcher: + def __init__(self, cost_function: Callable, selection_function: Callable) -> None: + self.cost_function = cost_function + self.selection_function = selection_function + + def compute_cost_matrix( + self, objects1: List[TrackedObject], objects2: List[TrackedObject] + ) -> np.ndarray: + """Computes a cost matrix of size [M, N] between a list of M TrackedObjects objects1, + and a list of N TrackedObjects objects2. + + Args: + objects1 (List[TrackedObject]): list of objects to be matched. + objects2 (List[TrackedObject]): list of candidates for matches. + + Returns: + np.ndarray: cost to match each pair of objects. + """ + if not objects1 or not objects2: + return np.array([]) # Return an empty array if either list is empty + + # Create matrices with all combinations of objects1 and objects2 + objects1_matrix, objects2_matrix = np.meshgrid(objects1, objects2) + + # Use np.vectorize to apply the scoring function to all combinations + cost_matrix = np.vectorize(self.cost_function)(objects1_matrix, objects2_matrix) + + return cost_matrix + + def compute_selection_matrix( + self, objects1: List[TrackedObject], objects2: List[TrackedObject] + ) -> np.ndarray: + """Computes a selection matrix of size [M, N] between a list of M TrackedObjects objects1, + and a list of N TrackedObjects objects2. + + Args: + objects1 (List[TrackedObject]): list of objects to be matched. + objects2 (List[TrackedObject]): list of candidates for matches. + + Returns: + np.ndarray: cost each pair of objects be matched or not ? + """ + if not objects1 or not objects2: + return np.array([]) # Return an empty array if either list is empty + + # Create matrices with all combinations of objects1 and objects2 + objects1_matrix, objects2_matrix = np.meshgrid(objects1, objects2) + + # Use np.vectorize to apply the scoring function to all combinations + selection_matrix = np.vectorize(self.selection_function)(objects1_matrix, objects2_matrix) + + return selection_matrix + + def match( + self, objects1: List[TrackedObject], objects2: List[TrackedObject] + ) -> List[Dict[TrackedObject, TrackedObject]]: + """Computes a dict of matching between objects in list objects1 and objects in objects2. + + Args: + objects1 (List[TrackedObject]): list of objects to be matched. + objects2 (List[TrackedObject]): list of candidates for matches. + + Returns: + List[Dict[TrackedObject, TrackedObject]]: list of pairs of TrackedObjects + if there is a match. + """ + if not objects1 or not objects2: + return [] # Return an empty array if either list is empty + + cost_matrix = self.compute_cost_matrix(objects1, objects2) + selection_matrix = self.compute_selection_matrix(objects1, objects2) + + # Set a large cost value for elements to be discarded + cost_matrix[selection_matrix == 0] = 1e3 + + # Find the best matches using the linear sum assignment + row_indices, col_indices = linear_sum_assignment(cost_matrix, maximize=False) + + matches = [] + for row, col in zip(row_indices, col_indices): + matches.append({objects1[col]: objects2[row]}) + + return matches diff --git a/track_reid/reid_processor.py b/track_reid/reid_processor.py new file mode 100644 index 0000000..d345e1d --- /dev/null +++ b/track_reid/reid_processor.py @@ -0,0 +1,240 @@ +from __future__ import annotations + +from typing import Dict, List, Set + +import numpy as np + +from track_reid.constants.reid_constants import reid_constants +from track_reid.matcher import Matcher +from track_reid.tracked_object import TrackedObject +from track_reid.tracked_object_filter import TrackedObjectFilter +from track_reid.utils import filter_objects_by_state, get_top_list_correction + + +class ReidProcessor: + def __init__( + self, + filter_confidence_threshold, + filter_time_threshold, + cost_function, + selection_function, + max_frames_to_rematch: int = 100, + max_attempt_to_rematch: int = 1, + ) -> None: + + self.matcher = Matcher(cost_function=cost_function, selection_function=selection_function) + + self.tracked_filter = TrackedObjectFilter( + confidence_threshold=filter_confidence_threshold, + frames_seen_threshold=filter_time_threshold, + ) + + self.all_tracked_objects: List[TrackedObject] = [] + self.switchers: List[TrackedObject] = [] + self.candidates: List[TrackedObject] = [] + + self.last_tracker_ids: Set[int] = set() + + self.max_frames_to_rematch = max_frames_to_rematch + self.max_attempt_to_rematch = max_attempt_to_rematch + + self.frame_id = 0 + + def update(self, tracker_output: np.ndarray): + reshaped_tracker_output = self._reshape_input(tracker_output) + self._preprocess(tracker_output=reshaped_tracker_output) + self._perform_reid_process(tracker_output=reshaped_tracker_output) + reid_output = self._postprocess(tracker_output=tracker_output) + return reid_output + + def _preprocess(self, tracker_output: np.ndarray): + self.all_tracked_objects = self._update_tracked_objects(tracker_output=tracker_output) + self.all_tracked_objects = self._apply_filtering() + + def _update_tracked_objects(self, tracker_output: np.ndarray): + + self.frame_id = tracker_output[0, 0] + for object_id, data_line in zip(tracker_output[:, 1], tracker_output): + if object_id not in self.all_tracked_objects: + new_tracked_object = TrackedObject( + object_ids=object_id, state=reid_constants.BYETRACK_OUTPUT, metadata=data_line + ) + self.all_tracked_objects.append(new_tracked_object) + else: + self.all_tracked_objects[self.all_tracked_objects.index(object_id)].update_metadata( + data_line + ) + + return self.all_tracked_objects + + @staticmethod + def _reshape_input(bytetrack_output: np.ndarray): + if bytetrack_output.ndim == 1: + bytetrack_output = np.expand_dims(bytetrack_output, 0) + return bytetrack_output + + def _apply_filtering(self): + for tracked_object in self.all_tracked_objects: + self.tracked_filter.update(tracked_object) + + return self.all_tracked_objects + + def _perform_reid_process(self, tracker_output: np.ndarray): + + tracked_ids = filter_objects_by_state( + self.all_tracked_objects, states=reid_constants.BYETRACK_OUTPUT, exclusion=True + ) + + current_tracker_ids = set(tracker_output[:, 1]).intersection(set(tracked_ids)) + + self.compute_stable_objects( + current_tracker_ids=current_tracker_ids, tracked_ids=self.all_tracked_objects + ) + + self.switchers = self.drop_switchers( + self.switchers, + current_tracker_ids, + max_frames_to_rematch=self.max_frames_to_rematch, + frame_id=self.frame_id, + ) + + self.candidates.extend(self.identify_candidates(tracked_ids=tracked_ids)) + + self.switchers.extend( + self.identify_switchers( + current_tracker_ids=current_tracker_ids, + last_bytetrack_ids=self.last_tracker_ids, + tracked_ids=tracked_ids, + ) + ) + + matches = self.matcher.match(self.candidates, self.switchers) + + self.process_matches( + all_tracked_objects=self.all_tracked_objects, + matches=matches, + candidates=self.candidates, + switchers=self.switchers, + current_tracker_ids=current_tracker_ids, + ) + + self.candidates = self.drop_candidates( + self.candidates, + ) + + self.last_tracker_ids = current_tracker_ids.copy() + + @staticmethod + def identify_switchers( + tracked_ids: List["TrackedObject"], + current_tracker_ids: Set[int], + last_bytetrack_ids: Set[int], + ): + switchers = [] + lost_ids = last_bytetrack_ids - current_tracker_ids + + for tracked_id in tracked_ids: + if tracked_id in lost_ids: + switchers.append(tracked_id) + tracked_id.state = reid_constants.SWITCHER + + return switchers + + @staticmethod + def identify_candidates(tracked_ids: List["TrackedObject"]): + candidates = [] + for current_object in tracked_ids: + if current_object.state == reid_constants.FILTERED_OUTPUT: + current_object.state = reid_constants.CANDIDATE + candidates.append(current_object) + return candidates + + @staticmethod + def compute_stable_objects(tracked_ids: list, current_tracker_ids: Set[int]): + + top_list_correction = get_top_list_correction(tracked_ids) + + for current_object in current_tracker_ids: + tracked_id = tracked_ids[tracked_ids.index(current_object)] + if current_object not in top_list_correction: + + tracked_ids.remove(tracked_id) + new_object, tracked_id = tracked_id.cut(current_object) + + new_object.state = reid_constants.STABLE + tracked_id.state = reid_constants.STABLE + + tracked_ids.append(new_object) + tracked_ids.append(tracked_id) + + @staticmethod + def process_matches( + all_tracked_objects: List["TrackedObject"], + matches: Dict["TrackedObject", "TrackedObject"], + switchers: List["TrackedObject"], + candidates: List["TrackedObject"], + current_tracker_ids: Set[int], + ): + + for match in matches: + candidate_match, switcher_match = match.popitem() + + switcher_match.merge(candidate_match) + all_tracked_objects.remove(candidate_match) + switchers.remove(switcher_match) + candidates.remove(candidate_match) + + current_tracker_ids.discard(candidate_match.id) + current_tracker_ids.add(switcher_match.id) + + @staticmethod + def drop_switchers( + switchers: List["TrackedObject"], + current_tracker_ids: Set[int], + max_frames_to_rematch: int, + frame_id: int, + ): + + switchers_to_drop = set(switchers).intersection(current_tracker_ids) + filtered_switchers = switchers.copy() + + for switcher in switchers: + if switcher in switchers_to_drop: + switcher.state = reid_constants.STABLE + filtered_switchers.remove(switcher) + elif switcher.get_nb_frames_since_last_appearance(frame_id) > max_frames_to_rematch: + filtered_switchers.remove(switcher) + + return filtered_switchers + + @staticmethod + def drop_candidates(candidates: List["TrackedObject"]): + # for now drop candidates if there was no match + for candidate in candidates: + candidate.state = reid_constants.STABLE + return [] + + def _postprocess(self, tracker_output: np.ndarray): + filtered_objects = list( + filter( + lambda obj: obj.get_state() == reid_constants.STABLE + and obj in tracker_output[:, 1], + self.all_tracked_objects, + ) + ) + reid_output = [] + for object in filtered_objects: + reid_output.append( + [ + self.frame_id, + object.id, + object.category, + object.bbox[0], + object.bbox[1], + object.bbox[2], + object.bbox[3], + object.confidence, + ] + ) + + return reid_output diff --git a/track_reid/tracked_object.py b/track_reid/tracked_object.py new file mode 100644 index 0000000..527b14c --- /dev/null +++ b/track_reid/tracked_object.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from typing import Union + +import numpy as np +from llist import sllist + +from track_reid.constants.reid_constants import reid_constants +from track_reid.tracked_object_metadata import TrackedObjectMetaData +from track_reid.utils import split_list_around_value + + +class TrackedObject: + def __init__( + self, + object_ids: Union[int, sllist], + state: int, + metadata: Union[np.ndarray, TrackedObjectMetaData], + ): + + self.state = state + + if isinstance(object_ids, int): + self.re_id_chain = sllist([object_ids]) + elif isinstance(object_ids, sllist): + self.re_id_chain = object_ids + else: + raise NameError("unrocognized type for object_ids.") + if isinstance(metadata, np.ndarray): + self.metadata = TrackedObjectMetaData(metadata) + elif isinstance(metadata, TrackedObjectMetaData): + self.metadata = metadata.copy() + else: + raise NameError("unrocognized type for metadata.") + + def merge(self, other_object): + if not isinstance(other_object, TrackedObject): + raise TypeError("Can only merge with another TrackedObject.") + + # Merge the re_id_chains + self.re_id_chain.extend(other_object.re_id_chain) + + # Merge the metadata (you should implement a proper merge logic in TrackedObjectMetaData) + self.metadata.merge(other_object.metadata) + self.state = reid_constants.STABLE + + # Return the merged object + return self + + @property + def id(self): + return self.re_id_chain.first.value + + @property + def category(self): + return max(self.metadata.class_counts, key=self.metadata.class_counts.get) + + @property + def confidence(self): + return self.metadata.confidence + + @property + def mean_confidence(self): + return self.metadata.mean_confidence() + + @property + def bbox(self): + return self.metadata.bbox + + def get_age(self, frame_id): + return frame_id - self.metadata.first_frame_id + + def get_nb_frames_since_last_appearance(self, frame_id): + return frame_id - self.metadata.last_frame_id + + def get_state(self): + return self.state + + def __hash__(self): + return hash(self.id) + + def __repr__(self): + return ( + f"TrackedObject(current_id={self.id}, re_id_chain={list(self.re_id_chain)}" + + f", state={self.state}: {reid_constants.DESCRIPTION[self.state]})" + ) + + def __str__(self): + return f"{self.__repr__()}, metadata : {self.metadata}" + + def update_metadata(self, data_line: np.ndarray): + self.metadata.update(data_line) + + def __eq__(self, other): + if isinstance(other, int): + return other in self.re_id_chain + elif isinstance(other, TrackedObject): + return self.re_id_chain == other.re_id_chain + return False + + def cut(self, object_id: int): + + if object_id not in self.re_id_chain: + raise NameError( + f"Trying to cut object {self} with {object_id} that is not in the re-id chain." + ) + + before, after = split_list_around_value(self.re_id_chain, object_id) + self.re_id_chain = before + + new_object = TrackedObject( + state=reid_constants.STABLE, object_ids=after, metadata=self.metadata + ) + return new_object, self + + def format_data(self): + return [ + self.id, + self.category, + self.bbox[0], + self.bbox[1], + self.bbox[2], + self.bbox[3], + self.confidence, + ] diff --git a/track_reid/tracked_object_filter.py b/track_reid/tracked_object_filter.py new file mode 100644 index 0000000..bc936b0 --- /dev/null +++ b/track_reid/tracked_object_filter.py @@ -0,0 +1,18 @@ +from track_reid.constants.reid_constants import reid_constants + + +class TrackedObjectFilter: + def __init__(self, confidence_threshold, frames_seen_threshold): + self.confidence_threshold = confidence_threshold + self.frames_seen_threshold = frames_seen_threshold + + def update(self, tracked_object): + if tracked_object.get_state() == reid_constants.BYETRACK_OUTPUT: + if ( + tracked_object.metadata.mean_confidence() > self.confidence_threshold + and tracked_object.metadata.observations >= self.frames_seen_threshold + ): + tracked_object.state = reid_constants.FILTERED_OUTPUT + + elif tracked_object.metadata.mean_confidence() < self.confidence_threshold: + tracked_object.state = reid_constants.BYETRACK_OUTPUT diff --git a/track_reid/tracked_object_metadata.py b/track_reid/tracked_object_metadata.py new file mode 100644 index 0000000..dc255f1 --- /dev/null +++ b/track_reid/tracked_object_metadata.py @@ -0,0 +1,110 @@ +import json +from pathlib import Path + +from track_reid.args.reid_args import POSSIBLE_CLASSES + + +class TrackedObjectMetaData: + def __init__(self, data_line): + self.first_frame_id = int(data_line[0]) + self.class_counts = {class_name: 0 for class_name in POSSIBLE_CLASSES} + self.observations = 0 + self.confidence_sum = 0 + self.confidence = 0 + self.update(data_line) + + def update(self, data_line): + self.last_frame_id = int(data_line[0]) + class_name = data_line[2] + self.class_counts[class_name] = self.class_counts.get(class_name, 0) + 1 + self.bbox = list(map(int, data_line[3:7])) + confidence = float(data_line[7]) + self.confidence = confidence + self.confidence_sum += confidence + self.observations += 1 + + def merge(self, other_object): + if not isinstance(other_object, TrackedObjectMetaData): + raise TypeError("Can only merge with another TrackedObjectMetaData.") + + self.observations += other_object.observations + self.confidence_sum += other_object.confidence_sum + self.confidence = other_object.confidence + self.bbox = other_object.bbox + self.last_frame_id = other_object.last_frame_id + for class_name in POSSIBLE_CLASSES: + self.class_counts[class_name] = self.class_counts.get( + class_name, 0 + ) + other_object.class_counts.get(class_name, 0) + + def copy(self): + # Create a new instance of TrackedObjectMetaData with the same data + copy_obj = TrackedObjectMetaData( + [self.first_frame_id, 0, list(self.class_counts.keys())[0], *self.bbox, self.confidence] + ) + # Update the copied instance with the actual class counts and observations + copy_obj.class_counts = self.class_counts.copy() + copy_obj.observations = self.observations + copy_obj.confidence_sum = self.confidence_sum + copy_obj.confidence = self.confidence + + return copy_obj + + def save_to_json(self, filename): + data = { + "first_frame_id": self.first_frame_id, + "class_counts": self.class_counts, + "bbox": self.bbox, + "confidence": self.confidence, + "confidence_sum": self.confidence_sum, + "observations": self.observations, + } + + with Path.open(filename, "w") as file: + json.dump(data, file) + + @classmethod + def load_from_json(cls, filename): + with Path.open(filename, "r") as file: + data = json.load(file) + obj = cls.__new__(cls) + obj.first_frame_id = data["first_frame_id"] + obj.class_counts = data["class_counts"] + obj.bbox = data["bbox"] + obj.confidence = data["confidence"] + obj.confidence_sum = data["confidence_sum"] + obj.observations = data["observations"] + return obj + + def class_proportions(self): + if self.observations > 0: + proportions = { + class_name: count / self.observations + for class_name, count in self.class_counts.items() + } + else: + proportions = {class_name: 0.0 for class_name in POSSIBLE_CLASSES} + return proportions + + def percentage_of_time_seen(self, frame_id): + if self.observations > 0: + percentage = (self.observations / (frame_id - self.first_frame_id + 1)) * 100 + else: + percentage = 0.0 + return percentage + + def mean_confidence(self): + if self.observations > 0: + return self.confidence_sum / self.observations + else: + return 0.0 + + def __repr__(self) -> str: + return f"TrackedObjectMetaData(bbox={self.bbox})" + + def __str__(self): + return ( + f"First frame seen: {self.first_frame_id}, nb observations: {self.observations}, " + + "class Proportions: {self.class_proportions()}, Bounding Box: {self.bbox}, " + + "Mean Confidence: {self.mean_confidence()}" + ) diff --git a/track_reid/utils.py b/track_reid/utils.py new file mode 100644 index 0000000..8060d3e --- /dev/null +++ b/track_reid/utils.py @@ -0,0 +1,43 @@ +from typing import List, Union + +from llist import sllist + + +def get_top_list_correction(tracked_ids: list): + + top_list_correction = [tracked_id.re_id_chain.last.value for tracked_id in tracked_ids] + + return top_list_correction + + +def split_list_around_value(my_list: sllist, value_to_split: int): + + if value_to_split == my_list.last.value: + raise NameError("split on the last") + before = sllist() + after = sllist() + + current = my_list.first + + while current: + before.append(current.value) + if current.value == value_to_split: + break + + current = current.next + + while current: + after.append(current.value) + current = current.next + + return before, after + + +def filter_objects_by_state(tracked_objects: List, states: Union[int, list], exclusion=False): + if isinstance(states, int): + states = [states] + if exclusion: + filtered_objects = [obj for obj in tracked_objects if obj.state not in states] + else: + filtered_objects = [obj for obj in tracked_objects if obj.state in states] + return filtered_objects