From 8ec079c1b17dbdc372ffe80ad53f0fffc444f427 Mon Sep 17 00:00:00 2001 From: Austin Zadoks Date: Wed, 18 Dec 2019 13:00:01 +0100 Subject: [PATCH] Add traverse_graph / AGE engine for visualization The graph visualization feature now uses the traverse_graph function (with AGE as the main engine) to collect the requested nodes to be visualized. This was implemented in the methods of the graph class: previously, `recurse_descendants` and `recurse_ancestors` used to work by calling `add_incoming` and `add_outgoing` many times, which in turn have to load nodes during the procedure. Now these are all independent and they all call the traverse_graph function, so the information is obtained directly from the query projections and no nodes are loaded. So these changes are not only important as a first step to homogenize graph traversal throughout the whole code: an improvement in the visualization procedure is expected as well. --- aiida/tools/visualization/graph.py | 362 +++++++++++++++++++---------- 1 file changed, 244 insertions(+), 118 deletions(-) diff --git a/aiida/tools/visualization/graph.py b/aiida/tools/visualization/graph.py index e858bb1c3d..38609ed6d4 100644 --- a/aiida/tools/visualization/graph.py +++ b/aiida/tools/visualization/graph.py @@ -12,11 +12,13 @@ """ import os + from graphviz import Digraph -from aiida.orm import load_node, Data, ProcessNode -from aiida.orm.querybuilder import QueryBuilder + +from aiida import orm from aiida.common import LinkType from aiida.orm.utils.links import LinkPair +from aiida.tools.graph.graph_traversers import traverse_graph __all__ = ('Graph', 'default_link_styles', 'default_node_styles', 'pstate_node_styles', 'default_node_sublabels') @@ -32,8 +34,8 @@ def default_link_styles(link_pair, add_label, add_type): :param add_type: include link type :type add_type: bool :rtype: dict - """ + style = { LinkType.INPUT_CALC: { 'style': 'solid', @@ -77,8 +79,8 @@ def default_node_styles(node): :param node: the node to map :type node: aiida.orm.nodes.node.Node :rtype: dict - """ + class_node_type = node.class_node_type try: @@ -135,8 +137,8 @@ def pstate_node_styles(node): :param node: the node to map :type node: aiida.orm.nodes.node.Node :rtype: dict - """ + class_node_type = node.class_node_type default = {'shape': 'rectangle', 'pencolor': 'black'} @@ -172,7 +174,7 @@ def pstate_node_styles(node): node_style = process_map.get(class_node_type, default) - if isinstance(node, ProcessNode): + if isinstance(node, orm.ProcessNode): # style process node, based on success/failure of process if node.is_failed or node.is_excepted or node.is_killed: node_style['fillcolor'] = '#de707fff' # red @@ -192,7 +194,6 @@ def default_node_sublabels(node): :param node: the node to map :type node: aiida.orm.nodes.node.Node :rtype: str - """ # pylint: disable=too-many-branches @@ -224,7 +225,7 @@ def default_node_sublabels(node): sublabel = '; '.join(sublabel_lines) elif class_node_type == 'data.upf.UpfData.': sublabel = '{}'.format(node.get_attribute('element', '')) - elif isinstance(node, ProcessNode): + elif isinstance(node, orm.ProcessNode): sublabel = [] if node.process_state is not None: sublabel.append('State: {}'.format(node.process_state.value)) @@ -276,17 +277,14 @@ def _add_graphviz_node( For subclasses of ProcessNode, we choose styles to distinguish between types, and also color the nodes for successful/failed processes - """ # pylint: disable=too-many-arguments node_style = {} - if isinstance(node, Data): - + if isinstance(node, orm.Data): node_style = node_style_func(node) label = ['{} ({})'.format(node.__class__.__name__, get_node_id_label(node, id_type))] - elif isinstance(node, ProcessNode): - + elif isinstance(node, orm.ProcessNode): node_style = node_style_func(node) label = [ @@ -320,8 +318,8 @@ def _add_graphviz_edge(graph, in_node, out_node, style=None): :param out_node: the tail node :param style: the graphviz style (Default value = None) :type style: dict or None - """ + if style is None: style = {} @@ -370,9 +368,9 @@ def __init__( node_sublabel_fn(node) -> str (Default value = None) :param node_id_type: the type of identifier to within the node text ('pk', 'uuid' or 'label') :type node_id_type: str - """ # pylint: disable=too-many-arguments + self._graph = Digraph(engine=engine, graph_attr=graph_attr) self._nodes = set() self._edges = set() @@ -406,12 +404,29 @@ def _load_node(node): :param node: node or node pk/uuid :type node: int or str or aiida.orm.nodes.node.Node :returns: aiida.orm.nodes.node.Node - """ if isinstance(node, (int, str)): - return load_node(node) + return orm.load_node(node) return node + @staticmethod + def _default_link_types(link_types): + """If link_types is empty, it will return all the links_types + + :param links: iterable with the link_types () + :returns: list of :py:class:`aiida.common.links.LinkType` + """ + if not link_types: + all_link_types = [LinkType.CREATE] + all_link_types.append(LinkType.RETURN) + all_link_types.append(LinkType.INPUT_CALC) + all_link_types.append(LinkType.INPUT_WORK) + all_link_types.append(LinkType.CALL_CALC) + all_link_types.append(LinkType.CALL_WORK) + return all_link_types + + return link_types + def add_node(self, node, style_override=None, overwrite=False): """add single node to the graph @@ -421,7 +436,6 @@ def add_node(self, node, style_override=None, overwrite=False): :type style_override: dict or None :param overwrite: whether to overrite an existing node (Default value = False) :type overwrite: bool - """ node = self._load_node(node) style = {} if style_override is None else style_override @@ -452,7 +466,6 @@ def add_edge(self, in_node, out_node, link_pair=None, style=None, overwrite=Fals :type style: dict or None :param overwrite: whether to overrite existing edge (Default value = False) :type overwrite: bool - """ in_node = self._load_node(in_node) if in_node.pk not in self._nodes: @@ -472,8 +485,7 @@ def add_edge(self, in_node, out_node, link_pair=None, style=None, overwrite=Fals @staticmethod def _convert_link_types(link_types): - """ convert link types, which may be strings, to a member of LinkType - """ + """convert link types, which may be strings, to a member of LinkType""" if link_types is None: return None if isinstance(link_types, str): @@ -493,24 +505,50 @@ def add_incoming(self, node, link_types=(), annotate_links=None, return_pks=True :param return_pks: whether to return a list of nodes, or list of node pks (Default value = True) :type return_pks: bool :returns: list of nodes or node pks - """ if annotate_links not in [None, False, 'label', 'type', 'both']: - raise AssertionError('annotate_links must be one of False, "label", "type" or "both"') + raise ValueError( + 'annotate_links must be one of False, "label", "type" or "both"\ninstead, it is: {}'. + format(annotate_links) + ) + + # incoming nodes are found traversing backwards + node_pk = node if isinstance(node, int) else node.pk + valid_link_types = self._default_link_types(link_types) + valid_link_types = self._convert_link_types(valid_link_types) + traversed_graph = traverse_graph( + (node_pk,), + max_iterations=1, + get_links=True, + links_backward=valid_link_types, + ) + + traversed_nodes = orm.QueryBuilder().append( + orm.Node, + filters={'id': { + 'in': traversed_graph['nodes'] + }}, + project=['id', '*'], + tag='node', + ) + traversed_nodes = {query_result[0]: query_result[1] for query_result in traversed_nodes.all()} - node = self.add_node(node) + for _, traversed_node in traversed_nodes.items(): + self.add_node(traversed_node, style_override=None) - nodes = [] - for link_triple in node.get_incoming(link_type=self._convert_link_types(link_types)).link_triples: - self.add_node(link_triple.node) - link_pair = LinkPair(link_triple.link_type, link_triple.link_label) - style = self._link_styles( + for link in traversed_graph['links']: + source_node = traversed_nodes[link.source_id] + target_node = traversed_nodes[link.target_id] + link_pair = LinkPair(self._convert_link_types(link.link_type)[0], link.link_label) + link_style = self._link_styles( link_pair, add_label=annotate_links in ['label', 'both'], add_type=annotate_links in ['type', 'both'] ) - self.add_edge(link_triple.node, node, link_pair, style=style) - nodes.append(link_triple.node.pk if return_pks else link_triple.node) + self.add_edge(source_node, target_node, link_pair, style=link_style) - return nodes + if return_pks: + return list(traversed_nodes.keys()) + # else: + return list(traversed_nodes.values()) def add_outgoing(self, node, link_types=(), annotate_links=None, return_pks=True): """add nodes and edges for outgoing links to a node @@ -524,24 +562,50 @@ def add_outgoing(self, node, link_types=(), annotate_links=None, return_pks=True :param return_pks: whether to return a list of nodes, or list of node pks (Default value = True) :type return_pks: bool :returns: list of nodes or node pks - """ if annotate_links not in [None, False, 'label', 'type', 'both']: - raise AssertionError('annotate_links must be one of False, "label", "type" or "both"') + raise ValueError( + 'annotate_links must be one of False, "label", "type" or "both"\ninstead, it is: {}'. + format(annotate_links) + ) + + # outgoing nodes are found traversing forwards + node_pk = node if isinstance(node, int) else node.pk + valid_link_types = self._default_link_types(link_types) + valid_link_types = self._convert_link_types(valid_link_types) + traversed_graph = traverse_graph( + (node_pk,), + max_iterations=1, + get_links=True, + links_forward=valid_link_types, + ) + + traversed_nodes = orm.QueryBuilder().append( + orm.Node, + filters={'id': { + 'in': traversed_graph['nodes'] + }}, + project=['id', '*'], + tag='node', + ) + traversed_nodes = {query_result[0]: query_result[1] for query_result in traversed_nodes.all()} - node = self.add_node(node) + for _, traversed_node in traversed_nodes.items(): + self.add_node(traversed_node, style_override=None) - nodes = [] - for link_triple in node.get_outgoing(link_type=self._convert_link_types(link_types)).link_triples: - self.add_node(link_triple.node) - link_pair = LinkPair(link_triple.link_type, link_triple.link_label) - style = self._link_styles( + for link in traversed_graph['links']: + source_node = traversed_nodes[link.source_id] + target_node = traversed_nodes[link.target_id] + link_pair = LinkPair(self._convert_link_types(link.link_type)[0], link.link_label) + link_style = self._link_styles( link_pair, add_label=annotate_links in ['label', 'both'], add_type=annotate_links in ['type', 'both'] ) - self.add_edge(node, link_triple.node, link_pair, style=style) - nodes.append(link_triple.node.pk if return_pks else link_triple.node) + self.add_edge(source_node, target_node, link_pair, style=link_style) - return nodes + if return_pks: + return list(traversed_nodes.keys()) + # else: + return list(traversed_nodes.values()) def recurse_descendants( self, @@ -549,7 +613,7 @@ def recurse_descendants( depth=None, link_types=(), annotate_links=False, - origin_style=(), + origin_style=None, include_process_inputs=False, print_func=None ): @@ -564,47 +628,79 @@ def recurse_descendants( :type link_types: tuple or str :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = False) :type annotate_links: bool or str - :param origin_style: node style map for origin node (Default value = ()) - :type origin_style: dict or tuple + :param origin_style: node style map for origin node (Default value = None) + :type origin_style: None or dict :param include_calculation_inputs: include incoming links for all processes (Default value = False) :type include_calculation_inputs: bool - :param print_func: a function to stream information to, i.e. print_func(str) + :param print_func: + a function to stream information to, i.e. print_func(str) + (this feature is deprecated since `v1.1.0` and will be removed in `v2.0.0`) """ - # pylint: disable=too-many-arguments - origin_node = self._load_node(origin) + # pylint: disable=too-many-arguments,too-many-locals + import warnings + from aiida.common.warnings import AiidaDeprecationWarning + if print_func: + warnings.warn( # pylint: disable=no-member + '`print_func` is deprecated because graph traversal has been refactored', AiidaDeprecationWarning + ) - self.add_node(origin_node, style_override=dict(origin_style)) + # Get graph traversal rules where the given link types and direction are all set to True, + # and all others are set to False + origin_pk = origin if isinstance(origin, int) else origin.pk + valid_link_types = self._default_link_types(link_types) + valid_link_types = self._convert_link_types(valid_link_types) + traversed_graph = traverse_graph( + (origin_pk,), + max_iterations=depth, + get_links=True, + links_forward=valid_link_types, + ) - leaf_nodes = [origin_node] - traversed_pks = [origin_node.pk] - cur_depth = 0 - while leaf_nodes: - cur_depth += 1 - # checking of maximum descendant depth is set and applies. - if depth is not None and cur_depth > depth: - break - if print_func: - print_func('- Depth: {}'.format(cur_depth)) - new_nodes = [] - for node in leaf_nodes: - outgoing_nodes = self.add_outgoing( - node, link_types=link_types, annotate_links=annotate_links, return_pks=False - ) - if outgoing_nodes and print_func: - print_func(' {} -> {}'.format(node.pk, [on.pk for on in outgoing_nodes])) - new_nodes.extend(outgoing_nodes) - - if include_process_inputs and isinstance(node, ProcessNode): - self.add_incoming(node, link_types=link_types, annotate_links=annotate_links) - - # ensure the same path isn't traversed multiple times - leaf_nodes = [] - for new_node in new_nodes: - if new_node.pk in traversed_pks: - continue - leaf_nodes.append(new_node) - traversed_pks.append(new_node.pk) + # Traverse backward along input_work and input_calc links from all nodes traversed in the previous step + # and join the result with the original traversed graph. This includes calculation inputs in the Graph + if include_process_inputs: + traversed_outputs = traverse_graph( + traversed_graph['nodes'], + max_iterations=1, + get_links=True, + links_backward=[LinkType.INPUT_WORK, LinkType.INPUT_CALC] + ) + traversed_graph['nodes'] = traversed_graph['nodes'].union(traversed_outputs['nodes']) + traversed_graph['links'] = traversed_graph['links'].union(traversed_outputs['links']) + + # Do one central query for all nodes in the Graph and generate a {id: Node} dictionary + traversed_nodes = orm.QueryBuilder().append( + orm.Node, + filters={'id': { + 'in': traversed_graph['nodes'] + }}, + project=['id', '*'], + tag='node', + ) + traversed_nodes = {query_result[0]: query_result[1] for query_result in traversed_nodes.all()} + + # Pop the origin node and add it to the graph, applying custom styling + origin_node = traversed_nodes.pop(origin_pk) + self.add_node(origin_node, style_override=origin_style) + + # Add all traversed nodes to the graph with default styling + for _, traversed_node in traversed_nodes.items(): + self.add_node(traversed_node, style_override={}) + + # Add the origin node back into traversed nodes so it can be found for adding edges + traversed_nodes[origin_pk] = origin_node + + # Add all links to the Graph, using the {id: Node} dictionary for queryless Node retrieval, applying + # appropriate styling + for link in traversed_graph['links']: + source_node = traversed_nodes[link.source_id] + target_node = traversed_nodes[link.target_id] + link_pair = LinkPair(self._convert_link_types(link.link_type)[0], link.link_label) + link_style = self._link_styles( + link_pair, add_label=annotate_links in ['label', 'both'], add_type=annotate_links in ['type', 'both'] + ) + self.add_edge(source_node, target_node, link_pair, style=link_style) def recurse_ancestors( self, @@ -612,7 +708,7 @@ def recurse_ancestors( depth=None, link_types=(), annotate_links=False, - origin_style=(), + origin_style=None, include_process_outputs=False, print_func=None ): @@ -627,47 +723,79 @@ def recurse_ancestors( :type link_types: tuple or str :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = False) :type annotate_links: bool - :param origin_style: node style map for origin node (Default value = ()) - :type origin_style: dict or tuple + :param origin_style: node style map for origin node (Default value = None) + :type origin_style: None or dict :param include_process_outputs: include outgoing links for all processes (Default value = False) :type include_process_outputs: bool :param print_func: a function to stream information to, i.e. print_func(str) + .. deprecated:: 1.1.0 + `print_func` will be removed in `v2.0.0` """ - # pylint: disable=too-many-arguments - origin_node = self._load_node(origin) + # pylint: disable=too-many-arguments,too-many-locals + import warnings + from aiida.common.warnings import AiidaDeprecationWarning + if print_func: + warnings.warn( # pylint: disable=no-member + '`print_func` is deprecated because graph traversal has been refactored', AiidaDeprecationWarning + ) - self.add_node(origin_node, style_override=dict(origin_style)) + # Get graph traversal rules where the given link types and direction are all set to True, + # and all others are set to False + origin_pk = origin if isinstance(origin, int) else origin.pk + valid_link_types = self._default_link_types(link_types) + valid_link_types = self._convert_link_types(valid_link_types) + traversed_graph = traverse_graph( + (origin_pk,), + max_iterations=depth, + get_links=True, + links_backward=valid_link_types, + ) - last_nodes = [origin_node] - traversed_pks = [origin_node.pk] - cur_depth = 0 - while last_nodes: - cur_depth += 1 - # checking of maximum descendant depth is set and applies. - if depth is not None and cur_depth > depth: - break - if print_func: - print_func('- Depth: {}'.format(cur_depth)) - new_nodes = [] - for node in last_nodes: - incoming_nodes = self.add_incoming( - node, link_types=link_types, annotate_links=annotate_links, return_pks=False - ) - if incoming_nodes and print_func: - print_func(' {} -> {}'.format(node.pk, [n.pk for n in incoming_nodes])) - new_nodes.extend(incoming_nodes) - - if include_process_outputs and isinstance(node, ProcessNode): - self.add_outgoing(node, link_types=link_types, annotate_links=annotate_links) - - # ensure the same path isn't traversed multiple times - last_nodes = [] - for new_node in new_nodes: - if new_node.pk in traversed_pks: - continue - last_nodes.append(new_node) - traversed_pks.append(new_node.pk) + # Traverse forward along input_work and input_calc links from all nodes traversed in the previous step + # and join the result with the original traversed graph. This includes calculation outputs in the Graph + if include_process_outputs: + traversed_outputs = traverse_graph( + traversed_graph['nodes'], + max_iterations=1, + get_links=True, + links_forward=[LinkType.CREATE, LinkType.RETURN] + ) + traversed_graph['nodes'] = traversed_graph['nodes'].union(traversed_outputs['nodes']) + traversed_graph['links'] = traversed_graph['links'].union(traversed_outputs['links']) + + # Do one central query for all nodes in the Graph and generate a {id: Node} dictionary + traversed_nodes = orm.QueryBuilder().append( + orm.Node, + filters={'id': { + 'in': traversed_graph['nodes'] + }}, + project=['id', '*'], + tag='node', + ) + traversed_nodes = {query_result[0]: query_result[1] for query_result in traversed_nodes.all()} + + # Pop the origin node and add it to the graph, applying custom styling + origin_node = traversed_nodes.pop(origin_pk) + self.add_node(origin_node, style_override=origin_style) + + # Add all traversed nodes to the graph with default styling + for _, traversed_node in traversed_nodes.items(): + self.add_node(traversed_node, style_override=None) + + # Add the origin node back into traversed nodes so it can be found for adding edges + traversed_nodes[origin_pk] = origin_node + + # Add all links to the Graph, using the {id: Node} dictionary for queryless Node retrieval, applying + # appropriate styling + for link in traversed_graph['links']: + source_node = traversed_nodes[link.source_id] + target_node = traversed_nodes[link.target_id] + link_pair = LinkPair(self._convert_link_types(link.link_type)[0], link.link_label) + link_style = self._link_styles( + link_pair, add_label=annotate_links in ['label', 'both'], add_type=annotate_links in ['type', 'both'] + ) + self.add_edge(source_node, target_node, link_pair, style=link_style) def add_origin_to_targets( self, @@ -694,7 +822,6 @@ def add_origin_to_targets( :type origin_style: dict or tuple :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = False) :type annotate_links: bool - """ # pylint: disable=too-many-arguments origin_node = self._load_node(origin) @@ -704,7 +831,7 @@ def add_origin_to_targets( self.add_node(origin_node, style_override=dict(origin_style)) - query = QueryBuilder( + query = orm.QueryBuilder( **{ 'path': [{ 'cls': origin_node.__class__, @@ -759,13 +886,12 @@ def add_origins_to_targets( :type origin_style: dict or tuple :param annotate_links: label edges with the link 'label', 'type' or 'both' (Default value = False) :type annotate_links: bool - """ # pylint: disable=too-many-arguments if origin_filters is None: origin_filters = {} - query = QueryBuilder( + query = orm.QueryBuilder( **{'path': [{ 'cls': origin_cls, 'filters': origin_filters,