Skip to content

Commit

Permalink
fixed display call
Browse files Browse the repository at this point in the history
  • Loading branch information
dillonalaird committed Dec 6, 2024
1 parent 70ed2cd commit fe0c41a
Showing 1 changed file with 36 additions and 9 deletions.
45 changes: 36 additions & 9 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,10 +467,22 @@ def florence2_sam2_video_tracking(
return_data.append(return_frame_data)
return_data = add_bboxes_from_masks(return_data)
return_data = nms(return_data, iou_threshold=0.95)

_display_tool_trace(
florence2_sam2_video_tracking.__name__,
payload,
detections[0],
[
[
{
"label": e["label"],
"score": e["score"],
"bbox": denormalize_bbox(e["bbox"], frames[0].shape[:2]),
"mask": rle_encode_array(e["mask"]),
}
for e in lst
]
for lst in return_data
],
files,
)
return return_data
Expand Down Expand Up @@ -686,21 +698,29 @@ def _run_countgd(prompt: str) -> List[Dict[str, Any]]:
)
for bbox in bboxes
]

# TODO: remove this once we start to use the confidence on countgd
filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
return_data = [bbox.model_dump() for bbox in filtered_bboxes]
return_data = single_nms(return_data, iou_threshold=0.80)
_display_tool_trace(
countgd_object_detection.__name__,
{
"prompts": prompt,
"confidence": box_threshold,
"model": "countgd",
},
bboxes,
[
{
"label": e["label"],
"score": e["score"],
"bbox": denormalize_bbox(e["bbox"], image_size),
}
for e in return_data
],
files,
)

# TODO: remove this once we start to use the confidence on countgd
filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
return_data = [bbox.model_dump() for bbox in filtered_bboxes]
return single_nms(return_data, iou_threshold=0.80)
return return_data


def countgd_sam2_object_detection(
Expand Down Expand Up @@ -830,14 +850,21 @@ def countgd_example_based_counting(
)
for bbox in bboxes_per_frame
]
filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
_display_tool_trace(
countgd_example_based_counting.__name__,
payload,
detections[0],
[
{
"label": e.label,
"score": e.score,
"bbox": denormalize_bbox(e.bbox, image_size), # type: ignore
}
for e in filtered_bboxes
],
files,
)

filtered_bboxes = filter_bboxes_by_threshold(bboxes_formatted, box_threshold)
return [bbox.model_dump() for bbox in filtered_bboxes]


Expand Down

0 comments on commit fe0c41a

Please sign in to comment.