diff --git a/.github/workflows/deploy_docs.yaml b/.github/workflows/deploy_docs.yaml index 8d4a6b6..adc0cff 100644 --- a/.github/workflows/deploy_docs.yaml +++ b/.github/workflows/deploy_docs.yaml @@ -23,5 +23,5 @@ jobs: make install - name: Deploying MkDocs documentation run: | - mkdocs build - mkdocs gh-deploy --force + poetry run mkdocs build + poetry run mkdocs gh-deploy --force diff --git a/notebooks/starter_kit_reid.ipynb b/notebooks/starter_kit_reid.ipynb index 685d1a6..177b9f7 100644 --- a/notebooks/starter_kit_reid.ipynb +++ b/notebooks/starter_kit_reid.ipynb @@ -317,6 +317,212 @@ " print(case)\n", " print(filtered_objects)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class DetectionHandler():\n", + " def __init__(self, image_shape) -> None:\n", + " self.image_shape = image_shape\n", + "\n", + " def process(self, detection_output):\n", + " if detection_output.size:\n", + " if detection_output.ndim == 1:\n", + " detection_output = np.expand_dims(detection_output, 0)\n", + "\n", + " processed_detection = np.zeros(detection_output.shape)\n", + "\n", + " for idx, detection in enumerate(detection_output):\n", + " clss = detection[0]\n", + " conf = detection[5]\n", + " bbox = detection[1:5]\n", + " xyxy_bbox = xy_center_to_xyxy(bbox)\n", + " rescaled_bbox = rescale_bbox(xyxy_bbox,self.image_shape)\n", + " processed_detection[idx,:4] = rescaled_bbox\n", + " processed_detection[idx,4] = conf\n", + " processed_detection[idx,5] = clss\n", + "\n", + " return processed_detection\n", + " else:\n", + " return detection_output\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class TrackingHandler():\n", + " def __init__(self, tracker) -> None:\n", + " self.tracker = tracker\n", + "\n", + " def update(self, detection_outputs, frame_id):\n", + "\n", + " if not detection_outputs.size :\n", + " return detection_outputs\n", + "\n", + " processed_detections = self._pre_process(detection_outputs)\n", + " tracked_objects = self.tracker.update(processed_detections, _ = frame_id)\n", + " processed_tracked = self._post_process(tracked_objects)\n", + " return processed_tracked\n", + "\n", + " def _pre_process(self,detection_outputs : np.ndarray):\n", + " return detection_outputs\n", + "\n", + " def _post_process(self, tracked_objects : np.ndarray):\n", + "\n", + " if tracked_objects.size :\n", + " if tracked_objects.ndim == 1:\n", + " tracked_objects = np.expand_dims(tracked_objects, 0)\n", + "\n", + " return tracked_objects" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "timestamp = datetime.now().strftime('%Y-%m-%d-%H-%M')\n", + "print(timestamp)\n", + "folder_save = os.path.join(PREDICTIONS_PATH,timestamp)\n", + "os.makedirs(folder_save, exist_ok=True)\n", + "GENERATE_VIDEOS = False\n", + "for sequence in tqdm(SEQUENCES) :\n", + " frame_path = get_sequence_frames(sequence)\n", + " test_sequence = Sequence(frame_path)\n", + " frame_id = 0\n", + " BaseTrack._count = 0\n", + " from datetime import datetime\n", + "\n", + " file_path = os.path.join(folder_save,sequence) + '.txt'\n", + " video_path = os.path.join(folder_save,sequence) + '.mp4'\n", + "\n", + " if GENERATE_VIDEOS:\n", + " fourcc = cv2.VideoWriter_fourcc(*'avc1') # or use 'x264'\n", + " out = cv2.VideoWriter(video_path, fourcc, 20.0, (2560, 1440)) # adjust the frame size (640, 480) as per your needs\n", + "\n", + " detection_handler = DetectionHandler(image_shape=[2560, 1440])\n", + " tracking_handler = TrackingHandler(tracker=BYTETracker(track_thresh= 0.3, track_buffer = 5, match_thresh = 0.85, frame_rate= 30))\n", + " reid_processor = ReidProcessor(filter_confidence_threshold=0.1,\n", + " filter_time_threshold=5,\n", + " cost_function=bounding_box_distance,\n", + " cost_function_threshold=5000, # max cost to rematch 2 objects\n", + " selection_function=select_by_category,\n", + " max_attempt_to_match=5,\n", + " max_frames_to_rematch=500,\n", + " save_to_txt=True,\n", + " file_path=file_path)\n", + "\n", + " for frame, detection in test_sequence:\n", + "\n", + " frame_id += 1\n", + "\n", + " processed_detections = detection_handler.process(detection)\n", + " processed_tracked = tracking_handler.update(processed_detections, frame_id)\n", + " reid_results = reid_processor.update(processed_tracked, frame_id)\n", + "\n", + " if GENERATE_VIDEOS and len(reid_results) > 0:\n", + " frame = np.array(frame)\n", + " for res in reid_results:\n", + " object_id = int(res[output_data_positions.object_id])\n", + " bbox = list(map(int, res[output_data_positions.bbox]))\n", + " class_id = int(res[output_data_positions.category])\n", + " tracker_id = int(res[output_data_positions.tracker_id])\n", + " mean_confidence = float(res[output_data_positions.mean_confidence])\n", + " x1, y1, x2, y2 = bbox\n", + " color = (0, 0, 255) if class_id else (0, 255, 0) # green for class 0, red for class 1\n", + " cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)\n", + " cv2.putText(frame, f\"{object_id} ({tracker_id}) : {round(mean_confidence,2)}\", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)\n", + "\n", + " if GENERATE_VIDEOS:\n", + " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", + " out.write(frame)\n", + "\n", + " if GENERATE_VIDEOS :\n", + " out.release()\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reid_processor.seen_objects" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from collections import defaultdict\n", + "\n", + "def count_occurrences(file_path, case):\n", + " object_counts = defaultdict(int)\n", + " class_counts = defaultdict(lambda: defaultdict(int))\n", + "\n", + " with open(file_path, 'r') as f:\n", + " for line in f:\n", + " data = line.split()\n", + "\n", + " if case != 'baseline':\n", + " object_id = int(data[1])\n", + " category = int(data[2])\n", + " else:\n", + " object_id = int(data[1])\n", + " category = int(data[-1])\n", + "\n", + " object_counts[object_id] += 1\n", + " class_counts[object_id][category] += 1\n", + "\n", + " return object_counts, class_counts\n", + "\n", + "def filter_counts(object_counts, class_counts, min_occurrences=10):\n", + " filtered_objects = {}\n", + "\n", + " for object_id, count in object_counts.items():\n", + " if count > min_occurrences and class_counts[object_id][0] > class_counts[object_id][1]:\n", + " filtered_objects[object_id] = count\n", + "\n", + " return filtered_objects\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "PATH_PREDICTIONS = f\"../data/predictions/{timestamp}\"\n", + "\n", + "for sequence in SEQUENCES:\n", + " print(\"-\"*50)\n", + " print(sequence)\n", + "\n", + " for case in [\"baseline\", timestamp]:\n", + " object_counts, class_counts = count_occurrences(f'../data/predictions/{case}/{sequence}.txt', case=case)\n", + " filtered_objects = filter_counts(object_counts, class_counts)\n", + "\n", + " print(case)\n", + " print(filtered_objects)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {