Skip to content

Commit

Permalink
Tp/add to txt saving (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
tristanpepinartefact authored Nov 16, 2023
1 parent 7bc79a6 commit c661fab
Show file tree
Hide file tree
Showing 12 changed files with 193 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,6 @@ secrets/*
data/detections/*
data/frames/*
*.mp4

*.txt
# poetry
poetry.lock
128 changes: 110 additions & 18 deletions notebooks/starter_kit_reid.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"source": [
"import os\n",
"import sys\n",
"from datetime import datetime\n",
"\n",
"import cv2\n",
"import numpy as np\n",
Expand All @@ -33,6 +34,28 @@
"sys.path.append(\"..\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For this demo, you have to install bytetrack. You can do so by typing the following command : \n",
"```bash\n",
"pip install git+https://github.com/artefactory-fr/bytetrack.git@main\n",
"````\n",
"\n",
"Baseline data can be found in `gs://data-track-reid/predictions/baseline`. You can copy them in `../data/predictions/` using the following commands (in a terminal at the root of the project):\n",
"\n",
"```bash\n",
"mkdir -p ./data/predictions/\n",
"gsutil -m cp -r gs://data-track-reid/predictions/baseline ./data/predictions/\n",
"```\n",
"\n",
"Then you can reoganize the data using the following : \n",
"```bash \n",
"find ./data/predictions/baseline -mindepth 2 -type f -name \"*.txt\" -exec sh -c 'mv \"$0\" \"${0%/*/*}/$(basename \"${0%/*}\").txt\"' {} \\; && find ./data/predictions/baseline -mindepth 1 -type d -empty -delete\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -49,9 +72,11 @@
"DATA_PATH = \"../data\"\n",
"DETECTION_PATH = f\"{DATA_PATH}/detections\"\n",
"FRAME_PATH = f\"{DATA_PATH}/frames\"\n",
"PREDICTIONS_PATH = f\"{DATA_PATH}/predictions\"\n",
"VIDEO_OUTPUT_PATH = \"private\"\n",
"\n",
"SEQUENCES = os.listdir(DETECTION_PATH)\n"
"SEQUENCES = os.listdir(DETECTION_PATH)\n",
"GENERATE_VIDEOS = False\n"
]
},
{
Expand Down Expand Up @@ -167,57 +192,124 @@
"metadata": {},
"outputs": [],
"source": [
"for sequence in SEQUENCES :\n",
"timestamp = datetime.now().strftime('%Y-%m-%d-%H-%M')\n",
"print(timestamp)\n",
"folder_save = os.path.join(PREDICTIONS_PATH,timestamp)\n",
"os.makedirs(folder_save, exist_ok=True)\n",
"GENERATE_VIDEOS = False\n",
"for sequence in tqdm(SEQUENCES) :\n",
" frame_path = get_sequence_frames(sequence)\n",
" test_sequence = Sequence(frame_path)\n",
" test_sequence\n",
" frame_id = 0\n",
" BaseTrack._count = 0\n",
" from datetime import datetime\n",
"\n",
" # Define the codec using VideoWriter_fourcc() and create a VideoWriter object\n",
" fourcc = cv2.VideoWriter_fourcc(*'avc1') # or use 'x264'\n",
" out = cv2.VideoWriter(f'{sequence}.mp4', fourcc, 20.0, (2560, 1440)) # adjust the frame size (640, 480) as per your needs\n",
" file_path = os.path.join(folder_save,sequence) + '.txt'\n",
" video_path = os.path.join(folder_save,sequence) + '.mp4'\n",
"\n",
" if GENERATE_VIDEOS:\n",
" fourcc = cv2.VideoWriter_fourcc(*'avc1') # or use 'x264'\n",
" out = cv2.VideoWriter(video_path, fourcc, 20.0, (2560, 1440)) # adjust the frame size (640, 480) as per your needs\n",
"\n",
" detection_handler = DetectionHandler(image_shape=[2560, 1440])\n",
" tracking_handler = TrackingHandler(tracker=BYTETracker(track_thresh= 0.3, track_buffer = 5, match_thresh = 0.85, frame_rate= 30))\n",
" reid_processor = ReidProcessor(filter_confidence_threshold=0.1,\n",
" filter_time_threshold=5,\n",
" cost_function=bounding_box_distance,\n",
" #cost_function_threshold=500, # max cost to rematch 2 objects\n",
" cost_function_threshold=5000, # max cost to rematch 2 objects\n",
" selection_function=select_by_category,\n",
" max_attempt_to_match=5,\n",
" max_frames_to_rematch=500)\n",
" max_frames_to_rematch=500,\n",
" save_to_txt=True,\n",
" file_path=file_path)\n",
"\n",
" for frame, detection in tqdm(test_sequence):\n",
" frame = np.array(frame)\n",
" for frame, detection in test_sequence:\n",
"\n",
" frame_id += 1\n",
"\n",
" processed_detections = detection_handler.process(detection)\n",
" processed_tracked = tracking_handler.update(processed_detections, frame_id)\n",
" reid_results = reid_processor.update(processed_tracked, frame_id)\n",
"\n",
" if len(reid_results) > 0:\n",
" if GENERATE_VIDEOS and len(reid_results) > 0:\n",
" frame = np.array(frame)\n",
" for res in reid_results:\n",
" object_id = int(res[OUTPUT_POSITIONS[\"object_id\"]])\n",
" bbox = list(map(int, res[OUTPUT_POSITIONS[\"bbox\"]]))\n",
" class_id = int(res[OUTPUT_POSITIONS[\"category\"]])\n",
" tracker_id = int(res[OUTPUT_POSITIONS[\"tracker_id\"]])\n",
" mean_confidence = float(res[OUTPUT_POSITIONS[\"mean_confidence\"]])\n",
" #mean_confidence_per_object[object_id].append((frame_id, mean_confidence))\n",
" x1, y1, x2, y2 = bbox\n",
" color = (0, 0, 255) if class_id else (0, 255, 0) # green for class 0, red for class 1\n",
" cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)\n",
" cv2.putText(frame, f\"{object_id} ({tracker_id}) : {round(mean_confidence,2)}\", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)\n",
" frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
"\n",
" # Write the frame to the video file\n",
" out.write(frame)\n",
" out.release()\n",
" if GENERATE_VIDEOS:\n",
" frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
" out.write(frame)\n",
"\n",
" if GENERATE_VIDEOS :\n",
" out.release()\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"\n",
"def count_occurrences(file_path, case):\n",
" object_counts = defaultdict(int)\n",
" class_counts = defaultdict(lambda: defaultdict(int))\n",
"\n",
" with open(file_path, 'r') as f:\n",
" for line in f:\n",
" data = line.split()\n",
"\n",
" if case != 'baseline':\n",
" object_id = int(data[1])\n",
" category = int(data[2])\n",
" else:\n",
" object_id = int(data[1])\n",
" category = int(data[-1])\n",
"\n",
" object_counts[object_id] += 1\n",
" class_counts[object_id][category] += 1\n",
"\n",
" return object_counts, class_counts\n",
"\n",
"def filter_counts(object_counts, class_counts, min_occurrences=10):\n",
" filtered_objects = {}\n",
"\n",
" for object_id, count in object_counts.items():\n",
" if count > min_occurrences and class_counts[object_id][0] > class_counts[object_id][1]:\n",
" filtered_objects[object_id] = count\n",
"\n",
" return filtered_objects\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"PATH_PREDICTIONS = f\"../data/predictions/{timestamp}\"\n",
"\n",
"for sequence in SEQUENCES:\n",
" print(\"-\"*50)\n",
" print(sequence)\n",
"\n",
" for case in [\"baseline\", timestamp]:\n",
" object_counts, class_counts = count_occurrences(f'../data/predictions/{case}/{sequence}.txt', case=case)\n",
" filtered_objects = filter_counts(object_counts, class_counts)\n",
"\n",
" print(sequence, len(reid_processor.seen_objects),reid_processor.nb_corrections)\n",
" print(reid_processor.seen_objects)\n"
" print(case)\n",
" print(filtered_objects)"
]
}
],
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ lapx = "^0.5.5"
opencv-python = "^4.8.1.78"
tqdm = "^4.66.1"
pillow = "^10.1.0"
bytetracker = {git = "git@github.com:artefactory-fr/bytetrack.git", branch = "main"}


[tool.poetry.group.dev.dependencies]
Expand Down
4 changes: 0 additions & 4 deletions tests/unit_tests/test_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@ def dummy_selection_function(candidate, switcher): # noqa: ARG001
candidates.append(obj)
switchers.append(obj)

print(candidates)
print(switchers)

matches = matcher.match(candidates, switchers)

assert len(matches) == 3
Expand All @@ -76,7 +73,6 @@ def dummy_selection_function(candidate, switcher):
switchers.append(obj)

matches = matcher.match(candidates, switchers)
print(matches)

assert len(matches) == 2
for match in matches:
Expand Down
12 changes: 6 additions & 6 deletions tests/unit_tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@ def test_tracked_metadata_copy():
copied_metadata = tracked_metadata.copy()
assert copied_metadata.first_frame_id == 15
assert copied_metadata.last_frame_id == 251
assert copied_metadata.class_counts == {"shop_item": 175, "personal_item": 0}
assert copied_metadata.class_counts == {0: 175, 1: 0}
assert copied_metadata.bbox == [598, 208, 814, 447]
assert copied_metadata.confidence == 0.610211
assert copied_metadata.confidence_sum == 111.30582399999996
assert copied_metadata.observations == 175

assert round(copied_metadata.percentage_of_time_seen(251), 2) == 73.84
class_proportions = copied_metadata.class_proportions()
assert round(class_proportions["shop_item"], 2) == 1.0
assert round(class_proportions["personal_item"], 2) == 0.0
assert round(class_proportions.get(0), 2) == 1.0
assert round(class_proportions.get(1), 2) == 0.0

tracked_metadata_2 = ALL_TRACKED_METADATA[1].copy()
tracked_metadata.merge(tracked_metadata_2)
# test impact of merge inplace in a copy, should be none

assert copied_metadata.class_counts == {"shop_item": 175, "personal_item": 0}
assert copied_metadata.class_counts == {0: 175, 1: 0}
assert copied_metadata.bbox == [598, 208, 814, 447]
assert copied_metadata.confidence == 0.610211
assert copied_metadata.confidence_sum == 111.30582399999996
Expand All @@ -44,8 +44,8 @@ def test_tracked_metadata_merge():
tracked_metadata_2 = ALL_TRACKED_METADATA[1].copy()
tracked_metadata_1.merge(tracked_metadata_2)
assert tracked_metadata_1.last_frame_id == 251
assert tracked_metadata_1.class_counts["shop_item"] == 175
assert tracked_metadata_1.class_counts["personal_item"] == 216
assert tracked_metadata_1.class_counts.get(0) == 175
assert tracked_metadata_1.class_counts.get(1) == 216
assert tracked_metadata_1.bbox == [548, 455, 846, 645]
assert tracked_metadata_1.confidence == 0.700626
assert tracked_metadata_1.confidence_sum == 260.988185
Expand Down
8 changes: 4 additions & 4 deletions tests/unit_tests/test_tracked_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_tracked_object_properties():
tracked_object = ALL_TRACKED_OBJECTS[0].copy()
assert tracked_object.object_id == 1.0
assert tracked_object.state == 0
assert tracked_object.category == "shop_item"
assert tracked_object.category == 0
assert round(tracked_object.confidence, 2) == 0.61
assert round(tracked_object.mean_confidence, 2) == 0.64
assert tracked_object.bbox == [598, 208, 814, 447]
Expand All @@ -52,7 +52,7 @@ def test_tracked_object_merge():
tracked_object_1.merge(tracked_object_2)
assert tracked_object_1.object_id == 1.0
assert tracked_object_1.state == 0
assert tracked_object_1.category == "personal_item"
assert tracked_object_1.category == 1
assert round(tracked_object_1.confidence, 2) == 0.70
assert round(tracked_object_1.mean_confidence, 2) == 0.67
assert tracked_object_1.bbox == [548, 455, 846, 645]
Expand All @@ -65,15 +65,15 @@ def test_tracked_object_cut():
new_object, cut_object = tracked_object.cut(2.0)
assert new_object.object_id == 14.0
assert new_object.state == 0
assert new_object.category == "shop_item"
assert new_object.category == 0
assert round(new_object.confidence, 2) == 0.61
assert round(new_object.mean_confidence, 2) == 0.64
assert new_object.bbox == [598, 208, 814, 447]
assert new_object.nb_ids == 3
assert new_object.nb_corrections == 2
assert cut_object.object_id == 1.0
assert cut_object.state == 0
assert cut_object.category == "shop_item"
assert cut_object.category == 0
assert round(cut_object.confidence, 2) == 0.61
assert round(cut_object.mean_confidence, 2) == 0.64
assert cut_object.bbox == [598, 208, 814, 447]
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ def test_filter_objects_by_state_2():


def test_filter_objects_by_category():
category = "shop_item"
category = 0
assert utils.filter_objects_by_category(ALL_TRACKED_OBJECTS, category, exclusion=False) == [
ALL_TRACKED_OBJECTS[0],
ALL_TRACKED_OBJECTS[2],
]


def test_filter_objects_by_category_2():
category = "personal_item"
category = 1
assert utils.filter_objects_by_category(ALL_TRACKED_OBJECTS, category, exclusion=True) == [
ALL_TRACKED_OBJECTS[0],
ALL_TRACKED_OBJECTS[2],
Expand Down
4 changes: 2 additions & 2 deletions trackreid/args/reid_args.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
POSSIBLE_CLASSES = [0.0, 1.0]
MAPPING_CLASSES = {0.0: "shop_item", 1.0: "personal_item"}
POSSIBLE_CLASSES = [0, 1]
MAPPING_CLASSES = {0: "shop_item", 1: "personal_item"}

INPUT_POSITIONS = {
"object_id": 4,
Expand Down
Loading

0 comments on commit c661fab

Please sign in to comment.