Skip to content

Commit

Permalink
handle attatchents in run
Browse files Browse the repository at this point in the history
  • Loading branch information
phact committed Nov 22, 2024
1 parent a752ff0 commit e164128
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 26 deletions.
17 changes: 16 additions & 1 deletion client/tests/astra-assistants/test_run_retreival_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,23 @@ def run_with_assistant(assistant, client, file_path, embedding_model):
user_message = "What are some cool math concepts behind this ML paper pdf? Explain in two sentences."
logger.info("creating persistent thread and message")
thread = client.beta.threads.create()

# Create a message with an attachment that has file_search enabled
file2 = client.files.create(
file=open(
"./tests/fixtures/hudson.txt",
"rb",
),
purpose="assistants",
embedding_model=embedding_model,
)

client.beta.threads.messages.create(
thread_id=thread.id, role="user", content=user_message
thread_id=thread.id,
role="user",
content=user_message,
file_ids=[file2.id],
tools=[{"type": "file_search"}]
)
logger.info(f"> {user_message}")

Expand Down
47 changes: 44 additions & 3 deletions client/tests/astra-assistants/test_streaming_run_retrieval_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time

from openai.lib.streaming import AssistantEventHandler
from openai.types.beta.threads.message_create_params import Attachment
from typing_extensions import override

def run_with_assistant(assistant, client):
Expand Down Expand Up @@ -53,11 +54,26 @@ def run_with_assistant(assistant, client):
tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}},
)
print(f"updated assistant: {assistant}")
user_message = "What are some cool math concepts behind this ML paper pdf? Explain in two sentences."
print("creating persistent thread and message")
thread = client.beta.threads.create()

# Create a message with an attachment that has file_search enabled
file2 = client.files.create(
file=open(
"./tests/fixtures/hudson.txt",
"rb",
),
purpose="assistants",
)

user_message = "What are some cool math concepts behind this ML paper pdf? Explain in two sentences."
client.beta.threads.messages.create(
thread_id=thread.id, role="user", content=user_message
thread_id=thread.id,
role="user",
content=user_message,
attachments=[
Attachment(file_id=file2.id, tools=[{"type": "file_search"}]),
]
)
print(f"> {user_message}")

Expand All @@ -74,7 +90,7 @@ def on_run_step_done(self, run_step) -> None:
for tool_call in run_step.step_details.tool_calls:
matches = tool_call.file_search
print(tool_call.file_search)
assert len(matches) > 0, "No matches found"
assert len(matches.chunks) > 0, "No matches found"

@override
def on_text_created(self, text) -> None:
Expand Down Expand Up @@ -102,6 +118,31 @@ def on_text_delta(self, delta, snapshot):
assert event_handler.on_text_created_count > 0, "No text created"
assert event_handler.on_text_delta_count > 0, "No text delta"

user_message = "What is the name of my dog"
client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content=user_message,
attachments=[
Attachment(file_id=file2.id, tools=[{"type": "file_search"}]),
]
)
print(f"> {user_message}")

event_handler = EventHandler()
print(f"creating run")
with client.beta.threads.runs.create_and_stream(
thread_id=thread.id,
assistant_id=assistant.id,
event_handler=event_handler,
) as stream:
for part in stream.text_deltas:
print(part)

assert event_handler.on_text_created_count > 0, "No text created"
assert event_handler.on_text_delta_count > 0, "No text delta"



instructions = "You are a personal math tutor. Answer thoroughly. The system will provide relevant context from files, use the context to respond."

Expand Down
1 change: 1 addition & 0 deletions client/tests/fixtures/hudson.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
the name of my dog is Hudson
10 changes: 10 additions & 0 deletions impl/model_v2/create_message_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from typing import Optional, List

from pydantic import Field

from model_v2.message_object_attatchments_inner import MessageObjectAttachmentsInner
from openapi_server_v2.models.create_message_request import CreateMessageRequest as CreateMessageRequestGenerated


class CreateMessageRequest(CreateMessageRequestGenerated):
attachments: Optional[List[MessageObjectAttachmentsInner]] = Field(default=None, description="A list of files attached to the message, and the tools they should be added to.")
9 changes: 9 additions & 0 deletions impl/model_v2/message_object_attatchments_inner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Optional, Any, List

from pydantic import Field

from openapi_server_v2.models.message_object_attachments_inner import MessageObjectAttachmentsInner as MessageObjectAttachmentsInnerGenerated


class MessageObjectAttachmentsInner(MessageObjectAttachmentsInnerGenerated):
tools: Optional[List[Any]] = Field(default=None, description="The tools to add this file to.")
78 changes: 56 additions & 22 deletions impl/routes_v2/threads_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@
from impl.services.inference_utils import get_chat_completion, get_async_chat_completion_response
from impl.utils import map_model, store_object, read_object, read_objects, generate_id
from impl.model_v2.create_thread_and_run_request import CreateThreadAndRunRequest
from model_v2.create_message_request import CreateMessageRequest
from openapi_server_v2.models.assistants_api_response_format_option import AssistantsApiResponseFormatOption
from openapi_server_v2.models.assistants_api_tool_choice_option import AssistantsApiToolChoiceOption
from openapi_server_v2.models.message_delta_object_delta_content_inner import MessageDeltaObjectDeltaContentInner
from openapi_server_v2.models.message_object_attachments_inner import MessageObjectAttachmentsInner
from openapi_server_v2.models.message_object_attachments_inner_tools_inner import \
MessageObjectAttachmentsInnerToolsInner
from openapi_server_v2.models.message_stream_event import MessageStreamEvent
from openapi_server_v2.models.run_step_delta_object_delta_step_details import RunStepDeltaObjectDeltaStepDetails
from openapi_server_v2.models.run_step_delta_step_details_tool_calls_object_tool_calls_inner import \
Expand All @@ -37,7 +41,6 @@

from openapi_server_v2.models.truncation_object import TruncationObject
from openapi_server_v2.models.assistant_stream_event import AssistantStreamEvent
from openapi_server_v2.models.create_message_request import CreateMessageRequest
from openapi_server_v2.models.create_thread_request import CreateThreadRequest
from openapi_server_v2.models.delete_message_response import DeleteMessageResponse
from openapi_server_v2.models.delete_thread_response import DeleteThreadResponse
Expand Down Expand Up @@ -224,7 +227,7 @@ async def delete_thread(
},
tags=["Assistants"],
summary="Create a message.",
response_model=MessageObject
response_model=None
)
async def create_message(
thread_id: str = Path(
Expand All @@ -233,7 +236,7 @@ async def create_message(
),
create_message_request: CreateMessageRequest = Body(None, description=""),
astradb: CassandraClient = Depends(verify_db_client),
) -> MessageObject:
) -> Any:
created_at = int(time.mktime(datetime.now().timetuple()) * 1000)
message_id = generate_id("msg")

Expand All @@ -253,7 +256,26 @@ async def create_message(
"object": "thread.message",
"content": [content]
}
return await store_object(astradb=astradb, obj=create_message_request, target_class=MessageObject, table_name="messages_v2", extra_fields=extra_fields)

# Handle attachments if present
if create_message_request.attachments:
attachments = []
for attachment in create_message_request.attachments:
tools = []
for tool in attachment.tools:
if tool is None:
raise HTTPException(status_code=400, detail="Tool is required for attachment. Pass , \"tools\": ["
"{\"type\": \"file_search\"}]")
tool_obj = MessageObjectAttachmentsInnerToolsInner(actual_instance=tool)
tools.append(tool_obj)
attachment_obj = MessageObjectAttachmentsInner(file_id=attachment.file_id, tools=tools)
attachments.append(attachment_obj)
extra_fields["attachments"] = attachments

create_message_request.attachments = None

stored_message = await store_object(astradb=astradb, obj=create_message_request, target_class=MessageObject, table_name="messages_v2", extra_fields=extra_fields)
return stored_message.to_dict()



Expand Down Expand Up @@ -287,14 +309,21 @@ async def get_message(
def messages_json_to_objects(raw_messages):
messages = []
for raw_message in raw_messages:
if 'content' in raw_message and raw_message['content'] is not None:
content_array = raw_message['content'].copy()
message_copy = raw_message.copy()
if 'content' in message_copy and message_copy['content'] is not None:
i=0
for raw_content in content_array:
for raw_content in message_copy['content']:
content = MessageContentTextObject.from_json(raw_content)
raw_message['content'][i] = content
message_copy['content'][i] = content
i+=1
message = MessageObject(**raw_message)
i=0
attachments = message_copy['attachments']
if attachments is not None:
for raw_attachment in attachments:
attachment = MessageObjectAttachmentsInner.from_json(raw_attachment)
message_copy['attachments'][i] = attachment
i+=1
message = MessageObject(**message_copy)
messages.append(message)
return messages

Expand Down Expand Up @@ -464,7 +493,6 @@ async def run_event_stream(run, message_id, astradb):
# data.delta.step_details_tool_calls
step_details = RunStepObjectStepDetails(
actual_instance=RunStepDetailsToolCallsObject(type="tool_calls", tool_calls=[])

)

run_step_id = message_id.replace("msg_", "step_")
Expand All @@ -482,8 +510,8 @@ async def run_event_stream(run, message_id, astradb):
last_error=None,
expired_at=None,
cancelled_at=None,
completed_at=None,
failed_at=None,
completed_at=None,
metadata=None,
usage=None,
)
Expand Down Expand Up @@ -711,13 +739,12 @@ async def stream_message_events(astradb, thread_id, limit, order, after, before,
last_message.content = message.content
break
except Exception as e:
logger.error(f"Error in stream message events, dbid: {astradb.dbid}, error: {e}")
logger.error(f"Error in stream message events, dbid: {astradb.dbid}, error: {e}\ntrace: {traceback.format_exc()}")
# TODO - cancel run, mark message incomplete
# yield f"data: []"



# TODO - add attachments?
async def init_message(thread_id, assistant_id, run_id, astradb, created_at, content=None):
if content is None:
content = []
Expand Down Expand Up @@ -851,7 +878,7 @@ async def create_run(
for tool_obj in tools:
tool = tool_obj.actual_instance
if tool.type == "file_search":
created_at = int(time.mktime(datetime.now().timetuple()) * 1000)
created_at = int(time.mktime(datetime.now().timetuple())*1000)

# initialize message
message_id = await init_message(thread_id=thread_id, assistant_id=assistant.id, run_id=run_id, astradb=astradb, created_at=created_at)
Expand Down Expand Up @@ -952,7 +979,7 @@ async def create_run(
except Exception as e:
status = "completed"
message_id = generate_id("msg")
created_at = int(time.mktime(datetime.now().timetuple()) * 1000)
created_at = int(time.mktime(datetime.now().timetuple())*1000)

content = MessageContentTextObject(
text=MessageContentTextObjectText(
Expand Down Expand Up @@ -987,7 +1014,7 @@ async def create_run(
status = "requires_action"

message_id = generate_id("msg")
created_at = int(time.mktime(datetime.now().timetuple()) * 1000)
created_at = int(time.mktime(datetime.now().timetuple())*1000)

# groq can't handle an assistant call with no content and perplexity can't handle non-alternating user/assistant messages
if message.content is None:
Expand Down Expand Up @@ -1133,8 +1160,6 @@ def summarize_message_content(instructions, messages, filter_user_messages=False
return message_content # maybe trim message history?




# https://platform.openai.com/docs/assistants/tools/file-search/how-it-works
async def process_rag(
run_id, thread_id, tool_resources, messages, model, instructions, astradb, litellm_kwargs, embedding_model,
Expand All @@ -1146,7 +1171,7 @@ async def process_rag(
message_content = []
if run_step_id is not None:
message_content = summarize_message_content(instructions, messages, True)
search_string_messages = message_content.copy()
search_string_messages = [message_content[len(message_content) - 1]].copy()

# TODO: enforce this with instructor?
search_string_prompt = "There's a corpus of files that are relevant to your task. You can search these with semantic search. Based on the conversation so far what search string would you search for to better inform your next response (REPLY ONLY WITH THE SEARCH STRING)?"
Expand All @@ -1164,13 +1189,25 @@ async def process_rag(
search_string = search_completion_response.content
logger.debug(f"ANN search_string {search_string}")

# Get file IDs from message attachments
message_attachment_file_ids = []
for message in messages:
if message.attachments:
for attachment in message.attachments:
if any(tool.actual_instance.type == "file_search" for tool in attachment.tools):
message_attachment_file_ids.append(attachment.file_id)

file_ids = []
if tool_resources.file_search is not None:
if tool_resources.file_search.vector_store_ids is not None:
for vector_store_id in tool_resources.file_search.vector_store_ids:
vector_store_files = await read_vsf(vector_store_id=vector_store_id, astradb=astradb)
for vector_store_file in vector_store_files:
file_ids.append(vector_store_file.id)

# Add message attachment file IDs
file_ids.extend(message_attachment_file_ids)

if len(file_ids) > 0:
created_at = int(time.mktime(datetime.now().timetuple())*1000)
context_json = astradb.annSearch(
Expand Down Expand Up @@ -1865,6 +1902,3 @@ async def make_text_delta_obj_from_chunk(chunk, i, run, message_id):
)
)
return message_delta



0 comments on commit e164128

Please sign in to comment.