diff --git a/src/hebg/call_graph.py b/src/hebg/call_graph.py index fec6d84..28e24eb 100644 --- a/src/hebg/call_graph.py +++ b/src/hebg/call_graph.py @@ -64,9 +64,13 @@ def call_nodes( # Search for name reference in all_behaviors if node.name in heb_graph.all_behaviors: node = heb_graph.all_behaviors[node.name] - action = node(observation, self) + if not hasattr(node, "graph"): + action = node(observation) + break + self._extend_frontiere(node.graph.roots, heb_graph=node.graph) elif node.type == "action": action = node(observation) + break elif node.type == "feature_condition": if node in self._known_fc: next_edge_index = self._known_fc[node] @@ -202,8 +206,9 @@ def draw( max_x, max_y = pos_arr.max(axis=0) min_x, min_y = pos_arr.min(axis=0) y_range = max_y - min_y + x_range = max_x - min_x ax.set_ylim([min_y - 0.1 * y_range, max_y + 0.1 * y_range]) - ax.set_xlim([min_x - 0.1 * y_range, min_x + y_range + 0.1 * y_range]) + ax.set_xlim([min_x - 0.1 * x_range, max_x + 0.1 * x_range]) nodes_complexity = np.array( [node_data["heb_node"].complexity for _, node_data in self.nodes(data=True)] @@ -230,19 +235,20 @@ def draw( draw_networkx_labels( self, labels={ - node: f"{node_data['label']}\n{node_data['heb_node'].complexity:.0f}" + node: f"{node_data['label']}" for node, node_data in self.nodes(data=True) }, ax=ax, horizontalalignment="center", verticalalignment="center", + font_size=8, pos=pos, **nodes_kwargs, ) if edges_kwargs is None: edges_kwargs = {} if "connectionstyle" not in edges_kwargs: - edges_kwargs.update(connectionstyle="angle,angleA=0,angleB=90,rad=10") + edges_kwargs.update(connectionstyle="angle,angleA=0,angleB=90,rad=5") draw_networkx_edges( self, ax=ax, @@ -300,18 +306,10 @@ def _call_graph_pos(call_graph: "CallGraph") -> Dict[str, Tuple[float, float]]: branches = all_simple_paths(call_graph, roots[0], leafs) branches = sorted(branches, key=lambda x: -len(x)) - branches_per_rank: Dict[int, List[int]] = {} for branch_id, nodes_in_branch in enumerate(branches): for node in nodes_in_branch: if node in pos: continue rank = node.rank - if rank not in branches_per_rank: - branches_per_rank[rank] = [] - - if branch_id not in branches_per_rank[rank]: - branches_per_rank[rank].append(branch_id) - - display_branch = branches_per_rank[rank].index(branch_id) - pos[node] = [display_branch, -rank] + pos[node] = [branch_id, -rank] return pos