Skip to content

Commit

Permalink
Add Graph.on_update, Block.on_node_update, Block.on_node_delete
Browse files Browse the repository at this point in the history
… hooks
  • Loading branch information
Pwuts committed Oct 16, 2024
1 parent d6d2820 commit 857ae69
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 2 deletions.
12 changes: 12 additions & 0 deletions autogpt_platform/backend/backend/data/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,18 @@ def execute(self, input_data: BlockInput, **kwargs) -> BlockOutput:
raise ValueError(f"Block produced an invalid output data: {error}")
yield output_name, output_data

def on_node_update(
self,
new_preset_inputs: BlockInput,
old_preset_inputs: Optional[BlockInput] = None,
) -> None:
"""Hook to be called when the preset inputs change or the block is created"""
pass

def on_node_delete(self, preset_inputs: BlockInput) -> None:
"""Hook to be called when the block is deleted"""
pass


# ======================= Block Helper Functions ======================= #

Expand Down
48 changes: 47 additions & 1 deletion autogpt_platform/backend/backend/data/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import uuid
from datetime import datetime, timezone
from typing import Any, Literal
from typing import Any, Literal, Optional

import prisma.types
from prisma.models import AgentGraph, AgentGraphExecution, AgentNode, AgentNodeLink
Expand Down Expand Up @@ -282,6 +282,52 @@ def is_input_output_block(nid: str) -> bool:

# TODO: Add type compatibility check here.

def on_update(self, previous_graph_version: Optional["Graph"] = None):
"""
Hook for graph creation/updation.
Compares nodes and their preset inputs with a previous graph version, and calls
the `on_node_update` and `on_node_delete` hooks of the corresponding blocks
where applicable.
"""
# Compare nodes in new_graph_version with previous_graph_version
for new_node in self.nodes:
if not (block := get_block(new_node.block_id)):
raise ValueError(f"Block #{new_node.block_id} not found")

if previous_graph_version and (
old_node := next(
(
node
for node in previous_graph_version.nodes
if node.id == new_node.id
),
None,
)
):
if new_node.input_default != old_node.input_default:
# Input default has changed, call on_node_update
block.on_node_update(
new_node.input_default,
old_node.input_default,
)
else:
# New node added, call on_node_update with only new inputs
block.on_node_update(new_node.input_default)

if previous_graph_version:
# Check for deleted nodes
for old_node in previous_graph_version.nodes:
if not any(node.id == old_node.id for node in self.nodes):
# Node was deleted, call on_node_delete
if block := get_block(old_node.block_id):
block.on_node_delete(preset_inputs=old_node.input_default)
else:
logger.warning(
f"Can not handle node #{old_node.id} deletion: "
f"block #{new_node.block_id} not found"
)

def get_input_schema(self) -> list[InputSchemaItem]:
"""
Walks the graph and returns all the inputs that are either not:
Expand Down
5 changes: 4 additions & 1 deletion autogpt_platform/backend/backend/server/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,9 @@ async def create_graph(
graph.is_active = not is_template
graph.reassign_ids(reassign_graph_id=True)

return await graph_db.create_graph(graph, user_id=user_id)
graph = await graph_db.create_graph(graph, user_id=user_id)
graph.on_update()
return graph

@classmethod
async def update_graph(
Expand Down Expand Up @@ -486,6 +488,7 @@ async def update_graph(
graph.reassign_ids()

new_graph_version = await graph_db.create_graph(graph, user_id=user_id)
new_graph_version.on_update(latest_version_graph)

if new_graph_version.is_active:
# Ensure new version is the only active version
Expand Down

0 comments on commit 857ae69

Please sign in to comment.