Skip to content

Commit

Permalink
Add support for multiple GH repos
Browse files Browse the repository at this point in the history
  • Loading branch information
codingjoe committed Apr 18, 2024
1 parent 82c7c7d commit 4a680f2
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 25 deletions.
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ sam = "sam.__main__:cli"
[project.optional-dependencies]
test = [
"pytest",
"pytest-cov",
"pytest-asyncio",
"pytest-cov",
"pytest-env",
]
lint = [
"bandit==1.7.8",
Expand Down Expand Up @@ -58,6 +59,9 @@ minversion = "6.0"
addopts = "--cov --tb=short -rxs"
testpaths = ["tests"]

[tool.pytest_env]
GITHUB_REPOS = 'voiio/sam'

[tool.coverage.run]
source = ["sam"]

Expand Down
4 changes: 2 additions & 2 deletions sam/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ async def complete_run(run_id: str, thread_id: str, *, retry: int = 0, **context
)
tool_outputs.append(
{
"tool_call_id": tool_call.id, # noqa
"output": fn(**kwargs, **context),
"tool_call_id": tool_call.id,
"output": fn(**kwargs, _context={**context}),
}
)
logger.info("Submitting tool outputs for run %s", run_id)
Expand Down
5 changes: 5 additions & 0 deletions sam/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import os
from zoneinfo import ZoneInfo

Expand All @@ -14,3 +15,7 @@
SENTRY_DSN = os.getenv("SENTRY_DSN")
GITHUB_ORG = os.getenv("GITHUB_ORG")
GITHUB_REPOSITORY = os.getenv("GITHUB_REPOSITORY")
GITHUB_REPOS = enum.StrEnum(
"GITHUB_REPOS",
{repo: repo for repo in os.getenv("GITHUB_REPOS", "").split(",") if repo},
)
8 changes: 4 additions & 4 deletions sam/contrib/github/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class GitHubAPIError(requests.HTTPError):

class AbstractGitHubAPIWrapper(abc.ABC): # pragma: no cover
@abc.abstractmethod
def create_issue(self, title, body):
def create_issue(self, title, body, repo):
return NotImplemented

@abc.abstractmethod
Expand All @@ -42,9 +42,9 @@ def __init__(self, token):
}
)

def create_issue(self, title, body):
def create_issue(self, title, body, repo):
response = self.post(
f"{self.endpoint}/repos/{config.GITHUB_ORG}/{config.GITHUB_REPOSITORY}/issues",
f"{self.endpoint}/repos/{repo}/issues",
json={"title": title, "body": body},
)
try:
Expand All @@ -64,7 +64,7 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
pass

def create_issue(self, title, body):
def create_issue(self, title, body, repo):
return {
"title": title,
"body": body,
Expand Down
1 change: 1 addition & 0 deletions sam/slack.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import enum
import json
import logging
import random # nosec
Expand Down
22 changes: 16 additions & 6 deletions sam/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from markdownify import markdownify as md
from slack_sdk import WebClient, errors

import sam.config
from sam import config
from sam.contrib import algolia, brave, github
from sam.utils import logger


def send_email(to: str, subject: str, body: str, **_context):
def send_email(to: str, subject: str, body: str, _context=None):
"""
Send an email the given recipients. The user is always cc'd on the email.
Expand All @@ -29,6 +30,7 @@ def send_email(to: str, subject: str, body: str, **_context):
subject: The subject of the email.
body: The body of the email.
"""
_context = _context or {}
email_url = os.getenv("EMAIL_URL")
from_email = os.getenv("FROM_EMAIL", "sam@voiio.de")
email_white_list = os.getenv("EMAIL_WHITE_LIST")
Expand Down Expand Up @@ -60,7 +62,7 @@ def send_email(to: str, subject: str, body: str, **_context):
return "Email sent successfully!"


def web_search(query: str, **_context) -> str:
def web_search(query: str, _context=None) -> str:
"""
Search the internet for information that matches the given query.
Expand Down Expand Up @@ -91,7 +93,7 @@ def web_search(query: str, **_context) -> str:
)


def fetch_website(url: str, **_context) -> str:
def fetch_website(url: str, _context=None) -> str:
"""
Fetch the website for the given URL and return the content as Markdown.
Expand All @@ -116,7 +118,7 @@ def fetch_website(url: str, **_context) -> str:
return "failed to parse website"


def fetch_coworker_emails(**_context) -> str:
def fetch_coworker_emails(_context=None) -> str:
"""
Fetch profile data about your coworkers from Slack.
Expand Down Expand Up @@ -156,7 +158,9 @@ def fetch_coworker_emails(**_context) -> str:
return json.dumps(profiles)


def create_github_issue(title: str, body: str) -> str:
def create_github_issue(
title: str, body: str, repo: "sam.config.GITHUB_REPOS", _context=None
) -> str:
"""
Create an issue on GitHub with the given title and body.
Expand All @@ -166,13 +170,19 @@ def create_github_issue(title: str, body: str) -> str:
You should provide ideas for a potential solution,
including code snippet examples in a Markdown code block.
You MUST ALWAYS write the issue in English.
Args:
title: The title of the issue.
body: The body of the issue, markdown supported.
repo: The repository to create the issue in.
"""
if repo not in config.GITHUB_REPOS.__members__:
logger.warning("Invalid repo: %s", repo)
return "invalid repo"
with github.get_client() as api:
try:
response = api.create_issue(title, body)
response = api.create_issue(title, body, repo)
except github.GitHubAPIError:
logger.exception("Failed to create issue on GitHub")
return "failed to create issue"
Expand Down
32 changes: 23 additions & 9 deletions sam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import inspect
import logging
import random
import typing

import openai
import redis.asyncio as redis
Expand All @@ -29,8 +30,6 @@
float: "number",
list: "array",
dict: "object",
enum.StrEnum: "string",
enum.IntEnum: "integer",
}


Expand All @@ -51,6 +50,8 @@ def func_to_tool(fn: callable) -> dict:
doc_data = yaml.safe_load(args.split("Returns:")[0])
else:
description = fn.__doc__
doc_data = {}

return {
"type": "function",
"function": {
Expand All @@ -60,13 +61,7 @@ def func_to_tool(fn: callable) -> dict:
),
"parameters": {
"type": "object",
"properties": {
param.name: {
"type": type_map[param.annotation],
"description": doc_data[param.name],
}
for param in params
},
"properties": dict(params_to_props(fn, params, doc_data)),
"required": [
param.name
for param in params
Expand All @@ -77,6 +72,25 @@ def func_to_tool(fn: callable) -> dict:
}


def params_to_props(fn, params, doc_data):
types = typing.get_type_hints(fn)
for param in params:
if param.name.startswith("_"):
continue
param_type = types[param.name]
if param_type in type_map:
yield param.name, {
"type": type_map[types[param.name]],
"description": doc_data[param.name],
}
elif issubclass(param_type, enum.StrEnum):
yield param.name, {
"type": "string",
"enum": [value.value for value in param_type],
"description": doc_data[param.name],
}


async def backoff(retries: int, max_jitter: int = 10):
"""Exponential backoff timer with a random jitter."""
await asyncio.sleep(2**retries + random.random() * max_jitter) # nosec
Expand Down
6 changes: 5 additions & 1 deletion tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,15 @@ def test_web_search__with_coordinates():

def test_create_github_issue():
assert (
tools.create_github_issue("title", "body")
tools.create_github_issue("title", "body", "voiio/sam")
== "https://www.youtube.com/watch?v=dQw4w9WgXcQ"
)


def test_create_github_issue__invalid_repo():
assert tools.create_github_issue("title", "body", "not-valid") == "invalid repo"


def test_platform_search():
assert tools.platform_search("ferien") == json.dumps(
{
Expand Down
18 changes: 16 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import enum

import pytest

import tests.test_tools
from sam import utils


Expand All @@ -10,13 +13,19 @@ async def test_backoff():
await utils.backoff(0, max_jitter=0)


BloodTypes = enum.StrEnum("BloodTypes", {"A": "A", "B": "B"})


def test_func_to_tool():
def fn(a: int, b: str) -> int:
def fn(
a: int, b: str, blood_types: "tests.test_utils.BloodTypes", _context=None
) -> str:
"""Function description.
Args:
a: Description of a.
b: Description of b.
blood_types: Description of bool_types.
Returns:
Description of return value.
Expand All @@ -40,8 +49,13 @@ def fn(a: int, b: str) -> int:
"type": "string",
"description": "Description of b.",
},
"blood_types": {
"type": "string",
"enum": ["A", "B"],
"description": "Description of bool_types.",
},
},
"required": ["a", "b"],
"required": ["a", "b", "blood_types"],
},
},
}

0 comments on commit 4a680f2

Please sign in to comment.