diff --git a/src/hebg/call_graph.py b/src/hebg/call_graph.py index 1be9c94..67d2913 100644 --- a/src/hebg/call_graph.py +++ b/src/hebg/call_graph.py @@ -1,6 +1,5 @@ from enum import Enum -from re import S -from typing import Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypeVar, Union from matplotlib.axes import Axes from networkx import ( @@ -10,24 +9,73 @@ draw_networkx_nodes, ) import numpy as np +from hebg.behavior import Behavior +from hebg.graph import get_successors_with_index +from hebg.node import FeatureCondition, Node -from hebg.node import Node +if TYPE_CHECKING: + from hebg.heb_graph import HEBGraph - -class CallEdgeStatus(Enum): - UNEXPLORED = "unexplored" - CALLED = "called" - FAILURE = "failure" +Action = TypeVar("Action") class CallGraph(DiGraph): - def __init__(self, initial_node: Node, **attr): + def __init__(self, initial_node: "Node", **attr): super().__init__(incoming_graph_data=None, **attr) + self.graph["n_calls"] = 0 self.graph["frontiere"] = [] - self.add_node(initial_node.name, order=0) + self._known_fc: Dict[FeatureCondition, Any] = {} + self.add_node(initial_node.name, exploration_order=0, calls_order=[0]) + + def call_nodes( + self, + nodes: List["Node"], + observation, + hebgraph: "HEBGraph", + parent: "Node" = None, + ) -> Action: + self._extend_frontiere(nodes, parent) + next_node = self._pop_from_frontiere(parent) + if next_node is None: + raise ValueError("No valid frontiere left in call_graph") + return self._call_node(next_node, observation, hebgraph) + + def _call_node( + self, + node: "Node", + observation: Any, + hebgraph: "HEBGraph", + ) -> Action: + if node.type == "behavior": + # Search for name reference in all_behaviors + if node.name in hebgraph.all_behaviors: + node = hebgraph.all_behaviors[node.name] + return node(observation, self) + elif node.type == "action": + return node(observation) + elif node.type == "feature_condition": + if node in self._known_fc: + next_edge_index = self._known_fc[node] + else: + next_edge_index = int(node(observation)) + self._known_fc[node] = next_edge_index + next_nodes = get_successors_with_index(hebgraph, node, next_edge_index) + elif node.type == "empty": + next_nodes = list(hebgraph.successors(node)) + else: + raise ValueError( + f"Unknowed value {node.type} for node.type with node: {node}." + ) + + return self.call_nodes( + next_nodes, + observation, + hebgraph=hebgraph, + parent=node, + ) - def extend_frontiere(self, nodes: List[Node], parent: Node): - frontiere: List[Node] = self.graph["frontiere"] + def _extend_frontiere(self, nodes: List["Node"], parent: "Node"): + frontiere: List["Node"] = self.graph["frontiere"] frontiere.extend(nodes) for node in nodes: @@ -36,11 +84,11 @@ def extend_frontiere(self, nodes: List[Node], parent: Node): ) node_data = self.nodes[node.name] parent_data = self.nodes[parent.name] - if "order" not in node_data: - node_data["order"] = parent_data["order"] + 1 + if "exploration_order" not in node_data: + node_data["exploration_order"] = parent_data["exploration_order"] + 1 - def pop_from_frontiere(self, parent: Node) -> Optional[Node]: - frontiere: List[Node] = self.graph["frontiere"] + def _pop_from_frontiere(self, parent: "Node") -> Optional["Node"]: + frontiere: List["Node"] = self.graph["frontiere"] next_node = None @@ -49,17 +97,26 @@ def pop_from_frontiere(self, parent: Node) -> Optional[Node]: return None _next_node = frontiere.pop(np.argmin([node.cost for node in frontiere])) - if len(list(self.successors(_next_node))) > 0: - self.update_edge_status(parent, _next_node, CallEdgeStatus.FAILURE) + if ( + isinstance(_next_node, Behavior) + and len(list(self.successors(_next_node))) > 0 + ): + self._update_edge_status(parent, _next_node, CallEdgeStatus.FAILURE) continue - self.update_edge_status(parent, _next_node, CallEdgeStatus.CALLED) next_node = _next_node + self.graph["n_calls"] += 1 + calls_order = self.nodes[next_node.name].get("calls_order", None) + if calls_order is None: + calls_order = [] + calls_order.append(self.graph["n_calls"]) + self.nodes[next_node.name]["calls_order"] = calls_order + self._update_edge_status(parent, next_node, CallEdgeStatus.CALLED) return next_node - def update_edge_status( - self, start: Node, end: Node, status: Union[CallEdgeStatus, str] + def _update_edge_status( + self, start: "Node", end: "Node", status: Union["CallEdgeStatus", str] ): status = CallEdgeStatus(status) self.edges[start.name, end.name]["status"] = status.value @@ -73,7 +130,7 @@ def draw( edges_kwargs: Optional[dict] = None, ): if pos is None: - pos = call_graph_pos(self) + pos = _call_graph_pos(self) if nodes_kwargs is None: nodes_kwargs = {} draw_networkx_nodes(self, ax=ax, pos=pos, **nodes_kwargs) @@ -89,14 +146,20 @@ def draw( ax=ax, pos=pos, edge_color=[ - call_status_to_color(status) + _call_status_to_color(status) for _, _, status in self.edges(data="status") ], **edges_kwargs, ) -def call_status_to_color(status: Union[str, CallEdgeStatus]): +class CallEdgeStatus(Enum): + UNEXPLORED = "unexplored" + CALLED = "called" + FAILURE = "failure" + + +def _call_status_to_color(status: Union[str, "CallEdgeStatus"]): status = CallEdgeStatus(status) if status is CallEdgeStatus.UNEXPLORED: return "black" @@ -107,11 +170,11 @@ def call_status_to_color(status: Union[str, CallEdgeStatus]): raise NotImplementedError -def call_graph_pos(call_graph: DiGraph) -> Dict[str, Tuple[float, float]]: +def _call_graph_pos(call_graph: DiGraph) -> Dict[str, Tuple[float, float]]: pos = {} amount_by_order = {} for node, node_data in call_graph.nodes(data=True): - order: int = node_data["order"] + order: int = node_data["exploration_order"] if order not in amount_by_order: amount_by_order[order] = 0 else: diff --git a/src/hebg/heb_graph.py b/src/hebg/heb_graph.py index 61991e3..ec14068 100644 --- a/src/hebg/heb_graph.py +++ b/src/hebg/heb_graph.py @@ -6,23 +6,20 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional, Tuple, TypeVar +from typing import Any, Dict, List, Optional, Tuple from matplotlib.axes import Axes from networkx import DiGraph from hebg.behavior import Behavior -from hebg.call_graph import CallEdgeStatus, CallGraph +from hebg.call_graph import CallGraph from hebg.codegen import get_hebg_source from hebg.draw import draw_hebgraph -from hebg.graph import get_roots, get_successors_with_index +from hebg.graph import get_roots from hebg.node import Node from hebg.unrolling import unroll_graph -Action = TypeVar("Action") - - class HEBGraph(DiGraph): """Base class for Hierchical Explanation of Behavior as Graphs. @@ -125,59 +122,11 @@ def __call__( ) -> Any: if call_graph is None: call_graph = CallGraph(initial_node=self.behavior) - self.call_graph = call_graph - return self._split_call_between_nodes( - self.roots, observation, call_graph=call_graph + return self.call_graph.call_nodes( + self.roots, observation, hebgraph=self, parent=self.behavior ) - def _get_action(self, node: Node, observation: Any, call_graph: DiGraph): - # Behavior - if node.type == "behavior": - # Search for name reference in all_behaviors - if node.name in self.all_behaviors: - node = self.all_behaviors[node.name] - - return node(observation, call_graph) - - # Action - if node.type == "action": - return node(observation) - - # Feature Condition - if node.type == "feature_condition": - next_edge_index = int(node(observation)) - next_nodes = get_successors_with_index(self, node, next_edge_index) - return self._split_call_between_nodes( - next_nodes, observation, call_graph=call_graph, parent=node - ) - # Empty - if node.type == "empty": - return self._split_call_between_nodes( - list(self.successors(node)), - observation, - call_graph=call_graph, - parent=node, - ) - raise ValueError(f"Unknowed value {node.type} for node.type with node: {node}.") - - def _split_call_between_nodes( - self, - nodes: List[Node], - observation, - call_graph: CallGraph, - parent: Optional[Node] = None, - ) -> List[Action]: - if parent is None: - parent = self.behavior - - call_graph.extend_frontiere(nodes, parent) - next_node = call_graph.pop_from_frontiere(parent) - if next_node is None: - raise ValueError("No valid frontiere left in call_graph") - action = self._get_action(next_node, observation, call_graph) - return action - @property def roots(self) -> List[Node]: """Roots of the behavior graph (nodes without predecessors).""" @@ -203,9 +152,3 @@ def draw( """ return draw_hebgraph(self, ax, **kwargs) - - -def remove_duplicate_actions(actions: List[Action]) -> List[Action]: - seen = set() - seen_add = seen.add - return [a for a in actions if not (a in seen or seen_add(a))] diff --git a/tests/test_call_graph.py b/tests/test_call_graph.py index 8155f20..2790ac2 100644 --- a/tests/test_call_graph.py +++ b/tests/test_call_graph.py @@ -5,6 +5,7 @@ from hebg.node import Action from pytest_mock import MockerFixture +import pytest_check as check from tests import plot_graph @@ -84,8 +85,77 @@ def build_graph(self) -> HEBGraph: ) assert set(call_graph.edges()) == set(expected_graph.edges()) + def test_multiple_call_to_same_fc(self, mocker: MockerFixture): + """Call graph should allow for the same feature condition + to be called multiple times in the same branch (in different behaviors).""" + expected_action = Action("EXPECTED") + unexpected_action = Action("UNEXPECTED") + + feature_condition_call = mocker.patch( + "tests.examples.feature_conditions.ThresholdFeatureCondition.__call__", + return_value=True, + ) + feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0) + + class SubBehavior(Behavior): + def __init__(self) -> None: + super().__init__("SubBehavior") + + def build_graph(self) -> HEBGraph: + graph = HEBGraph(self) + graph.add_edge(feature_condition, expected_action, index=int(True)) + graph.add_edge(feature_condition, unexpected_action, index=int(False)) + return graph + + class RootBehavior(Behavior): + + """Feature condition with mutliple actions on same index.""" + + def __init__(self) -> None: + super().__init__("RootBehavior") + + def build_graph(self) -> HEBGraph: + graph = HEBGraph(self) + graph.add_edge(feature_condition, SubBehavior(), index=int(True)) + graph.add_edge(feature_condition, unexpected_action, index=int(False)) + + return graph + + root_behavior = RootBehavior() + draw = False + if draw: + plot_graph(root_behavior.graph.unrolled_graph) + + # Sanity check that the right action should be called and not the forbidden one. + assert root_behavior(observation=2) == expected_action.action + + # Feature condition should only be called once on the same input + assert len(feature_condition_call.call_args_list) == 1 + + # Graph should have the good split + call_graph = root_behavior.graph.call_graph + expected_graph = DiGraph( + [ + ("RootBehavior", "Greater or equal to 0 ?"), + ("Greater or equal to 0 ?", "SubBehavior"), + ("SubBehavior", "Greater or equal to 0 ?"), + ("Greater or equal to 0 ?", "Action(EXPECTED)"), + ] + ) + assert set(call_graph.edges()) == set(expected_graph.edges()) + + expected_calls_order = { + "RootBehavior": [0], + "Greater or equal to 0 ?": [1, 3], + "SubBehavior": [2], + "Action(EXPECTED)": [4], + } + for node, node_calls_order in call_graph.nodes(data="calls_order"): + check.equal(node_calls_order, expected_calls_order[node]) + def test_chain_behaviors(self, mocker: MockerFixture): - """When sub-behaviors are chained they should be in the call graph.""" + """When sub-behaviors with a graph are called recursively, + the call graph should still find their nodes.""" expected_action = "EXPECTED" @@ -107,15 +177,18 @@ def __init__(self) -> None: def build_graph(self) -> HEBGraph: graph = HEBGraph(self) - graph.add_node(SubBehavior()) + graph.add_node(Behavior("SubBehavior")) return graph - f_aa_behavior = RootBehavior() + sub_behavior = SubBehavior() + + root_behavior = RootBehavior() + root_behavior.graph.all_behaviors["SubBehavior"] = sub_behavior # Sanity check that the right action should be called. - assert f_aa_behavior(observation=-1) == expected_action + assert root_behavior(observation=-1) == expected_action - call_graph = f_aa_behavior.graph.call_graph + call_graph = root_behavior.graph.call_graph expected_graph = DiGraph( [ ("RootBehavior", "SubBehavior"), @@ -144,7 +217,10 @@ def test_looping_goback(self): "Action(Punch tree)", ] nodes_by_order = sorted( - [(node, order) for (node, order) in call_graph.nodes(data="order")], + [ + (node, order) + for (node, order) in call_graph.nodes(data="exploration_order") + ], key=lambda x: x[1], ) assert [node for node, _order in nodes_by_order] == expected_order