Skip to content

Commit

Permalink
fix: commit histories (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDarmon authored Nov 23, 2023
2 parents 0151f3e + d82a733 commit e479f49
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 2 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/deploy_docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ jobs:
make install
- name: Deploying MkDocs documentation
run: |
mkdocs build
mkdocs gh-deploy --force
poetry run mkdocs build
poetry run mkdocs gh-deploy --force
206 changes: 206 additions & 0 deletions notebooks/starter_kit_reid.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,212 @@
" print(case)\n",
" print(filtered_objects)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class DetectionHandler():\n",
" def __init__(self, image_shape) -> None:\n",
" self.image_shape = image_shape\n",
"\n",
" def process(self, detection_output):\n",
" if detection_output.size:\n",
" if detection_output.ndim == 1:\n",
" detection_output = np.expand_dims(detection_output, 0)\n",
"\n",
" processed_detection = np.zeros(detection_output.shape)\n",
"\n",
" for idx, detection in enumerate(detection_output):\n",
" clss = detection[0]\n",
" conf = detection[5]\n",
" bbox = detection[1:5]\n",
" xyxy_bbox = xy_center_to_xyxy(bbox)\n",
" rescaled_bbox = rescale_bbox(xyxy_bbox,self.image_shape)\n",
" processed_detection[idx,:4] = rescaled_bbox\n",
" processed_detection[idx,4] = conf\n",
" processed_detection[idx,5] = clss\n",
"\n",
" return processed_detection\n",
" else:\n",
" return detection_output\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class TrackingHandler():\n",
" def __init__(self, tracker) -> None:\n",
" self.tracker = tracker\n",
"\n",
" def update(self, detection_outputs, frame_id):\n",
"\n",
" if not detection_outputs.size :\n",
" return detection_outputs\n",
"\n",
" processed_detections = self._pre_process(detection_outputs)\n",
" tracked_objects = self.tracker.update(processed_detections, _ = frame_id)\n",
" processed_tracked = self._post_process(tracked_objects)\n",
" return processed_tracked\n",
"\n",
" def _pre_process(self,detection_outputs : np.ndarray):\n",
" return detection_outputs\n",
"\n",
" def _post_process(self, tracked_objects : np.ndarray):\n",
"\n",
" if tracked_objects.size :\n",
" if tracked_objects.ndim == 1:\n",
" tracked_objects = np.expand_dims(tracked_objects, 0)\n",
"\n",
" return tracked_objects"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"timestamp = datetime.now().strftime('%Y-%m-%d-%H-%M')\n",
"print(timestamp)\n",
"folder_save = os.path.join(PREDICTIONS_PATH,timestamp)\n",
"os.makedirs(folder_save, exist_ok=True)\n",
"GENERATE_VIDEOS = False\n",
"for sequence in tqdm(SEQUENCES) :\n",
" frame_path = get_sequence_frames(sequence)\n",
" test_sequence = Sequence(frame_path)\n",
" frame_id = 0\n",
" BaseTrack._count = 0\n",
" from datetime import datetime\n",
"\n",
" file_path = os.path.join(folder_save,sequence) + '.txt'\n",
" video_path = os.path.join(folder_save,sequence) + '.mp4'\n",
"\n",
" if GENERATE_VIDEOS:\n",
" fourcc = cv2.VideoWriter_fourcc(*'avc1') # or use 'x264'\n",
" out = cv2.VideoWriter(video_path, fourcc, 20.0, (2560, 1440)) # adjust the frame size (640, 480) as per your needs\n",
"\n",
" detection_handler = DetectionHandler(image_shape=[2560, 1440])\n",
" tracking_handler = TrackingHandler(tracker=BYTETracker(track_thresh= 0.3, track_buffer = 5, match_thresh = 0.85, frame_rate= 30))\n",
" reid_processor = ReidProcessor(filter_confidence_threshold=0.1,\n",
" filter_time_threshold=5,\n",
" cost_function=bounding_box_distance,\n",
" cost_function_threshold=5000, # max cost to rematch 2 objects\n",
" selection_function=select_by_category,\n",
" max_attempt_to_match=5,\n",
" max_frames_to_rematch=500,\n",
" save_to_txt=True,\n",
" file_path=file_path)\n",
"\n",
" for frame, detection in test_sequence:\n",
"\n",
" frame_id += 1\n",
"\n",
" processed_detections = detection_handler.process(detection)\n",
" processed_tracked = tracking_handler.update(processed_detections, frame_id)\n",
" reid_results = reid_processor.update(processed_tracked, frame_id)\n",
"\n",
" if GENERATE_VIDEOS and len(reid_results) > 0:\n",
" frame = np.array(frame)\n",
" for res in reid_results:\n",
" object_id = int(res[output_data_positions.object_id])\n",
" bbox = list(map(int, res[output_data_positions.bbox]))\n",
" class_id = int(res[output_data_positions.category])\n",
" tracker_id = int(res[output_data_positions.tracker_id])\n",
" mean_confidence = float(res[output_data_positions.mean_confidence])\n",
" x1, y1, x2, y2 = bbox\n",
" color = (0, 0, 255) if class_id else (0, 255, 0) # green for class 0, red for class 1\n",
" cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)\n",
" cv2.putText(frame, f\"{object_id} ({tracker_id}) : {round(mean_confidence,2)}\", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)\n",
"\n",
" if GENERATE_VIDEOS:\n",
" frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
" out.write(frame)\n",
"\n",
" if GENERATE_VIDEOS :\n",
" out.release()\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"reid_processor.seen_objects"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from collections import defaultdict\n",
"\n",
"def count_occurrences(file_path, case):\n",
" object_counts = defaultdict(int)\n",
" class_counts = defaultdict(lambda: defaultdict(int))\n",
"\n",
" with open(file_path, 'r') as f:\n",
" for line in f:\n",
" data = line.split()\n",
"\n",
" if case != 'baseline':\n",
" object_id = int(data[1])\n",
" category = int(data[2])\n",
" else:\n",
" object_id = int(data[1])\n",
" category = int(data[-1])\n",
"\n",
" object_counts[object_id] += 1\n",
" class_counts[object_id][category] += 1\n",
"\n",
" return object_counts, class_counts\n",
"\n",
"def filter_counts(object_counts, class_counts, min_occurrences=10):\n",
" filtered_objects = {}\n",
"\n",
" for object_id, count in object_counts.items():\n",
" if count > min_occurrences and class_counts[object_id][0] > class_counts[object_id][1]:\n",
" filtered_objects[object_id] = count\n",
"\n",
" return filtered_objects\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"PATH_PREDICTIONS = f\"../data/predictions/{timestamp}\"\n",
"\n",
"for sequence in SEQUENCES:\n",
" print(\"-\"*50)\n",
" print(sequence)\n",
"\n",
" for case in [\"baseline\", timestamp]:\n",
" object_counts, class_counts = count_occurrences(f'../data/predictions/{case}/{sequence}.txt', case=case)\n",
" filtered_objects = filter_counts(object_counts, class_counts)\n",
"\n",
" print(case)\n",
" print(filtered_objects)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit e479f49

Please sign in to comment.