From 70d758cf8dba2bbd50fefd9d9692b9d713f44ebf Mon Sep 17 00:00:00 2001 From: Rust Saiargaliev Date: Thu, 18 Apr 2024 10:57:54 +0200 Subject: [PATCH 01/19] Revert OpenAI update and latest commits --- pyproject.toml | 2 +- sam/bot.py | 14 +++++--------- sam/config.py | 3 --- sam/slack.py | 10 +++++----- sam/tools.py | 14 ++++++-------- sam/utils.py | 18 +++++++++--------- 6 files changed, 26 insertions(+), 35 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1f59040..88730bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ dependencies = [ "redis", "requests", "slack-bolt", - "openai>=1.21.0", + "openai==1.20.0", "pyyaml", "sentry-sdk", ] diff --git a/sam/bot.py b/sam/bot.py index 8354b31..71f7ef8 100644 --- a/sam/bot.py +++ b/sam/bot.py @@ -3,13 +3,13 @@ import openai -from . import config, tools, utils +from . import tools, utils from .typing import Roles, RunStatus logger = logging.getLogger(__name__) -async def complete_run(run_id: str, thread_id: str, *, retry: int = 0, **context): +async def complete_run(run_id: str, thread_id: str, retry: int = 0): """ Wait for the run to complete. @@ -56,7 +56,7 @@ async def complete_run(run_id: str, thread_id: str, *, retry: int = 0, **context tool_outputs.append( { "tool_call_id": tool_call.id, # noqa - "output": fn(**kwargs, **context), + "output": fn(**kwargs), } ) logger.info("Submitting tool outputs for run %s", run_id) @@ -70,10 +70,7 @@ async def complete_run(run_id: str, thread_id: str, *, retry: int = 0, **context async def run( - assistant_id: str, - thread_id: str, - additional_instructions: str = None, - **context, + assistant_id: str, thread_id: str, additional_instructions: str = None ) -> str: """Run the assistant on the OpenAI thread.""" logger.info( @@ -87,7 +84,6 @@ async def run( thread_id=thread_id, assistant_id=assistant_id, additional_instructions=additional_instructions, - max_prompt_tokens=config.MAX_PROMPT_TOKENS, tools=[ utils.func_to_tool(tools.send_email), utils.func_to_tool(tools.web_search), @@ -96,7 +92,7 @@ async def run( utils.func_to_tool(tools.create_github_issue), ], ) - await complete_run(_run.id, thread_id, **context) + await complete_run(_run.id, thread_id) messages = await client.beta.threads.messages.list(thread_id=thread_id) for message in messages.data: diff --git a/sam/config.py b/sam/config.py index 2a9cfce..92d9ec7 100644 --- a/sam/config.py +++ b/sam/config.py @@ -5,9 +5,6 @@ SLACK_APP_TOKEN = os.getenv("SLACK_APP_TOKEN") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") OPENAI_ASSISTANT_ID = os.getenv("OPENAI_ASSISTANT_ID") - -MAX_PROMPT_TOKENS = int(os.getenv("MAX_PROMPT_TOKENS", "2048")) - REDIS_URL = os.getenv("REDIS_URL", "redis:///") RANDOM_RUN_RATIO = float(os.getenv("RANDOM_RUN_RATIO", "0")) TIMEZONE = ZoneInfo(os.getenv("TIMEZONE", "UTC")) diff --git a/sam/slack.py b/sam/slack.py index f5ab712..c9e0d67 100644 --- a/sam/slack.py +++ b/sam/slack.py @@ -72,6 +72,7 @@ async def handle_message(event: {str, Any}, say: AsyncSay): thread_id=thread_id, content=text, role="user", + file_ids=file_ids, ) logger.info( f"User={user_id} added Message={client_msg_id} added to Thread={thread_id}" @@ -108,10 +109,10 @@ async def process_run(event: {str, Any}, say: AsyncSay, voice_prompt: bool = Fal logger.debug(f"process_run={json.dumps(event)}") channel_id = event["channel"] user_id = event["user"] - user = await say.client.users_profile_get(user=user_id) - name = user["profile"]["display_name"] - email = user["profile"]["email"] - pronouns = user["profile"].get("pronouns") + profile = await say.client.users_profile_get(user=user_id) + name = profile["profile"]["display_name"] + email = profile["profile"]["email"] + pronouns = profile["profile"].get("pronouns") additional_instructions = ( f"You MUST ALWAYS address the user as <@{user_id}>.\n" f"You may refer to the user as {name}.\n" @@ -139,7 +140,6 @@ async def process_run(event: {str, Any}, say: AsyncSay, voice_prompt: bool = Fal thread_id=thread_id, assistant_id=config.OPENAI_ASSISTANT_ID, additional_instructions=additional_instructions, - **user["profile"], ) msg = await say( diff --git a/sam/tools.py b/sam/tools.py index b4ad52c..5b57aea 100644 --- a/sam/tools.py +++ b/sam/tools.py @@ -19,12 +19,12 @@ from sam.utils import logger -def send_email(to: str, subject: str, body: str, **_context): +def send_email(to: str, subject: str, body: str): """ - Send an email the given recipients. The user is always cc'd on the email. + Write and send email. Args: - to: Comma separated list of email addresses. + to: The recipient of the email, e.g. john.doe@voiio.de. subject: The subject of the email. body: The body of the email. """ @@ -39,8 +39,6 @@ def send_email(to: str, subject: str, body: str, **_context): msg = MIMEMultipart() msg["From"] = f"Sam <{from_email}>" msg["To"] = to - if cc := _context.get("email"): - msg["Cc"] = cc msg["Subject"] = subject msg.attach(MIMEText(body, "plain")) try: @@ -57,7 +55,7 @@ def send_email(to: str, subject: str, body: str, **_context): return "Email sent successfully!" -def web_search(query: str, **_context) -> str: +def web_search(query: str) -> str: """ Search the internet for information that matches the given query. @@ -88,7 +86,7 @@ def web_search(query: str, **_context) -> str: ) -def fetch_website(url: str, **_context) -> str: +def fetch_website(url: str) -> str: """ Fetch the website for the given URL and return the content as Markdown. @@ -113,7 +111,7 @@ def fetch_website(url: str, **_context) -> str: return "failed to parse website" -def fetch_coworker_emails(**_context) -> str: +def fetch_coworker_emails() -> str: """ Fetch profile data about your coworkers from Slack. diff --git a/sam/utils.py b/sam/utils.py index 121072c..36f0b61 100644 --- a/sam/utils.py +++ b/sam/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import datetime import enum import inspect import logging @@ -40,12 +41,7 @@ def func_to_tool(fn: callable) -> dict: The docstring should be formatted using the Google Napolean style. """ signature: inspect.Signature = inspect.signature(fn) - params = [ - param - for param in signature.parameters.values() - if not param.name.startswith("_") - ] - if params: + if signature.parameters: description, args = fn.__doc__.split("Args:") doc_data = yaml.safe_load(args.split("Returns:")[0]) else: @@ -64,11 +60,11 @@ def func_to_tool(fn: callable) -> dict: "type": type_map[param.annotation], "description": doc_data[param.name], } - for param in params + for param in signature.parameters.values() }, "required": [ param.name - for param in params + for param in signature.parameters.values() if param.default is inspect.Parameter.empty ], }, @@ -99,6 +95,10 @@ async def get_thread_id(slack_id) -> str: thread = await openai.AsyncOpenAI().beta.threads.create() thread_id = thread.id - await redis_client.set(slack_id, thread_id) + midnight = datetime.datetime.combine( + datetime.date.today(), datetime.time.max, tzinfo=config.TIMEZONE + ) + + await redis_client.set(slack_id, thread_id, exat=int(midnight.timestamp())) return thread_id From 5b5b4eaf793e69a6536b46596323d7e3367bbcf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rg=20Benesch?= <14031765+herrbenesch@users.noreply.github.com> Date: Thu, 18 Apr 2024 12:22:48 +0200 Subject: [PATCH 02/19] Fix #38 -- Add platform search to Sam (#41) --- pyproject.toml | 1 + sam/bot.py | 1 + sam/contrib/algolia/__init__.py | 74 +++++++++++++++++++++++++++++++++ sam/tools.py | 37 ++++++++++++++++- tests/test_tools.py | 26 ++++++++++++ 5 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 sam/contrib/algolia/__init__.py diff --git a/pyproject.toml b/pyproject.toml index 88730bc..fdbcf3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "slack-bolt", "openai==1.20.0", "pyyaml", + "algoliasearch", "sentry-sdk", ] diff --git a/sam/bot.py b/sam/bot.py index 71f7ef8..afad243 100644 --- a/sam/bot.py +++ b/sam/bot.py @@ -87,6 +87,7 @@ async def run( tools=[ utils.func_to_tool(tools.send_email), utils.func_to_tool(tools.web_search), + utils.func_to_tool(tools.platform_search), utils.func_to_tool(tools.fetch_website), utils.func_to_tool(tools.fetch_coworker_emails), utils.func_to_tool(tools.create_github_issue), diff --git a/sam/contrib/algolia/__init__.py b/sam/contrib/algolia/__init__.py new file mode 100644 index 0000000..8f480fd --- /dev/null +++ b/sam/contrib/algolia/__init__.py @@ -0,0 +1,74 @@ +"""AlgoliaSearch Search API client to perform searches on Algolia.""" + +import abc +import os + +import requests +from algoliasearch.search_client import SearchClient + +__all__ = ["get_client", "AlgoliaSearchAPIError"] + + +class AlgoliaSearchAPIError(requests.HTTPError): + pass + + +class AbstractAlgoliaSearch(abc.ABC): # pragma: no cover + @abc.abstractmethod + def search(self, query): + return NotImplemented + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + +class AlgoliaSearch(AbstractAlgoliaSearch): # pragma: no cover + def __init__(self, application_id, api_key, index): + super().__init__() + self.api_key = api_key + client = SearchClient.create(application_id, api_key) + self.index = client.init_index(index) + self.params = {} + + def search(self, query): + try: + return self.index.search( + query, + request_options={ + **self.params, + "length": 5, + }, + ) + except requests.HTTPError as e: + raise AlgoliaSearchAPIError("The Algolia search API call failed.") from e + + +class AlgoliaSearchStub(AbstractAlgoliaSearch): + def __init__(self): + self.headers = {} + self._objects = [ + { + "title": "Deutschland", + "parent_object_title": "Ferienangebote", + "public_url": "https://www.schulferien.org/deutschland/ferien/", + }, + ] + self.params = {} + + def search(self, query): + return {"hits": self._objects, "nbPages": 1} + + +def get_client(index=None) -> AbstractAlgoliaSearch: + index = index or os.getenv("ALGOLIA_SEARCH_INDEX", "event") + if api_key := os.getenv("ALGOLIA_SEARCH_API_KEY", None): # pragma: no cover + return AlgoliaSearch( + application_id=os.getenv("ALGOLIA_APPLICATION_ID"), + api_key=api_key, + index=index, + ) + else: + return AlgoliaSearchStub() diff --git a/sam/tools.py b/sam/tools.py index 5b57aea..d9f18b5 100644 --- a/sam/tools.py +++ b/sam/tools.py @@ -8,6 +8,7 @@ import urllib.parse from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText +from urllib.parse import urljoin import requests from bs4 import ParserRejectedMarkup @@ -15,7 +16,7 @@ from slack_sdk import WebClient, errors from sam import config -from sam.contrib import brave, github +from sam.contrib import algolia, brave, github from sam.utils import logger @@ -173,3 +174,37 @@ def create_github_issue(title: str, body: str) -> str: return "failed to create issue" else: return response["url"] + + +def platform_search(query: str) -> str: + """Search the platform for information that matches the given query. + + Return the title and URL of the matching objects in a user friendly format. + + Args: + query: The query to search for. + """ + with algolia.get_client() as api: + api.params.update( + { + "filters": "is_published:true", + "attributesToRetrieve": ["title", "parent_object_title", "public_url"], + } + ) + try: + results = api.search(query)["hits"] + except algolia.AlgoliaSearchAPIError: + logger.exception("Failed to search the platform for query: %s", query) + return "search failed" + else: + if not results: + logger.warning("No platform results found for query: %s", query) + return "no results found" + return json.dumps( + { + f"{hit['parent_object_title']}: {hit['title']}": urljoin( + "https://www.voiio.app", hit["public_url"] + ) + for hit in results + } + ) diff --git a/tests/test_tools.py b/tests/test_tools.py index d8efb66..8e1ac07 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,9 +1,12 @@ +import json import smtplib from unittest import mock import pytest +import requests from sam import tools +from sam.contrib import algolia @pytest.fixture @@ -53,3 +56,26 @@ def test_create_github_issue(): tools.create_github_issue("title", "body") == "https://www.youtube.com/watch?v=dQw4w9WgXcQ" ) + + +def test_platform_search(): + assert tools.platform_search("ferien") == json.dumps( + { + "Ferienangebote: Deutschland": "https://www.schulferien.org/deutschland/ferien/" + } + ) + + +def test_platform_search_with_error(): + with mock.patch( + "sam.contrib.algolia.AlgoliaSearchStub.search", + side_effect=algolia.AlgoliaSearchAPIError, + ): + assert tools.platform_search("something") == "search failed" + + +def test_platform_search_no_results(): + with mock.patch( + "sam.contrib.algolia.AlgoliaSearchStub.search", return_value={"hits": []} + ): + assert tools.platform_search("something") == "no results found" From 521a2b85e6a404bacd1a5cacc069e8188e23f5c7 Mon Sep 17 00:00:00 2001 From: Johannes Maron Date: Thu, 18 Apr 2024 13:59:36 +0200 Subject: [PATCH 03/19] Add funciton contexts and cc users on emails (#52) --- pyproject.toml | 2 +- sam/bot.py | 13 ++++++++----- sam/slack.py | 10 +++++----- sam/tools.py | 17 ++++++++++------- sam/utils.py | 11 ++++++++--- 5 files changed, 32 insertions(+), 21 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fdbcf3e..2b91bdb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,7 @@ dependencies = [ "redis", "requests", "slack-bolt", - "openai==1.20.0", + "openai~=1.20.0", "pyyaml", "algoliasearch", "sentry-sdk", diff --git a/sam/bot.py b/sam/bot.py index afad243..9b00585 100644 --- a/sam/bot.py +++ b/sam/bot.py @@ -3,13 +3,13 @@ import openai -from . import tools, utils +from . import config, tools, utils from .typing import Roles, RunStatus logger = logging.getLogger(__name__) -async def complete_run(run_id: str, thread_id: str, retry: int = 0): +async def complete_run(run_id: str, thread_id: str, *, retry: int = 0, **context): """ Wait for the run to complete. @@ -56,7 +56,7 @@ async def complete_run(run_id: str, thread_id: str, retry: int = 0): tool_outputs.append( { "tool_call_id": tool_call.id, # noqa - "output": fn(**kwargs), + "output": fn(**kwargs, **context), } ) logger.info("Submitting tool outputs for run %s", run_id) @@ -70,7 +70,10 @@ async def complete_run(run_id: str, thread_id: str, retry: int = 0): async def run( - assistant_id: str, thread_id: str, additional_instructions: str = None + assistant_id: str, + thread_id: str, + additional_instructions: str = None, + **context, ) -> str: """Run the assistant on the OpenAI thread.""" logger.info( @@ -93,7 +96,7 @@ async def run( utils.func_to_tool(tools.create_github_issue), ], ) - await complete_run(_run.id, thread_id) + await complete_run(_run.id, thread_id, **context) messages = await client.beta.threads.messages.list(thread_id=thread_id) for message in messages.data: diff --git a/sam/slack.py b/sam/slack.py index c9e0d67..a021ab2 100644 --- a/sam/slack.py +++ b/sam/slack.py @@ -72,7 +72,6 @@ async def handle_message(event: {str, Any}, say: AsyncSay): thread_id=thread_id, content=text, role="user", - file_ids=file_ids, ) logger.info( f"User={user_id} added Message={client_msg_id} added to Thread={thread_id}" @@ -109,10 +108,10 @@ async def process_run(event: {str, Any}, say: AsyncSay, voice_prompt: bool = Fal logger.debug(f"process_run={json.dumps(event)}") channel_id = event["channel"] user_id = event["user"] - profile = await say.client.users_profile_get(user=user_id) - name = profile["profile"]["display_name"] - email = profile["profile"]["email"] - pronouns = profile["profile"].get("pronouns") + profile = (await say.client.users_profile_get(user=user_id))["profile"] + name = profile["display_name"] + email = profile["email"] + pronouns = profile.get("pronouns") additional_instructions = ( f"You MUST ALWAYS address the user as <@{user_id}>.\n" f"You may refer to the user as {name}.\n" @@ -140,6 +139,7 @@ async def process_run(event: {str, Any}, say: AsyncSay, voice_prompt: bool = Fal thread_id=thread_id, assistant_id=config.OPENAI_ASSISTANT_ID, additional_instructions=additional_instructions, + **profile, ) msg = await say( diff --git a/sam/tools.py b/sam/tools.py index d9f18b5..1e2369d 100644 --- a/sam/tools.py +++ b/sam/tools.py @@ -20,12 +20,12 @@ from sam.utils import logger -def send_email(to: str, subject: str, body: str): +def send_email(to: str, subject: str, body: str, **_context): """ - Write and send email. + Send an email the given recipients. The user is always cc'd on the email. Args: - to: The recipient of the email, e.g. john.doe@voiio.de. + to: Comma separated list of email addresses. subject: The subject of the email. body: The body of the email. """ @@ -40,15 +40,18 @@ def send_email(to: str, subject: str, body: str): msg = MIMEMultipart() msg["From"] = f"Sam <{from_email}>" msg["To"] = to + if cc := _context.get("email"): + msg["Cc"] = cc msg["Subject"] = subject msg.attach(MIMEText(body, "plain")) + to_addr = to.split(",") + [cc] if cc else [] try: with smtplib.SMTP(url.hostname, url.port) as server: server.ehlo() server.starttls(context=context) server.ehlo() server.login(url.username, url.password) - server.sendmail(from_email, [to], msg.as_string()) + server.sendmail(from_email, to_addr, msg.as_string()) except smtplib.SMTPException: logger.exception("Failed to send email to: %s", to) return "Email not sent. An error occurred." @@ -56,7 +59,7 @@ def send_email(to: str, subject: str, body: str): return "Email sent successfully!" -def web_search(query: str) -> str: +def web_search(query: str, **_context) -> str: """ Search the internet for information that matches the given query. @@ -87,7 +90,7 @@ def web_search(query: str) -> str: ) -def fetch_website(url: str) -> str: +def fetch_website(url: str, **_context) -> str: """ Fetch the website for the given URL and return the content as Markdown. @@ -112,7 +115,7 @@ def fetch_website(url: str) -> str: return "failed to parse website" -def fetch_coworker_emails() -> str: +def fetch_coworker_emails(**_context) -> str: """ Fetch profile data about your coworkers from Slack. diff --git a/sam/utils.py b/sam/utils.py index 36f0b61..c92a9ed 100644 --- a/sam/utils.py +++ b/sam/utils.py @@ -41,7 +41,12 @@ def func_to_tool(fn: callable) -> dict: The docstring should be formatted using the Google Napolean style. """ signature: inspect.Signature = inspect.signature(fn) - if signature.parameters: + params = [ + param + for param in signature.parameters.values() + if not param.name.startswith("_") + ] + if params: description, args = fn.__doc__.split("Args:") doc_data = yaml.safe_load(args.split("Returns:")[0]) else: @@ -60,11 +65,11 @@ def func_to_tool(fn: callable) -> dict: "type": type_map[param.annotation], "description": doc_data[param.name], } - for param in signature.parameters.values() + for param in params }, "required": [ param.name - for param in signature.parameters.values() + for param in params if param.default is inspect.Parameter.empty ], }, From c2b2957737c8992a25b8d78f3e26ff148a416247 Mon Sep 17 00:00:00 2001 From: mostafa-anm Date: Thu, 18 Apr 2024 14:13:43 +0200 Subject: [PATCH 04/19] Ref #36 -- Add time of day as additional instruction (#47) --- sam/slack.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sam/slack.py b/sam/slack.py index a021ab2..0644026 100644 --- a/sam/slack.py +++ b/sam/slack.py @@ -3,6 +3,7 @@ import logging import random # nosec import urllib.request +from datetime import datetime from typing import Any import redis.asyncio as redis @@ -112,10 +113,12 @@ async def process_run(event: {str, Any}, say: AsyncSay, voice_prompt: bool = Fal name = profile["display_name"] email = profile["email"] pronouns = profile.get("pronouns") + local_time = datetime.now().astimezone(config.TIMEZONE).strftime("%H:%M") additional_instructions = ( f"You MUST ALWAYS address the user as <@{user_id}>.\n" f"You may refer to the user as {name}.\n" f"The user's email is {email}.\n" + f"The time is {local_time}.\n" ) if pronouns: additional_instructions += f"The user's pronouns are {pronouns}.\n" From 4b5e2113e1db44e591833872997aefab52430ef3 Mon Sep 17 00:00:00 2001 From: Rust Saiargaliev Date: Thu, 18 Apr 2024 14:45:36 +0200 Subject: [PATCH 05/19] Use HTML URL for creating GitHub issues (#53) `url` returns the API URL, which is useless for the end user. --- sam/contrib/github/__init__.py | 2 +- sam/tools.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sam/contrib/github/__init__.py b/sam/contrib/github/__init__.py index 125f3a1..74141fc 100644 --- a/sam/contrib/github/__init__.py +++ b/sam/contrib/github/__init__.py @@ -68,7 +68,7 @@ def create_issue(self, title, body): return { "title": title, "body": body, - "url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ", + "html_url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ", } diff --git a/sam/tools.py b/sam/tools.py index 1e2369d..6effc6a 100644 --- a/sam/tools.py +++ b/sam/tools.py @@ -176,7 +176,7 @@ def create_github_issue(title: str, body: str) -> str: logger.exception("Failed to create issue on GitHub") return "failed to create issue" else: - return response["url"] + return response["html_url"] def platform_search(query: str) -> str: From 4aa786fd062af52dc04314e3d9296b859bf4311a Mon Sep 17 00:00:00 2001 From: Rust Saiargaliev Date: Thu, 18 Apr 2024 14:49:04 +0200 Subject: [PATCH 06/19] Fix #55 -- await file creation coroutine --- sam/slack.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sam/slack.py b/sam/slack.py index 0644026..3ea03b2 100644 --- a/sam/slack.py +++ b/sam/slack.py @@ -60,12 +60,11 @@ async def handle_message(event: {str, Any}, say: AsyncSay): logger.info(f"User={user_id} added Audio={file['id']}") voice_prompt = True else: - file_ids.append( - client.files.create( - file=(file["name"], response.read()), - purpose="assistants", - ).id + new_file = await client.files.create( + file=(file["name"], response.read()), + purpose="assistants", ) + file_ids.append(new_file.id) logger.info( f"User={user_id} added File={file_ids[-1]} to Thread={thread_id}" ) From 82c7c7d25d6fccfcb5b9f6ee84a7e067f4bc453a Mon Sep 17 00:00:00 2001 From: Johannes Maron Date: Thu, 18 Apr 2024 19:47:38 +0200 Subject: [PATCH 07/19] Fix function call contexts (#57) --- sam/bot.py | 21 ++++++++++++++++----- sam/tools.py | 3 ++- sam/typing.py | 3 ++- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/sam/bot.py b/sam/bot.py index 9b00585..d92d751 100644 --- a/sam/bot.py +++ b/sam/bot.py @@ -27,7 +27,7 @@ async def complete_run(run_id: str, thread_id: str, *, retry: int = 0, **context match run.status: case status if status in [RunStatus.QUEUED, RunStatus.IN_PROGRESS]: await utils.backoff(retry) - await complete_run(run_id, thread_id, retry=retry + 1) + await complete_run(run_id, thread_id, retry=retry + 1, **context) case RunStatus.REQUIRES_ACTION: if ( run.required_action @@ -66,7 +66,13 @@ async def complete_run(run_id: str, thread_id: str, *, retry: int = 0, **context thread_id=thread_id, tool_outputs=tool_outputs, ) - await complete_run(run_id, thread_id) # we reset the retry counter + await complete_run( + run_id, thread_id, **context + ) # we reset the retry counter + case RunStatus.COMPLETED: + return + case _: + raise IOError(f"Run {run.id} failed with status {run.status}") async def run( @@ -77,11 +83,12 @@ async def run( ) -> str: """Run the assistant on the OpenAI thread.""" logger.info( - "Running assistant %s in thread %s with additional instructions: %r", + "Running assistant %s in thread %s with additional instructions", assistant_id, # noqa thread_id, - additional_instructions, ) + logger.debug("Additional instructions: %r", additional_instructions) + logger.debug("Context: %r", context) client: openai.AsyncOpenAI = openai.AsyncOpenAI() _run = await client.beta.threads.runs.create( thread_id=thread_id, @@ -96,7 +103,11 @@ async def run( utils.func_to_tool(tools.create_github_issue), ], ) - await complete_run(_run.id, thread_id, **context) + try: + await complete_run(_run.id, thread_id, **context) + except IOError: + logger.exception("Run %s failed", _run.id) + return "🤯" messages = await client.beta.threads.messages.list(thread_id=thread_id) for message in messages.data: diff --git a/sam/tools.py b/sam/tools.py index 6effc6a..7efbe55 100644 --- a/sam/tools.py +++ b/sam/tools.py @@ -40,11 +40,12 @@ def send_email(to: str, subject: str, body: str, **_context): msg = MIMEMultipart() msg["From"] = f"Sam <{from_email}>" msg["To"] = to + to_addr = to.split(",") if cc := _context.get("email"): msg["Cc"] = cc + to_addr.append(cc) msg["Subject"] = subject msg.attach(MIMEText(body, "plain")) - to_addr = to.split(",") + [cc] if cc else [] try: with smtplib.SMTP(url.hostname, url.port) as server: server.ehlo() diff --git a/sam/typing.py b/sam/typing.py index 18ecb3b..b0aea09 100644 --- a/sam/typing.py +++ b/sam/typing.py @@ -16,9 +16,10 @@ class RunStatus(enum.StrEnum): QUEUED = "queued" IN_PROGRESS = "in_progress" + COMPLETED = "completed" REQUIRES_ACTION = "requires_action" CANCELLING = "cancelling" CANCELLED = "cancelled" FAILED = "failed" - COMPLETED = "completed" EXPIRED = "expired" + INCOMPLETE = "incomplete" From 4a680f24f0513478c7c02df0e4f1551a54b61f43 Mon Sep 17 00:00:00 2001 From: Johannes Maron Date: Thu, 18 Apr 2024 19:53:01 +0200 Subject: [PATCH 08/19] Add support for multiple GH repos --- pyproject.toml | 6 +++++- sam/bot.py | 4 ++-- sam/config.py | 5 +++++ sam/contrib/github/__init__.py | 8 ++++---- sam/slack.py | 1 + sam/tools.py | 22 ++++++++++++++++------ sam/utils.py | 32 +++++++++++++++++++++++--------- tests/test_tools.py | 6 +++++- tests/test_utils.py | 18 ++++++++++++++++-- 9 files changed, 77 insertions(+), 25 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2b91bdb..03b83e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,8 +28,9 @@ sam = "sam.__main__:cli" [project.optional-dependencies] test = [ "pytest", - "pytest-cov", "pytest-asyncio", + "pytest-cov", + "pytest-env", ] lint = [ "bandit==1.7.8", @@ -58,6 +59,9 @@ minversion = "6.0" addopts = "--cov --tb=short -rxs" testpaths = ["tests"] +[tool.pytest_env] +GITHUB_REPOS = 'voiio/sam' + [tool.coverage.run] source = ["sam"] diff --git a/sam/bot.py b/sam/bot.py index d92d751..5530a1c 100644 --- a/sam/bot.py +++ b/sam/bot.py @@ -55,8 +55,8 @@ async def complete_run(run_id: str, thread_id: str, *, retry: int = 0, **context ) tool_outputs.append( { - "tool_call_id": tool_call.id, # noqa - "output": fn(**kwargs, **context), + "tool_call_id": tool_call.id, + "output": fn(**kwargs, _context={**context}), } ) logger.info("Submitting tool outputs for run %s", run_id) diff --git a/sam/config.py b/sam/config.py index 92d9ec7..398293f 100644 --- a/sam/config.py +++ b/sam/config.py @@ -1,3 +1,4 @@ +import enum import os from zoneinfo import ZoneInfo @@ -14,3 +15,7 @@ SENTRY_DSN = os.getenv("SENTRY_DSN") GITHUB_ORG = os.getenv("GITHUB_ORG") GITHUB_REPOSITORY = os.getenv("GITHUB_REPOSITORY") +GITHUB_REPOS = enum.StrEnum( + "GITHUB_REPOS", + {repo: repo for repo in os.getenv("GITHUB_REPOS", "").split(",") if repo}, +) diff --git a/sam/contrib/github/__init__.py b/sam/contrib/github/__init__.py index 74141fc..04a2084 100644 --- a/sam/contrib/github/__init__.py +++ b/sam/contrib/github/__init__.py @@ -16,7 +16,7 @@ class GitHubAPIError(requests.HTTPError): class AbstractGitHubAPIWrapper(abc.ABC): # pragma: no cover @abc.abstractmethod - def create_issue(self, title, body): + def create_issue(self, title, body, repo): return NotImplemented @abc.abstractmethod @@ -42,9 +42,9 @@ def __init__(self, token): } ) - def create_issue(self, title, body): + def create_issue(self, title, body, repo): response = self.post( - f"{self.endpoint}/repos/{config.GITHUB_ORG}/{config.GITHUB_REPOSITORY}/issues", + f"{self.endpoint}/repos/{repo}/issues", json={"title": title, "body": body}, ) try: @@ -64,7 +64,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_val, exc_tb): pass - def create_issue(self, title, body): + def create_issue(self, title, body, repo): return { "title": title, "body": body, diff --git a/sam/slack.py b/sam/slack.py index 3ea03b2..22742bd 100644 --- a/sam/slack.py +++ b/sam/slack.py @@ -1,4 +1,5 @@ import asyncio +import enum import json import logging import random # nosec diff --git a/sam/tools.py b/sam/tools.py index 7efbe55..e886a00 100644 --- a/sam/tools.py +++ b/sam/tools.py @@ -15,12 +15,13 @@ from markdownify import markdownify as md from slack_sdk import WebClient, errors +import sam.config from sam import config from sam.contrib import algolia, brave, github from sam.utils import logger -def send_email(to: str, subject: str, body: str, **_context): +def send_email(to: str, subject: str, body: str, _context=None): """ Send an email the given recipients. The user is always cc'd on the email. @@ -29,6 +30,7 @@ def send_email(to: str, subject: str, body: str, **_context): subject: The subject of the email. body: The body of the email. """ + _context = _context or {} email_url = os.getenv("EMAIL_URL") from_email = os.getenv("FROM_EMAIL", "sam@voiio.de") email_white_list = os.getenv("EMAIL_WHITE_LIST") @@ -60,7 +62,7 @@ def send_email(to: str, subject: str, body: str, **_context): return "Email sent successfully!" -def web_search(query: str, **_context) -> str: +def web_search(query: str, _context=None) -> str: """ Search the internet for information that matches the given query. @@ -91,7 +93,7 @@ def web_search(query: str, **_context) -> str: ) -def fetch_website(url: str, **_context) -> str: +def fetch_website(url: str, _context=None) -> str: """ Fetch the website for the given URL and return the content as Markdown. @@ -116,7 +118,7 @@ def fetch_website(url: str, **_context) -> str: return "failed to parse website" -def fetch_coworker_emails(**_context) -> str: +def fetch_coworker_emails(_context=None) -> str: """ Fetch profile data about your coworkers from Slack. @@ -156,7 +158,9 @@ def fetch_coworker_emails(**_context) -> str: return json.dumps(profiles) -def create_github_issue(title: str, body: str) -> str: +def create_github_issue( + title: str, body: str, repo: "sam.config.GITHUB_REPOS", _context=None +) -> str: """ Create an issue on GitHub with the given title and body. @@ -166,13 +170,19 @@ def create_github_issue(title: str, body: str) -> str: You should provide ideas for a potential solution, including code snippet examples in a Markdown code block. + You MUST ALWAYS write the issue in English. + Args: title: The title of the issue. body: The body of the issue, markdown supported. + repo: The repository to create the issue in. """ + if repo not in config.GITHUB_REPOS.__members__: + logger.warning("Invalid repo: %s", repo) + return "invalid repo" with github.get_client() as api: try: - response = api.create_issue(title, body) + response = api.create_issue(title, body, repo) except github.GitHubAPIError: logger.exception("Failed to create issue on GitHub") return "failed to create issue" diff --git a/sam/utils.py b/sam/utils.py index c92a9ed..c55328c 100644 --- a/sam/utils.py +++ b/sam/utils.py @@ -6,6 +6,7 @@ import inspect import logging import random +import typing import openai import redis.asyncio as redis @@ -29,8 +30,6 @@ float: "number", list: "array", dict: "object", - enum.StrEnum: "string", - enum.IntEnum: "integer", } @@ -51,6 +50,8 @@ def func_to_tool(fn: callable) -> dict: doc_data = yaml.safe_load(args.split("Returns:")[0]) else: description = fn.__doc__ + doc_data = {} + return { "type": "function", "function": { @@ -60,13 +61,7 @@ def func_to_tool(fn: callable) -> dict: ), "parameters": { "type": "object", - "properties": { - param.name: { - "type": type_map[param.annotation], - "description": doc_data[param.name], - } - for param in params - }, + "properties": dict(params_to_props(fn, params, doc_data)), "required": [ param.name for param in params @@ -77,6 +72,25 @@ def func_to_tool(fn: callable) -> dict: } +def params_to_props(fn, params, doc_data): + types = typing.get_type_hints(fn) + for param in params: + if param.name.startswith("_"): + continue + param_type = types[param.name] + if param_type in type_map: + yield param.name, { + "type": type_map[types[param.name]], + "description": doc_data[param.name], + } + elif issubclass(param_type, enum.StrEnum): + yield param.name, { + "type": "string", + "enum": [value.value for value in param_type], + "description": doc_data[param.name], + } + + async def backoff(retries: int, max_jitter: int = 10): """Exponential backoff timer with a random jitter.""" await asyncio.sleep(2**retries + random.random() * max_jitter) # nosec diff --git a/tests/test_tools.py b/tests/test_tools.py index 8e1ac07..50707a1 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -53,11 +53,15 @@ def test_web_search__with_coordinates(): def test_create_github_issue(): assert ( - tools.create_github_issue("title", "body") + tools.create_github_issue("title", "body", "voiio/sam") == "https://www.youtube.com/watch?v=dQw4w9WgXcQ" ) +def test_create_github_issue__invalid_repo(): + assert tools.create_github_issue("title", "body", "not-valid") == "invalid repo" + + def test_platform_search(): assert tools.platform_search("ferien") == json.dumps( { diff --git a/tests/test_utils.py b/tests/test_utils.py index 168d117..0e25d5c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,10 @@ from __future__ import annotations +import enum + import pytest +import tests.test_tools from sam import utils @@ -10,13 +13,19 @@ async def test_backoff(): await utils.backoff(0, max_jitter=0) +BloodTypes = enum.StrEnum("BloodTypes", {"A": "A", "B": "B"}) + + def test_func_to_tool(): - def fn(a: int, b: str) -> int: + def fn( + a: int, b: str, blood_types: "tests.test_utils.BloodTypes", _context=None + ) -> str: """Function description. Args: a: Description of a. b: Description of b. + blood_types: Description of bool_types. Returns: Description of return value. @@ -40,8 +49,13 @@ def fn(a: int, b: str) -> int: "type": "string", "description": "Description of b.", }, + "blood_types": { + "type": "string", + "enum": ["A", "B"], + "description": "Description of bool_types.", + }, }, - "required": ["a", "b"], + "required": ["a", "b", "blood_types"], }, }, } From 9d177a64b6a61184aa49d279bc9179deec2aecfc Mon Sep 17 00:00:00 2001 From: Johannes Maron Date: Fri, 19 Apr 2024 12:24:44 +0200 Subject: [PATCH 09/19] Fix #61 -- Add certifi to prevent SSL issues --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 03b83e4..da1bcf0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ requires-python = ">=3.11" dependencies = [ "aiohttp", "click", + "certifi", "markdownify", "redis", "requests", From 2d10a70fc3458fc5f3149392e2fd8dd0a0574503 Mon Sep 17 00:00:00 2001 From: Rust Saiargaliev Date: Fri, 19 Apr 2024 14:50:52 +0200 Subject: [PATCH 10/19] Drop obsolete config vars (#60) Usage was dropped in #58 --- sam/config.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sam/config.py b/sam/config.py index 398293f..178e895 100644 --- a/sam/config.py +++ b/sam/config.py @@ -13,8 +13,6 @@ BRAVE_SEARCH_LONGITUDE = os.getenv("BRAVE_SEARCH_LONGITUDE") BRAVE_SEARCH_LATITUDE = os.getenv("BRAVE_SEARCH_LATITUDE") SENTRY_DSN = os.getenv("SENTRY_DSN") -GITHUB_ORG = os.getenv("GITHUB_ORG") -GITHUB_REPOSITORY = os.getenv("GITHUB_REPOSITORY") GITHUB_REPOS = enum.StrEnum( "GITHUB_REPOS", {repo: repo for repo in os.getenv("GITHUB_REPOS", "").split(",") if repo}, From e6dcbdecf10b620a33ce427b18266c5b89baac66 Mon Sep 17 00:00:00 2001 From: Johannes Maron Date: Fri, 19 Apr 2024 17:03:22 +0200 Subject: [PATCH 11/19] Add instruction upload functionality to SAM (#63) --- sam/__main__.py | 63 ++++++++++++++++++++++++++++++----- sam/config.py | 28 ++++++++++++++++ tests/fixtures/harry.md | 1 + tests/fixtures/pyproject.toml | 8 +++++ tests/fixtures/security.md | 1 + tests/test_config.py | 15 +++++++++ tests/test_main.py | 40 ++++++++++++++++++++++ 7 files changed, 147 insertions(+), 9 deletions(-) create mode 100644 tests/fixtures/harry.md create mode 100644 tests/fixtures/pyproject.toml create mode 100644 tests/fixtures/security.md create mode 100644 tests/test_config.py create mode 100644 tests/test_main.py diff --git a/sam/__main__.py b/sam/__main__.py index 37c65c1..d2890a4 100644 --- a/sam/__main__.py +++ b/sam/__main__.py @@ -1,8 +1,10 @@ import asyncio import logging +import os import sys import click +import openai import sentry_sdk from . import config @@ -11,19 +13,19 @@ @click.group() -@click.option("-v", "--verbose", is_flag=True, help="Enables verbose mode.") -def cli(verbose): +def cli(): """Sam – cuz your company is nothing with Sam.""" - handler = logging.StreamHandler(sys.stdout) - handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s")) - logger = logging.getLogger("sam") - logger.addHandler(handler) - logger.setLevel(logging.DEBUG if verbose else logging.INFO) @cli.group(chain=True) -def run(): - """Run an assistent bot, currently only Slack is supported.""" +@click.option("-v", "--verbose", is_flag=True, help="Enables verbose mode.") +def run(verbose): + """Run an assistent bot.""" + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s")) + logging.basicConfig( + handlers=[handler], level=logging.DEBUG if verbose else logging.INFO + ) @run.command() @@ -40,5 +42,48 @@ def slack(): ) +@cli.group(chain=True) +@click.option("-v", "--verbose", is_flag=True, help="Enables verbose mode.") +def assistants(verbose): + """Manage OpenAI assistants.""" + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(logging.Formatter("%(message)s")) + logging.basicConfig( + handlers=[handler], level=logging.DEBUG if verbose else logging.INFO + ) + + +@assistants.command(name="list") +def _list(): + """List all assistants configured in your project.""" + assistant_list = list(config.load_assistants()) + for assistant_config in assistant_list: + click.echo( + f"{assistant_config.name} ({assistant_config.project}): {assistant_config.assistant_id}" + ) + if not assistant_list: + click.echo("No assistants configured.") + + +@assistants.command() +def upload(): + """Compile and upload all assistants system prompts to OpenAI.""" + assistant_list = list(config.load_assistants()) + for assistant_config in assistant_list: + click.echo(f"Uploading {assistant_config.name}...", nl=False) + project_api_key_name = ( + f"OPENAI_{assistant_config.project.replace('-', '_').upper()}_API_KEY" + ) + project_api_key = os.getenv(project_api_key_name) + with openai.OpenAI(api_key=project_api_key) as client: + client.beta.assistants.update( + assistant_id=assistant_config.assistant_id, + instructions=assistant_config.system_prompt, + ) + click.echo(" Done!") + if not assistant_list: + click.echo("No assistants configured.") + + if __name__ == "__main__": cli() diff --git a/sam/config.py b/sam/config.py index 178e895..62cfe78 100644 --- a/sam/config.py +++ b/sam/config.py @@ -1,5 +1,11 @@ +from __future__ import annotations + import enum import os +import tomllib +from dataclasses import dataclass +from functools import cached_property +from pathlib import Path from zoneinfo import ZoneInfo SLACK_BOT_TOKEN = os.getenv("SLACK_BOT_TOKEN") @@ -17,3 +23,25 @@ "GITHUB_REPOS", {repo: repo for repo in os.getenv("GITHUB_REPOS", "").split(",") if repo}, ) + + +@dataclass +class AssistantConfig: + name: str + assistant_id: str + instructions: list[str] + project: str + + @cached_property + def system_prompt(self): + return "\n\n".join( + Path(instruction).read_text() for instruction in self.instructions + ) + + +def load_assistants(): + with Path("pyproject.toml").open("rb") as fs: + for assistant in ( + tomllib.load(fs).get("tool", {}).get("sam", {}).get("assistants", []) + ): + yield AssistantConfig(**assistant) diff --git a/tests/fixtures/harry.md b/tests/fixtures/harry.md new file mode 100644 index 0000000..ab8912c --- /dev/null +++ b/tests/fixtures/harry.md @@ -0,0 +1 @@ +You are a wizard, Harry. \ No newline at end of file diff --git a/tests/fixtures/pyproject.toml b/tests/fixtures/pyproject.toml new file mode 100644 index 0000000..f918cd7 --- /dev/null +++ b/tests/fixtures/pyproject.toml @@ -0,0 +1,8 @@ +[[tool.sam.assistants]] +name = "Harry" +assistant_id = "asst_1234057341258907" +project="default-project" +instructions = [ + "harry.md", + "security.md", +] diff --git a/tests/fixtures/security.md b/tests/fixtures/security.md new file mode 100644 index 0000000..2bfaa35 --- /dev/null +++ b/tests/fixtures/security.md @@ -0,0 +1 @@ +You mustn't tell lies. \ No newline at end of file diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..04ca049 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,15 @@ +from sam.config import AssistantConfig + + +class TestAssistantConfig: + + def test_system_prompt(self): + assert ( + AssistantConfig( + name="Test", + assistant_id="test", + project="test", + instructions=["tests/fixtures/harry.md", "tests/fixtures/security.md"], + ).system_prompt + == "You are a wizard, Harry.\n\nYou mustn't tell lies." + ) diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..66f3513 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,40 @@ +import os +from unittest import mock + +from click.testing import CliRunner + +from sam.__main__ import cli + + +class TestCli: + + def test_list(self, monkeypatch): + monkeypatch.chdir("tests/fixtures") + runner = CliRunner() + result = runner.invoke(cli, ["assistants", "list"]) + assert result.exit_code == 0 + assert "Harry (default-project): asst_1234057341258907" in result.output + + def test_list__empty(self): + runner = CliRunner() + result = runner.invoke(cli, ["assistants", "list"]) + assert result.exit_code == 0 + assert "No assistants configured." in result.output + + def test_upload(self, monkeypatch): + monkeypatch.chdir("tests/fixtures") + client = mock.MagicMock() + monkeypatch.setattr("openai.OpenAI", lambda api_key: client) + runner = CliRunner() + result = runner.invoke(cli, ["assistants", "upload"]) + assert result.exit_code == 0 + assert "Uploading Harry... Done!" in result.output + assert client.__enter__().beta.assistants.update.called + + def test_upload__empty(self, monkeypatch): + client = mock.MagicMock() + monkeypatch.setattr("openai.OpenAI", lambda api_key: client) + runner = CliRunner() + result = runner.invoke(cli, ["assistants", "upload"]) + assert result.exit_code == 0 + assert "No assistants configured." in result.output From 85bd1f118b3cc263a8b5d07938020db4188af81a Mon Sep 17 00:00:00 2001 From: Johannes Maron Date: Fri, 19 Apr 2024 17:26:01 +0200 Subject: [PATCH 12/19] Add PyPi package setup --- .github/workflows/ci.yml | 15 +++++++++++++++ .github/workflows/release.yml | 21 ++++++++++++++++++++ README.md | 4 ++++ pyproject.toml | 36 +++++++++++++++++++++++++++++------ 4 files changed, 70 insertions(+), 6 deletions(-) create mode 100644 .github/workflows/release.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ad2a396..5a0a550 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,6 +27,21 @@ jobs: - run: python -m pip install -e .[lint] - run: ${{ matrix.lint-command }} + + dist: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.x" + - run: python -m pip install --upgrade pip build wheel twine + - run: python -m build --sdist --wheel + - run: python -m twine check dist/* + - uses: actions/upload-artifact@v4 + with: + path: dist/* + pytest: runs-on: ubuntu-latest steps: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..803231a --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,21 @@ +name: Release + +on: + release: + types: [published] + +jobs: + + PyPi: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.x" + - run: python -m pip install --upgrade pip build wheel twine + - run: python -m build --sdist --wheel + - run: python -m twine upload dist/* + env: + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} diff --git a/README.md b/README.md index 5a1476b..7e75fa7 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,7 @@ +[![PyPi Version](https://img.shields.io/pypi/v/opensam.svg)](https://pypi.python.org/pypi/opensam/) +[![Test Coverage](https://codecov.io/gh/voiio/sam/branch/main/graph/badge.svg)](https://codecov.io/gh/voiio/sam) +[![GitHub License](https://img.shields.io/github/license/voiio/sam)](https://raw.githubusercontent.com/voiio/sam/master/LICENSE) + # Sam – cuz your company is nothing without Sam ![meme](https://repository-images.githubusercontent.com/726003479/24d020ac-3ac5-401c-beae-7a6103c4e7ae) diff --git a/pyproject.toml b/pyproject.toml index da1bcf0..60f957c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,9 @@ +[build-system] +requires = ["flit_core>=3.2", "flit_scm", "wheel"] +build-backend = "flit_scm:buildapi" + [project] -name = "Sam" +name = "OpenSam" authors = [ { name = "Johannes Maron", email = "johannes@maron.family" }, ] @@ -7,8 +11,32 @@ readme = "README.md" license = { file = "LICENSE" } keywords = ["GPT", "AI", "Slack", "OpenAI", "bot"] dynamic = ["version", "description"] +classifiers = [ + "Development Status :: 2 - Pre-Alpha", + "Programming Language :: Python", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Topic :: Software Development", + "Topic :: Communications :: Chat", + "Topic :: Internet :: WWW/HTTP", + "Topic :: Communications :: Email", + "Topic :: Games/Entertainment :: Role-Playing", + "Topic :: Multimedia :: Sound/Audio :: Conversion", + "Topic :: Multimedia :: Sound/Audio :: Speech", + "Topic :: Office/Business", + "Topic :: Office/Business :: Groupware", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Scientific/Engineering :: Human Machine Interfaces", + "Topic :: Text Processing :: Markup :: Markdown", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.12", +] +requires-python = ">=3.12" -requires-python = ">=3.11" dependencies = [ "aiohttp", "click", @@ -44,10 +72,6 @@ lint = [ Project-URL = "https://github.com/voiio/Sam" Changelog = "https://github.com/voiio/Sam/releases" -[build-system] -requires = ["flit-scm", "wheel"] -build-backend = "flit_scm:buildapi" - [tool.flit.module] name = "sam" From 905a35791420e6e9ddf0e1f032170d3054f2a82d Mon Sep 17 00:00:00 2001 From: Johannes Maron Date: Fri, 19 Apr 2024 17:32:00 +0200 Subject: [PATCH 13/19] Add Python 3.11 version support --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 60f957c..09869ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,9 +33,10 @@ classifiers = [ "Topic :: Text Processing :: Markup :: Markdown", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", ] -requires-python = ">=3.12" +requires-python = ">=3.11" dependencies = [ "aiohttp", From efb1d6668530f9c039012b4dff184b9c41dea48a Mon Sep 17 00:00:00 2001 From: Johannes Maron Date: Fri, 19 Apr 2024 17:41:26 +0200 Subject: [PATCH 14/19] Update release processes to OpenID Connect in PyPI --- .github/workflows/release.yml | 36 ++++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 803231a..2ea5558 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -6,16 +6,30 @@ on: jobs: - PyPi: + release-build: runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 + with: + python-version: "3.x" + - run: python -m build + - uses: actions/upload-artifact@v4 + with: + name: release-dists + path: dist/ + + pypi-publish: + runs-on: ubuntu-latest + needs: + - release-build + permissions: + id-token: write + steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - with: - python-version: "3.x" - - run: python -m pip install --upgrade pip build wheel twine - - run: python -m build --sdist --wheel - - run: python -m twine upload dist/* - env: - TWINE_USERNAME: __token__ - TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} + - uses: actions/download-artifact@v4 + with: + name: release-dists + path: dist/ + - uses: pypa/gh-action-pypi-publish@release/v1 From 8b73624425bebcffc9af996429d78c93ed2121a0 Mon Sep 17 00:00:00 2001 From: Johannes Maron Date: Fri, 19 Apr 2024 17:44:56 +0200 Subject: [PATCH 15/19] Update build dependencies --- .github/workflows/ci.yml | 2 +- .github/workflows/release.yml | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5a0a550..8319d93 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,7 +35,7 @@ jobs: - uses: actions/setup-python@v5 with: python-version: "3.x" - - run: python -m pip install --upgrade pip build wheel twine + - run: python -m pip install --upgrade pip build twine - run: python -m build --sdist --wheel - run: python -m twine check dist/* - uses: actions/upload-artifact@v4 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 2ea5558..5119b56 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,7 +14,8 @@ jobs: - uses: actions/setup-python@v4 with: python-version: "3.x" - - run: python -m build + - run: python -m pip install --upgrade pip build + - run: python -m build --sdist --wheel - uses: actions/upload-artifact@v4 with: name: release-dists From 5ea0c176c05205c8b6c9cd89d279a7ff0841e0f6 Mon Sep 17 00:00:00 2001 From: Johannes Maron Date: Sat, 20 Apr 2024 00:19:05 +0200 Subject: [PATCH 16/19] Add more tests --- .github/workflows/ci.yml | 8 ++ pyproject.toml | 2 +- sam/__main__.py | 8 +- sam/bot.py | 217 +++++++++++++++++++++++++++------------ sam/config.py | 2 + sam/slack.py | 211 ++++++++++++++++++------------------- sam/typing.py | 3 + tests/test_bot.py | 106 +++++++++++++++++-- tests/test_slack.py | 207 +++++++++++++++++++++++++++++++++++++ 9 files changed, 584 insertions(+), 180 deletions(-) create mode 100644 tests/test_slack.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8319d93..202fc36 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -44,6 +44,14 @@ jobs: pytest: runs-on: ubuntu-latest + services: + redis: + image: redis + ports: + - 6379:6379 + options: --entrypoint redis-server + env: + REDIS_URL: redis:/// steps: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 diff --git a/pyproject.toml b/pyproject.toml index 09869ae..75cc388 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,7 +82,7 @@ fallback_version = "1.0.0" [tool.pytest.ini_options] minversion = "6.0" -addopts = "--cov --tb=short -rxs" +addopts = "--tb=short -rxs" testpaths = ["tests"] [tool.pytest_env] diff --git a/sam/__main__.py b/sam/__main__.py index d2890a4..4806674 100644 --- a/sam/__main__.py +++ b/sam/__main__.py @@ -22,7 +22,9 @@ def cli(): def run(verbose): """Run an assistent bot.""" handler = logging.StreamHandler(sys.stdout) - handler.setFormatter(logging.Formatter("%(asctime)s %(levelname)s %(message)s")) + handler.setFormatter( + logging.Formatter("%(asctime)s %(levelname)7s %(name)s - %(message)s") + ) logging.basicConfig( handlers=[handler], level=logging.DEBUG if verbose else logging.INFO ) @@ -33,12 +35,12 @@ def slack(): """Run the Slack bot demon.""" from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler - from .slack import app + from .slack import get_app loop = asyncio.get_event_loop() loop.run_until_complete( - AsyncSocketModeHandler(app, config.SLACK_APP_TOKEN).start_async() + AsyncSocketModeHandler(get_app(), config.SLACK_APP_TOKEN).start_async() ) diff --git a/sam/bot.py b/sam/bot.py index 5530a1c..dec78f1 100644 --- a/sam/bot.py +++ b/sam/bot.py @@ -1,10 +1,13 @@ +from __future__ import annotations + import json import logging +from pathlib import Path import openai from . import config, tools, utils -from .typing import Roles, RunStatus +from .typing import AUDIO_FORMATS, Roles, RunStatus logger = logging.getLogger(__name__) @@ -17,6 +20,8 @@ async def complete_run(run_id: str, thread_id: str, *, retry: int = 0, **context Raises: RecursionError: If the run status is not "completed" after 10 retries. + IOError: If the run status is not "completed" or "requires_action". + ValueError: If the run requires tools but none are provided. """ client: openai.AsyncOpenAI = openai.AsyncOpenAI() if retry > 10: @@ -29,56 +34,60 @@ async def complete_run(run_id: str, thread_id: str, *, retry: int = 0, **context await utils.backoff(retry) await complete_run(run_id, thread_id, retry=retry + 1, **context) case RunStatus.REQUIRES_ACTION: - if ( - run.required_action - and run.required_action.submit_tool_outputs - and run.required_action.submit_tool_outputs.tool_calls - ): - tool_outputs = [] - for tool_call in run.required_action.submit_tool_outputs.tool_calls: - kwargs = json.loads(tool_call.function.arguments) - try: - fn = getattr(tools, tool_call.function.name) - except KeyError: - logger.exception( - "Tool %s not found, cancelling run %s", - tool_call.function.name, - run_id, - ) - await client.beta.threads.runs.cancel( - run_id=run_id, thread_id=thread_id - ) - return - logger.info("Running tool %s", tool_call.function.name) - logger.debug( - "Tool %s arguments: %r", tool_call.function.name, kwargs - ) - tool_outputs.append( - { - "tool_call_id": tool_call.id, - "output": fn(**kwargs, _context={**context}), - } - ) - logger.info("Submitting tool outputs for run %s", run_id) - logger.debug("Tool outputs: %r", tool_outputs) - await client.beta.threads.runs.submit_tool_outputs( - run.id, # noqa - thread_id=thread_id, - tool_outputs=tool_outputs, - ) - await complete_run( - run_id, thread_id, **context - ) # we reset the retry counter + await call_tools(run, **context) + # after we submit the tool outputs, we reset the retry counter + await complete_run(run_id, thread_id, **context) case RunStatus.COMPLETED: return case _: raise IOError(f"Run {run.id} failed with status {run.status}") -async def run( +async def call_tools(run: openai.types.beta.threads.Run, **context) -> None: + """ + Call the tools required by the run. + + Raises: + IOError: If a tool is not found. + ValueError: If the run does not require any tools. + """ + client: openai.AsyncOpenAI = openai.AsyncOpenAI() + if not (run.required_action and run.required_action.submit_tool_outputs): + raise ValueError(f"Run {run.id} does not require any tools") + tool_outputs = [] + for tool_call in run.required_action.submit_tool_outputs.tool_calls: + kwargs = json.loads(tool_call.function.arguments) + try: + fn = getattr(tools, tool_call.function.name) + except KeyError as e: + await client.beta.threads.runs.cancel( + run_id=run.id, thread_id=run.thread_id + ) + raise IOError( + f"Tool {tool_call.function.name} not found, cancelling run {run.id}" + ) from e + logger.info("Running tool %s", tool_call.function.name) + logger.debug("Tool %s arguments: %r", tool_call.function.name, kwargs) + tool_outputs.append( + { + "tool_call_id": tool_call.id, + "output": fn(**kwargs, _context={**context}), + } + ) + logger.info("Submitting tool outputs for run %s", run.id) + logger.debug("Tool outputs: %r", tool_outputs) + await client.beta.threads.runs.submit_tool_outputs( + run.id, # noqa + thread_id=run.thread_id, + tool_outputs=tool_outputs, + ) + + +async def execute_run( assistant_id: str, thread_id: str, additional_instructions: str = None, + file_ids: list[str] = None, **context, ) -> str: """Run the assistant on the OpenAI thread.""" @@ -90,7 +99,7 @@ async def run( logger.debug("Additional instructions: %r", additional_instructions) logger.debug("Context: %r", context) client: openai.AsyncOpenAI = openai.AsyncOpenAI() - _run = await client.beta.threads.runs.create( + run = await client.beta.threads.runs.create( thread_id=thread_id, assistant_id=assistant_id, additional_instructions=additional_instructions, @@ -104,35 +113,115 @@ async def run( ], ) try: - await complete_run(_run.id, thread_id, **context) - except IOError: - logger.exception("Run %s failed", _run.id) + await complete_run(run.id, thread_id, **context) + except (RecursionError, IOError, ValueError): + logger.exception("Run %s failed", run.id) + return "🤯" + + try: + return await fetch_latest_assistant_message(thread_id) + except ValueError: + logger.exception("No assistant message found") return "🤯" + +async def fetch_latest_assistant_message(thread_id: str) -> str: + """ + Fetch the latest assistant message from the thread. + + Raises: + ValueError: If no assistant message is found. + """ + client: openai.AsyncOpenAI = openai.AsyncOpenAI() messages = await client.beta.threads.messages.list(thread_id=thread_id) for message in messages.data: if message.role == Roles.ASSISTANT: - message_content = message.content[0].text + try: + return await annotate_citations(message.content[0].text) + except IndexError as e: + raise ValueError("No assistant message found") from e + - annotations = message_content.annotations - citations = [] +async def annotate_citations( + message_content: openai.types.beta.threads.TextContentBlock, +) -> str: + """Annotate citations in the text using footnotes and the file metadata.""" + client: openai.AsyncOpenAI = openai.AsyncOpenAI() + citations = [] - # Iterate over the annotations and add footnotes - for index, annotation in enumerate(annotations): - message_content.value = message_content.value.replace( - annotation.text, f" [{index}]" - ) + # Iterate over the annotations and add footnotes + for index, annotation in enumerate(message_content.annotations): + message_content.value = message_content.value.replace( + annotation.text, f" [{index}]" + ) - if file_citation := getattr(annotation, "file_citation", None): - cited_file = client.files.retrieve(file_citation.file_id) - citations.append( - f"[{index}] {file_citation.quote} — {cited_file.filename}" + if file_citation := getattr(annotation, "file_citation", None): + cited_file = await client.files.retrieve(file_citation.file_id) + citations.append(f"[{index}] {file_citation.quote} — {cited_file.filename}") + elif file_path := getattr(annotation, "file_path", None): + cited_file = await client.files.retrieve(file_path.file_id) + citations.append(f"[{index}]({cited_file.filename})") + + # Add footnotes to the end of the message before displaying to user + message_content.value += "\n" + "\n".join(citations) + return message_content.value + + +async def add_message( + thread_id: str, + content: str, + files: [(str, bytes)] = None, +) -> tuple[list[str], bool]: + """Add a message to the thread.""" + logger.info(f"Adding message to thread={thread_id}") + client: openai.AsyncOpenAI = openai.AsyncOpenAI() + file_ids = [] + voice_prompt = False + for file_name, file_content in files or []: + if Path(file_name).suffix.lstrip(".") in AUDIO_FORMATS: + logger.debug("Transcribing audio file %s", file_name) + content += ( + "\n" + + ( + await client.audio.transcriptions.create( + model="whisper-1", + file=(file_name, file_content), ) - elif file_path := getattr(annotation, "file_path", None): - cited_file = client.files.retrieve(file_path.file_id) - citations.append(f"[{index}]({cited_file.filename})") + ).text + ) + voice_prompt = True + else: + logger.debug("Uploading file %s", file_name) + new_file = await client.files.create( + file=(file_name, file_content), + purpose="assistants", + ) + file_ids.append(new_file.id) + await client.beta.threads.messages.create( + thread_id=thread_id, + content=content, + role="assistant", + file_ids=file_ids, + ) + return file_ids, voice_prompt - # Add footnotes to the end of the message before displaying to user - message_content.value += "\n" + "\n".join(citations) - return message_content +async def tts(text: str) -> bytes: + """Convert text to speech using the OpenAI API.""" + client: openai.AsyncOpenAI = openai.AsyncOpenAI() + response = await client.audio.speech.create( + model=config.TTS_MODEL, + voice=config.TTS_VOICE, + input=text, + ) + return response.read() + + +async def stt(audio: bytes) -> str: + """Convert speech to text using the OpenAI API.""" + client: openai.AsyncOpenAI = openai.AsyncOpenAI() + response = await client.audio.transcriptions.create( + model="whisper-1", + file=audio, + ) + return response.text diff --git a/sam/config.py b/sam/config.py index 62cfe78..7f85dac 100644 --- a/sam/config.py +++ b/sam/config.py @@ -12,6 +12,8 @@ SLACK_APP_TOKEN = os.getenv("SLACK_APP_TOKEN") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") OPENAI_ASSISTANT_ID = os.getenv("OPENAI_ASSISTANT_ID") +TTS_VOICE = os.getenv("TTS_VOICE", "alloy") +TTS_MODEL = os.getenv("TTS_MODEL", "tts-1-hd") REDIS_URL = os.getenv("REDIS_URL", "redis:///") RANDOM_RUN_RATIO = float(os.getenv("RANDOM_RUN_RATIO", "0")) TIMEZONE = ZoneInfo(os.getenv("TIMEZONE", "UTC")) diff --git a/sam/slack.py b/sam/slack.py index 22742bd..de9cf1c 100644 --- a/sam/slack.py +++ b/sam/slack.py @@ -1,5 +1,4 @@ -import asyncio -import enum +import functools import json import logging import random # nosec @@ -8,125 +7,125 @@ from typing import Any import redis.asyncio as redis -from openai import AsyncOpenAI -from slack_bolt.async_app import AsyncApp, AsyncSay +from slack_bolt.async_app import AsyncSay +from slack_sdk.web.async_client import AsyncWebClient +from slack_sdk.web.client import WebClient from . import bot, config, utils -logger = logging.getLogger("sam") +logger = logging.getLogger(__name__) -client = AsyncOpenAI() -app = AsyncApp(token=config.SLACK_BOT_TOKEN) +_USER_HANDLE = None + +ACKNOWLEDGMENT_SMILEYS = [ + "thumbsup", + "ok_hand", + "eyes", + "wave", + "robot_face", + "saluting_face", + "v", + "100", + "muscle", + "thought_balloon", + "speech_balloon", + "space_invader", + "call_me_hand", +] -USER_HANDLE = None -AUDIO_FORMATS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"] +async def get_bot_user_id(): + """Get the Slack bot's user id.""" + client = AsyncWebClient(token=config.SLACK_BOT_TOKEN) + global _USER_HANDLE + if _USER_HANDLE is None: + logger.debug("Fetching the bot's user id") + response = await client.auth_test() + _USER_HANDLE = response["user_id"] + logger.debug(f"Bot's user id is {_USER_HANDLE}") + return _USER_HANDLE async def handle_message(event: {str, Any}, say: AsyncSay): + """Handle a message event from Slack.""" logger.debug(f"handle_message={json.dumps(event)}") - global USER_HANDLE - if USER_HANDLE is None: - logger.debug("Fetching the bot's user id") - response = await say.client.auth_test() - USER_HANDLE = response["user_id"] - logger.debug(f"Bot's user id is {USER_HANDLE}") + bot_id = await get_bot_user_id() channel_id = event["channel"] - client_msg_id = event["client_msg_id"] channel_type = event["channel_type"] - user_id = event["user"] text = event["text"] - text = text.replace(f"<@{USER_HANDLE}>", "Sam") + text = text.replace(f"<@{bot_id}>", "Sam") thread_id = await utils.get_thread_id(channel_id) # We may only add messages to a thread while the assistant is not running + files = [] + for file in event.get("files", []): + req = urllib.request.Request( + file["url_private"], + headers={"Authorization": f"Bearer {config.SLACK_BOT_TOKEN}"}, + ) + with urllib.request.urlopen(req) as response: # nosec + files.append((file["name"], response.read())) + async with ( redis.from_url(config.REDIS_URL) as redis_client, redis_client.lock(thread_id, timeout=10 * 60, thread_local=False), ): # 10 minutes - file_ids = [] - voice_prompt = False - if "files" in event: - for file in event["files"]: - req = urllib.request.Request( - file["url_private"], - headers={"Authorization": f"Bearer {config.SLACK_BOT_TOKEN}"}, - ) - with urllib.request.urlopen(req) as response: # nosec - if file["filetype"] in AUDIO_FORMATS: - text += "\n" + await client.audio.transcriptions.create( - model="whisper-1", - file=(file["name"], response.read()), - response_format="text", - ) - logger.info(f"User={user_id} added Audio={file['id']}") - voice_prompt = True - else: - new_file = await client.files.create( - file=(file["name"], response.read()), - purpose="assistants", - ) - file_ids.append(new_file.id) - logger.info( - f"User={user_id} added File={file_ids[-1]} to Thread={thread_id}" - ) - await client.beta.threads.messages.create( + file_ids, voice_prompt = await bot.add_message( thread_id=thread_id, content=text, - role="user", + files=files, ) - logger.info( - f"User={user_id} added Message={client_msg_id} added to Thread={thread_id}" - ) - if ( - channel_type == "im" - or event.get("parent_user_id") == USER_HANDLE - or random.random() < config.RANDOM_RUN_RATIO # nosec - ): - # we need to run the assistant in a separate thread, otherwise we will - # block the main thread: - # process_run(event, say, voice_prompt=voice_prompt) - asyncio.create_task(process_run(event, say, voice_prompt=voice_prompt)) + # we need to release the lock before starting a new run + if ( + channel_type == "im" + or event.get("parent_user_id") == bot_id + or random.random() < config.RANDOM_RUN_RATIO # nosec + ): + await send_response(event, say, file_ids=file_ids, voice_prompt=voice_prompt) -ACKNOWLEDGMENT_SMILEYS = [ - "thumbsup", - "ok_hand", - "eyes", - "wave", - "robot_face", - "saluting_face", - "v", - "100", - "muscle", - "thought_balloon", - "speech_balloon", - "space_invader", - "call_me_hand", -] +@functools.lru_cache(maxsize=128) +def get_user_profile(user_id: str) -> dict[str, Any]: + """Get the profile of a user.""" + client = WebClient(token=config.SLACK_BOT_TOKEN) + return client.users_profile_get(user=user_id)["profile"] -async def process_run(event: {str, Any}, say: AsyncSay, voice_prompt: bool = False): - logger.debug(f"process_run={json.dumps(event)}") - channel_id = event["channel"] - user_id = event["user"] - profile = (await say.client.users_profile_get(user=user_id))["profile"] + +@functools.lru_cache(maxsize=128) +def get_user_specific_instructions(user_id: str) -> str: + """Get the user-specific instructions.""" + profile = get_user_profile(user_id) name = profile["display_name"] email = profile["email"] pronouns = profile.get("pronouns") - local_time = datetime.now().astimezone(config.TIMEZONE).strftime("%H:%M") - additional_instructions = ( - f"You MUST ALWAYS address the user as <@{user_id}>.\n" - f"You may refer to the user as {name}.\n" - f"The user's email is {email}.\n" - f"The time is {local_time}.\n" - ) + local_time = datetime.now(tz=config.TIMEZONE) + instructions = [ + f"You MUST ALWAYS address the user as <@{user_id}>.", + f"You may refer to the user as {name}.", + f"The user's email is {email}.", + f"The time is {local_time.isoformat()}.", + ] if pronouns: - additional_instructions += f"The user's pronouns are {pronouns}.\n" + instructions.append(f"The user's pronouns are {pronouns}.") + return "\n".join(instructions) + + +async def send_response( + event: {str, Any}, + say: AsyncSay, + file_ids: list[str] = None, + voice_prompt: bool = False, +): + """Send a response to a message event from Slack.""" + logger.debug(f"process_run={json.dumps(event)}") + channel_id = event["channel"] + user_id = event["user"] try: - ts = event["ts"] + timestamp = event["ts"] except KeyError: - ts = event["thread_ts"] + timestamp = event["thread_ts"] thread_id = await utils.get_thread_id(channel_id) + # We may wait for the messages being processed, before starting a new run async with ( redis.from_url(config.REDIS_URL) as redis_client, @@ -136,31 +135,32 @@ async def process_run(event: {str, Any}, say: AsyncSay, voice_prompt: bool = Fal await say.client.reactions_add( channel=channel_id, name=random.choice(ACKNOWLEDGMENT_SMILEYS), # nosec - timestamp=ts, + timestamp=timestamp, ) - message_content = await bot.run( + text_response = await bot.execute_run( thread_id=thread_id, assistant_id=config.OPENAI_ASSISTANT_ID, - additional_instructions=additional_instructions, - **profile, + additional_instructions=get_user_specific_instructions(user_id), + file_ids=file_ids, + **get_user_profile(user_id), ) msg = await say( channel=say.channel, - text=message_content.value, + text=text_response, mrkdwn=True, thread_ts=event.get("thread_ts", None), ) + logger.info( + f"Sam responded to the User={user_id} in Channel={channel_id} via Text" + ) if voice_prompt: - response = await client.audio.speech.create( - model="tts-1-hd", - voice="alloy", - input=message_content.value, - ) await say.client.files_upload_v2( - content=response.read(), - channels=say.channel, + filename="response.mp3", + title="Voice Response", + content=await bot.tts(text_response), + channel=say.channel, thread_ts=event.get("thread_ts", None), ts=msg["ts"], ) @@ -168,10 +168,11 @@ async def process_run(event: {str, Any}, say: AsyncSay, voice_prompt: bool = Fal f"Sam responded to the User={user_id} in Channel={channel_id} via Voice" ) - logger.info( - f"Sam responded to the User={user_id} in Channel={channel_id} via Text" - ) +def get_app(): # pragma: no cover + from slack_bolt.async_app import AsyncApp -app.event("message")(handle_message) -app.event("app_mention")(process_run) + app = AsyncApp(token=config.SLACK_BOT_TOKEN) + app.event("message")(handle_message) + app.event("app_mention")(send_response) + return app diff --git a/sam/typing.py b/sam/typing.py index b0aea09..9d0dcfb 100644 --- a/sam/typing.py +++ b/sam/typing.py @@ -23,3 +23,6 @@ class RunStatus(enum.StrEnum): FAILED = "failed" EXPIRED = "expired" INCOMPLETE = "incomplete" + + +AUDIO_FORMATS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"] diff --git a/tests/test_bot.py b/tests/test_bot.py index bd3ffaa..30aa4d3 100644 --- a/tests/test_bot.py +++ b/tests/test_bot.py @@ -33,8 +33,13 @@ async def test_complete_run__requires_action(client, monkeypatch): tool_call.function.name = "web_search" required_action.submit_tool_outputs.tool_calls = [tool_call] client.beta.threads.runs.retrieve.return_value = namedtuple( - "Run", ["id", "status", "required_action"] - )(id="run-1", status="requires_action", required_action=required_action) + "Run", ["id", "thread_id", "status", "required_action"] + )( + id="run-1", + thread_id="thread-1", + status="requires_action", + required_action=required_action, + ) with pytest.raises(RecursionError): await bot.complete_run(run_id="run-1", thread_id="thread-1") assert web_search.called @@ -46,14 +51,33 @@ async def test_complete_run__queued(monkeypatch, client): client.beta.threads.runs.retrieve.return_value = namedtuple( "Run", ["id", "status"] )(id="run-1", status="queued") - with pytest.raises(Exception) as e: + with pytest.raises(RecursionError) as e: await bot.complete_run(run_id="run-1", thread_id="thread-1") assert "Max retries exceeded" in str(e.value) @pytest.mark.asyncio -async def test_run(monkeypatch, client): +async def test_complete_run__completed(monkeypatch, client): + monkeypatch.setattr("sam.utils.backoff", mock.AsyncMock()) + client.beta.threads.runs.retrieve.return_value = namedtuple( + "Run", ["id", "status"] + )(id="run-1", status="completed") + await bot.complete_run(run_id="run-1", thread_id="thread-1") + + +@pytest.mark.asyncio +async def test_complete_run__unexpected_status(monkeypatch, client): + monkeypatch.setattr("sam.utils.backoff", mock.AsyncMock()) + client.beta.threads.runs.retrieve.return_value = namedtuple( + "Run", ["id", "status"] + )(id="run-1", status="failed") + with pytest.raises(IOError): + await bot.complete_run(run_id="run-1", thread_id="thread-1") + + +@pytest.mark.asyncio +async def test_execute_run(monkeypatch, client): client.beta.threads.runs.retrieve.return_value = namedtuple( "Run", ["id", "status"] )(id="run-1", status="queued") @@ -76,6 +100,74 @@ async def test_run(monkeypatch, client): ), ] ) - bot.complete_run = mock.AsyncMock() - await bot.run(thread_id="thread-1", assistant_id="assistant-1") - assert bot.complete_run.called + complete_run = mock.AsyncMock() + monkeypatch.setattr(bot, "complete_run", complete_run) + await bot.execute_run(thread_id="thread-1", assistant_id="assistant-1") + assert complete_run.called + + +@pytest.mark.asyncio +async def test_execute_run__no_completed(monkeypatch, client): + client.beta.threads.runs.retrieve.return_value = namedtuple( + "Run", ["id", "status"] + )(id="run-1", status="queued") + client.beta.threads.messages.list.return_value = namedtuple("Response", ["data"])( + data=[ + Message( + id="msg-1", + content=[ + TextContentBlock( + type="text", text=Text(value="Hello", annotations=[]) + ) + ], + status="completed", + role="assistant", + created_at=123, + files=[], + file_ids=[], + object="thread.message", + thread_id="thread-4", + ), + ] + ) + complete_run = mock.AsyncMock(side_effect=[RecursionError, None]) + monkeypatch.setattr(bot, "complete_run", complete_run) + response = await bot.execute_run(thread_id="thread-1", assistant_id="assistant-1") + assert complete_run.called + assert response == "🤯" + + +@pytest.mark.asyncio +async def test_execute_run__no_message(monkeypatch, client): + client.beta.threads.runs.retrieve.return_value = namedtuple( + "Run", ["id", "status"] + )(id="run-1", status="queued") + client.beta.threads.messages.list.return_value = namedtuple("Response", ["data"])( + data=[ + Message( + id="msg-1", + content=[ + TextContentBlock( + type="text", text=Text(value="Hello", annotations=[]) + ) + ], + status="completed", + role="assistant", + created_at=123, + files=[], + file_ids=[], + object="thread.message", + thread_id="thread-4", + ), + ] + ) + complete_run = mock.AsyncMock() + monkeypatch.setattr(bot, "complete_run", complete_run) + fetch_latest_assistant_message = mock.AsyncMock(side_effect=[ValueError, None]) + monkeypatch.setattr( + bot, "fetch_latest_assistant_message", fetch_latest_assistant_message + ) + response = await bot.execute_run(thread_id="thread-1", assistant_id="assistant-1") + assert complete_run.called + assert fetch_latest_assistant_message.called + assert response == "🤯" diff --git a/tests/test_slack.py b/tests/test_slack.py new file mode 100644 index 0000000..a00b6e4 --- /dev/null +++ b/tests/test_slack.py @@ -0,0 +1,207 @@ +from unittest import mock + +import pytest + +from sam import bot, slack + + +@pytest.mark.asyncio +async def test_get_bot_user_id(monkeypatch): + auth_test = mock.AsyncMock(return_value={"user_id": "bot-1"}) + monkeypatch.setattr(slack.AsyncWebClient, "auth_test", auth_test) + assert await slack.get_bot_user_id() == "bot-1" + assert auth_test.called + + +@pytest.mark.asyncio +async def test_handle_message(monkeypatch): + urlopen = mock.AsyncMock() + urlopen.__enter__().read.return_value = b"Hello" + monkeypatch.setattr("urllib.request.urlopen", lambda *args, **kwargs: urlopen) + add_message = mock.AsyncMock(return_value=(["file-1"], False)) + monkeypatch.setattr(bot, "add_message", add_message) + send_response = mock.AsyncMock() + monkeypatch.setattr(slack, "send_response", send_response) + get_bot_user_id = mock.AsyncMock(return_value="bot-1") + monkeypatch.setattr(slack, "get_bot_user_id", get_bot_user_id) + monkeypatch.setattr( + "sam.utils.get_thread_id", mock.AsyncMock(return_value="thread-1") + ) + say = mock.AsyncMock() + event = { + "channel": "channel-1", + "client_msg_id": "client-msg-1", + "channel_type": "im", + "user": "user-1", + "text": "Hello", + "files": [ + { + "url_private": "https://audio-samples.github.io/samples/mp3/blizzard_tts_unbiased/sample-0/real.mp3", + "name": "file.mp3", + } + ], + } + await slack.handle_message(event, say) + assert add_message.called + assert add_message.call_args == mock.call( + thread_id="thread-1", content="Hello", files=[("file.mp3", b"Hello")] + ) + assert urlopen.__enter__().read.called + assert send_response.called + assert send_response.call_args == mock.call( + { + "channel": "channel-1", + "client_msg_id": "client-msg-1", + "channel_type": "im", + "user": "user-1", + "text": "Hello", + "files": [ + { + "url_private": "https://audio-samples.github.io/samples/mp3/blizzard_tts_unbiased/sample-0/real.mp3", + "name": "file.mp3", + } + ], + }, + say, + file_ids=["file-1"], + voice_prompt=False, + ) + + +def test_get_user_profile(monkeypatch): + client = mock.MagicMock() + client.users_profile_get.return_value = { + "profile": { + "display_name": "Spidy", + "status_text": "With great power comes great responsibility", + "pronouns": "spider/superhero", + "email": "peter.parker@avengers.com", + } + } + monkeypatch.setattr(slack.WebClient, "users_profile_get", client.users_profile_get) + assert slack.get_user_profile("user-1") == { + "display_name": "Spidy", + "status_text": "With great power comes great responsibility", + "pronouns": "spider/superhero", + "email": "peter.parker@avengers.com", + } + + +def test_get_user_specific_instructions(monkeypatch): + client = mock.MagicMock() + client.users_profile_get.return_value = { + "profile": { + "display_name": "Spidy", + "status_text": "With great power comes great responsibility", + "pronouns": "spider/superhero", + "email": "peter.parker@avengers.com", + } + } + monkeypatch.setattr(slack.WebClient, "users_profile_get", client.users_profile_get) + instructions = slack.get_user_specific_instructions("user-1") + assert "You MUST ALWAYS address the user as <@user-1>." in instructions + assert "You may refer to the user as Spidy." in instructions + assert "The user's email is peter.parker@avengers.com." in instructions + assert "The user's pronouns are spider/superhero." in instructions + + +@pytest.mark.asyncio +async def test_send_response(monkeypatch): + urlopen = mock.AsyncMock() + urlopen.__enter__().read.return_value = b"Hello" + monkeypatch.setattr("urllib.request.urlopen", lambda *args, **kwargs: urlopen) + execute_run = mock.AsyncMock(return_value="Hello World!") + monkeypatch.setattr(bot, "execute_run", execute_run) + tts = mock.AsyncMock(return_value=b"Hello") + monkeypatch.setattr(bot, "tts", tts) + get_bot_user_id = mock.AsyncMock(return_value="bot-1") + monkeypatch.setattr(slack, "get_bot_user_id", get_bot_user_id) + monkeypatch.setattr( + slack, + "get_user_specific_instructions", + lambda *args, **kwargs: "user_instructions", + ) + monkeypatch.setattr( + slack, "get_user_profile", lambda *args, **kwargs: {"name": "Sam"} + ) + monkeypatch.setattr( + "sam.utils.get_thread_id", mock.AsyncMock(return_value="thread-1") + ) + say = mock.AsyncMock() + event = { + "channel": "channel-1", + "client_msg_id": "client-msg-1", + "channel_type": "im", + "user": "user-1", + "ts": 12321345, + "text": "Hello", + "files": [ + { + "url_private": "https://example.com/file.mp3", + "name": "file.mp3", + } + ], + } + await slack.send_response(event, say, voice_prompt=True) + + assert execute_run.called + assert execute_run.call_args == mock.call( + thread_id="thread-1", + assistant_id=None, + additional_instructions="user_instructions", + file_ids=None, + name="Sam", + ) + assert tts.called + assert tts.call_args == mock.call("Hello World!") + + +@pytest.mark.asyncio +async def test_send_response__thread(monkeypatch): + urlopen = mock.AsyncMock() + urlopen.__enter__().read.return_value = b"Hello" + monkeypatch.setattr("urllib.request.urlopen", lambda *args, **kwargs: urlopen) + execute_run = mock.AsyncMock(return_value="Hello World!") + monkeypatch.setattr(bot, "execute_run", execute_run) + tts = mock.AsyncMock(return_value=b"Hello") + monkeypatch.setattr(bot, "tts", tts) + get_bot_user_id = mock.AsyncMock(return_value="bot-1") + monkeypatch.setattr(slack, "get_bot_user_id", get_bot_user_id) + monkeypatch.setattr( + slack, + "get_user_specific_instructions", + lambda *args, **kwargs: "user_instructions", + ) + monkeypatch.setattr( + slack, "get_user_profile", lambda *args, **kwargs: {"name": "Sam"} + ) + monkeypatch.setattr( + "sam.utils.get_thread_id", mock.AsyncMock(return_value="thread-1") + ) + say = mock.AsyncMock() + event = { + "channel": "channel-1", + "client_msg_id": "client-msg-1", + "channel_type": "im", + "user": "user-1", + "thread_ts": 12321345, + "text": "Hello", + "files": [ + { + "url_private": "https://audio-samples.github.io/samples/mp3/blizzard_tts_unbiased/sample-0/real.mp3", + "name": "file.mp3", + } + ], + } + await slack.send_response(event, say, voice_prompt=True) + + assert execute_run.called + assert execute_run.call_args == mock.call( + thread_id="thread-1", + assistant_id=None, + additional_instructions="user_instructions", + file_ids=None, + name="Sam", + ) + assert tts.called + assert tts.call_args == mock.call("Hello World!") From 6958b2ec3c2fcb0c9c6c5285989e229bfe948fe6 Mon Sep 17 00:00:00 2001 From: Johannes Maron Date: Sat, 20 Apr 2024 14:32:10 +0200 Subject: [PATCH 17/19] Hotfix -- Fix sentry support --- sam/__main__.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/sam/__main__.py b/sam/__main__.py index 4806674..66b4e97 100644 --- a/sam/__main__.py +++ b/sam/__main__.py @@ -6,10 +6,17 @@ import click import openai import sentry_sdk +from sentry_sdk.integrations.asyncio import AsyncioIntegration from . import config -sentry_sdk.init(config.SENTRY_DSN) +sentry_sdk.init( + dsn=config.SENTRY_DSN, + enable_tracing=True, + integrations=[ + AsyncioIntegration(), + ], +) @click.group() From b3ec8f35fda7ffe2e4c3ea66043800ced92afc81 Mon Sep 17 00:00:00 2001 From: Johannes Maron Date: Sat, 20 Apr 2024 15:04:39 +0200 Subject: [PATCH 18/19] Hotfix -- Fix message role --- sam/bot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sam/bot.py b/sam/bot.py index dec78f1..cc55c5e 100644 --- a/sam/bot.py +++ b/sam/bot.py @@ -200,7 +200,7 @@ async def add_message( await client.beta.threads.messages.create( thread_id=thread_id, content=content, - role="assistant", + role=Roles.USER, file_ids=file_ids, ) return file_ids, voice_prompt From 512628b047c1331b29cd436edc77afd4d33b8713 Mon Sep 17 00:00:00 2001 From: Johannes Maron Date: Sat, 20 Apr 2024 15:16:03 +0200 Subject: [PATCH 19/19] Fix #66 -- Handle special subtyptes See also: https://api.slack.com/events/message#hidden_subtypes --- sam/slack.py | 3 +++ tests/test_slack.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/sam/slack.py b/sam/slack.py index de9cf1c..79e9f7c 100644 --- a/sam/slack.py +++ b/sam/slack.py @@ -49,6 +49,9 @@ async def get_bot_user_id(): async def handle_message(event: {str, Any}, say: AsyncSay): """Handle a message event from Slack.""" logger.debug(f"handle_message={json.dumps(event)}") + if event.get("subtype") == "message_deleted": + logger.debug("Ignoring message_deleted event %s", event) + return # https://api.slack.com/events/message#hidden_subtypes bot_id = await get_bot_user_id() channel_id = event["channel"] channel_type = event["channel_type"] diff --git a/tests/test_slack.py b/tests/test_slack.py index a00b6e4..9d641f7 100644 --- a/tests/test_slack.py +++ b/tests/test_slack.py @@ -1,3 +1,4 @@ +import logging from unittest import mock import pytest @@ -68,6 +69,22 @@ async def test_handle_message(monkeypatch): ) +@pytest.mark.asyncio +async def test_handle_message__subtype_deleted(caplog): + event = { + "type": "message", + "subtype": "message_deleted", + "hidden": True, + "channel": "C123ABC456", + "ts": "1358878755.000001", + "deleted_ts": "1358878749.000002", + "event_ts": "1358878755.000002", + } + with caplog.at_level(logging.DEBUG): + await slack.handle_message(event, None) + assert "Ignoring message_deleted event" in caplog.text + + def test_get_user_profile(monkeypatch): client = mock.MagicMock() client.users_profile_get.return_value = {