Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests #42

Merged
merged 3 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,19 @@ jobs:
cache-dependency-path: 'pyproject.toml'
- run: python -m pip install -e .[lint]
- run: ${{ matrix.lint-command }}

pytest:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.x"
cache: 'pip'
cache-dependency-path: 'pyproject.toml'
- run: python -m pip install -e .[test]
- run: python -m pytest --cov=sam
- uses: codecov/codecov-action@v4
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: voiio/sam
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ sam = "sam.__main__:cli"
test = [
"pytest",
"pytest-cov",
"pytest-asyncio",
]
lint = [
"bandit==1.7.8",
Expand Down
10 changes: 8 additions & 2 deletions sam/__main__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio

Check warning on line 1 in sam/__main__.py

View check run for this annotation

Codecov / codecov/patch

sam/__main__.py#L1

Added line #L1 was not covered by tests
import logging
import sys

import click
from slack_bolt.adapter.socket_mode import SocketModeHandler

from . import config

Expand All @@ -26,9 +26,15 @@
@run.command()
def slack():
"""Run the Slack bot demon."""
from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler

Check warning on line 29 in sam/__main__.py

View check run for this annotation

Codecov / codecov/patch

sam/__main__.py#L29

Added line #L29 was not covered by tests

from .slack import app

SocketModeHandler(app, config.SLACK_APP_TOKEN).start()
loop = asyncio.get_event_loop()

Check warning on line 33 in sam/__main__.py

View check run for this annotation

Codecov / codecov/patch

sam/__main__.py#L33

Added line #L33 was not covered by tests

loop.run_until_complete(

Check warning on line 35 in sam/__main__.py

View check run for this annotation

Codecov / codecov/patch

sam/__main__.py#L35

Added line #L35 was not covered by tests
AsyncSocketModeHandler(app, config.SLACK_APP_TOKEN).start_async()
)


if __name__ == "__main__":
Expand Down
115 changes: 115 additions & 0 deletions sam/bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import json
import logging

import openai

from . import tools, utils
from .typing import Roles, RunStatus

logger = logging.getLogger(__name__)


async def complete_run(run_id: str, thread_id: str, retry: int = 0):
"""
Wait for the run to complete.

Run and submit tool outputs if required.

Raises:
RecursionError: If the run status is not "completed" after 10 retries.
"""
client: openai.AsyncOpenAI = openai.AsyncOpenAI()
if retry > 10:
await client.beta.threads.runs.cancel(run_id=run_id, thread_id=thread_id)
raise RecursionError("Max retries exceeded")
run = await client.beta.threads.runs.retrieve(thread_id=thread_id, run_id=run_id)
logger.info("Run %s status: %s", run.id, run.status) # noqa
match run.status:
case status if status in [RunStatus.QUEUED, RunStatus.IN_PROGRESS]:
await utils.backoff(retry)
await complete_run(run_id, thread_id, retry=retry + 1)
case RunStatus.REQUIRES_ACTION:
if (
run.required_action
and run.required_action.submit_tool_outputs
and run.required_action.submit_tool_outputs.tool_calls
):
tool_outputs = []
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
kwargs = json.loads(tool_call.function.arguments)
try:
fn = getattr(tools, tool_call.function.name)
except KeyError:
logger.exception(

Check warning on line 43 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L42-L43

Added lines #L42 - L43 were not covered by tests
"Tool %s not found, cancelling run %s",
tool_call.function.name,
run_id,
)
await client.beta.threads.runs.cancel(

Check warning on line 48 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L48

Added line #L48 was not covered by tests
run_id=run_id, thread_id=thread_id
)
tool_outputs.append(
{
"tool_call_id": tool_call.id, # noqa
"output": await fn(**kwargs),
}
)

await client.beta.threads.runs.submit_tool_outputs(
run.id, # noqa
thread_id=thread_id,
tool_outputs=tool_outputs,
)
await complete_run(run_id, thread_id) # we reset the retry counter


async def run(
assistant_id: str, thread_id: str, additional_instructions: str = None
) -> str:
"""Run the assistant on the OpenAI thread."""
logger.info(
"Running assistant %s in thread %s with additional instructions: %s",
assistant_id, # noqa
thread_id,
additional_instructions,
)
client: openai.AsyncOpenAI = openai.AsyncOpenAI()
_run = await client.beta.threads.runs.create(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
tools=[
utils.func_to_tool(tools.send_email),
utils.func_to_tool(tools.web_search),
utils.func_to_tool(tools.fetch_website),
],
)
await complete_run(_run.id, thread_id)

messages = await client.beta.threads.messages.list(thread_id=thread_id)
for message in messages.data:
if message.role == Roles.ASSISTANT:
message_content = message.content[0].text

annotations = message_content.annotations
citations = []

# Iterate over the annotations and add footnotes
for index, annotation in enumerate(annotations):
message_content.value = message_content.value.replace(

Check warning on line 99 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L99

Added line #L99 was not covered by tests
annotation.text, f" [{index}]"
)

if file_citation := getattr(annotation, "file_citation", None):
cited_file = client.files.retrieve(file_citation.file_id)
citations.append(

Check warning on line 105 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L103-L105

Added lines #L103 - L105 were not covered by tests
f"[{index}] {file_citation.quote} — {cited_file.filename}"
)
elif file_path := getattr(annotation, "file_path", None):
cited_file = client.files.retrieve(file_path.file_id)
citations.append(f"[{index}]({cited_file.filename})")

Check warning on line 110 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L108-L110

Added lines #L108 - L110 were not covered by tests

# Add footnotes to the end of the message before displaying to user
message_content.value += "\n" + "\n".join(citations)

return message_content
Loading