From 59b467eb78933c3e1411a1cf4ba4cc729b5225b2 Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Thu, 17 Oct 2024 02:09:03 +0200 Subject: [PATCH] Add credentials injection to `Graph.on_update` etc hooks --- .../backend/backend/blocks/github/triggers.py | 299 ++++++++++++++++++ .../backend/backend/data/block.py | 7 +- .../backend/backend/data/graph.py | 60 +++- .../backend/backend/server/rest_api.py | 25 +- 4 files changed, 373 insertions(+), 18 deletions(-) create mode 100644 autogpt_platform/backend/backend/blocks/github/triggers.py diff --git a/autogpt_platform/backend/backend/blocks/github/triggers.py b/autogpt_platform/backend/backend/blocks/github/triggers.py new file mode 100644 index 000000000000..1675e5d7144d --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/github/triggers.py @@ -0,0 +1,299 @@ +from typing import Optional + +import requests +from pydantic import BaseModel + +from backend.data.block import ( + Block, + BlockCategory, + BlockInput, + BlockOutput, + BlockSchema, +) +from backend.data.model import SchemaField +from backend.integrations.creds_manager import IntegrationCredentialsManager + +from ._auth import ( + TEST_CREDENTIALS, + TEST_CREDENTIALS_INPUT, + GithubCredentials, + GithubCredentialsField, + GithubCredentialsInput, +) + +creds_manager = IntegrationCredentialsManager() + + +class GitHubBaseTriggerBlock(Block): + class Input(BlockSchema): + credentials: GithubCredentialsInput = GithubCredentialsField("repo") + repo: str = SchemaField( + description="Repository to subscribe to", + placeholder="{owner}/{repo}", + ) + payload: dict = SchemaField(description="Webhook payload", exclude=True) + + class Output(BlockSchema): + event: str = SchemaField(description="The event that triggered the webhook") + payload: dict = SchemaField(description="Full payload of the event") + sender: dict = SchemaField( + description="Object representing the user who triggered the event" + ) + error: str = SchemaField( + description="Error message if the payload could not be processed" + ) + + @classmethod + def create_webhook( + cls, credentials: GithubCredentials, repo: str, events: list[str] + ): + # TODO: Create webhook in DB + + # Create webhook on GitHub + api_url = f"https://api.github.com/repos/{repo}/hooks" + headers = { + "Authorization": credentials.bearer(), + "Accept": "application/vnd.github.v3+json", + } + payload = { + "name": "web", + "active": True, + "events": events, + "config": { + "url": "YOUR_WEBHOOK_URL", # Replace with actual webhook URL + "content_type": "json", + "insecure_ssl": "0", + }, + } + response = requests.post(api_url, headers=headers, json=payload) + response.raise_for_status() + + @classmethod + def update_webhook( + cls, + credentials: GithubCredentials, + repo: str, + events: list[str], + webhook_id: str, + ): + # TODO: Update webhook in DB + + # Update webhook on GitHub + api_url = f"https://api.github.com/repos/{repo}/hooks/{webhook_id}" + headers = { + "Authorization": credentials.bearer(), + "Accept": "application/vnd.github.v3+json", + } + payload = { + "active": True, + "events": events, + "config": { + "url": "YOUR_WEBHOOK_URL", # Replace with actual webhook URL + "content_type": "json", + "insecure_ssl": "0", + }, + } + response = requests.patch(api_url, headers=headers, json=payload) + response.raise_for_status() + + @classmethod + def delete_webhook(cls, credentials: GithubCredentials, repo: str, webhook_id: str): + # TODO: Delete webhook from DB + + # Delete webhook from GitHub + api_url = f"https://api.github.com/repos/{repo}/hooks/{webhook_id}" + headers = { + "Authorization": credentials.bearer(), + "Accept": "application/vnd.github.v3+json", + } + response = requests.delete(api_url, headers=headers) + response.raise_for_status() + + +class GithubPullRequestTriggerBlock(GitHubBaseTriggerBlock): + class Input(GitHubBaseTriggerBlock.Input): + class EventsFilter(BaseModel): + """ + https://docs.github.com/en/webhooks/webhook-events-and-payloads#pull_request + """ + + opened: bool = False + edited: bool = False + closed: bool = False + reopened: bool = False + synchronize: bool = False + assigned: bool = False + unassigned: bool = False + labeled: bool = False + unlabeled: bool = False + converted_to_draft: bool = False + locked: bool = False + unlocked: bool = False + enqueued: bool = False + dequeued: bool = False + milestoned: bool = False + demilestoned: bool = False + ready_for_review: bool = False + review_requested: bool = False + review_request_removed: bool = False + auto_merge_enabled: bool = False + auto_merge_disabled: bool = False + + events: EventsFilter = SchemaField(description="The events to subscribe to") + + class Output(GitHubBaseTriggerBlock.Output): + number: int = SchemaField(description="The number of the affected pull request") + pull_request: dict = SchemaField( + description="Object representing the pull request" + ) + + def __init__(self): + super().__init__( + id="6c60ec01-8128-419e-988f-96a063ee2fea", + description="This block triggers on pull request events and outputs the event type and payload.", + categories={BlockCategory.DEVELOPER_TOOLS, BlockCategory.INPUT}, + input_schema=GithubPullRequestTriggerBlock.Input, + output_schema=GithubPullRequestTriggerBlock.Output, + test_input={ + "repo": "owner/repo", + "events": {"opened": True, "synchronize": True}, + "credentials": TEST_CREDENTIALS_INPUT, + }, + test_credentials=TEST_CREDENTIALS, + test_output=[ + # ("title", "Title of the pull request"), + # ("body", "This is the body of the pull request."), + # ("author", "username"), + # ("changes", "List of changes made in the pull request."), + ], + test_mock={ + # "read_pr": lambda *args, **kwargs: ( + # "Title of the pull request", + # "This is the body of the pull request.", + # "username", + # ), + # "read_pr_changes": lambda *args, **kwargs: "List of changes made in the pull request.", + }, + ) + + @staticmethod + def read_pr_changes(credentials: GithubCredentials, pr_url: str) -> str: + api_url = ( + pr_url.replace("github.com", "api.github.com/repos").replace( + "/pull/", "/pulls/" + ) + + "/files" + ) + + headers = { + "Authorization": credentials.bearer(), + "Accept": "application/vnd.github.v3+json", + } + + response = requests.get(api_url, headers=headers) + response.raise_for_status() + + files = response.json() + changes = [] + for file in files: + filename = file.get("filename") + patch = file.get("patch") + if filename and patch: + changes.append(f"File: {filename}\n{patch}") + + return "\n\n".join(changes) + + def run( + self, + input_data: Input, + **kwargs, + ) -> BlockOutput: + # title, body, author = self.read_pr( + # credentials, + # input_data.pr_url, + # ) + # yield "title", title + # yield "body", body + # yield "author", author + + # if input_data.include_pr_changes: + # changes = self.read_pr_changes( + # credentials, + # input_data.pr_url, + # ) + # yield "changes", changes + yield "payload", input_data.payload + yield "event", input_data.payload["action"] + yield "sender", input_data.payload["sender"] + yield "number", input_data.payload["number"] + yield "pull_request", input_data.payload["pull_request"] + + def on_node_update( + self, + new_preset_inputs: BlockInput, + old_preset_inputs: Optional[BlockInput] = None, + *, + new_credentials: Optional[GithubCredentials] = None, + old_credentials: Optional[GithubCredentials] = None, + ) -> None: + old_has_all = ( + old_credentials + and old_preset_inputs + and all(key in old_preset_inputs for key in ["repo", "events"]) + ) + new_has_all = new_credentials and all( + key in new_preset_inputs for key in ["repo", "events"] + ) + + if new_has_all and old_has_all: + # Input was and is complete -> update webhook to new config + + # Pyright doesn't get that old_has_all can be used for type narrowing here + assert old_preset_inputs + assert old_credentials is not None + assert new_credentials is not None + + # TODO: Get webhook_id from DB + webhook_id = "WEBHOOK_ID" # Replace with actual webhook ID + + if new_credentials != old_credentials: + # Credentials were replaced -> recreate webhook with new credentials + self.delete_webhook( + old_credentials, + old_preset_inputs["repo"], + webhook_id, + ) + self.create_webhook( + new_credentials, + new_preset_inputs["repo"], + new_preset_inputs["events"], + ) + else: + self.update_webhook( + new_credentials, + new_preset_inputs["repo"], + new_preset_inputs["events"], + webhook_id, + ) + elif new_has_all and not old_has_all: + # Input was incomplete -> create new webhook + assert new_credentials is not None + self.create_webhook( + new_credentials, + new_preset_inputs["repo"], + new_preset_inputs["events"], + ) + elif not new_has_all and old_has_all: + # Input has become incomplete -> delete webhook + + assert old_preset_inputs + assert old_credentials is not None + + # TODO: Get webhook_id from DB + webhook_id = "WEBHOOK_ID" # Replace with actual webhook ID + + self.delete_webhook( + old_credentials, + old_preset_inputs["repo"], + webhook_id, + ) diff --git a/autogpt_platform/backend/backend/data/block.py b/autogpt_platform/backend/backend/data/block.py index a94a02f6e212..890e730cd319 100644 --- a/autogpt_platform/backend/backend/data/block.py +++ b/autogpt_platform/backend/backend/data/block.py @@ -282,11 +282,16 @@ def on_node_update( self, new_preset_inputs: BlockInput, old_preset_inputs: Optional[BlockInput] = None, + *, + new_credentials: Optional[Credentials] = None, + old_credentials: Optional[Credentials] = None, ) -> None: """Hook to be called when the preset inputs change or the block is created""" pass - def on_node_delete(self, preset_inputs: BlockInput) -> None: + def on_node_delete( + self, preset_inputs: BlockInput, *, credentials: Optional[Credentials] = None + ) -> None: """Hook to be called when the block is deleted""" pass diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index 7dc8d85986e3..4bd3e563153c 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -2,7 +2,7 @@ import logging import uuid from datetime import datetime, timezone -from typing import Any, Literal, Optional +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional import prisma.types from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink @@ -14,8 +14,12 @@ from backend.data.block import BlockInput, get_block, get_blocks from backend.data.db import BaseDbModel, transaction from backend.data.execution import ExecutionStatus +from backend.data.model import CREDENTIALS_FIELD_NAME from backend.util import json +if TYPE_CHECKING: + from autogpt_libs.supabase_integration_credentials_store.types import Credentials + logger = logging.getLogger(__name__) @@ -282,19 +286,37 @@ def is_input_output_block(nid: str) -> bool: # TODO: Add type compatibility check here. - def on_update(self, previous_graph_version: Optional["Graph"] = None): + def on_update( + self, + previous_graph_version: Optional["Graph"] = None, + *, + get_credentials: Callable[[str], Credentials | None], + ): """ Hook for graph creation/updation. Compares nodes and their preset inputs with a previous graph version, and calls the `on_node_update` and `on_node_delete` hooks of the corresponding blocks where applicable. + + Params: + previous_graph_version: The previous graph version to compare to + get_credentials: `credentials_id` -> Credentials """ # Compare nodes in new_graph_version with previous_graph_version for new_node in self.nodes: if not (block := get_block(new_node.block_id)): raise ValueError(f"Block #{new_node.block_id} not found") + new_credentials = None + if creds_meta := getattr(new_node.input_default, CREDENTIALS_FIELD_NAME): + new_credentials = get_credentials(creds_meta["id"]) + if not new_credentials: + raise ValueError( + f"Node #{new_node.id} updated with non-existent " + f"credentials #{new_credentials}" + ) + if previous_graph_version and ( old_node := next( ( @@ -305,15 +327,30 @@ def on_update(self, previous_graph_version: Optional["Graph"] = None): None, ) ): + old_credentials = None + if creds_meta := getattr( + old_node.input_default, CREDENTIALS_FIELD_NAME + ): + old_credentials = get_credentials(creds_meta["id"]) + if not old_credentials: + logger.error( + f"Node #{old_node.id} referenced non-existent " + f"credentials #{creds_meta['id']}" + ) + if new_node.input_default != old_node.input_default: # Input default has changed, call on_node_update block.on_node_update( new_node.input_default, old_node.input_default, + new_credentials=new_credentials, + old_credentials=old_credentials, ) else: # New node added, call on_node_update with only new inputs - block.on_node_update(new_node.input_default) + block.on_node_update( + new_node.input_default, new_credentials=new_credentials + ) if previous_graph_version: # Check for deleted nodes @@ -321,11 +358,24 @@ def on_update(self, previous_graph_version: Optional["Graph"] = None): if not any(node.id == old_node.id for node in self.nodes): # Node was deleted, call on_node_delete if block := get_block(old_node.block_id): - block.on_node_delete(preset_inputs=old_node.input_default) + credentials = None + if creds_meta := getattr( + old_node.input_default, CREDENTIALS_FIELD_NAME + ): + credentials = get_credentials(creds_meta["id"]) + if not credentials: + logger.error( + f"Node #{old_node.id} referenced non-existent " + f"credentials #{creds_meta['id']}" + ) + block.on_node_delete( + preset_inputs=old_node.input_default, + credentials=credentials, + ) else: logger.warning( f"Can not handle node #{old_node.id} deletion: " - f"block #{new_node.block_id} not found" + f"block #{old_node.block_id} not found" ) def get_input_schema(self) -> list[InputSchemaItem]: diff --git a/autogpt_platform/backend/backend/server/rest_api.py b/autogpt_platform/backend/backend/server/rest_api.py index 415c2eaf4f87..6cdec271be41 100644 --- a/autogpt_platform/backend/backend/server/rest_api.py +++ b/autogpt_platform/backend/backend/server/rest_api.py @@ -397,17 +397,15 @@ async def get_graph_all_versions( raise HTTPException(status_code=404, detail=f"Graph #{graph_id} not found.") return graphs - @classmethod async def create_new_graph( - cls, create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)] + self, create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)] ) -> graph_db.Graph: - return await cls.create_graph(create_graph, is_template=False, user_id=user_id) + return await self.create_graph(create_graph, is_template=False, user_id=user_id) - @classmethod async def create_new_template( - cls, create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)] + self, create_graph: CreateGraph, user_id: Annotated[str, Depends(get_user_id)] ) -> graph_db.Graph: - return await cls.create_graph(create_graph, is_template=True, user_id=user_id) + return await self.create_graph(create_graph, is_template=True, user_id=user_id) class DeleteGraphResponse(TypedDict): version_counts: int @@ -420,9 +418,8 @@ async def delete_graph( "version_counts": await graph_db.delete_graph(graph_id, user_id=user_id) } - @classmethod async def create_graph( - cls, + self, create_graph: CreateGraph, is_template: bool, # user_id doesn't have to be annotated like on other endpoints, @@ -454,12 +451,13 @@ async def create_graph( graph.reassign_ids(reassign_graph_id=True) graph = await graph_db.create_graph(graph, user_id=user_id) - graph.on_update() + graph.on_update( + get_credentials=lambda id: self.integration_creds_manager.get(user_id, id) + ) return graph - @classmethod async def update_graph( - cls, + self, graph_id: str, graph: graph_db.Graph, user_id: Annotated[str, Depends(get_user_id)], @@ -488,7 +486,10 @@ async def update_graph( graph.reassign_ids() new_graph_version = await graph_db.create_graph(graph, user_id=user_id) - new_graph_version.on_update(latest_version_graph) + new_graph_version.on_update( + previous_graph_version=latest_version_graph, + get_credentials=lambda id: self.integration_creds_manager.get(user_id, id), + ) if new_graph_version.is_active: # Ensure new version is the only active version