diff --git a/sam/bot.py b/sam/bot.py index 40bdfc8..0e1e626 100644 --- a/sam/bot.py +++ b/sam/bot.py @@ -87,7 +87,7 @@ async def execute_run( assistant_id: str, thread_id: str, additional_instructions: str = None, - file_ids: list[str] = None, + file_search: bool = False, **context, ) -> str: """Run the assistant on the OpenAI thread.""" @@ -173,7 +173,7 @@ async def add_message( thread_id: str, content: str, files: [(str, bytes)] = None, -) -> tuple[list[str], bool]: +) -> tuple[bool, bool]: """Add a message to the thread.""" logger.info(f"Adding message to thread={thread_id}") client: openai.AsyncOpenAI = openai.AsyncOpenAI() @@ -208,7 +208,7 @@ async def add_message( for file_id in file_ids ], ) - return file_ids, voice_prompt + return bool(file_ids), voice_prompt async def tts(text: str) -> bytes: diff --git a/sam/slack.py b/sam/slack.py index 79e9f7c..977aeff 100644 --- a/sam/slack.py +++ b/sam/slack.py @@ -72,7 +72,7 @@ async def handle_message(event: {str, Any}, say: AsyncSay): redis.from_url(config.REDIS_URL) as redis_client, redis_client.lock(thread_id, timeout=10 * 60, thread_local=False), ): # 10 minutes - file_ids, voice_prompt = await bot.add_message( + has_attachments, has_audio = await bot.add_message( thread_id=thread_id, content=text, files=files, @@ -84,7 +84,9 @@ async def handle_message(event: {str, Any}, say: AsyncSay): or event.get("parent_user_id") == bot_id or random.random() < config.RANDOM_RUN_RATIO # nosec ): - await send_response(event, say, file_ids=file_ids, voice_prompt=voice_prompt) + await send_response( + event, say, file_search=has_attachments, voice_response=has_audio + ) @functools.lru_cache(maxsize=128) @@ -116,8 +118,8 @@ def get_user_specific_instructions(user_id: str) -> str: async def send_response( event: {str, Any}, say: AsyncSay, - file_ids: list[str] = None, - voice_prompt: bool = False, + file_search: bool = False, + voice_response: bool = False, ): """Send a response to a message event from Slack.""" logger.debug(f"process_run={json.dumps(event)}") @@ -144,7 +146,7 @@ async def send_response( thread_id=thread_id, assistant_id=config.OPENAI_ASSISTANT_ID, additional_instructions=get_user_specific_instructions(user_id), - file_ids=file_ids, + file_search=file_search, **get_user_profile(user_id), ) @@ -158,7 +160,7 @@ async def send_response( f"Sam responded to the User={user_id} in Channel={channel_id} via Text" ) - if voice_prompt: + if voice_response: await say.client.files_upload_v2( filename="response.mp3", title="Voice Response", diff --git a/tests/test_bot.py b/tests/test_bot.py index 30aa4d3..3dbf680 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -14,6 +14,45 @@ def client(monkeypatch): return client +@pytest.mark.asyncio +async def test_add_message(client): + assert await bot.add_message("thread-1", "Hello", []) == (False, False) + assert client.beta.threads.messages.create.called + assert client.beta.threads.messages.create.call_args == mock.call( + thread_id="thread-1", content="Hello", role="user", attachments=[] + ) + + +@pytest.mark.asyncio +async def test_add_message__file(client): + client.files.create.return_value = namedtuple("File", ["id"])(id="file-123") + assert await bot.add_message("thread-1", "Hello", [("file-1", b"Hello")]) == ( + True, + False, + ) + assert client.beta.threads.messages.create.called + assert client.beta.threads.messages.create.call_args == mock.call( + thread_id="thread-1", + content="Hello", + role="user", + attachments=[{"file_id": "file-123", "tools": [{"type": "file_search"}]}], + ) + + +@pytest.mark.asyncio +async def test_add_message__audio(client): + client.audio.transcriptions.create.return_value = namedtuple( + "Transcription", ["text"] + )(text="World") + assert await bot.add_message("thread-1", "Hello", [("file-1.mp3", b"World")]) == ( + False, + True, + ) + assert client.beta.threads.messages.create.call_args == mock.call( + thread_id="thread-1", content="Hello\nWorld", role="user", attachments=[] + ) + + @pytest.mark.asyncio async def test_complete_run__max_retries(client): client.beta.threads.runs.cancel = mock.AsyncMock() diff --git a/tests/test_slack.py b/tests/test_slack.py index 9d641f7..eac07d0 100644 --- a/tests/test_slack.py +++ b/tests/test_slack.py @@ -37,7 +37,7 @@ async def test_handle_message(monkeypatch): "text": "Hello", "files": [ { - "url_private": "https://audio-samples.github.io/samples/mp3/blizzard_tts_unbiased/sample-0/real.mp3", + "url_private": "https://example.com/file.mp3", "name": "file.mp3", } ], @@ -57,15 +57,12 @@ async def test_handle_message(monkeypatch): "user": "user-1", "text": "Hello", "files": [ - { - "url_private": "https://audio-samples.github.io/samples/mp3/blizzard_tts_unbiased/sample-0/real.mp3", - "name": "file.mp3", - } + {"url_private": "https://example.com/file.mp3", "name": "file.mp3"} ], }, say, - file_ids=["file-1"], - voice_prompt=False, + file_search=["file-1"], + voice_response=False, ) @@ -159,14 +156,14 @@ async def test_send_response(monkeypatch): } ], } - await slack.send_response(event, say, voice_prompt=True) + await slack.send_response(event, say, voice_response=True) assert execute_run.called assert execute_run.call_args == mock.call( thread_id="thread-1", assistant_id=None, additional_instructions="user_instructions", - file_ids=None, + file_search=False, name="Sam", ) assert tts.called @@ -210,14 +207,14 @@ async def test_send_response__thread(monkeypatch): } ], } - await slack.send_response(event, say, voice_prompt=True) + await slack.send_response(event, say, voice_response=True) assert execute_run.called assert execute_run.call_args == mock.call( thread_id="thread-1", assistant_id=None, additional_instructions="user_instructions", - file_ids=None, + file_search=False, name="Sam", ) assert tts.called