From cd1748f88cae2acdbce026b755d0c2f991bc610e Mon Sep 17 00:00:00 2001 From: Joe Zuntz Date: Fri, 31 Jan 2025 13:07:08 +0000 Subject: [PATCH] collect groups of file nodes together --- ceci/pipeline.py | 43 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/ceci/pipeline.py b/ceci/pipeline.py index fe85798..c1eb5ba 100644 --- a/ceci/pipeline.py +++ b/ceci/pipeline.py @@ -6,7 +6,7 @@ import yaml import shutil from abc import abstractmethod -import warnings +import collections from .stage import PipelineStage from . import minirunner @@ -1066,6 +1066,10 @@ def make_flow_chart(self, filename): # Nodes we have already added seen = set() + # Dictionary to track nodes by their inputs and outputs + node_groups = {} + + # Add overall pipeline inputs for inp in self.overall_inputs.keys(): graph.add_node(inp, shape="box", color="gold", style="filled") @@ -1091,6 +1095,43 @@ def make_flow_chart(self, filename): seen.add(out) graph.add_edge(stage.instance_name, out, color="black") + # We want to group together all the files that all created + # by the same stage and also all used by the same stages, to + # reduce the number of nodes in the graph and make it more readable. + # First we find that grouping. + node_groups = collections.defaultdict(list) + for node in graph.nodes_iter(): + # only affect the nodes representing files + if node.attr['color'] != "skyblue": + continue + # Find the stage node that created this file, + # and all the stage nodes that make use of it + edge_in = graph.in_edges(node)[0] + creator = edge_in[0] + users = [] + for edge in graph.out_edges(node): + users.append(edge[1]) + key = (creator, tuple(users)) + node_groups[key].append(node) + + # Now we remove all the groups of nodes with more than one in + # and replace them with a single node + for key, nodes in node_groups.items(): + if len(nodes) > 1: + if len(nodes) > 4: + # make a string with two nodes per line + node_names = [] + for i in range(0, len(nodes), 2): + node_names.append(", ".join(nodes[i:i+2])) + new_node = "\n".join(node_names) + else: + new_node = "\n".join(nodes) + graph.remove_nodes_from(nodes) + graph.add_node(new_node, shape="box", color="skyblue", style="filled") + graph.add_edge(key[0], new_node, color="black") + for user in key[1]: + graph.add_edge(new_node, user, color="black") + # finally, output the stage to file if filename.endswith(".dot"): graph.write(filename)