From 0151f3eb62885c14b3fe5b3ea2df9f1c5ae24374 Mon Sep 17 00:00:00 2001 From: TomDarmon <36815861+TomDarmon@users.noreply.github.com> Date: Thu, 23 Nov 2023 14:51:48 +0100 Subject: [PATCH] Dev (#21) Co-authored-by: TomDarmon Co-authored-by: Tristan Pepin <122389133+tristanpepinartefact@users.noreply.github.com> Co-authored-by: github-actions --- .github/workflows/ci.yaml | 5 +- .github/workflows/deploy_docs.yaml | 6 +- .gitignore | 3 +- .pre-commit-config.yaml | 8 - Makefile | 21 +- README.md | 31 +- bin/download_sample_sequences.sh | 2 + bin/install_with_conda.sh | 18 - bin/install_with_venv.sh | 20 - docs/code.md | 1 - docs/custom_cost_selection.md | 148 ++++ docs/index.md | 2 +- docs/quickstart_dev.md | 57 ++ docs/quickstart_user.md | 117 +++ docs/reference/cost_functions.md | 15 + docs/reference/matcher.md | 3 + docs/reference/reid_processor.md | 3 + docs/reference/selection_functions.md | 17 + docs/reference/tracked_object.md | 3 + docs/reference/tracked_object_filter.md | 3 + docs/reference/tracked_object_metadata.md | 3 + lib/norfair_helper/utils.py | 29 +- lib/norfair_helper/video.py | 16 +- lib/sequence.py | 1 + mkdocs.yaml | 31 - mkdocs.yml | 46 ++ notebooks/norfair_starter_kit.ipynb | 43 +- notebooks/starter_kit_reid.ipynb | 297 +++++-- pyproject.toml | 14 +- .../unit_tests/tracked_objects/object_1.json | 28 + .../unit_tests/tracked_objects/object_24.json | 24 + .../unit_tests/tracked_objects/object_4.json | 25 + tests/unit_tests/test_matcher.py | 80 ++ tests/unit_tests/test_metadata.py | 52 ++ tests/unit_tests/test_placeholder.py | 6 - tests/unit_tests/test_tracked_objects.py | 96 +++ tests/unit_tests/test_utils.py | 92 +++ trackreid/args/reid_args.py | 1 - trackreid/configs/input_data_positions.py | 25 + trackreid/configs/output_data_positions.py | 35 + .../{constants => configs}/reid_constants.py | 19 +- trackreid/cost_functions/__init__.py | 1 + .../cost_functions/bounding_box_distance.py | 28 + trackreid/matcher.py | 121 ++- trackreid/reid_processor.py | 734 ++++++++++++++---- trackreid/selection_functions/__init__.py | 1 + .../selection_functions/select_by_category.py | 18 + trackreid/tracked_object.py | 231 +++++- trackreid/tracked_object_filter.py | 36 +- trackreid/tracked_object_metadata.py | 224 +++++- trackreid/utils.py | 107 ++- 51 files changed, 2499 insertions(+), 448 deletions(-) delete mode 100644 bin/install_with_conda.sh delete mode 100644 bin/install_with_venv.sh delete mode 100644 docs/code.md create mode 100644 docs/custom_cost_selection.md create mode 100644 docs/quickstart_dev.md create mode 100644 docs/quickstart_user.md create mode 100644 docs/reference/cost_functions.md create mode 100644 docs/reference/matcher.md create mode 100644 docs/reference/reid_processor.md create mode 100644 docs/reference/selection_functions.md create mode 100644 docs/reference/tracked_object.md create mode 100644 docs/reference/tracked_object_filter.md create mode 100644 docs/reference/tracked_object_metadata.md delete mode 100644 mkdocs.yaml create mode 100644 mkdocs.yml create mode 100644 tests/data/unit_tests/tracked_objects/object_1.json create mode 100644 tests/data/unit_tests/tracked_objects/object_24.json create mode 100644 tests/data/unit_tests/tracked_objects/object_4.json create mode 100644 tests/unit_tests/test_matcher.py create mode 100644 tests/unit_tests/test_metadata.py delete mode 100644 tests/unit_tests/test_placeholder.py create mode 100644 tests/unit_tests/test_tracked_objects.py create mode 100644 tests/unit_tests/test_utils.py delete mode 100644 trackreid/args/reid_args.py create mode 100644 trackreid/configs/input_data_positions.py create mode 100644 trackreid/configs/output_data_positions.py rename trackreid/{constants => configs}/reid_constants.py (51%) create mode 100644 trackreid/cost_functions/__init__.py create mode 100644 trackreid/cost_functions/bounding_box_distance.py create mode 100644 trackreid/selection_functions/__init__.py create mode 100644 trackreid/selection_functions/select_by_category.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 5e4111b..8d94afd 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -29,7 +29,10 @@ jobs: - name: Install requirements run: | - poetry install + make install - name: Run Pre commit hooks run: make format-code + + - name: Test with pytest + run: make run-tests diff --git a/.github/workflows/deploy_docs.yaml b/.github/workflows/deploy_docs.yaml index 3ac59df..8d4a6b6 100644 --- a/.github/workflows/deploy_docs.yaml +++ b/.github/workflows/deploy_docs.yaml @@ -15,10 +15,12 @@ jobs: uses: actions/setup-python@v2 with: python-version: "3.10" - + - name: Install poetry + run: | + make download-poetry - name: Install requirements run: | - make install_project_requirements + make install - name: Deploying MkDocs documentation run: | mkdocs build diff --git a/.gitignore b/.gitignore index 8234a9f..fe7c188 100644 --- a/.gitignore +++ b/.gitignore @@ -141,6 +141,7 @@ secrets/* # Data ignore everythin data/detections and data/frames data/detections/* data/frames/* - +*.mp4 +*.txt # poetry poetry.lock diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6643f2c..2572435 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,12 +31,4 @@ repos: types: [file] files: (.ipynb)$ language: system - - id: pytest-check - name: Tests (pytest) - stages: [push] - entry: pytest tests/ - types: [python] - language: system - pass_filenames: false - always_run: true exclude: ^(.svn|CVS|.bzr|.hg|.git|__pycache__|.tox|.ipynb_checkpoints|assets|tests/assets/|venv/|.venv/) diff --git a/Makefile b/Makefile index 1385ca7..4de9643 100644 --- a/Makefile +++ b/Makefile @@ -1,10 +1,4 @@ -PYTHON_VERSION = 3.10 -USE_CONDA ?= 1 -INSTALL_SCRIPT = install_with_conda.sh -ifeq (false,$(USE_CONDA)) - INSTALL_SCRIPT = install_with_venv.sh -endif - +PYTHON_VERSION = 3.10.13 # help: help - Display this makefile's help information .PHONY: help help: @@ -17,6 +11,7 @@ download-poetry: # help: install - Install python dependencies using poetry .PHONY: install install: + @poetry config virtualenvs.create true @poetry env use $(PYTHON_VERSION) @poetry lock -n @poetry install -n @@ -27,12 +22,6 @@ install: install-requirements: @poetry install -n - -.PHONY: install-dev-requirements -# help : install-dev-requirements - Install Python Dependencies for development -install-dev-requirements: - @poetry install -n --with dev - .PHONY: update-requirements #help: update-requirements - Update Python Dependencies (requirements.txt and requirements-dev.txt) update-requirements: @@ -43,6 +32,12 @@ update-requirements: format-code: @poetry run pre-commit run -a +.PHONY: run-tests +#help: run-tests - Run all tests with pytest +run-tests: + @export PYTHONPATH=. + @poetry run pytest + # help: deploy_docs - Deploy documentation to GitHub Pages .PHONY: deploy_docs deploy_docs: diff --git a/README.md b/README.md index 276c58b..f65be64 100644 --- a/README.md +++ b/README.md @@ -12,8 +12,6 @@ [![Pre-commit](https://img.shields.io/badge/pre--commit-enabled-informational?logo=pre-commit&logoColor=white)](https://github.com/artefactory-fr/track-reid/blob/main/.pre-commit-config.yaml) -TODO: if not done already, check out the [Skaff documentation](https://artefact.roadie.so/catalog/default/component/repo-builder-ds/docs/) for more information about the generated repository. - This Git repository is dedicated to the development of a Python library aimed at correcting the results of tracking algorithms. The primary goal of this library is to reconcile and reassign lost or misidentified IDs, ensuring a consistent and accurate tracking of objects over time. ## Table of Contents @@ -27,13 +25,23 @@ This Git repository is dedicated to the development of a Python library aimed at ## Installation +First, install poetry: + +```bash +make download-poetry +``` + To install the required packages in a virtual environment, run the following command: ```bash make install ``` -TODO: Choose between conda and venv if necessary or let the Makefile as is and copy/paste the [MORE INFO installation section](MORE_INFO.md#eased-installation) to explain how to choose between conda and venv. +You can then activate the env with the following command: + +```bash +poetry shell +``` A complete list of available commands can be found using the following command: @@ -43,11 +51,22 @@ make help ## Usage -TODO: Add usage instructions here +For a quickstart, please refer to the documentation [here](https://artefactory-fr.github.io/track-reid/quickstart_user/). You also have at disposal a demo notebook in `notebooks/starer_kit_reid.ipynb`. -## Documentation +Lets say you have a `dataset` iterable object, composed for each iteartion of a frame id and its associated tracking results. You can call the `ReidProcessor` update class using the following: + +```python +for frame_id, tracker_output in dataset: + corrected_results = reid_processor.update(frame_id = frame_id, tracker_output=tracker_output) +``` + +At the end of the for loop, information about the correction can be retrieved using the `ReidProcessor` properties. For instance, the list of tracked object can be accessed using: -TODO: Github pages is not enabled by default, you need to enable it in the repository settings: Settings > Pages > Source: "Deploy from a branch" / Branch: "gh-pages" / Folder: "/(root)" +```python +reid_processor.seen_objects() +``` + +## Documentation A detailed documentation of this project is available [here](https://artefactory-fr.github.io/track-reid/) diff --git a/bin/download_sample_sequences.sh b/bin/download_sample_sequences.sh index d357733..b7e5b14 100644 --- a/bin/download_sample_sequences.sh +++ b/bin/download_sample_sequences.sh @@ -9,6 +9,8 @@ sequences_frames=$(gsutil ls gs://data-track-reid/frames | head -$N_SEQUENCES) sequences_detections=$(echo "$sequences_detections" | tail -n +2) sequences_frames=$(echo "$sequences_frames" | tail -n +2) +mkdir -p data/detections +mkdir -p data/frames # download the sequences to data/detections and data/frames for sequence in $sequences_detections; do diff --git a/bin/install_with_conda.sh b/bin/install_with_conda.sh deleted file mode 100644 index 58b6a47..0000000 --- a/bin/install_with_conda.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash -e - -read -p "Want to install conda env named 'track-reid'? (y/n)" answer -if [ "$answer" = "y" ]; then - echo "Installing conda env..." - conda create -n track-reid python=3.10 -y - source $(conda info --base)/etc/profile.d/conda.sh - conda activate track-reid - echo "Installing requirements..." - make install_project_requirements - python3 -m ipykernel install --user --name=track-reid - conda install -c conda-forge --name track-reid notebook -y - echo "Installing pre-commit..." - make install_precommit - echo "Installation complete!"; -else - echo "Installation of conda env aborted!"; -fi diff --git a/bin/install_with_venv.sh b/bin/install_with_venv.sh deleted file mode 100644 index b3389db..0000000 --- a/bin/install_with_venv.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash -e - -read -p "Want to install virtual env named 'venv' in this project ? (y/n)" answer -if [ "$answer" = "y" ]; then - echo "Installing virtual env..." - declare VENV_DIR=$(pwd)/venv - if ! [ -d "$VENV_DIR" ]; then - python3 -m venv $VENV_DIR - fi - - source $VENV_DIR/bin/activate - echo "Installing requirements..." - make install_project_requirements - python3 -m ipykernel install --user --name=venv - echo "Installing pre-commit..." - make install_precommit - echo "Installation complete!"; -else - echo "Installation of virtual env aborted!"; -fi diff --git a/docs/code.md b/docs/code.md deleted file mode 100644 index aa45473..0000000 --- a/docs/code.md +++ /dev/null @@ -1 +0,0 @@ -# Code diff --git a/docs/custom_cost_selection.md b/docs/custom_cost_selection.md new file mode 100644 index 0000000..71bb9c1 --- /dev/null +++ b/docs/custom_cost_selection.md @@ -0,0 +1,148 @@ +# Designing custom cost and selection functions + +## Custom cost function + +In our codebase, a cost function is utilized to quantify the dissimilarity between two objects, specifically instances of [TrackedObjects](reference/tracked_object.md). The cost function plays a pivotal role in the matching process within the [Matcher class](reference/matcher.md), where it computes a cost matrix. Each element in this matrix represents the cost of assigning a candidate to a switcher. For a deeper understanding of cost functions, please refer to the [related documentation](reference/cost_functions.md). + +When initializing the [ReidProcessor](reference/reid_processor.md), you have the option to provide a custom cost function. The requirements for designing one are as follows: + +- The cost function must accept 2 [TrackedObjects](reference/tracked_object.md) instances: a candidate (a new object that appears and can potentially be matched), and a switcher (an object that has been lost and can potentially be re-matched). +- All the [metadata](reference/tracked_object_metadata.md) of each [TrackedObject](reference/tracked_object.md) can be utilized to compute a cost. +- If additional metadata is required, you should modify the [metadata](reference/tracked_object_metadata.md) class accordingly. Please refer to the [developer quickstart documentation](quickstart_dev.md) if needed. + +Here is an example of an Intersection over Union (IoU) distance function that you can use: + +```python +def bounding_box_iou_distance(candidate: TrackedObject, switcher: TrackedObject) -> float: + """ + Calculates the Intersection over Union (IoU) between the bounding boxes of two TrackedObjects. + This measure is used as a measure of similarity between the two objects, with a higher IoU + indicating a higher likelihood of the objects being the same. + + Args: + candidate (TrackedObject): The first TrackedObject. + switcher (TrackedObject): The second TrackedObject. + + Returns: + float: The IoU between the bounding boxes of the two TrackedObjects. + """ + # Get the bounding boxes from the Metadata of each TrackedObject + bbox1 = candidate.metadata.bbox + bbox2 = switcher.metadata.bbox + + # Calculate the intersection of the bounding boxes + x1 = max(bbox1[0], bbox2[0]) + y1 = max(bbox1[1], bbox2[1]) + x2 = min(bbox1[2], bbox2[2]) + y2 = min(bbox1[3], bbox2[3]) + + # If the bounding boxes do not overlap, return 0 + if x2 < x1 or y2 < y1: + return 0.0 + + # Calculate the area of the intersection + intersection_area = (x2 - x1) * (y2 - y1) + + # Calculate the area of each bounding box + bbox1_area = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) + bbox2_area = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) + + # Calculate the IoU + iou = intersection_area / float(bbox1_area + bbox2_area - intersection_area) + + return 1 - iou + +``` + +Next, pass this function during the initialization of your [ReidProcessor](reference/reid_processor.md): + +```python +reid_processor = ReidProcessor(cost_function_threshold=0.3, + cost_function = bounding_box_iou_distance, + filter_confidence_threshold=..., + filter_time_threshold=..., + max_attempt_to_match=..., + max_frames_to_rematch=..., + save_to_txt=True, + file_path="your_file.txt") +``` + +In this case, candidates and switchers with bounding boxes will be matched if their IoU is below 0.7. Among possible matches, the two bounding boxes with the lowest cost (i.e., larger IoU) will be matched. You can use all the available metadata. For instance, here is an example of a cost function based on the difference in confidence: + +```python +def confidence_difference(candidate: TrackedObject, switcher: TrackedObject) -> float: + """ + Calculates the absolute difference between the confidence values of two TrackedObjects. + This measure is used as a measure of dissimilarity between the two objects, with a smaller difference + indicating a higher likelihood of the objects being the same. + + Args: + candidate (TrackedObject): The first TrackedObject. + switcher (TrackedObject): The second TrackedObject. + + Returns: + float: The absolute difference between the confidence values of the two TrackedObjects. + """ + # Get the confidence values from the Metadata of each TrackedObject + confidence1 = candidate.metadata.confidence + confidence2 = switcher.metadata.confidence + + # Calculate the absolute difference between the confidence values + difference = abs(confidence1 - confidence2) + + return difference + +``` + +Then, pass this function during the initialization of your [ReidProcessor](reference/reid_processor.md): + +```python +reid_processor = ReidProcessor(cost_function_threshold=0.1, + cost_function = confidence_difference, + filter_confidence_threshold=..., + filter_time_threshold=..., + max_attempt_to_match=..., + max_frames_to_rematch=..., + save_to_txt=True, + file_path="your_file.txt") +``` + +In this case, candidates and switchers will be matched if their confidence is similar, with a threshold acceptance of 0.1. Among possible matches, the two objects with the lowest cost (i.e., lower confidence difference) will be matched. + +## Custom Selection function + +In the codebase, a selection function is used to determine whether two objects, specifically [TrackedObjects](reference/tracked_object.md) instances, should be considered for matching. The selection function is a key part of the matching process in the [Matcher class](reference/matcher.md). For a deeper understanding of selection functions, please refer to the [related documentation](reference/selection_functions.md). + +Here is an example of a selection function per zone that you can use: + +```python + +# Define the area of interest, [x_min, y_min, x_max, y_max] +AREA_OF_INTEREST = [0, 0, 500, 500] + +def select_by_area(candidate: TrackedObject, switcher: TrackedObject) -> int: + + # Check if both objects are inside the area of interest + if (candidate.bbox[0] > AREA_OF_INTEREST[0] and candidate.bbox[1] > AREA_OF_INTEREST[1] and + candidate.bbox[0] + candidate.bbox[2] < AREA_OF_INTEREST[2] and candidate.bbox[1] + candidate.bbox[3] < AREA_OF_INTEREST[3] and + switcher.bbox[0] > AREA_OF_INTEREST[0] and switcher.bbox[1] > AREA_OF_INTEREST[1] and + switcher.bbox[0] + switcher.bbox[2] < AREA_OF_INTEREST[2] and switcher.bbox[1] + switcher.bbox[3] < AREA_OF_INTEREST[3]): + return 1 + else: + return 0 + +``` + +Then, pass this function during the initialization of your [ReidProcessor](reference/reid_processor.md): + +```python +reid_processor = ReidProcessor(selection_function = select_by_area, + filter_confidence_threshold=..., + filter_time_threshold=..., + max_attempt_to_match=..., + max_frames_to_rematch=..., + save_to_txt=True, + file_path="your_file.txt") +``` + +In this case, candidates and switchers will be considerated for matching if they belong to the same zone. You can of course combine selection functions, for instance to selection only switchers and candidates that belong to the same area and belong to the same category. diff --git a/docs/index.md b/docs/index.md index 8013429..df25491 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,3 +1,3 @@ -# Welcome to the documentation! +# Welcome to the documentation For more information, make sure to check the [Material for MkDocs documentation](https://squidfunk.github.io/mkdocs-material/getting-started/) diff --git a/docs/quickstart_dev.md b/docs/quickstart_dev.md new file mode 100644 index 0000000..7a0f99e --- /dev/null +++ b/docs/quickstart_dev.md @@ -0,0 +1,57 @@ +# Quickstart developers + +## Installation + +First, clone the repository to your local machine: + +```bash +git clone https://github.com/artefactory-fr/track-reid.git +``` + +Then, navigate to the project directory: + +```bash +cd track-reid +``` + +To install the necessary dependencies, we use Poetry. If you don't have Poetry installed, you can download it using the following command: + +```bash +curl -sSL https://install.python-poetry.org | python3 - +``` + +Now, you can install the dependencies: + +```bash +make install +``` + +This will create a virtual environment and install the necessary dependencies. +To activate the virtual environment in your terminal, you can use the following command: + +```bash +poetry shell +``` + +You can also update the requirements using the following command: + +```bash +make update-requirements +``` + +Then, you are ready to go ! +For more detailed information, please refer to the `Makefile`. + +## Tests + +In this project, we have designed both integration tests and unit tests. These tests are located in the `tests` directory of the project. + +Integration tests are designed to test the interaction between different parts of the system, ensuring that they work together as expected. Those tests can be found in the `tests/integration_tests` directory of the project. + +Unit tests, on the other hand, are designed to test individual components of the system in isolation. We provided a bench of unit tests to test key functions of the project, those can be found in `tests/unit_tests`. + +To run all tests, you can use the following command: + +```bash +make run_tests +``` diff --git a/docs/quickstart_user.md b/docs/quickstart_user.md new file mode 100644 index 0000000..d5262cb --- /dev/null +++ b/docs/quickstart_user.md @@ -0,0 +1,117 @@ +# Using the ReidProcessor + +The `ReidProcessor` is the entry point of the `track-reid` library. It is used to process and reconcile tracking data, ensuring consistent and accurate tracking of objects over time. Here's a step-by-step guide on how to use it: + +## Step 1: Understand the Usage + +The reidentification process is applied to tracking results, which are derived from the application of a tracking algorithm on detection results for successive frames of a video. This reidentification process is applied iteratively on each tracking result, updating its internal states during the process. + +The `ReidProcessor` needs to be updated with the tracking results for each frame of your +sequence or video. This is done by calling the `update` method that takes 2 arguments: + +- `frame_id`: an integer specifying the current frame of the video +- `tracker_output`: a numpy array containing the tracking results for the current frame + +## Step 2: Understand the Data Format Requirements + +The `ReidProcessor` update function requires a numpy array of tracking results for the current frame as input. This data must meet specific criteria regarding data type and structure. + +All input data must be numeric, either integers or floats. +Here's an example of the expected input data format based on the schema: + +| bbox (0-3) | object_id (4) | category (5) | confidence (6) | +|-----------------|---------------|--------------|----------------| +| 50, 60, 120, 80 | 1 | 1 | 0.91 | +| 50, 60, 120, 80 | 2 | 0 | 0.54 | + +Each row corresponds to a tracked object. + +- The first four columns denote the **bounding box coordinates** in the format (x, y, width, height), +where x and y are the top left coordinates of the bounding box. These coordinates can be either normalized or in pixel units. +These values remain unchanged during the reidentification process. +- The fifth column is the **object ID** assigned by the tracker, which may be adjusted during the reidentification process. +- The sixth column indicates the **category** of the detected object, which may also be adjusted during the reidentification process. +- The seventh column is the confidence score of the detection, which is not modified by the reidentification process. + +For additional information, you can utilize `ReidProcessor.print_input_data_requirements()`. + +Here's a reformatted example of how the output data should appear, based on the schema: + +| frame_id (0) | object_id (1) | category (2) | bbox (3-6) | confidence (7) | mean_confidence (8) | tracker_id (9) | +|--------------|---------------|--------------|-----------------|----------------|---------------------|----------------| +| 1 | 1 | 1 | 50, 60, 120, 80 | 0.91 | 0.85 | 1 | +| 2 | 2 | 0 | 50, 60, 120, 80 | 0.54 | 0.60 | 2 | + +- The first column represents the **frame identifier**, indicating the frame for which the result is applicable. +- The second column is the **object ID** assigned by the reidentification process. +- The third column is the **category** of the detected object, which may be adjusted during the reidentification process. +- The next four columns represent the **bounding box coordinates**, which remain unchanged from the input data. +- The seventh column is the **confidence** of the object, which also remains unchanged from the input data. +- The eighth column indicates the **average confidence** of the detected object over its lifetime, from the beginning of the tracking to the current frame. +- The final column is the **object ID assigned by the tracking algorithm**, before the reidentification process. + +You can use `ReidProcessor.print_output_data_format_information()` for more insight. + +## Step 3: Understand Necessary Modules + +To make ReidProcessor work, several modules are necessary: + +- `TrackedObject`: This class represents a tracked object. It is used within the Matcher and ReidProcessor classes. +- `TrackedObjectMetadata`: This class is attached to a tracked object and represents informations and properties about the object. +- `TrackedObjectFilter`: This class is used to filter tracked objects based on certain criteria. It is used within the ReidProcessor class. +- `Matcher`: This class is used to match tracked objects based on a cost function and a selection function. It is initialized within the ReidProcessor class. + +The cost and selection functions are key components of the ReidProcessor, as they will drive the matching process between lost objects and new objects during the video. Those two functions are fully customizable and can be passed as arguments of the ReidProcessor at initialization. They both take 2 `TrackedObjects` as inputs, and perform computation based on their metadatas. + +- **cost function**: This function calculates the cost of matching two objects. It takes two TrackedObject instances as input and returns a numerical value representing the cost of matching these two objects. A lower cost indicates a higher likelihood of a match. The default cost function is `bounding_box_distance`. + +- **selection_function**: This function determines whether two objects should be considered for matching. It takes two TrackedObject instances as input and returns a binary value (0 or 1). A return value of 1 indicates that the pair should be considered for matching, while a return value of 0 indicates that the pair should not be considered. The default selection function is `select_by_category`. + +In summary, prior to the matching process, filtering on which objects should be considerated is applied thought the `TrackedObjectFilter`. All objects are represented by the `TrackedObject` class, with its attached metadata represented by `TrackedObjectMetadata`. The `ReidProcessor` then uses the `Matcher` class with a cost function and selection function to match objects. + +## Step 4: Initialize ReidProcessor + +If you do not want to provide custom cost and selection function, here is an example of ReidProcessor initialization: + +```python +reid_processor = ReidProcessor(filter_confidence_threshold=0.1, + filter_time_threshold=5, + cost_function_threshold=5000, + max_attempt_to_match=5, + max_frames_to_rematch=500, + save_to_txt=True, + file_path="your_file.txt") +``` + +Here is a brief explanation of each argument in the ReidProcessor function, and how you can monitor the `Matcher` and the `TrackedObjectFilter` behaviours: + +- `filter_confidence_threshold`: Float value that sets the **minimum average confidence level** for a tracked object to be considered valid. Tracked objects with average confidence levels below this threshold will be ignored. + +- `filter_time_threshold`: Integer that sets the **minimum number of frames** a tracked object must be seen with the same id to be considered valid. Tracked objects seen less frames that this threshold will be ignored. + +- `cost_function_threshold`: This is a float value that sets the **maximum cost for a match** between a detection and a track. If the cost of matching a detection to a track exceeds this threshold, the match will not be made. Set to None for no limitation. + +- `max_attempt_to_match`: This is an integer that sets the **maximum number of attempts to match a tracked object never seen before** to a lost tracked object. If this tracked object never seen before can't be matched within this number of attempts, it will be considered a new stable tracked object. + +- `max_frames_to_rematch`: This is an integer that sets the **maximum number of frames to try to rematch a tracked object that has been lost**. If a lost object can't be rematch within this number of frames, it will be considered as lost forever. + +- `save_to_txt`: This is a boolean value that determines whether the tracking results should be saved to a text file. If set to True, the results will be saved to a text file. + +- `file_path`: This is a string that specifies the path to the text file where the tracking results will be saved. This argument is only relevant if save_to_txt is set to True. + +For more information on how to design custom cost and selection functions, refer to [this guide](custom_cost_selection.md). + +## Step 5: Run reidentifiaction process + +Lets say you have a `dataset` iterable object, composed for each iteartion of a frame id and its associated tracking results. You can call the `ReidProcessor` update class using the following: + +```python +for frame_id, tracker_output in dataset: + corrected_results = reid_processor.update(frame_id = frame_id, tracker_output=tracker_output) +``` + +At the end of the for loop, information about the correction can be retrieved using the `ReidProcessor` properties. For instance, the list of tracked object can be accessed using: + +```python +reid_processor.seen_objects() +``` diff --git a/docs/reference/cost_functions.md b/docs/reference/cost_functions.md new file mode 100644 index 0000000..3a0a1c9 --- /dev/null +++ b/docs/reference/cost_functions.md @@ -0,0 +1,15 @@ +# Cost functions + +In the codebase, a cost function is used to measure the dissimilarity between two objects, specifically [TrackedObjects](tracked_object.md) instances. The cost function is a crucial part of the matching process in the [Matcher class](matcher.md). It calculates a cost matrix, where each element represents the cost of assigning a candidate to a switcher. + +The cost function affects the behavior of the matching process in the following ways: + +1. **Determining Matches**: The cost function is used to determine the best matches between candidates and switchers. The lower the cost, the higher the likelihood that two objects are the same. + +2. **Influencing Match Quality**: The choice of cost function can greatly influence the quality of the matches. For example, a cost function that calculates the Euclidean distance between the centers of bounding boxes might be more suitable for tracking objects in a video, while a cost function that calculates the absolute difference between confidence values might be more suitable for matching objects based on their detection confidence. + +3. **Setting Match Thresholds**: The cost function also plays a role in setting thresholds for matches. In the [Matcher class](matcher.md), if the cost exceeds a certain threshold, the match is discarded. + +You can provide a custom cost function to the reidentification process. For more information, please refer to [this documentation](../custom_cost_selection.md). + +:::trackreid.cost_functions diff --git a/docs/reference/matcher.md b/docs/reference/matcher.md new file mode 100644 index 0000000..396ffbf --- /dev/null +++ b/docs/reference/matcher.md @@ -0,0 +1,3 @@ +# Matcher + +:::trackreid.matcher diff --git a/docs/reference/reid_processor.md b/docs/reference/reid_processor.md new file mode 100644 index 0000000..c72dbdc --- /dev/null +++ b/docs/reference/reid_processor.md @@ -0,0 +1,3 @@ +# Reid processor + +:::trackreid.reid_processor diff --git a/docs/reference/selection_functions.md b/docs/reference/selection_functions.md new file mode 100644 index 0000000..6b4f998 --- /dev/null +++ b/docs/reference/selection_functions.md @@ -0,0 +1,17 @@ +# Selection Functions + +In the codebase, a selection function is used to determine whether two objects, specifically [TrackedObjects](tracked_object.md) instances, should be considered for matching. The selection function is a key part of the matching process in the [Matcher class](matcher.md). + +The selection function influences the behavior of the matching process in the following ways: + +1. **Filtering Candidates**: The selection function is used to filter out pairs of objects that should not be considered for matching. This can help reduce the computational complexity of the matching process by reducing the size of the cost matrix. + +2. **Customizing Matching Criteria**: The selection function allows you to customize the criteria for considering a pair of objects for matching. For example, you might want to only consider pairs of objects that belong to the same category, or pairs of objects that belong to the same area / zone. + +3. **Improving Match Quality**: By carefully choosing or designing a selection function, you can improve the quality of the matches. For example, a selection function that only considers pairs of objects with similar appearance features might lead to more accurate matches. + +The selection function should return a boolean value. A return value of `True` or `1` indicates that the pair of objects should be considered for matching, while a return value of `False` or `0` indicates that the pair should not be considered. + +You can provide a custom selection function to the reidentification process. For more information, please refer to [this documentation](../custom_cost_selection.md). + +:::trackreid.selection_functions diff --git a/docs/reference/tracked_object.md b/docs/reference/tracked_object.md new file mode 100644 index 0000000..d74024a --- /dev/null +++ b/docs/reference/tracked_object.md @@ -0,0 +1,3 @@ +# TrackedObject + +:::trackreid.tracked_object diff --git a/docs/reference/tracked_object_filter.md b/docs/reference/tracked_object_filter.md new file mode 100644 index 0000000..757ebe7 --- /dev/null +++ b/docs/reference/tracked_object_filter.md @@ -0,0 +1,3 @@ +# TrackedObjectFilter + +:::trackreid.tracked_object_filter diff --git a/docs/reference/tracked_object_metadata.md b/docs/reference/tracked_object_metadata.md new file mode 100644 index 0000000..b03870c --- /dev/null +++ b/docs/reference/tracked_object_metadata.md @@ -0,0 +1,3 @@ +# TrackedObjectMetadata + +:::trackreid.tracked_object_metadata diff --git a/lib/norfair_helper/utils.py b/lib/norfair_helper/utils.py index 11bf4af..aff22c8 100644 --- a/lib/norfair_helper/utils.py +++ b/lib/norfair_helper/utils.py @@ -1,13 +1,15 @@ from typing import List +import cv2 import numpy as np -from norfair import Detection +from norfair import Detection, get_cutout from lib.bbox.utils import rescale_bbox, xy_center_to_xyxy def yolo_to_norfair_detection( - yolo_detections: np.array, original_img_size: tuple + yolo_detections: np.array, + original_img_size: tuple, ) -> List[Detection]: """convert detections_as_xywh to norfair detections""" norfair_detections: List[Detection] = [] @@ -23,3 +25,26 @@ def yolo_to_norfair_detection( scores = np.array([detection_output[5].item(), detection_output[5].item()]) norfair_detections.append(Detection(points=bbox, scores=scores, label=detection_output[0])) return norfair_detections + + +def compute_embeddings(norfair_detections: List[Detection], image: np.array): + """ + Add embedding attribute to all Detection objects in norfair_detections. + """ + for detection in norfair_detections: + object = get_cutout(detection.points, image) + if object.shape[0] > 0 and object.shape[1] > 0: + detection.embedding = get_hist(object) + return norfair_detections + + +def get_hist(image: np.array): + """Compute an embedding with histograms""" + hist = cv2.calcHist( + [cv2.cvtColor(image, cv2.COLOR_BGR2Lab)], + [0, 1], + None, + [128, 128], + [0, 256, 0, 256], + ) + return cv2.normalize(hist, hist).flatten() diff --git a/lib/norfair_helper/video.py b/lib/norfair_helper/video.py index 3ea8530..d84e09d 100644 --- a/lib/norfair_helper/video.py +++ b/lib/norfair_helper/video.py @@ -2,12 +2,16 @@ import numpy as np from norfair import Tracker, draw_boxes -from lib.norfair_helper.utils import yolo_to_norfair_detection +from lib.norfair_helper.utils import compute_embeddings, yolo_to_norfair_detection from lib.sequence import Sequence def generate_tracking_video( - sequence: Sequence, tracker: Tracker, frame_size: tuple, output_path: str + sequence: Sequence, + tracker: Tracker, + frame_size: tuple, + output_path: str, + add_embedding: bool = False, ) -> str: """ Generate a video with the tracking results. @@ -17,6 +21,7 @@ def generate_tracking_video( tracker: The tracker to use. frame_size: The size of the frames. output_path: The path to save the video to. + add_embedding: Whether to add the embedding to the video. Returns: The path to the video. @@ -26,11 +31,12 @@ def generate_tracking_video( out = cv2.VideoWriter(output_path, fourcc, 20.0, frame_size) # Changed file extension to .mp4 for frame, detection in sequence: + frame = np.array(frame) detections_list = yolo_to_norfair_detection(detection, frame_size) + if add_embedding: + detections_list = compute_embeddings(detections_list, frame) tracked_objects = tracker.update(detections=detections_list) - frame_detected = draw_boxes( - np.array(frame), tracked_objects, draw_ids=True, color="by_label" - ) + frame_detected = draw_boxes(frame, tracked_objects, draw_ids=True, color="by_label") frame_detected = cv2.cvtColor(frame_detected, cv2.COLOR_BGR2RGB) out.write(frame_detected) out.release() diff --git a/lib/sequence.py b/lib/sequence.py index 97b765a..b458f0b 100644 --- a/lib/sequence.py +++ b/lib/sequence.py @@ -30,6 +30,7 @@ def __next__(self): raise StopIteration frame = Image.open(self.frame_paths[self.index]) + try: detection = np.loadtxt(self.detection_paths[self.index], dtype="float") except OSError: # file doesn't exist not detection return empty file diff --git a/mkdocs.yaml b/mkdocs.yaml deleted file mode 100644 index 449fd64..0000000 --- a/mkdocs.yaml +++ /dev/null @@ -1,31 +0,0 @@ -site_name: track-reid - -theme: - name: "material" - palette: - - media: "(prefers-color-scheme: dark)" - scheme: default - primary: teal - accent: amber - toggle: - icon: material/moon-waning-crescent - name: Switch to dark mode - - media: "(prefers-color-scheme: light)" - scheme: slate - primary: teal - accent: amber - toggle: - icon: material/white-balance-sunny - name: Switch to light mode - features: - - search.suggest - - search.highlight - - content.tabs.link - -plugins: - - mkdocstrings - - search - -nav: - - Home: index.md - - Source code: code.md diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..edc0dd2 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,46 @@ +site_name: track-reid + +theme: + name: "material" + palette: + - media: "(prefers-color-scheme: dark)" + scheme: default + primary: indigo + accent: pink + toggle: + icon: material/moon-waning-crescent + name: Switch to dark mode + - media: "(prefers-color-scheme: light)" + scheme: slate + primary: indigo + accent: pink + toggle: + icon: material/white-balance-sunny + name: Switch to light mode + features: + - search.suggest + - search.highlight + - content.tabs.link + +plugins: + - mkdocstrings + - search + +markdown_extensions: + - codehilite: + use_pygments: true + pygments_style: monokai + +nav: + - Home: index.md + - Quickstart users: quickstart_user.md + - Quickstart developers: quickstart_dev.md + - Custom cost and selection functions: custom_cost_selection.md + - Code Reference: + - ReidProcessor: reference/reid_processor.md + - TrackedObjectFilter: reference/tracked_object_filter.md + - Matcher: reference/matcher.md + - TrackedObjectMetadata: reference/tracked_object_metadata.md + - TrackedObject: reference/tracked_object.md + - Cost functions: reference/cost_functions.md + - Selection functions: reference/selection_functions.md diff --git a/notebooks/norfair_starter_kit.ipynb b/notebooks/norfair_starter_kit.ipynb index ff9ff59..c911c52 100644 --- a/notebooks/norfair_starter_kit.ipynb +++ b/notebooks/norfair_starter_kit.ipynb @@ -33,7 +33,8 @@ "import sys; sys.path.append('..')\n", "import os\n", "\n", - "from norfair import Tracker\n", + "import cv2\n", + "from norfair import Tracker, OptimizedKalmanFilterFactory\n", "\n", "from lib.sequence import Sequence\n", "from lib.norfair_helper.video import generate_tracking_video\n" @@ -96,7 +97,7 @@ " detections.sort()\n", " return detections\n", "\n", - "frame_path = get_sequence_frames(SEQUENCES[1])\n", + "frame_path = get_sequence_frames(SEQUENCES[3])\n", "test_sequence = Sequence(frame_path)\n", "test_sequence" ] @@ -158,6 +159,7 @@ " tracker=basic_tracker,\n", " frame_size=(2560, 1440),\n", " output_path=os.path.join(VIDEO_OUTPUT_PATH, \"basic_tracking.mp4\"),\n", + " add_embedding=False,\n", ")\n", "video_path" ] @@ -175,8 +177,29 @@ "metadata": {}, "outputs": [], "source": [ - "def reid_distance_advanced(new_object, unmatched_object):\n", - " return 0 # ALWAYS MATCH" + "def always_match(new_object, unmatched_object):\n", + " return 0 # ALWAYS MATCH\n", + "\n", + "\n", + "def embedding_distance(matched_not_init_trackers, unmatched_trackers):\n", + " snd_embedding = unmatched_trackers.last_detection.embedding\n", + "\n", + " # Find last non-empty embedding if current is None\n", + " if snd_embedding is None:\n", + " snd_embedding = next((detection.embedding for detection in reversed(unmatched_trackers.past_detections) if detection.embedding is not None), None)\n", + "\n", + " if snd_embedding is None:\n", + " return 1 # No match if no embedding is found\n", + "\n", + " # Iterate over past detections and calculate distance\n", + " for detection_fst in matched_not_init_trackers.past_detections:\n", + " if detection_fst.embedding is not None:\n", + " distance = 1 - cv2.compareHist(snd_embedding, detection_fst.embedding, cv2.HISTCMP_CORREL)\n", + " # If similar a tiny bit similar, we return the distance to the tracker\n", + " if distance < 0.9:\n", + " return distance\n", + "\n", + " return 1 # No match if no matching embedding is found between the 2" ] }, { @@ -187,12 +210,13 @@ "source": [ "advanced_tracker = Tracker(\n", " distance_function=\"sqeuclidean\",\n", + " filter_factory = OptimizedKalmanFilterFactory(R=5, Q=0.05),\n", " distance_threshold=350, # Higher value means objects further away will be matched\n", - " initialization_delay=10, # Wait 15 frames before an object is starts to be tracked\n", - " hit_counter_max=20, # Inertia, higher values means an object will take time to enter in reid phase\n", - " reid_distance_function=reid_distance_advanced, # function to decide on which metric to reid\n", - " reid_distance_threshold=0.5, # If the distance is below 0.5 the object is matched\n", - " reid_hit_counter_max=200, # inertia, higher values means an object will enter reid phase longer\n", + " initialization_delay=12, # Wait 15 frames before an object is starts to be tracked\n", + " hit_counter_max=15, # Inertia, higher values means an object will take time to enter in reid phase\n", + " reid_distance_function=embedding_distance, # function to decide on which metric to reid\n", + " reid_distance_threshold=0.9, # If the distance is below the object is matched\n", + " reid_hit_counter_max=200, #higher values means an object will stay reid phase longer\n", " )" ] }, @@ -207,6 +231,7 @@ " tracker=advanced_tracker,\n", " frame_size=(2560, 1440),\n", " output_path=os.path.join(VIDEO_OUTPUT_PATH, \"advance_tracking.mp4\"),\n", + " add_embedding=True,\n", ")\n", "video_path" ] diff --git a/notebooks/starter_kit_reid.ipynb b/notebooks/starter_kit_reid.ipynb index 58605c0..685d1a6 100644 --- a/notebooks/starter_kit_reid.ipynb +++ b/notebooks/starter_kit_reid.ipynb @@ -6,18 +6,53 @@ "metadata": {}, "outputs": [], "source": [ - "%load_ext autoreload \n", + "%load_ext autoreload\n", "%autoreload 2" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For this demo, you have to install bytetrack. You can do so by typing the following command : \n", + "```bash\n", + "pip install git+https://github.com/artefactory-fr/bytetrack.git@main\n", + "````\n", + "\n", + "Baseline data can be found in `gs://data-track-reid/predictions/baseline`. You can copy them in `../data/predictions/` using the following commands (in a terminal at the root of the project):\n", + "\n", + "```bash\n", + "mkdir -p ./data/predictions/\n", + "gsutil -m cp -r gs://data-track-reid/predictions/baseline ./data/predictions/\n", + "```\n", + "\n", + "Then you can reoganize the data using the following : \n", + "```bash \n", + "find ./data/predictions/baseline -mindepth 2 -type f -name \"*.txt\" -exec sh -c 'mv \"$0\" \"${0%/*/*}/$(basename \"${0%/*}\").txt\"' {} \\; && find ./data/predictions/baseline -mindepth 1 -type d -empty -delete\n", + "```" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "import sys \n", - "sys.path.append(\"..\")" + "import os\n", + "import sys\n", + "from datetime import datetime\n", + "\n", + "import cv2\n", + "import numpy as np\n", + "from bytetracker import BYTETracker\n", + "from bytetracker.basetrack import BaseTrack\n", + "from tqdm import tqdm\n", + "\n", + "from lib.bbox.utils import rescale_bbox, xy_center_to_xyxy\n", + "from lib.sequence import Sequence\n", + "from trackreid.configs.output_data_positions import output_data_positions\n", + "\n", + "sys.path.append(\"..\")\n" ] }, { @@ -26,9 +61,52 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", - "import pandas as pd\n", - "from track_reid.reid_processor import ReidProcessor" + "\n", + "from trackreid.reid_processor import ReidProcessor\n", + "from trackreid.cost_functions import bounding_box_distance\n", + "from trackreid.selection_functions import select_by_category\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ReidProcessor.print_input_data_format_requirements()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ReidProcessor.print_output_data_format_information()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Real life data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "DATA_PATH = \"../data\"\n", + "DETECTION_PATH = f\"{DATA_PATH}/detections\"\n", + "FRAME_PATH = f\"{DATA_PATH}/frames\"\n", + "PREDICTIONS_PATH = f\"{DATA_PATH}/predictions\"\n", + "VIDEO_OUTPUT_PATH = \"private\"\n", + "\n", + "SEQUENCES = os.listdir(DETECTION_PATH)\n", + "GENERATE_VIDEOS = False\n" ] }, { @@ -37,21 +115,49 @@ "metadata": {}, "outputs": [], "source": [ - "def bounding_box_distance(obj1, obj2):\n", - " # Get the bounding boxes from the Metadata of each TrackedObject\n", - " bbox1 = obj1.metadata.bbox\n", - " bbox2 = obj2.metadata.bbox\n", + "def get_sequence_frames(sequence):\n", + " frames = os.listdir(f\"{FRAME_PATH}/{sequence}\")\n", + " frames = [os.path.join(f\"{FRAME_PATH}/{sequence}\", frame) for frame in frames]\n", + " frames.sort()\n", + " return frames\n", "\n", - " # Calculate the Euclidean distance between the centers of the bounding boxes\n", - " center1 = ((bbox1[0] + bbox1[2]) / 2, (bbox1[1] + bbox1[3]) / 2)\n", - " center2 = ((bbox2[0] + bbox2[2]) / 2, (bbox2[1] + bbox2[3]) / 2)\n", - " distance = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)\n", + "def get_sequence_detections(sequence):\n", + " detections = os.listdir(f\"{DETECTION_PATH}/{sequence}\")\n", + " detections = [os.path.join(f\"{DETECTION_PATH}/{sequence}\", detection) for detection in detections]\n", + " detections.sort()\n", + " return detections" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class DetectionHandler():\n", + " def __init__(self, image_shape) -> None:\n", + " self.image_shape = image_shape\n", + "\n", + " def process(self, detection_output):\n", + " if detection_output.size:\n", + " if detection_output.ndim == 1:\n", + " detection_output = np.expand_dims(detection_output, 0)\n", "\n", - " return distance\n", + " processed_detection = np.zeros(detection_output.shape)\n", "\n", - "def select_by_category(obj1, obj2):\n", - " # Compare the categories of the two objects\n", - " return 1 if obj1.category == obj2.category else 0" + " for idx, detection in enumerate(detection_output):\n", + " clss = detection[0]\n", + " conf = detection[5]\n", + " bbox = detection[1:5]\n", + " xyxy_bbox = xy_center_to_xyxy(bbox)\n", + " rescaled_bbox = rescale_bbox(xyxy_bbox,self.image_shape)\n", + " processed_detection[idx,:4] = rescaled_bbox\n", + " processed_detection[idx,4] = conf\n", + " processed_detection[idx,5] = clss\n", + "\n", + " return processed_detection\n", + " else:\n", + " return detection_output\n" ] }, { @@ -60,23 +166,30 @@ "metadata": {}, "outputs": [], "source": [ - "# Example usage:\n", - "data = np.array([\n", - " [1, 1, \"car\", 100, 200, 300, 400, 0.9],\n", - " [1, 2, \"person\", 50, 150, 200, 400, 0.8],\n", - " [2, 1, \"truck\", 120, 220, 320, 420, 0.95],\n", - " [2, 2, \"person\", 60, 160, 220, 420, 0.85],\n", - " [3, 1, \"car\", 110, 210, 310, 410, 0.91],\n", - " [3, 3, \"person\", 61, 170, 220, 420, 0.91],\n", - " [3, 4, \"car\", 60, 160, 220, 420, 0.91],\n", - " [3, 6, \"person\", 60, 160, 220, 420, 0.91],\n", - " [4, 1, \"truck\", 130, 230, 330, 430, 0.92],\n", - " [4, 2, \"person\", 65, 165, 225, 425, 0.83],\n", - " [5, 1, \"car\", 115, 215, 315, 415, 0.93],\n", - " [5, 2, \"person\", 57, 157, 207, 407, 0.84],\n", - " [5, 4, \"car\", 60, 160, 220, 420, 0.91],\n", - " [5, 8, \"person\", 60, 160, 220, 420, 0.91],\n", - "])\n" + "class TrackingHandler():\n", + " def __init__(self, tracker) -> None:\n", + " self.tracker = tracker\n", + "\n", + " def update(self, detection_outputs, frame_id):\n", + "\n", + " if not detection_outputs.size :\n", + " return detection_outputs\n", + "\n", + " processed_detections = self._pre_process(detection_outputs)\n", + " tracked_objects = self.tracker.update(processed_detections, frame_id = frame_id)\n", + " processed_tracked = self._post_process(tracked_objects)\n", + " return processed_tracked\n", + "\n", + " def _pre_process(self,detection_outputs : np.ndarray):\n", + " return detection_outputs\n", + "\n", + " def _post_process(self, tracked_objects : np.ndarray):\n", + "\n", + " if tracked_objects.size :\n", + " if tracked_objects.ndim == 1:\n", + " tracked_objects = np.expand_dims(tracked_objects, 0)\n", + "\n", + " return tracked_objects" ] }, { @@ -85,28 +198,66 @@ "metadata": {}, "outputs": [], "source": [ - "processor = ReidProcessor(filter_confidence_threshold=0.4, \n", - " filter_time_threshold=0,\n", - " cost_function=bounding_box_distance,\n", - " selection_function=select_by_category,\n", - " max_attempt_to_rematch=0,\n", - " max_frames_to_rematch=100)\n", + "timestamp = datetime.now().strftime('%Y-%m-%d-%H-%M')\n", + "print(timestamp)\n", + "folder_save = os.path.join(PREDICTIONS_PATH,timestamp)\n", + "os.makedirs(folder_save, exist_ok=True)\n", + "GENERATE_VIDEOS = False\n", + "for sequence in tqdm(SEQUENCES) :\n", + " frame_path = get_sequence_frames(sequence)\n", + " test_sequence = Sequence(frame_path)\n", + " frame_id = 0\n", + " BaseTrack._count = 0\n", + " from datetime import datetime\n", + "\n", + " file_path = os.path.join(folder_save,sequence) + '.txt'\n", + " video_path = os.path.join(folder_save,sequence) + '.mp4'\n", "\n", + " if GENERATE_VIDEOS:\n", + " fourcc = cv2.VideoWriter_fourcc(*'avc1') # or use 'x264'\n", + " out = cv2.VideoWriter(video_path, fourcc, 20.0, (2560, 1440)) # adjust the frame size (640, 480) as per your needs\n", "\n", - "columns = ['frame_id', 'object_id', 'category', 'x1', 'y1', 'x2', 'y2', 'confidence']\n", - "df = pd.DataFrame(data, columns=columns)\n", - "# Convert numerical columns to the appropriate data type\n", - "df[['frame_id', 'object_id', 'x1', 'y1', 'x2', 'y2']] = df[['frame_id', 'object_id', 'x1', 'y1', 'x2', 'y2']].astype(int)\n", - "df['confidence'] = df['confidence'].astype(float)\n", + " detection_handler = DetectionHandler(image_shape=[2560, 1440])\n", + " tracking_handler = TrackingHandler(tracker=BYTETracker(track_thresh= 0.3, track_buffer = 5, match_thresh = 0.85, frame_rate= 30))\n", + " reid_processor = ReidProcessor(filter_confidence_threshold=0.1,\n", + " filter_time_threshold=5,\n", + " cost_function=bounding_box_distance,\n", + " cost_function_threshold=5000, # max cost to rematch 2 objects\n", + " selection_function=select_by_category,\n", + " max_attempt_to_match=5,\n", + " max_frames_to_rematch=500,\n", + " save_to_txt=True,\n", + " file_path=file_path)\n", "\n", + " for frame, detection in test_sequence:\n", "\n", - "for frame_id, frame_data in df.groupby(\"frame_id\"):\n", + " frame_id += 1\n", "\n", - " bytetrack_output = frame_data.values\n", - " if bytetrack_output.ndim == 1 : \n", - " bytetrack_output = np.expand_dims(bytetrack_output, 0)\n", + " processed_detections = detection_handler.process(detection)\n", + " processed_tracked = tracking_handler.update(processed_detections, frame_id)\n", + " reid_results = reid_processor.update(processed_tracked, frame_id)\n", "\n", - " results = processor.update(bytetrack_output)\n" + " if GENERATE_VIDEOS and len(reid_results) > 0:\n", + " frame = np.array(frame)\n", + " for res in reid_results:\n", + " object_id = int(res[output_data_positions.object_id])\n", + " bbox = list(map(int, res[output_data_positions.bbox]))\n", + " class_id = int(res[output_data_positions.category])\n", + " tracker_id = int(res[output_data_positions.tracker_id])\n", + " mean_confidence = float(res[output_data_positions.mean_confidence])\n", + " x1, y1, x2, y2 = bbox\n", + " color = (0, 0, 255) if class_id else (0, 255, 0) # green for class 0, red for class 1\n", + " cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)\n", + " cv2.putText(frame, f\"{object_id} ({tracker_id}) : {round(mean_confidence,2)}\", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)\n", + "\n", + " if GENERATE_VIDEOS:\n", + " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", + " out.write(frame)\n", + "\n", + " if GENERATE_VIDEOS :\n", + " out.release()\n", + "\n", + "\n" ] }, { @@ -115,7 +266,36 @@ "metadata": {}, "outputs": [], "source": [ - "print(processor.all_tracked_objects[0])" + "from collections import defaultdict\n", + "\n", + "def count_occurrences(file_path, case):\n", + " object_counts = defaultdict(int)\n", + " class_counts = defaultdict(lambda: defaultdict(int))\n", + "\n", + " with open(file_path, 'r') as f:\n", + " for line in f:\n", + " data = line.split()\n", + "\n", + " if case != 'baseline':\n", + " object_id = int(data[1])\n", + " category = int(data[2])\n", + " else:\n", + " object_id = int(data[1])\n", + " category = int(data[-1])\n", + "\n", + " object_counts[object_id] += 1\n", + " class_counts[object_id][category] += 1\n", + "\n", + " return object_counts, class_counts\n", + "\n", + "def filter_counts(object_counts, class_counts, min_occurrences=10):\n", + " filtered_objects = {}\n", + "\n", + " for object_id, count in object_counts.items():\n", + " if count > min_occurrences and class_counts[object_id][0] > class_counts[object_id][1]:\n", + " filtered_objects[object_id] = count\n", + "\n", + " return filtered_objects\n" ] }, { @@ -124,7 +304,18 @@ "metadata": {}, "outputs": [], "source": [ - "print(processor.all_tracked_objects[0].metadata)" + "PATH_PREDICTIONS = f\"../data/predictions/{timestamp}\"\n", + "\n", + "for sequence in SEQUENCES:\n", + " print(\"-\"*50)\n", + " print(sequence)\n", + "\n", + " for case in [\"baseline\", timestamp]:\n", + " object_counts, class_counts = count_occurrences(f'../data/predictions/{case}/{sequence}.txt', case=case)\n", + " filtered_objects = filter_counts(object_counts, class_counts)\n", + "\n", + " print(case)\n", + " print(filtered_objects)" ] } ], diff --git a/pyproject.toml b/pyproject.toml index a7209d1..1d64bc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,12 +7,15 @@ readme = "README.md" [tool.poetry.dependencies] -python = ">=3.8, <3.11.0" -pandas = "1.5.3" +python = "3.10.13" numpy = "1.24.2" llist = "0.7.1" pydantic = "2.4.2" -bytetracker = { git = "https://github.com/TomDarmon/bytetrack-pip.git", branch = "main" } + +lapx = "^0.5.5" +opencv-python = "^4.8.1.78" +tqdm = "^4.66.1" +pillow = "^10.1.0" [tool.poetry.group.dev.dependencies] black = "22.10.0" @@ -20,12 +23,13 @@ ruff = "0.0.272" isort = "5.12.0" pre-commit = "3.3.3" pytest = "7.3.2" +ipykernel = "6.24.0" mkdocs = "1.4.3" mkdocs-material = "9.1.15" -mkdocstrings-python = "1.1.2" +mkdocstrings = {extras = ["python-legacy"], version = "^0.24.0"} bandit = "1.7.5" nbstripout = "0.6.1" -ipykernel = "6.24.0" + [build-system] diff --git a/tests/data/unit_tests/tracked_objects/object_1.json b/tests/data/unit_tests/tracked_objects/object_1.json new file mode 100644 index 0000000..9609a5b --- /dev/null +++ b/tests/data/unit_tests/tracked_objects/object_1.json @@ -0,0 +1,28 @@ +{ + "object_id": 1.0, + "state": 0, + "re_id_chain": [ + 1.0, + 2.0, + 14.0, + 18.0, + 21.0 + ], + "metadata": { + "first_frame_id": 15, + "last_frame_id": 251, + "class_counts": { + "0": 175, + "1": 0 + }, + "bbox": [ + 598, + 208, + 814, + 447 + ], + "confidence": 0.610211, + "confidence_sum": 111.30582399999996, + "observations": 175 + } +} diff --git a/tests/data/unit_tests/tracked_objects/object_24.json b/tests/data/unit_tests/tracked_objects/object_24.json new file mode 100644 index 0000000..aef62df --- /dev/null +++ b/tests/data/unit_tests/tracked_objects/object_24.json @@ -0,0 +1,24 @@ +{ + "object_id": 24.0, + "state": -2, + "re_id_chain": [ + 24.0 + ], + "metadata": { + "first_frame_id": 154, + "last_frame_id": 251, + "class_counts": { + "0": 2, + "1": 0 + }, + "bbox": [ + 1430, + 664, + 1531, + 830 + ], + "confidence": 0.48447, + "confidence_sum": 1.108755, + "observations": 2 + } +} diff --git a/tests/data/unit_tests/tracked_objects/object_4.json b/tests/data/unit_tests/tracked_objects/object_4.json new file mode 100644 index 0000000..ecb65dd --- /dev/null +++ b/tests/data/unit_tests/tracked_objects/object_4.json @@ -0,0 +1,25 @@ +{ + "object_id": 4.0, + "state": 0, + "re_id_chain": [ + 4.0, + 13.0 + ], + "metadata": { + "first_frame_id": 38, + "last_frame_id": 251, + "class_counts": { + "0": 0, + "1": 216 + }, + "bbox": [ + 548, + 455, + 846, + 645 + ], + "confidence": 0.700626, + "confidence_sum": 149.68236100000004, + "observations": 216 + } +} diff --git a/tests/unit_tests/test_matcher.py b/tests/unit_tests/test_matcher.py new file mode 100644 index 0000000..1e21867 --- /dev/null +++ b/tests/unit_tests/test_matcher.py @@ -0,0 +1,80 @@ +import json +from pathlib import Path + +from trackreid.matcher import Matcher +from trackreid.tracked_object import TrackedObject + +INPUT_FOLDER = Path("tests/data/unit_tests/tracked_objects") +LIST_TRACKED_OBJECTS = ["object_1.json", "object_4.json", "object_24.json"] + +ALL_TRACKED_OBJECTS = [] +for tracked_object in LIST_TRACKED_OBJECTS: + with Path.open(INPUT_FOLDER / tracked_object) as file: + ALL_TRACKED_OBJECTS.append(TrackedObject.from_dict(json.load(file))) + + +def test_matcher_no_match(): + def dummy_cost_function(candidate, switcher): + return abs(candidate.object_id - switcher.object_id) + + def dummy_selection_function(candidate, switcher): # noqa: ARG001 + return 0 + + matcher = Matcher(dummy_cost_function, dummy_selection_function) + + candidates = [] + switchers = [] + for obj in ALL_TRACKED_OBJECTS: + candidates.append(obj) + switchers.append(obj) + + matches = matcher.match(candidates, switchers) + + assert len(matches) == 0 + + +def test_matcher_all_match(): + def dummy_cost_function(candidate, switcher): + return abs(candidate.object_id - switcher.object_id) + + def dummy_selection_function(candidate, switcher): # noqa: ARG001 + return 1 + + matcher = Matcher(dummy_cost_function, dummy_selection_function) + + candidates = [] + switchers = [] + for obj in ALL_TRACKED_OBJECTS: + candidates.append(obj) + switchers.append(obj) + + matches = matcher.match(candidates, switchers) + + assert len(matches) == 3 + for i in range(3): + assert matches[i][candidates[i]] == switchers[i] + + +def test_matcher_middle_case(): + def dummy_cost_function(candidate, switcher): + return abs(candidate.object_id - switcher.object_id) + + def dummy_selection_function(candidate, switcher): + return (candidate.object_id % 2 == switcher.object_id % 2) and ( + candidate.object_id != switcher.object_id + ) + + matcher = Matcher(dummy_cost_function, dummy_selection_function) + + candidates = [] + switchers = [] + for obj in ALL_TRACKED_OBJECTS: + candidates.append(obj) + switchers.append(obj) + + matches = matcher.match(candidates, switchers) + + assert len(matches) == 2 + for match in matches: + for candidate, switcher in match.items(): + assert candidate.object_id % 2 == switcher.object_id % 2 diff --git a/tests/unit_tests/test_metadata.py b/tests/unit_tests/test_metadata.py new file mode 100644 index 0000000..7b195ac --- /dev/null +++ b/tests/unit_tests/test_metadata.py @@ -0,0 +1,52 @@ +import json +from pathlib import Path + +from trackreid.tracked_object_metadata import TrackedObjectMetaData + +INPUT_FOLDER = Path("tests/data/unit_tests/tracked_objects") +LIST_TRACKED_OBJECTS = ["object_1.json", "object_4.json", "object_24.json"] + +ALL_TRACKED_METADATA = [] +for tracked_object in LIST_TRACKED_OBJECTS: + with Path.open(INPUT_FOLDER / tracked_object) as file: + ALL_TRACKED_METADATA.append(TrackedObjectMetaData.from_dict(json.load(file)["metadata"])) + + +def test_tracked_metadata_copy(): + tracked_metadata = ALL_TRACKED_METADATA[0].copy() + copied_metadata = tracked_metadata.copy() + assert copied_metadata.first_frame_id == 15 + assert copied_metadata.last_frame_id == 251 + assert copied_metadata.class_counts == {0: 175, 1: 0} + assert copied_metadata.bbox == [598, 208, 814, 447] + assert copied_metadata.confidence == 0.610211 + assert copied_metadata.confidence_sum == 111.30582399999996 + assert copied_metadata.observations == 175 + + assert round(copied_metadata.percentage_of_time_seen(251), 2) == 73.84 + class_proportions = copied_metadata.class_proportions() + assert round(class_proportions.get(0), 2) == 1.0 + assert round(class_proportions.get(1), 2) == 0.0 + + tracked_metadata_2 = ALL_TRACKED_METADATA[1].copy() + tracked_metadata.merge(tracked_metadata_2) + # test impact of merge inplace in a copy, should be none + + assert copied_metadata.class_counts == {0: 175, 1: 0} + assert copied_metadata.bbox == [598, 208, 814, 447] + assert copied_metadata.confidence == 0.610211 + assert copied_metadata.confidence_sum == 111.30582399999996 + assert copied_metadata.observations == 175 + + +def test_tracked_metadata_merge(): + tracked_metadata_1 = ALL_TRACKED_METADATA[0].copy() + tracked_metadata_2 = ALL_TRACKED_METADATA[1].copy() + tracked_metadata_1.merge(tracked_metadata_2) + assert tracked_metadata_1.last_frame_id == 251 + assert tracked_metadata_1.class_counts.get(0) == 175 + assert tracked_metadata_1.class_counts.get(1) == 216 + assert tracked_metadata_1.bbox == [548, 455, 846, 645] + assert tracked_metadata_1.confidence == 0.700626 + assert tracked_metadata_1.confidence_sum == 260.988185 + assert tracked_metadata_1.observations == 391 diff --git a/tests/unit_tests/test_placeholder.py b/tests/unit_tests/test_placeholder.py deleted file mode 100644 index 338a8e0..0000000 --- a/tests/unit_tests/test_placeholder.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Placeholder test file for unit tests. To be replaced with actual tests.""" - - -def test_placeholder() -> None: - """To be replaced with actual tests.""" - pass diff --git a/tests/unit_tests/test_tracked_objects.py b/tests/unit_tests/test_tracked_objects.py new file mode 100644 index 0000000..7c70952 --- /dev/null +++ b/tests/unit_tests/test_tracked_objects.py @@ -0,0 +1,96 @@ +import json +from pathlib import Path + +from trackreid.tracked_object import TrackedObject + +INPUT_FOLDER = Path("tests/data/unit_tests/tracked_objects") +LIST_TRACKED_OBJECTS = ["object_1.json", "object_4.json", "object_24.json"] + +ALL_TRACKED_OBJECTS = [] +for tracked_object in LIST_TRACKED_OBJECTS: + with Path.open(INPUT_FOLDER / tracked_object) as file: + ALL_TRACKED_OBJECTS.append(TrackedObject.from_dict(json.load(file))) + + +def test_tracked_object_copy(): + tracked_object = ALL_TRACKED_OBJECTS[0].copy() + copied_object = tracked_object.copy() + assert copied_object.object_id == tracked_object.object_id + assert copied_object.state == tracked_object.state + assert copied_object.category == tracked_object.category + assert round(copied_object.confidence, 2) == round(tracked_object.confidence, 2) + assert round(copied_object.mean_confidence, 2) == round(tracked_object.mean_confidence, 2) + assert copied_object.bbox == tracked_object.bbox + assert copied_object.nb_ids == tracked_object.nb_ids + assert copied_object.nb_corrections == tracked_object.nb_corrections + + tracked_object_2 = ALL_TRACKED_OBJECTS[1].copy() + tracked_object.merge(tracked_object_2) + + assert round(copied_object.confidence, 2) != round(tracked_object.confidence, 2) + assert round(copied_object.mean_confidence, 2) != round(tracked_object.mean_confidence, 2) + assert copied_object.bbox != tracked_object.bbox + assert copied_object.nb_ids != tracked_object.nb_ids + assert copied_object.nb_corrections != tracked_object.nb_corrections + + +def test_tracked_object_properties(): + tracked_object = ALL_TRACKED_OBJECTS[0].copy() + assert tracked_object.object_id == 1.0 + assert tracked_object.state == 0 + assert tracked_object.category == 0 + assert round(tracked_object.confidence, 2) == 0.61 + assert round(tracked_object.mean_confidence, 2) == 0.64 + assert tracked_object.bbox == [598, 208, 814, 447] + assert tracked_object.nb_ids == 5 + assert tracked_object.nb_corrections == 4 + + +def test_tracked_object_merge(): + tracked_object_1 = ALL_TRACKED_OBJECTS[0].copy() + tracked_object_2 = ALL_TRACKED_OBJECTS[1].copy() + tracked_object_1.merge(tracked_object_2) + assert tracked_object_1.object_id == 1.0 + assert tracked_object_1.state == 0 + assert tracked_object_1.category == 1 + assert round(tracked_object_1.confidence, 2) == 0.70 + assert round(tracked_object_1.mean_confidence, 2) == 0.67 + assert tracked_object_1.bbox == [548, 455, 846, 645] + assert tracked_object_1.nb_ids == 7 + assert tracked_object_1.nb_corrections == 6 + + +def test_tracked_object_cut(): + tracked_object = ALL_TRACKED_OBJECTS[0].copy() + new_object, cut_object = tracked_object.cut(2.0) + assert new_object.object_id == 14.0 + assert new_object.state == 0 + assert new_object.category == 0 + assert round(new_object.confidence, 2) == 0.61 + assert round(new_object.mean_confidence, 2) == 0.64 + assert new_object.bbox == [598, 208, 814, 447] + assert new_object.nb_ids == 3 + assert new_object.nb_corrections == 2 + assert cut_object.object_id == 1.0 + assert cut_object.state == 0 + assert cut_object.category == 0 + assert round(cut_object.confidence, 2) == 0.61 + assert round(cut_object.mean_confidence, 2) == 0.64 + assert cut_object.bbox == [598, 208, 814, 447] + assert cut_object.nb_ids == 2 + assert cut_object.nb_corrections == 1 + + +def test_get_age(): + tracked_object = ALL_TRACKED_OBJECTS[0].copy() + assert tracked_object.get_age(100) == 85 + + +def test_get_nb_frames_since_last_appearance(): + tracked_object = ALL_TRACKED_OBJECTS[0].copy() + assert tracked_object.get_nb_frames_since_last_appearance(300) == 49 + + +def test_get_state(): + tracked_object = ALL_TRACKED_OBJECTS[0].copy() + assert tracked_object.get_state() == 0 diff --git a/tests/unit_tests/test_utils.py b/tests/unit_tests/test_utils.py new file mode 100644 index 0000000..c96bb56 --- /dev/null +++ b/tests/unit_tests/test_utils.py @@ -0,0 +1,92 @@ +import json +from pathlib import Path + +import numpy as np +from llist import sllist + +from trackreid import utils +from trackreid.configs.output_data_positions import OutputDataPositions +from trackreid.tracked_object import TrackedObject + +# Load tracked object data +INPUT_FOLDER = Path("tests/data/unit_tests/tracked_objects") +LIST_TRACKED_OBJECTS = ["object_1.json", "object_4.json", "object_24.json"] + +ALL_TRACKED_OBJECTS = [] +for tracked_object in LIST_TRACKED_OBJECTS: + with Path.open(INPUT_FOLDER / tracked_object) as file: + ALL_TRACKED_OBJECTS.append(TrackedObject.from_dict(json.load(file))) + + +# Define tests +def test_get_top_list_correction(): + top_list_correction = utils.get_top_list_correction(ALL_TRACKED_OBJECTS) + assert top_list_correction == [21.0, 13.0, 24.0] + + +def test_split_list_around_value_1(): + my_list = sllist([1, 2, 3, 4, 5]) + value_to_split = 3 + before, after = utils.split_list_around_value(my_list, value_to_split) + assert list(before) == [1, 2, 3] + assert list(after) == [4, 5] + + +def test_split_list_around_value_2(): + my_list = sllist([1, 2, 3, 4, 5]) + value_to_split = 1 + before, after = utils.split_list_around_value(my_list, value_to_split) + assert list(before) == [1] + assert list(after) == [2, 3, 4, 5] + + +def test_split_list_around_value_3(): + my_list = sllist([1, 2, 3, 4, 5]) + value_to_split = 4 + before, after = utils.split_list_around_value(my_list, value_to_split) + assert list(before) == [1, 2, 3, 4] + assert list(after) == [5] + + +def test_filter_objects_by_state(): + states = 0 + assert utils.filter_objects_by_state(ALL_TRACKED_OBJECTS, states, exclusion=False) == [ + ALL_TRACKED_OBJECTS[0], + ALL_TRACKED_OBJECTS[1], + ] + + +def test_filter_objects_by_state_2(): + states = -2 + assert utils.filter_objects_by_state(ALL_TRACKED_OBJECTS, states, exclusion=True) == [ + ALL_TRACKED_OBJECTS[0], + ALL_TRACKED_OBJECTS[1], + ] + + +def test_filter_objects_by_category(): + category = 0 + assert utils.filter_objects_by_category(ALL_TRACKED_OBJECTS, category, exclusion=False) == [ + ALL_TRACKED_OBJECTS[0], + ALL_TRACKED_OBJECTS[2], + ] + + +def test_filter_objects_by_category_2(): + category = 1 + assert utils.filter_objects_by_category(ALL_TRACKED_OBJECTS, category, exclusion=True) == [ + ALL_TRACKED_OBJECTS[0], + ALL_TRACKED_OBJECTS[2], + ] + + +def test_reshape_tracker_result(): + tracker_output = np.array([1, 1, 3, 4, 5, 6, 7]) + assert np.array_equal( + utils.reshape_tracker_result(tracker_output), np.array([[1, 1, 3, 4, 5, 6, 7]]) + ) + + +def test_get_nb_output_cols(): + output_positions = OutputDataPositions() + assert utils.get_nb_output_cols(output_positions) == 10 diff --git a/trackreid/args/reid_args.py b/trackreid/args/reid_args.py deleted file mode 100644 index bab95c1..0000000 --- a/trackreid/args/reid_args.py +++ /dev/null @@ -1 +0,0 @@ -POSSIBLE_CLASSES = ["car", "person", "truck", "animal"] diff --git a/trackreid/configs/input_data_positions.py b/trackreid/configs/input_data_positions.py new file mode 100644 index 0000000..ec750ce --- /dev/null +++ b/trackreid/configs/input_data_positions.py @@ -0,0 +1,25 @@ +from pydantic import BaseModel, Field + + +class InputDataPositions(BaseModel): + bbox: list = Field( + [0, 1, 2, 3], + description="List of bounding box coordinate positions in the input (numpy array)." + + "Coordinates are in the format x,y,w,h by default.", + ) + object_id: int = Field( + 4, + description="Position of the ID assigned by the tracker to each item in the input (numpy array)", + ) + category: int = Field( + 5, + description="Position of the category assigned to each detected object in the input (numpy array)", + ) + confidence: int = Field( + 6, + description="Position of the confidence score (range [0, 1]) for each" + + "detected object in the input (numpy array)", + ) + + +input_data_positions = InputDataPositions() diff --git a/trackreid/configs/output_data_positions.py b/trackreid/configs/output_data_positions.py new file mode 100644 index 0000000..e010f5c --- /dev/null +++ b/trackreid/configs/output_data_positions.py @@ -0,0 +1,35 @@ +from pydantic import BaseModel, Field + + +class OutputDataPositions(BaseModel): + frame_id: int = Field(0, description="Position of the frame id in the output (numpy array)") + object_id: int = Field( + 1, + description="Position of the ID assigned by the reid processor to each item in the output (numpy array)", + ) + category: int = Field( + 2, + description="Position of the category assigned to each detected object in the output (numpy array)", + ) + bbox: list = Field( + [3, 4, 5, 6], + description="List of bounding box coordinate positions in the output (numpy array)." + + "Coordinates are in the format x,y,w,h by default.", + ) + confidence: int = Field( + 7, + description="Position of the confidence score (range [0, 1]) for each" + + " detected object in the output (numpy array)", + ) + mean_confidence: int = Field( + 8, + description="Position of the mean confidence score over object life time (range [0, 1]) for each" + + " tracked object in the output (numpy array)", + ) + tracker_id: int = Field( + 9, + description="Position of the id assigned to the tracker to each object (prior re-identification).", + ) + + +output_data_positions = OutputDataPositions() diff --git a/trackreid/constants/reid_constants.py b/trackreid/configs/reid_constants.py similarity index 51% rename from trackreid/constants/reid_constants.py rename to trackreid/configs/reid_constants.py index 950bf8b..9606e06 100644 --- a/trackreid/constants/reid_constants.py +++ b/trackreid/configs/reid_constants.py @@ -3,20 +3,31 @@ from pydantic import BaseModel -class ReidConstants(BaseModel): - BYETRACK_OUTPUT: int = -2 +class States(BaseModel): + LOST_FOREVER: int = -3 + TRACKER_OUTPUT: int = -2 FILTERED_OUTPUT: int = -1 STABLE: int = 0 SWITCHER: int = 1 CANDIDATE: int = 2 DESCRIPTION: ClassVar[dict] = { - BYETRACK_OUTPUT: "bytetrack output not in reid process", - FILTERED_OUTPUT: "bytetrack output entering reid process", + LOST_FOREVER: "switcher never rematched", + TRACKER_OUTPUT: "tracker output not in reid process", + FILTERED_OUTPUT: "tracker output entering reid process", STABLE: "stable object", SWITCHER: "lost object to be re-matched", CANDIDATE: "new object to be matched", } +class Matches(BaseModel): + DISALLOWED_MATCH: int = 1e6 + + +class ReidConstants(BaseModel): + STATES: States = States() + MATCHES: Matches = Matches() + + reid_constants = ReidConstants() diff --git a/trackreid/cost_functions/__init__.py b/trackreid/cost_functions/__init__.py new file mode 100644 index 0000000..8c0f3bd --- /dev/null +++ b/trackreid/cost_functions/__init__.py @@ -0,0 +1 @@ +from .bounding_box_distance import bounding_box_distance # noqa: F401 diff --git a/trackreid/cost_functions/bounding_box_distance.py b/trackreid/cost_functions/bounding_box_distance.py new file mode 100644 index 0000000..8deff92 --- /dev/null +++ b/trackreid/cost_functions/bounding_box_distance.py @@ -0,0 +1,28 @@ +import numpy as np + +from trackreid.tracked_object import TrackedObject + + +def bounding_box_distance(candidate: TrackedObject, switcher: TrackedObject) -> float: + """ + Calculates the Euclidean distance between the centers of the bounding boxes of two TrackedObjects. + This distance is used as a measure of dissimilarity between the two objects, with a smaller distance + indicating a higher likelihood of the objects being the same. + + Args: + candidate (TrackedObject): The first TrackedObject. + switcher (TrackedObject): The second TrackedObject. + + Returns: + float: The Euclidean distance between the centers of the bounding boxes of the two TrackedObjects. + """ + # Get the bounding boxes from the Metadata of each TrackedObject + bbox1 = candidate.metadata.bbox + bbox2 = switcher.metadata.bbox + + # Calculate the Euclidean distance between the centers of the bounding boxes + center1 = ((bbox1[0] + bbox1[2]) / 2, (bbox1[1] + bbox1[3]) / 2) + center2 = ((bbox2[0] + bbox2[2]) / 2, (bbox2[1] + bbox2[3]) / 2) + distance = np.sqrt((center1[0] - center2[0]) ** 2 + (center1[1] - center2[1]) ** 2) + + return distance diff --git a/trackreid/matcher.py b/trackreid/matcher.py index 18702a4..c57a9d9 100644 --- a/trackreid/matcher.py +++ b/trackreid/matcher.py @@ -1,90 +1,141 @@ -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Optional, Union +import lap import numpy as np -from scipy.optimize import linear_sum_assignment -from track_reid.tracked_object import TrackedObject + +from trackreid.configs.reid_constants import reid_constants +from trackreid.tracked_object import TrackedObject class Matcher: - def __init__(self, cost_function: Callable, selection_function: Callable) -> None: + def __init__( + self, + cost_function: Callable, + selection_function: Callable, + cost_function_threshold: Optional[Union[int, float]] = None, + ) -> None: + """ + Initializes the Matcher object with the provided cost function, selection function, and cost function threshold. + + Args: + cost_function (Callable): A function that calculates the cost of matching two objects. This function should take two TrackedObject instances as input and return a numerical value representing the cost of matching these two objects. A lower cost indicates a higher likelihood of a match. + selection_function (Callable): A function that determines whether two objects should be considered for matching. This function should take two TrackedObject instances as input and return a binary value (0 or 1). A return value of 1 indicates that the pair should be considered for matching, while a return value of 0 indicates that the pair should not be considered. + cost_function_threshold (Optional[Union[int, float]]): An optional threshold value for the cost function. If provided, any pair of objects with a matching cost greater than this threshold will not be considered for matching. If not provided, all selected pairs will be considered regardless of their matching cost. + + Returns: + None + """ # noqa: E501 self.cost_function = cost_function self.selection_function = selection_function + self.cost_function_threshold = cost_function_threshold def compute_cost_matrix( - self, objects1: List[TrackedObject], objects2: List[TrackedObject] + self, candidates: List[TrackedObject], switchers: List[TrackedObject] ) -> np.ndarray: - """Computes a cost matrix of size [M, N] between a list of M TrackedObjects objects1, - and a list of N TrackedObjects objects2. + """Computes a cost matrix of size [M, N] between a list of M TrackedObjects candidates, + and a list of N TrackedObjects switchers. Args: - objects1 (List[TrackedObject]): list of objects to be matched. - objects2 (List[TrackedObject]): list of candidates for matches. + candidates (List[TrackedObject]): list of candidates for matches. + switchers (List[TrackedObject]): list of objects to be matched. Returns: np.ndarray: cost to match each pair of objects. """ - if not objects1 or not objects2: + if not candidates or not switchers: return np.array([]) # Return an empty array if either list is empty - # Create matrices with all combinations of objects1 and objects2 - objects1_matrix, objects2_matrix = np.meshgrid(objects1, objects2) + # Create matrices with all combinations of candidates and switchers + candidates_matrix, switchers_matrix = np.meshgrid(candidates, switchers) # Use np.vectorize to apply the scoring function to all combinations - cost_matrix = np.vectorize(self.cost_function)(objects1_matrix, objects2_matrix) + cost_matrix = np.vectorize(self.cost_function)(candidates_matrix, switchers_matrix) return cost_matrix def compute_selection_matrix( - self, objects1: List[TrackedObject], objects2: List[TrackedObject] + self, candidates: List[TrackedObject], switchers: List[TrackedObject] ) -> np.ndarray: - """Computes a selection matrix of size [M, N] between a list of M TrackedObjects objects1, - and a list of N TrackedObjects objects2. + """Computes a selection matrix of size [M, N] between a list of M TrackedObjects candidates, + and a list of N TrackedObjects switchers. Args: - objects1 (List[TrackedObject]): list of objects to be matched. - objects2 (List[TrackedObject]): list of candidates for matches. + candidates (List[TrackedObject]): list of candidates for matches. + switchers (List[TrackedObject]): list of objects to be rematched. Returns: np.ndarray: cost each pair of objects be matched or not ? """ - if not objects1 or not objects2: + if not candidates or not switchers: return np.array([]) # Return an empty array if either list is empty - # Create matrices with all combinations of objects1 and objects2 - objects1_matrix, objects2_matrix = np.meshgrid(objects1, objects2) + # Create matrices with all combinations of candidates and switchers + candidates_matrix, switchers_matrix = np.meshgrid(candidates, switchers) # Use np.vectorize to apply the scoring function to all combinations - selection_matrix = np.vectorize(self.selection_function)(objects1_matrix, objects2_matrix) + selection_matrix = np.vectorize(self.selection_function)( + candidates_matrix, switchers_matrix + ) return selection_matrix def match( - self, objects1: List[TrackedObject], objects2: List[TrackedObject] + self, candidates: List[TrackedObject], switchers: List[TrackedObject] ) -> List[Dict[TrackedObject, TrackedObject]]: - """Computes a dict of matching between objects in list objects1 and objects in objects2. + """Computes a dict of matching between objects in list candidates and objects in switchers. Args: - objects1 (List[TrackedObject]): list of objects to be matched. - objects2 (List[TrackedObject]): list of candidates for matches. + candidates (List[TrackedObject]): list of candidates for matches. + switchers (List[TrackedObject]): list of objects to be matched. Returns: List[Dict[TrackedObject, TrackedObject]]: list of pairs of TrackedObjects if there is a match. """ - if not objects1 or not objects2: + if not candidates or not switchers: return [] # Return an empty array if either list is empty - cost_matrix = self.compute_cost_matrix(objects1, objects2) - selection_matrix = self.compute_selection_matrix(objects1, objects2) + cost_matrix = self.compute_cost_matrix(candidates, switchers) + selection_matrix = self.compute_selection_matrix(candidates, switchers) + + # Set a elements values to be discard at DISALLOWED_MATCH value, large cost + cost_matrix[selection_matrix == 0] = reid_constants.MATCHES.DISALLOWED_MATCH + if self.cost_function_threshold is not None: + cost_matrix[ + cost_matrix > self.cost_function_threshold + ] = reid_constants.MATCHES.DISALLOWED_MATCH - # Set a large cost value for elements to be discarded - cost_matrix[selection_matrix == 0] = 1e3 + matches = self.linear_assigment(cost_matrix, candidates=candidates, switchers=switchers) + + return matches - # Find the best matches using the linear sum assignment - row_indices, col_indices = linear_sum_assignment(cost_matrix, maximize=False) + @staticmethod + def linear_assigment( + cost_matrix: np.ndarray, candidates: List[TrackedObject], switchers: List[TrackedObject] + ) -> List[Dict[TrackedObject, TrackedObject]]: + """ + Performs linear assignment on the cost matrix to find the optimal match between candidates and switchers. + + The function uses the Jonker-Volgenant algorithm to solve the linear assignment problem. The algorithm finds the + optimal assignment (minimum total cost) for the given cost matrix. The cost matrix is a 2D numpy array where + each cell represents the cost of assigning a candidate to a switcher. + + Args: + cost_matrix (np.ndarray): A 2D array representing the cost of assigning each candidate to each switcher. + candidates (List[TrackedObject]): A list of candidate TrackedObjects for matching. + switchers (List[TrackedObject]): A list of switcher TrackedObjects to be matched. + + Returns: + List[Dict[TrackedObject, TrackedObject]]: A list of dictionaries where each dictionary represents a match. + The key is a candidate and the value is the corresponding switcher. + """ + _, _, row_cols = lap.lapjv( + cost_matrix, extend_cost=True, cost_limit=reid_constants.MATCHES.DISALLOWED_MATCH - 0.1 + ) matches = [] - for row, col in zip(row_indices, col_indices): - matches.append({objects1[col]: objects2[row]}) + for candidate_idx, switcher_idx in enumerate(row_cols): + if switcher_idx >= 0: + matches.append({candidates[candidate_idx]: switchers[switcher_idx]}) return matches diff --git a/trackreid/reid_processor.py b/trackreid/reid_processor.py index 57d18a3..027f07e 100644 --- a/trackreid/reid_processor.py +++ b/trackreid/reid_processor.py @@ -1,26 +1,92 @@ from __future__ import annotations -from typing import Dict, List, Set +from typing import Callable, Dict, List, Optional, Set, Union import numpy as np -from track_reid.constants.reid_constants import reid_constants -from track_reid.matcher import Matcher -from track_reid.tracked_object import TrackedObject -from track_reid.tracked_object_filter import TrackedObjectFilter -from track_reid.utils import filter_objects_by_state, get_top_list_correction + +from trackreid.configs.input_data_positions import input_data_positions +from trackreid.configs.output_data_positions import output_data_positions +from trackreid.configs.reid_constants import reid_constants +from trackreid.cost_functions import bounding_box_distance +from trackreid.matcher import Matcher +from trackreid.selection_functions import select_by_category +from trackreid.tracked_object import TrackedObject +from trackreid.tracked_object_filter import TrackedObjectFilter +from trackreid.utils import ( + filter_objects_by_state, + get_nb_output_cols, + get_top_list_correction, + reshape_tracker_result, +) class ReidProcessor: + """ + The ReidProcessor class is designed to correct the results of tracking algorithms by reconciling and reassigning + lost or misidentified IDs. This ensures a consistent and accurate tracking of objects over time. + + All input data should be of numeric type, either integers or floats. + Here's an example of how the input data should look like based on the schema: + + | bbox (0-3) | object_id (4) | category (5) | confidence (6) | + |-----------------|---------------|--------------|----------------| + | 50, 60, 120, 80 | 1 | 1 | 0.91 | + | 50, 60, 120, 80 | 2 | 0 | 0.54 | + + Each row represents a detected object. The first four columns represent the bounding box coordinates + (x, y, width, height), the fifth column represents the object ID assigned by the tracker, + the sixth column represents the category of the detected object, and the seventh column represents + the confidence score of the detection. + + You can use ReidProcessor.print_input_data_requirements() for more insight. + + Here's an example of how the output data looks like based on the schema: + + | frame_id (0) | object_id (1) | category (2) | bbox (3-6) | confidence (7) | mean_confidence (8) | tracker_id (9) | + |--------------|---------------|--------------|-----------------|----------------|---------------------|----------------| + | 1 | 1 | 1 | 50, 60, 120, 80 | 0.91 | 0.85 | 1 | + | 2 | 2 | 0 | 50, 60, 120, 80 | 0.54 | 0.60 | 2 | + + You can use ReidProcessor.print_output_data_format_information() for more insight. + + + Args: + filter_confidence_threshold (float): Confidence threshold for the filter. The filter will only consider tracked objects that have a mean confidence score during the all transaction above this threshold. + + filter_time_threshold (int): Time threshold for the filter. The filter will only consider tracked objects that have been seen for a number of frames above this threshold. + + max_frames_to_rematch (int): Maximum number of frames to rematch. If a switcher is lost for a number of frames greater than this value, it will be flagged as lost forever. + + max_attempt_to_match (int): Maximum number of attempts to match a candidate. If a candidate has not been rematched despite a number of attempts equal to this value, it will be flagged as a stable object. + + selection_function (Callable): A function that determines whether two objects should be considered for matching. The selection function should take two TrackedObject instances as input and return a binary value (0 or 1). A return value of 1 indicates that the pair should be considered for matching, while a return value of 0 indicates that the pair should not be considered. + + cost_function (Callable): A function that calculates the cost of matching two objects. The cost function should take two TrackedObject instances as input and return a numerical value representing the cost of matching these two objects. A lower cost indicates a higher likelihood of a match. + + cost_function_threshold (Optional[Union[int, float]]): An maximal threshold value for the cost function. If provided, any pair of objects with a matching cost greater than this threshold will not be considered for matching. If not provided, all selected pairs will be considered regardless of their matching cost.\n + + save_to_txt (bool): A flag indicating whether to save the results to a text file. If set to True, the results will be saved to a text file specified by the file_path parameter. + + file_path (str): The path to the text file where the results will be saved if save_to_txt is set to True. + """ # noqa: E501 + def __init__( self, - filter_confidence_threshold, - filter_time_threshold, - cost_function, - selection_function, - max_frames_to_rematch: int = 100, - max_attempt_to_rematch: int = 1, + filter_confidence_threshold: float, + filter_time_threshold: int, + max_frames_to_rematch: int, + max_attempt_to_match: int, + selection_function: Callable = select_by_category, + cost_function: Callable = bounding_box_distance, + cost_function_threshold: Optional[Union[int, float]] = None, + save_to_txt: bool = False, + file_path: str = "tracks.txt", ) -> None: - self.matcher = Matcher(cost_function=cost_function, selection_function=selection_function) + self.matcher = Matcher( + cost_function=cost_function, + selection_function=selection_function, + cost_function_threshold=cost_function_threshold, + ) self.tracked_filter = TrackedObjectFilter( confidence_threshold=filter_confidence_threshold, @@ -28,205 +94,589 @@ def __init__( ) self.all_tracked_objects: List[TrackedObject] = [] - self.switchers: List[TrackedObject] = [] - self.candidates: List[TrackedObject] = [] - - self.last_tracker_ids: Set[int] = set() + self.last_frame_tracked_objects: Set[TrackedObject] = set() self.max_frames_to_rematch = max_frames_to_rematch - self.max_attempt_to_rematch = max_attempt_to_rematch + self.max_attempt_to_match = max_attempt_to_match self.frame_id = 0 + self.nb_output_cols = get_nb_output_cols(output_positions=output_data_positions) + + self.save_to_txt = save_to_txt + self.file_path = file_path + + def set_file_path(self, new_file_path: str) -> None: + """ + Sets a new file path for saving txt data. + + Args: + new_file_path (str): The new file path. + """ + self.file_path = new_file_path + + @property + def nb_corrections(self) -> int: + """ + Calculates and returns the total number of corrections made across all tracked objects. + + Returns: + int: Total number of corrections. + """ + nb_corrections = 0 + for obj in self.all_tracked_objects: + nb_corrections += obj.nb_corrections + return nb_corrections + + @property + def nb_tracker_ids(self) -> int: + """ + Calculates and returns the total number of tracker IDs across all tracked objects. + + Returns: + int: Total number of tracker IDs. + """ + tracker_ids = 0 + for obj in self.all_tracked_objects: + tracker_ids += obj.nb_ids + return tracker_ids + + @property + def corrected_objects(self) -> List["TrackedObject"]: + """ + Returns a list of tracked objects that have been corrected. + + Returns: + List[TrackedObject]: List of corrected tracked objects. + """ + return [obj for obj in self.all_tracked_objects if obj.nb_corrections] + + @property + def seen_objects(self) -> List["TrackedObject"]: + """ + Returns a list of tracked objects that have been seen, excluding those in the + states TRACKER_OUTPUT and FILTERED_OUTPUT. + + Returns: + List[TrackedObject]: List of seen tracked objects. + """ + return filter_objects_by_state( + tracked_objects=self.all_tracked_objects, + states=[reid_constants.STATES.TRACKER_OUTPUT, reid_constants.STATES.FILTERED_OUTPUT], + exclusion=True, + ) + + @property + def mean_nb_corrections(self) -> float: + """ + Calculates and returns the mean number of corrections across all tracked objects. + + Returns: + float: Mean number of corrections. + """ + return self.nb_corrections / len(self.all_tracked_objects) + + def update(self, tracker_output: np.ndarray, frame_id: int) -> np.ndarray: + """ + Processes the tracker output and updates internal states. + + All input data should be of numeric type, either integers or floats. + Here's an example of how the input data should look like based on the schema: + + | bbox (0-3) | object_id (4) | category (5) | confidence (6) | + |-----------------|---------------|--------------|----------------| + | 50, 60, 120, 80 | 1 | 1 | 0.91 | + | 50, 60, 120, 80 | 2 | 0 | 0.54 | + + Each row represents a detected object. The first four columns represent the bounding box coordinates + (x, y, width, height), the fifth column represents the object ID assigned by the tracker, + the sixth column represents the category of the detected object, and the seventh column represents + the confidence score of the detection. + + You can use ReidProcessor.print_input_data_requirements() for more insight. + + Here's an example of how the output data looks like based on the schema: + + | frame_id (0) | object_id (1) | category (2) | bbox (3-6) | confidence (7) | mean_confidence (8) | tracker_id (9) | + |--------------|---------------|--------------|-----------------|----------------|---------------------|----------------| + | 1 | 1 | 1 | 50, 60, 120, 80 | 0.91 | 0.85 | 1 | + | 2 | 2 | 0 | 50, 60, 120, 80 | 0.54 | 0.60 | 2 | + + You can use ReidProcessor.print_output_data_format_information() for more insight. + + Args: + tracker_output (np.ndarray): The tracker output. + frame_id (int): The frame id. + + Returns: + np.ndarray: The processed output. + """ # noqa: E501 + if tracker_output.size: # empty tracking + self.all_tracked_objects, current_tracker_ids = self._preprocess( + tracker_output=tracker_output, frame_id=frame_id + ) + self._perform_reid_process(current_tracker_ids=current_tracker_ids) + reid_output = self._postprocess(current_tracker_ids=current_tracker_ids) + + else: + reid_output = tracker_output + + if self.save_to_txt: + self.save_results_to_txt(file_path=self.file_path, reid_output=reid_output) - def update(self, tracker_output: np.ndarray): - reshaped_tracker_output = self._reshape_input(tracker_output) - self._preprocess(tracker_output=reshaped_tracker_output) - self._perform_reid_process(tracker_output=reshaped_tracker_output) - reid_output = self._postprocess(tracker_output=tracker_output) return reid_output - def _preprocess(self, tracker_output: np.ndarray): - self.all_tracked_objects = self._update_tracked_objects(tracker_output=tracker_output) - self.all_tracked_objects = self._apply_filtering() + def _preprocess(self, tracker_output: np.ndarray, frame_id: int) -> List["TrackedObject"]: + """ + Preprocesses the tracker output. - def _update_tracked_objects(self, tracker_output: np.ndarray): - self.frame_id = tracker_output[0, 0] - for object_id, data_line in zip(tracker_output[:, 1], tracker_output): + Args: + tracker_output (np.ndarray): The tracker output. + frame_id (int): The frame id. + + Returns: + List["TrackedObject"]: The preprocessed output. + """ + reshaped_tracker_output = reshape_tracker_result(tracker_output=tracker_output) + current_tracker_ids = list(reshaped_tracker_output[:, input_data_positions.object_id]) + + self.all_tracked_objects = self._update_tracked_objects( + tracker_output=reshaped_tracker_output, frame_id=frame_id + ) + self.all_tracked_objects = self._apply_filtering() + return self.all_tracked_objects, current_tracker_ids + + def _update_tracked_objects( + self, tracker_output: np.ndarray, frame_id: int + ) -> List[TrackedObject]: + """ + Updates the tracked objects. + + Args: + tracker_output (np.ndarray): The tracker output. + frame_id (int): The frame id. + + Returns: + List[TrackedObject]: The updated tracked objects. + """ + self.frame_id = frame_id + for object_id, data_line in zip( + tracker_output[:, input_data_positions.object_id], tracker_output + ): if object_id not in self.all_tracked_objects: new_tracked_object = TrackedObject( - object_ids=object_id, state=reid_constants.BYETRACK_OUTPUT, metadata=data_line + object_ids=object_id, + state=reid_constants.STATES.TRACKER_OUTPUT, + frame_id=frame_id, + metadata=data_line, ) self.all_tracked_objects.append(new_tracked_object) else: self.all_tracked_objects[self.all_tracked_objects.index(object_id)].update_metadata( - data_line + data_line, frame_id=frame_id ) return self.all_tracked_objects - @staticmethod - def _reshape_input(bytetrack_output: np.ndarray): - if bytetrack_output.ndim == 1: - bytetrack_output = np.expand_dims(bytetrack_output, 0) - return bytetrack_output + def _get_current_frame_tracked_objects( + self, current_tracker_ids: Set[Union[int, float]] + ) -> Set[Union[int, float]]: + """ + Retrieves the tracked objects for the current frame. + + Args: + current_tracker_ids (Set[Union[int, float]]): The set of current tracker IDs. - def _apply_filtering(self): + Returns: + Set[Union[int, float]]: The set of tracked objects for the current frame. + """ + tracked_objects = filter_objects_by_state( + self.all_tracked_objects, states=reid_constants.STATES.TRACKER_OUTPUT, exclusion=True + ) + + current_frame_tracked_objects = set( + [tracked_id for tracked_id in tracked_objects if tracked_id in current_tracker_ids] + ) + + return current_frame_tracked_objects + + def _apply_filtering(self) -> List[TrackedObject]: + """ + Applies filtering to the tracked objects. + + Returns: + List[TrackedObject]: The filtered tracked objects. + """ for tracked_object in self.all_tracked_objects: self.tracked_filter.update(tracked_object) return self.all_tracked_objects - def _perform_reid_process(self, tracker_output: np.ndarray): - tracked_ids = filter_objects_by_state( - self.all_tracked_objects, states=reid_constants.BYETRACK_OUTPUT, exclusion=True + def _perform_reid_process(self, current_tracker_ids: List[Union[int, float]]) -> None: + """ + Performs the re-identification process on tracked objects. + + This method is responsible for managing the state of tracked objects and identifying potential + candidates for re-identification. It follows these steps: + + 1. correct_reid_chains: Corrects the re-identification chains of all tracked objects + based on the current tracker IDs. This avoids potential duplicates. + 2. update_switchers_states: Updates the states of switchers (objects that have switched IDs) + based on the current frame's tracked objects, the maximum number of frames to rematch, and the current frame ID. + 3. update_candidates_states: Updates the states of candidate objects (potential matches for re-identification) + based on the maximum number of attempts to match and the current frame ID. + 4. identify_switchers: Identifies switchers based on the current and last frame's tracked objects and + updates the state of all tracked objects accordingly. + 5. identify_candidates: Identifies candidates for re-identification and updates the state of all + tracked objects accordingly. + 6. match: Matches candidates with switchers using Jonker-Volgenant algorithm. + 7. process_matches: Processes the matches and updates the state of all tracked objects accordingly. + + Args: + current_tracker_ids (List[Union[int, float]]): The current tracker IDs. + """ + + self.all_tracked_objects = self.correct_reid_chains( + all_tracked_objects=self.all_tracked_objects, current_tracker_ids=current_tracker_ids ) - current_tracker_ids = set(tracker_output[:, 1]).intersection(set(tracked_ids)) - - self.compute_stable_objects( - current_tracker_ids=current_tracker_ids, tracked_ids=self.all_tracked_objects + current_frame_tracked_objects = self._get_current_frame_tracked_objects( + current_tracker_ids=current_tracker_ids ) - self.switchers = self.drop_switchers( - self.switchers, - current_tracker_ids, + self.all_tracked_objects = self.update_switchers_states( + all_tracked_objects=self.all_tracked_objects, + current_frame_tracked_objects=current_frame_tracked_objects, max_frames_to_rematch=self.max_frames_to_rematch, frame_id=self.frame_id, ) - self.candidates.extend(self.identify_candidates(tracked_ids=tracked_ids)) + self.all_tracked_objects = self.update_candidates_states( + all_tracked_objects=self.all_tracked_objects, + max_attempt_to_match=self.max_attempt_to_match, + frame_id=self.frame_id, + ) + + self.all_tracked_objects = self.identify_switchers( + current_frame_tracked_objects=current_frame_tracked_objects, + last_frame_tracked_objects=self.last_frame_tracked_objects, + all_tracked_objects=self.all_tracked_objects, + ) - self.switchers.extend( - self.identify_switchers( - current_tracker_ids=current_tracker_ids, - last_bytetrack_ids=self.last_tracker_ids, - tracked_ids=tracked_ids, - ) + self.all_tracked_objects = self.identify_candidates( + all_tracked_objects=self.all_tracked_objects ) - matches = self.matcher.match(self.candidates, self.switchers) + candidates = filter_objects_by_state( + self.all_tracked_objects, states=reid_constants.STATES.CANDIDATE, exclusion=False + ) + switchers = filter_objects_by_state( + self.all_tracked_objects, states=reid_constants.STATES.SWITCHER, exclusion=False + ) + + matches = self.matcher.match(candidates, switchers) - self.process_matches( + self.all_tracked_objects = self.process_matches( all_tracked_objects=self.all_tracked_objects, matches=matches, - candidates=self.candidates, - switchers=self.switchers, - current_tracker_ids=current_tracker_ids, ) - self.candidates = self.drop_candidates( - self.candidates, + current_frame_tracked_objects = self._get_current_frame_tracked_objects( + current_tracker_ids=current_tracker_ids ) - self.last_tracker_ids = current_tracker_ids.copy() + self.last_frame_tracked_objects = current_frame_tracked_objects.copy() @staticmethod def identify_switchers( - tracked_ids: List["TrackedObject"], - current_tracker_ids: Set[int], - last_bytetrack_ids: Set[int], - ): - switchers = [] - lost_ids = last_bytetrack_ids - current_tracker_ids - - for tracked_id in tracked_ids: - if tracked_id in lost_ids: - switchers.append(tracked_id) - tracked_id.state = reid_constants.SWITCHER - - return switchers + all_tracked_objects: List["TrackedObject"], + current_frame_tracked_objects: Set["TrackedObject"], + last_frame_tracked_objects: Set["TrackedObject"], + ) -> List["TrackedObject"]: + """ + Identifies switchers in the list of all tracked objects, and + update their states. A switcher is an object that is lost, and probably + needs to be rematched. + + Args: + all_tracked_objects (List["TrackedObject"]): List of all objects being tracked. + current_frame_tracked_objects (Set["TrackedObject"]): Set of currently tracked objects. + last_frame_tracked_objects Set["TrackedObject"]: Set of last timestep tracked objects. + + Returns: + List["TrackedObject"]: Updated list of all tracked objects after state changes. + """ + lost_objects = last_frame_tracked_objects - current_frame_tracked_objects + + for tracked_object in all_tracked_objects: + if tracked_object in lost_objects: + tracked_object.state = reid_constants.STATES.SWITCHER + + return all_tracked_objects @staticmethod - def identify_candidates(tracked_ids: List["TrackedObject"]): - candidates = [] - for current_object in tracked_ids: - if current_object.state == reid_constants.FILTERED_OUTPUT: - current_object.state = reid_constants.CANDIDATE - candidates.append(current_object) - return candidates + def identify_candidates(all_tracked_objects: List["TrackedObject"]) -> List["TrackedObject"]: + """ + Identifies candidates in the list of all tracked objects, and + update their states. A candidate is an object that was never seen before and + that probably needs to be rematched. + + Args: + all_tracked_objects (List["TrackedObject"]): List of all objects being tracked. + + Returns: + List["TrackedObject"]: Updated list of all tracked objects after state changes. + """ + tracked_objects = filter_objects_by_state( + all_tracked_objects, states=reid_constants.STATES.TRACKER_OUTPUT, exclusion=True + ) + for current_object in tracked_objects: + if current_object.state == reid_constants.STATES.FILTERED_OUTPUT: + current_object.state = reid_constants.STATES.CANDIDATE + return all_tracked_objects @staticmethod - def compute_stable_objects(tracked_ids: list, current_tracker_ids: Set[int]): - top_list_correction = get_top_list_correction(tracked_ids) - - for current_object in current_tracker_ids: - tracked_id = tracked_ids[tracked_ids.index(current_object)] - if current_object not in top_list_correction: - tracked_ids.remove(tracked_id) - new_object, tracked_id = tracked_id.cut(current_object) - - new_object.state = reid_constants.STABLE - tracked_id.state = reid_constants.STABLE - - tracked_ids.append(new_object) - tracked_ids.append(tracked_id) + def correct_reid_chains( + all_tracked_objects: List["TrackedObject"], + current_tracker_ids: List[Union[int, float]], + ) -> List["TrackedObject"]: + """ + Corrects the reid chains to prevent duplicates when an object reappears with a corrected id. + For instance, if an object has a reid chain [1, 3, 6, 7], only the id 7 should be in the tracker's output. + If another id from the chain (e.g., 3) is in the tracker's output, the reid chain is split into two: + [1, 3] and [6, 7]. The first object's state is set to stable as 3 is in the current tracker output, + and a new object with reid chain [6, 7] is created. + The new object's state can be: + - stable, if the tracker output is in the new reid chain + - switcher, if not + - nothing, if this is a singleton object, in which case the reid process is performed automatically. + + Args: + all_tracked_objects (List["TrackedObject"]): List of all objects being tracked. + current_tracker_ids (List[Union[int, float]]): The current tracker IDs. + + Returns: + List["TrackedObject"]: The corrected tracked objects. + """ + top_list_correction = get_top_list_correction(all_tracked_objects) + to_correct = set(current_tracker_ids) - set(top_list_correction) + + for current_object in to_correct: + tracked_id = all_tracked_objects[all_tracked_objects.index(current_object)] + all_tracked_objects.remove(tracked_id) + new_object, tracked_id = tracked_id.cut(current_object) + + tracked_id.state = reid_constants.STATES.STABLE + all_tracked_objects.append(tracked_id) + + if new_object in current_tracker_ids: + new_object.state = reid_constants.STATES.CANDIDATE + all_tracked_objects.append(new_object) + + elif new_object.nb_corrections > 1: + new_object.state = reid_constants.STATES.SWITCHER + all_tracked_objects.append(new_object) + + return all_tracked_objects @staticmethod def process_matches( all_tracked_objects: List["TrackedObject"], matches: Dict["TrackedObject", "TrackedObject"], - switchers: List["TrackedObject"], - candidates: List["TrackedObject"], - current_tracker_ids: Set[int], - ): + ) -> List["TrackedObject"]: + """ + Processes the matches. + + Args: + all_tracked_objects (List["TrackedObject"]): List of all objects being tracked. + matches (Dict["TrackedObject", "TrackedObject"]): The matches. + + Returns: + List["TrackedObject"]: The processed tracked objects. + """ for match in matches: candidate_match, switcher_match = match.popitem() - switcher_match.merge(candidate_match) + switcher_match.state = reid_constants.STATES.STABLE all_tracked_objects.remove(candidate_match) - switchers.remove(switcher_match) - candidates.remove(candidate_match) - current_tracker_ids.discard(candidate_match.id) - current_tracker_ids.add(switcher_match.id) + return all_tracked_objects @staticmethod - def drop_switchers( - switchers: List["TrackedObject"], - current_tracker_ids: Set[int], + def update_switchers_states( + all_tracked_objects: List["TrackedObject"], + current_frame_tracked_objects: Set["TrackedObject"], max_frames_to_rematch: int, frame_id: int, - ): - switchers_to_drop = set(switchers).intersection(current_tracker_ids) - filtered_switchers = switchers.copy() + ) -> List["TrackedObject"]: + """ + Updates the state of switchers in the list of all tracked objects: + - If a switcher is lost for too long, it will be flaged as lost forever + - If a switcher reapears in the tracking output, it will be flaged as + a stable object. + + Args: + all_tracked_objects (List["TrackedObject"]): List of all objects being tracked. + current_frame_tracked_objects (Set["TrackedObject"]): Set of currently tracked objects. + max_frames_to_rematch (int): Maximum number of frames to rematch. + frame_id (int): Current frame id. + + Returns: + List["TrackedObject"]: Updated list of all tracked objects after state changes. + """ + switchers = filter_objects_by_state( + all_tracked_objects, reid_constants.STATES.SWITCHER, exclusion=False + ) + switchers_to_drop = set(switchers).intersection(current_frame_tracked_objects) for switcher in switchers: if switcher in switchers_to_drop: - switcher.state = reid_constants.STABLE - filtered_switchers.remove(switcher) + switcher.state = reid_constants.STATES.STABLE elif switcher.get_nb_frames_since_last_appearance(frame_id) > max_frames_to_rematch: - filtered_switchers.remove(switcher) + switcher.state = reid_constants.STATES.LOST_FOREVER - return filtered_switchers + return all_tracked_objects @staticmethod - def drop_candidates(candidates: List["TrackedObject"]): - # for now drop candidates if there was no match - for candidate in candidates: - candidate.state = reid_constants.STABLE - return [] - - def _postprocess(self, tracker_output: np.ndarray): - filtered_objects = list( - filter( - lambda obj: obj.get_state() == reid_constants.STABLE - and obj in tracker_output[:, 1], - self.all_tracked_objects, - ) + def update_candidates_states( + all_tracked_objects: List["TrackedObject"], max_attempt_to_match: int, frame_id: int + ) -> List["TrackedObject"]: + """ + Updates the state of candidates in the list of all tracked objects. + If a candidate has not been rematched despite max_attempt_to_match attempts, + if will be flaged as a stable object. + + Args: + all_tracked_objects (List["TrackedObject"]): List of all objects being tracked. + max_attempt_to_match (int): Maximum attempt to match a candidate. + frame_id (int): Current frame id. + + Returns: + List["TrackedObject"]: Updated list of all tracked objects after state changes. + """ + candidates = filter_objects_by_state( + tracked_objects=all_tracked_objects, + states=reid_constants.STATES.CANDIDATE, + exclusion=False, ) - reid_output = [] - for object in filtered_objects: - reid_output.append( - [ - self.frame_id, - object.id, - object.category, - object.bbox[0], - object.bbox[1], - object.bbox[2], - object.bbox[3], - object.confidence, - ] - ) + + for candidate in candidates: + if candidate.get_age(frame_id) >= max_attempt_to_match: + candidate.state = reid_constants.STATES.STABLE + return all_tracked_objects + + def _postprocess( + self, + current_tracker_ids: List[Union[int, float]], + ) -> np.ndarray: + """ + Postprocesses the current tracker IDs. + It selects the stable TrackedObjects, and formats their datas in the output + to match requirements. + + Args: + current_tracker_ids (List[Union[int, float]]): The current tracker IDs. + + Returns: + np.ndarray: The postprocessed output. + """ + stable_objects = [ + obj + for obj in self.all_tracked_objects + if obj.get_state() == reid_constants.STATES.STABLE and obj in current_tracker_ids + ] + + reid_output = np.zeros((len(stable_objects), self.nb_output_cols)) + + for idx, stable_object in enumerate(stable_objects): + for required_variable in output_data_positions.model_json_schema()["properties"].keys(): + output = ( + self.frame_id + if required_variable == "frame_id" + else getattr(stable_object, required_variable, None) + ) + if output is None: + raise NameError( + f"Attribute {required_variable} not in TrackedObject. Check your required output names." + ) + reid_output[idx, getattr(output_data_positions, required_variable)] = output return reid_output + + def save_results_to_txt(self, file_path: str, reid_output: np.ndarray) -> None: + """ + Saves the reid_output to a txt file. + + Args: + file_path (str): The path to the txt file. + reid_output (np.ndarray): The output of _post_process. + """ + with open(file_path, "a") as f: # noqa: PTH123 + for row in reid_output: + line = " ".join( + str(int(val)) if val.is_integer() else "{:.6f}".format(val) for val in row + ) + f.write(line + "\n") + + def to_dict(self) -> Dict: + """ + Converts the tracked objects to a dictionary. + + Returns: + Dict: The dictionary representation of the tracked objects. + """ + data = dict() + for tracked_object in self.all_tracked_objects: + data[tracked_object.object_id] = tracked_object.to_dict() + return data + + @staticmethod + def print_input_data_format_requirements(): + """ + + Prints the input data format requirements. + + All input data should be of numeric type, either integers or floats. + Here's an example of how the input data should look like based on the schema: + + | bbox (0-3) | object_id (4) | category (5) | confidence (6) | + |-----------------|---------------|--------------|----------------| + | 50, 60, 120, 80 | 1 | 1 | 0.91 | + | 50, 60, 120, 80 | 2 | 0 | 0.54 | + + Each row represents a detected object. The first four columns represent the bounding box coordinates + (x, y, width, height), the fifth column represents the object ID assigned by the tracker, + the sixth column represents the category of the detected object, and the seventh column represents + the confidence score of the detection. + """ + input_schema = input_data_positions.model_json_schema() + + print("Input Data Format Requirements:") + for name, properties in input_schema["properties"].items(): + print("-" * 50) + print(f"{name}: {properties['description']}") + print( + f"{name} (position of {name} in the input array must be): {properties['default']}" + ) + + @staticmethod + def print_output_data_format_information(): + """ + Prints the output data format information. + + Here's an example of how the output data looks like based on the schema: + + | frame_id (0) | object_id (1) | category (2) | bbox (3-6) | confidence (7) | mean_confidence (8) | tracker_id (9) | + |--------------|---------------|--------------|------------|----------------|-------------------|------------------| + | 1 | 1 | 1 | 50,60,120,80 | 0.91 | 0.85 | 1 | + | 2 | 2 | 0 | 50,60,120,80 | 0.54 | 0.60 | 2 | + + """ # noqa: E501 + output_schema = output_data_positions.model_json_schema() + + print("\nOutput Data Format:") + for name, properties in output_schema["properties"].items(): + print("-" * 50) + print(f"{name}: {properties['description']}") + print( + f"{name} (position of {name} in the output array will be): {properties['default']}" + ) diff --git a/trackreid/selection_functions/__init__.py b/trackreid/selection_functions/__init__.py new file mode 100644 index 0000000..b16c7b2 --- /dev/null +++ b/trackreid/selection_functions/__init__.py @@ -0,0 +1 @@ +from .select_by_category import select_by_category # noqa: F401 diff --git a/trackreid/selection_functions/select_by_category.py b/trackreid/selection_functions/select_by_category.py new file mode 100644 index 0000000..38aec83 --- /dev/null +++ b/trackreid/selection_functions/select_by_category.py @@ -0,0 +1,18 @@ +from trackreid.tracked_object import TrackedObject + + +def select_by_category(candidate: TrackedObject, switcher: TrackedObject) -> int: + """ + Compares the categories of two TrackedObject instances. + This selection function is used as a measure of similarity between the two objects, + matches are discard if this function returns 0. + + Args: + candidate (TrackedObject): The first TrackedObject instance. + switcher (TrackedObject): The second TrackedObject instance. + + Returns: + int: Returns 1 if the categories of the two objects are the same, otherwise returns 0. + """ + # Compare the categories of the two objects + return 1 if candidate.category == switcher.category else 0 diff --git a/trackreid/tracked_object.py b/trackreid/tracked_object.py index b2dd1dc..1bc84b0 100644 --- a/trackreid/tracked_object.py +++ b/trackreid/tracked_object.py @@ -1,102 +1,223 @@ from __future__ import annotations -from typing import Union +import json +from typing import Optional, Union import numpy as np from llist import sllist -from track_reid.constants.reid_constants import reid_constants -from track_reid.tracked_object_metadata import TrackedObjectMetaData -from track_reid.utils import split_list_around_value + +from trackreid.configs.reid_constants import reid_constants +from trackreid.tracked_object_metadata import TrackedObjectMetaData +from trackreid.utils import split_list_around_value class TrackedObject: + """ + The TrackedObject class represents an object that is being tracked in a video frame. + It contains information about the object's state, its unique identifiers, and metadata. + + The object's state is an integer that represents the current state of the object in the + reid process. The states can take the following values: + + - LOST_FOREVER (-3): "Switcher never rematched" + - TRACKER_OUTPUT (-2): "Tracker output not in reid process" + - FILTERED_OUTPUT (-1): "Tracker output entering reid process" + - STABLE (0): "Stable object" + - SWITCHER (1): "Lost object to be re-matched" + - CANDIDATE (2): "New object to be matched" + + The object's unique identifiers are stored in a singly linked list (sllist) called re_id_chain. The re_id_chain + is a crucial component in the codebase. It stores the history of the object's unique identifiers, allowing for + tracking of the object across different frames. The first value in the re_id_chain + is the original object ID, while the last value is the most recent tracker ID assigned to the object. + + The metadata is an instance of the TrackedObjectMetaData class, which contains additional information + about the object. + + The TrackedObject class provides several methods for manipulating and accessing the data it contains. + These include methods for merging two TrackedObject instances, updating the metadata, and converting the + TrackedObject instance to a dictionary or JSON string. + + The TrackedObject class also provides several properties for accessing specific pieces of data, such as the object's + unique identifier, its state, and its metadata. + + Args: + object_ids (Union[Union[float, int], sllist]): The unique identifiers for the object. + state (int): The current state of the object. + metadata (Union[np.ndarray, TrackedObjectMetaData]): The metadata for the object. It can be either a TrackedObjectMetaData object, or a data line, i.e. output of detection model. If metadata is initialized with a TrackedObjectMetaData object, a frame_id must be given. + frame_id (Optional[int], optional): The frame ID where the object was first seen. Defaults to None. + + Raises: + NameError: If the type of object_ids or metadata is unrecognized. + """ # noqa: E501 + def __init__( self, - object_ids: Union[int, sllist], + object_ids: Union[Union[float, int], sllist], state: int, metadata: Union[np.ndarray, TrackedObjectMetaData], + frame_id: Optional[int] = None, ): self.state = state - if isinstance(object_ids, int): + if isinstance(object_ids, Union[float, int]): self.re_id_chain = sllist([object_ids]) elif isinstance(object_ids, sllist): - self.re_id_chain = object_ids + self.re_id_chain = sllist(object_ids) else: raise NameError("unrocognized type for object_ids.") if isinstance(metadata, np.ndarray): - self.metadata = TrackedObjectMetaData(metadata) + assert ( + frame_id is not None + ), "Please provide a frame_id for TrackedObject initialization" + self.metadata = TrackedObjectMetaData(metadata, frame_id) elif isinstance(metadata, TrackedObjectMetaData): self.metadata = metadata.copy() else: raise NameError("unrocognized type for metadata.") - def merge(self, other_object): + def copy(self): + return TrackedObject(object_ids=self.re_id_chain, state=self.state, metadata=self.metadata) + + def merge(self, other_object: TrackedObject): if not isinstance(other_object, TrackedObject): raise TypeError("Can only merge with another TrackedObject.") # Merge the re_id_chains self.re_id_chain.extend(other_object.re_id_chain) - - # Merge the metadata (you should implement a proper merge logic in TrackedObjectMetaData) self.metadata.merge(other_object.metadata) - self.state = reid_constants.STABLE + self.state = other_object.state # Return the merged object return self @property - def id(self): + def object_id(self): + """ + Returns the first value in the re_id_chain which represents the object id. + """ return self.re_id_chain.first.value + @property + def tracker_id(self): + """ + Returns the last value in the re_id_chain which represents the last tracker id. + """ + return self.re_id_chain.last.value + @property def category(self): + """ + Returns the category with the maximum count in the class_counts dictionary of the metadata. + """ return max(self.metadata.class_counts, key=self.metadata.class_counts.get) @property def confidence(self): + """ + Returns the confidence value from the metadata. + """ return self.metadata.confidence @property def mean_confidence(self): + """ + Returns the mean confidence value from the metadata. + """ return self.metadata.mean_confidence() @property def bbox(self): + """ + Returns the bounding box coordinates from the metadata. + """ return self.metadata.bbox - def get_age(self, frame_id): + @property + def nb_ids(self): + """ + Returns the number of ids in the re_id_chain. + """ + return len(self.re_id_chain) + + @property + def nb_corrections(self): + """ + Returns the number of corrections which is the number of ids in the re_id_chain minus one. + """ + return self.nb_ids - 1 + + def get_age(self, frame_id: int): + """ + Calculates and returns the age of the tracked object based on the given frame id. + Age is defined as the difference between the current frame id and the first frame id where + the object was detected. + """ return frame_id - self.metadata.first_frame_id - def get_nb_frames_since_last_appearance(self, frame_id): + def get_nb_frames_since_last_appearance(self, frame_id: int): + """ + Calculates and returns the number of frames since the last appearance of the tracked object. + This is computed as the difference between the current frame id and the last frame id where + the object was detected. + """ return frame_id - self.metadata.last_frame_id def get_state(self): + """ + Returns the current state of the tracked object. + """ return self.state def __hash__(self): - return hash(self.id) + return hash(self.object_id) def __repr__(self): return ( - f"TrackedObject(current_id={self.id}, re_id_chain={list(self.re_id_chain)}" - + f", state={self.state}: {reid_constants.DESCRIPTION[self.state]})" + f"TrackedObject(current_id={self.object_id}, re_id_chain={list(self.re_id_chain)}" + + f", state={self.state}: {reid_constants.STATES.DESCRIPTION[self.state]})" ) def __str__(self): return f"{self.__repr__()}, metadata : {self.metadata}" - def update_metadata(self, data_line: np.ndarray): - self.metadata.update(data_line) + def update_metadata(self, data_line: np.ndarray, frame_id: int): + """ + Updates the metadata of the tracked object based on new detection data. + + This method is used to update the metadata of a tracked object whenever new detection data is available. + It updates the metadata by calling the update method of the TrackedObjectMetaData instance associated with + the tracked object. + + Args: + data_line (np.ndarray): The detection data for a single frame. It contains information such as the class name, bounding box coordinates, and confidence level of the detection. + + frame_id (int): The frame id where the object was detected. This is used to update the last frame id of the tracked object. + """ # noqa: E501 + self.metadata.update(data_line=data_line, frame_id=frame_id) def __eq__(self, other): - if isinstance(other, int): + if isinstance(other, Union[float, int]): return other in self.re_id_chain elif isinstance(other, TrackedObject): return self.re_id_chain == other.re_id_chain return False def cut(self, object_id: int): + """ + Splits the re_id_chain of the tracked object at the specified object_id and creates a new TrackedObject + instance with the remaining part of the re_id_chain. The original TrackedObject instance retains the part + of the re_id_chain before the specified object_id. + + Args: + object_id (int): The object_id at which to split the re_id_chain. + + Raises: + NameError: If the specified object_id is not found in the re_id_chain of the tracked object. + + Returns: + tuple: A tuple containing the new TrackedObject instance and the original TrackedObject instance. + """ if object_id not in self.re_id_chain: raise NameError( f"Trying to cut object {self} with {object_id} that is not in the re-id chain." @@ -106,17 +227,63 @@ def cut(self, object_id: int): self.re_id_chain = before new_object = TrackedObject( - state=reid_constants.STABLE, object_ids=after, metadata=self.metadata + state=reid_constants.STATES.STABLE, object_ids=after, metadata=self.metadata ) + # set potential age 0 for new object + new_object.metadata.first_frame_id = new_object.metadata.last_frame_id return new_object, self - def format_data(self): - return [ - self.id, - self.category, - self.bbox[0], - self.bbox[1], - self.bbox[2], - self.bbox[3], - self.confidence, - ] + def to_dict(self): + """ + Converts the TrackedObject instance to a dictionary. + + Returns: + dict: A dictionary representation of the TrackedObject instance. + """ + data = { + "object_id": float(self.object_id), + "state": int(self.state), + "re_id_chain": list(self.re_id_chain), + "metadata": self.metadata.to_dict(), + } + return data + + def to_json(self): + """ + Converts the TrackedObject instance to a JSON string. + + Returns: + str: A JSON string representation of the TrackedObject instance. + """ + return json.dumps(self.to_dict(), indent=4) + + @classmethod + def from_dict(cls, data: dict): + """ + Creates a new TrackedObject instance from a dictionary. + + Args: + data (dict): A dictionary containing the data for the TrackedObject instance. + + Returns: + TrackedObject: A new TrackedObject instance created from the dictionary. + """ + obj = cls.__new__(cls) + obj.state = data["state"] + obj.re_id_chain = sllist(data["re_id_chain"]) + obj.metadata = TrackedObjectMetaData.from_dict(data["metadata"]) + return obj + + @classmethod + def from_json(cls, json_str: str): + """ + Creates a new TrackedObject instance from a JSON string. + + Args: + json_str (str): A JSON string containing the data for the TrackedObject instance. + + Returns: + TrackedObject: A new TrackedObject instance created from the JSON string. + """ + data = json.loads(json_str) + return cls.from_dict(data) diff --git a/trackreid/tracked_object_filter.py b/trackreid/tracked_object_filter.py index bc936b0..cf9b016 100644 --- a/trackreid/tracked_object_filter.py +++ b/trackreid/tracked_object_filter.py @@ -1,18 +1,42 @@ -from track_reid.constants.reid_constants import reid_constants +from trackreid.configs.reid_constants import reid_constants +from trackreid.tracked_object import TrackedObject class TrackedObjectFilter: - def __init__(self, confidence_threshold, frames_seen_threshold): + """ + The TrackedObjectFilter class is used to filter tracked objects based on their + confidence and the number of frames they have been observed in. + + Args: + confidence_threshold (float): The minimum mean confidence level required for a tracked object to be considered valid. + frames_seen_threshold (int): The minimum number of frames a tracked object must be observed in to be considered valid. + """ # noqa: E501 + + def __init__(self, confidence_threshold: float, frames_seen_threshold: int): self.confidence_threshold = confidence_threshold self.frames_seen_threshold = frames_seen_threshold - def update(self, tracked_object): - if tracked_object.get_state() == reid_constants.BYETRACK_OUTPUT: + def update(self, tracked_object: TrackedObject): + """ + The update method is used to update the state of a tracked object based on its confidence + and the number of frames it has been observed in. + + If the tracked object's state is TRACKER_OUTPUT, and its mean confidence is greater than the + confidence_threshold, and it has been observed in more frames than the frames_seen_threshold, + its state is updated to FILTERED_OUTPUT. + + If the tracked object's mean confidence is less than the confidence_threshold, its state is + updated to TRACKER_OUTPUT. + + Args: + tracked_object (TrackedObject): The tracked object to update. + """ + if tracked_object.get_state() == reid_constants.STATES.TRACKER_OUTPUT: if ( tracked_object.metadata.mean_confidence() > self.confidence_threshold and tracked_object.metadata.observations >= self.frames_seen_threshold ): - tracked_object.state = reid_constants.FILTERED_OUTPUT + tracked_object.state = reid_constants.STATES.FILTERED_OUTPUT elif tracked_object.metadata.mean_confidence() < self.confidence_threshold: - tracked_object.state = reid_constants.BYETRACK_OUTPUT + tracked_object.state = reid_constants.STATES.TRACKER_OUTPUT diff --git a/trackreid/tracked_object_metadata.py b/trackreid/tracked_object_metadata.py index dc255f1..874ec4b 100644 --- a/trackreid/tracked_object_metadata.py +++ b/trackreid/tracked_object_metadata.py @@ -1,30 +1,81 @@ import json -from pathlib import Path -from track_reid.args.reid_args import POSSIBLE_CLASSES +import numpy as np + +from trackreid.configs.input_data_positions import input_data_positions class TrackedObjectMetaData: - def __init__(self, data_line): - self.first_frame_id = int(data_line[0]) - self.class_counts = {class_name: 0 for class_name in POSSIBLE_CLASSES} + """ + The TrackedObjectMetaData class is used to store and manage metadata for tracked objects in a video frame. + This metadata includes information such as the frame ID where the object was first seen, the class counts + (how many times each class was detected), the bounding box coordinates, and the confidence level of the detection. + + This metadata is then use in selection and cost functions to compute likelihood of a match between two objects. + + Usage: + An instance of TrackedObjectMetaData is created by passing a data_line (which contains the detection data + for a single frame) and a frame_id (which identifies the frame where the object was detected). + """ + + def __init__(self, data_line: np.ndarray, frame_id: int): + self.first_frame_id = frame_id + self.class_counts = {} self.observations = 0 self.confidence_sum = 0 self.confidence = 0 - self.update(data_line) + self.update(data_line, frame_id) + + def update(self, data_line: np.ndarray, frame_id: int): + """ + Updates the metadata of a tracked object based on new detection data. + + This method is used to update the metadata of a tracked object whenever new detection data is available. + It updates the last frame id, class counts, bounding box, confidence, confidence sum, and observations: + - last_frame_id: Updated to the frame id where the object was detected + - class_counts: Incremented by 1 for the detected class + - bbox: Updated to the bounding box coordinates from the detection data + - confidence: Updated to the confidence level from the detection data + - confidence_sum: Incremented by the confidence level from the detection data + - observations: Incremented by 1 - def update(self, data_line): - self.last_frame_id = int(data_line[0]) - class_name = data_line[2] + Args: + data_line (np.ndarra): The detection data for a single frame. It contains information such as the class name, bounding box coordinates, and confidence level of the detection. + + frame_id (int): The frame id where the object was detected. This is used to update the last frame id of the tracked object. + + """ # noqa: E501 + self.last_frame_id = frame_id + + class_name = int(data_line[input_data_positions.category]) self.class_counts[class_name] = self.class_counts.get(class_name, 0) + 1 - self.bbox = list(map(int, data_line[3:7])) - confidence = float(data_line[7]) + self.bbox = list(map(int, data_line[input_data_positions.bbox])) + confidence = float(data_line[input_data_positions.confidence]) self.confidence = confidence self.confidence_sum += confidence self.observations += 1 def merge(self, other_object): - if not isinstance(other_object, TrackedObjectMetaData): + """ + Merges the metadata of another TrackedObjectMetaData instance into the current one. + Updates the current instance with the data from the other TrackedObjectMetaData instance. + + The following properties are updated: + - observations: Incremented by the observations of the other object. + - confidence_sum: Incremented by the confidence sum of the other object. + - confidence: Set to the confidence of the other object. + - bbox: Set to the bounding box of the other object. + - last_frame_id: Set to the last frame id of the other object. + - class_counts: For each class, the count is incremented by the count of the other object. + + Args: + other_object (TrackedObjectMetaData): The other TrackedObjectMetaData instance whose metadata is to be merged with the current instance. + + Raises: + TypeError: If the other_object is not an instance of TrackedObjectMetaData. + + """ # noqa: E501 + if not isinstance(other_object, type(self)): raise TypeError("Can only merge with another TrackedObjectMetaData.") self.observations += other_object.observations @@ -32,61 +83,132 @@ def merge(self, other_object): self.confidence = other_object.confidence self.bbox = other_object.bbox self.last_frame_id = other_object.last_frame_id - for class_name in POSSIBLE_CLASSES: + for class_name in other_object.class_counts.keys(): self.class_counts[class_name] = self.class_counts.get( class_name, 0 ) + other_object.class_counts.get(class_name, 0) def copy(self): - # Create a new instance of TrackedObjectMetaData with the same data - copy_obj = TrackedObjectMetaData( - [self.first_frame_id, 0, list(self.class_counts.keys())[0], *self.bbox, self.confidence] - ) - # Update the copied instance with the actual class counts and observations + """ + Creates a copy of the current TrackedObjectMetaData instance. + + Returns: + TrackedObjectMetaData: A new instance of TrackedObjectMetaData with the same + properties as the current instance. + """ + copy_obj = TrackedObjectMetaData.__new__(TrackedObjectMetaData) + copy_obj.bbox = self.bbox.copy() copy_obj.class_counts = self.class_counts.copy() copy_obj.observations = self.observations copy_obj.confidence_sum = self.confidence_sum copy_obj.confidence = self.confidence + copy_obj.first_frame_id = self.first_frame_id + copy_obj.last_frame_id = self.last_frame_id return copy_obj - def save_to_json(self, filename): + def to_dict(self): + """ + Converts the TrackedObjectMetaData instance to a dictionary. + + The class_counts dictionary is converted to a string-keyed dictionary. + The bounding box list is converted to a list of integers. + The first_frame_id, last_frame_id, confidence, confidence_sum, and observations are converted to their + respective types. + + Returns: + dict: A dictionary representation of the TrackedObjectMetaData instance. + """ + class_counts_str = { + str(class_name): count for class_name, count in self.class_counts.items() + } data = { - "first_frame_id": self.first_frame_id, - "class_counts": self.class_counts, - "bbox": self.bbox, - "confidence": self.confidence, - "confidence_sum": self.confidence_sum, - "observations": self.observations, + "first_frame_id": int(self.first_frame_id), + "last_frame_id": int(self.last_frame_id), + "class_counts": class_counts_str, + "bbox": [int(i) for i in self.bbox], + "confidence": float(self.confidence), + "confidence_sum": float(self.confidence_sum), + "observations": int(self.observations), } + return data + + def to_json(self): + """ + Converts the TrackedObjectMetaData instance to a JSON string. + + Returns: + str: A JSON string representation of the TrackedObjectMetaData instance. + """ + return json.dumps(self.to_dict(), indent=4) - with Path.open(filename, "w") as file: - json.dump(data, file) + @classmethod + def from_dict(cls, data: dict): + """ + Creates a new instance of the class from a dictionary. + + The dictionary should contain the following keys: "first_frame_id", "last_frame_id", "class_counts", + "bbox", "confidence", "confidence_sum", and "observations". The "class_counts" key should map to a + dictionary where the keys are class names (as integers) and the values are counts. + + Args: + data (dict): A dictionary containing the data to populate the new instance. + + Returns: + TrackedObjectMetaData: A new instance of TrackedObjectMetaData populated with the data from the dictionary. + """ + class_counts_str = data["class_counts"] + class_counts = {int(class_name): count for class_name, count in class_counts_str.items()} + obj = cls.__new__(cls) + obj.first_frame_id = data["first_frame_id"] + obj.last_frame_id = data["last_frame_id"] + obj.class_counts = class_counts + obj.bbox = data["bbox"] + obj.confidence = data["confidence"] + obj.confidence_sum = data["confidence_sum"] + obj.observations = data["observations"] + return obj @classmethod - def load_from_json(cls, filename): - with Path.open(filename, "r") as file: - data = json.load(file) - obj = cls.__new__(cls) - obj.first_frame_id = data["first_frame_id"] - obj.class_counts = data["class_counts"] - obj.bbox = data["bbox"] - obj.confidence = data["confidence"] - obj.confidence_sum = data["confidence_sum"] - obj.observations = data["observations"] - return obj + def from_json(cls, json_str: str): + """ + Creates a new instance of the class from a JSON string. + + Args: + json_str (str): A JSON string representation of the TrackedObjectMetaData instance. + + Returns: + TrackedObjectMetaData: A new instance of TrackedObjectMetaData populated with the data from the JSON string. + """ + data = json.loads(json_str) + return cls.from_dict(data) def class_proportions(self): + """ + Calculates the proportions of each class in the tracked object. + + Returns: + dict: A dictionary where the keys are class names and the values are the proportions of each class. + """ if self.observations > 0: proportions = { class_name: count / self.observations for class_name, count in self.class_counts.items() } else: - proportions = {class_name: 0.0 for class_name in POSSIBLE_CLASSES} + proportions = None return proportions - def percentage_of_time_seen(self, frame_id): + def percentage_of_time_seen(self, frame_id: int): + """ + Calculates the percentage of time the tracked object has been seen. + + Args: + frame_id (int): The current frame id. + + Returns: + float: The percentage of time the tracked object has been seen. + """ if self.observations > 0: percentage = (self.observations / (frame_id - self.first_frame_id + 1)) * 100 else: @@ -94,17 +216,35 @@ def percentage_of_time_seen(self, frame_id): return percentage def mean_confidence(self): + """ + Calculates the mean confidence of the tracked object. + + Returns: + float: The mean confidence of the tracked object. + """ if self.observations > 0: return self.confidence_sum / self.observations else: return 0.0 def __repr__(self) -> str: + """ + Returns a string representation of the TrackedObjectMetaData instance. + + Returns: + str: A string representation of the TrackedObjectMetaData instance. + """ return f"TrackedObjectMetaData(bbox={self.bbox})" def __str__(self): + """ + Returns a string representation of the TrackedObjectMetaData instance. + + Returns: + str: A string representation of the TrackedObjectMetaData instance. + """ return ( f"First frame seen: {self.first_frame_id}, nb observations: {self.observations}, " - + "class Proportions: {self.class_proportions()}, Bounding Box: {self.bbox}, " - + "Mean Confidence: {self.mean_confidence()}" + + f"class proportions: {self.class_proportions()}, bbox: {self.bbox}, " + + f"mean confidence: {self.mean_confidence()}" ) diff --git a/trackreid/utils.py b/trackreid/utils.py index 7c559f3..65f8bcb 100644 --- a/trackreid/utils.py +++ b/trackreid/utils.py @@ -1,17 +1,42 @@ from typing import List, Union +import numpy as np from llist import sllist +from trackreid.configs.output_data_positions import OutputDataPositions -def get_top_list_correction(tracked_ids: list): + +def get_top_list_correction(tracked_ids: List): + """ + Function to get the last value of each re_id_chain in tracked_ids. + + Args: + tracked_ids (list): List of tracked ids. + + Returns: + list: List of last values of each re_id_chain in tracked_ids. + """ top_list_correction = [tracked_id.re_id_chain.last.value for tracked_id in tracked_ids] return top_list_correction -def split_list_around_value(my_list: sllist, value_to_split: int): +def split_list_around_value(my_list: sllist, value_to_split: float): + """ + Function to split a list around a given value. + + Args: + my_list (sllist): The list to split. + value_to_split (float): The value to split the list around. + + Returns: + tuple: Two lists, before and after the split value. + """ if value_to_split == my_list.last.value: raise NameError("split on the last") + if value_to_split not in my_list: + raise NameError(f"{value_to_split} is not in the list") + before = sllist() after = sllist() @@ -21,9 +46,9 @@ def split_list_around_value(my_list: sllist, value_to_split: int): before.append(current.value) if current.value == value_to_split: break + current = current.next current = current.next - while current: after.append(current.value) current = current.next @@ -31,7 +56,18 @@ def split_list_around_value(my_list: sllist, value_to_split: int): return before, after -def filter_objects_by_state(tracked_objects: List, states: Union[int, list], exclusion=False): +def filter_objects_by_state(tracked_objects: List, states: Union[int, List[int]], exclusion=False): + """ + Function to filter tracked objects by their state. + + Args: + tracked_objects (List): List of tracked objects. + states (Union[int, list]): State or list of states to filter by. + exclusion (bool, optional): If True, exclude objects with the given states. Defaults to False. + + Returns: + list: List of filtered tracked objects. + """ if isinstance(states, int): states = [states] if exclusion: @@ -39,3 +75,66 @@ def filter_objects_by_state(tracked_objects: List, states: Union[int, list], exc else: filtered_objects = [obj for obj in tracked_objects if obj.state in states] return filtered_objects + + +def filter_objects_by_category( + tracked_objects: List, + category: Union[Union[float, int], List[Union[float, int]]], + exclusion=False, +): + """ + Function to filter tracked objects by their category. + + Args: + tracked_objects (List): List of tracked objects. + category (Union[Union[float, int], list]): Category or list of categories to filter by. + exclusion (bool, optional): If True, exclude objects with the given categories. Defaults to False. + + Returns: + list: List of filtered tracked objects. + """ + if isinstance(category, Union[float, int]): + category = [category] + if exclusion: + filtered_objects = [obj for obj in tracked_objects if obj.category not in category] + else: + filtered_objects = [obj for obj in tracked_objects if obj.category in category] + return filtered_objects + + +def reshape_tracker_result(tracker_output: np.ndarray): + """ + Function to reshape the tracker output if it has only one dimension. + + Args: + tracker_output (np.ndarray): The tracker output to reshape. + + Returns: + np.ndarray: The reshaped tracker output. + """ + if tracker_output.ndim == 1: + tracker_output = np.expand_dims(tracker_output, 0) + return tracker_output + + +def get_nb_output_cols(output_positions: OutputDataPositions): + """ + Function to get the number of output columns based on the model json schema. + + Args: + output_positions (OutputDataPositions): The output data positions. + + Returns: + int: The number of output columns. + """ + schema = output_positions.model_json_schema() + nb_cols = 0 + for feature in schema["properties"]: + if schema["properties"][feature]["type"] == "integer": + nb_cols += 1 + elif schema["properties"][feature]["type"] == "array": + nb_cols += len(schema["properties"][feature]["default"]) + else: + raise TypeError("Unknown type in required output positions.") + + return nb_cols