Skip to content

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
codingjoe committed Apr 20, 2024
1 parent 8b73624 commit 41848be
Show file tree
Hide file tree
Showing 9 changed files with 584 additions and 180 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ jobs:

pytest:
runs-on: ubuntu-latest
services:
redis:
image: redis
ports:
- 6379:6379
options: --entrypoint redis-server
env:
REDIS_URL: redis:///
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ fallback_version = "1.0.0"

[tool.pytest.ini_options]
minversion = "6.0"
addopts = "--cov --tb=short -rxs"
addopts = "--tb=short -rxs"
testpaths = ["tests"]

[tool.pytest_env]
Expand Down
8 changes: 5 additions & 3 deletions sam/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def cli():
def run(verbose):
"""Run an assistent bot."""
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s"))
handler.setFormatter(

Check warning on line 25 in sam/__main__.py

View check run for this annotation

Codecov / codecov/patch

sam/__main__.py#L25

Added line #L25 was not covered by tests
logging.Formatter("%(asctime)s %(levelname)7s %(name)s - %(message)s")
)
logging.basicConfig(
handlers=[handler], level=logging.DEBUG if verbose else logging.INFO
)
Expand All @@ -33,12 +35,12 @@ def slack():
"""Run the Slack bot demon."""
from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler

from .slack import app
from .slack import get_app

Check warning on line 38 in sam/__main__.py

View check run for this annotation

Codecov / codecov/patch

sam/__main__.py#L38

Added line #L38 was not covered by tests

loop = asyncio.get_event_loop()

loop.run_until_complete(
AsyncSocketModeHandler(app, config.SLACK_APP_TOKEN).start_async()
AsyncSocketModeHandler(get_app(), config.SLACK_APP_TOKEN).start_async()
)


Expand Down
217 changes: 153 additions & 64 deletions sam/bot.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import json
import logging
from pathlib import Path

import openai

from . import config, tools, utils
from .typing import Roles, RunStatus
from .typing import AUDIO_FORMATS, Roles, RunStatus

logger = logging.getLogger(__name__)

Expand All @@ -17,6 +20,8 @@ async def complete_run(run_id: str, thread_id: str, *, retry: int = 0, **context
Raises:
RecursionError: If the run status is not "completed" after 10 retries.
IOError: If the run status is not "completed" or "requires_action".
ValueError: If the run requires tools but none are provided.
"""
client: openai.AsyncOpenAI = openai.AsyncOpenAI()
if retry > 10:
Expand All @@ -29,56 +34,60 @@ async def complete_run(run_id: str, thread_id: str, *, retry: int = 0, **context
await utils.backoff(retry)
await complete_run(run_id, thread_id, retry=retry + 1, **context)
case RunStatus.REQUIRES_ACTION:
if (
run.required_action
and run.required_action.submit_tool_outputs
and run.required_action.submit_tool_outputs.tool_calls
):
tool_outputs = []
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
kwargs = json.loads(tool_call.function.arguments)
try:
fn = getattr(tools, tool_call.function.name)
except KeyError:
logger.exception(
"Tool %s not found, cancelling run %s",
tool_call.function.name,
run_id,
)
await client.beta.threads.runs.cancel(
run_id=run_id, thread_id=thread_id
)
return
logger.info("Running tool %s", tool_call.function.name)
logger.debug(
"Tool %s arguments: %r", tool_call.function.name, kwargs
)
tool_outputs.append(
{
"tool_call_id": tool_call.id,
"output": fn(**kwargs, _context={**context}),
}
)
logger.info("Submitting tool outputs for run %s", run_id)
logger.debug("Tool outputs: %r", tool_outputs)
await client.beta.threads.runs.submit_tool_outputs(
run.id, # noqa
thread_id=thread_id,
tool_outputs=tool_outputs,
)
await complete_run(
run_id, thread_id, **context
) # we reset the retry counter
await call_tools(run, **context)
# after we submit the tool outputs, we reset the retry counter
await complete_run(run_id, thread_id, **context)
case RunStatus.COMPLETED:
return
case _:
raise IOError(f"Run {run.id} failed with status {run.status}")


async def run(
async def call_tools(run: openai.types.beta.threads.Run, **context) -> None:
"""
Call the tools required by the run.
Raises:
IOError: If a tool is not found.
ValueError: If the run does not require any tools.
"""
client: openai.AsyncOpenAI = openai.AsyncOpenAI()
if not (run.required_action and run.required_action.submit_tool_outputs):
raise ValueError(f"Run {run.id} does not require any tools")

Check warning on line 56 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L56

Added line #L56 was not covered by tests
tool_outputs = []
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
kwargs = json.loads(tool_call.function.arguments)
try:
fn = getattr(tools, tool_call.function.name)
except KeyError as e:
await client.beta.threads.runs.cancel(

Check warning on line 63 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L62-L63

Added lines #L62 - L63 were not covered by tests
run_id=run.id, thread_id=run.thread_id
)
raise IOError(

Check warning on line 66 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L66

Added line #L66 was not covered by tests
f"Tool {tool_call.function.name} not found, cancelling run {run.id}"
) from e
logger.info("Running tool %s", tool_call.function.name)
logger.debug("Tool %s arguments: %r", tool_call.function.name, kwargs)
tool_outputs.append(
{
"tool_call_id": tool_call.id,
"output": fn(**kwargs, _context={**context}),
}
)
logger.info("Submitting tool outputs for run %s", run.id)
logger.debug("Tool outputs: %r", tool_outputs)
await client.beta.threads.runs.submit_tool_outputs(
run.id, # noqa
thread_id=run.thread_id,
tool_outputs=tool_outputs,
)


async def execute_run(
assistant_id: str,
thread_id: str,
additional_instructions: str = None,
file_ids: list[str] = None,
**context,
) -> str:
"""Run the assistant on the OpenAI thread."""
Expand All @@ -90,7 +99,7 @@ async def run(
logger.debug("Additional instructions: %r", additional_instructions)
logger.debug("Context: %r", context)
client: openai.AsyncOpenAI = openai.AsyncOpenAI()
_run = await client.beta.threads.runs.create(
run = await client.beta.threads.runs.create(
thread_id=thread_id,
assistant_id=assistant_id,
additional_instructions=additional_instructions,
Expand All @@ -104,35 +113,115 @@ async def run(
],
)
try:
await complete_run(_run.id, thread_id, **context)
except IOError:
logger.exception("Run %s failed", _run.id)
await complete_run(run.id, thread_id, **context)
except (RecursionError, IOError, ValueError):
logger.exception("Run %s failed", run.id)
return "🤯"

try:
return await fetch_latest_assistant_message(thread_id)
except ValueError:
logger.exception("No assistant message found")
return "🤯"


async def fetch_latest_assistant_message(thread_id: str) -> str:
"""
Fetch the latest assistant message from the thread.
Raises:
ValueError: If no assistant message is found.
"""
client: openai.AsyncOpenAI = openai.AsyncOpenAI()
messages = await client.beta.threads.messages.list(thread_id=thread_id)
for message in messages.data:
if message.role == Roles.ASSISTANT:
message_content = message.content[0].text
try:
return await annotate_citations(message.content[0].text)
except IndexError as e:
raise ValueError("No assistant message found") from e

Check warning on line 142 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L141-L142

Added lines #L141 - L142 were not covered by tests


annotations = message_content.annotations
citations = []
async def annotate_citations(
message_content: openai.types.beta.threads.TextContentBlock,
) -> str:
"""Annotate citations in the text using footnotes and the file metadata."""
client: openai.AsyncOpenAI = openai.AsyncOpenAI()
citations = []

# Iterate over the annotations and add footnotes
for index, annotation in enumerate(annotations):
message_content.value = message_content.value.replace(
annotation.text, f" [{index}]"
)
# Iterate over the annotations and add footnotes
for index, annotation in enumerate(message_content.annotations):
message_content.value = message_content.value.replace(

Check warning on line 154 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L154

Added line #L154 was not covered by tests
annotation.text, f" [{index}]"
)

if file_citation := getattr(annotation, "file_citation", None):
cited_file = client.files.retrieve(file_citation.file_id)
citations.append(
f"[{index}] {file_citation.quote}{cited_file.filename}"
if file_citation := getattr(annotation, "file_citation", None):
cited_file = await client.files.retrieve(file_citation.file_id)
citations.append(f"[{index}] {file_citation.quote}{cited_file.filename}")
elif file_path := getattr(annotation, "file_path", None):
cited_file = await client.files.retrieve(file_path.file_id)
citations.append(f"[{index}]({cited_file.filename})")

Check warning on line 163 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L158-L163

Added lines #L158 - L163 were not covered by tests

# Add footnotes to the end of the message before displaying to user
message_content.value += "\n" + "\n".join(citations)
return message_content.value


async def add_message(
thread_id: str,
content: str,
files: [(str, bytes)] = None,
) -> tuple[list[str], bool]:
"""Add a message to the thread."""
logger.info(f"Adding message to thread={thread_id}")
client: openai.AsyncOpenAI = openai.AsyncOpenAI()
file_ids = []
voice_prompt = False
for file_name, file_content in files or []:
if Path(file_name).suffix.lstrip(".") in AUDIO_FORMATS:
logger.debug("Transcribing audio file %s", file_name)
content += (

Check warning on line 183 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L176-L183

Added lines #L176 - L183 were not covered by tests
"\n"
+ (
await client.audio.transcriptions.create(
model="whisper-1",
file=(file_name, file_content),
)
elif file_path := getattr(annotation, "file_path", None):
cited_file = client.files.retrieve(file_path.file_id)
citations.append(f"[{index}]({cited_file.filename})")
).text
)
voice_prompt = True

Check warning on line 192 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L192

Added line #L192 was not covered by tests
else:
logger.debug("Uploading file %s", file_name)
new_file = await client.files.create(

Check warning on line 195 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L194-L195

Added lines #L194 - L195 were not covered by tests
file=(file_name, file_content),
purpose="assistants",
)
file_ids.append(new_file.id)
await client.beta.threads.messages.create(

Check warning on line 200 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L199-L200

Added lines #L199 - L200 were not covered by tests
thread_id=thread_id,
content=content,
role="assistant",
file_ids=file_ids,
)
return file_ids, voice_prompt

Check warning on line 206 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L206

Added line #L206 was not covered by tests

# Add footnotes to the end of the message before displaying to user
message_content.value += "\n" + "\n".join(citations)

return message_content
async def tts(text: str) -> bytes:
"""Convert text to speech using the OpenAI API."""
client: openai.AsyncOpenAI = openai.AsyncOpenAI()
response = await client.audio.speech.create(

Check warning on line 212 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L211-L212

Added lines #L211 - L212 were not covered by tests
model=config.TTS_MODEL,
voice=config.TTS_VOICE,
input=text,
)
return response.read()

Check warning on line 217 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L217

Added line #L217 was not covered by tests


async def stt(audio: bytes) -> str:
"""Convert speech to text using the OpenAI API."""
client: openai.AsyncOpenAI = openai.AsyncOpenAI()
response = await client.audio.transcriptions.create(

Check warning on line 223 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L222-L223

Added lines #L222 - L223 were not covered by tests
model="whisper-1",
file=audio,
)
return response.text

Check warning on line 227 in sam/bot.py

View check run for this annotation

Codecov / codecov/patch

sam/bot.py#L227

Added line #L227 was not covered by tests
2 changes: 2 additions & 0 deletions sam/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
SLACK_APP_TOKEN = os.getenv("SLACK_APP_TOKEN")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_ASSISTANT_ID = os.getenv("OPENAI_ASSISTANT_ID")
TTS_VOICE = os.getenv("TTS_VOICE", "alloy")
TTS_MODEL = os.getenv("TTS_MODEL", "tts-1-hd")
REDIS_URL = os.getenv("REDIS_URL", "redis:///")
RANDOM_RUN_RATIO = float(os.getenv("RANDOM_RUN_RATIO", "0"))
TIMEZONE = ZoneInfo(os.getenv("TIMEZONE", "UTC"))
Expand Down
Loading

0 comments on commit 41848be

Please sign in to comment.