Skip to content

Commit

Permalink
Add Tool Display (#319)
Browse files Browse the repository at this point in the history
* remove old tools

* use new display call

* remove old display call

* re-encode rle arrays

* clean up tools

* mypy fixes

* fixed display call

* add back in original api display call

* fix up tools

* fix type error
  • Loading branch information
dillonalaird authored Dec 10, 2024
1 parent 3d6d625 commit 33dd203
Show file tree
Hide file tree
Showing 5 changed files with 545 additions and 926 deletions.
153 changes: 0 additions & 153 deletions tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,19 @@
from PIL import Image

from vision_agent.tools import (
blip_image_caption,
clip,
closest_mask_distance,
countgd_object_detection,
countgd_sam2_object_detection,
countgd_example_based_counting,
depth_anything_v2,
detr_segmentation,
dpt_hybrid_midas,
florence2_image_caption,
florence2_ocr,
florence2_phrase_grounding,
florence2_phrase_grounding_video,
florence2_roberta_vqa,
florence2_sam2_image,
florence2_sam2_video_tracking,
flux_image_inpainting,
generate_pose_image,
generate_soft_edge_image,
git_vqa_v2,
grounding_dino,
grounding_sam,
ixc25_image_vqa,
ixc25_video_vqa,
loca_visual_prompt_counting,
loca_zero_shot_counting,
ocr,
owl_v2_image,
owl_v2_video,
Expand All @@ -44,27 +31,6 @@
FINE_TUNE_ID = "65ebba4a-88b7-419f-9046-0750e30250da"


def test_grounding_dino():
img = ski.data.coins()
result = grounding_dino(
prompt="coin",
image=img,
)
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24


def test_grounding_dino_tiny():
img = ski.data.coins()
result = grounding_dino(
prompt="coin",
image=img,
model_size="tiny",
)
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24


def test_owl_v2_image():
img = ski.data.coins()
result = owl_v2_image(
Expand Down Expand Up @@ -197,17 +163,6 @@ def test_template_match():
assert len(result) == 2


def test_grounding_sam():
img = ski.data.coins()
result = grounding_sam(
prompt="coin",
image=img,
)
assert len(result) == 24
assert [res["label"] for res in result] == ["coin"] * 24
assert len([res["mask"] for res in result]) == 24


def test_florence2_sam2_image():
img = ski.data.coins()
result = florence2_sam2_image(
Expand Down Expand Up @@ -285,24 +240,6 @@ def test_detr_segmentation_empty():
assert result == []


def test_clip():
img = ski.data.coins()
result = clip(
classes=["coins", "notes"],
image=img,
)
assert result["scores"] == [0.9999, 0.0001]


def test_clip_empty():
result = clip(
classes=["coins", "notes"],
image=np.zeros((0, 0, 3)).astype(np.uint8),
)
assert result["scores"] == []
assert result["labels"] == []


def test_vit_classification():
img = ski.data.coins()
result = vit_image_classification(
Expand All @@ -327,67 +264,6 @@ def test_nsfw_classification():
assert result["label"] == "normal"


def test_image_caption():
img = ski.data.rocket()
result = blip_image_caption(
image=img,
)
assert result.strip() == "a rocket on a stand"


def test_florence_image_caption():
img = ski.data.rocket()
result = florence2_image_caption(
image=img,
)
assert "The image shows a rocket on a launch pad at night" in result.strip()


def test_loca_zero_shot_counting():
img = ski.data.coins()

result = loca_zero_shot_counting(
image=img,
)
assert result["count"] == 21


def test_loca_visual_prompt_counting():
img = ski.data.coins()
result = loca_visual_prompt_counting(
visual_prompt={"bbox": [85, 106, 122, 145]},
image=img,
)
assert result["count"] == 25


def test_git_vqa_v2():
img = ski.data.rocket()
result = git_vqa_v2(
prompt="Is the scene captured during day or night ?",
image=img,
)
assert result.strip() == "night"


def test_image_qa_with_context():
img = ski.data.rocket()
result = florence2_roberta_vqa(
prompt="Is the scene captured during day or night ?",
image=img,
)
assert "night" in result.strip()


def test_ixc25_image_vqa():
img = ski.data.cat()
result = ixc25_image_vqa(
prompt="What animal is in this image?",
image=img,
)
assert "cat" in result.strip()


def test_qwen2_vl_images_vqa():
img = ski.data.page()
result = qwen2_vl_images_vqa(
Expand All @@ -408,17 +284,6 @@ def test_qwen2_vl_video_vqa():
assert "cat" in result.strip()


def test_ixc25_video_vqa():
frames = [
np.array(Image.fromarray(ski.data.cat()).convert("RGB")) for _ in range(10)
]
result = ixc25_video_vqa(
prompt="What animal is in this video?",
frames=frames,
)
assert "cat" in result.strip()


def test_video_temporal_localization():
frames = [
np.array(Image.fromarray(ski.data.cat()).convert("RGB")) for _ in range(10)
Expand Down Expand Up @@ -500,24 +365,6 @@ def test_generate_pose():
assert result.shape == img.shape + (3,)


def test_generate_normal():
img = ski.data.coins()
result = dpt_hybrid_midas(
image=img,
)

assert result.shape == img.shape + (3,)


def test_generate_hed():
img = ski.data.coins()
result = generate_soft_edge_image(
image=img,
)

assert result.shape == img.shape


def test_countgd_sam2_object_detection():
img = ski.data.coins()
result = countgd_sam2_object_detection(image=img, prompt="coin")
Expand Down
15 changes: 1 addition & 14 deletions vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,40 +23,27 @@
TOOLS_INFO,
UTIL_TOOLS,
UTILITIES_DOCSTRING,
blip_image_caption,
claude35_text_extraction,
clip,
closest_box_distance,
closest_mask_distance,
countgd_example_based_counting,
countgd_object_detection,
countgd_sam2_object_detection,
countgd_example_based_counting,
depth_anything_v2,
detr_segmentation,
dpt_hybrid_midas,
extract_frames_and_timestamps,
florence2_image_caption,
florence2_ocr,
florence2_phrase_grounding,
florence2_phrase_grounding_video,
florence2_roberta_vqa,
florence2_sam2_image,
florence2_sam2_video_tracking,
flux_image_inpainting,
generate_pose_image,
generate_soft_edge_image,
get_tool_documentation,
get_tool_recommender,
git_vqa_v2,
gpt4o_image_vqa,
gpt4o_video_vqa,
grounding_dino,
grounding_sam,
ixc25_image_vqa,
ixc25_video_vqa,
load_image,
loca_visual_prompt_counting,
loca_zero_shot_counting,
minimum_distance,
ocr,
overlay_bounding_boxes,
Expand Down
4 changes: 2 additions & 2 deletions vision_agent/tools/tool_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

class ToolCallTrace(BaseModel):
endpoint_url: str
type: str
request: MutableMapping[str, Any]
response: MutableMapping[str, Any]
error: Optional[Error]
Expand Down Expand Up @@ -221,14 +222,14 @@ def _call_post(
else:
response = session.post(url, json=payload)

# make sure function_name is in the payload so we can display it
tool_call_trace_payload = (
payload
if "function_name" in payload
else {**payload, **{"function_name": function_name}}
)
tool_call_trace = ToolCallTrace(
endpoint_url=url,
type="tool_call",
request=tool_call_trace_payload,
response={},
error=None,
Expand All @@ -252,7 +253,6 @@ def _call_post(
finally:
if tool_call_trace is not None:
trace = tool_call_trace.model_dump()
trace["type"] = "tool_call"
display({MimeType.APPLICATION_JSON: trace}, raw=True)


Expand Down
Loading

0 comments on commit 33dd203

Please sign in to comment.