From e1641282891179b76be2eaf44bce03ba1799cca3 Mon Sep 17 00:00:00 2001 From: phact Date: Thu, 21 Nov 2024 22:32:22 -0500 Subject: [PATCH] handle attatchents in run --- .../astra-assistants/test_run_retreival_v2.py | 17 +++- .../test_streaming_run_retrieval_v2.py | 47 ++++++++++- client/tests/fixtures/hudson.txt | 1 + impl/model_v2/create_message_request.py | 10 +++ .../message_object_attatchments_inner.py | 9 +++ impl/routes_v2/threads_v2.py | 78 +++++++++++++------ 6 files changed, 136 insertions(+), 26 deletions(-) create mode 100644 client/tests/fixtures/hudson.txt create mode 100644 impl/model_v2/create_message_request.py create mode 100644 impl/model_v2/message_object_attatchments_inner.py diff --git a/client/tests/astra-assistants/test_run_retreival_v2.py b/client/tests/astra-assistants/test_run_retreival_v2.py index 38d0597..09932b0 100644 --- a/client/tests/astra-assistants/test_run_retreival_v2.py +++ b/client/tests/astra-assistants/test_run_retreival_v2.py @@ -71,8 +71,23 @@ def run_with_assistant(assistant, client, file_path, embedding_model): user_message = "What are some cool math concepts behind this ML paper pdf? Explain in two sentences." logger.info("creating persistent thread and message") thread = client.beta.threads.create() + + # Create a message with an attachment that has file_search enabled + file2 = client.files.create( + file=open( + "./tests/fixtures/hudson.txt", + "rb", + ), + purpose="assistants", + embedding_model=embedding_model, + ) + client.beta.threads.messages.create( - thread_id=thread.id, role="user", content=user_message + thread_id=thread.id, + role="user", + content=user_message, + file_ids=[file2.id], + tools=[{"type": "file_search"}] ) logger.info(f"> {user_message}") diff --git a/client/tests/astra-assistants/test_streaming_run_retrieval_v2.py b/client/tests/astra-assistants/test_streaming_run_retrieval_v2.py index 735e103..e4b3c57 100644 --- a/client/tests/astra-assistants/test_streaming_run_retrieval_v2.py +++ b/client/tests/astra-assistants/test_streaming_run_retrieval_v2.py @@ -4,6 +4,7 @@ import time from openai.lib.streaming import AssistantEventHandler +from openai.types.beta.threads.message_create_params import Attachment from typing_extensions import override def run_with_assistant(assistant, client): @@ -53,11 +54,26 @@ def run_with_assistant(assistant, client): tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}}, ) print(f"updated assistant: {assistant}") - user_message = "What are some cool math concepts behind this ML paper pdf? Explain in two sentences." print("creating persistent thread and message") thread = client.beta.threads.create() + + # Create a message with an attachment that has file_search enabled + file2 = client.files.create( + file=open( + "./tests/fixtures/hudson.txt", + "rb", + ), + purpose="assistants", + ) + + user_message = "What are some cool math concepts behind this ML paper pdf? Explain in two sentences." client.beta.threads.messages.create( - thread_id=thread.id, role="user", content=user_message + thread_id=thread.id, + role="user", + content=user_message, + attachments=[ + Attachment(file_id=file2.id, tools=[{"type": "file_search"}]), + ] ) print(f"> {user_message}") @@ -74,7 +90,7 @@ def on_run_step_done(self, run_step) -> None: for tool_call in run_step.step_details.tool_calls: matches = tool_call.file_search print(tool_call.file_search) - assert len(matches) > 0, "No matches found" + assert len(matches.chunks) > 0, "No matches found" @override def on_text_created(self, text) -> None: @@ -102,6 +118,31 @@ def on_text_delta(self, delta, snapshot): assert event_handler.on_text_created_count > 0, "No text created" assert event_handler.on_text_delta_count > 0, "No text delta" + user_message = "What is the name of my dog" + client.beta.threads.messages.create( + thread_id=thread.id, + role="user", + content=user_message, + attachments=[ + Attachment(file_id=file2.id, tools=[{"type": "file_search"}]), + ] + ) + print(f"> {user_message}") + + event_handler = EventHandler() + print(f"creating run") + with client.beta.threads.runs.create_and_stream( + thread_id=thread.id, + assistant_id=assistant.id, + event_handler=event_handler, + ) as stream: + for part in stream.text_deltas: + print(part) + + assert event_handler.on_text_created_count > 0, "No text created" + assert event_handler.on_text_delta_count > 0, "No text delta" + + instructions = "You are a personal math tutor. Answer thoroughly. The system will provide relevant context from files, use the context to respond." diff --git a/client/tests/fixtures/hudson.txt b/client/tests/fixtures/hudson.txt new file mode 100644 index 0000000..27e2c12 --- /dev/null +++ b/client/tests/fixtures/hudson.txt @@ -0,0 +1 @@ +the name of my dog is Hudson diff --git a/impl/model_v2/create_message_request.py b/impl/model_v2/create_message_request.py new file mode 100644 index 0000000..f2f93ab --- /dev/null +++ b/impl/model_v2/create_message_request.py @@ -0,0 +1,10 @@ +from typing import Optional, List + +from pydantic import Field + +from model_v2.message_object_attatchments_inner import MessageObjectAttachmentsInner +from openapi_server_v2.models.create_message_request import CreateMessageRequest as CreateMessageRequestGenerated + + +class CreateMessageRequest(CreateMessageRequestGenerated): + attachments: Optional[List[MessageObjectAttachmentsInner]] = Field(default=None, description="A list of files attached to the message, and the tools they should be added to.") diff --git a/impl/model_v2/message_object_attatchments_inner.py b/impl/model_v2/message_object_attatchments_inner.py new file mode 100644 index 0000000..2ee7884 --- /dev/null +++ b/impl/model_v2/message_object_attatchments_inner.py @@ -0,0 +1,9 @@ +from typing import Optional, Any, List + +from pydantic import Field + +from openapi_server_v2.models.message_object_attachments_inner import MessageObjectAttachmentsInner as MessageObjectAttachmentsInnerGenerated + + +class MessageObjectAttachmentsInner(MessageObjectAttachmentsInnerGenerated): + tools: Optional[List[Any]] = Field(default=None, description="The tools to add this file to.") diff --git a/impl/routes_v2/threads_v2.py b/impl/routes_v2/threads_v2.py index 97e7f7f..5fee478 100644 --- a/impl/routes_v2/threads_v2.py +++ b/impl/routes_v2/threads_v2.py @@ -25,9 +25,13 @@ from impl.services.inference_utils import get_chat_completion, get_async_chat_completion_response from impl.utils import map_model, store_object, read_object, read_objects, generate_id from impl.model_v2.create_thread_and_run_request import CreateThreadAndRunRequest +from model_v2.create_message_request import CreateMessageRequest from openapi_server_v2.models.assistants_api_response_format_option import AssistantsApiResponseFormatOption from openapi_server_v2.models.assistants_api_tool_choice_option import AssistantsApiToolChoiceOption from openapi_server_v2.models.message_delta_object_delta_content_inner import MessageDeltaObjectDeltaContentInner +from openapi_server_v2.models.message_object_attachments_inner import MessageObjectAttachmentsInner +from openapi_server_v2.models.message_object_attachments_inner_tools_inner import \ + MessageObjectAttachmentsInnerToolsInner from openapi_server_v2.models.message_stream_event import MessageStreamEvent from openapi_server_v2.models.run_step_delta_object_delta_step_details import RunStepDeltaObjectDeltaStepDetails from openapi_server_v2.models.run_step_delta_step_details_tool_calls_object_tool_calls_inner import \ @@ -37,7 +41,6 @@ from openapi_server_v2.models.truncation_object import TruncationObject from openapi_server_v2.models.assistant_stream_event import AssistantStreamEvent -from openapi_server_v2.models.create_message_request import CreateMessageRequest from openapi_server_v2.models.create_thread_request import CreateThreadRequest from openapi_server_v2.models.delete_message_response import DeleteMessageResponse from openapi_server_v2.models.delete_thread_response import DeleteThreadResponse @@ -224,7 +227,7 @@ async def delete_thread( }, tags=["Assistants"], summary="Create a message.", - response_model=MessageObject + response_model=None ) async def create_message( thread_id: str = Path( @@ -233,7 +236,7 @@ async def create_message( ), create_message_request: CreateMessageRequest = Body(None, description=""), astradb: CassandraClient = Depends(verify_db_client), -) -> MessageObject: +) -> Any: created_at = int(time.mktime(datetime.now().timetuple()) * 1000) message_id = generate_id("msg") @@ -253,7 +256,26 @@ async def create_message( "object": "thread.message", "content": [content] } - return await store_object(astradb=astradb, obj=create_message_request, target_class=MessageObject, table_name="messages_v2", extra_fields=extra_fields) + + # Handle attachments if present + if create_message_request.attachments: + attachments = [] + for attachment in create_message_request.attachments: + tools = [] + for tool in attachment.tools: + if tool is None: + raise HTTPException(status_code=400, detail="Tool is required for attachment. Pass , \"tools\": [" + "{\"type\": \"file_search\"}]") + tool_obj = MessageObjectAttachmentsInnerToolsInner(actual_instance=tool) + tools.append(tool_obj) + attachment_obj = MessageObjectAttachmentsInner(file_id=attachment.file_id, tools=tools) + attachments.append(attachment_obj) + extra_fields["attachments"] = attachments + + create_message_request.attachments = None + + stored_message = await store_object(astradb=astradb, obj=create_message_request, target_class=MessageObject, table_name="messages_v2", extra_fields=extra_fields) + return stored_message.to_dict() @@ -287,14 +309,21 @@ async def get_message( def messages_json_to_objects(raw_messages): messages = [] for raw_message in raw_messages: - if 'content' in raw_message and raw_message['content'] is not None: - content_array = raw_message['content'].copy() + message_copy = raw_message.copy() + if 'content' in message_copy and message_copy['content'] is not None: i=0 - for raw_content in content_array: + for raw_content in message_copy['content']: content = MessageContentTextObject.from_json(raw_content) - raw_message['content'][i] = content + message_copy['content'][i] = content i+=1 - message = MessageObject(**raw_message) + i=0 + attachments = message_copy['attachments'] + if attachments is not None: + for raw_attachment in attachments: + attachment = MessageObjectAttachmentsInner.from_json(raw_attachment) + message_copy['attachments'][i] = attachment + i+=1 + message = MessageObject(**message_copy) messages.append(message) return messages @@ -464,7 +493,6 @@ async def run_event_stream(run, message_id, astradb): # data.delta.step_details_tool_calls step_details = RunStepObjectStepDetails( actual_instance=RunStepDetailsToolCallsObject(type="tool_calls", tool_calls=[]) - ) run_step_id = message_id.replace("msg_", "step_") @@ -482,8 +510,8 @@ async def run_event_stream(run, message_id, astradb): last_error=None, expired_at=None, cancelled_at=None, - completed_at=None, failed_at=None, + completed_at=None, metadata=None, usage=None, ) @@ -711,13 +739,12 @@ async def stream_message_events(astradb, thread_id, limit, order, after, before, last_message.content = message.content break except Exception as e: - logger.error(f"Error in stream message events, dbid: {astradb.dbid}, error: {e}") + logger.error(f"Error in stream message events, dbid: {astradb.dbid}, error: {e}\ntrace: {traceback.format_exc()}") # TODO - cancel run, mark message incomplete # yield f"data: []" -# TODO - add attachments? async def init_message(thread_id, assistant_id, run_id, astradb, created_at, content=None): if content is None: content = [] @@ -851,7 +878,7 @@ async def create_run( for tool_obj in tools: tool = tool_obj.actual_instance if tool.type == "file_search": - created_at = int(time.mktime(datetime.now().timetuple()) * 1000) + created_at = int(time.mktime(datetime.now().timetuple())*1000) # initialize message message_id = await init_message(thread_id=thread_id, assistant_id=assistant.id, run_id=run_id, astradb=astradb, created_at=created_at) @@ -952,7 +979,7 @@ async def create_run( except Exception as e: status = "completed" message_id = generate_id("msg") - created_at = int(time.mktime(datetime.now().timetuple()) * 1000) + created_at = int(time.mktime(datetime.now().timetuple())*1000) content = MessageContentTextObject( text=MessageContentTextObjectText( @@ -987,7 +1014,7 @@ async def create_run( status = "requires_action" message_id = generate_id("msg") - created_at = int(time.mktime(datetime.now().timetuple()) * 1000) + created_at = int(time.mktime(datetime.now().timetuple())*1000) # groq can't handle an assistant call with no content and perplexity can't handle non-alternating user/assistant messages if message.content is None: @@ -1133,8 +1160,6 @@ def summarize_message_content(instructions, messages, filter_user_messages=False return message_content # maybe trim message history? - - # https://platform.openai.com/docs/assistants/tools/file-search/how-it-works async def process_rag( run_id, thread_id, tool_resources, messages, model, instructions, astradb, litellm_kwargs, embedding_model, @@ -1146,7 +1171,7 @@ async def process_rag( message_content = [] if run_step_id is not None: message_content = summarize_message_content(instructions, messages, True) - search_string_messages = message_content.copy() + search_string_messages = [message_content[len(message_content) - 1]].copy() # TODO: enforce this with instructor? search_string_prompt = "There's a corpus of files that are relevant to your task. You can search these with semantic search. Based on the conversation so far what search string would you search for to better inform your next response (REPLY ONLY WITH THE SEARCH STRING)?" @@ -1164,6 +1189,14 @@ async def process_rag( search_string = search_completion_response.content logger.debug(f"ANN search_string {search_string}") + # Get file IDs from message attachments + message_attachment_file_ids = [] + for message in messages: + if message.attachments: + for attachment in message.attachments: + if any(tool.actual_instance.type == "file_search" for tool in attachment.tools): + message_attachment_file_ids.append(attachment.file_id) + file_ids = [] if tool_resources.file_search is not None: if tool_resources.file_search.vector_store_ids is not None: @@ -1171,6 +1204,10 @@ async def process_rag( vector_store_files = await read_vsf(vector_store_id=vector_store_id, astradb=astradb) for vector_store_file in vector_store_files: file_ids.append(vector_store_file.id) + + # Add message attachment file IDs + file_ids.extend(message_attachment_file_ids) + if len(file_ids) > 0: created_at = int(time.mktime(datetime.now().timetuple())*1000) context_json = astradb.annSearch( @@ -1865,6 +1902,3 @@ async def make_text_delta_obj_from_chunk(chunk, i, run, message_id): ) ) return message_delta - - -