Skip to content

Commit

Permalink
allow multiple od tools
Browse files Browse the repository at this point in the history
  • Loading branch information
hrnn committed Dec 5, 2024
1 parent 984eaa3 commit 87a787c
Showing 1 changed file with 59 additions and 33 deletions.
92 changes: 59 additions & 33 deletions vision_agent/tools/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tempfile
import urllib.request
from concurrent.futures import ThreadPoolExecutor, as_completed
from enum import Enum
from functools import lru_cache
from importlib import resources
from pathlib import Path
Expand Down Expand Up @@ -2625,44 +2626,25 @@ def _plot_counting(
return image


def countgd_sam2_video_tracking(
class ODModels(str, Enum):
COUNTGD = "countgd"


def od_sam2_video_tracking(
od_model: ODModels,
prompt: str,
frames: List[np.ndarray],
chunk_length: Optional[int] = 10,
fine_tune_id: Optional[str] = None,
) -> List[List[Dict[str, Any]]]:
"""'countgd_sam2_video_tracking' is a tool that can segment multiple objects given a text
prompt such as category names or referring expressions. The categories in the text
prompt are separated by commas. It returns a list of bounding boxes, label names,
mask file names and associated probability scores of 1.0.
Parameters:
prompt (str): The prompt to ground to the image.
image (np.ndarray): The image to ground the prompt to.

Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label,
bounding box, and mask of the detected objects with normalized coordinates
(xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left
and xmax and ymax are the coordinates of the bottom-right of the bounding box.
The mask is binary 2D numpy array where 1 indicates the object and 0 indicates
the background.
if od_model == ODModels.COUNTGD:
detection_function = countgd_object_detection
else:
raise NotImplementedError(
f"Object detection model '{od_model.value}' is not implemented."
)

Example
-------
>>> countgd_sam2_video_tracking("car, dinosaur", image)
[
{
'score': 1.0,
'label': 'dinosaur',
'bbox': [0.1, 0.11, 0.35, 0.4],
'mask': array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
},
]
"""
results: List[Optional[List[Dict[str, Any]]]] = [None] * len(frames)

if chunk_length is None:
Expand All @@ -2673,7 +2655,7 @@ def countgd_sam2_video_tracking(
step = chunk_length # Process frames with the specified step size

for idx in range(0, len(frames), step):
results[idx] = countgd_object_detection(prompt=prompt, image=frames[idx])
results[idx] = detection_function(prompt=prompt, image=frames[idx])

image_size = frames[0].shape[:2]

Expand Down Expand Up @@ -2727,6 +2709,50 @@ def _transform_detections(
return nms(return_data, iou_threshold=0.95)


def countgd_sam2_video_tracking(
prompt: str,
frames: List[np.ndarray],
chunk_length: Optional[int] = 10,
) -> List[List[Dict[str, Any]]]:
"""'countgd_sam2_video_tracking' is a tool that can segment multiple objects given a text
prompt such as category names or referring expressions. The categories in the text
prompt are separated by commas. It returns a list of bounding boxes, label names,
mask file names and associated probability scores of 1.0.
Parameters:
prompt (str): The prompt to ground to the image.
image (np.ndarray): The image to ground the prompt to.
Returns:
List[Dict[str, Any]]: A list of dictionaries containing the score, label,
bounding box, and mask of the detected objects with normalized coordinates
(xmin, ymin, xmax, ymax). xmin and ymin are the coordinates of the top-left
and xmax and ymax are the coordinates of the bottom-right of the bounding box.
The mask is binary 2D numpy array where 1 indicates the object and 0 indicates
the background.
Example
-------
>>> countgd_sam2_video_tracking("car, dinosaur", image)
[
{
'score': 1.0,
'label': 'dinosaur',
'bbox': [0.1, 0.11, 0.35, 0.4],
'mask': array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8),
},
]
"""

return od_sam2_video_tracking(
ODModels.COUNTGD, prompt=prompt, frames=frames, chunk_length=chunk_length
)


FUNCTION_TOOLS = [
owl_v2_image,
owl_v2_video,
Expand Down

0 comments on commit 87a787c

Please sign in to comment.