diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml
index 6a0203f..d0fecef 100644
--- a/.github/workflows/python-tests.yml
+++ b/.github/workflows/python-tests.yml
@@ -8,7 +8,7 @@ jobs:
runs-on: windows-latest
strategy:
matrix:
- python-version: ['3.8', '3.9', '3.10']
+ python-version: ['3.10', '3.11', '3.12']
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index df0cbe5..d986d22 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -15,8 +15,8 @@ repos:
hooks:
- id: pytest-fast-check
name: pytest-fast-check
- entry: ./venv/Scripts/python.exe -m pytest -m "not slow"
- stages: ["commit"]
+ entry: pytest -m "not slow"
+ stages: ["pre-commit"]
language: system
pass_filenames: false
always_run: true
@@ -24,8 +24,8 @@ repos:
hooks:
- id: pytest-check
name: pytest-check
- entry: ./venv/Scripts/python.exe -m pytest
- stages: ["push"]
+ entry: pytest
+ stages: ["pre-push"]
language: system
pass_filenames: false
always_run: true
diff --git a/README.rst b/README.rst
index e4f03dc..42bc25c 100644
--- a/README.rst
+++ b/README.rst
@@ -90,7 +90,7 @@ Here is an example to show how could we hierarchicaly build an explanable behavi
def __init__(self, hand) -> None:
super().__init__(name="Is hand near the cat ?")
self.hand = hand
- def __call__(self, observation):
+ def __call__(self, observation) -> int:
# Could be a very complex function that returns 1 is the hand is near the cat else 0.
if observation["cat"] == observation[self.hand]:
return int(True) # 1
@@ -119,7 +119,7 @@ Here is an example to show how could we hierarchicaly build an explanable behavi
class IsThereACatAround(FeatureCondition):
def __init__(self) -> None:
super().__init__(name="Is there a cat around ?")
- def __call__(self, observation):
+ def __call__(self, observation) -> int:
# Could be a very complex function that returns 1 is there is a cat around else 0.
if "cat" in observation:
return int(True) # 1
@@ -217,7 +217,7 @@ Will generate the code bellow:
# Require 'Look for a nearby cat' behavior to be given.
# Require 'Move slowly your hand near the cat' behavior to be given.
class PetTheCat(GeneratedBehavior):
- def __call__(self, observation):
+ def __call__(self, observation) -> Any:
edge_index = self.feature_conditions['Is there a cat around ?'](observation)
if edge_index == 0:
return self.known_behaviors['Look for a nearby cat'](observation)
diff --git a/commands/coverage.ps1 b/commands/coverage.ps1
new file mode 100644
index 0000000..a578912
--- /dev/null
+++ b/commands/coverage.ps1
@@ -0,0 +1 @@
+pytest --cov=src --cov-report=html --cov-report=term
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..b024da8
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,4 @@
+from setuptools import setup
+
+
+setup()
diff --git a/src/hebg/__init__.py b/src/hebg/__init__.py
index 0a5dc99..8fbff98 100644
--- a/src/hebg/__init__.py
+++ b/src/hebg/__init__.py
@@ -1,5 +1,5 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
"""A structure for explainable hierarchical reinforcement learning"""
diff --git a/src/hebg/behavior.py b/src/hebg/behavior.py
index fc68d8b..aa804b4 100644
--- a/src/hebg/behavior.py
+++ b/src/hebg/behavior.py
@@ -1,5 +1,5 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
"""Module for base Behavior."""
@@ -17,11 +17,11 @@
class Behavior(Node):
"""Abstract class for a Behavior as Node"""
- def __init__(self, name: str, image=None) -> None:
- super().__init__(name, "behavior", image=image)
+ def __init__(self, name: str, image=None, **kwargs) -> None:
+ super().__init__(name, "behavior", image=image, **kwargs)
self._graph = None
- def __call__(self, observation, *args, **kwargs):
+ def __call__(self, observation, *args, **kwargs) -> None:
"""Use the behavior to get next actions.
By default, uses the HEBGraph if it can be built.
diff --git a/src/hebg/call_graph.py b/src/hebg/call_graph.py
new file mode 100644
index 0000000..5c9e261
--- /dev/null
+++ b/src/hebg/call_graph.py
@@ -0,0 +1,315 @@
+from enum import Enum
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ List,
+ NamedTuple,
+ Optional,
+ Tuple,
+ TypeVar,
+ Union,
+)
+from matplotlib import pyplot as plt
+from matplotlib.axes import Axes
+
+from networkx import (
+ DiGraph,
+ all_simple_paths,
+ draw_networkx_edges,
+ draw_networkx_labels,
+ draw_networkx_nodes,
+ ancestors,
+)
+import numpy as np
+from hebg.behavior import Behavior
+from hebg.graph import get_successors_with_index
+from hebg.node import Action, FeatureCondition, Node
+
+if TYPE_CHECKING:
+ from hebg.heb_graph import HEBGraph
+
+EnvAction = TypeVar("EnvAction")
+
+
+class CallGraph(DiGraph):
+ def __init__(self, **attr) -> None:
+ super().__init__(incoming_graph_data=None, **attr)
+ self.graph["n_branches"] = 0
+ self.graph["n_calls"] = 0
+ self.graph["frontiere"] = []
+ self._known_fc: Dict[FeatureCondition, Any] = {}
+ self._current_node = CallNode(0, 0)
+
+ def add_root(self, heb_node: "Node", heb_graph: "HEBGraph", **kwargs) -> None:
+ self.add_node(
+ self._current_node, heb_node=heb_node, heb_graph=heb_graph, **kwargs
+ )
+
+ def call_nodes(
+ self, nodes: List["Node"], observation, heb_graph: "HEBGraph"
+ ) -> EnvAction:
+ self._extend_frontiere(nodes, heb_graph)
+ action = None
+
+ while len(self.graph["frontiere"]) > 0 and action is None:
+ next_call_node = self._pop_from_frontiere()
+ if next_call_node is None:
+ break
+
+ node: "Node" = self.nodes[next_call_node]["heb_node"]
+ heb_graph: "HEBGraph" = self.nodes[next_call_node]["heb_graph"]
+
+ if node.type == "behavior":
+ # Search for name reference in all_behaviors
+ if node.name in heb_graph.all_behaviors:
+ node = heb_graph.all_behaviors[node.name]
+ if not hasattr(node, "_graph") or node._graph is None:
+ action = node(observation)
+ break
+ self._extend_frontiere(node.graph.roots, heb_graph=node.graph)
+ elif node.type == "action":
+ action = node(observation)
+ break
+ elif node.type == "feature_condition":
+ if node in self._known_fc:
+ next_edge_index = self._known_fc[node]
+ else:
+ next_edge_index = int(node(observation))
+ self._known_fc[node] = next_edge_index
+ next_nodes = get_successors_with_index(heb_graph, node, next_edge_index)
+ self._extend_frontiere(next_nodes, heb_graph)
+ elif node.type == "empty":
+ self._extend_frontiere(list(heb_graph.successors(node)), heb_graph)
+ else:
+ raise ValueError(
+ f"Unknowed value {node.type} for node.type with node: {node}."
+ )
+
+ if action is None:
+ raise ValueError("No valid frontiere left in call_graph")
+
+ return action
+
+ def call_edge_labels(self) -> list[tuple]:
+ return [
+ (self.nodes[u]["label"], self.nodes[v]["label"]) for u, v in self.edges()
+ ]
+
+ def add_node(
+ self, node_for_adding, heb_node: "Node", heb_graph: "HEBGraph", **attr
+ ):
+ super().add_node(
+ node_for_adding,
+ heb_graph=heb_graph,
+ heb_node=heb_node,
+ label=heb_node.name,
+ **attr,
+ )
+
+ def add_edge(
+ self,
+ u_of_edge,
+ v_of_edge,
+ status: "CallEdgeStatus",
+ **attr,
+ ):
+ return super().add_edge(u_of_edge, v_of_edge, status=status.value, **attr)
+
+ def _make_new_branch(self) -> int:
+ self.graph["n_branches"] += 1
+ return self.graph["n_branches"]
+
+ def _extend_frontiere(self, nodes: List["Node"], heb_graph: "HEBGraph") -> None:
+ frontiere: List[CallNode] = self.graph["frontiere"]
+
+ parent = self._current_node
+ call_nodes = []
+
+ for i, node in enumerate(nodes):
+ if i > 0:
+ branch_id = self._make_new_branch()
+ else:
+ branch_id = parent.branch
+ call_node = CallNode(branch_id, parent.rank + 1)
+
+ if node.name in heb_graph.all_behaviors:
+ node = heb_graph.all_behaviors[node.name]
+ self.add_node(call_node, heb_node=node, heb_graph=heb_graph)
+ self.add_edge(parent, call_node, CallEdgeStatus.UNEXPLORED)
+ call_nodes.append(call_node)
+
+ frontiere.extend(call_nodes)
+
+ def _heb_node_from_call_node(self, node: "CallNode") -> "Node":
+ return self.nodes[node]["heb_node"]
+
+ def _pop_from_frontiere(self) -> Optional["CallNode"]:
+ frontiere: List["CallNode"] = self.graph["frontiere"]
+
+ next_node = None
+
+ while next_node is None:
+ if not frontiere:
+ return None
+
+ next_call_node = frontiere.pop(
+ np.argmin(
+ [
+ self._heb_node_from_call_node(node).complexity
+ for node in frontiere
+ ]
+ )
+ )
+ maybe_next_node = self._heb_node_from_call_node(next_call_node)
+ # Nodes should only have one parent
+ parent = list(self.predecessors(next_call_node))[0]
+
+ if isinstance(maybe_next_node, Behavior) and maybe_next_node in [
+ self._heb_node_from_call_node(node)
+ for node in ancestors(self, next_call_node)
+ ]:
+ self._update_edge_status(parent, next_call_node, CallEdgeStatus.FAILURE)
+ continue
+
+ next_node = maybe_next_node
+
+ self.graph["n_calls"] += 1
+ self.nodes[next_call_node]["call_rank"] = 1
+ self._update_edge_status(parent, next_call_node, CallEdgeStatus.CALLED)
+ self._current_node = next_call_node
+ return next_call_node
+
+ def _update_edge_status(
+ self, start: "Node", end: "Node", status: Union["CallEdgeStatus", str]
+ ):
+ status = CallEdgeStatus(status)
+ self.edges[start, end]["status"] = status.value
+
+ def draw(
+ self,
+ ax: Optional[Axes] = None,
+ pos: Optional[Dict[str, Tuple[float, float]]] = None,
+ nodes_kwargs: Optional[dict] = None,
+ label_kwargs: Optional[dict] = None,
+ edges_kwargs: Optional[dict] = None,
+ ):
+ if pos is None:
+ pos = _call_graph_pos(self)
+ if nodes_kwargs is None:
+ nodes_kwargs = {}
+
+ if ax is None:
+ ax = plt.gca()
+
+ pos_arr = np.array(list(pos.values()))
+ max_x, max_y = pos_arr.max(axis=0)
+ min_x, min_y = pos_arr.min(axis=0)
+ y_range = max_y - min_y
+ x_range = max_x - min_x
+ ax.set_ylim([min_y - 0.1 * y_range, max_y + 0.1 * y_range])
+ ax.set_xlim([min_x - 0.1 * x_range, max_x + 0.1 * x_range])
+
+ nodes_complexity = np.array(
+ [node_data["heb_node"].complexity for _, node_data in self.nodes(data=True)]
+ )
+ complexity_range = nodes_complexity.max() - nodes_complexity.min()
+
+ nodes_complexity_scaled = (
+ 50 + 600 * (nodes_complexity - nodes_complexity.min()) / complexity_range
+ )
+
+ draw_networkx_nodes(
+ self,
+ node_color=[
+ _node_color(node_data["heb_node"])
+ for _, node_data in self.nodes(data=True)
+ ],
+ node_size=nodes_complexity_scaled,
+ ax=ax,
+ pos=pos,
+ **nodes_kwargs,
+ )
+ if label_kwargs is None:
+ label_kwargs = {}
+ draw_networkx_labels(
+ self,
+ labels={
+ node: f"{node_data['label']}"
+ for node, node_data in self.nodes(data=True)
+ },
+ ax=ax,
+ horizontalalignment="center",
+ verticalalignment="center",
+ font_size=8,
+ pos=pos,
+ **nodes_kwargs,
+ )
+ if edges_kwargs is None:
+ edges_kwargs = {}
+ if "connectionstyle" not in edges_kwargs:
+ edges_kwargs.update(connectionstyle="angle,angleA=0,angleB=90,rad=5")
+ draw_networkx_edges(
+ self,
+ ax=ax,
+ pos=pos,
+ arrowstyle="-",
+ alpha=0.5,
+ width=3,
+ node_size=1,
+ edge_color=[
+ _call_status_to_color(status)
+ for _, _, status in self.edges(data="status")
+ ],
+ **edges_kwargs,
+ )
+
+
+class CallNode(NamedTuple):
+ branch: int
+ rank: int
+
+
+class CallEdgeStatus(Enum):
+ UNEXPLORED = "unexplored"
+ CALLED = "called"
+ FAILURE = "failure"
+
+
+def _node_color(node: Union[Action, FeatureCondition, Behavior]) -> str:
+ if isinstance(node, Action):
+ return "red"
+ if isinstance(node, FeatureCondition):
+ return "blue"
+ if isinstance(node, Behavior):
+ return "orange"
+ raise NotImplementedError
+
+
+def _call_status_to_color(status: Union[str, "CallEdgeStatus"]) -> str:
+ status = CallEdgeStatus(status)
+ if status is CallEdgeStatus.UNEXPLORED:
+ return "black"
+ if status is CallEdgeStatus.CALLED:
+ return "green"
+ if status is CallEdgeStatus.FAILURE:
+ return "red"
+ raise NotImplementedError
+
+
+def _call_graph_pos(call_graph: "CallGraph") -> Dict[str, Tuple[float, float]]:
+ pos = {}
+
+ roots = [n for (n, d) in call_graph.in_degree if d == 0]
+ leafs = [n for (n, d) in call_graph.out_degree if d == 0]
+
+ branches = all_simple_paths(call_graph, roots[0], leafs)
+ branches = sorted(branches, key=lambda x: -len(x))
+
+ for branch_id, nodes_in_branch in enumerate(branches):
+ for node in nodes_in_branch:
+ if node in pos:
+ continue
+ rank = node.rank
+ pos[node] = [branch_id, -rank]
+ return pos
diff --git a/src/hebg/codegen.py b/src/hebg/codegen.py
index b58987d..fec9e51 100644
--- a/src/hebg/codegen.py
+++ b/src/hebg/codegen.py
@@ -1,3 +1,6 @@
+# HEBGraph for explainable hierarchical reinforcement learning
+# Copyright (C) 2021-2024 Mathïs FEDERICO
+
"""Module for code generation from HEBGraph."""
from re import sub
diff --git a/src/hebg/draw.py b/src/hebg/draw.py
index cd4d6ad..80efd7d 100644
--- a/src/hebg/draw.py
+++ b/src/hebg/draw.py
@@ -1,3 +1,6 @@
+# HEBGraph for explainable hierarchical reinforcement learning
+# Copyright (C) 2021-2024 Mathïs FEDERICO
+
import math
from typing import TYPE_CHECKING, Dict, Optional, Tuple
@@ -7,7 +10,7 @@
from matplotlib.axes import Axes
from matplotlib.legend import Legend
from matplotlib.legend_handler import HandlerPatch
-from networkx import draw_networkx_edges
+from networkx import draw_networkx_edges, spring_layout
from scipy.spatial import ConvexHull # pylint: disable=no-name-in-module
from hebg.graph import draw_networkx_nodes_images
@@ -34,7 +37,10 @@ def draw_hebgraph(
plt.setp(ax.spines.values(), color="orange")
if pos is None:
- pos = staircase_layout(graph)
+ if len(graph.roots) > 0:
+ pos = staircase_layout(graph)
+ else:
+ pos = spring_layout(graph)
draw_networkx_nodes_images(graph, pos, ax=ax, img_zoom=0.5)
draw_networkx_edges(
diff --git a/src/hebg/graph.py b/src/hebg/graph.py
index 21e3693..4269bf4 100644
--- a/src/hebg/graph.py
+++ b/src/hebg/graph.py
@@ -1,5 +1,5 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
# pylint: disable=protected-access
"""Additional utility functions for networkx graphs."""
diff --git a/src/hebg/heb_graph.py b/src/hebg/heb_graph.py
index e077f69..cbb42a4 100644
--- a/src/hebg/heb_graph.py
+++ b/src/hebg/heb_graph.py
@@ -1,28 +1,25 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
# pylint: disable=arguments-differ
"""Module containing the HEBGraph base class."""
from __future__ import annotations
-from typing import Any, Dict, List, Optional, Tuple, TypeVar
+from typing import Any, Dict, List, Optional, Tuple
-import numpy as np
from matplotlib.axes import Axes
from networkx import DiGraph
from hebg.behavior import Behavior
+from hebg.call_graph import CallGraph
from hebg.codegen import get_hebg_source
from hebg.draw import draw_hebgraph
-from hebg.graph import get_roots, get_successors_with_index
+from hebg.graph import get_roots
from hebg.node import Node
from hebg.unrolling import unroll_graph
-Action = TypeVar("Action")
-
-
class HEBGraph(DiGraph):
"""Base class for Hierchical Explanation of Behavior as Graphs.
@@ -45,7 +42,6 @@ class HEBGraph(DiGraph):
behavior: The Behavior object from which this graph is built.
all_behaviors: A dictionary of behavior, this can be used to avoid cirular definitions using
the behavior names as anchor instead of the behavior object itself.
- any_mode: How to choose path, when multiple path are valid.
incoming_graph_data: Additional data to include in the graph.
"""
@@ -60,24 +56,19 @@ class HEBGraph(DiGraph):
5: "cyan",
6: "gray",
}
- ANY_MODES = ("first", "last", "random")
def __init__(
self,
behavior: Behavior,
all_behaviors: Dict[str, Behavior] = None,
incoming_graph_data=None,
- any_mode: str = "first",
**attr,
):
self.behavior = behavior
self.all_behaviors = all_behaviors if all_behaviors is not None else {}
self._unrolled_graph = None
- self.last_call_behaviors_stack = None
-
- assert any_mode in self.ANY_MODES, f"Unknowed any_mode: {any_mode}"
- self.any_mode = any_mode
+ self.call_graph: Optional[CallGraph] = None
super().__init__(incoming_graph_data=incoming_graph_data, **attr)
@@ -106,62 +97,6 @@ def add_edge(self, u_of_edge: Node, v_of_edge: Node, index: int = 1, **attr):
color = "black"
super().add_edge(u_of_edge, v_of_edge, index=index, color=color, **attr)
- def _get_options(
- self,
- nodes: List[Node],
- observation,
- behaviors_in_search: list,
- last_call_behaviors_stack: Optional[list] = None,
- parent_name: Optional[str] = None,
- ) -> List[Action]:
- actions = []
- for node in nodes:
- node_action = self._get_action(
- node,
- observation,
- behaviors_in_search,
- last_call_behaviors_stack=last_call_behaviors_stack,
- )
- if node_action is None:
- return None
- actions.append(node_action)
-
- options = remove_duplicate_actions(
- [action for action in actions if action != "Impossible"]
- )
-
- if parent_name is None:
- parent_name = self.behavior.name
- if (
- (len(nodes) > 1 or self.behavior.name)
- and options
- and last_call_behaviors_stack is not None
- ):
- last_call_behaviors_stack.insert(
- 0, (self.behavior.name, [n.name for n in self.roots], options)
- )
-
- return options
-
- def _choose_action(self, actions: Optional[List[Action]]) -> Action:
- if actions is None:
- return None
- if len(actions) == 0:
- return "Impossible"
- if self.any_mode == "first" or len(actions) == 1:
- return actions[0]
- if self.any_mode == "last":
- return actions[-1]
- if self.any_mode == "random":
- return np.random.choice(actions)
-
- def _get_any_action(
- self, nodes: List[Node], observation, behaviors_in_search: list
- ):
- return self._choose_action(
- self._get_options(nodes, observation, behaviors_in_search)
- )
-
@property
def unrolled_graph(self) -> HEBGraph:
"""Access to the unrolled behavior graph.
@@ -179,69 +114,16 @@ def unrolled_graph(self) -> HEBGraph:
self._unrolled_graph = unroll_graph(self)
return self._unrolled_graph
- def _get_action(
- self,
- node: Node,
- observation: Any,
- behaviors_in_search: List[str],
- last_call_behaviors_stack: Optional[list] = None,
- ):
- # Behavior
- if node.type == "behavior":
- # To avoid cycling definitions
- if node.name in behaviors_in_search:
- return "Impossible"
-
- # Search for name reference in all_behaviors
- if node.name in self.all_behaviors:
- node = self.all_behaviors[node.name]
-
- return node(observation, behaviors_in_search, last_call_behaviors_stack)
-
- # Action
- if node.type == "action":
- return node(observation)
- # Feature Condition
- if node.type == "feature_condition":
- next_edge_index = int(node(observation))
- next_nodes = get_successors_with_index(self, node, next_edge_index)
- options = self._get_options(
- next_nodes,
- observation,
- behaviors_in_search,
- last_call_behaviors_stack=last_call_behaviors_stack,
- parent_name=node.name,
- )
- return self._choose_action(options)
- # Empty
- if node.type == "empty":
- next_node = self.successors(node).__next__()
- return self._get_action(
- next_node,
- observation,
- behaviors_in_search,
- last_call_behaviors_stack=last_call_behaviors_stack,
- )
- raise ValueError(f"Unknowed value {node.type} for node.type with node: {node}.")
-
def __call__(
self,
observation,
- behaviors_in_search: Optional[List[str]] = None,
- last_call_behaviors_stack: Optional[list] = None,
+ call_graph: Optional[CallGraph] = None,
) -> Any:
- if behaviors_in_search is None:
- behaviors_in_search = []
- last_call_behaviors_stack = []
- behaviors_in_search.append(self.behavior.name)
- options = self._get_options(
- self.roots,
- observation,
- behaviors_in_search,
- last_call_behaviors_stack=last_call_behaviors_stack,
- )
- self.last_call_behaviors_stack = last_call_behaviors_stack
- return self._choose_action(options)
+ if call_graph is None:
+ call_graph = CallGraph()
+ call_graph.add_root(heb_node=self.behavior, heb_graph=self)
+ self.call_graph = call_graph
+ return self.call_graph.call_nodes(self.roots, observation, heb_graph=self)
@property
def roots(self) -> List[Node]:
@@ -268,9 +150,3 @@ def draw(
"""
return draw_hebgraph(self, ax, **kwargs)
-
-
-def remove_duplicate_actions(actions: List[Action]) -> List[Action]:
- seen = set()
- seen_add = seen.add
- return [a for a in actions if not (a in seen or seen_add(a))]
diff --git a/src/hebg/layouts/__init__.py b/src/hebg/layouts/__init__.py
index 7bf586e..915138d 100644
--- a/src/hebg/layouts/__init__.py
+++ b/src/hebg/layouts/__init__.py
@@ -1,5 +1,5 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
"""Module containing layouts to draw graphs."""
diff --git a/src/hebg/layouts/deterministic.py b/src/hebg/layouts/deterministic.py
index d7ff6db..477f81b 100644
--- a/src/hebg/layouts/deterministic.py
+++ b/src/hebg/layouts/deterministic.py
@@ -1,5 +1,5 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
# pylint: disable=protected-access
"""Deterministic layouts"""
diff --git a/src/hebg/layouts/metabased.py b/src/hebg/layouts/metabased.py
index b6060fd..1843db4 100644
--- a/src/hebg/layouts/metabased.py
+++ b/src/hebg/layouts/metabased.py
@@ -1,5 +1,5 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
# pylint: disable=protected-access
"""Metaheuristics based layouts"""
diff --git a/src/hebg/layouts/metaheuristics.py b/src/hebg/layouts/metaheuristics.py
index f1282bf..3ee3292 100644
--- a/src/hebg/layouts/metaheuristics.py
+++ b/src/hebg/layouts/metaheuristics.py
@@ -1,5 +1,5 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
"""Metaheuristics used for building layouts"""
diff --git a/src/hebg/metrics/__init__.py b/src/hebg/metrics/__init__.py
index c86285c..7ce4b9c 100644
--- a/src/hebg/metrics/__init__.py
+++ b/src/hebg/metrics/__init__.py
@@ -1,4 +1,4 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
"""Module containing HEBGraph metrics."""
diff --git a/src/hebg/metrics/complexity/__init__.py b/src/hebg/metrics/complexity/__init__.py
index caf67d4..dc0ae8d 100644
--- a/src/hebg/metrics/complexity/__init__.py
+++ b/src/hebg/metrics/complexity/__init__.py
@@ -1,5 +1,5 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
"""Module for complexity computation methods."""
diff --git a/src/hebg/metrics/complexity/complexities.py b/src/hebg/metrics/complexity/complexities.py
index a00da8d..992307b 100644
--- a/src/hebg/metrics/complexity/complexities.py
+++ b/src/hebg/metrics/complexity/complexities.py
@@ -1,5 +1,5 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
"""General complexity."""
diff --git a/src/hebg/metrics/complexity/utils.py b/src/hebg/metrics/complexity/utils.py
index cad8903..fe4021e 100644
--- a/src/hebg/metrics/complexity/utils.py
+++ b/src/hebg/metrics/complexity/utils.py
@@ -1,5 +1,5 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
"""Utility functions for complexity computation."""
diff --git a/src/hebg/metrics/histograms.py b/src/hebg/metrics/histograms.py
index 6ba45f9..24ef27b 100644
--- a/src/hebg/metrics/histograms.py
+++ b/src/hebg/metrics/histograms.py
@@ -1,5 +1,5 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
"""HEBGraph used nodes histograms computation."""
@@ -308,9 +308,9 @@ def _get_node_histogram_complexity(
if behaviors_in_search is not None and str(node) in behaviors_in_search:
return {}, np.inf
if node.type in ("action", "feature_condition", "behavior"):
- try:
+ if node.complexity is not None:
node_complexity = node.complexity
- except AttributeError:
+ else:
node_complexity = default_node_complexity
return {node: 1}, node_complexity
if node.type == "empty":
diff --git a/src/hebg/metrics/utility/__init__.py b/src/hebg/metrics/utility/__init__.py
index 98a3e1a..12adad7 100644
--- a/src/hebg/metrics/utility/__init__.py
+++ b/src/hebg/metrics/utility/__init__.py
@@ -1,5 +1,5 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
"""Module for utility computation methods."""
diff --git a/src/hebg/metrics/utility/binary_utility.py b/src/hebg/metrics/utility/binary_utility.py
index 91e9833..a533c50 100644
--- a/src/hebg/metrics/utility/binary_utility.py
+++ b/src/hebg/metrics/utility/binary_utility.py
@@ -1,5 +1,5 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
"""Simplest binary utility for HEBGraph."""
diff --git a/src/hebg/node.py b/src/hebg/node.py
index 2779bae..a398b04 100644
--- a/src/hebg/node.py
+++ b/src/hebg/node.py
@@ -1,5 +1,5 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
"""Module for base Node classes."""
@@ -45,11 +45,7 @@ def __init__(
f"not in authorised node_types ({self.NODE_TYPES})."
)
self.type = node_type
- if complexity is not None:
- self.complexity = complexity
- else:
- self.complexity = bytecode_complexity(self.__init__)
- self.complexity += bytecode_complexity(self.__call__)
+ self.complexity = complexity
def __call__(self, observation: Any) -> Any:
raise NotImplementedError
diff --git a/src/hebg/requirements_graph.py b/src/hebg/requirements_graph.py
index 1520939..73557fb 100644
--- a/src/hebg/requirements_graph.py
+++ b/src/hebg/requirements_graph.py
@@ -1,5 +1,5 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
# pylint: disable=arguments-differ
"""Module for building underlying requirement graphs based on a set of behaviors."""
diff --git a/src/hebg/unrolling.py b/src/hebg/unrolling.py
index 0eb0250..7874443 100644
--- a/src/hebg/unrolling.py
+++ b/src/hebg/unrolling.py
@@ -1,3 +1,6 @@
+# HEBGraph for explainable hierarchical reinforcement learning
+# Copyright (C) 2021-2024 Mathïs FEDERICO
+
"""Module to unroll HEBGraph.
Unrolling means expanding each sub-behavior node as it's own graph in the global HEBGraph.
@@ -20,7 +23,7 @@
def unroll_graph(
graph: "HEBGraph",
- add_prefix=False,
+ add_prefix: bool = False,
cut_looping_alternatives: bool = False,
) -> "HEBGraph":
"""Build the the unrolled HEBGraph.
@@ -48,7 +51,7 @@ def unroll_graph(
def _unroll_graph(
graph: "HEBGraph",
- add_prefix=False,
+ add_prefix: bool = False,
cut_looping_alternatives: bool = False,
_current_alternatives: Optional[List[Union["Action", "Behavior"]]] = None,
_unrolled_behaviors: Optional[Dict[str, Optional["HEBGraph"]]] = None,
@@ -56,7 +59,7 @@ def _unroll_graph(
if _unrolled_behaviors is None:
_unrolled_behaviors = {}
if _current_alternatives is None:
- _current_alternatives = []
+ _current_alternatives = {0: []}
is_looping = False
_unrolled_behaviors[graph.behavior.name] = None
@@ -65,14 +68,9 @@ def _unroll_graph(
for node in list(unrolled_graph.nodes()):
if not isinstance(node, Behavior):
continue
- new_alternatives = []
- for pred, _node, data in graph.in_edges(node, data=True):
- index = data["index"]
- for _pred, alternative, alt_index in graph.out_edges(pred, data="index"):
- if index == alt_index and alternative != node:
- new_alternatives.append(alternative)
- if new_alternatives:
- _current_alternatives = new_alternatives
+
+ _current_alternatives[0] = _direct_alternatives(node, graph)
+ _current_alternatives[1] = _roots_alternatives(node, graph)
unrolled_graph, behavior_is_looping = _unroll_behavior(
unrolled_graph,
node,
@@ -87,6 +85,25 @@ def _unroll_graph(
return unrolled_graph, is_looping
+def _direct_alternatives(node: "Node", graph: "HEBGraph") -> list["Node"]:
+ alternatives = []
+ for pred, _node, data in graph.in_edges(node, data=True):
+ index = data["index"]
+ for _pred, alternative, alt_index in graph.out_edges(pred, data="index"):
+ if index != alt_index or alternative == node:
+ continue
+ alternatives.append(alternative)
+ return alternatives
+
+
+def _roots_alternatives(node: "Node", graph: "HEBGraph") -> list["Node"]:
+ alternatives = []
+ for pred, _node, _data in graph.in_edges(node, data=True):
+ if pred in graph.roots:
+ alternatives.extend([r for r in graph.roots if r != pred])
+ return alternatives
+
+
def _unroll_behavior(
graph: "HEBGraph",
behavior: "Behavior",
@@ -119,13 +136,24 @@ def _unroll_behavior(
)
if is_looping and cut_looping_alternatives:
- if not _current_alternatives:
- return graph, is_looping
- for alternative in _current_alternatives:
+ for alternative in _current_alternatives[0]:
for last_condition, _, data in graph.in_edges(behavior, data=True):
graph.add_edge(last_condition, alternative, **data)
- graph.remove_node(behavior)
- return graph, False
+ if _current_alternatives[0]:
+ graph.remove_node(behavior)
+ return graph, False
+ if _current_alternatives[1]:
+ predecessors = list(graph.predecessors(behavior))
+ for last_condition in predecessors:
+ successors = list(graph.successors(last_condition))
+ for descendant in successors:
+ graph.remove_edge(last_condition, descendant)
+ if graph.neighbors(descendant) == 0:
+ graph.remove_node(descendant)
+ graph.remove_node(last_condition)
+ graph.remove_node(behavior)
+ return graph, False
+ raise NotImplementedError()
if node_graph is None:
# If we cannot get the node's graph, we keep it as is.
@@ -151,7 +179,7 @@ def _unrolled_behavior_graph(
cut_looping_alternatives: bool,
_current_alternatives: List[Union["Action", "Behavior"]],
_unrolled_behaviors: Dict[str, Optional["HEBGraph"]],
-) -> Optional["HEBGraph"]:
+) -> Tuple[Optional["HEBGraph"], bool]:
"""Get the unrolled sub-graph of a behavior.
Args:
@@ -188,7 +216,7 @@ def _add_prefix_to_graph(graph: "HEBGraph", prefix: str) -> None:
if prefix is None:
return graph
- def rename(node: "Node"):
+ def rename(node: "Node") -> None:
new_node = copy(node)
new_node.name = prefix + node.name
return new_node
@@ -216,14 +244,14 @@ def group_behaviors_points(
for i in range(len(groups[:-1])):
key = tuple(groups[: -1 - i])
point = pos[node]
- try:
+ if key in points_grouped_by_behavior:
points_grouped_by_behavior[key].append(point)
- except KeyError:
+ else:
points_grouped_by_behavior[key] = [point]
return points_grouped_by_behavior
-def compose_heb_graphs(graph_of_reference: "HEBGraph", other_graph: "HEBGraph"):
+def compose_heb_graphs(graph_of_reference: "HEBGraph", other_graph: "HEBGraph") -> None:
"""Returns a new_graph of graph_of_reference composed with other_graph.
Composition is the simple union of the node sets and edge sets.
@@ -237,9 +265,7 @@ def compose_heb_graphs(graph_of_reference: "HEBGraph", other_graph: "HEBGraph"):
"""
new_graph = graph_of_reference.__class__(
- graph_of_reference.behavior,
- all_behaviors=graph_of_reference.all_behaviors,
- any_mode=graph_of_reference.any_mode,
+ graph_of_reference.behavior, all_behaviors=graph_of_reference.all_behaviors
)
# add graph attributes, H attributes take precedent over G attributes
new_graph.graph.update(graph_of_reference.graph)
diff --git a/tests/__init__.py b/tests/__init__.py
index a4983ea..66b76f8 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -1,4 +1,22 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
"""Tests for the heb_graph package."""
+
+from typing import Protocol
+from matplotlib import pyplot as plt
+
+
+class Graph(Protocol):
+ def draw(self, ax, pos):
+ """Draw the graph on a matplotlib axes."""
+
+ def nodes(self) -> list:
+ """Return a list of nodes"""
+
+
+def plot_graph(graph: Graph, **kwargs):
+ _, ax = plt.subplots()
+ graph.draw(ax, **kwargs)
+ plt.axis("off") # turn off axis
+ plt.show()
diff --git a/tests/examples/behaviors/__init__.py b/tests/examples/behaviors/__init__.py
index e0a7bc3..d339dd3 100644
--- a/tests/examples/behaviors/__init__.py
+++ b/tests/examples/behaviors/__init__.py
@@ -1,4 +1,32 @@
-from tests.examples.behaviors.basic import *
-from tests.examples.behaviors.basic_empty import *
-
+from tests.examples.behaviors.basic import (
+ FundamentalBehavior,
+ F_A_Behavior,
+ F_AA_Behavior,
+ F_F_A_Behavior,
+)
+from tests.examples.behaviors.basic_empty import (
+ E_A_Behavior,
+ E_F_A_Behavior,
+ F_E_A_Behavior,
+ E_E_A_Behavior,
+)
from tests.examples.behaviors.binary_sum import build_binary_sum_behavior
+from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors
+from tests.examples.behaviors.loop_without_alternative import (
+ build_looping_behaviors_without_direct_alternatives,
+)
+
+
+__all__ = [
+ "FundamentalBehavior",
+ "F_A_Behavior",
+ "F_AA_Behavior",
+ "F_F_A_Behavior",
+ "E_A_Behavior",
+ "E_F_A_Behavior",
+ "F_E_A_Behavior",
+ "E_E_A_Behavior",
+ "build_binary_sum_behavior",
+ "build_looping_behaviors",
+ "build_looping_behaviors_without_direct_alternatives",
+]
diff --git a/tests/examples/behaviors/basic.py b/tests/examples/behaviors/basic.py
index 8d1c88b..9bad5fb 100644
--- a/tests/examples/behaviors/basic.py
+++ b/tests/examples/behaviors/basic.py
@@ -67,6 +67,9 @@ def build_graph(self) -> HEBGraph:
class F_F_A_Behavior(Behavior):
"""Double layer feature conditions behavior"""
+ def __init__(self, name: str = "F_F_A", *args, **kwargs) -> None:
+ super().__init__(name=name, *args, **kwargs)
+
def build_graph(self) -> HEBGraph:
graph = HEBGraph(self)
@@ -89,12 +92,11 @@ def build_graph(self) -> HEBGraph:
class F_AA_Behavior(Behavior):
"""Feature condition with mutliple actions on same index."""
- def __init__(self, name: str, any_mode: str) -> None:
+ def __init__(self, name: str = "F_AA") -> None:
super().__init__(name, image=None)
- self.any_mode = any_mode
def build_graph(self) -> HEBGraph:
- graph = HEBGraph(self, any_mode=self.any_mode)
+ graph = HEBGraph(self)
feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0)
graph.add_edge(feature_condition, Action(0), index=int(True))
diff --git a/tests/examples/behaviors/loop_with_alternative.py b/tests/examples/behaviors/loop_with_alternative.py
index 4fec202..dfd462c 100644
--- a/tests/examples/behaviors/loop_with_alternative.py
+++ b/tests/examples/behaviors/loop_with_alternative.py
@@ -1,8 +1,17 @@
-from typing import List
+from typing import Any, List
from hebg import HEBGraph, Action, FeatureCondition, Behavior
+class HasItem(FeatureCondition):
+ def __init__(self, item_name: str) -> None:
+ self.item_name = item_name
+ super().__init__(name=f"Has {item_name} ?", complexity=1.0)
+
+ def __call__(self, observation: Any) -> int:
+ return self.item_name in observation
+
+
class GatherWood(Behavior):
"""Gather wood"""
@@ -12,10 +21,10 @@ def __init__(self) -> None:
def build_graph(self) -> HEBGraph:
graph = HEBGraph(self)
- feature = FeatureCondition("Has an axe")
- graph.add_edge(feature, Action("Punch tree"), index=False)
- graph.add_edge(feature, Behavior("Get new axe"), index=False)
- graph.add_edge(feature, Action("Use axe on tree"), index=True)
+ has_axe = HasItem("axe")
+ graph.add_edge(has_axe, Action("Punch tree", complexity=2.0), index=False)
+ graph.add_edge(has_axe, Behavior("Get new axe", complexity=1.0), index=False)
+ graph.add_edge(has_axe, Action("Use axe on tree", complexity=1.0), index=True)
return graph
@@ -28,9 +37,12 @@ def __init__(self) -> None:
def build_graph(self) -> HEBGraph:
graph = HEBGraph(self)
- feature = FeatureCondition("Has wood")
- graph.add_edge(feature, Behavior("Gather wood"), index=False)
- graph.add_edge(feature, Action("Craft axe"), index=True)
+ has_wood = HasItem("wood")
+ graph.add_edge(has_wood, Behavior("Gather wood", complexity=1.0), index=False)
+ graph.add_edge(
+ has_wood, Action("Summon axe out of thin air", complexity=10.0), index=False
+ )
+ graph.add_edge(has_wood, Action("Craft axe", complexity=1.0), index=True)
return graph
@@ -39,4 +51,5 @@ def build_looping_behaviors() -> List[Behavior]:
all_behaviors = {behavior.name: behavior for behavior in behaviors}
for behavior in behaviors:
behavior.graph.all_behaviors = all_behaviors
+ behavior.complexity = 5
return behaviors
diff --git a/tests/examples/behaviors/loop_without_alternative.py b/tests/examples/behaviors/loop_without_alternative.py
index 31e0d0d..94ee273 100644
--- a/tests/examples/behaviors/loop_without_alternative.py
+++ b/tests/examples/behaviors/loop_without_alternative.py
@@ -57,7 +57,7 @@ def build_graph(self) -> HEBGraph:
return graph
-def build_looping_behaviors() -> List[Behavior]:
+def build_looping_behaviors_without_direct_alternatives() -> List[Behavior]:
behaviors: List[Behavior] = [
ReachForest(),
ReachOtherZone(),
diff --git a/tests/examples/feature_conditions/scalar.py b/tests/examples/feature_conditions.py
similarity index 95%
rename from tests/examples/feature_conditions/scalar.py
rename to tests/examples/feature_conditions.py
index 75da674..4d38cdc 100644
--- a/tests/examples/feature_conditions/scalar.py
+++ b/tests/examples/feature_conditions.py
@@ -12,7 +12,7 @@ class Relation(Enum):
LESSER_THAN = "<"
def __init__(
- self, relation: Union[Relation, str] = ">=", threshold: float = 0
+ self, relation: Union[Relation, str] = ">=", threshold: float = 0, **kwargs
) -> None:
"""Threshold-based feature condition for scalar feature."""
self.relation = relation
@@ -20,7 +20,7 @@ def __init__(
self._relation = self.Relation(relation)
display_name = self._relation.name.capitalize().replace("_", " ")
name = f"{display_name} {threshold} ?"
- super().__init__(name=name, image=None)
+ super().__init__(name=name, **kwargs)
def __call__(self, observation: float) -> int:
conditions = {
diff --git a/tests/examples/feature_conditions/__init__.py b/tests/examples/feature_conditions/__init__.py
deleted file mode 100644
index 255c52e..0000000
--- a/tests/examples/feature_conditions/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-from tests.examples.feature_conditions.scalar import (
- ThresholdFeatureCondition,
- IsDivisibleFeatureCondition,
-)
-
-
-__all__ = ["ThresholdFeatureCondition", "IsDivisibleFeatureCondition"]
diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py
deleted file mode 100644
index 489eafd..0000000
--- a/tests/integration/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
-
-"""Integration tests for the heb_graph package."""
diff --git a/tests/integration/test_behavior.py b/tests/integration/test_behavior.py
deleted file mode 100644
index d76f24d..0000000
--- a/tests/integration/test_behavior.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
-
-"""Behavior of HEBGraphs when called."""
-
-import pytest_check as check
-
-from hebg.node import Action
-from tests.examples.behaviors import (
- FundamentalBehavior,
- AA_Behavior,
- F_A_Behavior,
- F_F_A_Behavior,
- AF_A_Behavior,
- F_AA_Behavior,
-)
-from tests.examples.feature_conditions import ThresholdFeatureCondition
-
-
-def test_a_graph():
- """(A) Fundamental behaviors (single action) should work properly."""
- action_id = 42
- behavior = FundamentalBehavior(Action(action_id))
- check.equal(behavior(None), action_id)
-
-
-def test_f_a_graph():
- """(F-A) Feature condition should orient path properly."""
- feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0)
- actions = {0: Action(0), 1: Action(1)}
- behavior = F_A_Behavior("F_A", feature_condition, actions)
- check.equal(behavior(1), 1)
- check.equal(behavior(-1), 0)
-
-
-def test_f_f_a_graph():
- """(F-F-A) Feature condition should orient path properly in double chain."""
- behavior = F_F_A_Behavior("F_F_A")
- check.equal(behavior(-2), 0)
- check.equal(behavior(-1), 1)
- check.equal(behavior(1), 2)
- check.equal(behavior(2), 3)
-
-
-def test_aa_graph():
- """(AA) Should choose between roots depending on 'any_mode'."""
- behavior = AA_Behavior("AA", any_mode="first")
- check.equal(behavior(None), 0)
-
- behavior = AA_Behavior("AA", any_mode="last")
- check.equal(behavior(None), 1)
-
-
-def test_af_a_graph():
- """(AF-A) Should choose between roots depending on 'any_mode'."""
- behavior = AF_A_Behavior("AF_A", any_mode="first")
- check.equal(behavior(1), 0)
- check.equal(behavior(-1), 0)
-
- behavior = AF_A_Behavior("AF_A", any_mode="last")
- check.equal(behavior(1), 1)
- check.equal(behavior(-1), 2)
-
-
-def test_f_af_a_graph():
- """(F-AA) Should choose between condition edges depending on 'any_mode'."""
- behavior = F_AA_Behavior("F_AA", any_mode="first")
- check.equal(behavior(1), 0)
- check.equal(behavior(-1), 1)
-
- behavior = F_AA_Behavior("F_AA", any_mode="last")
- check.equal(behavior(1), 0)
- check.equal(behavior(-1), 2)
diff --git a/tests/integration/test_loop_with_alternative.py b/tests/integration/test_loop_with_alternative.py
deleted file mode 100644
index 641d6b2..0000000
--- a/tests/integration/test_loop_with_alternative.py
+++ /dev/null
@@ -1,93 +0,0 @@
-import pytest
-import pytest_check as check
-
-import networkx as nx
-from hebg import HEBGraph
-from hebg.unrolling import unroll_graph
-
-from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors
-
-import matplotlib.pyplot as plt
-
-
-class TestLoop:
- """Tests for the loop example"""
-
- @pytest.fixture(autouse=True)
- def setup_method(self):
- self.gather_wood, self.get_new_axe = build_looping_behaviors()
-
- def test_unroll_gather_wood(self):
- draw = False
- unrolled_graph = unroll_graph(self.gather_wood.graph)
- if draw:
- _plot_graph(unrolled_graph)
-
- expected_graph = nx.DiGraph()
- expected_graph.add_edge("Has axe", "Punch tree")
- expected_graph.add_edge("Has axe", "Cut tree with axe")
- expected_graph.add_edge("Has axe", "Has wood")
-
- # Expected sub-behavior
- expected_graph.add_edge("Has wood", "Gather wood")
- expected_graph.add_edge("Has wood", "Craft axe")
- check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph))
-
- def test_unroll_get_new_axe(self):
- draw = False
- unrolled_graph = unroll_graph(self.get_new_axe.graph)
- if draw:
- _plot_graph(unrolled_graph)
-
- expected_graph = nx.DiGraph()
- expected_graph.add_edge("Has wood", "Has axe")
- expected_graph.add_edge("Has wood", "Craft new axe")
-
- # Expected sub-behavior
- expected_graph.add_edge("Has axe", "Punch tree")
- expected_graph.add_edge("Has axe", "Cut tree with axe")
- expected_graph.add_edge("Has axe", "Get new axe")
- check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph))
-
- def test_unroll_gather_wood_cutting_alternatives(self):
- draw = False
- unrolled_graph = unroll_graph(
- self.gather_wood.graph,
- cut_looping_alternatives=True,
- )
- if draw:
- _plot_graph(unrolled_graph)
-
- expected_graph = nx.DiGraph()
- expected_graph.add_edge("Has axe", "Punch tree")
- expected_graph.add_edge("Has axe", "Cut tree with axe")
- expected_graph.add_edge("Has axe", "Has wood")
-
- # Expected sub-behavior
- expected_graph.add_edge("Has wood", "Punch tree")
- expected_graph.add_edge("Has wood", "Craft axe")
- check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph))
-
- def test_unroll_get_new_axe_cutting_alternatives(self):
- draw = False
- unrolled_graph = unroll_graph(
- self.get_new_axe.graph,
- cut_looping_alternatives=True,
- )
- if draw:
- _plot_graph(unrolled_graph)
-
- expected_graph = nx.DiGraph()
- expected_graph.add_edge("Has wood", "Has axe")
- expected_graph.add_edge("Has wood", "Craft new axe")
-
- # Expected sub-behavior
- expected_graph.add_edge("Has axe", "Punch tree")
- expected_graph.add_edge("Has axe", "Cut tree with axe")
- check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph))
-
-
-def _plot_graph(graph: "HEBGraph"):
- _, ax = plt.subplots()
- graph.draw(ax)
- plt.show()
diff --git a/tests/integration/test_loop_without_alternative.py b/tests/integration/test_loop_without_alternative.py
deleted file mode 100644
index a6e7490..0000000
--- a/tests/integration/test_loop_without_alternative.py
+++ /dev/null
@@ -1,45 +0,0 @@
-import pytest
-import pytest_check as check
-
-import networkx as nx
-from hebg import HEBGraph
-from hebg.unrolling import unroll_graph
-
-from tests.examples.behaviors.loop_without_alternative import build_looping_behaviors
-
-import matplotlib.pyplot as plt
-
-
-class TestLoop:
- """Tests for the loop example"""
-
- @pytest.fixture(autouse=True)
- def setup_method(self):
- (
- self.reach_forest,
- self.reach_other_zone,
- self.reach_meadow,
- ) = build_looping_behaviors()
-
- @pytest.mark.xfail
- def test_unroll_reach_forest(self):
- draw = False
- unrolled_graph = unroll_graph(
- self.reach_forest.graph,
- add_prefix=True,
- cut_looping_alternatives=True,
- )
- if draw:
- _plot_graph(unrolled_graph)
-
- expected_graph = nx.DiGraph()
- check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph))
-
-
-def _plot_graph(graph: "HEBGraph"):
- _, ax = plt.subplots()
- pos = None
- if len(graph.roots) == 0:
- pos = nx.spring_layout(graph)
- graph.draw(ax, pos=pos)
- plt.show()
diff --git a/tests/test_behavior.py b/tests/test_behavior.py
new file mode 100644
index 0000000..b5de95e
--- /dev/null
+++ b/tests/test_behavior.py
@@ -0,0 +1,138 @@
+# HEBGraph for explainable hierarchical reinforcement learning
+# Copyright (C) 2021-2024 Mathïs FEDERICO
+
+"""Behavior of HEBGraphs when called."""
+
+import pytest
+import pytest_check as check
+from pytest_mock import MockerFixture
+
+from hebg.behavior import Behavior
+from hebg.heb_graph import HEBGraph
+from hebg.node import Action
+
+from tests.examples.behaviors import FundamentalBehavior, F_A_Behavior, F_F_A_Behavior
+from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors
+from tests.examples.feature_conditions import ThresholdFeatureCondition
+
+
+class TestBehavior:
+
+ """Behavior"""
+
+ @pytest.fixture(autouse=True)
+ def setup(self):
+ """Initialize variables."""
+ self.node = Behavior("behavior_name")
+
+ def test_node_type(self):
+ """should have 'behavior' as node_type."""
+ check.equal(self.node.type, "behavior")
+
+ def test_node_call(self, mocker: MockerFixture):
+ """should use graph on call."""
+ mocker.patch("hebg.behavior.Behavior.graph")
+ self.node(None)
+ check.is_true(self.node.graph.called)
+
+ def test_build_graph(self):
+ """should raise NotImplementedError when build_graph is called."""
+ with pytest.raises(NotImplementedError):
+ self.node.build_graph()
+
+ def test_graph(self, mocker: MockerFixture):
+ """should build graph and compute its levels if, and only if,
+ the graph is not yet built.
+ """
+ mocker.patch("hebg.behavior.Behavior.build_graph")
+ mocker.patch("hebg.behavior.compute_levels")
+ self.node.graph
+ check.is_true(self.node.build_graph.called)
+ check.is_true(self.node.build_graph.called)
+
+ mocker.patch("hebg.behavior.Behavior.build_graph")
+ mocker.patch("hebg.behavior.compute_levels")
+ self.node.graph
+ check.is_false(self.node.build_graph.called)
+ check.is_false(self.node.build_graph.called)
+
+
+class TestPathfinding:
+ def test_fundamental_behavior(self):
+ """Fundamental behavior (single action) should return its action."""
+ action_id = 42
+ behavior = FundamentalBehavior(Action(action_id))
+ check.equal(behavior(None), action_id)
+
+ def test_feature_condition_single(self):
+ """Feature condition should orient path properly."""
+ feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0)
+ actions = {0: Action(0), 1: Action(1)}
+ behavior = F_A_Behavior("F_A", feature_condition, actions)
+ check.equal(behavior(1), 1)
+ check.equal(behavior(-1), 0)
+
+ def test_feature_conditions_chained(self):
+ """Feature condition should orient path properly in double chain."""
+ behavior = F_F_A_Behavior("F_F_A")
+ check.equal(behavior(-2), 0)
+ check.equal(behavior(-1), 1)
+ check.equal(behavior(1), 2)
+ check.equal(behavior(2), 3)
+
+ def test_looping_resolve(self):
+ """Loops with alternatives should be ignored."""
+ _gather_wood, get_axe = build_looping_behaviors()
+ check.equal(get_axe({}), "Punch tree")
+
+
+class TestCostBehavior:
+ def test_choose_root_of_lesser_cost(self):
+ """Should choose root of lesser cost."""
+
+ expected_action = "EXPECTED"
+
+ class AAA_Behavior(Behavior):
+ def __init__(self) -> None:
+ super().__init__("AAA")
+
+ def build_graph(self) -> HEBGraph:
+ graph = HEBGraph(self)
+ graph.add_node(Action(0, complexity=2))
+ graph.add_node(Action(expected_action, complexity=1))
+ graph.add_node(Action(2, complexity=3))
+ return graph
+
+ behavior = AAA_Behavior()
+ check.equal(behavior(None), expected_action)
+
+ def test_not_path_of_least_cost(self):
+ """Should choose path of larger complexity if individual costs lead to it."""
+
+ class AF_A_Behavior(Behavior):
+
+ """Double root with feature condition and action"""
+
+ def __init__(self) -> None:
+ super().__init__("AF_A")
+
+ def build_graph(self) -> HEBGraph:
+ graph = HEBGraph(self)
+
+ graph.add_node(Action(0, complexity=1.5))
+ feature_condition = ThresholdFeatureCondition(
+ relation=">=", threshold=0, complexity=1.0
+ )
+
+ graph.add_edge(
+ feature_condition, Action(1, complexity=1.0), index=int(True)
+ )
+ graph.add_edge(
+ feature_condition, Action(2, complexity=1.0), index=int(False)
+ )
+
+ return graph
+
+ behavior = AF_A_Behavior()
+ check.equal(behavior(1), 1)
+ check.equal(behavior(-1), 2)
diff --git a/tests/integration/test_behavior_empty.py b/tests/test_behavior_empty.py
similarity index 94%
rename from tests/integration/test_behavior_empty.py
rename to tests/test_behavior_empty.py
index 7efd053..6c06b19 100644
--- a/tests/integration/test_behavior_empty.py
+++ b/tests/test_behavior_empty.py
@@ -1,5 +1,5 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
"""Behavior of HEBGraphs with empty nodes."""
diff --git a/tests/test_call_graph.py b/tests/test_call_graph.py
new file mode 100644
index 0000000..4e297ed
--- /dev/null
+++ b/tests/test_call_graph.py
@@ -0,0 +1,284 @@
+from networkx import DiGraph
+import pytest
+
+from hebg.behavior import Behavior
+from hebg.call_graph import CallEdgeStatus, CallGraph, CallNode, _call_graph_pos
+from hebg.heb_graph import HEBGraph
+from hebg.node import Action, FeatureCondition
+
+from pytest_mock import MockerFixture
+import pytest_check as check
+
+from tests import plot_graph
+
+from tests.examples.behaviors import F_F_A_Behavior
+from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors
+from tests.examples.feature_conditions import ThresholdFeatureCondition
+
+
+class TestCall:
+ """Ensure that the call graph is faithful for debugging and efficient breadth first search."""
+
+ def test_call_stack_without_branches(self) -> None:
+ """When there is no branches, the graph should be a simple sequence of the call stack."""
+ f_f_a_behavior = F_F_A_Behavior()
+
+ draw = False
+ if draw:
+ plot_graph(f_f_a_behavior.graph.unrolled_graph)
+ f_f_a_behavior(observation=-2)
+
+ expected_graph = DiGraph(
+ [
+ ("F_F_A", "Greater or equal to 0 ?"),
+ ("Greater or equal to 0 ?", "Greater or equal to -1 ?"),
+ ("Greater or equal to -1 ?", "Action(0)"),
+ ]
+ )
+
+ call_graph = f_f_a_behavior.graph.call_graph
+ assert set(call_graph.call_edge_labels()) == set(expected_graph.edges())
+
+ def test_split_on_same_fc_index(self, mocker: MockerFixture) -> None:
+ """When there are multiple indexes on the same feature condition,
+ a branch should be created."""
+
+ expected_action = Action("EXPECTED", complexity=1)
+
+ forbidden_value = "FORBIDDEN"
+ forbidden_action = Action(forbidden_value, complexity=2)
+ forbidden_action.__call__ = mocker.MagicMock(return_value=forbidden_value)
+
+ class F_AA_Behavior(Behavior):
+ """Feature condition with mutliple actions on same index."""
+
+ def __init__(self) -> None:
+ super().__init__("F_AA")
+
+ def build_graph(self) -> HEBGraph:
+ graph = HEBGraph(self)
+ feature_condition = ThresholdFeatureCondition(
+ relation=">=", threshold=0
+ )
+ graph.add_edge(
+ feature_condition, Action(0, complexity=0), index=int(True)
+ )
+ graph.add_edge(feature_condition, forbidden_action, index=int(False))
+ graph.add_edge(feature_condition, expected_action, index=int(False))
+
+ return graph
+
+ f_aa_behavior = F_AA_Behavior()
+ draw = False
+ if draw:
+ plot_graph(f_aa_behavior.graph.unrolled_graph)
+
+ # Sanity check that the right action should be called and not the forbidden one.
+ assert f_aa_behavior(observation=-1) == expected_action.action
+ forbidden_action.__call__.assert_not_called()
+
+ # Graph should have the good split
+ call_graph = f_aa_behavior.graph.call_graph
+ expected_graph = DiGraph(
+ [
+ ("F_AA", "Greater or equal to 0 ?"),
+ ("Greater or equal to 0 ?", "Action(EXPECTED)"),
+ ("Greater or equal to 0 ?", "Action(FORBIDDEN)"),
+ ]
+ )
+ assert set(call_graph.call_edge_labels()) == set(expected_graph.edges())
+
+ @pytest.mark.xfail
+ def test_multiple_call_to_same_fc(self, mocker: MockerFixture) -> None:
+ """Call graph should allow for the same feature condition
+ to be called multiple times in the same branch (in different behaviors)."""
+ expected_action = Action("EXPECTED")
+ unexpected_action = Action("UNEXPECTED")
+
+ feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0)
+
+ class SubBehavior(Behavior):
+ def __init__(self) -> None:
+ super().__init__("SubBehavior")
+
+ def build_graph(self) -> HEBGraph:
+ graph = HEBGraph(self)
+ graph.add_edge(feature_condition, expected_action, index=int(True))
+ graph.add_edge(feature_condition, unexpected_action, index=int(False))
+ return graph
+
+ class RootBehavior(Behavior):
+ """Feature condition with mutliple actions on same index."""
+
+ def __init__(self) -> None:
+ super().__init__("RootBehavior")
+
+ def build_graph(self) -> HEBGraph:
+ graph = HEBGraph(self)
+ graph.add_edge(feature_condition, SubBehavior(), index=int(True))
+ graph.add_edge(feature_condition, unexpected_action, index=int(False))
+
+ return graph
+
+ root_behavior = RootBehavior()
+ draw = False
+ if draw:
+ plot_graph(root_behavior.graph.unrolled_graph)
+
+ # Sanity check that the right action should be called and not the forbidden one.
+ assert root_behavior(observation=2) == expected_action.action
+
+ # Graph should have the good split
+ call_graph = root_behavior.graph.call_graph
+ expected_graph = DiGraph(
+ [
+ ("RootBehavior", "Greater or equal to 0 ?"),
+ ("Greater or equal to 0 ?", "SubBehavior"),
+ ("SubBehavior", "Greater or equal to 0 ?"),
+ ("Greater or equal to 0 ?", "Action(EXPECTED)"),
+ ]
+ )
+ assert set(call_graph.call_edge_labels()) == set(expected_graph.edges())
+
+ expected_labels = {
+ CallNode(0, 0): "RootBehavior",
+ CallNode(0, 1): "Greater or equal to 0 ?",
+ CallNode(0, 2): "SubBehavior",
+ CallNode(0, 3): "Greater or equal to 0 ?",
+ CallNode(0, 4): "Action(EXPECTED)",
+ }
+ for node, label in call_graph.nodes(data="label"):
+ check.equal(label, expected_labels[node])
+
+ def test_chain_behaviors(self, mocker: MockerFixture) -> None:
+ """When sub-behaviors with a graph are called recursively,
+ the call graph should still find their nodes."""
+
+ expected_action = "EXPECTED"
+
+ class DummyBehavior(Behavior):
+ __call__ = mocker.MagicMock(return_value=expected_action)
+
+ class SubBehavior(Behavior):
+ def __init__(self) -> None:
+ super().__init__("SubBehavior")
+
+ def build_graph(self) -> HEBGraph:
+ graph = HEBGraph(self)
+ graph.add_node(DummyBehavior("Dummy"))
+ return graph
+
+ class RootBehavior(Behavior):
+ def __init__(self) -> None:
+ super().__init__("RootBehavior")
+
+ def build_graph(self) -> HEBGraph:
+ graph = HEBGraph(self)
+ graph.add_node(Behavior("SubBehavior"))
+ return graph
+
+ sub_behavior = SubBehavior()
+ sub_behavior.graph
+
+ root_behavior = RootBehavior()
+ root_behavior.graph.all_behaviors["SubBehavior"] = sub_behavior
+
+ # Sanity check that the right action should be called.
+ assert root_behavior(observation=-1) == expected_action
+
+ call_graph = root_behavior.graph.call_graph
+ expected_graph = DiGraph(
+ [
+ ("RootBehavior", "SubBehavior"),
+ ("SubBehavior", "Dummy"),
+ ]
+ )
+ assert set(call_graph.call_edge_labels()) == set(expected_graph.edges())
+
+ def test_looping_goback(self) -> None:
+ """Loops with alternatives should be ignored."""
+ draw = False
+ _gather_wood, get_axe = build_looping_behaviors()
+ assert get_axe({}) == "Punch tree"
+
+ call_graph = get_axe.graph.call_graph
+
+ if draw:
+ plot_graph(call_graph)
+
+ expected_labels = {
+ CallNode(0, 0): "Get new axe",
+ CallNode(0, 1): "Has wood ?",
+ CallNode(1, 2): "Action(Summon axe out of thin air)",
+ CallNode(0, 2): "Gather wood",
+ CallNode(0, 3): "Has axe ?",
+ CallNode(2, 4): "Get new axe",
+ CallNode(0, 4): "Action(Punch tree)",
+ }
+ for node, label in call_graph.nodes(data="label"):
+ check.equal(label, expected_labels[node])
+
+ expected_graph = DiGraph(
+ [
+ ("Get new axe", "Has wood ?"),
+ ("Has wood ?", "Action(Summon axe out of thin air)"),
+ ("Has wood ?", "Gather wood"),
+ ("Gather wood", "Has axe ?"),
+ ("Has axe ?", "Get new axe"),
+ ("Has axe ?", "Action(Punch tree)"),
+ ]
+ )
+
+ assert set(call_graph.call_edge_labels()) == set(expected_graph.edges())
+
+
+class TestDraw:
+ """Ensures that the graph is readable even in complex situations."""
+
+ def test_result_on_first_branch(self) -> None:
+ """Resulting action should always be on the first branch."""
+ draw = False
+ root_behavior = Behavior("Root", complexity=20)
+ call_graph = CallGraph()
+ call_graph.add_root(root_behavior, None)
+
+ nodes = [
+ (CallNode(0, 1), FeatureCondition("FC1", complexity=1)),
+ (CallNode(0, 2), FeatureCondition("FC2", complexity=1)),
+ (CallNode(0, 3), root_behavior),
+ (CallNode(1, 1), FeatureCondition("FC3", complexity=1)),
+ (CallNode(1, 2), FeatureCondition("FC4", complexity=1)),
+ (CallNode(1, 3), FeatureCondition("FC5", complexity=1)),
+ (CallNode(1, 4), Action("A", complexity=1)),
+ ]
+
+ for node, heb_node in nodes:
+ call_graph.add_node(node, heb_node, None)
+
+ edges = [
+ (CallNode(0, 0), CallNode(0, 1), CallEdgeStatus.CALLED),
+ (CallNode(0, 1), CallNode(0, 2), CallEdgeStatus.CALLED),
+ (CallNode(0, 2), CallNode(0, 3), CallEdgeStatus.FAILURE),
+ (CallNode(0, 0), CallNode(1, 1), CallEdgeStatus.CALLED),
+ (CallNode(1, 1), CallNode(1, 2), CallEdgeStatus.CALLED),
+ (CallNode(1, 2), CallNode(1, 3), CallEdgeStatus.CALLED),
+ (CallNode(1, 3), CallNode(1, 4), CallEdgeStatus.CALLED),
+ ]
+
+ for start, end, status in edges:
+ call_graph.add_edge(start, end, status)
+
+ expected_poses = {
+ CallNode(0, 0): [0, 0],
+ CallNode(1, 1): [0, -1],
+ CallNode(1, 2): [0, -2],
+ CallNode(1, 3): [0, -3],
+ CallNode(1, 4): [0, -4],
+ CallNode(0, 1): [1, -1],
+ CallNode(0, 2): [1, -2],
+ CallNode(0, 3): [1, -3],
+ }
+ if draw:
+ plot_graph(call_graph)
+
+ assert _call_graph_pos(call_graph) == expected_poses
diff --git a/tests/integration/test_code_generation.py b/tests/test_code_generation.py
similarity index 100%
rename from tests/integration/test_code_generation.py
rename to tests/test_code_generation.py
diff --git a/tests/test_loop.py b/tests/test_loop.py
new file mode 100644
index 0000000..5ab85db
--- /dev/null
+++ b/tests/test_loop.py
@@ -0,0 +1,122 @@
+import pytest
+import pytest_check as check
+
+import networkx as nx
+from hebg.unrolling import unroll_graph
+from tests import plot_graph
+
+from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors
+from tests.examples.behaviors.loop_without_alternative import (
+ build_looping_behaviors_without_direct_alternatives,
+)
+
+
+class TestLoopAlternative:
+ """Tests for the loop with alternative example"""
+
+ @pytest.fixture(autouse=True)
+ def setup_method(self):
+ self.gather_wood, self.get_new_axe = build_looping_behaviors()
+
+ def test_unroll_gather_wood(self):
+ draw = False
+ unrolled_graph = unroll_graph(self.gather_wood.graph, add_prefix=True)
+ if draw:
+ plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True)
+
+ expected_graph = nx.DiGraph()
+ expected_graph.add_edge("Has axe", "Punch tree")
+ expected_graph.add_edge("Has axe", "Cut tree with axe")
+ expected_graph.add_edge("Has axe", "Has wood")
+
+ # Expected sub-behavior
+ expected_graph.add_edge("Has wood", "Gather wood")
+ expected_graph.add_edge("Has wood", "Craft axe")
+ expected_graph.add_edge("Has wood", "Summon axe out of thin air")
+ check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph))
+
+ def test_unroll_get_new_axe(self):
+ draw = False
+ unrolled_graph = unroll_graph(self.get_new_axe.graph, add_prefix=True)
+ if draw:
+ plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True)
+
+ expected_graph = nx.DiGraph()
+ expected_graph.add_edge("Has wood", "Has axe")
+ expected_graph.add_edge("Has wood", "Craft new axe")
+ expected_graph.add_edge("Has wood", "Summon axe out of thin air")
+
+ # Expected sub-behavior
+ expected_graph.add_edge("Has axe", "Punch tree")
+ expected_graph.add_edge("Has axe", "Cut tree with axe")
+ expected_graph.add_edge("Has axe", "Get new axe")
+ check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph))
+
+ def test_unroll_gather_wood_cutting_alternatives(self):
+ draw = False
+ unrolled_graph = unroll_graph(
+ self.gather_wood.graph, add_prefix=True, cut_looping_alternatives=True
+ )
+ if draw:
+ plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True)
+
+ expected_graph = nx.DiGraph()
+ expected_graph.add_edge("Has axe", "Punch tree")
+ expected_graph.add_edge("Has axe", "Has wood")
+ expected_graph.add_edge("Has axe", "Use axe")
+
+ # Expected sub-behavior
+ expected_graph.add_edge("Has wood", "Summon axe of out thin air")
+ expected_graph.add_edge("Has wood", "Craft axe")
+
+ check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph))
+
+ def test_unroll_get_new_axe_cutting_alternatives(self):
+ draw = False
+ unrolled_graph = unroll_graph(
+ self.get_new_axe.graph, add_prefix=True, cut_looping_alternatives=True
+ )
+ if draw:
+ plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True)
+
+ expected_graph = nx.DiGraph(
+ [
+ ("Has wood", "Has axe"),
+ ("Has wood", "Craft new axe"),
+ ("Has wood", "Summon axe out of thin air"),
+ # Expected sub-behavior
+ ("Has axe", "Punch tree"),
+ ("Has axe", "Cut tree with axe"),
+ ]
+ )
+ check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph))
+
+ @pytest.mark.xfail
+ def test_unroll_root_alternative_reach_forest(self):
+ (
+ reach_forest,
+ _reach_other_zone,
+ _reach_meadow,
+ ) = build_looping_behaviors_without_direct_alternatives()
+ draw = False
+ unrolled_graph = unroll_graph(
+ reach_forest.graph,
+ add_prefix=True,
+ cut_looping_alternatives=True,
+ )
+ if draw:
+ plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True)
+
+ expected_graph = nx.DiGraph(
+ [
+ # ("Root", "Is in other zone ?"),
+ # ("Root", "Is in meadow ?"),
+ ("Is in other zone ?", "Reach other zone"),
+ ("Is in other zone ?", "Go to forest"),
+ ("Is in meadow ?", "Go to forest"),
+ ("Is in meadow ?", "Reach meadow>Is in other zones ?"),
+ ("Reach meadow>Is in other zone ?", "Reach meadow>Reach other zone"),
+ ("Reach meadow>Is in other zone ?", "Reach meadow>Go to forest"),
+ ]
+ )
+ check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph))
diff --git a/tests/unit/layouts/test_metaheuristics.py b/tests/test_metaheuristics.py
similarity index 91%
rename from tests/unit/layouts/test_metaheuristics.py
rename to tests/test_metaheuristics.py
index c9f096e..a2c3517 100644
--- a/tests/unit/layouts/test_metaheuristics.py
+++ b/tests/test_metaheuristics.py
@@ -1,5 +1,5 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
"""Unit tests for the hebg.metaheuristics module."""
diff --git a/tests/unit/test_node.py b/tests/test_node.py
similarity index 97%
rename from tests/unit/test_node.py
rename to tests/test_node.py
index 2f13d77..e034918 100644
--- a/tests/unit/test_node.py
+++ b/tests/test_node.py
@@ -1,5 +1,5 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
"""Unit tests for the hebg.node module."""
diff --git a/tests/integration/test_paper_basic_example.py b/tests/test_paper_basic_example.py
similarity index 99%
rename from tests/integration/test_paper_basic_example.py
rename to tests/test_paper_basic_example.py
index 70ad159..7d70eff 100644
--- a/tests/integration/test_paper_basic_example.py
+++ b/tests/test_paper_basic_example.py
@@ -1,5 +1,5 @@
# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
+# Copyright (C) 2021-2024 Mathïs FEDERICO
"""Integration tests for the initial paper examples."""
diff --git a/tests/integration/test_pet_a_cat.py b/tests/test_pet_a_cat.py
similarity index 98%
rename from tests/integration/test_pet_a_cat.py
rename to tests/test_pet_a_cat.py
index dd936e2..2919336 100644
--- a/tests/integration/test_pet_a_cat.py
+++ b/tests/test_pet_a_cat.py
@@ -27,7 +27,7 @@
from hebg import HEBGraph, Action, FeatureCondition, Behavior
from hebg.unrolling import unroll_graph
-from tests.integration.test_code_generation import _unidiff_output
+from tests.test_code_generation import _unidiff_output
class Pet(Action):
diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py
deleted file mode 100644
index 8c9c1f0..0000000
--- a/tests/unit/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
-
-"""Unit tests for the heb_graph package."""
diff --git a/tests/unit/layouts/__init__.py b/tests/unit/layouts/__init__.py
deleted file mode 100644
index f0cb5ff..0000000
--- a/tests/unit/layouts/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
-
-"""Unit tests for the hebg.layouts submodules."""
diff --git a/tests/unit/metrics/__init__.py b/tests/unit/metrics/__init__.py
deleted file mode 100644
index 3e53eae..0000000
--- a/tests/unit/metrics/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
-
-"""Unit tests for the hebg.metrics module."""
diff --git a/tests/unit/metrics/complexity/__init__.py b/tests/unit/metrics/complexity/__init__.py
deleted file mode 100644
index d670a73..0000000
--- a/tests/unit/metrics/complexity/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
-
-"""Unit tests for the hebg.metrics.complexity module."""
diff --git a/tests/unit/metrics/complexity/test_complexities.py b/tests/unit/metrics/complexity/test_complexities.py
deleted file mode 100644
index f1fb124..0000000
--- a/tests/unit/metrics/complexity/test_complexities.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
-
-"""Integration tests for the hebg.metrics.complexity.complexities module."""
-
-import pytest
-
-
-class TestComplexities:
- """Complexities"""
-
- @pytest.fixture(autouse=True)
- def setup(self):
- """Initialize variables."""
diff --git a/tests/unit/test_option.py b/tests/unit/test_option.py
deleted file mode 100644
index d6e6f3d..0000000
--- a/tests/unit/test_option.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
-
-"""Unit tests for the hebg.behavior module."""
-
-import pytest
-import pytest_check as check
-from pytest_mock import MockerFixture
-
-from hebg.behavior import Behavior
-
-
-class TestBehavior:
- """Behavior"""
-
- @pytest.fixture(autouse=True)
- def setup(self):
- """Initialize variables."""
- self.node = Behavior("behavior_name")
-
- def test_node_type(self):
- """should have 'behavior' as node_type."""
- check.equal(self.node.type, "behavior")
-
- def test_node_call(self, mocker: MockerFixture):
- """should use graph on call."""
- mocker.patch("hebg.behavior.Behavior.graph")
- self.node(None)
- check.is_true(self.node.graph.called)
-
- def test_build_graph(self):
- """should raise NotImplementedError when build_graph is called."""
- with pytest.raises(NotImplementedError):
- self.node.build_graph()
-
- def test_graph(self, mocker: MockerFixture):
- """should build graph and compute its levels if, and only if,
- the graph is not yet built.
- """
- mocker.patch("hebg.behavior.Behavior.build_graph")
- mocker.patch("hebg.behavior.compute_levels")
- self.node.graph
- check.is_true(self.node.build_graph.called)
- check.is_true(self.node.build_graph.called)
-
- mocker.patch("hebg.behavior.Behavior.build_graph")
- mocker.patch("hebg.behavior.compute_levels")
- self.node.graph
- check.is_false(self.node.build_graph.called)
- check.is_false(self.node.build_graph.called)
diff --git a/tests/unit/test_option_graph.py b/tests/unit/test_option_graph.py
deleted file mode 100644
index 7af386c..0000000
--- a/tests/unit/test_option_graph.py
+++ /dev/null
@@ -1,450 +0,0 @@
-# HEBGraph for explainable hierarchical reinforcement learning
-# Copyright (C) 2021-2022 Mathïs FEDERICO
-# pylint: disable=protected-access, unused-argument, missing-function-docstring
-
-"""Unit tests for the hebg.behavior module."""
-
-from copy import deepcopy
-
-import pytest
-import pytest_check as check
-from pytest_mock import MockerFixture
-
-from hebg.heb_graph import HEBGraph, DiGraph, Behavior
-
-
-class TestHEBGraph:
- """HEBGraph"""
-
- @pytest.fixture(autouse=True)
- def setup(self):
- """Initialize variables."""
- self.behavior = Behavior("base_behavior_name")
- self.heb_graph = HEBGraph(self.behavior)
-
- def test_init(self):
- """should instanciate correctly."""
- graph = self.heb_graph
- check.equal(graph.behavior, self.behavior)
- check.equal(graph.all_behaviors, {})
- check.equal(graph.any_mode, "first")
- check.is_true(isinstance(graph, DiGraph))
-
- def test_add_node(self, mocker: MockerFixture):
- """should add Node to the graph correctly."""
- mocker.patch("hebg.heb_graph.DiGraph.add_node")
-
- class DummyNode:
- """DummyNode"""
-
- name = "node_name"
- type = "node_type"
- image = "node_image"
-
- def __str__(self) -> str:
- return self.name
-
- node = DummyNode()
- expected_kwargs = {"type": "node_type", "image": "node_image", "color": None}
- expected_args = (node,)
-
- self.heb_graph.add_node(node)
- args, kwargs = DiGraph.add_node.call_args
- check.equal(kwargs, expected_kwargs)
- check.equal(args, expected_args)
-
- def test_add_edge_only(self, mocker: MockerFixture):
- """should add edges to the graph correctly if nodes already exists."""
- mocker.patch("hebg.heb_graph.DiGraph.add_edge")
-
- class DummyNode:
- """DummyNode"""
-
- type = "node_type"
- image = "node_image"
-
- def __init__(self, i) -> None:
- self.name = f"node_name_{i}"
-
- def __str__(self) -> str:
- return self.name
-
- node_0, node_1 = DummyNode(0), DummyNode(1)
- self.heb_graph.add_node(node_0)
- self.heb_graph.add_node(node_1)
-
- mocker.patch("hebg.heb_graph.DiGraph.add_node")
- expected_kwargs = {"index": 42, "color": "black"}
- expected_args = (node_0, node_1)
-
- self.heb_graph.add_edge(node_0, node_1, index=42)
- args, kwargs = DiGraph.add_edge.call_args
- check.equal(kwargs, expected_kwargs)
- check.equal(args, expected_args)
- check.is_false(DiGraph.add_node.called)
-
- def test_add_edge_and_nodes(self, mocker: MockerFixture):
- """should add edges and nodes to the graph correctly if nodes are not in the graph yet."""
- mocker.patch("hebg.heb_graph.DiGraph.add_edge")
- mocker.patch("hebg.heb_graph.DiGraph.add_node")
-
- class DummyNode:
- """DummyNode"""
-
- type = "node_type"
- image = "node_image"
-
- def __init__(self, i) -> None:
- self.name = f"node_name_{i}"
-
- def __str__(self) -> str:
- return self.name
-
- node_0, node_1 = DummyNode(0), DummyNode(1)
-
- expected_nodes_args = ((node_0,), (node_1,))
- expected_edge_kwargs = {"index": 42, "color": "black"}
- expected_edge_args = (node_0, node_1)
-
- self.heb_graph.add_edge(node_0, node_1, index=42)
- args, kwargs = DiGraph.add_edge.call_args
- check.equal(kwargs, expected_edge_kwargs)
- check.equal(args, expected_edge_args)
-
- for i, (args, _) in enumerate(DiGraph.add_node.call_args_list):
- check.equal(args, expected_nodes_args[i])
-
- def test_call(self, mocker: MockerFixture):
- """should return roots action on call."""
- roots = "roots"
- observation = "obs"
- mocker.patch("hebg.heb_graph.HEBGraph._get_options")
- mocker.patch("hebg.heb_graph.HEBGraph.roots", roots)
- self.heb_graph(observation)
- args, _ = HEBGraph._get_options.call_args
- check.equal(args[0], roots)
- check.equal(args[1], observation)
- check.equal(args[2], [self.behavior.name])
-
- def test_roots(self, mocker: MockerFixture):
- """should have roots as property."""
-
- nodes = ["A", "B", "C", "AA", "AB"]
- predecessors = {"A": [], "B": [], "C": [], "AA": ["A"], "AB": ["A", "B"]}
-
- mocker.patch("hebg.heb_graph.HEBGraph.nodes", lambda self: nodes)
- mocker.patch(
- "hebg.heb_graph.HEBGraph.predecessors",
- lambda self, node: predecessors[node],
- )
-
- check.equal(self.heb_graph.roots, ["A", "B", "C"])
-
-
-class TestHEBGraphGetAnyAction:
- """HEBGraph._get_any_action"""
-
- @pytest.fixture(autouse=True)
- def setup(self):
- """Initialize variables."""
- self.behavior = Behavior("behavior_name")
- self.heb_graph = HEBGraph(self.behavior)
-
- def test_none_in_actions(self, mocker: MockerFixture):
- """should return None if any node returns None."""
-
- actions = [0, "Impossible", 2, None, 3]
-
- _actions = deepcopy(actions)
-
- def mocked_get_action(*args, **kwargs):
- return _actions.pop(0)
-
- mocker.patch("hebg.heb_graph.HEBGraph._get_action", mocked_get_action)
- action = self.heb_graph._get_any_action(range(5), None, None)
- check.is_none(action)
-
- def test_no_actions(self, mocker: MockerFixture):
- """should return 'Impossible' if no action is possible."""
-
- actions = ["Impossible", "Impossible", "Impossible", "Impossible", "Impossible"]
-
- _actions = deepcopy(actions)
-
- def mocked_get_action(*args, **kwargs):
- return _actions.pop(0)
-
- mocker.patch("hebg.heb_graph.HEBGraph._get_action", mocked_get_action)
- action = self.heb_graph._get_any_action(range(5), None, None)
- check.equal(action, "Impossible")
-
- action = self.heb_graph._get_any_action([], None, None)
- check.equal(action, "Impossible")
-
- @pytest.mark.parametrize("any_mode", HEBGraph.ANY_MODES)
- def test_any_mode_(self, any_mode, mocker: MockerFixture):
- actions = [0, "Impossible", 2, 3, "Impossible"]
-
- _actions = deepcopy(actions)
-
- def mocked_get_action(*args, **kwargs):
- return _actions.pop(0)
-
- expected_actions = {"first": (0,), "last": (3,), "random": (0, 2, 3)}
-
- mocker.patch("hebg.heb_graph.HEBGraph._get_action", mocked_get_action)
- self.heb_graph.any_mode = any_mode
- action = self.heb_graph._get_any_action(range(5), None, None)
- check.is_in(action, expected_actions[any_mode])
-
-
-class TestHEBGraphGetAction:
- """HEBGraph._get_action"""
-
- @pytest.fixture(autouse=True)
- def setup(self):
- """Initialize variables."""
- self.behavior = Behavior("behavior_name")
- self.heb_graph = HEBGraph(self.behavior)
-
- def test_action(self):
- """should return action given by an Action node."""
- expected_action = "action_action"
-
- class DummyAction:
- """DummyNode"""
-
- type = "action"
-
- def __call__(self, observation):
- return expected_action
-
- action_node = DummyAction()
- action = self.heb_graph._get_action(action_node, None, None)
- check.equal(action, expected_action)
-
- def test_empty(self, mocker: MockerFixture):
- """should return next successor returns when given an Empty node."""
-
- expected_action = "action_action"
-
- class DummyAction:
- """DummyNode"""
-
- type = "action"
-
- def __call__(self, observation):
- return expected_action
-
- mocker.patch(
- "hebg.heb_graph.HEBGraph.successors",
- lambda self, node: iter([DummyAction()]),
- )
-
- class DummyEmpty:
- """DummyNode"""
-
- type = "empty"
-
- empty_node = DummyEmpty()
- action = self.heb_graph._get_action(empty_node, None, None)
- check.equal(action, expected_action)
-
- def test_unknowed_node_type(self):
- """should raise ValueError if node.type is unknowed."""
-
- class DummyNode:
- """DummyNode"""
-
- type = "random_type_error"
-
- node = DummyNode()
- with pytest.raises(ValueError):
- self.heb_graph._get_action(node, None, None)
-
- def test_behavior_in_search(self):
- """should return 'Impossible' if behavior is already in search to avoid cycles."""
-
- class DummyBehavior:
- """DummyBehavior"""
-
- type = "behavior"
- name = "behavior_already_in_search"
-
- def __str__(self) -> str:
- return self.name
-
- node = DummyBehavior()
- action = self.heb_graph._get_action(node, None, ["behavior_already_in_search"])
- check.equal(action, "Impossible")
-
- def test_behavior_call(self):
- """should return behavior's return if behavior can be called."""
- expected_action = "behavior_action"
-
- class DummyBehavior:
- """DummyBehavior"""
-
- type = "behavior"
- name = "behavior_name"
-
- def __str__(self) -> str:
- return self.name
-
- def __call__(self, *args) -> str:
- return expected_action
-
- node = DummyBehavior()
- action = self.heb_graph._get_action(node, None, [])
- check.equal(action, expected_action)
-
- def test_behavior_by_all_behaviors(self):
- """should use all_behaviors if behavior cannot be called."""
- expected_action = "behavior_action"
-
- class DummyTrueBehavior:
- """DummyBehavior"""
-
- type = "behavior"
- name = "behavior_name"
-
- def __str__(self) -> str:
- return self.name
-
- def __call__(self, *args):
- return expected_action
-
- class DummyBehaviorInGraph:
- """DummyBehavior"""
-
- type = "behavior"
- name = "behavior_name"
-
- def __str__(self) -> str:
- return self.name
-
- def __call__(self, *args):
- raise NotImplementedError
-
- true_node = DummyTrueBehavior()
- self.heb_graph.all_behaviors = {true_node.name: true_node}
-
- node_in_graph = DummyBehaviorInGraph()
- action = self.heb_graph._get_action(
- node_in_graph,
- observation=None,
- behaviors_in_search=[],
- )
- check.equal(action, expected_action)
-
- def test_feature_condition(self, mocker: MockerFixture):
- """should use FeatureCondition's given index to orient in graph."""
-
- class DummyAction:
- """DummyAction"""
-
- type = "action"
- name = "action_name"
-
- def __init__(self, i) -> None:
- self.index = i
- self.action = f"action_{i}"
-
- def __str__(self) -> str:
- return self.name
-
- def __call__(self, *args) -> str:
- return self.action
-
- class DummyFeatureCondition:
- """DummyFeatureCondition"""
-
- type = "feature_condition"
- name = "feature_condition_name"
-
- def __init__(self, i) -> None:
- self.fc_index = i
-
- def __str__(self) -> str:
- return self.name
-
- def __call__(self, *args) -> int:
- return self.fc_index
-
- actions = [DummyAction(i) for i in range(3)]
- mocker.patch(
- "hebg.heb_graph.HEBGraph.successors",
- lambda self, node: actions,
- )
-
- class DummyEdges:
- """DummyEdges"""
-
- def __getitem__(self, e):
- return {"index": e[1].index}
-
- mocker.patch("hebg.heb_graph.HEBGraph.edges", DummyEdges())
- mocker.patch(
- "hebg.heb_graph.HEBGraph._get_any_action",
- lambda self, next_nodes, observation, behaviors_in_search: next_nodes[0](
- observation
- ),
- )
-
- for fc_index in range(3):
- node = DummyFeatureCondition(fc_index)
- action = self.heb_graph._get_action(node, None, [])
- expected_action = f"action_{fc_index}"
- check.equal(action, expected_action)
-
- def test_feature_condition_index_value_error(self, mocker: MockerFixture):
- """should raise ValueError if FeatureCondition's given index represents no successor."""
-
- class DummyAction:
- """DummyAction"""
-
- type = "action"
- name = "action_name"
-
- def __init__(self, i) -> None:
- self.index = i
- self.action = f"action_{i}"
-
- def __str__(self) -> str:
- return self.name
-
- def __call__(self, *args) -> str:
- return self.action
-
- class DummyFeatureCondition:
- """DummyFeatureCondition"""
-
- type = "feature_condition"
- name = "feature_condition_name"
-
- def __init__(self, i) -> None:
- self.fc_index = i
-
- def __str__(self) -> str:
- return self.name
-
- def __call__(self, *args) -> int:
- return self.fc_index
-
- actions = [DummyAction(i) for i in range(3)]
- mocker.patch(
- "hebg.heb_graph.HEBGraph.successors",
- lambda self, node: actions,
- )
-
- class DummyEdges:
- """DummyEdges"""
-
- def __getitem__(self, e):
- return {"index": e[1].index}
-
- mocker.patch("hebg.heb_graph.HEBGraph.edges", DummyEdges())
-
- node = DummyFeatureCondition(4)
- with pytest.raises(ValueError):
- self.heb_graph._get_action(node, None, [])