diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 6a0203f..d0fecef 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -8,7 +8,7 @@ jobs: runs-on: windows-latest strategy: matrix: - python-version: ['3.8', '3.9', '3.10'] + python-version: ['3.10', '3.11', '3.12'] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index df0cbe5..d986d22 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,8 +15,8 @@ repos: hooks: - id: pytest-fast-check name: pytest-fast-check - entry: ./venv/Scripts/python.exe -m pytest -m "not slow" - stages: ["commit"] + entry: pytest -m "not slow" + stages: ["pre-commit"] language: system pass_filenames: false always_run: true @@ -24,8 +24,8 @@ repos: hooks: - id: pytest-check name: pytest-check - entry: ./venv/Scripts/python.exe -m pytest - stages: ["push"] + entry: pytest + stages: ["pre-push"] language: system pass_filenames: false always_run: true diff --git a/README.rst b/README.rst index e4f03dc..42bc25c 100644 --- a/README.rst +++ b/README.rst @@ -90,7 +90,7 @@ Here is an example to show how could we hierarchicaly build an explanable behavi def __init__(self, hand) -> None: super().__init__(name="Is hand near the cat ?") self.hand = hand - def __call__(self, observation): + def __call__(self, observation) -> int: # Could be a very complex function that returns 1 is the hand is near the cat else 0. if observation["cat"] == observation[self.hand]: return int(True) # 1 @@ -119,7 +119,7 @@ Here is an example to show how could we hierarchicaly build an explanable behavi class IsThereACatAround(FeatureCondition): def __init__(self) -> None: super().__init__(name="Is there a cat around ?") - def __call__(self, observation): + def __call__(self, observation) -> int: # Could be a very complex function that returns 1 is there is a cat around else 0. if "cat" in observation: return int(True) # 1 @@ -217,7 +217,7 @@ Will generate the code bellow: # Require 'Look for a nearby cat' behavior to be given. # Require 'Move slowly your hand near the cat' behavior to be given. class PetTheCat(GeneratedBehavior): - def __call__(self, observation): + def __call__(self, observation) -> Any: edge_index = self.feature_conditions['Is there a cat around ?'](observation) if edge_index == 0: return self.known_behaviors['Look for a nearby cat'](observation) diff --git a/commands/coverage.ps1 b/commands/coverage.ps1 new file mode 100644 index 0000000..a578912 --- /dev/null +++ b/commands/coverage.ps1 @@ -0,0 +1 @@ +pytest --cov=src --cov-report=html --cov-report=term diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..b024da8 --- /dev/null +++ b/setup.py @@ -0,0 +1,4 @@ +from setuptools import setup + + +setup() diff --git a/src/hebg/__init__.py b/src/hebg/__init__.py index 0a5dc99..8fbff98 100644 --- a/src/hebg/__init__.py +++ b/src/hebg/__init__.py @@ -1,5 +1,5 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO """A structure for explainable hierarchical reinforcement learning""" diff --git a/src/hebg/behavior.py b/src/hebg/behavior.py index fc68d8b..aa804b4 100644 --- a/src/hebg/behavior.py +++ b/src/hebg/behavior.py @@ -1,5 +1,5 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO """Module for base Behavior.""" @@ -17,11 +17,11 @@ class Behavior(Node): """Abstract class for a Behavior as Node""" - def __init__(self, name: str, image=None) -> None: - super().__init__(name, "behavior", image=image) + def __init__(self, name: str, image=None, **kwargs) -> None: + super().__init__(name, "behavior", image=image, **kwargs) self._graph = None - def __call__(self, observation, *args, **kwargs): + def __call__(self, observation, *args, **kwargs) -> None: """Use the behavior to get next actions. By default, uses the HEBGraph if it can be built. diff --git a/src/hebg/call_graph.py b/src/hebg/call_graph.py new file mode 100644 index 0000000..5c9e261 --- /dev/null +++ b/src/hebg/call_graph.py @@ -0,0 +1,315 @@ +from enum import Enum +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + NamedTuple, + Optional, + Tuple, + TypeVar, + Union, +) +from matplotlib import pyplot as plt +from matplotlib.axes import Axes + +from networkx import ( + DiGraph, + all_simple_paths, + draw_networkx_edges, + draw_networkx_labels, + draw_networkx_nodes, + ancestors, +) +import numpy as np +from hebg.behavior import Behavior +from hebg.graph import get_successors_with_index +from hebg.node import Action, FeatureCondition, Node + +if TYPE_CHECKING: + from hebg.heb_graph import HEBGraph + +EnvAction = TypeVar("EnvAction") + + +class CallGraph(DiGraph): + def __init__(self, **attr) -> None: + super().__init__(incoming_graph_data=None, **attr) + self.graph["n_branches"] = 0 + self.graph["n_calls"] = 0 + self.graph["frontiere"] = [] + self._known_fc: Dict[FeatureCondition, Any] = {} + self._current_node = CallNode(0, 0) + + def add_root(self, heb_node: "Node", heb_graph: "HEBGraph", **kwargs) -> None: + self.add_node( + self._current_node, heb_node=heb_node, heb_graph=heb_graph, **kwargs + ) + + def call_nodes( + self, nodes: List["Node"], observation, heb_graph: "HEBGraph" + ) -> EnvAction: + self._extend_frontiere(nodes, heb_graph) + action = None + + while len(self.graph["frontiere"]) > 0 and action is None: + next_call_node = self._pop_from_frontiere() + if next_call_node is None: + break + + node: "Node" = self.nodes[next_call_node]["heb_node"] + heb_graph: "HEBGraph" = self.nodes[next_call_node]["heb_graph"] + + if node.type == "behavior": + # Search for name reference in all_behaviors + if node.name in heb_graph.all_behaviors: + node = heb_graph.all_behaviors[node.name] + if not hasattr(node, "_graph") or node._graph is None: + action = node(observation) + break + self._extend_frontiere(node.graph.roots, heb_graph=node.graph) + elif node.type == "action": + action = node(observation) + break + elif node.type == "feature_condition": + if node in self._known_fc: + next_edge_index = self._known_fc[node] + else: + next_edge_index = int(node(observation)) + self._known_fc[node] = next_edge_index + next_nodes = get_successors_with_index(heb_graph, node, next_edge_index) + self._extend_frontiere(next_nodes, heb_graph) + elif node.type == "empty": + self._extend_frontiere(list(heb_graph.successors(node)), heb_graph) + else: + raise ValueError( + f"Unknowed value {node.type} for node.type with node: {node}." + ) + + if action is None: + raise ValueError("No valid frontiere left in call_graph") + + return action + + def call_edge_labels(self) -> list[tuple]: + return [ + (self.nodes[u]["label"], self.nodes[v]["label"]) for u, v in self.edges() + ] + + def add_node( + self, node_for_adding, heb_node: "Node", heb_graph: "HEBGraph", **attr + ): + super().add_node( + node_for_adding, + heb_graph=heb_graph, + heb_node=heb_node, + label=heb_node.name, + **attr, + ) + + def add_edge( + self, + u_of_edge, + v_of_edge, + status: "CallEdgeStatus", + **attr, + ): + return super().add_edge(u_of_edge, v_of_edge, status=status.value, **attr) + + def _make_new_branch(self) -> int: + self.graph["n_branches"] += 1 + return self.graph["n_branches"] + + def _extend_frontiere(self, nodes: List["Node"], heb_graph: "HEBGraph") -> None: + frontiere: List[CallNode] = self.graph["frontiere"] + + parent = self._current_node + call_nodes = [] + + for i, node in enumerate(nodes): + if i > 0: + branch_id = self._make_new_branch() + else: + branch_id = parent.branch + call_node = CallNode(branch_id, parent.rank + 1) + + if node.name in heb_graph.all_behaviors: + node = heb_graph.all_behaviors[node.name] + self.add_node(call_node, heb_node=node, heb_graph=heb_graph) + self.add_edge(parent, call_node, CallEdgeStatus.UNEXPLORED) + call_nodes.append(call_node) + + frontiere.extend(call_nodes) + + def _heb_node_from_call_node(self, node: "CallNode") -> "Node": + return self.nodes[node]["heb_node"] + + def _pop_from_frontiere(self) -> Optional["CallNode"]: + frontiere: List["CallNode"] = self.graph["frontiere"] + + next_node = None + + while next_node is None: + if not frontiere: + return None + + next_call_node = frontiere.pop( + np.argmin( + [ + self._heb_node_from_call_node(node).complexity + for node in frontiere + ] + ) + ) + maybe_next_node = self._heb_node_from_call_node(next_call_node) + # Nodes should only have one parent + parent = list(self.predecessors(next_call_node))[0] + + if isinstance(maybe_next_node, Behavior) and maybe_next_node in [ + self._heb_node_from_call_node(node) + for node in ancestors(self, next_call_node) + ]: + self._update_edge_status(parent, next_call_node, CallEdgeStatus.FAILURE) + continue + + next_node = maybe_next_node + + self.graph["n_calls"] += 1 + self.nodes[next_call_node]["call_rank"] = 1 + self._update_edge_status(parent, next_call_node, CallEdgeStatus.CALLED) + self._current_node = next_call_node + return next_call_node + + def _update_edge_status( + self, start: "Node", end: "Node", status: Union["CallEdgeStatus", str] + ): + status = CallEdgeStatus(status) + self.edges[start, end]["status"] = status.value + + def draw( + self, + ax: Optional[Axes] = None, + pos: Optional[Dict[str, Tuple[float, float]]] = None, + nodes_kwargs: Optional[dict] = None, + label_kwargs: Optional[dict] = None, + edges_kwargs: Optional[dict] = None, + ): + if pos is None: + pos = _call_graph_pos(self) + if nodes_kwargs is None: + nodes_kwargs = {} + + if ax is None: + ax = plt.gca() + + pos_arr = np.array(list(pos.values())) + max_x, max_y = pos_arr.max(axis=0) + min_x, min_y = pos_arr.min(axis=0) + y_range = max_y - min_y + x_range = max_x - min_x + ax.set_ylim([min_y - 0.1 * y_range, max_y + 0.1 * y_range]) + ax.set_xlim([min_x - 0.1 * x_range, max_x + 0.1 * x_range]) + + nodes_complexity = np.array( + [node_data["heb_node"].complexity for _, node_data in self.nodes(data=True)] + ) + complexity_range = nodes_complexity.max() - nodes_complexity.min() + + nodes_complexity_scaled = ( + 50 + 600 * (nodes_complexity - nodes_complexity.min()) / complexity_range + ) + + draw_networkx_nodes( + self, + node_color=[ + _node_color(node_data["heb_node"]) + for _, node_data in self.nodes(data=True) + ], + node_size=nodes_complexity_scaled, + ax=ax, + pos=pos, + **nodes_kwargs, + ) + if label_kwargs is None: + label_kwargs = {} + draw_networkx_labels( + self, + labels={ + node: f"{node_data['label']}" + for node, node_data in self.nodes(data=True) + }, + ax=ax, + horizontalalignment="center", + verticalalignment="center", + font_size=8, + pos=pos, + **nodes_kwargs, + ) + if edges_kwargs is None: + edges_kwargs = {} + if "connectionstyle" not in edges_kwargs: + edges_kwargs.update(connectionstyle="angle,angleA=0,angleB=90,rad=5") + draw_networkx_edges( + self, + ax=ax, + pos=pos, + arrowstyle="-", + alpha=0.5, + width=3, + node_size=1, + edge_color=[ + _call_status_to_color(status) + for _, _, status in self.edges(data="status") + ], + **edges_kwargs, + ) + + +class CallNode(NamedTuple): + branch: int + rank: int + + +class CallEdgeStatus(Enum): + UNEXPLORED = "unexplored" + CALLED = "called" + FAILURE = "failure" + + +def _node_color(node: Union[Action, FeatureCondition, Behavior]) -> str: + if isinstance(node, Action): + return "red" + if isinstance(node, FeatureCondition): + return "blue" + if isinstance(node, Behavior): + return "orange" + raise NotImplementedError + + +def _call_status_to_color(status: Union[str, "CallEdgeStatus"]) -> str: + status = CallEdgeStatus(status) + if status is CallEdgeStatus.UNEXPLORED: + return "black" + if status is CallEdgeStatus.CALLED: + return "green" + if status is CallEdgeStatus.FAILURE: + return "red" + raise NotImplementedError + + +def _call_graph_pos(call_graph: "CallGraph") -> Dict[str, Tuple[float, float]]: + pos = {} + + roots = [n for (n, d) in call_graph.in_degree if d == 0] + leafs = [n for (n, d) in call_graph.out_degree if d == 0] + + branches = all_simple_paths(call_graph, roots[0], leafs) + branches = sorted(branches, key=lambda x: -len(x)) + + for branch_id, nodes_in_branch in enumerate(branches): + for node in nodes_in_branch: + if node in pos: + continue + rank = node.rank + pos[node] = [branch_id, -rank] + return pos diff --git a/src/hebg/codegen.py b/src/hebg/codegen.py index b58987d..fec9e51 100644 --- a/src/hebg/codegen.py +++ b/src/hebg/codegen.py @@ -1,3 +1,6 @@ +# HEBGraph for explainable hierarchical reinforcement learning +# Copyright (C) 2021-2024 Mathïs FEDERICO + """Module for code generation from HEBGraph.""" from re import sub diff --git a/src/hebg/draw.py b/src/hebg/draw.py index cd4d6ad..80efd7d 100644 --- a/src/hebg/draw.py +++ b/src/hebg/draw.py @@ -1,3 +1,6 @@ +# HEBGraph for explainable hierarchical reinforcement learning +# Copyright (C) 2021-2024 Mathïs FEDERICO + import math from typing import TYPE_CHECKING, Dict, Optional, Tuple @@ -7,7 +10,7 @@ from matplotlib.axes import Axes from matplotlib.legend import Legend from matplotlib.legend_handler import HandlerPatch -from networkx import draw_networkx_edges +from networkx import draw_networkx_edges, spring_layout from scipy.spatial import ConvexHull # pylint: disable=no-name-in-module from hebg.graph import draw_networkx_nodes_images @@ -34,7 +37,10 @@ def draw_hebgraph( plt.setp(ax.spines.values(), color="orange") if pos is None: - pos = staircase_layout(graph) + if len(graph.roots) > 0: + pos = staircase_layout(graph) + else: + pos = spring_layout(graph) draw_networkx_nodes_images(graph, pos, ax=ax, img_zoom=0.5) draw_networkx_edges( diff --git a/src/hebg/graph.py b/src/hebg/graph.py index 21e3693..4269bf4 100644 --- a/src/hebg/graph.py +++ b/src/hebg/graph.py @@ -1,5 +1,5 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO # pylint: disable=protected-access """Additional utility functions for networkx graphs.""" diff --git a/src/hebg/heb_graph.py b/src/hebg/heb_graph.py index e077f69..cbb42a4 100644 --- a/src/hebg/heb_graph.py +++ b/src/hebg/heb_graph.py @@ -1,28 +1,25 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO # pylint: disable=arguments-differ """Module containing the HEBGraph base class.""" from __future__ import annotations -from typing import Any, Dict, List, Optional, Tuple, TypeVar +from typing import Any, Dict, List, Optional, Tuple -import numpy as np from matplotlib.axes import Axes from networkx import DiGraph from hebg.behavior import Behavior +from hebg.call_graph import CallGraph from hebg.codegen import get_hebg_source from hebg.draw import draw_hebgraph -from hebg.graph import get_roots, get_successors_with_index +from hebg.graph import get_roots from hebg.node import Node from hebg.unrolling import unroll_graph -Action = TypeVar("Action") - - class HEBGraph(DiGraph): """Base class for Hierchical Explanation of Behavior as Graphs. @@ -45,7 +42,6 @@ class HEBGraph(DiGraph): behavior: The Behavior object from which this graph is built. all_behaviors: A dictionary of behavior, this can be used to avoid cirular definitions using the behavior names as anchor instead of the behavior object itself. - any_mode: How to choose path, when multiple path are valid. incoming_graph_data: Additional data to include in the graph. """ @@ -60,24 +56,19 @@ class HEBGraph(DiGraph): 5: "cyan", 6: "gray", } - ANY_MODES = ("first", "last", "random") def __init__( self, behavior: Behavior, all_behaviors: Dict[str, Behavior] = None, incoming_graph_data=None, - any_mode: str = "first", **attr, ): self.behavior = behavior self.all_behaviors = all_behaviors if all_behaviors is not None else {} self._unrolled_graph = None - self.last_call_behaviors_stack = None - - assert any_mode in self.ANY_MODES, f"Unknowed any_mode: {any_mode}" - self.any_mode = any_mode + self.call_graph: Optional[CallGraph] = None super().__init__(incoming_graph_data=incoming_graph_data, **attr) @@ -106,62 +97,6 @@ def add_edge(self, u_of_edge: Node, v_of_edge: Node, index: int = 1, **attr): color = "black" super().add_edge(u_of_edge, v_of_edge, index=index, color=color, **attr) - def _get_options( - self, - nodes: List[Node], - observation, - behaviors_in_search: list, - last_call_behaviors_stack: Optional[list] = None, - parent_name: Optional[str] = None, - ) -> List[Action]: - actions = [] - for node in nodes: - node_action = self._get_action( - node, - observation, - behaviors_in_search, - last_call_behaviors_stack=last_call_behaviors_stack, - ) - if node_action is None: - return None - actions.append(node_action) - - options = remove_duplicate_actions( - [action for action in actions if action != "Impossible"] - ) - - if parent_name is None: - parent_name = self.behavior.name - if ( - (len(nodes) > 1 or self.behavior.name) - and options - and last_call_behaviors_stack is not None - ): - last_call_behaviors_stack.insert( - 0, (self.behavior.name, [n.name for n in self.roots], options) - ) - - return options - - def _choose_action(self, actions: Optional[List[Action]]) -> Action: - if actions is None: - return None - if len(actions) == 0: - return "Impossible" - if self.any_mode == "first" or len(actions) == 1: - return actions[0] - if self.any_mode == "last": - return actions[-1] - if self.any_mode == "random": - return np.random.choice(actions) - - def _get_any_action( - self, nodes: List[Node], observation, behaviors_in_search: list - ): - return self._choose_action( - self._get_options(nodes, observation, behaviors_in_search) - ) - @property def unrolled_graph(self) -> HEBGraph: """Access to the unrolled behavior graph. @@ -179,69 +114,16 @@ def unrolled_graph(self) -> HEBGraph: self._unrolled_graph = unroll_graph(self) return self._unrolled_graph - def _get_action( - self, - node: Node, - observation: Any, - behaviors_in_search: List[str], - last_call_behaviors_stack: Optional[list] = None, - ): - # Behavior - if node.type == "behavior": - # To avoid cycling definitions - if node.name in behaviors_in_search: - return "Impossible" - - # Search for name reference in all_behaviors - if node.name in self.all_behaviors: - node = self.all_behaviors[node.name] - - return node(observation, behaviors_in_search, last_call_behaviors_stack) - - # Action - if node.type == "action": - return node(observation) - # Feature Condition - if node.type == "feature_condition": - next_edge_index = int(node(observation)) - next_nodes = get_successors_with_index(self, node, next_edge_index) - options = self._get_options( - next_nodes, - observation, - behaviors_in_search, - last_call_behaviors_stack=last_call_behaviors_stack, - parent_name=node.name, - ) - return self._choose_action(options) - # Empty - if node.type == "empty": - next_node = self.successors(node).__next__() - return self._get_action( - next_node, - observation, - behaviors_in_search, - last_call_behaviors_stack=last_call_behaviors_stack, - ) - raise ValueError(f"Unknowed value {node.type} for node.type with node: {node}.") - def __call__( self, observation, - behaviors_in_search: Optional[List[str]] = None, - last_call_behaviors_stack: Optional[list] = None, + call_graph: Optional[CallGraph] = None, ) -> Any: - if behaviors_in_search is None: - behaviors_in_search = [] - last_call_behaviors_stack = [] - behaviors_in_search.append(self.behavior.name) - options = self._get_options( - self.roots, - observation, - behaviors_in_search, - last_call_behaviors_stack=last_call_behaviors_stack, - ) - self.last_call_behaviors_stack = last_call_behaviors_stack - return self._choose_action(options) + if call_graph is None: + call_graph = CallGraph() + call_graph.add_root(heb_node=self.behavior, heb_graph=self) + self.call_graph = call_graph + return self.call_graph.call_nodes(self.roots, observation, heb_graph=self) @property def roots(self) -> List[Node]: @@ -268,9 +150,3 @@ def draw( """ return draw_hebgraph(self, ax, **kwargs) - - -def remove_duplicate_actions(actions: List[Action]) -> List[Action]: - seen = set() - seen_add = seen.add - return [a for a in actions if not (a in seen or seen_add(a))] diff --git a/src/hebg/layouts/__init__.py b/src/hebg/layouts/__init__.py index 7bf586e..915138d 100644 --- a/src/hebg/layouts/__init__.py +++ b/src/hebg/layouts/__init__.py @@ -1,5 +1,5 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO """Module containing layouts to draw graphs.""" diff --git a/src/hebg/layouts/deterministic.py b/src/hebg/layouts/deterministic.py index d7ff6db..477f81b 100644 --- a/src/hebg/layouts/deterministic.py +++ b/src/hebg/layouts/deterministic.py @@ -1,5 +1,5 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO # pylint: disable=protected-access """Deterministic layouts""" diff --git a/src/hebg/layouts/metabased.py b/src/hebg/layouts/metabased.py index b6060fd..1843db4 100644 --- a/src/hebg/layouts/metabased.py +++ b/src/hebg/layouts/metabased.py @@ -1,5 +1,5 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO # pylint: disable=protected-access """Metaheuristics based layouts""" diff --git a/src/hebg/layouts/metaheuristics.py b/src/hebg/layouts/metaheuristics.py index f1282bf..3ee3292 100644 --- a/src/hebg/layouts/metaheuristics.py +++ b/src/hebg/layouts/metaheuristics.py @@ -1,5 +1,5 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO """Metaheuristics used for building layouts""" diff --git a/src/hebg/metrics/__init__.py b/src/hebg/metrics/__init__.py index c86285c..7ce4b9c 100644 --- a/src/hebg/metrics/__init__.py +++ b/src/hebg/metrics/__init__.py @@ -1,4 +1,4 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO """Module containing HEBGraph metrics.""" diff --git a/src/hebg/metrics/complexity/__init__.py b/src/hebg/metrics/complexity/__init__.py index caf67d4..dc0ae8d 100644 --- a/src/hebg/metrics/complexity/__init__.py +++ b/src/hebg/metrics/complexity/__init__.py @@ -1,5 +1,5 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO """Module for complexity computation methods.""" diff --git a/src/hebg/metrics/complexity/complexities.py b/src/hebg/metrics/complexity/complexities.py index a00da8d..992307b 100644 --- a/src/hebg/metrics/complexity/complexities.py +++ b/src/hebg/metrics/complexity/complexities.py @@ -1,5 +1,5 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO """General complexity.""" diff --git a/src/hebg/metrics/complexity/utils.py b/src/hebg/metrics/complexity/utils.py index cad8903..fe4021e 100644 --- a/src/hebg/metrics/complexity/utils.py +++ b/src/hebg/metrics/complexity/utils.py @@ -1,5 +1,5 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO """Utility functions for complexity computation.""" diff --git a/src/hebg/metrics/histograms.py b/src/hebg/metrics/histograms.py index 6ba45f9..24ef27b 100644 --- a/src/hebg/metrics/histograms.py +++ b/src/hebg/metrics/histograms.py @@ -1,5 +1,5 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO """HEBGraph used nodes histograms computation.""" @@ -308,9 +308,9 @@ def _get_node_histogram_complexity( if behaviors_in_search is not None and str(node) in behaviors_in_search: return {}, np.inf if node.type in ("action", "feature_condition", "behavior"): - try: + if node.complexity is not None: node_complexity = node.complexity - except AttributeError: + else: node_complexity = default_node_complexity return {node: 1}, node_complexity if node.type == "empty": diff --git a/src/hebg/metrics/utility/__init__.py b/src/hebg/metrics/utility/__init__.py index 98a3e1a..12adad7 100644 --- a/src/hebg/metrics/utility/__init__.py +++ b/src/hebg/metrics/utility/__init__.py @@ -1,5 +1,5 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO """Module for utility computation methods.""" diff --git a/src/hebg/metrics/utility/binary_utility.py b/src/hebg/metrics/utility/binary_utility.py index 91e9833..a533c50 100644 --- a/src/hebg/metrics/utility/binary_utility.py +++ b/src/hebg/metrics/utility/binary_utility.py @@ -1,5 +1,5 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO """Simplest binary utility for HEBGraph.""" diff --git a/src/hebg/node.py b/src/hebg/node.py index 2779bae..a398b04 100644 --- a/src/hebg/node.py +++ b/src/hebg/node.py @@ -1,5 +1,5 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO """Module for base Node classes.""" @@ -45,11 +45,7 @@ def __init__( f"not in authorised node_types ({self.NODE_TYPES})." ) self.type = node_type - if complexity is not None: - self.complexity = complexity - else: - self.complexity = bytecode_complexity(self.__init__) - self.complexity += bytecode_complexity(self.__call__) + self.complexity = complexity def __call__(self, observation: Any) -> Any: raise NotImplementedError diff --git a/src/hebg/requirements_graph.py b/src/hebg/requirements_graph.py index 1520939..73557fb 100644 --- a/src/hebg/requirements_graph.py +++ b/src/hebg/requirements_graph.py @@ -1,5 +1,5 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO # pylint: disable=arguments-differ """Module for building underlying requirement graphs based on a set of behaviors.""" diff --git a/src/hebg/unrolling.py b/src/hebg/unrolling.py index 0eb0250..7874443 100644 --- a/src/hebg/unrolling.py +++ b/src/hebg/unrolling.py @@ -1,3 +1,6 @@ +# HEBGraph for explainable hierarchical reinforcement learning +# Copyright (C) 2021-2024 Mathïs FEDERICO + """Module to unroll HEBGraph. Unrolling means expanding each sub-behavior node as it's own graph in the global HEBGraph. @@ -20,7 +23,7 @@ def unroll_graph( graph: "HEBGraph", - add_prefix=False, + add_prefix: bool = False, cut_looping_alternatives: bool = False, ) -> "HEBGraph": """Build the the unrolled HEBGraph. @@ -48,7 +51,7 @@ def unroll_graph( def _unroll_graph( graph: "HEBGraph", - add_prefix=False, + add_prefix: bool = False, cut_looping_alternatives: bool = False, _current_alternatives: Optional[List[Union["Action", "Behavior"]]] = None, _unrolled_behaviors: Optional[Dict[str, Optional["HEBGraph"]]] = None, @@ -56,7 +59,7 @@ def _unroll_graph( if _unrolled_behaviors is None: _unrolled_behaviors = {} if _current_alternatives is None: - _current_alternatives = [] + _current_alternatives = {0: []} is_looping = False _unrolled_behaviors[graph.behavior.name] = None @@ -65,14 +68,9 @@ def _unroll_graph( for node in list(unrolled_graph.nodes()): if not isinstance(node, Behavior): continue - new_alternatives = [] - for pred, _node, data in graph.in_edges(node, data=True): - index = data["index"] - for _pred, alternative, alt_index in graph.out_edges(pred, data="index"): - if index == alt_index and alternative != node: - new_alternatives.append(alternative) - if new_alternatives: - _current_alternatives = new_alternatives + + _current_alternatives[0] = _direct_alternatives(node, graph) + _current_alternatives[1] = _roots_alternatives(node, graph) unrolled_graph, behavior_is_looping = _unroll_behavior( unrolled_graph, node, @@ -87,6 +85,25 @@ def _unroll_graph( return unrolled_graph, is_looping +def _direct_alternatives(node: "Node", graph: "HEBGraph") -> list["Node"]: + alternatives = [] + for pred, _node, data in graph.in_edges(node, data=True): + index = data["index"] + for _pred, alternative, alt_index in graph.out_edges(pred, data="index"): + if index != alt_index or alternative == node: + continue + alternatives.append(alternative) + return alternatives + + +def _roots_alternatives(node: "Node", graph: "HEBGraph") -> list["Node"]: + alternatives = [] + for pred, _node, _data in graph.in_edges(node, data=True): + if pred in graph.roots: + alternatives.extend([r for r in graph.roots if r != pred]) + return alternatives + + def _unroll_behavior( graph: "HEBGraph", behavior: "Behavior", @@ -119,13 +136,24 @@ def _unroll_behavior( ) if is_looping and cut_looping_alternatives: - if not _current_alternatives: - return graph, is_looping - for alternative in _current_alternatives: + for alternative in _current_alternatives[0]: for last_condition, _, data in graph.in_edges(behavior, data=True): graph.add_edge(last_condition, alternative, **data) - graph.remove_node(behavior) - return graph, False + if _current_alternatives[0]: + graph.remove_node(behavior) + return graph, False + if _current_alternatives[1]: + predecessors = list(graph.predecessors(behavior)) + for last_condition in predecessors: + successors = list(graph.successors(last_condition)) + for descendant in successors: + graph.remove_edge(last_condition, descendant) + if graph.neighbors(descendant) == 0: + graph.remove_node(descendant) + graph.remove_node(last_condition) + graph.remove_node(behavior) + return graph, False + raise NotImplementedError() if node_graph is None: # If we cannot get the node's graph, we keep it as is. @@ -151,7 +179,7 @@ def _unrolled_behavior_graph( cut_looping_alternatives: bool, _current_alternatives: List[Union["Action", "Behavior"]], _unrolled_behaviors: Dict[str, Optional["HEBGraph"]], -) -> Optional["HEBGraph"]: +) -> Tuple[Optional["HEBGraph"], bool]: """Get the unrolled sub-graph of a behavior. Args: @@ -188,7 +216,7 @@ def _add_prefix_to_graph(graph: "HEBGraph", prefix: str) -> None: if prefix is None: return graph - def rename(node: "Node"): + def rename(node: "Node") -> None: new_node = copy(node) new_node.name = prefix + node.name return new_node @@ -216,14 +244,14 @@ def group_behaviors_points( for i in range(len(groups[:-1])): key = tuple(groups[: -1 - i]) point = pos[node] - try: + if key in points_grouped_by_behavior: points_grouped_by_behavior[key].append(point) - except KeyError: + else: points_grouped_by_behavior[key] = [point] return points_grouped_by_behavior -def compose_heb_graphs(graph_of_reference: "HEBGraph", other_graph: "HEBGraph"): +def compose_heb_graphs(graph_of_reference: "HEBGraph", other_graph: "HEBGraph") -> None: """Returns a new_graph of graph_of_reference composed with other_graph. Composition is the simple union of the node sets and edge sets. @@ -237,9 +265,7 @@ def compose_heb_graphs(graph_of_reference: "HEBGraph", other_graph: "HEBGraph"): """ new_graph = graph_of_reference.__class__( - graph_of_reference.behavior, - all_behaviors=graph_of_reference.all_behaviors, - any_mode=graph_of_reference.any_mode, + graph_of_reference.behavior, all_behaviors=graph_of_reference.all_behaviors ) # add graph attributes, H attributes take precedent over G attributes new_graph.graph.update(graph_of_reference.graph) diff --git a/tests/__init__.py b/tests/__init__.py index a4983ea..66b76f8 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,22 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO """Tests for the heb_graph package.""" + +from typing import Protocol +from matplotlib import pyplot as plt + + +class Graph(Protocol): + def draw(self, ax, pos): + """Draw the graph on a matplotlib axes.""" + + def nodes(self) -> list: + """Return a list of nodes""" + + +def plot_graph(graph: Graph, **kwargs): + _, ax = plt.subplots() + graph.draw(ax, **kwargs) + plt.axis("off") # turn off axis + plt.show() diff --git a/tests/examples/behaviors/__init__.py b/tests/examples/behaviors/__init__.py index e0a7bc3..d339dd3 100644 --- a/tests/examples/behaviors/__init__.py +++ b/tests/examples/behaviors/__init__.py @@ -1,4 +1,32 @@ -from tests.examples.behaviors.basic import * -from tests.examples.behaviors.basic_empty import * - +from tests.examples.behaviors.basic import ( + FundamentalBehavior, + F_A_Behavior, + F_AA_Behavior, + F_F_A_Behavior, +) +from tests.examples.behaviors.basic_empty import ( + E_A_Behavior, + E_F_A_Behavior, + F_E_A_Behavior, + E_E_A_Behavior, +) from tests.examples.behaviors.binary_sum import build_binary_sum_behavior +from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors +from tests.examples.behaviors.loop_without_alternative import ( + build_looping_behaviors_without_direct_alternatives, +) + + +__all__ = [ + "FundamentalBehavior", + "F_A_Behavior", + "F_AA_Behavior", + "F_F_A_Behavior", + "E_A_Behavior", + "E_F_A_Behavior", + "F_E_A_Behavior", + "E_E_A_Behavior", + "build_binary_sum_behavior", + "build_looping_behaviors", + "build_looping_behaviors_without_direct_alternatives", +] diff --git a/tests/examples/behaviors/basic.py b/tests/examples/behaviors/basic.py index 8d1c88b..9bad5fb 100644 --- a/tests/examples/behaviors/basic.py +++ b/tests/examples/behaviors/basic.py @@ -67,6 +67,9 @@ def build_graph(self) -> HEBGraph: class F_F_A_Behavior(Behavior): """Double layer feature conditions behavior""" + def __init__(self, name: str = "F_F_A", *args, **kwargs) -> None: + super().__init__(name=name, *args, **kwargs) + def build_graph(self) -> HEBGraph: graph = HEBGraph(self) @@ -89,12 +92,11 @@ def build_graph(self) -> HEBGraph: class F_AA_Behavior(Behavior): """Feature condition with mutliple actions on same index.""" - def __init__(self, name: str, any_mode: str) -> None: + def __init__(self, name: str = "F_AA") -> None: super().__init__(name, image=None) - self.any_mode = any_mode def build_graph(self) -> HEBGraph: - graph = HEBGraph(self, any_mode=self.any_mode) + graph = HEBGraph(self) feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0) graph.add_edge(feature_condition, Action(0), index=int(True)) diff --git a/tests/examples/behaviors/loop_with_alternative.py b/tests/examples/behaviors/loop_with_alternative.py index 4fec202..dfd462c 100644 --- a/tests/examples/behaviors/loop_with_alternative.py +++ b/tests/examples/behaviors/loop_with_alternative.py @@ -1,8 +1,17 @@ -from typing import List +from typing import Any, List from hebg import HEBGraph, Action, FeatureCondition, Behavior +class HasItem(FeatureCondition): + def __init__(self, item_name: str) -> None: + self.item_name = item_name + super().__init__(name=f"Has {item_name} ?", complexity=1.0) + + def __call__(self, observation: Any) -> int: + return self.item_name in observation + + class GatherWood(Behavior): """Gather wood""" @@ -12,10 +21,10 @@ def __init__(self) -> None: def build_graph(self) -> HEBGraph: graph = HEBGraph(self) - feature = FeatureCondition("Has an axe") - graph.add_edge(feature, Action("Punch tree"), index=False) - graph.add_edge(feature, Behavior("Get new axe"), index=False) - graph.add_edge(feature, Action("Use axe on tree"), index=True) + has_axe = HasItem("axe") + graph.add_edge(has_axe, Action("Punch tree", complexity=2.0), index=False) + graph.add_edge(has_axe, Behavior("Get new axe", complexity=1.0), index=False) + graph.add_edge(has_axe, Action("Use axe on tree", complexity=1.0), index=True) return graph @@ -28,9 +37,12 @@ def __init__(self) -> None: def build_graph(self) -> HEBGraph: graph = HEBGraph(self) - feature = FeatureCondition("Has wood") - graph.add_edge(feature, Behavior("Gather wood"), index=False) - graph.add_edge(feature, Action("Craft axe"), index=True) + has_wood = HasItem("wood") + graph.add_edge(has_wood, Behavior("Gather wood", complexity=1.0), index=False) + graph.add_edge( + has_wood, Action("Summon axe out of thin air", complexity=10.0), index=False + ) + graph.add_edge(has_wood, Action("Craft axe", complexity=1.0), index=True) return graph @@ -39,4 +51,5 @@ def build_looping_behaviors() -> List[Behavior]: all_behaviors = {behavior.name: behavior for behavior in behaviors} for behavior in behaviors: behavior.graph.all_behaviors = all_behaviors + behavior.complexity = 5 return behaviors diff --git a/tests/examples/behaviors/loop_without_alternative.py b/tests/examples/behaviors/loop_without_alternative.py index 31e0d0d..94ee273 100644 --- a/tests/examples/behaviors/loop_without_alternative.py +++ b/tests/examples/behaviors/loop_without_alternative.py @@ -57,7 +57,7 @@ def build_graph(self) -> HEBGraph: return graph -def build_looping_behaviors() -> List[Behavior]: +def build_looping_behaviors_without_direct_alternatives() -> List[Behavior]: behaviors: List[Behavior] = [ ReachForest(), ReachOtherZone(), diff --git a/tests/examples/feature_conditions/scalar.py b/tests/examples/feature_conditions.py similarity index 95% rename from tests/examples/feature_conditions/scalar.py rename to tests/examples/feature_conditions.py index 75da674..4d38cdc 100644 --- a/tests/examples/feature_conditions/scalar.py +++ b/tests/examples/feature_conditions.py @@ -12,7 +12,7 @@ class Relation(Enum): LESSER_THAN = "<" def __init__( - self, relation: Union[Relation, str] = ">=", threshold: float = 0 + self, relation: Union[Relation, str] = ">=", threshold: float = 0, **kwargs ) -> None: """Threshold-based feature condition for scalar feature.""" self.relation = relation @@ -20,7 +20,7 @@ def __init__( self._relation = self.Relation(relation) display_name = self._relation.name.capitalize().replace("_", " ") name = f"{display_name} {threshold} ?" - super().__init__(name=name, image=None) + super().__init__(name=name, **kwargs) def __call__(self, observation: float) -> int: conditions = { diff --git a/tests/examples/feature_conditions/__init__.py b/tests/examples/feature_conditions/__init__.py deleted file mode 100644 index 255c52e..0000000 --- a/tests/examples/feature_conditions/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from tests.examples.feature_conditions.scalar import ( - ThresholdFeatureCondition, - IsDivisibleFeatureCondition, -) - - -__all__ = ["ThresholdFeatureCondition", "IsDivisibleFeatureCondition"] diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py deleted file mode 100644 index 489eafd..0000000 --- a/tests/integration/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO - -"""Integration tests for the heb_graph package.""" diff --git a/tests/integration/test_behavior.py b/tests/integration/test_behavior.py deleted file mode 100644 index d76f24d..0000000 --- a/tests/integration/test_behavior.py +++ /dev/null @@ -1,73 +0,0 @@ -# HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO - -"""Behavior of HEBGraphs when called.""" - -import pytest_check as check - -from hebg.node import Action -from tests.examples.behaviors import ( - FundamentalBehavior, - AA_Behavior, - F_A_Behavior, - F_F_A_Behavior, - AF_A_Behavior, - F_AA_Behavior, -) -from tests.examples.feature_conditions import ThresholdFeatureCondition - - -def test_a_graph(): - """(A) Fundamental behaviors (single action) should work properly.""" - action_id = 42 - behavior = FundamentalBehavior(Action(action_id)) - check.equal(behavior(None), action_id) - - -def test_f_a_graph(): - """(F-A) Feature condition should orient path properly.""" - feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0) - actions = {0: Action(0), 1: Action(1)} - behavior = F_A_Behavior("F_A", feature_condition, actions) - check.equal(behavior(1), 1) - check.equal(behavior(-1), 0) - - -def test_f_f_a_graph(): - """(F-F-A) Feature condition should orient path properly in double chain.""" - behavior = F_F_A_Behavior("F_F_A") - check.equal(behavior(-2), 0) - check.equal(behavior(-1), 1) - check.equal(behavior(1), 2) - check.equal(behavior(2), 3) - - -def test_aa_graph(): - """(AA) Should choose between roots depending on 'any_mode'.""" - behavior = AA_Behavior("AA", any_mode="first") - check.equal(behavior(None), 0) - - behavior = AA_Behavior("AA", any_mode="last") - check.equal(behavior(None), 1) - - -def test_af_a_graph(): - """(AF-A) Should choose between roots depending on 'any_mode'.""" - behavior = AF_A_Behavior("AF_A", any_mode="first") - check.equal(behavior(1), 0) - check.equal(behavior(-1), 0) - - behavior = AF_A_Behavior("AF_A", any_mode="last") - check.equal(behavior(1), 1) - check.equal(behavior(-1), 2) - - -def test_f_af_a_graph(): - """(F-AA) Should choose between condition edges depending on 'any_mode'.""" - behavior = F_AA_Behavior("F_AA", any_mode="first") - check.equal(behavior(1), 0) - check.equal(behavior(-1), 1) - - behavior = F_AA_Behavior("F_AA", any_mode="last") - check.equal(behavior(1), 0) - check.equal(behavior(-1), 2) diff --git a/tests/integration/test_loop_with_alternative.py b/tests/integration/test_loop_with_alternative.py deleted file mode 100644 index 641d6b2..0000000 --- a/tests/integration/test_loop_with_alternative.py +++ /dev/null @@ -1,93 +0,0 @@ -import pytest -import pytest_check as check - -import networkx as nx -from hebg import HEBGraph -from hebg.unrolling import unroll_graph - -from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors - -import matplotlib.pyplot as plt - - -class TestLoop: - """Tests for the loop example""" - - @pytest.fixture(autouse=True) - def setup_method(self): - self.gather_wood, self.get_new_axe = build_looping_behaviors() - - def test_unroll_gather_wood(self): - draw = False - unrolled_graph = unroll_graph(self.gather_wood.graph) - if draw: - _plot_graph(unrolled_graph) - - expected_graph = nx.DiGraph() - expected_graph.add_edge("Has axe", "Punch tree") - expected_graph.add_edge("Has axe", "Cut tree with axe") - expected_graph.add_edge("Has axe", "Has wood") - - # Expected sub-behavior - expected_graph.add_edge("Has wood", "Gather wood") - expected_graph.add_edge("Has wood", "Craft axe") - check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph)) - - def test_unroll_get_new_axe(self): - draw = False - unrolled_graph = unroll_graph(self.get_new_axe.graph) - if draw: - _plot_graph(unrolled_graph) - - expected_graph = nx.DiGraph() - expected_graph.add_edge("Has wood", "Has axe") - expected_graph.add_edge("Has wood", "Craft new axe") - - # Expected sub-behavior - expected_graph.add_edge("Has axe", "Punch tree") - expected_graph.add_edge("Has axe", "Cut tree with axe") - expected_graph.add_edge("Has axe", "Get new axe") - check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph)) - - def test_unroll_gather_wood_cutting_alternatives(self): - draw = False - unrolled_graph = unroll_graph( - self.gather_wood.graph, - cut_looping_alternatives=True, - ) - if draw: - _plot_graph(unrolled_graph) - - expected_graph = nx.DiGraph() - expected_graph.add_edge("Has axe", "Punch tree") - expected_graph.add_edge("Has axe", "Cut tree with axe") - expected_graph.add_edge("Has axe", "Has wood") - - # Expected sub-behavior - expected_graph.add_edge("Has wood", "Punch tree") - expected_graph.add_edge("Has wood", "Craft axe") - check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph)) - - def test_unroll_get_new_axe_cutting_alternatives(self): - draw = False - unrolled_graph = unroll_graph( - self.get_new_axe.graph, - cut_looping_alternatives=True, - ) - if draw: - _plot_graph(unrolled_graph) - - expected_graph = nx.DiGraph() - expected_graph.add_edge("Has wood", "Has axe") - expected_graph.add_edge("Has wood", "Craft new axe") - - # Expected sub-behavior - expected_graph.add_edge("Has axe", "Punch tree") - expected_graph.add_edge("Has axe", "Cut tree with axe") - check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph)) - - -def _plot_graph(graph: "HEBGraph"): - _, ax = plt.subplots() - graph.draw(ax) - plt.show() diff --git a/tests/integration/test_loop_without_alternative.py b/tests/integration/test_loop_without_alternative.py deleted file mode 100644 index a6e7490..0000000 --- a/tests/integration/test_loop_without_alternative.py +++ /dev/null @@ -1,45 +0,0 @@ -import pytest -import pytest_check as check - -import networkx as nx -from hebg import HEBGraph -from hebg.unrolling import unroll_graph - -from tests.examples.behaviors.loop_without_alternative import build_looping_behaviors - -import matplotlib.pyplot as plt - - -class TestLoop: - """Tests for the loop example""" - - @pytest.fixture(autouse=True) - def setup_method(self): - ( - self.reach_forest, - self.reach_other_zone, - self.reach_meadow, - ) = build_looping_behaviors() - - @pytest.mark.xfail - def test_unroll_reach_forest(self): - draw = False - unrolled_graph = unroll_graph( - self.reach_forest.graph, - add_prefix=True, - cut_looping_alternatives=True, - ) - if draw: - _plot_graph(unrolled_graph) - - expected_graph = nx.DiGraph() - check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph)) - - -def _plot_graph(graph: "HEBGraph"): - _, ax = plt.subplots() - pos = None - if len(graph.roots) == 0: - pos = nx.spring_layout(graph) - graph.draw(ax, pos=pos) - plt.show() diff --git a/tests/test_behavior.py b/tests/test_behavior.py new file mode 100644 index 0000000..b5de95e --- /dev/null +++ b/tests/test_behavior.py @@ -0,0 +1,138 @@ +# HEBGraph for explainable hierarchical reinforcement learning +# Copyright (C) 2021-2024 Mathïs FEDERICO + +"""Behavior of HEBGraphs when called.""" + +import pytest +import pytest_check as check +from pytest_mock import MockerFixture + +from hebg.behavior import Behavior +from hebg.heb_graph import HEBGraph +from hebg.node import Action + +from tests.examples.behaviors import FundamentalBehavior, F_A_Behavior, F_F_A_Behavior +from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors +from tests.examples.feature_conditions import ThresholdFeatureCondition + + +class TestBehavior: + + """Behavior""" + + @pytest.fixture(autouse=True) + def setup(self): + """Initialize variables.""" + self.node = Behavior("behavior_name") + + def test_node_type(self): + """should have 'behavior' as node_type.""" + check.equal(self.node.type, "behavior") + + def test_node_call(self, mocker: MockerFixture): + """should use graph on call.""" + mocker.patch("hebg.behavior.Behavior.graph") + self.node(None) + check.is_true(self.node.graph.called) + + def test_build_graph(self): + """should raise NotImplementedError when build_graph is called.""" + with pytest.raises(NotImplementedError): + self.node.build_graph() + + def test_graph(self, mocker: MockerFixture): + """should build graph and compute its levels if, and only if, + the graph is not yet built. + """ + mocker.patch("hebg.behavior.Behavior.build_graph") + mocker.patch("hebg.behavior.compute_levels") + self.node.graph + check.is_true(self.node.build_graph.called) + check.is_true(self.node.build_graph.called) + + mocker.patch("hebg.behavior.Behavior.build_graph") + mocker.patch("hebg.behavior.compute_levels") + self.node.graph + check.is_false(self.node.build_graph.called) + check.is_false(self.node.build_graph.called) + + +class TestPathfinding: + def test_fundamental_behavior(self): + """Fundamental behavior (single action) should return its action.""" + action_id = 42 + behavior = FundamentalBehavior(Action(action_id)) + check.equal(behavior(None), action_id) + + def test_feature_condition_single(self): + """Feature condition should orient path properly.""" + feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0) + actions = {0: Action(0), 1: Action(1)} + behavior = F_A_Behavior("F_A", feature_condition, actions) + check.equal(behavior(1), 1) + check.equal(behavior(-1), 0) + + def test_feature_conditions_chained(self): + """Feature condition should orient path properly in double chain.""" + behavior = F_F_A_Behavior("F_F_A") + check.equal(behavior(-2), 0) + check.equal(behavior(-1), 1) + check.equal(behavior(1), 2) + check.equal(behavior(2), 3) + + def test_looping_resolve(self): + """Loops with alternatives should be ignored.""" + _gather_wood, get_axe = build_looping_behaviors() + check.equal(get_axe({}), "Punch tree") + + +class TestCostBehavior: + def test_choose_root_of_lesser_cost(self): + """Should choose root of lesser cost.""" + + expected_action = "EXPECTED" + + class AAA_Behavior(Behavior): + def __init__(self) -> None: + super().__init__("AAA") + + def build_graph(self) -> HEBGraph: + graph = HEBGraph(self) + graph.add_node(Action(0, complexity=2)) + graph.add_node(Action(expected_action, complexity=1)) + graph.add_node(Action(2, complexity=3)) + return graph + + behavior = AAA_Behavior() + check.equal(behavior(None), expected_action) + + def test_not_path_of_least_cost(self): + """Should choose path of larger complexity if individual costs lead to it.""" + + class AF_A_Behavior(Behavior): + + """Double root with feature condition and action""" + + def __init__(self) -> None: + super().__init__("AF_A") + + def build_graph(self) -> HEBGraph: + graph = HEBGraph(self) + + graph.add_node(Action(0, complexity=1.5)) + feature_condition = ThresholdFeatureCondition( + relation=">=", threshold=0, complexity=1.0 + ) + + graph.add_edge( + feature_condition, Action(1, complexity=1.0), index=int(True) + ) + graph.add_edge( + feature_condition, Action(2, complexity=1.0), index=int(False) + ) + + return graph + + behavior = AF_A_Behavior() + check.equal(behavior(1), 1) + check.equal(behavior(-1), 2) diff --git a/tests/integration/test_behavior_empty.py b/tests/test_behavior_empty.py similarity index 94% rename from tests/integration/test_behavior_empty.py rename to tests/test_behavior_empty.py index 7efd053..6c06b19 100644 --- a/tests/integration/test_behavior_empty.py +++ b/tests/test_behavior_empty.py @@ -1,5 +1,5 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO """Behavior of HEBGraphs with empty nodes.""" diff --git a/tests/test_call_graph.py b/tests/test_call_graph.py new file mode 100644 index 0000000..4e297ed --- /dev/null +++ b/tests/test_call_graph.py @@ -0,0 +1,284 @@ +from networkx import DiGraph +import pytest + +from hebg.behavior import Behavior +from hebg.call_graph import CallEdgeStatus, CallGraph, CallNode, _call_graph_pos +from hebg.heb_graph import HEBGraph +from hebg.node import Action, FeatureCondition + +from pytest_mock import MockerFixture +import pytest_check as check + +from tests import plot_graph + +from tests.examples.behaviors import F_F_A_Behavior +from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors +from tests.examples.feature_conditions import ThresholdFeatureCondition + + +class TestCall: + """Ensure that the call graph is faithful for debugging and efficient breadth first search.""" + + def test_call_stack_without_branches(self) -> None: + """When there is no branches, the graph should be a simple sequence of the call stack.""" + f_f_a_behavior = F_F_A_Behavior() + + draw = False + if draw: + plot_graph(f_f_a_behavior.graph.unrolled_graph) + f_f_a_behavior(observation=-2) + + expected_graph = DiGraph( + [ + ("F_F_A", "Greater or equal to 0 ?"), + ("Greater or equal to 0 ?", "Greater or equal to -1 ?"), + ("Greater or equal to -1 ?", "Action(0)"), + ] + ) + + call_graph = f_f_a_behavior.graph.call_graph + assert set(call_graph.call_edge_labels()) == set(expected_graph.edges()) + + def test_split_on_same_fc_index(self, mocker: MockerFixture) -> None: + """When there are multiple indexes on the same feature condition, + a branch should be created.""" + + expected_action = Action("EXPECTED", complexity=1) + + forbidden_value = "FORBIDDEN" + forbidden_action = Action(forbidden_value, complexity=2) + forbidden_action.__call__ = mocker.MagicMock(return_value=forbidden_value) + + class F_AA_Behavior(Behavior): + """Feature condition with mutliple actions on same index.""" + + def __init__(self) -> None: + super().__init__("F_AA") + + def build_graph(self) -> HEBGraph: + graph = HEBGraph(self) + feature_condition = ThresholdFeatureCondition( + relation=">=", threshold=0 + ) + graph.add_edge( + feature_condition, Action(0, complexity=0), index=int(True) + ) + graph.add_edge(feature_condition, forbidden_action, index=int(False)) + graph.add_edge(feature_condition, expected_action, index=int(False)) + + return graph + + f_aa_behavior = F_AA_Behavior() + draw = False + if draw: + plot_graph(f_aa_behavior.graph.unrolled_graph) + + # Sanity check that the right action should be called and not the forbidden one. + assert f_aa_behavior(observation=-1) == expected_action.action + forbidden_action.__call__.assert_not_called() + + # Graph should have the good split + call_graph = f_aa_behavior.graph.call_graph + expected_graph = DiGraph( + [ + ("F_AA", "Greater or equal to 0 ?"), + ("Greater or equal to 0 ?", "Action(EXPECTED)"), + ("Greater or equal to 0 ?", "Action(FORBIDDEN)"), + ] + ) + assert set(call_graph.call_edge_labels()) == set(expected_graph.edges()) + + @pytest.mark.xfail + def test_multiple_call_to_same_fc(self, mocker: MockerFixture) -> None: + """Call graph should allow for the same feature condition + to be called multiple times in the same branch (in different behaviors).""" + expected_action = Action("EXPECTED") + unexpected_action = Action("UNEXPECTED") + + feature_condition = ThresholdFeatureCondition(relation=">=", threshold=0) + + class SubBehavior(Behavior): + def __init__(self) -> None: + super().__init__("SubBehavior") + + def build_graph(self) -> HEBGraph: + graph = HEBGraph(self) + graph.add_edge(feature_condition, expected_action, index=int(True)) + graph.add_edge(feature_condition, unexpected_action, index=int(False)) + return graph + + class RootBehavior(Behavior): + """Feature condition with mutliple actions on same index.""" + + def __init__(self) -> None: + super().__init__("RootBehavior") + + def build_graph(self) -> HEBGraph: + graph = HEBGraph(self) + graph.add_edge(feature_condition, SubBehavior(), index=int(True)) + graph.add_edge(feature_condition, unexpected_action, index=int(False)) + + return graph + + root_behavior = RootBehavior() + draw = False + if draw: + plot_graph(root_behavior.graph.unrolled_graph) + + # Sanity check that the right action should be called and not the forbidden one. + assert root_behavior(observation=2) == expected_action.action + + # Graph should have the good split + call_graph = root_behavior.graph.call_graph + expected_graph = DiGraph( + [ + ("RootBehavior", "Greater or equal to 0 ?"), + ("Greater or equal to 0 ?", "SubBehavior"), + ("SubBehavior", "Greater or equal to 0 ?"), + ("Greater or equal to 0 ?", "Action(EXPECTED)"), + ] + ) + assert set(call_graph.call_edge_labels()) == set(expected_graph.edges()) + + expected_labels = { + CallNode(0, 0): "RootBehavior", + CallNode(0, 1): "Greater or equal to 0 ?", + CallNode(0, 2): "SubBehavior", + CallNode(0, 3): "Greater or equal to 0 ?", + CallNode(0, 4): "Action(EXPECTED)", + } + for node, label in call_graph.nodes(data="label"): + check.equal(label, expected_labels[node]) + + def test_chain_behaviors(self, mocker: MockerFixture) -> None: + """When sub-behaviors with a graph are called recursively, + the call graph should still find their nodes.""" + + expected_action = "EXPECTED" + + class DummyBehavior(Behavior): + __call__ = mocker.MagicMock(return_value=expected_action) + + class SubBehavior(Behavior): + def __init__(self) -> None: + super().__init__("SubBehavior") + + def build_graph(self) -> HEBGraph: + graph = HEBGraph(self) + graph.add_node(DummyBehavior("Dummy")) + return graph + + class RootBehavior(Behavior): + def __init__(self) -> None: + super().__init__("RootBehavior") + + def build_graph(self) -> HEBGraph: + graph = HEBGraph(self) + graph.add_node(Behavior("SubBehavior")) + return graph + + sub_behavior = SubBehavior() + sub_behavior.graph + + root_behavior = RootBehavior() + root_behavior.graph.all_behaviors["SubBehavior"] = sub_behavior + + # Sanity check that the right action should be called. + assert root_behavior(observation=-1) == expected_action + + call_graph = root_behavior.graph.call_graph + expected_graph = DiGraph( + [ + ("RootBehavior", "SubBehavior"), + ("SubBehavior", "Dummy"), + ] + ) + assert set(call_graph.call_edge_labels()) == set(expected_graph.edges()) + + def test_looping_goback(self) -> None: + """Loops with alternatives should be ignored.""" + draw = False + _gather_wood, get_axe = build_looping_behaviors() + assert get_axe({}) == "Punch tree" + + call_graph = get_axe.graph.call_graph + + if draw: + plot_graph(call_graph) + + expected_labels = { + CallNode(0, 0): "Get new axe", + CallNode(0, 1): "Has wood ?", + CallNode(1, 2): "Action(Summon axe out of thin air)", + CallNode(0, 2): "Gather wood", + CallNode(0, 3): "Has axe ?", + CallNode(2, 4): "Get new axe", + CallNode(0, 4): "Action(Punch tree)", + } + for node, label in call_graph.nodes(data="label"): + check.equal(label, expected_labels[node]) + + expected_graph = DiGraph( + [ + ("Get new axe", "Has wood ?"), + ("Has wood ?", "Action(Summon axe out of thin air)"), + ("Has wood ?", "Gather wood"), + ("Gather wood", "Has axe ?"), + ("Has axe ?", "Get new axe"), + ("Has axe ?", "Action(Punch tree)"), + ] + ) + + assert set(call_graph.call_edge_labels()) == set(expected_graph.edges()) + + +class TestDraw: + """Ensures that the graph is readable even in complex situations.""" + + def test_result_on_first_branch(self) -> None: + """Resulting action should always be on the first branch.""" + draw = False + root_behavior = Behavior("Root", complexity=20) + call_graph = CallGraph() + call_graph.add_root(root_behavior, None) + + nodes = [ + (CallNode(0, 1), FeatureCondition("FC1", complexity=1)), + (CallNode(0, 2), FeatureCondition("FC2", complexity=1)), + (CallNode(0, 3), root_behavior), + (CallNode(1, 1), FeatureCondition("FC3", complexity=1)), + (CallNode(1, 2), FeatureCondition("FC4", complexity=1)), + (CallNode(1, 3), FeatureCondition("FC5", complexity=1)), + (CallNode(1, 4), Action("A", complexity=1)), + ] + + for node, heb_node in nodes: + call_graph.add_node(node, heb_node, None) + + edges = [ + (CallNode(0, 0), CallNode(0, 1), CallEdgeStatus.CALLED), + (CallNode(0, 1), CallNode(0, 2), CallEdgeStatus.CALLED), + (CallNode(0, 2), CallNode(0, 3), CallEdgeStatus.FAILURE), + (CallNode(0, 0), CallNode(1, 1), CallEdgeStatus.CALLED), + (CallNode(1, 1), CallNode(1, 2), CallEdgeStatus.CALLED), + (CallNode(1, 2), CallNode(1, 3), CallEdgeStatus.CALLED), + (CallNode(1, 3), CallNode(1, 4), CallEdgeStatus.CALLED), + ] + + for start, end, status in edges: + call_graph.add_edge(start, end, status) + + expected_poses = { + CallNode(0, 0): [0, 0], + CallNode(1, 1): [0, -1], + CallNode(1, 2): [0, -2], + CallNode(1, 3): [0, -3], + CallNode(1, 4): [0, -4], + CallNode(0, 1): [1, -1], + CallNode(0, 2): [1, -2], + CallNode(0, 3): [1, -3], + } + if draw: + plot_graph(call_graph) + + assert _call_graph_pos(call_graph) == expected_poses diff --git a/tests/integration/test_code_generation.py b/tests/test_code_generation.py similarity index 100% rename from tests/integration/test_code_generation.py rename to tests/test_code_generation.py diff --git a/tests/test_loop.py b/tests/test_loop.py new file mode 100644 index 0000000..5ab85db --- /dev/null +++ b/tests/test_loop.py @@ -0,0 +1,122 @@ +import pytest +import pytest_check as check + +import networkx as nx +from hebg.unrolling import unroll_graph +from tests import plot_graph + +from tests.examples.behaviors.loop_with_alternative import build_looping_behaviors +from tests.examples.behaviors.loop_without_alternative import ( + build_looping_behaviors_without_direct_alternatives, +) + + +class TestLoopAlternative: + """Tests for the loop with alternative example""" + + @pytest.fixture(autouse=True) + def setup_method(self): + self.gather_wood, self.get_new_axe = build_looping_behaviors() + + def test_unroll_gather_wood(self): + draw = False + unrolled_graph = unroll_graph(self.gather_wood.graph, add_prefix=True) + if draw: + plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True) + + expected_graph = nx.DiGraph() + expected_graph.add_edge("Has axe", "Punch tree") + expected_graph.add_edge("Has axe", "Cut tree with axe") + expected_graph.add_edge("Has axe", "Has wood") + + # Expected sub-behavior + expected_graph.add_edge("Has wood", "Gather wood") + expected_graph.add_edge("Has wood", "Craft axe") + expected_graph.add_edge("Has wood", "Summon axe out of thin air") + check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph)) + + def test_unroll_get_new_axe(self): + draw = False + unrolled_graph = unroll_graph(self.get_new_axe.graph, add_prefix=True) + if draw: + plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True) + + expected_graph = nx.DiGraph() + expected_graph.add_edge("Has wood", "Has axe") + expected_graph.add_edge("Has wood", "Craft new axe") + expected_graph.add_edge("Has wood", "Summon axe out of thin air") + + # Expected sub-behavior + expected_graph.add_edge("Has axe", "Punch tree") + expected_graph.add_edge("Has axe", "Cut tree with axe") + expected_graph.add_edge("Has axe", "Get new axe") + check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph)) + + def test_unroll_gather_wood_cutting_alternatives(self): + draw = False + unrolled_graph = unroll_graph( + self.gather_wood.graph, add_prefix=True, cut_looping_alternatives=True + ) + if draw: + plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True) + + expected_graph = nx.DiGraph() + expected_graph.add_edge("Has axe", "Punch tree") + expected_graph.add_edge("Has axe", "Has wood") + expected_graph.add_edge("Has axe", "Use axe") + + # Expected sub-behavior + expected_graph.add_edge("Has wood", "Summon axe of out thin air") + expected_graph.add_edge("Has wood", "Craft axe") + + check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph)) + + def test_unroll_get_new_axe_cutting_alternatives(self): + draw = False + unrolled_graph = unroll_graph( + self.get_new_axe.graph, add_prefix=True, cut_looping_alternatives=True + ) + if draw: + plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True) + + expected_graph = nx.DiGraph( + [ + ("Has wood", "Has axe"), + ("Has wood", "Craft new axe"), + ("Has wood", "Summon axe out of thin air"), + # Expected sub-behavior + ("Has axe", "Punch tree"), + ("Has axe", "Cut tree with axe"), + ] + ) + check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph)) + + @pytest.mark.xfail + def test_unroll_root_alternative_reach_forest(self): + ( + reach_forest, + _reach_other_zone, + _reach_meadow, + ) = build_looping_behaviors_without_direct_alternatives() + draw = False + unrolled_graph = unroll_graph( + reach_forest.graph, + add_prefix=True, + cut_looping_alternatives=True, + ) + if draw: + plot_graph(unrolled_graph, draw_hulls=True, show_all_hulls=True) + + expected_graph = nx.DiGraph( + [ + # ("Root", "Is in other zone ?"), + # ("Root", "Is in meadow ?"), + ("Is in other zone ?", "Reach other zone"), + ("Is in other zone ?", "Go to forest"), + ("Is in meadow ?", "Go to forest"), + ("Is in meadow ?", "Reach meadow>Is in other zones ?"), + ("Reach meadow>Is in other zone ?", "Reach meadow>Reach other zone"), + ("Reach meadow>Is in other zone ?", "Reach meadow>Go to forest"), + ] + ) + check.is_true(nx.is_isomorphic(unrolled_graph, expected_graph)) diff --git a/tests/unit/layouts/test_metaheuristics.py b/tests/test_metaheuristics.py similarity index 91% rename from tests/unit/layouts/test_metaheuristics.py rename to tests/test_metaheuristics.py index c9f096e..a2c3517 100644 --- a/tests/unit/layouts/test_metaheuristics.py +++ b/tests/test_metaheuristics.py @@ -1,5 +1,5 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO """Unit tests for the hebg.metaheuristics module.""" diff --git a/tests/unit/test_node.py b/tests/test_node.py similarity index 97% rename from tests/unit/test_node.py rename to tests/test_node.py index 2f13d77..e034918 100644 --- a/tests/unit/test_node.py +++ b/tests/test_node.py @@ -1,5 +1,5 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO """Unit tests for the hebg.node module.""" diff --git a/tests/integration/test_paper_basic_example.py b/tests/test_paper_basic_example.py similarity index 99% rename from tests/integration/test_paper_basic_example.py rename to tests/test_paper_basic_example.py index 70ad159..7d70eff 100644 --- a/tests/integration/test_paper_basic_example.py +++ b/tests/test_paper_basic_example.py @@ -1,5 +1,5 @@ # HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO +# Copyright (C) 2021-2024 Mathïs FEDERICO """Integration tests for the initial paper examples.""" diff --git a/tests/integration/test_pet_a_cat.py b/tests/test_pet_a_cat.py similarity index 98% rename from tests/integration/test_pet_a_cat.py rename to tests/test_pet_a_cat.py index dd936e2..2919336 100644 --- a/tests/integration/test_pet_a_cat.py +++ b/tests/test_pet_a_cat.py @@ -27,7 +27,7 @@ from hebg import HEBGraph, Action, FeatureCondition, Behavior from hebg.unrolling import unroll_graph -from tests.integration.test_code_generation import _unidiff_output +from tests.test_code_generation import _unidiff_output class Pet(Action): diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py deleted file mode 100644 index 8c9c1f0..0000000 --- a/tests/unit/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO - -"""Unit tests for the heb_graph package.""" diff --git a/tests/unit/layouts/__init__.py b/tests/unit/layouts/__init__.py deleted file mode 100644 index f0cb5ff..0000000 --- a/tests/unit/layouts/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO - -"""Unit tests for the hebg.layouts submodules.""" diff --git a/tests/unit/metrics/__init__.py b/tests/unit/metrics/__init__.py deleted file mode 100644 index 3e53eae..0000000 --- a/tests/unit/metrics/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO - -"""Unit tests for the hebg.metrics module.""" diff --git a/tests/unit/metrics/complexity/__init__.py b/tests/unit/metrics/complexity/__init__.py deleted file mode 100644 index d670a73..0000000 --- a/tests/unit/metrics/complexity/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO - -"""Unit tests for the hebg.metrics.complexity module.""" diff --git a/tests/unit/metrics/complexity/test_complexities.py b/tests/unit/metrics/complexity/test_complexities.py deleted file mode 100644 index f1fb124..0000000 --- a/tests/unit/metrics/complexity/test_complexities.py +++ /dev/null @@ -1,14 +0,0 @@ -# HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO - -"""Integration tests for the hebg.metrics.complexity.complexities module.""" - -import pytest - - -class TestComplexities: - """Complexities""" - - @pytest.fixture(autouse=True) - def setup(self): - """Initialize variables.""" diff --git a/tests/unit/test_option.py b/tests/unit/test_option.py deleted file mode 100644 index d6e6f3d..0000000 --- a/tests/unit/test_option.py +++ /dev/null @@ -1,50 +0,0 @@ -# HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO - -"""Unit tests for the hebg.behavior module.""" - -import pytest -import pytest_check as check -from pytest_mock import MockerFixture - -from hebg.behavior import Behavior - - -class TestBehavior: - """Behavior""" - - @pytest.fixture(autouse=True) - def setup(self): - """Initialize variables.""" - self.node = Behavior("behavior_name") - - def test_node_type(self): - """should have 'behavior' as node_type.""" - check.equal(self.node.type, "behavior") - - def test_node_call(self, mocker: MockerFixture): - """should use graph on call.""" - mocker.patch("hebg.behavior.Behavior.graph") - self.node(None) - check.is_true(self.node.graph.called) - - def test_build_graph(self): - """should raise NotImplementedError when build_graph is called.""" - with pytest.raises(NotImplementedError): - self.node.build_graph() - - def test_graph(self, mocker: MockerFixture): - """should build graph and compute its levels if, and only if, - the graph is not yet built. - """ - mocker.patch("hebg.behavior.Behavior.build_graph") - mocker.patch("hebg.behavior.compute_levels") - self.node.graph - check.is_true(self.node.build_graph.called) - check.is_true(self.node.build_graph.called) - - mocker.patch("hebg.behavior.Behavior.build_graph") - mocker.patch("hebg.behavior.compute_levels") - self.node.graph - check.is_false(self.node.build_graph.called) - check.is_false(self.node.build_graph.called) diff --git a/tests/unit/test_option_graph.py b/tests/unit/test_option_graph.py deleted file mode 100644 index 7af386c..0000000 --- a/tests/unit/test_option_graph.py +++ /dev/null @@ -1,450 +0,0 @@ -# HEBGraph for explainable hierarchical reinforcement learning -# Copyright (C) 2021-2022 Mathïs FEDERICO -# pylint: disable=protected-access, unused-argument, missing-function-docstring - -"""Unit tests for the hebg.behavior module.""" - -from copy import deepcopy - -import pytest -import pytest_check as check -from pytest_mock import MockerFixture - -from hebg.heb_graph import HEBGraph, DiGraph, Behavior - - -class TestHEBGraph: - """HEBGraph""" - - @pytest.fixture(autouse=True) - def setup(self): - """Initialize variables.""" - self.behavior = Behavior("base_behavior_name") - self.heb_graph = HEBGraph(self.behavior) - - def test_init(self): - """should instanciate correctly.""" - graph = self.heb_graph - check.equal(graph.behavior, self.behavior) - check.equal(graph.all_behaviors, {}) - check.equal(graph.any_mode, "first") - check.is_true(isinstance(graph, DiGraph)) - - def test_add_node(self, mocker: MockerFixture): - """should add Node to the graph correctly.""" - mocker.patch("hebg.heb_graph.DiGraph.add_node") - - class DummyNode: - """DummyNode""" - - name = "node_name" - type = "node_type" - image = "node_image" - - def __str__(self) -> str: - return self.name - - node = DummyNode() - expected_kwargs = {"type": "node_type", "image": "node_image", "color": None} - expected_args = (node,) - - self.heb_graph.add_node(node) - args, kwargs = DiGraph.add_node.call_args - check.equal(kwargs, expected_kwargs) - check.equal(args, expected_args) - - def test_add_edge_only(self, mocker: MockerFixture): - """should add edges to the graph correctly if nodes already exists.""" - mocker.patch("hebg.heb_graph.DiGraph.add_edge") - - class DummyNode: - """DummyNode""" - - type = "node_type" - image = "node_image" - - def __init__(self, i) -> None: - self.name = f"node_name_{i}" - - def __str__(self) -> str: - return self.name - - node_0, node_1 = DummyNode(0), DummyNode(1) - self.heb_graph.add_node(node_0) - self.heb_graph.add_node(node_1) - - mocker.patch("hebg.heb_graph.DiGraph.add_node") - expected_kwargs = {"index": 42, "color": "black"} - expected_args = (node_0, node_1) - - self.heb_graph.add_edge(node_0, node_1, index=42) - args, kwargs = DiGraph.add_edge.call_args - check.equal(kwargs, expected_kwargs) - check.equal(args, expected_args) - check.is_false(DiGraph.add_node.called) - - def test_add_edge_and_nodes(self, mocker: MockerFixture): - """should add edges and nodes to the graph correctly if nodes are not in the graph yet.""" - mocker.patch("hebg.heb_graph.DiGraph.add_edge") - mocker.patch("hebg.heb_graph.DiGraph.add_node") - - class DummyNode: - """DummyNode""" - - type = "node_type" - image = "node_image" - - def __init__(self, i) -> None: - self.name = f"node_name_{i}" - - def __str__(self) -> str: - return self.name - - node_0, node_1 = DummyNode(0), DummyNode(1) - - expected_nodes_args = ((node_0,), (node_1,)) - expected_edge_kwargs = {"index": 42, "color": "black"} - expected_edge_args = (node_0, node_1) - - self.heb_graph.add_edge(node_0, node_1, index=42) - args, kwargs = DiGraph.add_edge.call_args - check.equal(kwargs, expected_edge_kwargs) - check.equal(args, expected_edge_args) - - for i, (args, _) in enumerate(DiGraph.add_node.call_args_list): - check.equal(args, expected_nodes_args[i]) - - def test_call(self, mocker: MockerFixture): - """should return roots action on call.""" - roots = "roots" - observation = "obs" - mocker.patch("hebg.heb_graph.HEBGraph._get_options") - mocker.patch("hebg.heb_graph.HEBGraph.roots", roots) - self.heb_graph(observation) - args, _ = HEBGraph._get_options.call_args - check.equal(args[0], roots) - check.equal(args[1], observation) - check.equal(args[2], [self.behavior.name]) - - def test_roots(self, mocker: MockerFixture): - """should have roots as property.""" - - nodes = ["A", "B", "C", "AA", "AB"] - predecessors = {"A": [], "B": [], "C": [], "AA": ["A"], "AB": ["A", "B"]} - - mocker.patch("hebg.heb_graph.HEBGraph.nodes", lambda self: nodes) - mocker.patch( - "hebg.heb_graph.HEBGraph.predecessors", - lambda self, node: predecessors[node], - ) - - check.equal(self.heb_graph.roots, ["A", "B", "C"]) - - -class TestHEBGraphGetAnyAction: - """HEBGraph._get_any_action""" - - @pytest.fixture(autouse=True) - def setup(self): - """Initialize variables.""" - self.behavior = Behavior("behavior_name") - self.heb_graph = HEBGraph(self.behavior) - - def test_none_in_actions(self, mocker: MockerFixture): - """should return None if any node returns None.""" - - actions = [0, "Impossible", 2, None, 3] - - _actions = deepcopy(actions) - - def mocked_get_action(*args, **kwargs): - return _actions.pop(0) - - mocker.patch("hebg.heb_graph.HEBGraph._get_action", mocked_get_action) - action = self.heb_graph._get_any_action(range(5), None, None) - check.is_none(action) - - def test_no_actions(self, mocker: MockerFixture): - """should return 'Impossible' if no action is possible.""" - - actions = ["Impossible", "Impossible", "Impossible", "Impossible", "Impossible"] - - _actions = deepcopy(actions) - - def mocked_get_action(*args, **kwargs): - return _actions.pop(0) - - mocker.patch("hebg.heb_graph.HEBGraph._get_action", mocked_get_action) - action = self.heb_graph._get_any_action(range(5), None, None) - check.equal(action, "Impossible") - - action = self.heb_graph._get_any_action([], None, None) - check.equal(action, "Impossible") - - @pytest.mark.parametrize("any_mode", HEBGraph.ANY_MODES) - def test_any_mode_(self, any_mode, mocker: MockerFixture): - actions = [0, "Impossible", 2, 3, "Impossible"] - - _actions = deepcopy(actions) - - def mocked_get_action(*args, **kwargs): - return _actions.pop(0) - - expected_actions = {"first": (0,), "last": (3,), "random": (0, 2, 3)} - - mocker.patch("hebg.heb_graph.HEBGraph._get_action", mocked_get_action) - self.heb_graph.any_mode = any_mode - action = self.heb_graph._get_any_action(range(5), None, None) - check.is_in(action, expected_actions[any_mode]) - - -class TestHEBGraphGetAction: - """HEBGraph._get_action""" - - @pytest.fixture(autouse=True) - def setup(self): - """Initialize variables.""" - self.behavior = Behavior("behavior_name") - self.heb_graph = HEBGraph(self.behavior) - - def test_action(self): - """should return action given by an Action node.""" - expected_action = "action_action" - - class DummyAction: - """DummyNode""" - - type = "action" - - def __call__(self, observation): - return expected_action - - action_node = DummyAction() - action = self.heb_graph._get_action(action_node, None, None) - check.equal(action, expected_action) - - def test_empty(self, mocker: MockerFixture): - """should return next successor returns when given an Empty node.""" - - expected_action = "action_action" - - class DummyAction: - """DummyNode""" - - type = "action" - - def __call__(self, observation): - return expected_action - - mocker.patch( - "hebg.heb_graph.HEBGraph.successors", - lambda self, node: iter([DummyAction()]), - ) - - class DummyEmpty: - """DummyNode""" - - type = "empty" - - empty_node = DummyEmpty() - action = self.heb_graph._get_action(empty_node, None, None) - check.equal(action, expected_action) - - def test_unknowed_node_type(self): - """should raise ValueError if node.type is unknowed.""" - - class DummyNode: - """DummyNode""" - - type = "random_type_error" - - node = DummyNode() - with pytest.raises(ValueError): - self.heb_graph._get_action(node, None, None) - - def test_behavior_in_search(self): - """should return 'Impossible' if behavior is already in search to avoid cycles.""" - - class DummyBehavior: - """DummyBehavior""" - - type = "behavior" - name = "behavior_already_in_search" - - def __str__(self) -> str: - return self.name - - node = DummyBehavior() - action = self.heb_graph._get_action(node, None, ["behavior_already_in_search"]) - check.equal(action, "Impossible") - - def test_behavior_call(self): - """should return behavior's return if behavior can be called.""" - expected_action = "behavior_action" - - class DummyBehavior: - """DummyBehavior""" - - type = "behavior" - name = "behavior_name" - - def __str__(self) -> str: - return self.name - - def __call__(self, *args) -> str: - return expected_action - - node = DummyBehavior() - action = self.heb_graph._get_action(node, None, []) - check.equal(action, expected_action) - - def test_behavior_by_all_behaviors(self): - """should use all_behaviors if behavior cannot be called.""" - expected_action = "behavior_action" - - class DummyTrueBehavior: - """DummyBehavior""" - - type = "behavior" - name = "behavior_name" - - def __str__(self) -> str: - return self.name - - def __call__(self, *args): - return expected_action - - class DummyBehaviorInGraph: - """DummyBehavior""" - - type = "behavior" - name = "behavior_name" - - def __str__(self) -> str: - return self.name - - def __call__(self, *args): - raise NotImplementedError - - true_node = DummyTrueBehavior() - self.heb_graph.all_behaviors = {true_node.name: true_node} - - node_in_graph = DummyBehaviorInGraph() - action = self.heb_graph._get_action( - node_in_graph, - observation=None, - behaviors_in_search=[], - ) - check.equal(action, expected_action) - - def test_feature_condition(self, mocker: MockerFixture): - """should use FeatureCondition's given index to orient in graph.""" - - class DummyAction: - """DummyAction""" - - type = "action" - name = "action_name" - - def __init__(self, i) -> None: - self.index = i - self.action = f"action_{i}" - - def __str__(self) -> str: - return self.name - - def __call__(self, *args) -> str: - return self.action - - class DummyFeatureCondition: - """DummyFeatureCondition""" - - type = "feature_condition" - name = "feature_condition_name" - - def __init__(self, i) -> None: - self.fc_index = i - - def __str__(self) -> str: - return self.name - - def __call__(self, *args) -> int: - return self.fc_index - - actions = [DummyAction(i) for i in range(3)] - mocker.patch( - "hebg.heb_graph.HEBGraph.successors", - lambda self, node: actions, - ) - - class DummyEdges: - """DummyEdges""" - - def __getitem__(self, e): - return {"index": e[1].index} - - mocker.patch("hebg.heb_graph.HEBGraph.edges", DummyEdges()) - mocker.patch( - "hebg.heb_graph.HEBGraph._get_any_action", - lambda self, next_nodes, observation, behaviors_in_search: next_nodes[0]( - observation - ), - ) - - for fc_index in range(3): - node = DummyFeatureCondition(fc_index) - action = self.heb_graph._get_action(node, None, []) - expected_action = f"action_{fc_index}" - check.equal(action, expected_action) - - def test_feature_condition_index_value_error(self, mocker: MockerFixture): - """should raise ValueError if FeatureCondition's given index represents no successor.""" - - class DummyAction: - """DummyAction""" - - type = "action" - name = "action_name" - - def __init__(self, i) -> None: - self.index = i - self.action = f"action_{i}" - - def __str__(self) -> str: - return self.name - - def __call__(self, *args) -> str: - return self.action - - class DummyFeatureCondition: - """DummyFeatureCondition""" - - type = "feature_condition" - name = "feature_condition_name" - - def __init__(self, i) -> None: - self.fc_index = i - - def __str__(self) -> str: - return self.name - - def __call__(self, *args) -> int: - return self.fc_index - - actions = [DummyAction(i) for i in range(3)] - mocker.patch( - "hebg.heb_graph.HEBGraph.successors", - lambda self, node: actions, - ) - - class DummyEdges: - """DummyEdges""" - - def __getitem__(self, e): - return {"index": e[1].index} - - mocker.patch("hebg.heb_graph.HEBGraph.edges", DummyEdges()) - - node = DummyFeatureCondition(4) - with pytest.raises(ValueError): - self.heb_graph._get_action(node, None, [])