Skip to content

Commit

Permalink
feat: countgd sam2 video support (#318)
Browse files Browse the repository at this point in the history
* feat: countgd sam2 video
* added tests
* handle empty chunk length
* fixed mypy issues
* updated docstring
* allow multiple od tools
* added owlv2
* fixed import
* fixed probability
* fixed return example
* added florence2 support
  • Loading branch information
hrnn authored Dec 13, 2024
1 parent 63eab86 commit 421dd1e
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 2 deletions.
51 changes: 50 additions & 1 deletion tests/integ/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from vision_agent.tools import (
closest_mask_distance,
countgd_example_based_counting,
countgd_object_detection,
countgd_sam2_object_detection,
countgd_example_based_counting,
countgd_sam2_video_tracking,
depth_anything_v2,
detr_segmentation,
florence2_ocr,
Expand All @@ -17,8 +18,10 @@
flux_image_inpainting,
generate_pose_image,
ocr,
od_sam2_video_tracking,
owl_v2_image,
owl_v2_video,
owlv2_sam2_video_tracking,
qwen2_vl_images_vqa,
qwen2_vl_video_vqa,
siglip_classification,
Expand Down Expand Up @@ -471,3 +474,49 @@ def test_flux_image_inpainting_resizing_big_image():

assert result.shape[0] == 512
assert result.shape[1] == 208


def test_video_tracking_with_countgd():

frames = [
np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
]
result = countgd_sam2_video_tracking(
prompt="coin",
frames=frames,
)

assert len(result) == 10
assert len([res["label"] for res in result[0]]) == 24
assert len([res["mask"] for res in result[0]]) == 24


def test_video_tracking_with_owlv2():

frames = [
np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
]
result = owlv2_sam2_video_tracking(
prompt="coin",
frames=frames,
)

assert len(result) == 10
assert len([res["label"] for res in result[0]]) == 24
assert len([res["mask"] for res in result[0]]) == 24


def test_video_tracking_by_given_model():

frames = [
np.array(Image.fromarray(ski.data.coins()).convert("RGB")) for _ in range(10)
]
result = od_sam2_video_tracking(
od_model="florence2",
prompt="coin",
frames=frames,
)

assert len(result) == 10
assert len([res["label"] for res in result[0]]) == 24
assert len([res["mask"] for res in result[0]]) == 24
5 changes: 4 additions & 1 deletion vision_agent/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
claude35_text_extraction,
closest_box_distance,
closest_mask_distance,
countgd_example_based_counting,
countgd_object_detection,
countgd_sam2_object_detection,
countgd_example_based_counting,
countgd_sam2_video_tracking,
depth_anything_v2,
detr_segmentation,
extract_frames_and_timestamps,
Expand All @@ -46,11 +47,13 @@
load_image,
minimum_distance,
ocr,
od_sam2_video_tracking,
overlay_bounding_boxes,
overlay_heat_map,
overlay_segmentation_masks,
owl_v2_image,
owl_v2_video,
owlv2_sam2_video_tracking,
qwen2_vl_images_vqa,
qwen2_vl_video_vqa,
sam2,
Expand Down
192 changes: 192 additions & 0 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import urllib.request
from base64 import b64encode
from concurrent.futures import ThreadPoolExecutor, as_completed
from enum import Enum
from functools import lru_cache
from importlib import resources
from pathlib import Path
Expand Down Expand Up @@ -2394,6 +2395,197 @@ def _plot_counting(
return image


class ODModels(str, Enum):
COUNTGD = "countgd"
FLORENCE2 = "florence2"
OWLV2 = "owlv2"


def od_sam2_video_tracking(
od_model: ODModels,
prompt: str,
frames: List[np.ndarray],
chunk_length: Optional[int] = 10,
fine_tune_id: Optional[str] = None,
) -> List[List[Dict[str, Any]]]:

results: List[Optional[List[Dict[str, Any]]]] = [None] * len(frames)

if chunk_length is None:
step = 1 # Process every frame
elif chunk_length <= 0:
raise ValueError("chunk_length must be a positive integer or None.")
else:
step = chunk_length # Process frames with the specified step size

for idx in range(0, len(frames), step):
if od_model == ODModels.COUNTGD:
results[idx] = countgd_object_detection(prompt=prompt, image=frames[idx])
function_name = "countgd_object_detection"
elif od_model == ODModels.OWLV2:
results[idx] = owl_v2_image(
prompt=prompt, image=frames[idx], fine_tune_id=fine_tune_id
)
function_name = "owl_v2_image"
elif od_model == ODModels.FLORENCE2:
results[idx] = florence2_sam2_image(
prompt=prompt, image=frames[idx], fine_tune_id=fine_tune_id
)
function_name = "florence2_sam2_image"
else:
raise NotImplementedError(
f"Object detection model '{od_model}' is not implemented."
)

image_size = frames[0].shape[:2]

def _transform_detections(
input_list: List[Optional[List[Dict[str, Any]]]]
) -> List[Optional[Dict[str, Any]]]:
output_list: List[Optional[Dict[str, Any]]] = []

for idx, frame in enumerate(input_list):
if frame is not None:
labels = [detection["label"] for detection in frame]
bboxes = [
denormalize_bbox(detection["bbox"], image_size)
for detection in frame
]

output_list.append(
{
"labels": labels,
"bboxes": bboxes,
}
)
else:
output_list.append(None)

return output_list

output = _transform_detections(results)

buffer_bytes = frames_to_bytes(frames)
files = [("video", buffer_bytes)]
payload = {"bboxes": json.dumps(output), "chunk_length": chunk_length}
metadata = {"function_name": function_name}

detections = send_task_inference_request(
payload,
"sam2",
files=files,
metadata=metadata,
)

return_data = []
for frame in detections:
return_frame_data = []
for detection in frame:
mask = rle_decode_array(detection["mask"])
label = str(detection["id"]) + ": " + detection["label"]
return_frame_data.append({"label": label, "mask": mask, "score": 1.0})
return_data.append(return_frame_data)
return_data = add_bboxes_from_masks(return_data)
return nms(return_data, iou_threshold=0.95)


def countgd_sam2_video_tracking(
prompt: str,
frames: List[np.ndarray],
chunk_length: Optional[int] = 10,
) -> List[List[Dict[str, Any]]]:
"""'countgd_sam2_video_tracking' is a tool that can segment multiple objects given a text
prompt such as category names or referring expressions. The categories in the text
prompt are separated by commas. It returns a list of bounding boxes, label names,
mask file names and associated probability scores.
Parameters:
prompt (str): The prompt to ground to the image.
image (np.ndarray): The image to ground the prompt to.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label,
bounding box, and mask of the detected objects with normalized coordinates
(xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left
and xmax and ymax are the coordinates of the bottom-right of the bounding box.
The mask is binary 2D numpy array where 1 indicates the object and 0 indicates
the background.
Example
-------
>>> countgd_sam2_video_tracking("car, dinosaur", frames)
[
[
{
'label': '0: dinosaur',
'bbox': [0.1, 0.11, 0.35, 0.4],
'mask': array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
},
],
...
]
"""

return od_sam2_video_tracking(
ODModels.COUNTGD, prompt=prompt, frames=frames, chunk_length=chunk_length
)


def owlv2_sam2_video_tracking(
prompt: str,
frames: List[np.ndarray],
chunk_length: Optional[int] = 10,
fine_tune_id: Optional[str] = None,
) -> List[List[Dict[str, Any]]]:
"""'owlv2_sam2_video_tracking' is a tool that can segment multiple objects given a text
prompt such as category names or referring expressions. The categories in the text
prompt are separated by commas. It returns a list of bounding boxes, label names,
mask file names and associated probability scores.
Parameters:
prompt (str): The prompt to ground to the image.
image (np.ndarray): The image to ground the prompt to.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label,
bounding box, and mask of the detected objects with normalized coordinates
(xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left
and xmax and ymax are the coordinates of the bottom-right of the bounding box.
The mask is binary 2D numpy array where 1 indicates the object and 0 indicates
the background.
Example
-------
>>> countgd_sam2_video_tracking("car, dinosaur", frames)
[
[
{
'label': '0: dinosaur',
'bbox': [0.1, 0.11, 0.35, 0.4],
'mask': array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
},
],
...
]
"""

return od_sam2_video_tracking(
ODModels.OWLV2,
prompt=prompt,
frames=frames,
chunk_length=chunk_length,
fine_tune_id=fine_tune_id,
)


FUNCTION_TOOLS = [
owl_v2_image,
owl_v2_video,
Expand Down

0 comments on commit 421dd1e

Please sign in to comment.