Skip to content

Commit

Permalink
feat: add multimodal support for ChatMessage
Browse files Browse the repository at this point in the history
  • Loading branch information
LastRemote committed Dec 4, 2024
1 parent e1aebb8 commit 6031934
Show file tree
Hide file tree
Showing 10 changed files with 605 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

import json
import logging
from base64 import b64encode
from typing import Any, Callable, Dict, List, Optional, Tuple, Type

from haystack import component, default_from_dict
from haystack.dataclasses import StreamingChunk
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace

from haystack_experimental.dataclasses import ChatMessage, ToolCall
from haystack_experimental.dataclasses import ChatMessage, ToolCall, ByteStream
from haystack_experimental.dataclasses.chat_message import ChatRole, ToolCallResult
from haystack_experimental.dataclasses.tool import Tool, deserialize_tools_inplace

Expand All @@ -38,7 +39,9 @@
# - AnthropicChatGenerator fails with ImportError at init (due to anthropic_integration_import.check()).

if anthropic_integration_import.is_successful():
chatgenerator_base_class: Type[AnthropicChatGeneratorBase] = AnthropicChatGeneratorBase
chatgenerator_base_class: Type[AnthropicChatGeneratorBase] = (
AnthropicChatGeneratorBase
)
else:
chatgenerator_base_class: Type[object] = object # type: ignore[no-redef]

Expand All @@ -57,7 +60,9 @@ def _update_anthropic_message_with_tool_call_results(

for tool_call_result in tool_call_results:
if tool_call_result.origin.id is None:
raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with Anthropic.")
raise ValueError(
"`ToolCall` must have a non-null `id` attribute to be used with Anthropic."
)
anthropic_msg["content"].append(
{
"type": "tool_result",
Expand All @@ -68,7 +73,9 @@ def _update_anthropic_message_with_tool_call_results(
)


def _convert_tool_calls_to_anthropic_format(tool_calls: List[ToolCall]) -> List[Dict[str, Any]]:
def _convert_tool_calls_to_anthropic_format(
tool_calls: List[ToolCall],
) -> List[Dict[str, Any]]:
"""
Convert a list of tool calls to the format expected by Anthropic Chat API.
Expand All @@ -78,7 +85,9 @@ def _convert_tool_calls_to_anthropic_format(tool_calls: List[ToolCall]) -> List[
anthropic_tool_calls = []
for tc in tool_calls:
if tc.id is None:
raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with Anthropic.")
raise ValueError(
"`ToolCall` must have a non-null `id` attribute to be used with Anthropic."
)
anthropic_tool_calls.append(
{
"type": "tool_use",
Expand All @@ -90,6 +99,44 @@ def _convert_tool_calls_to_anthropic_format(tool_calls: List[ToolCall]) -> List[
return anthropic_tool_calls


def _convert_media_to_anthropic_format(media: List[ByteStream]) -> List[Dict[str, Any]]:
"""
Convert a list of media to the format expected by Anthropic Chat API.
:param media: The list of ByteStreams to convert.
:return: A list of dictionaries in the format expected by Anthropic API.
"""
anthropic_media = []
for item in media:
if item.type == "image":
anthropic_media.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": item.mime_type,
"data": b64encode(item.data).decode("utf-8"),
},
}
)
elif item.type == "application" and item.subtype == "pdf":
anthropic_media.append(
{
"type": "document",
"source": {
"type": "base64",
"media_type": item.mime_type,
"data": b64encode(item.data).decode("utf-8"),
},
}
)
else:
raise ValueError(
f"Unsupported media type '{item.mime_type}' for Anthropic completions."
)
return anthropic_media


def _convert_messages_to_anthropic_format(
messages: List[ChatMessage],
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
Expand Down Expand Up @@ -119,10 +166,17 @@ def _convert_messages_to_anthropic_format(

anthropic_msg: Dict[str, Any] = {"role": message._role.value, "content": []}

if message.texts and message.texts[0]:
anthropic_msg["content"].append({"type": "text", "text": message.texts[0]})
if message.texts:
for item in message.texts:
anthropic_msg["content"].append({"type": "text", "text": item})
if message.media:
anthropic_msg["content"] += _convert_media_to_anthropic_format(
message.media
)
if message.tool_calls:
anthropic_msg["content"] += _convert_tool_calls_to_anthropic_format(message.tool_calls)
anthropic_msg["content"] += _convert_tool_calls_to_anthropic_format(
message.tool_calls
)

if message.tool_call_results:
results = message.tool_call_results.copy()
Expand All @@ -136,7 +190,8 @@ def _convert_messages_to_anthropic_format(

if not anthropic_msg["content"]:
raise ValueError(
"A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`."
"A `ChatMessage` must contain at least one `TextContent`, `MediaContent`, "
"`ToolCall`, or `ToolCallResult`."
)

anthropic_non_system_messages.append(anthropic_msg)
Expand Down Expand Up @@ -250,7 +305,9 @@ def to_dict(self) -> Dict[str, Any]:
The serialized component as a dictionary.
"""
serialized = super(AnthropicChatGenerator, self).to_dict()
serialized["init_parameters"]["tools"] = [tool.to_dict() for tool in self.tools] if self.tools else None
serialized["init_parameters"]["tools"] = (
[tool.to_dict() for tool in self.tools] if self.tools else None
)
return serialized

@classmethod
Expand All @@ -267,11 +324,15 @@ def from_dict(cls, data: Dict[str, Any]) -> "AnthropicChatGenerator":
init_params = data.get("init_parameters", {})
serialized_callback_handler = init_params.get("streaming_callback")
if serialized_callback_handler:
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
data["init_parameters"]["streaming_callback"] = deserialize_callable(
serialized_callback_handler
)

return default_from_dict(cls, data)

def _convert_chat_completion_to_chat_message(self, anthropic_response: Any) -> ChatMessage:
def _convert_chat_completion_to_chat_message(
self, anthropic_response: Any
) -> ChatMessage:
"""
Converts the response from the Anthropic API to a ChatMessage.
"""
Expand Down Expand Up @@ -343,15 +404,22 @@ def _convert_streaming_chunks_to_chat_message(
full_content += delta.get("text", "")
elif delta.get("type") == "input_json_delta" and current_tool_call:
current_tool_call["arguments"] += delta.get("partial_json", "")
elif chunk_type == "message_delta": # noqa: SIM102 (prefer nested if statement here for readability)
if chunk.meta.get("delta", {}).get("stop_reason") == "tool_use" and current_tool_call:
elif (
chunk_type == "message_delta"
): # noqa: SIM102 (prefer nested if statement here for readability)
if (
chunk.meta.get("delta", {}).get("stop_reason") == "tool_use"
and current_tool_call
):
try:
# arguments is a string, convert to json
tool_calls.append(
ToolCall(
id=current_tool_call.get("id"),
tool_name=str(current_tool_call.get("name")),
arguments=json.loads(current_tool_call.get("arguments", {})),
arguments=json.loads(
current_tool_call.get("arguments", {})
),
)
)
except json.JSONDecodeError:
Expand All @@ -370,7 +438,9 @@ def _convert_streaming_chunks_to_chat_message(
{
"model": model,
"index": 0,
"finish_reason": last_chunk_meta.get("delta", {}).get("stop_reason", None),
"finish_reason": last_chunk_meta.get("delta", {}).get(
"stop_reason", None
),
"usage": last_chunk_meta.get("usage", {}),
}
)
Expand Down Expand Up @@ -405,12 +475,16 @@ def run(
disallowed_params,
self.ALLOWED_PARAMS,
)
generation_kwargs = {k: v for k, v in generation_kwargs.items() if k in self.ALLOWED_PARAMS}
generation_kwargs = {
k: v for k, v in generation_kwargs.items() if k in self.ALLOWED_PARAMS
}
tools = tools or self.tools
if tools:
_check_duplicate_tool_names(tools)

system_messages, non_system_messages = _convert_messages_to_anthropic_format(messages)
system_messages, non_system_messages = _convert_messages_to_anthropic_format(
messages
)
anthropic_tools = (
[
{
Expand Down Expand Up @@ -447,12 +521,16 @@ def run(
"content_block_delta",
"message_delta",
]:
streaming_chunk = self._convert_anthropic_chunk_to_streaming_chunk(chunk)
streaming_chunk = self._convert_anthropic_chunk_to_streaming_chunk(
chunk
)
chunks.append(streaming_chunk)
if streaming_callback:
streaming_callback(streaming_chunk)

completion = self._convert_streaming_chunks_to_chat_message(chunks, model)
return {"replies": [completion]}
else:
return {"replies": [self._convert_chat_completion_to_chat_message(response)]}
return {
"replies": [self._convert_chat_completion_to_chat_message(response)]
}
118 changes: 77 additions & 41 deletions haystack_experimental/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import json
import os
from base64 import b64encode
from typing import Any, Dict, List, Optional, Union

from haystack import component, default_from_dict, default_to_dict, logging
Expand All @@ -19,7 +20,14 @@
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice

from haystack_experimental.dataclasses import ChatMessage, Tool, ToolCall
from haystack_experimental.dataclasses import (
ChatMessage,
Tool,
ToolCall,
TextContent,
ChatRole,
MediaContent,
)
from haystack_experimental.dataclasses.streaming_chunk import (
AsyncStreamingCallbackT,
StreamingCallbackT,
Expand All @@ -34,53 +42,81 @@ def _convert_message_to_openai_format(message: ChatMessage) -> Dict[str, Any]:
"""
Convert a message to the format expected by OpenAI's Chat API.
"""
text_contents = message.texts
tool_calls = message.tool_calls
tool_call_results = message.tool_call_results

if not text_contents and not tool_calls and not tool_call_results:
raise ValueError(
"A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`."
)
elif len(text_contents) + len(tool_call_results) > 1:
openai_msg: Dict[str, Any] = {"role": message.role.value}
if len(message) == 0:
raise ValueError(
"A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`."
"ChatMessage must contain at least one `TextContent`, "
"`MediaContent`, `ToolCall`, or `ToolCallResult`."
)

openai_msg: Dict[str, Any] = {"role": message._role.value}

if tool_call_results:
result = tool_call_results[0]
if result.origin.id is None:
if len(message) == 1 and isinstance(message.content[0], TextContent):
openai_msg["content"] = message.content[0].text
elif message.tool_call_result:
# Tool call results should only be included for ChatRole.TOOL messages
# and should not include any other content
if message.role != ChatRole.TOOL:
raise ValueError(
"Tool call results should only be included for tool messages."
)
if len(message) > 1:
raise ValueError(
"Tool call results should not be included with other content."
)
if message.tool_call_result.origin.id is None:
raise ValueError(
"`ToolCall` must have a non-null `id` attribute to be used with OpenAI."
)
openai_msg["content"] = result.result
openai_msg["tool_call_id"] = result.origin.id
# OpenAI does not provide a way to communicate errors in tool invocations, so we ignore the error field
return openai_msg

if text_contents:
openai_msg["content"] = text_contents[0]
if tool_calls:
openai_tool_calls = []
for tc in tool_calls:
if tc.id is None:
openai_msg["content"] = message.tool_call_result.result
openai_msg["tool_call_id"] = message.tool_call_result.origin.id
else:
openai_msg["content"] = []
for item in message.content:
if isinstance(item, TextContent):
openai_msg["content"].append({"type": "text", "text": item.text})
elif isinstance(item, MediaContent):
match item.media.type:
case "image":
base64_data = b64encode(item.media.data).decode("utf-8")
url = f"data:{item.media.mime_type};base64,{base64_data}"
openai_msg["content"].append(
{
"type": "image_url",
"image_url": {
"url": url,
"detail": item.media.meta.get("detail", "auto"),
},
}
)
case _:
raise ValueError(
f"Unsupported media type '{item.media.mime_type}' for OpenAI completions."
)
elif isinstance(item, ToolCall):
if message.role != ChatRole.ASSISTANT:
raise ValueError(
"Tool calls should only be included for assistant messages."
)
if item.id is None:
raise ValueError(
"`ToolCall` must have a non-null `id` attribute to be used with OpenAI."
)
openai_msg.setdefault("tool_calls", []).append(
{
"id": item.id,
"type": "function",
"function": {
"name": item.tool_name,
"arguments": json.dumps(item.arguments, ensure_ascii=False),
},
}
)
else:
raise ValueError(
"`ToolCall` must have a non-null `id` attribute to be used with OpenAI."
f"Unsupported content type '{type(item).__name__}' for OpenAI completions."
)
openai_tool_calls.append(
{
"id": tc.id,
"type": "function",
# We disable ensure_ascii so special chars like emojis are not converted
"function": {
"name": tc.tool_name,
"arguments": json.dumps(tc.arguments, ensure_ascii=False),
},
}
)
openai_msg["tool_calls"] = openai_tool_calls

if message.name:
openai_msg["name"] = message.name

return openai_msg


Expand Down
Loading

0 comments on commit 6031934

Please sign in to comment.