Skip to content

Commit

Permalink
⚡️ Remove last call recursivity
Browse files Browse the repository at this point in the history
Make call graph branches less compact but more readable
Update graph look to remove little cornrer abberations
  • Loading branch information
MathisFederico committed Feb 2, 2024
1 parent 5447dff commit 5abde84
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions src/hebg/call_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,13 @@ def call_nodes(
# Search for name reference in all_behaviors
if node.name in heb_graph.all_behaviors:
node = heb_graph.all_behaviors[node.name]
action = node(observation, self)
if not hasattr(node, "graph"):
action = node(observation)
break
self._extend_frontiere(node.graph.roots, heb_graph=node.graph)
elif node.type == "action":
action = node(observation)
break
elif node.type == "feature_condition":
if node in self._known_fc:
next_edge_index = self._known_fc[node]
Expand Down Expand Up @@ -202,8 +206,9 @@ def draw(
max_x, max_y = pos_arr.max(axis=0)
min_x, min_y = pos_arr.min(axis=0)
y_range = max_y - min_y
x_range = max_x - min_x
ax.set_ylim([min_y - 0.1 * y_range, max_y + 0.1 * y_range])
ax.set_xlim([min_x - 0.1 * y_range, min_x + y_range + 0.1 * y_range])
ax.set_xlim([min_x - 0.1 * x_range, max_x + 0.1 * x_range])

nodes_complexity = np.array(
[node_data["heb_node"].complexity for _, node_data in self.nodes(data=True)]
Expand All @@ -230,19 +235,20 @@ def draw(
draw_networkx_labels(
self,
labels={
node: f"{node_data['label']}\n{node_data['heb_node'].complexity:.0f}"
node: f"{node_data['label']}"
for node, node_data in self.nodes(data=True)
},
ax=ax,
horizontalalignment="center",
verticalalignment="center",
font_size=8,
pos=pos,
**nodes_kwargs,
)
if edges_kwargs is None:
edges_kwargs = {}
if "connectionstyle" not in edges_kwargs:
edges_kwargs.update(connectionstyle="angle,angleA=0,angleB=90,rad=10")
edges_kwargs.update(connectionstyle="angle,angleA=0,angleB=90,rad=5")
draw_networkx_edges(
self,
ax=ax,
Expand Down Expand Up @@ -300,18 +306,10 @@ def _call_graph_pos(call_graph: "CallGraph") -> Dict[str, Tuple[float, float]]:
branches = all_simple_paths(call_graph, roots[0], leafs)
branches = sorted(branches, key=lambda x: -len(x))

branches_per_rank: Dict[int, List[int]] = {}
for branch_id, nodes_in_branch in enumerate(branches):
for node in nodes_in_branch:
if node in pos:
continue
rank = node.rank
if rank not in branches_per_rank:
branches_per_rank[rank] = []

if branch_id not in branches_per_rank[rank]:
branches_per_rank[rank].append(branch_id)

display_branch = branches_per_rank[rank].index(branch_id)
pos[node] = [display_branch, -rank]
pos[node] = [branch_id, -rank]
return pos

0 comments on commit 5abde84

Please sign in to comment.