Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
codingjoe committed Apr 17, 2024
1 parent 55278f4 commit 08ab029
Show file tree
Hide file tree
Showing 11 changed files with 543 additions and 259 deletions.
16 changes: 16 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,19 @@ jobs:
cache-dependency-path: 'pyproject.toml'
- run: python -m pip install -e .[lint]
- run: ${{ matrix.lint-command }}

pytest:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.x"
cache: 'pip'
cache-dependency-path: 'pyproject.toml'
- run: python -m pip install -e .[test]
- run: python -m pytest --cov=sam
- uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: voiio/sam
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ sam = "sam.__main__:cli"
test = [
"pytest",
"pytest-cov",
"pytest-asyncio",
]
lint = [
"bandit==1.7.8",
Expand Down
10 changes: 8 additions & 2 deletions sam/__main__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio

Check warning on line 1 in sam/__main__.py

View check run for this annotation

Codecov / codecov/patch

sam/__main__.py#L1

Added line #L1 was not covered by tests
import logging
import sys

import click
from slack_bolt.adapter.socket_mode import SocketModeHandler

from . import config

Expand All @@ -26,9 +26,15 @@ def run():
@run.command()
def slack():
"""Run the Slack bot demon."""
from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler

Check warning on line 29 in sam/__main__.py

View check run for this annotation

Codecov / codecov/patch

sam/__main__.py#L29

Added line #L29 was not covered by tests

from .slack import app

SocketModeHandler(app, config.SLACK_APP_TOKEN).start()
loop = asyncio.get_event_loop()

Check warning on line 33 in sam/__main__.py

View check run for this annotation

Codecov / codecov/patch

sam/__main__.py#L33

Added line #L33 was not covered by tests

loop.run_until_complete(

Check warning on line 35 in sam/__main__.py

View check run for this annotation

Codecov / codecov/patch

sam/__main__.py#L35

Added line #L35 was not covered by tests
AsyncSocketModeHandler(app, config.SLACK_APP_TOKEN).start_async()
)


if __name__ == "__main__":
Expand Down
115 changes: 115 additions & 0 deletions sam/bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import json
import logging

import openai

from . import tools, utils
from .typing import Roles, RunStatus

logger = logging.getLogger(__name__)


async def complete_run(run_id: str, thread_id: str, retry: int = 0):
"""
Wait for the run to complete.
Run and submit tool outputs if required.
Raises:
RecursionError: If the run status is not "completed" after 10 retries.
"""
client: openai.AsyncOpenAI = openai.AsyncOpenAI()
if retry > 10:
await client.beta.threads.runs.cancel(run_id=run_id, thread_id=thread_id)
raise RecursionError("Max retries exceeded")
run = await client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run_id)
logger.info("Run %s status: %s", run.id, run.status) # noqa
match run.status:
case status if status in [RunStatus.QUEUED, RunStatus.IN_PROGRESS]:
await utils.backoff(retry)
await complete_run(run_id, thread_id, retry=retry + 1)
case RunStatus.REQUIRES_ACTION:
if (
run.required_action
and run.required_action.submit_tool_outputs
and run.required_action.submit_tool_outputs.tool_calls
):
tool_outputs = []
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
kwargs = json.loads(tool_call.function.arguments)
try:
fn = getattr(tools, tool_call.function.name)
except KeyError:
logger.exception(

Check warning on line 43 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L42-L43

Added lines #L42 - L43 were not covered by tests
"Tool %s not found, cancelling run %s",
tool_call.function.name,
run_id,
)
await client.beta.threads.runs.cancel(

Check warning on line 48 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L48

Added line #L48 was not covered by tests
run_id=run_id, thread_id=thread_id
)
tool_outputs.append(
{
"tool_call_id": tool_call.id, # noqa
"output": await fn(**kwargs),
}
)

await client.beta.threads.runs.submit_tool_outputs(
run.id, # noqa
thread_id=thread_id,
tool_outputs=tool_outputs,
)
await complete_run(run_id, thread_id) # we reset the retry counter


async def run(
assistant_id: str, thread_id: str, additional_instructions: str = None
) -> str:
"""Run the assistant on the OpenAI thread."""
logger.info(
"Running assistant %s in thread %s with additional instructions: %s",
assistant_id, # noqa
thread_id,
additional_instructions,
)
client: openai.AsyncOpenAI = openai.AsyncOpenAI()
_run = await client.beta.threads.runs.create(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
tools=[
utils.func_to_tool(tools.send_email),
utils.func_to_tool(tools.web_search),
utils.func_to_tool(tools.fetch_website),
],
)
await complete_run(_run.id, thread_id)

messages = await client.beta.threads.messages.list(thread_id=thread_id)
for message in messages.data:
if message.role == Roles.ASSISTANT:
message_content = message.content[0].text

annotations = message_content.annotations
citations = []

# Iterate over the annotations and add footnotes
for index, annotation in enumerate(annotations):
message_content.value = message_content.value.replace(

Check warning on line 99 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L99

Added line #L99 was not covered by tests
annotation.text, f" [{index}]"
)

if file_citation := getattr(annotation, "file_citation", None):
cited_file = client.files.retrieve(file_citation.file_id)
citations.append(

Check warning on line 105 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L103-L105

Added lines #L103 - L105 were not covered by tests
f"[{index}] {file_citation.quote}{cited_file.filename}"
)
elif file_path := getattr(annotation, "file_path", None):
cited_file = client.files.retrieve(file_path.file_id)
citations.append(f"[{index}]({cited_file.filename})")

Check warning on line 110 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L108-L110

Added lines #L108 - L110 were not covered by tests

# Add footnotes to the end of the message before displaying to user
message_content.value += "\n" + "\n".join(citations)

return message_content
Loading

0 comments on commit 08ab029

Please sign in to comment.