Skip to content

Commit

Permalink
Add option to save graph as PNG (#523)
Browse files Browse the repository at this point in the history
* Add option to save graph as PNG

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove file after test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
jan-janssen and pre-commit-ci[bot] authored Dec 17, 2024
1 parent baec2a1 commit c3a0ae7
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 15 deletions.
1 change: 0 additions & 1 deletion .ci_support/environment-mpich.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ dependencies:
- mpi4py =4.0.1
- pyzmq =26.2.0
- h5py =3.12.1
- matplotlib =3.10.0
- networkx =3.4.2
- pygraphviz =1.14
- ipython =8.30.0
Expand Down
1 change: 0 additions & 1 deletion .ci_support/environment-old.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ dependencies:
- mpi4py =3.1.4
- pyzmq =25.0.0
- h5py =3.6.0
- matplotlib =3.5.3
- networkx =2.8.8
- ipython =7.33.0
- pygraphviz =1.10
1 change: 0 additions & 1 deletion .ci_support/environment-openmpi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ dependencies:
- mpi4py =4.0.1
- pyzmq =26.2.0
- h5py =3.12.1
- matplotlib =3.10.0
- networkx =3.4.2
- pygraphviz =1.14
- pysqa =0.2.2
Expand Down
1 change: 0 additions & 1 deletion .ci_support/environment-win.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ dependencies:
- mpi4py =4.0.1
- pyzmq =26.2.0
- h5py =3.12.1
- matplotlib =3.10.0
- networkx =3.4.2
- pygraphviz =1.14
- ipython =8.30.0
Expand Down
5 changes: 5 additions & 0 deletions executorlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class Executor:
refresh_rate (float): Set the refresh rate in seconds, how frequently the input queue is checked.
plot_dependency_graph (bool): Plot the dependencies of multiple future objects without executing them. For
debugging purposes and to get an overview of the specified dependencies.
plot_dependency_graph_filename (str): Name of the file to store the plotted graph in.
Examples:
```
Expand Down Expand Up @@ -101,6 +102,7 @@ def __init__(
disable_dependencies: bool = False,
refresh_rate: float = 0.01,
plot_dependency_graph: bool = False,
plot_dependency_graph_filename: Optional[str] = None,
):
# Use __new__() instead of __init__(). This function is only implemented to enable auto-completion.
pass
Expand All @@ -122,6 +124,7 @@ def __new__(
disable_dependencies: bool = False,
refresh_rate: float = 0.01,
plot_dependency_graph: bool = False,
plot_dependency_graph_filename: Optional[str] = None,
):
"""
Instead of returning a executorlib.Executor object this function returns either a executorlib.mpi.PyMPIExecutor,
Expand Down Expand Up @@ -167,6 +170,7 @@ def __new__(
refresh_rate (float): Set the refresh rate in seconds, how frequently the input queue is checked.
plot_dependency_graph (bool): Plot the dependencies of multiple future objects without executing them. For
debugging purposes and to get an overview of the specified dependencies.
plot_dependency_graph_filename (str): Name of the file to store the plotted graph in.
"""
default_resource_dict = {
Expand Down Expand Up @@ -216,6 +220,7 @@ def __new__(
init_function=init_function,
refresh_rate=refresh_rate,
plot_dependency_graph=plot_dependency_graph,
plot_dependency_graph_filename=plot_dependency_graph_filename,
)
else:
_check_pysqa_config_directory(pysqa_config_directory=pysqa_config_directory)
Expand Down
15 changes: 13 additions & 2 deletions executorlib/interactive/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,15 @@ class ExecutorWithDependencies(ExecutorBase):
Args:
refresh_rate (float, optional): The refresh rate for updating the executor queue. Defaults to 0.01.
plot_dependency_graph (bool, optional): Whether to generate and plot the dependency graph. Defaults to False.
plot_dependency_graph_filename (str): Name of the file to store the plotted graph in.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Attributes:
_future_hash_dict (Dict[str, Future]): A dictionary mapping task hash to future object.
_task_hash_dict (Dict[str, Dict]): A dictionary mapping task hash to task dictionary.
_generate_dependency_graph (bool): Whether to generate the dependency graph.
_generate_dependency_graph (str): Name of the file to store the plotted graph in.
"""

Expand All @@ -57,6 +59,7 @@ def __init__(
*args: Any,
refresh_rate: float = 0.01,
plot_dependency_graph: bool = False,
plot_dependency_graph_filename: Optional[str] = None,
**kwargs: Any,
) -> None:
super().__init__(max_cores=kwargs.get("max_cores", None))
Expand All @@ -75,7 +78,11 @@ def __init__(
)
self._future_hash_dict = {}
self._task_hash_dict = {}
self._generate_dependency_graph = plot_dependency_graph
self._plot_dependency_graph_filename = plot_dependency_graph_filename
if plot_dependency_graph_filename is None:
self._generate_dependency_graph = plot_dependency_graph
else:
self._generate_dependency_graph = True

def submit(
self,
Expand Down Expand Up @@ -142,7 +149,11 @@ def __exit__(
v: k for k, v in self._future_hash_dict.items()
},
)
return draw(node_lst=node_lst, edge_lst=edge_lst)
return draw(
node_lst=node_lst,
edge_lst=edge_lst,
filename=self._plot_dependency_graph_filename,
)


def create_executor(
Expand Down
19 changes: 12 additions & 7 deletions executorlib/standalone/plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os.path
from concurrent.futures import Future
from typing import Tuple
from typing import Optional, Tuple

import cloudpickle

Expand Down Expand Up @@ -106,23 +107,27 @@ def convert_arg(arg, future_hash_inverse_dict):
)


def draw(node_lst: list, edge_lst: list):
def draw(node_lst: list, edge_lst: list, filename: Optional[str] = None):
"""
Draw the graph visualization of nodes and edges.
Args:
node_lst (list): List of nodes.
edge_lst (list): List of edges.
filename (str): Name of the file to store the plotted graph in.
"""
from IPython.display import SVG, display # noqa
import matplotlib.pyplot as plt # noqa
import networkx as nx # noqa

graph = nx.DiGraph()
for node in node_lst:
graph.add_node(node["id"], label=node["name"], shape=node["shape"])
for edge in edge_lst:
graph.add_edge(edge["start"], edge["end"], label=edge["label"])
svg = nx.nx_agraph.to_agraph(graph).draw(prog="dot", format="svg")
display(SVG(svg))
plt.show()
if filename is not None:
file_format = os.path.splitext(filename)[-1][1:]
with open(filename, "wb") as f:
f.write(nx.nx_agraph.to_agraph(graph).draw(prog="dot", format=file_format))
else:
from IPython.display import SVG, display # noqa

display(SVG(nx.nx_agraph.to_agraph(graph).draw(prog="dot", format="svg")))
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ Repository = "https://github.com/pyiron/executorlib"
cache = ["h5py==3.12.1"]
graph = [
"pygraphviz==1.14",
"matplotlib==3.10.0",
"networkx==3.4.2",
]
graphnotebook = [
"pygraphviz==1.14",
"networkx==3.4.2",
"ipython==8.30.0",
]
Expand All @@ -53,7 +56,6 @@ all = [
"pysqa==0.2.2",
"h5py==3.12.1",
"pygraphviz==1.14",
"matplotlib==3.10.0",
"networkx==3.4.2",
"ipython==8.30.0",
]
Expand Down
21 changes: 21 additions & 0 deletions tests/test_dependencies_executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from concurrent.futures import Future
import os
import unittest
from time import sleep
from queue import Queue
Expand Down Expand Up @@ -73,6 +74,26 @@ def test_executor_dependency_plot(self):
self.assertEqual(len(nodes), 5)
self.assertEqual(len(edges), 4)

@unittest.skipIf(
skip_graphviz_test,
"graphviz is not installed, so the plot_dependency_graph tests are skipped.",
)
def test_executor_dependency_plot_filename(self):
graph_file = os.path.join(os.path.dirname(__file__), "test.png")
with Executor(
max_cores=1,
backend="local",
plot_dependency_graph=False,
plot_dependency_graph_filename=graph_file,
) as exe:
cloudpickle_register(ind=1)
future_1 = exe.submit(add_function, 1, parameter_2=2)
future_2 = exe.submit(add_function, 1, parameter_2=future_1)
self.assertTrue(future_1.done())
self.assertTrue(future_2.done())
self.assertTrue(os.path.exists(graph_file))
os.remove(graph_file)

def test_create_executor_error(self):
with self.assertRaises(ValueError):
create_executor(backend="toast", resource_dict={"cores": 1})
Expand Down

0 comments on commit c3a0ae7

Please sign in to comment.