Skip to content

Commit

Permalink
✅ ♻️ Put all call responsability to call graph
Browse files Browse the repository at this point in the history
Test that multiple_call to same feature condition calls it only once and records all calls order
Distinguish call_order from exploration_order in call graph
  • Loading branch information
MathisFederico committed Jan 30, 2024
1 parent cc11b6c commit fd558c0
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 94 deletions.
115 changes: 89 additions & 26 deletions src/hebg/call_graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from enum import Enum
from re import S
from typing import Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, TypeVar, Union
from matplotlib.axes import Axes

from networkx import (
Expand All @@ -10,24 +9,73 @@
draw_networkx_nodes,
)
import numpy as np
from hebg.behavior import Behavior
from hebg.graph import get_successors_with_index
from hebg.node import FeatureCondition, Node

from hebg.node import Node
if TYPE_CHECKING:
from hebg.heb_graph import HEBGraph


class CallEdgeStatus(Enum):
UNEXPLORED = "unexplored"
CALLED = "called"
FAILURE = "failure"
Action = TypeVar("Action")


class CallGraph(DiGraph):
def __init__(self, initial_node: Node, **attr):
def __init__(self, initial_node: "Node", **attr):
super().__init__(incoming_graph_data=None, **attr)
self.graph["n_calls"] = 0
self.graph["frontiere"] = []
self.add_node(initial_node.name, order=0)
self._known_fc: Dict[FeatureCondition, Any] = {}
self.add_node(initial_node.name, exploration_order=0, calls_order=[0])

def call_nodes(
self,
nodes: List["Node"],
observation,
hebgraph: "HEBGraph",
parent: "Node" = None,
) -> Action:
self._extend_frontiere(nodes, parent)
next_node = self._pop_from_frontiere(parent)
if next_node is None:
raise ValueError("No valid frontiere left in call_graph")
return self._call_node(next_node, observation, hebgraph)

def _call_node(
self,
node: "Node",
observation: Any,
hebgraph: "HEBGraph",
) -> Action:
if node.type == "behavior":
# Search for name reference in all_behaviors
if node.name in hebgraph.all_behaviors:
node = hebgraph.all_behaviors[node.name]
return node(observation, self)
elif node.type == "action":
return node(observation)
elif node.type == "feature_condition":
if node in self._known_fc:
next_edge_index = self._known_fc[node]
else:
next_edge_index = int(node(observation))
self._known_fc[node] = next_edge_index
next_nodes = get_successors_with_index(hebgraph, node, next_edge_index)
elif node.type == "empty":
next_nodes = list(hebgraph.successors(node))
else:
raise ValueError(
f"Unknowed value {node.type} for node.type with node: {node}."
)

return self.call_nodes(
next_nodes,
observation,
hebgraph=hebgraph,
parent=node,
)

def extend_frontiere(self, nodes: List[Node], parent: Node):
frontiere: List[Node] = self.graph["frontiere"]
def _extend_frontiere(self, nodes: List["Node"], parent: "Node"):
frontiere: List["Node"] = self.graph["frontiere"]
frontiere.extend(nodes)

for node in nodes:
Expand All @@ -36,11 +84,11 @@ def extend_frontiere(self, nodes: List[Node], parent: Node):
)
node_data = self.nodes[node.name]
parent_data = self.nodes[parent.name]
if "order" not in node_data:
node_data["order"] = parent_data["order"] + 1
if "exploration_order" not in node_data:
node_data["exploration_order"] = parent_data["exploration_order"] + 1

def pop_from_frontiere(self, parent: Node) -> Optional[Node]:
frontiere: List[Node] = self.graph["frontiere"]
def _pop_from_frontiere(self, parent: "Node") -> Optional["Node"]:
frontiere: List["Node"] = self.graph["frontiere"]

next_node = None

Expand All @@ -49,17 +97,26 @@ def pop_from_frontiere(self, parent: Node) -> Optional[Node]:
return None
_next_node = frontiere.pop(np.argmin([node.cost for node in frontiere]))

if len(list(self.successors(_next_node))) > 0:
self.update_edge_status(parent, _next_node, CallEdgeStatus.FAILURE)
if (
isinstance(_next_node, Behavior)
and len(list(self.successors(_next_node))) > 0
):
self._update_edge_status(parent, _next_node, CallEdgeStatus.FAILURE)
continue

self.update_edge_status(parent, _next_node, CallEdgeStatus.CALLED)
next_node = _next_node

self.graph["n_calls"] += 1
calls_order = self.nodes[next_node.name].get("calls_order", None)
if calls_order is None:
calls_order = []
calls_order.append(self.graph["n_calls"])
self.nodes[next_node.name]["calls_order"] = calls_order
self._update_edge_status(parent, next_node, CallEdgeStatus.CALLED)
return next_node

def update_edge_status(
self, start: Node, end: Node, status: Union[CallEdgeStatus, str]
def _update_edge_status(
self, start: "Node", end: "Node", status: Union["CallEdgeStatus", str]
):
status = CallEdgeStatus(status)
self.edges[start.name, end.name]["status"] = status.value
Expand All @@ -73,7 +130,7 @@ def draw(
edges_kwargs: Optional[dict] = None,
):
if pos is None:
pos = call_graph_pos(self)
pos = _call_graph_pos(self)
if nodes_kwargs is None:
nodes_kwargs = {}
draw_networkx_nodes(self, ax=ax, pos=pos, **nodes_kwargs)
Expand All @@ -89,14 +146,20 @@ def draw(
ax=ax,
pos=pos,
edge_color=[
call_status_to_color(status)
_call_status_to_color(status)
for _, _, status in self.edges(data="status")
],
**edges_kwargs,
)


def call_status_to_color(status: Union[str, CallEdgeStatus]):
class CallEdgeStatus(Enum):
UNEXPLORED = "unexplored"
CALLED = "called"
FAILURE = "failure"


def _call_status_to_color(status: Union[str, "CallEdgeStatus"]):
status = CallEdgeStatus(status)
if status is CallEdgeStatus.UNEXPLORED:
return "black"
Expand All @@ -107,11 +170,11 @@ def call_status_to_color(status: Union[str, CallEdgeStatus]):
raise NotImplementedError


def call_graph_pos(call_graph: DiGraph) -> Dict[str, Tuple[float, float]]:
def _call_graph_pos(call_graph: DiGraph) -> Dict[str, Tuple[float, float]]:
pos = {}
amount_by_order = {}
for node, node_data in call_graph.nodes(data=True):
order: int = node_data["order"]
order: int = node_data["exploration_order"]
if order not in amount_by_order:
amount_by_order[order] = 0
else:
Expand Down
67 changes: 5 additions & 62 deletions src/hebg/heb_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,20 @@

from __future__ import annotations

from typing import Any, Dict, List, Optional, Tuple, TypeVar
from typing import Any, Dict, List, Optional, Tuple

from matplotlib.axes import Axes
from networkx import DiGraph

from hebg.behavior import Behavior
from hebg.call_graph import CallEdgeStatus, CallGraph
from hebg.call_graph import CallGraph
from hebg.codegen import get_hebg_source
from hebg.draw import draw_hebgraph
from hebg.graph import get_roots, get_successors_with_index
from hebg.graph import get_roots
from hebg.node import Node
from hebg.unrolling import unroll_graph


Action = TypeVar("Action")


class HEBGraph(DiGraph):

"""Base class for Hierchical Explanation of Behavior as Graphs.
Expand Down Expand Up @@ -125,59 +122,11 @@ def __call__(
) -> Any:
if call_graph is None:
call_graph = CallGraph(initial_node=self.behavior)

self.call_graph = call_graph
return self._split_call_between_nodes(
self.roots, observation, call_graph=call_graph
return self.call_graph.call_nodes(
self.roots, observation, hebgraph=self, parent=self.behavior
)

def _get_action(self, node: Node, observation: Any, call_graph: DiGraph):
# Behavior
if node.type == "behavior":
# Search for name reference in all_behaviors
if node.name in self.all_behaviors:
node = self.all_behaviors[node.name]

return node(observation, call_graph)

# Action
if node.type == "action":
return node(observation)

# Feature Condition
if node.type == "feature_condition":
next_edge_index = int(node(observation))
next_nodes = get_successors_with_index(self, node, next_edge_index)
return self._split_call_between_nodes(
next_nodes, observation, call_graph=call_graph, parent=node
)
# Empty
if node.type == "empty":
return self._split_call_between_nodes(
list(self.successors(node)),
observation,
call_graph=call_graph,
parent=node,
)
raise ValueError(f"Unknowed value {node.type} for node.type with node: {node}.")

def _split_call_between_nodes(
self,
nodes: List[Node],
observation,
call_graph: CallGraph,
parent: Optional[Node] = None,
) -> List[Action]:
if parent is None:
parent = self.behavior

call_graph.extend_frontiere(nodes, parent)
next_node = call_graph.pop_from_frontiere(parent)
if next_node is None:
raise ValueError("No valid frontiere left in call_graph")
action = self._get_action(next_node, observation, call_graph)
return action

@property
def roots(self) -> List[Node]:
"""Roots of the behavior graph (nodes without predecessors)."""
Expand All @@ -203,9 +152,3 @@ def draw(
"""
return draw_hebgraph(self, ax, **kwargs)


def remove_duplicate_actions(actions: List[Action]) -> List[Action]:
seen = set()
seen_add = seen.add
return [a for a in actions if not (a in seen or seen_add(a))]
Loading

0 comments on commit fd558c0

Please sign in to comment.