Skip to content

Commit

Permalink
Hilo threshold (#311)
Browse files Browse the repository at this point in the history
* planner and coder to handle interaction type

* add interaction tool choice function

* vision agent to hanlde interaction types

* pass hilo parameter down to planner

* fix minor issues with interaction in planner

* redisplay tool calls so outside world can see them

* format fix

* fix planner for interaction

* fix vision agent for interactions

* better search for function call

* added box_threshold replace

* add box threshold set to get_tool_for_task

* update example code to handle hilo with threshold

* remove merge message

* lower default threshold

* type ignore
  • Loading branch information
dillonalaird authored Dec 2, 2024
1 parent 720a7f4 commit 39f893f
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 44 deletions.
41 changes: 29 additions & 12 deletions examples/chat/chat-app/src/components/ChatSection.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ const formatAssistantContent = (role: string, content: string) => {
const pythonMatch = content.match(/<execute_python>(.*?)<\/execute_python>/s);
const finalPlanJsonMatch = content.match(/<json>(.*?)<\/json>/s);
const interactionMatch = content.match(/<interaction>(.*?)<\/interaction>/s);
const interactionJson = JSON.parse(interactionMatch ? interactionMatch[1] : "{}");
const interactionJson = JSON.parse(
interactionMatch ? interactionMatch[1] : "{}",
);

const finalPlanJson = JSON.parse(
finalPlanJsonMatch ? finalPlanJsonMatch[1] : "{}",
Expand All @@ -88,7 +90,8 @@ const formatAssistantContent = (role: string, content: string) => {
return (
<>
<div>
<strong className="text-gray-700">[{role.toUpperCase()}]</strong> {finalPlanJson.plan}
<strong className="text-gray-700">[{role.toUpperCase()}]</strong>{" "}
{finalPlanJson.plan}
</div>
<pre className="bg-gray-800 text-white p-1.5 rounded mt-2 overflow-x-auto text-xs">
<code style={{ whiteSpace: "pre-wrap" }}>
Expand All @@ -105,30 +108,36 @@ const formatAssistantContent = (role: string, content: string) => {
return (
<>
<div>
<strong className="text-gray-700">[{role.toUpperCase()}]</strong> Function calls:
<strong className="text-gray-700">[{role.toUpperCase()}]</strong>{" "}
Function calls:
</div>
<pre className="bg-gray-800 text-white p-1.5 rounded mt-2 overflow-x-auto text-xs">
<code style={{ whiteSpace: "pre-wrap" }}>
{interactionJson.map((interaction: { request: { function_name: string } }) =>
`- ${interaction.request.function_name}`
).join('\n')}
{interactionJson
.map(
(interaction: { request: { function_name: string } }) =>
`- ${interaction.request.function_name}`,
)
.join("\n")}
</code>
</pre>
</>
)
);
}

if (responseMatch || thinkingMatch || pythonMatch) {
return (
<>
{thinkingMatch && (
<div>
<strong className="text-gray-700">[{role.toUpperCase()}]</strong> {thinkingMatch[1]}
<strong className="text-gray-700">[{role.toUpperCase()}]</strong>{" "}
{thinkingMatch[1]}
</div>
)}
{responseMatch && (
<div>
<strong className="text-gray-700">[{role.toUpperCase()}]</strong> {responseMatch[1]}
<strong className="text-gray-700">[{role.toUpperCase()}]</strong>{" "}
{responseMatch[1]}
</div>
)}
{pythonMatch && (
Expand All @@ -146,7 +155,7 @@ export function MessageBubble({ message }: MessageBubbleProps) {
return (
<div
className={`mb-4 ${
(message.role === "user" || message.role === "interaction_response")
message.role === "user" || message.role === "interaction_response"
? "ml-auto bg-primary text-primary-foreground"
: message.role === "assistant"
? "mr-auto bg-muted"
Expand Down Expand Up @@ -185,10 +194,18 @@ export function ChatSection({

if (input.value.trim()) {
let userMessage: Message;
if (messages.length > 0 && messages[messages.length - 1].role === "interaction") {
if (
messages.length > 0 &&
messages[messages.length - 1].role === "interaction"
) {
const function_name = input.value.split(",")[0].trim();
const box_threshold = input.value.split(",")[1].trim();
userMessage = {
role: "interaction_response",
content: JSON.stringify({function_name: input.value})
content: JSON.stringify({
function_name: function_name,
box_threshold: box_threshold,
}),
} as Message;
} else {
userMessage = { role: "user", content: input.value } as Message;
Expand Down
63 changes: 63 additions & 0 deletions tests/unit/test_planner_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from vision_agent.tools.planner_tools import check_function_call, replace_box_threshold


def test_check_function_call():
code = """
test_function('one', image1)
"""
assert check_function_call(code, "test_function") == True
assert check_function_call(code, "test_function2") == False


def test_check_function_call_try_catch():
code = """
try:
test_function('one', image1)
except Exception as e:
pass
"""
assert check_function_call(code, "test_function") == True
assert check_function_call(code, "test_function2") == False


def test_replace_box_threshold():
code = """
test_function('one', image1, box_threshold=0.1)
"""
expected_code = """
test_function('one', image1, box_threshold=0.5)
"""
assert replace_box_threshold(code, ["test_function"], 0.5) == expected_code


def test_replace_box_threshold_in_function():
code = """
def test_function_outer():
test_function('one', image1, box_threshold=0.1)
"""
expected_code = """
def test_function_outer():
test_function('one', image1, box_threshold=0.5)
"""
assert replace_box_threshold(code, ["test_function"], 0.5) == expected_code


def test_replace_box_threshold_no_arg():
code = """
test_function('one', image1)
"""
expected_code = """
test_function('one', image1, box_threshold=0.5)
"""
assert replace_box_threshold(code, ["test_function"], 0.5) == expected_code


def test_replace_box_threshold_no_func():
code = """
test_function2('one', image1)
"""
expected_code = """
test_function2('one', image1)
"""
assert replace_box_threshold(code, ["test_function"], 0.5) == expected_code

4 changes: 3 additions & 1 deletion vision_agent/agent/vision_agent_planner_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,10 @@ def replace_interaction_with_obs(chat: List[AgentMessage]) -> List[AgentMessage]
response = json.loads(chat[i + 1].content)
function_name = response["function_name"]
tool_doc = get_tool_documentation(function_name)
if "box_threshold" in response:
tool_doc = f"Use the following function with box_threshold={response['box_threshold']}\n\n{tool_doc}"
new_chat.append(AgentMessage(role="observation", content=tool_doc))
except json.JSONDecodeError:
except (json.JSONDecodeError, KeyError):
raise ValueError(f"Invalid JSON in interaction response: {chat_i}")
else:
new_chat.append(chat_i)
Expand Down
90 changes: 59 additions & 31 deletions vision_agent/tools/planner_tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import logging
import shutil
import tempfile
Expand Down Expand Up @@ -63,12 +64,55 @@ def extract_tool_info(
return tool, tool_thoughts, tool_docstring, ""


def replace_box_threshold(code: str, functions: List[str], box_threshold: float) -> str:
class ReplaceBoxThresholdTransformer(cst.CSTTransformer):
def leave_Call(
self, original_node: cst.Call, updated_node: cst.Call
) -> cst.Call:
if (
isinstance(updated_node.func, cst.Name)
and updated_node.func.value in functions
) or (
isinstance(updated_node.func, cst.Attribute)
and updated_node.func.attr.value in functions
):
new_args = []
found = False
for arg in updated_node.args:
if arg.keyword and arg.keyword.value == "box_threshold":
new_arg = arg.with_changes(value=cst.Float(str(box_threshold)))
new_args.append(new_arg)
found = True
else:
new_args.append(arg)

if not found:
new_args.append(
cst.Arg(
keyword=cst.Name("box_threshold"),
value=cst.Float(str(box_threshold)),
equal=cst.AssignEqual(
whitespace_before=cst.SimpleWhitespace(""),
whitespace_after=cst.SimpleWhitespace(""),
),
)
)
return updated_node.with_changes(args=new_args)
return updated_node

tree = cst.parse_module(code)
transformer = ReplaceBoxThresholdTransformer()
new_tree = tree.visit(transformer)
return new_tree.code


def run_tool_testing(
task: str,
image_paths: List[str],
lmm: LMM,
exclude_tools: Optional[List[str]],
code_interpreter: CodeInterpreter,
process_code: Callable[[str], str] = lambda x: x,
) -> tuple[str, str, Execution]:
"""Helper function to generate and run tool testing code."""
query = lmm.generate(CATEGORIZE_TOOL_REQUEST.format(task=task))
Expand Down Expand Up @@ -101,6 +145,7 @@ def run_tool_testing(
code = extract_tag(response, "code") # type: ignore
if code is None:
raise ValueError(f"Could not extract code from response: {response}")
code = process_code(code)
tool_output = code_interpreter.exec_isolation(DefaultImports.prepend_imports(code))
tool_output_str = tool_output.text(include_results=False).strip()

Expand All @@ -119,6 +164,7 @@ def run_tool_testing(
media=str(image_paths),
)
code = extract_code(lmm.generate(prompt, media=image_paths)) # type: ignore
code = process_code(code)
tool_output = code_interpreter.exec_isolation(
DefaultImports.prepend_imports(code)
)
Expand Down Expand Up @@ -221,36 +267,7 @@ def get_tool_documentation(tool_name: str) -> str:
def get_tool_for_task_human_reviewer(
task: str, images: List[np.ndarray], exclude_tools: Optional[List[str]] = None
) -> None:
# NOTE: this should be the same documentation as get_tool_for_task
"""Given a task and one or more images this function will find a tool to accomplish
the jobs. It prints the tool documentation and thoughts on why it chose the tool.
It can produce tools for the following types of tasks:
- Object detection and counting
- Classification
- Segmentation
- OCR
- VQA
- Depth and pose estimation
- Video object tracking
Wait until the documentation is printed to use the function so you know what the
input and output signatures are.
Parameters:
task: str: The task to accomplish.
images: List[np.ndarray]: The images to use for the task.
exclude_tools: Optional[List[str]]: A list of tool names to exclude from the
recommendations. This is helpful if you are calling get_tool_for_task twice
and do not want the same tool recommended.
Returns:
The tool to use for the task is printed to stdout
Examples
--------
>>> get_tool_for_task("Give me an OCR model that can find 'hot chocolate' in the image", [image])
"""
# NOTE: this will have the same documentation as get_tool_for_task
lmm = AnthropicLMM()

with (
Expand All @@ -263,8 +280,19 @@ def get_tool_for_task_human_reviewer(
Image.fromarray(image).save(image_path)
image_paths.append(image_path)

tools = [
t.__name__
for t in T.TOOLS
if inspect.signature(t).parameters.get("box_threshold") # type: ignore
]

_, _, tool_output = run_tool_testing(
task, image_paths, lmm, exclude_tools, code_interpreter
task,
image_paths,
lmm,
exclude_tools,
code_interpreter,
process_code=lambda x: replace_box_threshold(x, tools, 0.05),
)

# need to re-display results for the outer notebook to see them
Expand Down

0 comments on commit 39f893f

Please sign in to comment.