diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ad2a396..202fc36 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,8 +27,31 @@ jobs: - run: python -m pip install -e .[lint] - run: ${{ matrix.lint-command }} + + dist: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.x" + - run: python -m pip install --upgrade pip build twine + - run: python -m build --sdist --wheel + - run: python -m twine check dist/* + - uses: actions/upload-artifact@v4 + with: + path: dist/* + pytest: runs-on: ubuntu-latest + services: + redis: + image: redis + ports: + - 6379:6379 + options: --entrypoint redis-server + env: + REDIS_URL: redis:/// steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..5119b56 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,36 @@ +name: Release + +on: + release: + types: [published] + +jobs: + + release-build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: "3.x" + - run: python -m pip install --upgrade pip build + - run: python -m build --sdist --wheel + - uses: actions/upload-artifact@v4 + with: + name: release-dists + path: dist/ + + pypi-publish: + runs-on: ubuntu-latest + needs: + - release-build + permissions: + id-token: write + + steps: + - uses: actions/download-artifact@v4 + with: + name: release-dists + path: dist/ + - uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/README.md b/README.md index 5a1476b..7e75fa7 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,7 @@ +[![PyPi Version](https://img.shields.io/pypi/v/opensam.svg)](https://pypi.python.org/pypi/opensam/) +[![Test Coverage](https://codecov.io/gh/voiio/sam/branch/main/graph/badge.svg)](https://codecov.io/gh/voiio/sam) +[![GitHub License](https://img.shields.io/github/license/voiio/sam)](https://raw.githubusercontent.com/voiio/sam/master/LICENSE) + # Sam – cuz your company is nothing without Sam ![meme](https://repository-images.githubusercontent.com/726003479/24d020ac-3ac5-401c-beae-7a6103c4e7ae) diff --git a/pyproject.toml b/pyproject.toml index 1f59040..d2948a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,9 @@ +[build-system] +requires = ["flit_core>=3.2", "flit_scm", "wheel"] +build-backend = "flit_scm:buildapi" + [project] -name = "Sam" +name = "OpenSam" authors = [ { name = "Johannes Maron", email = "johannes@maron.family" }, ] @@ -7,17 +11,44 @@ readme = "README.md" license = { file = "LICENSE" } keywords = ["GPT", "AI", "Slack", "OpenAI", "bot"] dynamic = ["version", "description"] - +classifiers = [ + "Development Status :: 2 - Pre-Alpha", + "Programming Language :: Python", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Topic :: Software Development", + "Topic :: Communications :: Chat", + "Topic :: Internet :: WWW/HTTP", + "Topic :: Communications :: Email", + "Topic :: Games/Entertainment :: Role-Playing", + "Topic :: Multimedia :: Sound/Audio :: Conversion", + "Topic :: Multimedia :: Sound/Audio :: Speech", + "Topic :: Office/Business", + "Topic :: Office/Business :: Groupware", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Human Machine Interfaces", + "Topic :: Text Processing :: Markup :: Markdown", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] requires-python = ">=3.11" + dependencies = [ "aiohttp", "click", + "certifi", "markdownify", "redis", "requests", "slack-bolt", "openai>=1.21.0", "pyyaml", + "algoliasearch", "sentry-sdk", ] @@ -27,8 +58,9 @@ sam = "sam.__main__:cli" [project.optional-dependencies] test = [ "pytest", - "pytest-cov", "pytest-asyncio", + "pytest-cov", + "pytest-env", ] lint = [ "bandit==1.7.8", @@ -41,10 +73,6 @@ lint = [ Project-URL = "https://github.com/voiio/Sam" Changelog = "https://github.com/voiio/Sam/releases" -[build-system] -requires = ["flit-scm", "wheel"] -build-backend = "flit_scm:buildapi" - [tool.flit.module] name = "sam" @@ -54,9 +82,12 @@ fallback_version = "1.0.0" [tool.pytest.ini_options] minversion = "6.0" -addopts = "--cov --tb=short -rxs" +addopts = "--tb=short -rxs" testpaths = ["tests"] +[tool.pytest_env] +GITHUB_REPOS = 'voiio/sam' + [tool.coverage.run] source = ["sam"] diff --git a/sam/__main__.py b/sam/__main__.py index 37c65c1..66b4e97 100644 --- a/sam/__main__.py +++ b/sam/__main__.py @@ -1,29 +1,40 @@ import asyncio import logging +import os import sys import click +import openai import sentry_sdk +from sentry_sdk.integrations.asyncio import AsyncioIntegration from . import config -sentry_sdk.init(config.SENTRY_DSN) +sentry_sdk.init( + dsn=config.SENTRY_DSN, + enable_tracing=True, + integrations=[ + AsyncioIntegration(), + ], +) @click.group() -@click.option("-v", "--verbose", is_flag=True, help="Enables verbose mode.") -def cli(verbose): +def cli(): """Sam – cuz your company is nothing with Sam.""" - handler = logging.StreamHandler(sys.stdout) - handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s")) - logger = logging.getLogger("sam") - logger.addHandler(handler) - logger.setLevel(logging.DEBUG if verbose else logging.INFO) @cli.group(chain=True) -def run(): - """Run an assistent bot, currently only Slack is supported.""" +@click.option("-v", "--verbose", is_flag=True, help="Enables verbose mode.") +def run(verbose): + """Run an assistent bot.""" + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + logging.Formatter("%(asctime)s %(levelname)7s %(name)s - %(message)s") + ) + logging.basicConfig( + handlers=[handler], level=logging.DEBUG if verbose else logging.INFO + ) @run.command() @@ -31,14 +42,57 @@ def slack(): """Run the Slack bot demon.""" from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler - from .slack import app + from .slack import get_app loop = asyncio.get_event_loop() loop.run_until_complete( - AsyncSocketModeHandler(app, config.SLACK_APP_TOKEN).start_async() + AsyncSocketModeHandler(get_app(), config.SLACK_APP_TOKEN).start_async() + ) + + +@cli.group(chain=True) +@click.option("-v", "--verbose", is_flag=True, help="Enables verbose mode.") +def assistants(verbose): + """Manage OpenAI assistants.""" + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter("%(message)s")) + logging.basicConfig( + handlers=[handler], level=logging.DEBUG if verbose else logging.INFO ) +@assistants.command(name="list") +def _list(): + """List all assistants configured in your project.""" + assistant_list = list(config.load_assistants()) + for assistant_config in assistant_list: + click.echo( + f"{assistant_config.name} ({assistant_config.project}): {assistant_config.assistant_id}" + ) + if not assistant_list: + click.echo("No assistants configured.") + + +@assistants.command() +def upload(): + """Compile and upload all assistants system prompts to OpenAI.""" + assistant_list = list(config.load_assistants()) + for assistant_config in assistant_list: + click.echo(f"Uploading {assistant_config.name}...", nl=False) + project_api_key_name = ( + f"OPENAI_{assistant_config.project.replace('-', '_').upper()}_API_KEY" + ) + project_api_key = os.getenv(project_api_key_name) + with openai.OpenAI(api_key=project_api_key) as client: + client.beta.assistants.update( + assistant_id=assistant_config.assistant_id, + instructions=assistant_config.system_prompt, + ) + click.echo(" Done!") + if not assistant_list: + click.echo("No assistants configured.") + + if __name__ == "__main__": cli() diff --git a/sam/bot.py b/sam/bot.py index 94954c3..40bdfc8 100644 --- a/sam/bot.py +++ b/sam/bot.py @@ -1,10 +1,13 @@ +from __future__ import annotations + import json import logging +from pathlib import Path import openai from . import config, tools, utils -from .typing import Roles, RunStatus +from .typing import AUDIO_FORMATS, Roles, RunStatus logger = logging.getLogger(__name__) @@ -17,6 +20,8 @@ async def complete_run(run_id: str, thread_id: str, *, retry: int = 0, **context Raises: RecursionError: If the run status is not "completed" after 10 retries. + IOError: If the run status is not "completed" or "requires_action". + ValueError: If the run requires tools but none are provided. """ client: openai.AsyncOpenAI = openai.AsyncOpenAI() if retry > 10: @@ -27,71 +32,81 @@ async def complete_run(run_id: str, thread_id: str, *, retry: int = 0, **context match run.status: case status if status in [RunStatus.QUEUED, RunStatus.IN_PROGRESS]: await utils.backoff(retry) - await complete_run(run_id, thread_id, retry=retry + 1) + await complete_run(run_id, thread_id, retry=retry + 1, **context) case RunStatus.REQUIRES_ACTION: - if ( - run.required_action - and run.required_action.submit_tool_outputs - and run.required_action.submit_tool_outputs.tool_calls - ): - tool_outputs = [] - for tool_call in run.required_action.submit_tool_outputs.tool_calls: - kwargs = json.loads(tool_call.function.arguments) - try: - fn = getattr(tools, tool_call.function.name) - except KeyError: - logger.exception( - "Tool %s not found, cancelling run %s", - tool_call.function.name, - run_id, - ) - await client.beta.threads.runs.cancel( - run_id=run_id, thread_id=thread_id - ) - return - logger.info("Running tool %s", tool_call.function.name) - logger.debug( - "Tool %s arguments: %r", tool_call.function.name, kwargs - ) - tool_outputs.append( - { - "tool_call_id": tool_call.id, # noqa - "output": fn(**kwargs, **context), - } - ) - logger.info("Submitting tool outputs for run %s", run_id) - logger.debug("Tool outputs: %r", tool_outputs) - await client.beta.threads.runs.submit_tool_outputs( - run.id, # noqa - thread_id=thread_id, - tool_outputs=tool_outputs, - ) - await complete_run(run_id, thread_id) # we reset the retry counter + await call_tools(run, **context) + # after we submit the tool outputs, we reset the retry counter + await complete_run(run_id, thread_id, **context) + case RunStatus.COMPLETED: + return + case _: + raise IOError(f"Run {run.id} failed with status {run.status}") + + +async def call_tools(run: openai.types.beta.threads.Run, **context) -> None: + """ + Call the tools required by the run. + + Raises: + IOError: If a tool is not found. + ValueError: If the run does not require any tools. + """ + client: openai.AsyncOpenAI = openai.AsyncOpenAI() + if not (run.required_action and run.required_action.submit_tool_outputs): + raise ValueError(f"Run {run.id} does not require any tools") + tool_outputs = [] + for tool_call in run.required_action.submit_tool_outputs.tool_calls: + kwargs = json.loads(tool_call.function.arguments) + try: + fn = getattr(tools, tool_call.function.name) + except KeyError as e: + await client.beta.threads.runs.cancel( + run_id=run.id, thread_id=run.thread_id + ) + raise IOError( + f"Tool {tool_call.function.name} not found, cancelling run {run.id}" + ) from e + logger.info("Running tool %s", tool_call.function.name) + logger.debug("Tool %s arguments: %r", tool_call.function.name, kwargs) + tool_outputs.append( + { + "tool_call_id": tool_call.id, + "output": fn(**kwargs, _context={**context}), + } + ) + logger.info("Submitting tool outputs for run %s", run.id) + logger.debug("Tool outputs: %r", tool_outputs) + await client.beta.threads.runs.submit_tool_outputs( + run.id, # noqa + thread_id=run.thread_id, + tool_outputs=tool_outputs, + ) -async def run( +async def execute_run( assistant_id: str, thread_id: str, additional_instructions: str = None, - file_search: bool = False, + file_ids: list[str] = None, **context, ) -> str: """Run the assistant on the OpenAI thread.""" logger.info( - "Running assistant %s in thread %s with additional instructions: %r", + "Running assistant %s in thread %s with additional instructions", assistant_id, # noqa thread_id, - additional_instructions, ) + logger.debug("Additional instructions: %r", additional_instructions) + logger.debug("Context: %r", context) client: openai.AsyncOpenAI = openai.AsyncOpenAI() - _run = await client.beta.threads.runs.create( + run = await client.beta.threads.runs.create( thread_id=thread_id, assistant_id=assistant_id, additional_instructions=additional_instructions, - max_prompt_tokens=config.MAX_PROMPT_TOKENS, tools=[ utils.func_to_tool(tools.send_email), utils.func_to_tool(tools.web_search), + utils.func_to_tool(tools.platform_search), utils.func_to_tool(tools.fetch_website), utils.func_to_tool(tools.fetch_coworker_emails), utils.func_to_tool(tools.create_github_issue), @@ -99,32 +114,119 @@ async def run( ], tool_choice={"type": "file_search"} if file_search else "auto", ) - await complete_run(_run.id, thread_id, **context) + try: + await complete_run(run.id, thread_id, **context) + except (RecursionError, IOError, ValueError): + logger.exception("Run %s failed", run.id) + return "🤯" + + try: + return await fetch_latest_assistant_message(thread_id) + except ValueError: + logger.exception("No assistant message found") + return "🤯" + +async def fetch_latest_assistant_message(thread_id: str) -> str: + """ + Fetch the latest assistant message from the thread. + + Raises: + ValueError: If no assistant message is found. + """ + client: openai.AsyncOpenAI = openai.AsyncOpenAI() messages = await client.beta.threads.messages.list(thread_id=thread_id) for message in messages.data: if message.role == Roles.ASSISTANT: - message_content = message.content[0].text + try: + return await annotate_citations(message.content[0].text) + except IndexError as e: + raise ValueError("No assistant message found") from e - annotations = message_content.annotations - citations = [] - # Iterate over the annotations and add footnotes - for index, annotation in enumerate(annotations): - message_content.value = message_content.value.replace( - annotation.text, f" [{index}]" - ) +async def annotate_citations( + message_content: openai.types.beta.threads.TextContentBlock, +) -> str: + """Annotate citations in the text using footnotes and the file metadata.""" + client: openai.AsyncOpenAI = openai.AsyncOpenAI() + citations = [] + + # Iterate over the annotations and add footnotes + for index, annotation in enumerate(message_content.annotations): + message_content.value = message_content.value.replace( + annotation.text, f" [{index}]" + ) + + if file_citation := getattr(annotation, "file_citation", None): + cited_file = await client.files.retrieve(file_citation.file_id) + citations.append(f"[{index}] {file_citation.quote} — {cited_file.filename}") + elif file_path := getattr(annotation, "file_path", None): + cited_file = await client.files.retrieve(file_path.file_id) + citations.append(f"[{index}]({cited_file.filename})") + + # Add footnotes to the end of the message before displaying to user + message_content.value += "\n" + "\n".join(citations) + return message_content.value - if file_citation := getattr(annotation, "file_citation", None): - cited_file = client.files.retrieve(file_citation.file_id) - citations.append( - f"[{index}] {file_citation.quote} — {cited_file.filename}" + +async def add_message( + thread_id: str, + content: str, + files: [(str, bytes)] = None, +) -> tuple[list[str], bool]: + """Add a message to the thread.""" + logger.info(f"Adding message to thread={thread_id}") + client: openai.AsyncOpenAI = openai.AsyncOpenAI() + file_ids = [] + voice_prompt = False + for file_name, file_content in files or []: + if Path(file_name).suffix.lstrip(".") in AUDIO_FORMATS: + logger.debug("Transcribing audio file %s", file_name) + content += ( + "\n" + + ( + await client.audio.transcriptions.create( + model="whisper-1", + file=(file_name, file_content), ) - elif file_path := getattr(annotation, "file_path", None): - cited_file = client.files.retrieve(file_path.file_id) - citations.append(f"[{index}]({cited_file.filename})") + ).text + ) + voice_prompt = True + else: + logger.debug("Uploading file %s", file_name) + new_file = await client.files.create( + file=(file_name, file_content), + purpose="assistants", + ) + file_ids.append(new_file.id) + await client.beta.threads.messages.create( + thread_id=thread_id, + content=content, + role=Roles.USER, + attachments=[ + {"file_id": file_id, "tools": [{"type": "file_search"}]} + for file_id in file_ids + ], + ) + return file_ids, voice_prompt + + +async def tts(text: str) -> bytes: + """Convert text to speech using the OpenAI API.""" + client: openai.AsyncOpenAI = openai.AsyncOpenAI() + response = await client.audio.speech.create( + model=config.TTS_MODEL, + voice=config.TTS_VOICE, + input=text, + ) + return response.read() - # Add footnotes to the end of the message before displaying to user - message_content.value += "\n" + "\n".join(citations) - return message_content +async def stt(audio: bytes) -> str: + """Convert speech to text using the OpenAI API.""" + client: openai.AsyncOpenAI = openai.AsyncOpenAI() + response = await client.audio.transcriptions.create( + model="whisper-1", + file=audio, + ) + return response.text diff --git a/sam/config.py b/sam/config.py index f336413..ea91fae 100644 --- a/sam/config.py +++ b/sam/config.py @@ -1,13 +1,20 @@ +from __future__ import annotations + +import enum import os +import tomllib +from dataclasses import dataclass +from functools import cached_property +from pathlib import Path from zoneinfo import ZoneInfo SLACK_BOT_TOKEN = os.getenv("SLACK_BOT_TOKEN") SLACK_APP_TOKEN = os.getenv("SLACK_APP_TOKEN") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") OPENAI_ASSISTANT_ID = os.getenv("OPENAI_ASSISTANT_ID") - +TTS_VOICE = os.getenv("TTS_VOICE", "alloy") +TTS_MODEL = os.getenv("TTS_MODEL", "tts-1-hd") MAX_PROMPT_TOKENS = int(os.getenv("MAX_PROMPT_TOKENS", "20480")) - REDIS_URL = os.getenv("REDIS_URL", "redis:///") RANDOM_RUN_RATIO = float(os.getenv("RANDOM_RUN_RATIO", "0")) TIMEZONE = ZoneInfo(os.getenv("TIMEZONE", "UTC")) @@ -15,5 +22,29 @@ BRAVE_SEARCH_LONGITUDE = os.getenv("BRAVE_SEARCH_LONGITUDE") BRAVE_SEARCH_LATITUDE = os.getenv("BRAVE_SEARCH_LATITUDE") SENTRY_DSN = os.getenv("SENTRY_DSN") -GITHUB_ORG = os.getenv("GITHUB_ORG") -GITHUB_REPOSITORY = os.getenv("GITHUB_REPOSITORY") +GITHUB_REPOS = enum.StrEnum( + "GITHUB_REPOS", + {repo: repo for repo in os.getenv("GITHUB_REPOS", "").split(",") if repo}, +) + + +@dataclass +class AssistantConfig: + name: str + assistant_id: str + instructions: list[str] + project: str + + @cached_property + def system_prompt(self): + return "\n\n".join( + Path(instruction).read_text() for instruction in self.instructions + ) + + +def load_assistants(): + with Path("pyproject.toml").open("rb") as fs: + for assistant in ( + tomllib.load(fs).get("tool", {}).get("sam", {}).get("assistants", []) + ): + yield AssistantConfig(**assistant) diff --git a/sam/contrib/algolia/__init__.py b/sam/contrib/algolia/__init__.py new file mode 100644 index 0000000..8f480fd --- /dev/null +++ b/sam/contrib/algolia/__init__.py @@ -0,0 +1,74 @@ +"""AlgoliaSearch Search API client to perform searches on Algolia.""" + +import abc +import os + +import requests +from algoliasearch.search_client import SearchClient + +__all__ = ["get_client", "AlgoliaSearchAPIError"] + + +class AlgoliaSearchAPIError(requests.HTTPError): + pass + + +class AbstractAlgoliaSearch(abc.ABC): # pragma: no cover + @abc.abstractmethod + def search(self, query): + return NotImplemented + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +class AlgoliaSearch(AbstractAlgoliaSearch): # pragma: no cover + def __init__(self, application_id, api_key, index): + super().__init__() + self.api_key = api_key + client = SearchClient.create(application_id, api_key) + self.index = client.init_index(index) + self.params = {} + + def search(self, query): + try: + return self.index.search( + query, + request_options={ + **self.params, + "length": 5, + }, + ) + except requests.HTTPError as e: + raise AlgoliaSearchAPIError("The Algolia search API call failed.") from e + + +class AlgoliaSearchStub(AbstractAlgoliaSearch): + def __init__(self): + self.headers = {} + self._objects = [ + { + "title": "Deutschland", + "parent_object_title": "Ferienangebote", + "public_url": "https://www.schulferien.org/deutschland/ferien/", + }, + ] + self.params = {} + + def search(self, query): + return {"hits": self._objects, "nbPages": 1} + + +def get_client(index=None) -> AbstractAlgoliaSearch: + index = index or os.getenv("ALGOLIA_SEARCH_INDEX", "event") + if api_key := os.getenv("ALGOLIA_SEARCH_API_KEY", None): # pragma: no cover + return AlgoliaSearch( + application_id=os.getenv("ALGOLIA_APPLICATION_ID"), + api_key=api_key, + index=index, + ) + else: + return AlgoliaSearchStub() diff --git a/sam/contrib/github/__init__.py b/sam/contrib/github/__init__.py index 125f3a1..04a2084 100644 --- a/sam/contrib/github/__init__.py +++ b/sam/contrib/github/__init__.py @@ -16,7 +16,7 @@ class GitHubAPIError(requests.HTTPError): class AbstractGitHubAPIWrapper(abc.ABC): # pragma: no cover @abc.abstractmethod - def create_issue(self, title, body): + def create_issue(self, title, body, repo): return NotImplemented @abc.abstractmethod @@ -42,9 +42,9 @@ def __init__(self, token): } ) - def create_issue(self, title, body): + def create_issue(self, title, body, repo): response = self.post( - f"{self.endpoint}/repos/{config.GITHUB_ORG}/{config.GITHUB_REPOSITORY}/issues", + f"{self.endpoint}/repos/{repo}/issues", json={"title": title, "body": body}, ) try: @@ -64,11 +64,11 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): pass - def create_issue(self, title, body): + def create_issue(self, title, body, repo): return { "title": title, "body": body, - "url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ", + "html_url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ", } diff --git a/sam/slack.py b/sam/slack.py index 97c478c..79e9f7c 100644 --- a/sam/slack.py +++ b/sam/slack.py @@ -1,144 +1,134 @@ -import asyncio +import functools import json import logging import random # nosec import urllib.request +from datetime import datetime from typing import Any import redis.asyncio as redis -from openai import AsyncOpenAI -from slack_bolt.async_app import AsyncApp, AsyncSay +from slack_bolt.async_app import AsyncSay +from slack_sdk.web.async_client import AsyncWebClient +from slack_sdk.web.client import WebClient from . import bot, config, utils -logger = logging.getLogger("sam") +logger = logging.getLogger(__name__) -client = AsyncOpenAI() -app = AsyncApp(token=config.SLACK_BOT_TOKEN) +_USER_HANDLE = None + +ACKNOWLEDGMENT_SMILEYS = [ + "thumbsup", + "ok_hand", + "eyes", + "wave", + "robot_face", + "saluting_face", + "v", + "100", + "muscle", + "thought_balloon", + "speech_balloon", + "space_invader", + "call_me_hand", +] -USER_HANDLE = None -AUDIO_FORMATS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"] +async def get_bot_user_id(): + """Get the Slack bot's user id.""" + client = AsyncWebClient(token=config.SLACK_BOT_TOKEN) + global _USER_HANDLE + if _USER_HANDLE is None: + logger.debug("Fetching the bot's user id") + response = await client.auth_test() + _USER_HANDLE = response["user_id"] + logger.debug(f"Bot's user id is {_USER_HANDLE}") + return _USER_HANDLE async def handle_message(event: {str, Any}, say: AsyncSay): + """Handle a message event from Slack.""" logger.debug(f"handle_message={json.dumps(event)}") - global USER_HANDLE - if USER_HANDLE is None: - logger.debug("Fetching the bot's user id") - response = await say.client.auth_test() - USER_HANDLE = response["user_id"] - logger.debug(f"Bot's user id is {USER_HANDLE}") + if event.get("subtype") == "message_deleted": + logger.debug("Ignoring message_deleted event %s", event) + return # https://api.slack.com/events/message#hidden_subtypes + bot_id = await get_bot_user_id() channel_id = event["channel"] - client_msg_id = event["client_msg_id"] channel_type = event["channel_type"] - user_id = event["user"] text = event["text"] - text = text.replace(f"<@{USER_HANDLE}>", "Sam") + text = text.replace(f"<@{bot_id}>", "Sam") thread_id = await utils.get_thread_id(channel_id) # We may only add messages to a thread while the assistant is not running + files = [] + for file in event.get("files", []): + req = urllib.request.Request( + file["url_private"], + headers={"Authorization": f"Bearer {config.SLACK_BOT_TOKEN}"}, + ) + with urllib.request.urlopen(req) as response: # nosec + files.append((file["name"], response.read())) + async with ( redis.from_url(config.REDIS_URL) as redis_client, redis_client.lock(thread_id, timeout=10 * 60, thread_local=False), ): # 10 minutes - file_ids = [] - voice_prompt = False - if "files" in event: - for file in event["files"]: - req = urllib.request.Request( - file["url_private"], - headers={"Authorization": f"Bearer {config.SLACK_BOT_TOKEN}"}, - ) - with urllib.request.urlopen(req) as response: # nosec - if file["filetype"] in AUDIO_FORMATS: - text += "\n" + await client.audio.transcriptions.create( - model="whisper-1", - file=(file["name"], response.read()), - response_format="text", - ) - logger.info(f"User={user_id} added Audio={file['id']}") - voice_prompt = True - else: - file_ids.append( - ( - await client.files.create( - file=(file["name"], response.read()), - purpose="assistants", - ) - ).id - ) - logger.info( - f"User={user_id} added File={file_ids[-1]} to Thread={thread_id}" - ) - await client.beta.threads.messages.create( + file_ids, voice_prompt = await bot.add_message( thread_id=thread_id, content=text, - role="user", - attachments=[ - {"file_id": file_id, "tools": [{"type": "file_search"}]} - for file_id in file_ids - ], - ) - logger.info( - f"User={user_id} added Message={client_msg_id} added to Thread={thread_id}" + files=files, ) - if ( - channel_type == "im" - or event.get("parent_user_id") == USER_HANDLE - or random.random() < config.RANDOM_RUN_RATIO # nosec - ): - # we need to run the assistant in a separate thread, otherwise we will - # block the main thread: - # process_run(event, say, voice_prompt=voice_prompt) - asyncio.create_task( - process_run( - event, say, voice_prompt=voice_prompt, file_search=bool(file_ids) - ) - ) - -ACKNOWLEDGMENT_SMILEYS = [ - "thumbsup", - "ok_hand", - "eyes", - "wave", - "robot_face", - "saluting_face", - "v", - "100", - "muscle", - "thought_balloon", - "speech_balloon", - "space_invader", - "call_me_hand", -] + # we need to release the lock before starting a new run + if ( + channel_type == "im" + or event.get("parent_user_id") == bot_id + or random.random() < config.RANDOM_RUN_RATIO # nosec + ): + await send_response(event, say, file_ids=file_ids, voice_prompt=voice_prompt) + + +@functools.lru_cache(maxsize=128) +def get_user_profile(user_id: str) -> dict[str, Any]: + """Get the profile of a user.""" + client = WebClient(token=config.SLACK_BOT_TOKEN) + return client.users_profile_get(user=user_id)["profile"] + + +@functools.lru_cache(maxsize=128) +def get_user_specific_instructions(user_id: str) -> str: + """Get the user-specific instructions.""" + profile = get_user_profile(user_id) + name = profile["display_name"] + email = profile["email"] + pronouns = profile.get("pronouns") + local_time = datetime.now(tz=config.TIMEZONE) + instructions = [ + f"You MUST ALWAYS address the user as <@{user_id}>.", + f"You may refer to the user as {name}.", + f"The user's email is {email}.", + f"The time is {local_time.isoformat()}.", + ] + if pronouns: + instructions.append(f"The user's pronouns are {pronouns}.") + return "\n".join(instructions) -async def process_run( +async def send_response( event: {str, Any}, say: AsyncSay, + file_ids: list[str] = None, voice_prompt: bool = False, - file_search: bool = False, ): + """Send a response to a message event from Slack.""" logger.debug(f"process_run={json.dumps(event)}") channel_id = event["channel"] user_id = event["user"] - user = await say.client.users_profile_get(user=user_id) - name = user["profile"]["display_name"] - email = user["profile"]["email"] - pronouns = user["profile"].get("pronouns") - additional_instructions = ( - f"You MUST ALWAYS address the user as <@{user_id}>.\n" - f"You may refer to the user as {name}.\n" - f"The user's email is {email}.\n" - ) - if pronouns: - additional_instructions += f"The user's pronouns are {pronouns}.\n" try: - ts = event["ts"] + timestamp = event["ts"] except KeyError: - ts = event["thread_ts"] + timestamp = event["thread_ts"] thread_id = await utils.get_thread_id(channel_id) + # We may wait for the messages being processed, before starting a new run async with ( redis.from_url(config.REDIS_URL) as redis_client, @@ -148,32 +138,32 @@ async def process_run( await say.client.reactions_add( channel=channel_id, name=random.choice(ACKNOWLEDGMENT_SMILEYS), # nosec - timestamp=ts, + timestamp=timestamp, ) - message_content = await bot.run( + text_response = await bot.execute_run( thread_id=thread_id, assistant_id=config.OPENAI_ASSISTANT_ID, - additional_instructions=additional_instructions, - file_search=file_search, - **user["profile"], + additional_instructions=get_user_specific_instructions(user_id), + file_ids=file_ids, + **get_user_profile(user_id), ) msg = await say( channel=say.channel, - text=message_content.value, + text=text_response, mrkdwn=True, thread_ts=event.get("thread_ts", None), ) + logger.info( + f"Sam responded to the User={user_id} in Channel={channel_id} via Text" + ) if voice_prompt: - response = await client.audio.speech.create( - model="tts-1-hd", - voice="alloy", - input=message_content.value, - ) await say.client.files_upload_v2( - content=response.read(), - channels=say.channel, + filename="response.mp3", + title="Voice Response", + content=await bot.tts(text_response), + channel=say.channel, thread_ts=event.get("thread_ts", None), ts=msg["ts"], ) @@ -181,10 +171,11 @@ async def process_run( f"Sam responded to the User={user_id} in Channel={channel_id} via Voice" ) - logger.info( - f"Sam responded to the User={user_id} in Channel={channel_id} via Text" - ) +def get_app(): # pragma: no cover + from slack_bolt.async_app import AsyncApp -app.event("message")(handle_message) -app.event("app_mention")(process_run) + app = AsyncApp(token=config.SLACK_BOT_TOKEN) + app.event("message")(handle_message) + app.event("app_mention")(send_response) + return app diff --git a/sam/tools.py b/sam/tools.py index b4ad52c..e886a00 100644 --- a/sam/tools.py +++ b/sam/tools.py @@ -8,18 +8,20 @@ import urllib.parse from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText +from urllib.parse import urljoin import requests from bs4 import ParserRejectedMarkup from markdownify import markdownify as md from slack_sdk import WebClient, errors +import sam.config from sam import config -from sam.contrib import brave, github +from sam.contrib import algolia, brave, github from sam.utils import logger -def send_email(to: str, subject: str, body: str, **_context): +def send_email(to: str, subject: str, body: str, _context=None): """ Send an email the given recipients. The user is always cc'd on the email. @@ -28,6 +30,7 @@ def send_email(to: str, subject: str, body: str, **_context): subject: The subject of the email. body: The body of the email. """ + _context = _context or {} email_url = os.getenv("EMAIL_URL") from_email = os.getenv("FROM_EMAIL", "sam@voiio.de") email_white_list = os.getenv("EMAIL_WHITE_LIST") @@ -39,8 +42,10 @@ def send_email(to: str, subject: str, body: str, **_context): msg = MIMEMultipart() msg["From"] = f"Sam <{from_email}>" msg["To"] = to + to_addr = to.split(",") if cc := _context.get("email"): msg["Cc"] = cc + to_addr.append(cc) msg["Subject"] = subject msg.attach(MIMEText(body, "plain")) try: @@ -49,7 +54,7 @@ def send_email(to: str, subject: str, body: str, **_context): server.starttls(context=context) server.ehlo() server.login(url.username, url.password) - server.sendmail(from_email, [to], msg.as_string()) + server.sendmail(from_email, to_addr, msg.as_string()) except smtplib.SMTPException: logger.exception("Failed to send email to: %s", to) return "Email not sent. An error occurred." @@ -57,7 +62,7 @@ def send_email(to: str, subject: str, body: str, **_context): return "Email sent successfully!" -def web_search(query: str, **_context) -> str: +def web_search(query: str, _context=None) -> str: """ Search the internet for information that matches the given query. @@ -88,7 +93,7 @@ def web_search(query: str, **_context) -> str: ) -def fetch_website(url: str, **_context) -> str: +def fetch_website(url: str, _context=None) -> str: """ Fetch the website for the given URL and return the content as Markdown. @@ -113,7 +118,7 @@ def fetch_website(url: str, **_context) -> str: return "failed to parse website" -def fetch_coworker_emails(**_context) -> str: +def fetch_coworker_emails(_context=None) -> str: """ Fetch profile data about your coworkers from Slack. @@ -153,7 +158,9 @@ def fetch_coworker_emails(**_context) -> str: return json.dumps(profiles) -def create_github_issue(title: str, body: str) -> str: +def create_github_issue( + title: str, body: str, repo: "sam.config.GITHUB_REPOS", _context=None +) -> str: """ Create an issue on GitHub with the given title and body. @@ -163,15 +170,55 @@ def create_github_issue(title: str, body: str) -> str: You should provide ideas for a potential solution, including code snippet examples in a Markdown code block. + You MUST ALWAYS write the issue in English. + Args: title: The title of the issue. body: The body of the issue, markdown supported. + repo: The repository to create the issue in. """ + if repo not in config.GITHUB_REPOS.__members__: + logger.warning("Invalid repo: %s", repo) + return "invalid repo" with github.get_client() as api: try: - response = api.create_issue(title, body) + response = api.create_issue(title, body, repo) except github.GitHubAPIError: logger.exception("Failed to create issue on GitHub") return "failed to create issue" else: - return response["url"] + return response["html_url"] + + +def platform_search(query: str) -> str: + """Search the platform for information that matches the given query. + + Return the title and URL of the matching objects in a user friendly format. + + Args: + query: The query to search for. + """ + with algolia.get_client() as api: + api.params.update( + { + "filters": "is_published:true", + "attributesToRetrieve": ["title", "parent_object_title", "public_url"], + } + ) + try: + results = api.search(query)["hits"] + except algolia.AlgoliaSearchAPIError: + logger.exception("Failed to search the platform for query: %s", query) + return "search failed" + else: + if not results: + logger.warning("No platform results found for query: %s", query) + return "no results found" + return json.dumps( + { + f"{hit['parent_object_title']}: {hit['title']}": urljoin( + "https://www.voiio.app", hit["public_url"] + ) + for hit in results + } + ) diff --git a/sam/typing.py b/sam/typing.py index 18ecb3b..9d0dcfb 100644 --- a/sam/typing.py +++ b/sam/typing.py @@ -16,9 +16,13 @@ class RunStatus(enum.StrEnum): QUEUED = "queued" IN_PROGRESS = "in_progress" + COMPLETED = "completed" REQUIRES_ACTION = "requires_action" CANCELLING = "cancelling" CANCELLED = "cancelled" FAILED = "failed" - COMPLETED = "completed" EXPIRED = "expired" + INCOMPLETE = "incomplete" + + +AUDIO_FORMATS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"] diff --git a/sam/utils.py b/sam/utils.py index 121072c..c55328c 100644 --- a/sam/utils.py +++ b/sam/utils.py @@ -1,10 +1,12 @@ from __future__ import annotations import asyncio +import datetime import enum import inspect import logging import random +import typing import openai import redis.asyncio as redis @@ -28,8 +30,6 @@ float: "number", list: "array", dict: "object", - enum.StrEnum: "string", - enum.IntEnum: "integer", } @@ -50,6 +50,8 @@ def func_to_tool(fn: callable) -> dict: doc_data = yaml.safe_load(args.split("Returns:")[0]) else: description = fn.__doc__ + doc_data = {} + return { "type": "function", "function": { @@ -59,13 +61,7 @@ def func_to_tool(fn: callable) -> dict: ), "parameters": { "type": "object", - "properties": { - param.name: { - "type": type_map[param.annotation], - "description": doc_data[param.name], - } - for param in params - }, + "properties": dict(params_to_props(fn, params, doc_data)), "required": [ param.name for param in params @@ -76,6 +72,25 @@ def func_to_tool(fn: callable) -> dict: } +def params_to_props(fn, params, doc_data): + types = typing.get_type_hints(fn) + for param in params: + if param.name.startswith("_"): + continue + param_type = types[param.name] + if param_type in type_map: + yield param.name, { + "type": type_map[types[param.name]], + "description": doc_data[param.name], + } + elif issubclass(param_type, enum.StrEnum): + yield param.name, { + "type": "string", + "enum": [value.value for value in param_type], + "description": doc_data[param.name], + } + + async def backoff(retries: int, max_jitter: int = 10): """Exponential backoff timer with a random jitter.""" await asyncio.sleep(2**retries + random.random() * max_jitter) # nosec @@ -99,6 +114,10 @@ async def get_thread_id(slack_id) -> str: thread = await openai.AsyncOpenAI().beta.threads.create() thread_id = thread.id - await redis_client.set(slack_id, thread_id) + midnight = datetime.datetime.combine( + datetime.date.today(), datetime.time.max, tzinfo=config.TIMEZONE + ) + + await redis_client.set(slack_id, thread_id, exat=int(midnight.timestamp())) return thread_id diff --git a/tests/fixtures/harry.md b/tests/fixtures/harry.md new file mode 100644 index 0000000..ab8912c --- /dev/null +++ b/tests/fixtures/harry.md @@ -0,0 +1 @@ +You are a wizard, Harry. \ No newline at end of file diff --git a/tests/fixtures/pyproject.toml b/tests/fixtures/pyproject.toml new file mode 100644 index 0000000..f918cd7 --- /dev/null +++ b/tests/fixtures/pyproject.toml @@ -0,0 +1,8 @@ +[[tool.sam.assistants]] +name = "Harry" +assistant_id = "asst_1234057341258907" +project="default-project" +instructions = [ + "harry.md", + "security.md", +] diff --git a/tests/fixtures/security.md b/tests/fixtures/security.md new file mode 100644 index 0000000..2bfaa35 --- /dev/null +++ b/tests/fixtures/security.md @@ -0,0 +1 @@ +You mustn't tell lies. \ No newline at end of file diff --git a/tests/test_bot.py b/tests/test_bot.py index bd3ffaa..30aa4d3 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -33,8 +33,13 @@ async def test_complete_run__requires_action(client, monkeypatch): tool_call.function.name = "web_search" required_action.submit_tool_outputs.tool_calls = [tool_call] client.beta.threads.runs.retrieve.return_value = namedtuple( - "Run", ["id", "status", "required_action"] - )(id="run-1", status="requires_action", required_action=required_action) + "Run", ["id", "thread_id", "status", "required_action"] + )( + id="run-1", + thread_id="thread-1", + status="requires_action", + required_action=required_action, + ) with pytest.raises(RecursionError): await bot.complete_run(run_id="run-1", thread_id="thread-1") assert web_search.called @@ -46,14 +51,33 @@ async def test_complete_run__queued(monkeypatch, client): client.beta.threads.runs.retrieve.return_value = namedtuple( "Run", ["id", "status"] )(id="run-1", status="queued") - with pytest.raises(Exception) as e: + with pytest.raises(RecursionError) as e: await bot.complete_run(run_id="run-1", thread_id="thread-1") assert "Max retries exceeded" in str(e.value) @pytest.mark.asyncio -async def test_run(monkeypatch, client): +async def test_complete_run__completed(monkeypatch, client): + monkeypatch.setattr("sam.utils.backoff", mock.AsyncMock()) + client.beta.threads.runs.retrieve.return_value = namedtuple( + "Run", ["id", "status"] + )(id="run-1", status="completed") + await bot.complete_run(run_id="run-1", thread_id="thread-1") + + +@pytest.mark.asyncio +async def test_complete_run__unexpected_status(monkeypatch, client): + monkeypatch.setattr("sam.utils.backoff", mock.AsyncMock()) + client.beta.threads.runs.retrieve.return_value = namedtuple( + "Run", ["id", "status"] + )(id="run-1", status="failed") + with pytest.raises(IOError): + await bot.complete_run(run_id="run-1", thread_id="thread-1") + + +@pytest.mark.asyncio +async def test_execute_run(monkeypatch, client): client.beta.threads.runs.retrieve.return_value = namedtuple( "Run", ["id", "status"] )(id="run-1", status="queued") @@ -76,6 +100,74 @@ async def test_run(monkeypatch, client): ), ] ) - bot.complete_run = mock.AsyncMock() - await bot.run(thread_id="thread-1", assistant_id="assistant-1") - assert bot.complete_run.called + complete_run = mock.AsyncMock() + monkeypatch.setattr(bot, "complete_run", complete_run) + await bot.execute_run(thread_id="thread-1", assistant_id="assistant-1") + assert complete_run.called + + +@pytest.mark.asyncio +async def test_execute_run__no_completed(monkeypatch, client): + client.beta.threads.runs.retrieve.return_value = namedtuple( + "Run", ["id", "status"] + )(id="run-1", status="queued") + client.beta.threads.messages.list.return_value = namedtuple("Response", ["data"])( + data=[ + Message( + id="msg-1", + content=[ + TextContentBlock( + type="text", text=Text(value="Hello", annotations=[]) + ) + ], + status="completed", + role="assistant", + created_at=123, + files=[], + file_ids=[], + object="thread.message", + thread_id="thread-4", + ), + ] + ) + complete_run = mock.AsyncMock(side_effect=[RecursionError, None]) + monkeypatch.setattr(bot, "complete_run", complete_run) + response = await bot.execute_run(thread_id="thread-1", assistant_id="assistant-1") + assert complete_run.called + assert response == "🤯" + + +@pytest.mark.asyncio +async def test_execute_run__no_message(monkeypatch, client): + client.beta.threads.runs.retrieve.return_value = namedtuple( + "Run", ["id", "status"] + )(id="run-1", status="queued") + client.beta.threads.messages.list.return_value = namedtuple("Response", ["data"])( + data=[ + Message( + id="msg-1", + content=[ + TextContentBlock( + type="text", text=Text(value="Hello", annotations=[]) + ) + ], + status="completed", + role="assistant", + created_at=123, + files=[], + file_ids=[], + object="thread.message", + thread_id="thread-4", + ), + ] + ) + complete_run = mock.AsyncMock() + monkeypatch.setattr(bot, "complete_run", complete_run) + fetch_latest_assistant_message = mock.AsyncMock(side_effect=[ValueError, None]) + monkeypatch.setattr( + bot, "fetch_latest_assistant_message", fetch_latest_assistant_message + ) + response = await bot.execute_run(thread_id="thread-1", assistant_id="assistant-1") + assert complete_run.called + assert fetch_latest_assistant_message.called + assert response == "🤯" diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..04ca049 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,15 @@ +from sam.config import AssistantConfig + + +class TestAssistantConfig: + + def test_system_prompt(self): + assert ( + AssistantConfig( + name="Test", + assistant_id="test", + project="test", + instructions=["tests/fixtures/harry.md", "tests/fixtures/security.md"], + ).system_prompt + == "You are a wizard, Harry.\n\nYou mustn't tell lies." + ) diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..66f3513 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,40 @@ +import os +from unittest import mock + +from click.testing import CliRunner + +from sam.__main__ import cli + + +class TestCli: + + def test_list(self, monkeypatch): + monkeypatch.chdir("tests/fixtures") + runner = CliRunner() + result = runner.invoke(cli, ["assistants", "list"]) + assert result.exit_code == 0 + assert "Harry (default-project): asst_1234057341258907" in result.output + + def test_list__empty(self): + runner = CliRunner() + result = runner.invoke(cli, ["assistants", "list"]) + assert result.exit_code == 0 + assert "No assistants configured." in result.output + + def test_upload(self, monkeypatch): + monkeypatch.chdir("tests/fixtures") + client = mock.MagicMock() + monkeypatch.setattr("openai.OpenAI", lambda api_key: client) + runner = CliRunner() + result = runner.invoke(cli, ["assistants", "upload"]) + assert result.exit_code == 0 + assert "Uploading Harry... Done!" in result.output + assert client.__enter__().beta.assistants.update.called + + def test_upload__empty(self, monkeypatch): + client = mock.MagicMock() + monkeypatch.setattr("openai.OpenAI", lambda api_key: client) + runner = CliRunner() + result = runner.invoke(cli, ["assistants", "upload"]) + assert result.exit_code == 0 + assert "No assistants configured." in result.output diff --git a/tests/test_slack.py b/tests/test_slack.py new file mode 100644 index 0000000..9d641f7 --- /dev/null +++ b/tests/test_slack.py @@ -0,0 +1,224 @@ +import logging +from unittest import mock + +import pytest + +from sam import bot, slack + + +@pytest.mark.asyncio +async def test_get_bot_user_id(monkeypatch): + auth_test = mock.AsyncMock(return_value={"user_id": "bot-1"}) + monkeypatch.setattr(slack.AsyncWebClient, "auth_test", auth_test) + assert await slack.get_bot_user_id() == "bot-1" + assert auth_test.called + + +@pytest.mark.asyncio +async def test_handle_message(monkeypatch): + urlopen = mock.AsyncMock() + urlopen.__enter__().read.return_value = b"Hello" + monkeypatch.setattr("urllib.request.urlopen", lambda *args, **kwargs: urlopen) + add_message = mock.AsyncMock(return_value=(["file-1"], False)) + monkeypatch.setattr(bot, "add_message", add_message) + send_response = mock.AsyncMock() + monkeypatch.setattr(slack, "send_response", send_response) + get_bot_user_id = mock.AsyncMock(return_value="bot-1") + monkeypatch.setattr(slack, "get_bot_user_id", get_bot_user_id) + monkeypatch.setattr( + "sam.utils.get_thread_id", mock.AsyncMock(return_value="thread-1") + ) + say = mock.AsyncMock() + event = { + "channel": "channel-1", + "client_msg_id": "client-msg-1", + "channel_type": "im", + "user": "user-1", + "text": "Hello", + "files": [ + { + "url_private": "https://audio-samples.github.io/samples/mp3/blizzard_tts_unbiased/sample-0/real.mp3", + "name": "file.mp3", + } + ], + } + await slack.handle_message(event, say) + assert add_message.called + assert add_message.call_args == mock.call( + thread_id="thread-1", content="Hello", files=[("file.mp3", b"Hello")] + ) + assert urlopen.__enter__().read.called + assert send_response.called + assert send_response.call_args == mock.call( + { + "channel": "channel-1", + "client_msg_id": "client-msg-1", + "channel_type": "im", + "user": "user-1", + "text": "Hello", + "files": [ + { + "url_private": "https://audio-samples.github.io/samples/mp3/blizzard_tts_unbiased/sample-0/real.mp3", + "name": "file.mp3", + } + ], + }, + say, + file_ids=["file-1"], + voice_prompt=False, + ) + + +@pytest.mark.asyncio +async def test_handle_message__subtype_deleted(caplog): + event = { + "type": "message", + "subtype": "message_deleted", + "hidden": True, + "channel": "C123ABC456", + "ts": "1358878755.000001", + "deleted_ts": "1358878749.000002", + "event_ts": "1358878755.000002", + } + with caplog.at_level(logging.DEBUG): + await slack.handle_message(event, None) + assert "Ignoring message_deleted event" in caplog.text + + +def test_get_user_profile(monkeypatch): + client = mock.MagicMock() + client.users_profile_get.return_value = { + "profile": { + "display_name": "Spidy", + "status_text": "With great power comes great responsibility", + "pronouns": "spider/superhero", + "email": "peter.parker@avengers.com", + } + } + monkeypatch.setattr(slack.WebClient, "users_profile_get", client.users_profile_get) + assert slack.get_user_profile("user-1") == { + "display_name": "Spidy", + "status_text": "With great power comes great responsibility", + "pronouns": "spider/superhero", + "email": "peter.parker@avengers.com", + } + + +def test_get_user_specific_instructions(monkeypatch): + client = mock.MagicMock() + client.users_profile_get.return_value = { + "profile": { + "display_name": "Spidy", + "status_text": "With great power comes great responsibility", + "pronouns": "spider/superhero", + "email": "peter.parker@avengers.com", + } + } + monkeypatch.setattr(slack.WebClient, "users_profile_get", client.users_profile_get) + instructions = slack.get_user_specific_instructions("user-1") + assert "You MUST ALWAYS address the user as <@user-1>." in instructions + assert "You may refer to the user as Spidy." in instructions + assert "The user's email is peter.parker@avengers.com." in instructions + assert "The user's pronouns are spider/superhero." in instructions + + +@pytest.mark.asyncio +async def test_send_response(monkeypatch): + urlopen = mock.AsyncMock() + urlopen.__enter__().read.return_value = b"Hello" + monkeypatch.setattr("urllib.request.urlopen", lambda *args, **kwargs: urlopen) + execute_run = mock.AsyncMock(return_value="Hello World!") + monkeypatch.setattr(bot, "execute_run", execute_run) + tts = mock.AsyncMock(return_value=b"Hello") + monkeypatch.setattr(bot, "tts", tts) + get_bot_user_id = mock.AsyncMock(return_value="bot-1") + monkeypatch.setattr(slack, "get_bot_user_id", get_bot_user_id) + monkeypatch.setattr( + slack, + "get_user_specific_instructions", + lambda *args, **kwargs: "user_instructions", + ) + monkeypatch.setattr( + slack, "get_user_profile", lambda *args, **kwargs: {"name": "Sam"} + ) + monkeypatch.setattr( + "sam.utils.get_thread_id", mock.AsyncMock(return_value="thread-1") + ) + say = mock.AsyncMock() + event = { + "channel": "channel-1", + "client_msg_id": "client-msg-1", + "channel_type": "im", + "user": "user-1", + "ts": 12321345, + "text": "Hello", + "files": [ + { + "url_private": "https://example.com/file.mp3", + "name": "file.mp3", + } + ], + } + await slack.send_response(event, say, voice_prompt=True) + + assert execute_run.called + assert execute_run.call_args == mock.call( + thread_id="thread-1", + assistant_id=None, + additional_instructions="user_instructions", + file_ids=None, + name="Sam", + ) + assert tts.called + assert tts.call_args == mock.call("Hello World!") + + +@pytest.mark.asyncio +async def test_send_response__thread(monkeypatch): + urlopen = mock.AsyncMock() + urlopen.__enter__().read.return_value = b"Hello" + monkeypatch.setattr("urllib.request.urlopen", lambda *args, **kwargs: urlopen) + execute_run = mock.AsyncMock(return_value="Hello World!") + monkeypatch.setattr(bot, "execute_run", execute_run) + tts = mock.AsyncMock(return_value=b"Hello") + monkeypatch.setattr(bot, "tts", tts) + get_bot_user_id = mock.AsyncMock(return_value="bot-1") + monkeypatch.setattr(slack, "get_bot_user_id", get_bot_user_id) + monkeypatch.setattr( + slack, + "get_user_specific_instructions", + lambda *args, **kwargs: "user_instructions", + ) + monkeypatch.setattr( + slack, "get_user_profile", lambda *args, **kwargs: {"name": "Sam"} + ) + monkeypatch.setattr( + "sam.utils.get_thread_id", mock.AsyncMock(return_value="thread-1") + ) + say = mock.AsyncMock() + event = { + "channel": "channel-1", + "client_msg_id": "client-msg-1", + "channel_type": "im", + "user": "user-1", + "thread_ts": 12321345, + "text": "Hello", + "files": [ + { + "url_private": "https://audio-samples.github.io/samples/mp3/blizzard_tts_unbiased/sample-0/real.mp3", + "name": "file.mp3", + } + ], + } + await slack.send_response(event, say, voice_prompt=True) + + assert execute_run.called + assert execute_run.call_args == mock.call( + thread_id="thread-1", + assistant_id=None, + additional_instructions="user_instructions", + file_ids=None, + name="Sam", + ) + assert tts.called + assert tts.call_args == mock.call("Hello World!") diff --git a/tests/test_tools.py b/tests/test_tools.py index d8efb66..50707a1 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,9 +1,12 @@ +import json import smtplib from unittest import mock import pytest +import requests from sam import tools +from sam.contrib import algolia @pytest.fixture @@ -50,6 +53,33 @@ def test_web_search__with_coordinates(): def test_create_github_issue(): assert ( - tools.create_github_issue("title", "body") + tools.create_github_issue("title", "body", "voiio/sam") == "https://www.youtube.com/watch?v=dQw4w9WgXcQ" ) + + +def test_create_github_issue__invalid_repo(): + assert tools.create_github_issue("title", "body", "not-valid") == "invalid repo" + + +def test_platform_search(): + assert tools.platform_search("ferien") == json.dumps( + { + "Ferienangebote: Deutschland": "https://www.schulferien.org/deutschland/ferien/" + } + ) + + +def test_platform_search_with_error(): + with mock.patch( + "sam.contrib.algolia.AlgoliaSearchStub.search", + side_effect=algolia.AlgoliaSearchAPIError, + ): + assert tools.platform_search("something") == "search failed" + + +def test_platform_search_no_results(): + with mock.patch( + "sam.contrib.algolia.AlgoliaSearchStub.search", return_value={"hits": []} + ): + assert tools.platform_search("something") == "no results found" diff --git a/tests/test_utils.py b/tests/test_utils.py index 168d117..0e25d5c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,10 @@ from __future__ import annotations +import enum + import pytest +import tests.test_tools from sam import utils @@ -10,13 +13,19 @@ async def test_backoff(): await utils.backoff(0, max_jitter=0) +BloodTypes = enum.StrEnum("BloodTypes", {"A": "A", "B": "B"}) + + def test_func_to_tool(): - def fn(a: int, b: str) -> int: + def fn( + a: int, b: str, blood_types: "tests.test_utils.BloodTypes", _context=None + ) -> str: """Function description. Args: a: Description of a. b: Description of b. + blood_types: Description of bool_types. Returns: Description of return value. @@ -40,8 +49,13 @@ def fn(a: int, b: str) -> int: "type": "string", "description": "Description of b.", }, + "blood_types": { + "type": "string", + "enum": ["A", "B"], + "description": "Description of bool_types.", + }, }, - "required": ["a", "b"], + "required": ["a", "b", "blood_types"], }, }, }