Skip to content

Commit

Permalink
overlay bboxes works with frames (#244)
Browse files Browse the repository at this point in the history
* overlay bboxes works with frames

* fix mkdocs

* fix return type
  • Loading branch information
dillonalaird authored Sep 25, 2024
1 parent f0a5c90 commit baf650a
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 32 deletions.
2 changes: 1 addition & 1 deletion docs/api/lmm.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

::: vision_agent.lmm.OllamaLMM

::: vision_agent.lmm.ClaudeSonnetLMM
::: vision_agent.lmm.AnthropicLMM
79 changes: 48 additions & 31 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1759,14 +1759,17 @@ def _save_video_to_result(video_uri: str) -> None:


def overlay_bounding_boxes(
image: np.ndarray, bboxes: List[Dict[str, Any]]
) -> np.ndarray:
medias: Union[np.ndarray, List[np.ndarray]],
bboxes: Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]],
) -> Union[np.ndarray, List[np.ndarray]]:
"""'overlay_bounding_boxes' is a utility function that displays bounding boxes on
an image.
Parameters:
image (np.ndarray): The image to display the bounding boxes on.
bboxes (List[Dict[str, Any]]): A list of dictionaries containing the bounding
medias (Union[np.ndarray, List[np.ndarra]]): The image or frames to display the
bounding boxes on.
bboxes (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]): A list of
dictionaries or a list of list of dictionaries containing the bounding
boxes.
Returns:
Expand All @@ -1778,41 +1781,54 @@ def overlay_bounding_boxes(
image, [{'score': 0.99, 'label': 'dinosaur', 'bbox': [0.1, 0.11, 0.35, 0.4]}],
)
"""
pil_image = Image.fromarray(image.astype(np.uint8)).convert("RGB")

if len(set([box["label"] for box in bboxes])) > len(COLORS):
medias_int: List[np.ndarray] = (
[medias] if isinstance(medias, np.ndarray) else medias
)
bbox_int = [bboxes] if isinstance(bboxes[0], dict) else bboxes
bbox_int = cast(List[List[Dict[str, Any]]], bbox_int)
labels = set([bb["label"] for b in bbox_int for bb in b])

if len(labels) > len(COLORS):
_LOGGER.warning(
"Number of unique labels exceeds the number of available colors. Some labels may have the same color."
)

color = {
label: COLORS[i % len(COLORS)]
for i, label in enumerate(set([box["label"] for box in bboxes]))
}
bboxes = sorted(bboxes, key=lambda x: x["label"], reverse=True)
color = {label: COLORS[i % len(COLORS)] for i, label in enumerate(labels)}

width, height = pil_image.size
fontsize = max(12, int(min(width, height) / 40))
draw = ImageDraw.Draw(pil_image)
font = ImageFont.truetype(
str(resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")),
fontsize,
)
frame_out = []
for i, frame in enumerate(medias_int):
pil_image = Image.fromarray(frame.astype(np.uint8)).convert("RGB")

for elt in bboxes:
label = elt["label"]
box = elt["bbox"]
scores = elt["score"]
bboxes = bbox_int[i]
bboxes = sorted(bboxes, key=lambda x: x["label"], reverse=True)

# denormalize the box if it is normalized
box = denormalize_bbox(box, (height, width))
width, height = pil_image.size
fontsize = max(12, int(min(width, height) / 40))
draw = ImageDraw.Draw(pil_image)
font = ImageFont.truetype(
str(
resources.files("vision_agent.fonts").joinpath("default_font_ch_en.ttf")
),
fontsize,
)

draw.rectangle(box, outline=color[label], width=4)
text = f"{label}: {scores:.2f}"
text_box = draw.textbbox((box[0], box[1]), text=text, font=font)
draw.rectangle((box[0], box[1], text_box[2], text_box[3]), fill=color[label])
draw.text((box[0], box[1]), text, fill="black", font=font)
return np.array(pil_image)
for elt in bboxes:
label = elt["label"]
box = elt["bbox"]
scores = elt["score"]

# denormalize the box if it is normalized
box = denormalize_bbox(box, (height, width))
draw.rectangle(box, outline=color[label], width=4)
text = f"{label}: {scores:.2f}"
text_box = draw.textbbox((box[0], box[1]), text=text, font=font)
draw.rectangle(
(box[0], box[1], text_box[2], text_box[3]), fill=color[label]
)
draw.text((box[0], box[1]), text, fill="black", font=font)
frame_out.append(np.array(pil_image))
return frame_out[0] if len(frame_out) == 1 else frame_out


def _get_text_coords_from_mask(
Expand Down Expand Up @@ -1852,7 +1868,8 @@ def overlay_segmentation_masks(
medias (Union[np.ndarray, List[np.ndarray]]): The image or frames to display
the masks on.
masks (Union[List[Dict[str, Any]], List[List[Dict[str, Any]]]]): A list of
dictionaries containing the masks, labels and scores.
dictionaries or a list of list of dictionaries containing the masks, labels
and scores.
draw_label (bool, optional): If True, the labels will be displayed on the image.
secondary_label_key (str, optional): The key to use for the secondary
tracking label which is needed in videos to display tracking information.
Expand Down

0 comments on commit baf650a

Please sign in to comment.