From 857ae69184675b9ad2c60dfc0f3c3e6eb7a9777d Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Thu, 17 Oct 2024 00:09:44 +0200 Subject: [PATCH] Add `Graph.on_update`, `Block.on_node_update`, `Block.on_node_delete` hooks --- .../backend/backend/data/block.py | 12 +++++ .../backend/backend/data/graph.py | 48 ++++++++++++++++++- .../backend/backend/server/rest_api.py | 5 +- 3 files changed, 63 insertions(+), 2 deletions(-) diff --git a/autogpt_platform/backend/backend/data/block.py b/autogpt_platform/backend/backend/data/block.py index 594fd10e7681..a94a02f6e212 100644 --- a/autogpt_platform/backend/backend/data/block.py +++ b/autogpt_platform/backend/backend/data/block.py @@ -278,6 +278,18 @@ def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput: raise ValueError(f"Block produced an invalid output data: {error}") yield output_name, output_data + def on_node_update( + self, + new_preset_inputs: BlockInput, + old_preset_inputs: Optional[BlockInput] = None, + ) -> None: + """Hook to be called when the preset inputs change or the block is created""" + pass + + def on_node_delete(self, preset_inputs: BlockInput) -> None: + """Hook to be called when the block is deleted""" + pass + # ======================= Block Helper Functions ======================= # diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index 4f1be1de1ed8..7dc8d85986e3 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -2,7 +2,7 @@ import logging import uuid from datetime import datetime, timezone -from typing import Any, Literal +from typing import Any, Literal, Optional import prisma.types from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink @@ -282,6 +282,52 @@ def is_input_output_block(nid: str) -> bool: # TODO: Add type compatibility check here. + def on_update(self, previous_graph_version: Optional["Graph"] = None): + """ + Hook for graph creation/updation. + + Compares nodes and their preset inputs with a previous graph version, and calls + the `on_node_update` and `on_node_delete` hooks of the corresponding blocks + where applicable. + """ + # Compare nodes in new_graph_version with previous_graph_version + for new_node in self.nodes: + if not (block := get_block(new_node.block_id)): + raise ValueError(f"Block #{new_node.block_id} not found") + + if previous_graph_version and ( + old_node := next( + ( + node + for node in previous_graph_version.nodes + if node.id == new_node.id + ), + None, + ) + ): + if new_node.input_default != old_node.input_default: + # Input default has changed, call on_node_update + block.on_node_update( + new_node.input_default, + old_node.input_default, + ) + else: + # New node added, call on_node_update with only new inputs + block.on_node_update(new_node.input_default) + + if previous_graph_version: + # Check for deleted nodes + for old_node in previous_graph_version.nodes: + if not any(node.id == old_node.id for node in self.nodes): + # Node was deleted, call on_node_delete + if block := get_block(old_node.block_id): + block.on_node_delete(preset_inputs=old_node.input_default) + else: + logger.warning( + f"Can not handle node #{old_node.id} deletion: " + f"block #{new_node.block_id} not found" + ) + def get_input_schema(self) -> list[InputSchemaItem]: """ Walks the graph and returns all the inputs that are either not: diff --git a/autogpt_platform/backend/backend/server/rest_api.py b/autogpt_platform/backend/backend/server/rest_api.py index 6860f3f60fee..415c2eaf4f87 100644 --- a/autogpt_platform/backend/backend/server/rest_api.py +++ b/autogpt_platform/backend/backend/server/rest_api.py @@ -453,7 +453,9 @@ async def create_graph( graph.is_active = not is_template graph.reassign_ids(reassign_graph_id=True) - return await graph_db.create_graph(graph, user_id=user_id) + graph = await graph_db.create_graph(graph, user_id=user_id) + graph.on_update() + return graph @classmethod async def update_graph( @@ -486,6 +488,7 @@ async def update_graph( graph.reassign_ids() new_graph_version = await graph_db.create_graph(graph, user_id=user_id) + new_graph_version.on_update(latest_version_graph) if new_graph_version.is_active: # Ensure new version is the only active version