diff --git a/autogpt_platform/backend/backend/data/api_key.py b/autogpt_platform/backend/backend/data/api_key.py index ae919ccf4b0f..91fb33622e28 100644 --- a/autogpt_platform/backend/backend/data/api_key.py +++ b/autogpt_platform/backend/backend/data/api_key.py @@ -1,26 +1,49 @@ -import uuid import logging +import uuid from datetime import datetime, timezone -from typing import Optional, List, Union -from enum import Enum -from pydantic import BaseModel +from typing import List, Optional -from prisma.models import APIKey as PrismaAPIKey -from backend.data.db import BaseDbModel, transaction from autogpt_libs.api_key.key_manager import APIKeyManager +from prisma.enums import APIKeyPermission, APIKeyStatus +from prisma.errors import PrismaError +from prisma.models import APIKey as PrismaAPIKey +from prisma.types import ( + APIKeyCreateInput, + APIKeyUpdateInput, + APIKeyWhereInput, + APIKeyWhereUniqueInput, +) +from pydantic import BaseModel + +from backend.data.db import BaseDbModel logger = logging.getLogger(__name__) -class APIKeyPermission(str, Enum): - EXECUTE_GRAPH = "EXECUTE_GRAPH" - READ_GRAPH = "READ_GRAPH" - EXECUTE_BLOCK = "EXECUTE_BLOCK" - READ_BLOCK = "READ_BLOCK" -class APIKeyStatus(str, Enum): - ACTIVE = "ACTIVE" - REVOKED = "REVOKED" - SUSPENDED = "SUSPENDED" +# Some basic exceptions +class APIKeyError(Exception): + """Base exception for API key operations""" + + pass + + +class APIKeyNotFoundError(APIKeyError): + """Raised when an API key is not found""" + + pass + + +class APIKeyPermissionError(APIKeyError): + """Raised when there are permission issues with API key operations""" + + pass + + +class APIKeyValidationError(APIKeyError): + """Raised when API key validation fails""" + + pass + class APIKey(BaseDbModel): name: str @@ -37,20 +60,25 @@ class APIKey(BaseDbModel): @staticmethod def from_db(api_key: PrismaAPIKey): - return APIKey( - id=api_key.id, - name=api_key.name, - prefix=api_key.prefix, - postfix=api_key.postfix, - key=api_key.key, - status=APIKeyStatus(api_key.status), - permissions=[APIKeyPermission(p) for p in api_key.permissions], - created_at=api_key.createdAt, - last_used_at=api_key.lastUsedAt, - revoked_at=api_key.revokedAt, - description=api_key.description, - user_id=api_key.userId - ) + try: + return APIKey( + id=api_key.id, + name=api_key.name, + prefix=api_key.prefix, + postfix=api_key.postfix, + key=api_key.key, + status=APIKeyStatus(api_key.status), + permissions=[APIKeyPermission(p) for p in api_key.permissions], + created_at=api_key.createdAt, + last_used_at=api_key.lastUsedAt, + revoked_at=api_key.revokedAt, + description=api_key.description, + user_id=api_key.userId, + ) + except Exception as e: + logger.error(f"Error creating APIKey from db: {str(e)}") + raise APIKeyError(f"Failed to create API key object: {str(e)}") + class APIKeyWithoutHash(BaseModel): id: str @@ -67,137 +95,231 @@ class APIKeyWithoutHash(BaseModel): @staticmethod def from_db(api_key: PrismaAPIKey): - return APIKeyWithoutHash( - id=api_key.id, - name=api_key.name, - prefix=api_key.prefix, - postfix=api_key.postfix, - status=APIKeyStatus(api_key.status), - permissions=[APIKeyPermission(p) for p in api_key.permissions], - created_at=api_key.createdAt, - last_used_at=api_key.lastUsedAt, - revoked_at=api_key.revokedAt, - description=api_key.description, - user_id=api_key.userId - ) + try: + return APIKeyWithoutHash( + id=api_key.id, + name=api_key.name, + prefix=api_key.prefix, + postfix=api_key.postfix, + status=APIKeyStatus(api_key.status), + permissions=[APIKeyPermission(p) for p in api_key.permissions], + created_at=api_key.createdAt, + last_used_at=api_key.lastUsedAt, + revoked_at=api_key.revokedAt, + description=api_key.description, + user_id=api_key.userId, + ) + except Exception as e: + logger.error(f"Error creating APIKeyWithoutHash from db: {str(e)}") + raise APIKeyError(f"Failed to create API key object: {str(e)}") -# --------------------- Model functions --------------------- # async def generate_api_key( name: str, user_id: str, permissions: List[APIKeyPermission], - description: Optional[str] = None + description: Optional[str] = None, ) -> tuple[APIKeyWithoutHash, str]: """ Generate a new API key and store it in the database. Returns the API key object (without hash) and the plain text key. """ - api_manager = APIKeyManager() - key = api_manager.generate_api_key() - - api_key = await PrismaAPIKey.prisma().create( - data={ - "id": str(uuid.uuid4()), - "name": name, - "prefix": key.prefix, - "postfix": key.postfix, - "key": key.hash, - "permissions": [p.value for p in permissions], - "description": description, - "userId": user_id - } - ) - - api_key_without_hash = APIKeyWithoutHash.from_db(api_key) - return api_key_without_hash, key.raw + try: + api_manager = APIKeyManager() + key = api_manager.generate_api_key() + + api_key = await PrismaAPIKey.prisma().create( + data=APIKeyCreateInput( + id=str(uuid.uuid4()), + name=name, + prefix=key.prefix, + postfix=key.postfix, + key=key.hash, + permissions=[p for p in permissions], + description=description, + userId=user_id, + ) + ) + + api_key_without_hash = APIKeyWithoutHash.from_db(api_key) + return api_key_without_hash, key.raw + except PrismaError as e: + logger.error(f"Database error while generating API key: {str(e)}") + raise APIKeyError(f"Failed to generate API key: {str(e)}") + except Exception as e: + logger.error(f"Unexpected error while generating API key: {str(e)}") + raise APIKeyError(f"Failed to generate API key: {str(e)}") + async def validate_api_key(plain_text_key: str) -> Optional[APIKey]: """ Validate an API key and return the API key object if valid. """ - if not plain_text_key.startswith(APIKeyManager.PREFIX): - return None + try: + if not plain_text_key.startswith(APIKeyManager.PREFIX): + logger.warning("Invalid API key format") + return None - prefix = plain_text_key[:APIKeyManager.PREFIX_LENGTH] - api_manager = APIKeyManager() + prefix = plain_text_key[: APIKeyManager.PREFIX_LENGTH] + api_manager = APIKeyManager() - api_key = await PrismaAPIKey.prisma().find_first( - where={ - "prefix": prefix, - "status": APIKeyStatus.ACTIVE.value - } - ) + api_key = await PrismaAPIKey.prisma().find_first( + where=APIKeyWhereInput(prefix=prefix, status=(APIKeyStatus.ACTIVE)) + ) - if not api_key: - return None + if not api_key: + logger.warning(f"No active API key found with prefix {prefix}") + return None - is_valid = api_manager.verify_api_key(plain_text_key, api_key.key) - if not is_valid: - return None + is_valid = api_manager.verify_api_key(plain_text_key, api_key.key) + if not is_valid: + logger.warning("API key verification failed") + return None + + return APIKey.from_db(api_key) + except Exception as e: + logger.error(f"Error validating API key: {str(e)}") + raise APIKeyValidationError(f"Failed to validate API key: {str(e)}") + + +async def revoke_api_key(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]: + try: + api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id}) - return APIKey.from_db(api_key) + if not api_key: + raise APIKeyNotFoundError(f"API key with id {key_id} not found") -async def revoke_api_key(key_id: str, user_id: str) -> APIKeyWithoutHash: - api_key = await PrismaAPIKey.prisma().update( - where={ - "id": key_id, - "userId": user_id - }, - data={ - "status": APIKeyStatus.REVOKED.value, - "revokedAt": datetime.now(timezone.utc) - } - ) + if api_key.userId != user_id: + raise APIKeyPermissionError( + "You do not have permission to revoke this API key." + ) + + where_clause: APIKeyWhereUniqueInput = {"id": key_id} + updated_api_key = await PrismaAPIKey.prisma().update( + where=where_clause, + data=APIKeyUpdateInput( + status=APIKeyStatus.REVOKED, revokedAt=datetime.now(timezone.utc) + ), + ) + + if updated_api_key: + return APIKeyWithoutHash.from_db(updated_api_key) + return None + except (APIKeyNotFoundError, APIKeyPermissionError) as e: + raise e + except PrismaError as e: + logger.error(f"Database error while revoking API key: {str(e)}") + raise APIKeyError(f"Failed to revoke API key: {str(e)}") + except Exception as e: + logger.error(f"Unexpected error while revoking API key: {str(e)}") + raise APIKeyError(f"Failed to revoke API key: {str(e)}") - return APIKeyWithoutHash.from_db(api_key) async def list_user_api_keys(user_id: str) -> List[APIKeyWithoutHash]: - api_keys = await PrismaAPIKey.prisma().find_many( - where={"userId": user_id}, - order={"createdAt": "desc"} - ) + try: + where_clause: APIKeyWhereInput = {"userId": user_id} - return [APIKeyWithoutHash.from_db(key) for key in api_keys] + api_keys = await PrismaAPIKey.prisma().find_many( + where=where_clause, order={"createdAt": "desc"} + ) + + return [APIKeyWithoutHash.from_db(key) for key in api_keys] + except PrismaError as e: + logger.error(f"Database error while listing API keys: {str(e)}") + raise APIKeyError(f"Failed to list API keys: {str(e)}") + except Exception as e: + logger.error(f"Unexpected error while listing API keys: {str(e)}") + raise APIKeyError(f"Failed to list API keys: {str(e)}") + + +async def suspend_api_key(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]: + try: + api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id}) -async def suspend_api_key(key_id: str, user_id: str) -> APIKeyWithoutHash: - api_key = await PrismaAPIKey.prisma().update( - where={ - "id": key_id, - "userId": user_id - }, - data={"status": APIKeyStatus.SUSPENDED.value} - ) + if not api_key: + raise APIKeyNotFoundError(f"API key with id {key_id} not found") + + if api_key.userId != user_id: + raise APIKeyPermissionError( + "You do not have permission to suspend this API key." + ) + + where_clause: APIKeyWhereUniqueInput = {"id": key_id} + updated_api_key = await PrismaAPIKey.prisma().update( + where=where_clause, + data=APIKeyUpdateInput(status=APIKeyStatus.SUSPENDED), + ) + + if updated_api_key: + return APIKeyWithoutHash.from_db(updated_api_key) + return None + except (APIKeyNotFoundError, APIKeyPermissionError) as e: + raise e + except PrismaError as e: + logger.error(f"Database error while suspending API key: {str(e)}") + raise APIKeyError(f"Failed to suspend API key: {str(e)}") + except Exception as e: + logger.error(f"Unexpected error while suspending API key: {str(e)}") + raise APIKeyError(f"Failed to suspend API key: {str(e)}") - return APIKeyWithoutHash.from_db(api_key) def has_permission(api_key: APIKey, required_permission: APIKeyPermission) -> bool: - return required_permission in api_key.permissions + try: + return required_permission in api_key.permissions + except Exception as e: + logger.error(f"Error checking API key permissions: {str(e)}") + return False + async def get_api_key_by_id(key_id: str, user_id: str) -> Optional[APIKeyWithoutHash]: - api_key = await PrismaAPIKey.prisma().find_first( - where={ - "id": key_id, - "userId": user_id - } - ) + try: + api_key = await PrismaAPIKey.prisma().find_first( + where=APIKeyWhereInput(id=key_id, userId=user_id) + ) + + if not api_key: + return None + + return APIKeyWithoutHash.from_db(api_key) + except PrismaError as e: + logger.error(f"Database error while getting API key: {str(e)}") + raise APIKeyError(f"Failed to get API key: {str(e)}") + except Exception as e: + logger.error(f"Unexpected error while getting API key: {str(e)}") + raise APIKeyError(f"Failed to get API key: {str(e)}") - return APIKeyWithoutHash.from_db(api_key) if api_key else None async def update_api_key_permissions( - key_id: str, - user_id: str, - permissions: List[APIKeyPermission] -) -> APIKeyWithoutHash: + key_id: str, user_id: str, permissions: List[APIKeyPermission] +) -> Optional[APIKeyWithoutHash]: """ Update the permissions of an API key. """ - api_key = await PrismaAPIKey.prisma().update( - where={ - "id": key_id, - "userId": user_id - }, - data={"permissions": [p.value for p in permissions]} - ) - - return APIKeyWithoutHash.from_db(api_key) + try: + api_key = await PrismaAPIKey.prisma().find_unique(where={"id": key_id}) + + if api_key is None: + raise APIKeyNotFoundError("No such API key found.") + + if api_key.userId != user_id: + raise APIKeyPermissionError( + "You do not have permission to update this API key." + ) + + where_clause: APIKeyWhereUniqueInput = {"id": key_id} + updated_api_key = await PrismaAPIKey.prisma().update( + where=where_clause, + data=APIKeyUpdateInput(permissions=permissions), + ) + + if updated_api_key: + return APIKeyWithoutHash.from_db(updated_api_key) + return None + except (APIKeyNotFoundError, APIKeyPermissionError) as e: + raise e + except PrismaError as e: + logger.error(f"Database error while updating API key permissions: {str(e)}") + raise APIKeyError(f"Failed to update API key permissions: {str(e)}") + except Exception as e: + logger.error(f"Unexpected error while updating API key permissions: {str(e)}") + raise APIKeyError(f"Failed to update API key permissions: {str(e)}") diff --git a/autogpt_platform/backend/backend/server/model.py b/autogpt_platform/backend/backend/server/model.py index d41f08492c33..4891c2c9cbb9 100644 --- a/autogpt_platform/backend/backend/server/model.py +++ b/autogpt_platform/backend/backend/server/model.py @@ -6,6 +6,7 @@ import backend.data.graph from backend.data.api_key import APIKeyPermission, APIKeyWithoutHash + class Methods(enum.Enum): SUBSCRIBE = "subscribe" UNSUBSCRIBE = "unsubscribe" @@ -42,12 +43,15 @@ class CreateAPIKeyRequest(pydantic.BaseModel): permissions: typing.List[APIKeyPermission] description: typing.Optional[str] = None + class CreateAPIKeyResponse(pydantic.BaseModel): api_key: APIKeyWithoutHash plain_text_key: str + class SetGraphActiveVersion(pydantic.BaseModel): active_graph_version: int + class UpdatePermissionsRequest(pydantic.BaseModel): permissions: typing.List[APIKeyPermission] diff --git a/autogpt_platform/backend/backend/server/routers/v1.py b/autogpt_platform/backend/backend/server/routers/v1.py index d2463c856b32..d88061bf1f3f 100644 --- a/autogpt_platform/backend/backend/server/routers/v1.py +++ b/autogpt_platform/backend/backend/server/routers/v1.py @@ -1,25 +1,41 @@ import asyncio import logging from collections import defaultdict -from typing import Annotated, Any, Dict, List, Optional +from typing import Annotated, Any, Dict, List from autogpt_libs.auth.middleware import auth_middleware from autogpt_libs.utils.cache import thread_cached from fastapi import APIRouter, Body, Depends, HTTPException -from typing_extensions import TypedDict -from pydantic import BaseModel +from typing_extensions import Optional, TypedDict import backend.data.block import backend.server.integrations.router import backend.server.routers.analytics from backend.data import execution as execution_db from backend.data import graph as graph_db +from backend.data.api_key import ( + APIKeyError, + APIKeyNotFoundError, + APIKeyPermissionError, + APIKeyWithoutHash, + generate_api_key, + get_api_key_by_id, + list_user_api_keys, + revoke_api_key, + suspend_api_key, + update_api_key_permissions, +) from backend.data.block import BlockInput, CompletedBlockOutput from backend.data.credit import get_block_costs, get_user_credit_model from backend.data.user import get_or_create_user -from backend.data.api_key import APIKeyPermission, APIKeyWithoutHash, generate_api_key, get_api_key_by_id, list_user_api_keys, revoke_api_key, suspend_api_key, update_api_key_permissions from backend.executor import ExecutionManager, ExecutionScheduler -from backend.server.model import CreateAPIKeyRequest, CreateAPIKeyResponse, CreateGraph, SetGraphActiveVersion, UpdatePermissionsRequest +from backend.server.model import ( + CreateAPIKeyRequest, + CreateAPIKeyResponse, + CreateGraph, + SetGraphActiveVersion, + UpdatePermissionsRequest, +) from backend.server.utils import get_user_id from backend.util.service import get_service_client from backend.util.settings import Settings @@ -529,6 +545,7 @@ async def update_configuration( ##################### API KEY ############################## ######################################################## + @v1_router.post( "/api-keys", response_model=CreateAPIKeyResponse, @@ -536,18 +553,21 @@ async def update_configuration( dependencies=[Depends(auth_middleware)], ) async def create_api_key( - request: CreateAPIKeyRequest, - user_id: Annotated[str, Depends(get_user_id)] + request: CreateAPIKeyRequest, user_id: Annotated[str, Depends(get_user_id)] ) -> CreateAPIKeyResponse: """Create a new API key""" - api_key, plain_text = await generate_api_key( - name=request.name, - user_id=user_id, - permissions=request.permissions, - description=request.description - ) + try: + api_key, plain_text = await generate_api_key( + name=request.name, + user_id=user_id, + permissions=request.permissions, + description=request.description, + ) + return CreateAPIKeyResponse(api_key=api_key, plain_text_key=plain_text) + except APIKeyError as e: + logger.error(f"Failed to create API key: {str(e)}") + raise HTTPException(status_code=400, detail=str(e)) - return CreateAPIKeyResponse(api_key=api_key, plain_text_key=plain_text) @v1_router.get( "/api-keys", @@ -559,7 +579,12 @@ async def get_api_keys( user_id: Annotated[str, Depends(get_user_id)] ) -> List[APIKeyWithoutHash]: """List all API keys for the user""" - return await list_user_api_keys(user_id) + try: + return await list_user_api_keys(user_id) + except APIKeyError as e: + logger.error(f"Failed to list API keys: {str(e)}") + raise HTTPException(status_code=400, detail=str(e)) + @v1_router.get( "/api-keys/{key_id}", @@ -568,14 +593,18 @@ async def get_api_keys( dependencies=[Depends(auth_middleware)], ) async def get_api_key( - key_id: str, - user_id: Annotated[str, Depends(get_user_id)] + key_id: str, user_id: Annotated[str, Depends(get_user_id)] ) -> APIKeyWithoutHash: """Get a specific API key""" - api_key = await get_api_key_by_id(key_id, user_id) - if not api_key: - raise HTTPException(status_code=404, detail="API key not found") - return api_key + try: + api_key = await get_api_key_by_id(key_id, user_id) + if not api_key: + raise HTTPException(status_code=404, detail="API key not found") + return api_key + except APIKeyError as e: + logger.error(f"Failed to get API key: {str(e)}") + raise HTTPException(status_code=400, detail=str(e)) + @v1_router.delete( "/api-keys/{key_id}", @@ -584,11 +613,19 @@ async def get_api_key( dependencies=[Depends(auth_middleware)], ) async def delete_api_key( - key_id: str, - user_id: Annotated[str, Depends(get_user_id)] -) -> APIKeyWithoutHash: + key_id: str, user_id: Annotated[str, Depends(get_user_id)] +) -> Optional[APIKeyWithoutHash]: """Revoke an API key""" - return await revoke_api_key(key_id, user_id) + try: + return await revoke_api_key(key_id, user_id) + except APIKeyNotFoundError: + raise HTTPException(status_code=404, detail="API key not found") + except APIKeyPermissionError: + raise HTTPException(status_code=403, detail="Permission denied") + except APIKeyError as e: + logger.error(f"Failed to revoke API key: {str(e)}") + raise HTTPException(status_code=400, detail=str(e)) + @v1_router.post( "/api-keys/{key_id}/suspend", @@ -597,11 +634,19 @@ async def delete_api_key( dependencies=[Depends(auth_middleware)], ) async def suspend_key( - key_id: str, - user_id: Annotated[str, Depends(get_user_id)] -) -> APIKeyWithoutHash: + key_id: str, user_id: Annotated[str, Depends(get_user_id)] +) -> Optional[APIKeyWithoutHash]: """Suspend an API key""" - return await suspend_api_key(key_id, user_id) + try: + return await suspend_api_key(key_id, user_id) + except APIKeyNotFoundError: + raise HTTPException(status_code=404, detail="API key not found") + except APIKeyPermissionError: + raise HTTPException(status_code=403, detail="Permission denied") + except APIKeyError as e: + logger.error(f"Failed to suspend API key: {str(e)}") + raise HTTPException(status_code=400, detail=str(e)) + @v1_router.put( "/api-keys/{key_id}/permissions", @@ -611,8 +656,16 @@ async def suspend_key( ) async def update_permissions( key_id: str, - request: UpdatePermissionsRequest , - user_id: Annotated[str, Depends(get_user_id)] -) -> APIKeyWithoutHash: + request: UpdatePermissionsRequest, + user_id: Annotated[str, Depends(get_user_id)], +) -> Optional[APIKeyWithoutHash]: """Update API key permissions""" - return await update_api_key_permissions(key_id, user_id, request.permissions) + try: + return await update_api_key_permissions(key_id, user_id, request.permissions) + except APIKeyNotFoundError: + raise HTTPException(status_code=404, detail="API key not found") + except APIKeyPermissionError: + raise HTTPException(status_code=403, detail="Permission denied") + except APIKeyError as e: + logger.error(f"Failed to update API key permissions: {str(e)}") + raise HTTPException(status_code=400, detail=str(e))