Skip to content

Commit

Permalink
fix(backend): phase one pr updates
Browse files Browse the repository at this point in the history
  • Loading branch information
ntindle committed Oct 18, 2024
1 parent 83fda8c commit 9c07633
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 95 deletions.
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import secrets
from datetime import datetime, timedelta, timezone
from typing import cast
from backend.data import db
from prisma import Json, Prisma
from prisma.models import User
import json
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from redis import Redis
from backend.executor.database import DatabaseManager

from autogpt_libs.utils.synchronize import RedisKeyedMutex

Expand All @@ -22,46 +18,42 @@


class SupabaseIntegrationCredentialsStore:
def __init__(self, redis: "Redis"):
self.prisma: Prisma = Prisma()
def __init__(self, redis: "Redis", db: "DatabaseManager"):
self.db_manager: DatabaseManager = db
self.locks = RedisKeyedMutex(redis)

async def add_creds(self, user_id: str, credentials: Credentials) -> None:
def add_creds(self, user_id: str, credentials: Credentials) -> None:
with self.locked_user_metadata(user_id):
if await self.get_creds_by_id(user_id, credentials.id):
if self.get_creds_by_id(user_id, credentials.id):
raise ValueError(
f"Can not re-create existing credentials #{credentials.id} "
f"for user #{user_id}"
)
await self._set_user_integration_creds(
user_id, [*await self.get_all_creds(user_id), credentials]
self._set_user_integration_creds(
user_id, [*self.get_all_creds(user_id), credentials]
)

async def get_all_creds(self, user_id: str) -> list[Credentials]:
user_metadata = await self._get_user_metadata(user_id)
def get_all_creds(self, user_id: str) -> list[Credentials]:
user_metadata = self._get_user_metadata(user_id)
return UserMetadata.model_validate(
user_metadata.model_dump()
).integration_credentials

async def get_creds_by_id(
self, user_id: str, credentials_id: str
) -> Credentials | None:
all_credentials = await self.get_all_creds(user_id)
def get_creds_by_id(self, user_id: str, credentials_id: str) -> Credentials | None:
all_credentials = self.get_all_creds(user_id)
return next((c for c in all_credentials if c.id == credentials_id), None)

async def get_creds_by_provider(
self, user_id: str, provider: str
) -> list[Credentials]:
credentials = await self.get_all_creds(user_id)
def get_creds_by_provider(self, user_id: str, provider: str) -> list[Credentials]:
credentials = self.get_all_creds(user_id)
return [c for c in credentials if c.provider == provider]

async def get_authorized_providers(self, user_id: str) -> list[str]:
credentials = await self.get_all_creds(user_id)
def get_authorized_providers(self, user_id: str) -> list[str]:
credentials = self.get_all_creds(user_id)
return list(set(c.provider for c in credentials))

async def update_creds(self, user_id: str, updated: Credentials) -> None:
def update_creds(self, user_id: str, updated: Credentials) -> None:
with self.locked_user_metadata(user_id):
current = await self.get_creds_by_id(user_id, updated.id)
current = self.get_creds_by_id(user_id, updated.id)
if not current:
raise ValueError(
f"Credentials with ID {updated.id} "
Expand Down Expand Up @@ -89,20 +81,18 @@ async def update_creds(self, user_id: str, updated: Credentials) -> None:
# Update the credentials
updated_credentials_list = [
updated if c.id == updated.id else c
for c in await self.get_all_creds(user_id)
for c in self.get_all_creds(user_id)
]
await self._set_user_integration_creds(user_id, updated_credentials_list)
self._set_user_integration_creds(user_id, updated_credentials_list)

async def delete_creds_by_id(self, user_id: str, credentials_id: str) -> None:
def delete_creds_by_id(self, user_id: str, credentials_id: str) -> None:
with self.locked_user_metadata(user_id):
filtered_credentials = [
c for c in await self.get_all_creds(user_id) if c.id != credentials_id
c for c in self.get_all_creds(user_id) if c.id != credentials_id
]
await self._set_user_integration_creds(user_id, filtered_credentials)
self._set_user_integration_creds(user_id, filtered_credentials)

async def store_state_token(
self, user_id: str, provider: str, scopes: list[str]
) -> str:
def store_state_token(self, user_id: str, provider: str, scopes: list[str]) -> str:
token = secrets.token_urlsafe(32)
expires_at = datetime.now(timezone.utc) + timedelta(minutes=10)

Expand All @@ -114,21 +104,18 @@ async def store_state_token(
)

with self.locked_user_metadata(user_id):
user_metadata = await self._get_user_metadata(user_id)
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.integration_oauth_states
oauth_states.append(state.model_dump())
user_metadata.integration_oauth_states = oauth_states

if not self.prisma.is_connected():
await self.prisma.connect()
await self.prisma.user.update(
where={"id": user_id},
data={"metadata": Json(user_metadata.model_dump())},
self.db_manager.update_user_metadata(
user_id=user_id, metadata=user_metadata
)

return token

async def get_any_valid_scopes_from_state_token(
def get_any_valid_scopes_from_state_token(
self, user_id: str, token: str, provider: str
) -> list[str]:
"""
Expand All @@ -138,7 +125,7 @@ async def get_any_valid_scopes_from_state_token(
IS TO CHECK IF THE USER HAS GIVEN PERMISSIONS TO THE APPLICATION BEFORE EXCHANGING
THE CODE FOR TOKENS.
"""
user_metadata = await self._get_user_metadata(user_id)
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.integration_oauth_states

now = datetime.now(timezone.utc)
Expand All @@ -158,9 +145,9 @@ async def get_any_valid_scopes_from_state_token(

return []

async def verify_state_token(self, user_id: str, token: str, provider: str) -> bool:
def verify_state_token(self, user_id: str, token: str, provider: str) -> bool:
with self.locked_user_metadata(user_id):
user_metadata = await self._get_user_metadata(user_id)
user_metadata = self._get_user_metadata(user_id)
oauth_states = user_metadata.integration_oauth_states

now = datetime.now(timezone.utc)
Expand All @@ -179,39 +166,26 @@ async def verify_state_token(self, user_id: str, token: str, provider: str) -> b
# Remove the used state
oauth_states.remove(valid_state)
user_metadata.integration_oauth_states = oauth_states
if not self.prisma.is_connected():
await self.prisma.connect()
await self.prisma.user.update(
where={"id": user_id},
data={"metadata": Json(user_metadata.model_dump())},
)
self.db_manager.update_user_metadata(user_id, user_metadata)
return True

return False

async def _set_user_integration_creds(
def _set_user_integration_creds(
self, user_id: str, credentials: list[Credentials]
) -> None:
raw_metadata = await self._get_user_metadata(user_id)
raw_metadata = self._get_user_metadata(user_id)
raw_metadata.integration_credentials = [c.model_dump() for c in credentials]
if not self.prisma.is_connected():
await self.prisma.connect()
await self.prisma.user.update(
where={"id": user_id}, data={"metadata": Json(raw_metadata.model_dump())}
)
self.db_manager.update_user_metadata(user_id, raw_metadata)

async def _get_user_metadata(self, user_id: str) -> UserMetadataRaw:
if not self.prisma.is_connected():
await self.prisma.connect()
user = await self.prisma.user.find_unique(where={"id": user_id})
if not user:
raise ValueError(f"User with ID {user_id} not found")
def _get_user_metadata(self, user_id: str) -> UserMetadataRaw:
user = self.db_manager.get_user(user_id=user_id)
return (
UserMetadataRaw.model_validate(user.metadata)
if user.metadata
else UserMetadataRaw()
)

def locked_user_metadata(self, user_id: str):
key = (self.prisma, f"user:{user_id}", "metadata")
key = (self.db_manager, f"user:{user_id}", "metadata")
return self.locks.locked(key)
17 changes: 17 additions & 0 deletions autogpt_platform/backend/backend/data/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
from multiprocessing import Manager
from typing import Any, Generic, TypeVar

from autogpt_libs.supabase_integration_credentials_store.types import UserMetadataRaw
from prisma import Json
from prisma.enums import AgentExecutionStatus
from prisma.models import (
AgentGraphExecution,
AgentNodeExecution,
AgentNodeExecutionInputOutput,
User,
)
from prisma.types import (
AgentGraphExecutionInclude,
Expand Down Expand Up @@ -477,3 +480,17 @@ async def get_incomplete_executions(
include=EXECUTION_RESULT_INCLUDE,
)
return [ExecutionResult.from_db(execution) for execution in executions]


async def get_user(user_id: str) -> User:
user = await User.prisma().find_unique_or_raise(
where={"id": user_id},
)
return user


async def update_user_metadata(user_id: str, metadata: UserMetadataRaw):
await User.prisma().update(
where={"id": user_id},
data={"metadata": Json(metadata.model_dump())},
)
6 changes: 6 additions & 0 deletions autogpt_platform/backend/backend/executor/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
get_execution_results,
get_incomplete_executions,
get_latest_execution,
get_user,
update_execution_status,
update_graph_execution_stats,
update_node_execution_stats,
update_user_metadata,
upsert_execution_input,
upsert_execution_output,
)
Expand Down Expand Up @@ -73,3 +75,7 @@ def wrapper(self, *args: P.args, **kwargs: P.kwargs) -> R:
Callable[[Any, str, int, str, dict[str, str], float, float], int],
exposed_run_and_wait(user_credit_model.spend_credits),
)

# User + User Metadata
get_user = exposed_run_and_wait(get_user)
update_user_metadata = exposed_run_and_wait(update_user_metadata)
12 changes: 7 additions & 5 deletions autogpt_platform/backend/backend/executor/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def on_node_executor_start(cls):
redis.connect()
cls.pid = os.getpid()
cls.db_client = get_db_client()
cls.creds_manager = IntegrationCredentialsManager()
cls.creds_manager = IntegrationCredentialsManager(db_manager=cls.db_client)

# Set up shutdown handlers
cls.shutdown_lock = threading.Lock()
Expand Down Expand Up @@ -669,7 +669,9 @@ def run_service(self):
SupabaseIntegrationCredentialsStore,
)

self.credentials_store = SupabaseIntegrationCredentialsStore(redis.get_redis())
self.credentials_store = SupabaseIntegrationCredentialsStore(
redis=redis.get_redis(), db=self.db_client
)
self.executor = ProcessPoolExecutor(
max_workers=self.pool_size,
initializer=Executor.on_graph_executor_start,
Expand Down Expand Up @@ -712,7 +714,7 @@ def add_execution(
raise Exception(f"Graph #{graph_id} not found.")

graph.validate_graph(for_run=True)
self.run_and_wait(self._validate_node_input_credentials(graph, user_id))
self._validate_node_input_credentials(graph, user_id)

nodes_input = []
for node in graph.starting_nodes:
Expand Down Expand Up @@ -806,7 +808,7 @@ def cancel_execution(self, graph_exec_id: str) -> None:
)
self.db_client.send_execution_update(exec_update.model_dump())

async def _validate_node_input_credentials(self, graph: Graph, user_id: str):
def _validate_node_input_credentials(self, graph: Graph, user_id: str):
"""Checks all credentials for all nodes of the graph"""

for node in graph.nodes:
Expand All @@ -828,7 +830,7 @@ async def _validate_node_input_credentials(self, graph: Graph, user_id: str):
node.input_default[CREDENTIALS_FIELD_NAME]
)
# Fetch the corresponding Credentials and perform sanity checks
credentials = await self.credentials_store.get_creds_by_id(
credentials = self.credentials_store.get_creds_by_id(
user_id, credentials_meta.id
)
if not credentials:
Expand Down
32 changes: 17 additions & 15 deletions autogpt_platform/backend/backend/integrations/creds_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import logging
from contextlib import contextmanager
from datetime import datetime
Expand All @@ -11,6 +10,7 @@
from redis.lock import Lock as RedisLock

from backend.data import redis
from backend.executor.database import DatabaseManager
from backend.integrations.oauth import HANDLERS_BY_NAME, BaseOAuthHandler
from backend.util.settings import Settings

Expand Down Expand Up @@ -50,21 +50,23 @@ class IntegrationCredentialsManager:
cause so much latency that it's worth implementing.
"""

def __init__(self):
def __init__(self, db_manager: DatabaseManager):
redis_conn = redis.get_redis()
self._locks = RedisKeyedMutex(redis_conn)
self.store = SupabaseIntegrationCredentialsStore(redis_conn)
self.store = SupabaseIntegrationCredentialsStore(
redis=redis_conn, db=db_manager
)

async def create(self, user_id: str, credentials: Credentials) -> None:
return await self.store.add_creds(user_id, credentials)
def create(self, user_id: str, credentials: Credentials) -> None:
return self.store.add_creds(user_id, credentials)

async def exists(self, user_id: str, credentials_id: str) -> bool:
def exists(self, user_id: str, credentials_id: str) -> bool:
return self.store.get_creds_by_id(user_id, credentials_id) is not None

async def get(
def get(
self, user_id: str, credentials_id: str, lock: bool = True
) -> Credentials | None:
credentials = await self.store.get_creds_by_id(user_id, credentials_id)
credentials = self.store.get_creds_by_id(user_id, credentials_id)
if not credentials:
return None

Expand All @@ -89,7 +91,7 @@ async def get(
_lock = self._acquire_lock(user_id, credentials_id)

fresh_credentials = oauth_handler.refresh_tokens(credentials)
await self.store.update_creds(user_id, fresh_credentials)
self.store.update_creds(user_id, fresh_credentials)
if _lock:
_lock.release()

Expand All @@ -111,26 +113,26 @@ def acquire(
# to allow priority access for refreshing/updating the tokens.
with self._locked(user_id, credentials_id, "!time_sensitive"):
lock = self._acquire_lock(user_id, credentials_id)
credentials = asyncio.run(self.get(user_id, credentials_id, lock=False))
credentials = self.get(user_id, credentials_id, lock=False)
if not credentials:
raise ValueError(
f"Credentials #{credentials_id} for user #{user_id} not found"
)
return credentials, lock

async def update(self, user_id: str, updated: Credentials) -> None:
def update(self, user_id: str, updated: Credentials) -> None:
with self._locked(user_id, updated.id):
await self.store.update_creds(user_id, updated)
self.store.update_creds(user_id, updated)

async def delete(self, user_id: str, credentials_id: str) -> None:
def delete(self, user_id: str, credentials_id: str) -> None:
with self._locked(user_id, credentials_id):
await self.store.delete_creds_by_id(user_id, credentials_id)
self.store.delete_creds_by_id(user_id, credentials_id)

# -- Locking utilities -- #

def _acquire_lock(self, user_id: str, credentials_id: str, *args: str) -> RedisLock:
key = (
self.store.prisma,
self.store.db_manager,
f"user:{user_id}",
f"credentials:{credentials_id}",
*args,
Expand Down
Loading

0 comments on commit 9c07633

Please sign in to comment.