From 2f2bdc38191632f30b2c0be4f3445190dc896d90 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 22 Oct 2021 11:24:22 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=91=8C=20IMPROVE:=20Ensure=20`QueryBuilde?= =?UTF-8?q?r`=20is=20passed=20`Backend`=20(#5186)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR ensures core code always calls `QueryBuilder` with a specific `Backend`, as opposed to assuming the loaded `Backend`. This will allow for muliple backends to be used at the same time (for example export archives), for features including graph traversal and visualisation. --- aiida/cmdline/utils/common.py | 8 +++-- aiida/orm/implementation/django/comments.py | 2 +- aiida/orm/implementation/django/logs.py | 2 +- .../orm/implementation/sqlalchemy/comments.py | 2 +- aiida/orm/implementation/sqlalchemy/logs.py | 2 +- aiida/orm/nodes/data/array/bands.py | 4 +-- aiida/orm/nodes/data/cif.py | 4 +-- aiida/orm/nodes/data/code.py | 8 ++--- aiida/orm/nodes/data/upf.py | 14 ++++----- aiida/orm/nodes/node.py | 8 ++--- aiida/orm/querybuilder.py | 11 +++++-- aiida/orm/utils/links.py | 8 ++--- aiida/orm/utils/remote.py | 6 ++-- aiida/tools/graph/age_entities.py | 11 ------- aiida/tools/graph/age_rules.py | 7 +++-- aiida/tools/graph/deletions.py | 20 ++++++------ aiida/tools/graph/graph_traversers.py | 24 +++++++++----- aiida/tools/visualization/graph.py | 31 +++++++++++++++---- 18 files changed, 101 insertions(+), 71 deletions(-) diff --git a/aiida/cmdline/utils/common.py b/aiida/cmdline/utils/common.py index 36fa393170..9c57980ab9 100644 --- a/aiida/cmdline/utils/common.py +++ b/aiida/cmdline/utils/common.py @@ -11,11 +11,15 @@ import logging import os import sys +from typing import TYPE_CHECKING from tabulate import tabulate from . import echo +if TYPE_CHECKING: + from aiida.orm import WorkChainNode + __all__ = ('is_verbose',) @@ -306,7 +310,7 @@ def get_process_function_report(node): return '\n'.join(report) -def get_workchain_report(node, levelname, indent_size=4, max_depth=None): +def get_workchain_report(node: 'WorkChainNode', levelname, indent_size=4, max_depth=None): """ Return a multi line string representation of the log messages and output of a given workchain @@ -333,7 +337,7 @@ def get_subtree(uuid, level=0): Get a nested tree of work calculation nodes and their nesting level starting from this uuid. The result is a list of uuid of these nodes. """ - builder = orm.QueryBuilder() + builder = orm.QueryBuilder(backend=node.backend) builder.append(cls=orm.WorkChainNode, filters={'uuid': uuid}, tag='workcalculation') builder.append( cls=orm.WorkChainNode, diff --git a/aiida/orm/implementation/django/comments.py b/aiida/orm/implementation/django/comments.py index be7fe71b9d..ab874b6b52 100644 --- a/aiida/orm/implementation/django/comments.py +++ b/aiida/orm/implementation/django/comments.py @@ -168,7 +168,7 @@ def delete_many(self, filters): raise exceptions.ValidationError('filters must not be empty') # Apply filter and delete found entities - builder = QueryBuilder().append(Comment, filters=filters, project='id').all() + builder = QueryBuilder(backend=self.backend).append(Comment, filters=filters, project='id').all() entities_to_delete = [_[0] for _ in builder] for entity in entities_to_delete: self.delete(entity) diff --git a/aiida/orm/implementation/django/logs.py b/aiida/orm/implementation/django/logs.py index 7b3b725c2c..4ddd8fe10f 100644 --- a/aiida/orm/implementation/django/logs.py +++ b/aiida/orm/implementation/django/logs.py @@ -144,7 +144,7 @@ def delete_many(self, filters): raise exceptions.ValidationError('filters must not be empty') # Apply filter and delete found entities - builder = QueryBuilder().append(Log, filters=filters, project='id') + builder = QueryBuilder(backend=self.backend).append(Log, filters=filters, project='id') entities_to_delete = builder.all(flat=True) for entity in entities_to_delete: self.delete(entity) diff --git a/aiida/orm/implementation/sqlalchemy/comments.py b/aiida/orm/implementation/sqlalchemy/comments.py index da100140dd..618aa021bf 100644 --- a/aiida/orm/implementation/sqlalchemy/comments.py +++ b/aiida/orm/implementation/sqlalchemy/comments.py @@ -171,7 +171,7 @@ def delete_many(self, filters): raise exceptions.ValidationError('filters must not be empty') # Apply filter and delete found entities - builder = QueryBuilder().append(Comment, filters=filters, project='id') + builder = QueryBuilder(backend=self.backend).append(Comment, filters=filters, project='id') entities_to_delete = builder.all(flat=True) for entity in entities_to_delete: self.delete(entity) diff --git a/aiida/orm/implementation/sqlalchemy/logs.py b/aiida/orm/implementation/sqlalchemy/logs.py index b4d75ad6ac..62a973171d 100644 --- a/aiida/orm/implementation/sqlalchemy/logs.py +++ b/aiida/orm/implementation/sqlalchemy/logs.py @@ -153,7 +153,7 @@ def delete_many(self, filters): raise exceptions.ValidationError('filter must not be empty') # Apply filter and delete found entities - builder = QueryBuilder().append(Log, filters=filters, project='id') + builder = QueryBuilder(backend=self.backend).append(Log, filters=filters, project='id') entities_to_delete = builder.all(flat=True) for entity in entities_to_delete: self.delete(entity) diff --git a/aiida/orm/nodes/data/array/bands.py b/aiida/orm/nodes/data/array/bands.py index 83484b0983..ec6adffeb3 100644 --- a/aiida/orm/nodes/data/array/bands.py +++ b/aiida/orm/nodes/data/array/bands.py @@ -1803,7 +1803,7 @@ def _prepare_json(self, main_file_name='', comments=True): # pylint: disable=un MATPLOTLIB_FOOTER_TEMPLATE_EXPORTFILE_WITH_DPI = Template("""pl.savefig("$fname", format="$format", dpi=$dpi)""") -def get_bands_and_parents_structure(args): +def get_bands_and_parents_structure(args, backend=None): """Search for bands and return bands and the closest structure that is a parent of the instance. :returns: @@ -1817,7 +1817,7 @@ def get_bands_and_parents_structure(args): from aiida import orm from aiida.common import timezone - q_build = orm.QueryBuilder() + q_build = orm.QueryBuilder(backend=backend) if args.all_users is False: q_build.append(orm.User, tag='creator', filters={'email': orm.User.objects.get_default().email}) else: diff --git a/aiida/orm/nodes/data/cif.py b/aiida/orm/nodes/data/cif.py index 5b0696e103..f0278d0724 100644 --- a/aiida/orm/nodes/data/cif.py +++ b/aiida/orm/nodes/data/cif.py @@ -329,7 +329,7 @@ def read_cif(fileobj, index=-1, **kwargs): return struct_list[index] @classmethod - def from_md5(cls, md5): + def from_md5(cls, md5, backend=None): """ Return a list of all CIF files that match a given MD5 hash. @@ -337,7 +337,7 @@ def from_md5(cls, md5): otherwise the CIF file will not be found. """ from aiida.orm.querybuilder import QueryBuilder - builder = QueryBuilder() + builder = QueryBuilder(backend=backend) builder.append(cls, filters={'attributes.md5': {'==': md5}}) return builder.all(flat=True) diff --git a/aiida/orm/nodes/data/code.py b/aiida/orm/nodes/data/code.py index 9bd8787b13..c936e0ef16 100644 --- a/aiida/orm/nodes/data/code.py +++ b/aiida/orm/nodes/data/code.py @@ -151,7 +151,7 @@ def get_description(self): return f'{self.description}' @classmethod - def get_code_helper(cls, label, machinename=None): + def get_code_helper(cls, label, machinename=None, backend=None): """ :param label: the code label identifying the code to load :param machinename: the machine name where code is setup @@ -164,7 +164,7 @@ def get_code_helper(cls, label, machinename=None): from aiida.orm.computers import Computer from aiida.orm.querybuilder import QueryBuilder - query = QueryBuilder() + query = QueryBuilder(backend=backend) query.append(cls, filters={'label': label}, project='*', tag='code') if machinename: query.append(Computer, filters={'label': machinename}, with_node='code') @@ -249,7 +249,7 @@ def get_from_string(cls, code_string): raise MultipleObjectsError(f'{code_string} could not be uniquely resolved') @classmethod - def list_for_plugin(cls, plugin, labels=True): + def list_for_plugin(cls, plugin, labels=True, backend=None): """ Return a list of valid code strings for a given plugin. @@ -260,7 +260,7 @@ def list_for_plugin(cls, plugin, labels=True): otherwise a list of integers with the code PKs. """ from aiida.orm.querybuilder import QueryBuilder - query = QueryBuilder() + query = QueryBuilder(backend=backend) query.append(cls, filters={'attributes.input_plugin': {'==': plugin}}) valid_codes = query.all(flat=True) diff --git a/aiida/orm/nodes/data/upf.py b/aiida/orm/nodes/data/upf.py index 1ad082dd37..b212327ba2 100644 --- a/aiida/orm/nodes/data/upf.py +++ b/aiida/orm/nodes/data/upf.py @@ -70,7 +70,7 @@ def get_pseudos_from_structure(structure, family_name): return pseudo_list -def upload_upf_family(folder, group_label, group_description, stop_if_existing=True): +def upload_upf_family(folder, group_label, group_description, stop_if_existing=True, backend=None): """Upload a set of UPF files in a given group. :param folder: a path containing all UPF files to be added. @@ -120,7 +120,7 @@ def upload_upf_family(folder, group_label, group_description, stop_if_existing=T for filename in filenames: md5sum = md5_file(filename) - builder = orm.QueryBuilder() + builder = orm.QueryBuilder(backend=backend) builder.append(UpfData, filters={'attributes.md5': {'==': md5sum}}) existing_upf = builder.first() @@ -321,7 +321,7 @@ def store(self, *args, **kwargs): # pylint: disable=signature-differs return super().store(*args, **kwargs) @classmethod - def from_md5(cls, md5): + def from_md5(cls, md5, backend=None): """Return a list of all `UpfData` that match the given md5 hash. .. note:: assumes hash of stored `UpfData` nodes is stored in the `md5` attribute @@ -330,7 +330,7 @@ def from_md5(cls, md5): :return: list of existing `UpfData` nodes that have the same md5 hash """ from aiida.orm.querybuilder import QueryBuilder - builder = QueryBuilder() + builder = QueryBuilder(backend=backend) builder.append(cls, filters={'attributes.md5': {'==': md5}}) return builder.all(flat=True) @@ -366,7 +366,7 @@ def get_upf_family_names(self): """Get the list of all upf family names to which the pseudo belongs.""" from aiida.orm import QueryBuilder, UpfFamily - query = QueryBuilder() + query = QueryBuilder(backend=self.backend) query.append(UpfFamily, tag='group', project='label') query.append(UpfData, filters={'id': {'==': self.id}}, with_group='group') return query.all(flat=True) @@ -448,7 +448,7 @@ def get_upf_group(cls, group_label): return UpfFamily.get(label=group_label) @classmethod - def get_upf_groups(cls, filter_elements=None, user=None): + def get_upf_groups(cls, filter_elements=None, user=None, backend=None): """Return all names of groups of type UpfFamily, possibly with some filters. :param filter_elements: A string or a list of strings. @@ -460,7 +460,7 @@ def get_upf_groups(cls, filter_elements=None, user=None): """ from aiida.orm import QueryBuilder, UpfFamily, User - builder = QueryBuilder() + builder = QueryBuilder(backend=backend) builder.append(UpfFamily, tag='group', project='*') if user: diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py index 723705f109..c8694d2405 100644 --- a/aiida/orm/nodes/node.py +++ b/aiida/orm/nodes/node.py @@ -456,11 +456,11 @@ def validate_incoming(self, source: 'Node', link_type: LinkType, link_label: str """ from aiida.orm.utils.links import validate_link - validate_link(source, self, link_type, link_label) + validate_link(source, self, link_type, link_label, backend=self.backend) # Check if the proposed link would introduce a cycle in the graph following ancestor/descendant rules if link_type in [LinkType.CREATE, LinkType.INPUT_CALC, LinkType.INPUT_WORK]: - builder = QueryBuilder().append( + builder = QueryBuilder(backend=self.backend).append( Node, filters={'id': self.pk}, tag='parent').append( Node, filters={'id': source.pk}, tag='child', with_ancestors='parent') # yapf:disable if builder.count() > 0: @@ -537,7 +537,7 @@ def get_stored_link_triples( if link_label_filter: edge_filters['label'] = {'like': link_label_filter} - builder = QueryBuilder() + builder = QueryBuilder(backend=self.backend) builder.append(Node, filters=node_filters, tag='main') node_project = ['uuid'] if only_uuid else ['*'] @@ -894,7 +894,7 @@ def _iter_all_same_nodes(self, allow_before_store=False) -> Iterator['Node']: if not node_hash or not self._cachable: return iter(()) - builder = QueryBuilder() + builder = QueryBuilder(backend=self.backend) builder.append(self.__class__, filters={'extras._aiida_hash': node_hash}, project='*', subclassing=False) nodes_identical = (n[0] for n in builder.iterall()) diff --git a/aiida/orm/querybuilder.py b/aiida/orm/querybuilder.py index 82444e5d55..d3c04ebc56 100644 --- a/aiida/orm/querybuilder.py +++ b/aiida/orm/querybuilder.py @@ -136,8 +136,8 @@ def __init__( :param distinct: Whether to return de-duplicated rows """ - backend = backend or get_manager().get_backend() - self._impl: BackendQueryBuilder = backend.query() + self._backend = backend or get_manager().get_backend() + self._impl: BackendQueryBuilder = self._backend.query() # SERIALISABLE ATTRIBUTES # A list storing the path being traversed by the query @@ -189,6 +189,11 @@ def __init__( if order_by: self.order_by(order_by) + @property + def backend(self) -> 'Backend': + """Return the backend used by the QueryBuilder.""" + return self._backend + def as_dict(self, copy: bool = True) -> QueryDictType: """Convert to a JSON serialisable dictionary representation of the query.""" data: QueryDictType = { @@ -225,7 +230,7 @@ def __str__(self) -> str: def __deepcopy__(self, memo) -> 'QueryBuilder': """Create deep copy of the instance.""" - return type(self)(**self.as_dict()) # type: ignore + return type(self)(backend=self.backend, **self.as_dict()) # type: ignore def get_used_tags(self, vertices: bool = True, edges: bool = True) -> List[str]: """Returns a list of all the vertices that are being used. diff --git a/aiida/orm/utils/links.py b/aiida/orm/utils/links.py index 535ca0caa5..f79667777f 100644 --- a/aiida/orm/utils/links.py +++ b/aiida/orm/utils/links.py @@ -21,7 +21,7 @@ LinkQuadruple = namedtuple('LinkQuadruple', ['source_id', 'target_id', 'link_type', 'link_label']) -def link_triple_exists(source, target, link_type, link_label): +def link_triple_exists(source, target, link_type, link_label, backend=None): """Return whether a link with the given type and label exists between the given source and target node. :param source: node from which the link is outgoing @@ -42,7 +42,7 @@ def link_triple_exists(source, target, link_type, link_label): # Here we have two stored nodes, so we need to check if the same link already exists in the database. # Finding just a single match is sufficient so we can use the `limit` clause for efficiency - builder = QueryBuilder() + builder = QueryBuilder(backend=backend) builder.append(Node, filters={'id': source.id}, project=['id']) builder.append(Node, filters={'id': target.id}, edge_filters={'type': link_type.value, 'label': link_label}) builder.limit(1) @@ -50,7 +50,7 @@ def link_triple_exists(source, target, link_type, link_label): return builder.count() != 0 -def validate_link(source, target, link_type, link_label): +def validate_link(source, target, link_type, link_label, backend=None): """ Validate adding a link of the given type and label from a given node to ourself. @@ -153,7 +153,7 @@ def validate_link(source, target, link_type, link_label): if outdegree == 'unique_triple' or indegree == 'unique_triple': # For a `unique_triple` degree we just have to check if an identical triple already exist, either in the cache # or stored, in which case, the new proposed link is a duplicate and thus illegal - duplicate_link_triple = link_triple_exists(source, target, link_type, link_label) + duplicate_link_triple = link_triple_exists(source, target, link_type, link_label, backend) # If the outdegree is `unique` there cannot already be any other outgoing link of that type if outdegree == 'unique' and source.get_outgoing(link_type=link_type, only_uuid=True).all(): diff --git a/aiida/orm/utils/remote.py b/aiida/orm/utils/remote.py index 71f3e339d3..deb40ab874 100644 --- a/aiida/orm/utils/remote.py +++ b/aiida/orm/utils/remote.py @@ -37,13 +37,13 @@ def clean_remote(transport, path): pass -def get_calcjob_remote_paths(pks=None, past_days=None, older_than=None, computers=None, user=None): +def get_calcjob_remote_paths(pks=None, past_days=None, older_than=None, computers=None, user=None, backend=None): """ Return a mapping of computer uuids to a list of remote paths, for a given set of calcjobs. The set of calcjobs will be determined by a query with filters based on the pks, past_days, older_than, computers and user arguments. - :param pks: onlu include calcjobs with a pk in this list + :param pks: only include calcjobs with a pk in this list :param past_days: only include calcjobs created since past_days :param older_than: only include calcjobs older than :param computers: only include calcjobs that were ran on these computers @@ -74,7 +74,7 @@ def get_calcjob_remote_paths(pks=None, past_days=None, older_than=None, computer if pks: filters_calc['id'] = {'in': pks} - query = orm.QueryBuilder() + query = orm.QueryBuilder(backend=backend) query.append(CalcJobNode, tag='calc', project=['attributes.remote_workdir'], filters=filters_calc) query.append(orm.Computer, with_node='calc', tag='computer', project=['*'], filters=filters_computer) query.append(orm.User, with_node='calc', filters={'email': user.email}) diff --git a/aiida/tools/graph/age_entities.py b/aiida/tools/graph/age_entities.py index a729d58ee7..de5fcec0d1 100644 --- a/aiida/tools/graph/age_entities.py +++ b/aiida/tools/graph/age_entities.py @@ -225,17 +225,6 @@ def aiida_cls(self): """Class of nodes contained in the entity set (node or group)""" return self._aiida_cls - def get_entities(self): - """Iterator that returns the AiiDA entities""" - for entity, in orm.QueryBuilder().append( - self._aiida_cls, project='*', filters={ - self._identifier: { - 'in': self.keyset - } - } - ).iterall(): - yield entity - class DirectedEdgeSet(AbstractSetContainer): """Extension of AbstractSetContainer diff --git a/aiida/tools/graph/age_rules.py b/aiida/tools/graph/age_rules.py index f768d8d07a..973d334909 100644 --- a/aiida/tools/graph/age_rules.py +++ b/aiida/tools/graph/age_rules.py @@ -11,6 +11,7 @@ from abc import ABCMeta, abstractmethod from collections import defaultdict +from copy import deepcopy import numpy as np @@ -65,7 +66,7 @@ class QueryRule(Operation, metaclass=ABCMeta): found in the last iteration of the query (ReplaceRule). """ - def __init__(self, querybuilder, max_iterations=1, track_edges=False): + def __init__(self, querybuilder: orm.QueryBuilder, max_iterations=1, track_edges=False): """Initialization method :param querybuilder: an instance of the QueryBuilder class from which to take the @@ -107,7 +108,7 @@ def get_spec_from_path(query_dict, idx): for pathspec in query_dict['path']: if not pathspec['entity_type']: pathspec['entity_type'] = 'node.Node.' - self._qbtemplate = orm.QueryBuilder(**query_dict) + self._qbtemplate = deepcopy(querybuilder) query_dict = self._qbtemplate.as_dict() self._first_tag = query_dict['path'][0]['tag'] self._last_tag = query_dict['path'][-1]['tag'] @@ -163,7 +164,7 @@ def _init_run(self, operational_set): # Copying qbtemplate so there's no problem if it is used again in a later run: query_dict = self._qbtemplate.as_dict() - self._querybuilder = orm.QueryBuilder.from_dict(query_dict) + self._querybuilder = deepcopy(self._qbtemplate) self._entity_to_identifier = operational_set[self._entity_to].identifier diff --git a/aiida/tools/graph/deletions.py b/aiida/tools/graph/deletions.py index d14d9c7dd5..61e0454f1d 100644 --- a/aiida/tools/graph/deletions.py +++ b/aiida/tools/graph/deletions.py @@ -71,17 +71,18 @@ def _missing_callback(_pks: Iterable[int]): for _pk in _pks: DELETE_LOGGER.warning(f'warning: node with pk<{_pk}> does not exist, skipping') - pks_set_to_delete = get_nodes_delete(pks, get_links=False, missing_callback=_missing_callback, - **traversal_rules)['nodes'] + pks_set_to_delete = get_nodes_delete( + pks, get_links=False, missing_callback=_missing_callback, backend=backend, **traversal_rules + )['nodes'] DELETE_LOGGER.report('%s Node(s) marked for deletion', len(pks_set_to_delete)) if pks_set_to_delete and DELETE_LOGGER.level == logging.DEBUG: - builder = QueryBuilder().append( - Node, filters={'id': { - 'in': pks_set_to_delete - }}, project=('uuid', 'id', 'node_type', 'label') - ) + builder = QueryBuilder( + backend=backend + ).append(Node, filters={'id': { + 'in': pks_set_to_delete + }}, project=('uuid', 'id', 'node_type', 'label')) DELETE_LOGGER.debug('Node(s) to delete:') for uuid, pk, type_string, label in builder.iterall(): try: @@ -113,6 +114,7 @@ def _missing_callback(_pks: Iterable[int]): def delete_group_nodes( pks: Iterable[int], dry_run: Union[bool, Callable[[Set[int]], bool]] = True, + backend=None, **traversal_rules: bool ) -> Tuple[Set[int], bool]: """Delete nodes contained in a list of groups (not the groups themselves!). @@ -149,7 +151,7 @@ def delete_group_nodes( :returns: (node pks to delete, whether they were deleted) """ - group_node_query = QueryBuilder().append( + group_node_query = QueryBuilder(backend=backend).append( Group, filters={ 'id': { @@ -160,4 +162,4 @@ def delete_group_nodes( ).append(Node, project='id', with_group='groups') group_node_query.distinct() node_pks = group_node_query.all(flat=True) - return delete_nodes(node_pks, dry_run=dry_run, **traversal_rules) + return delete_nodes(node_pks, dry_run=dry_run, backend=backend, **traversal_rules) diff --git a/aiida/tools/graph/graph_traversers.py b/aiida/tools/graph/graph_traversers.py index 6468ead76e..8f9f0c0f6d 100644 --- a/aiida/tools/graph/graph_traversers.py +++ b/aiida/tools/graph/graph_traversers.py @@ -9,7 +9,7 @@ ########################################################################### """Module for functions to traverse AiiDA graphs.""" import sys -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, cast from numpy import inf @@ -20,6 +20,9 @@ from aiida.tools.graph.age_entities import Basket from aiida.tools.graph.age_rules import RuleSaveWalkers, RuleSequence, RuleSetWalkers, UpdateRule +if TYPE_CHECKING: + from aiida.orm.implementation import Backend + if sys.version_info >= (3, 8): from typing import TypedDict @@ -35,6 +38,7 @@ def get_nodes_delete( starting_pks: Iterable[int], get_links: bool = False, missing_callback: Optional[Callable[[Iterable[int]], None]] = None, + backend: Optional['Backend'] = None, **traversal_rules: bool ) -> TraverseGraphOutput: """ @@ -59,9 +63,10 @@ def get_nodes_delete( traverse_output = traverse_graph( starting_pks, get_links=get_links, + backend=backend, links_forward=traverse_links['forward'], links_backward=traverse_links['backward'], - missing_callback=missing_callback + missing_callback=missing_callback, ) function_output = { @@ -74,7 +79,10 @@ def get_nodes_delete( def get_nodes_export( - starting_pks: Iterable[int], get_links: bool = False, **traversal_rules: bool + starting_pks: Iterable[int], + get_links: bool = False, + backend: Optional['Backend'] = None, + **traversal_rules: bool ) -> TraverseGraphOutput: """ This function will return the set of all nodes that can be connected @@ -99,6 +107,7 @@ def get_nodes_export( traverse_output = traverse_graph( starting_pks, get_links=get_links, + backend=backend, links_forward=traverse_links['forward'], links_backward=traverse_links['backward'] ) @@ -186,7 +195,8 @@ def traverse_graph( get_links: bool = False, links_forward: Iterable[LinkType] = (), links_backward: Iterable[LinkType] = (), - missing_callback: Optional[Callable[[Iterable[int]], None]] = None + missing_callback: Optional[Callable[[Iterable[int]], None]] = None, + backend: Optional['Backend'] = None ) -> TraverseGraphOutput: """ This function will return the set of all nodes that can be connected @@ -239,7 +249,7 @@ def traverse_graph( return {'nodes': set(), 'links': set()} return {'nodes': set(), 'links': None} - query_nodes = orm.QueryBuilder() + query_nodes = orm.QueryBuilder(backend=backend) query_nodes.append(orm.Node, project=['id'], filters={'id': {'in': operational_set}}) existing_pks = set(query_nodes.all(flat=True)) missing_pks = operational_set.difference(existing_pks) @@ -266,7 +276,7 @@ def traverse_graph( rules += [RuleSaveWalkers(stash)] if links_forward: - query_outgoing = orm.QueryBuilder() + query_outgoing = orm.QueryBuilder(backend=backend) query_outgoing.append(orm.Node, tag='sources') query_outgoing.append(orm.Node, edge_filters=filters_forwards, with_incoming='sources') rule_outgoing = UpdateRule(query_outgoing, max_iterations=1, track_edges=get_links) @@ -276,7 +286,7 @@ def traverse_graph( rules += [RuleSetWalkers(stash)] if links_backward: - query_incoming = orm.QueryBuilder() + query_incoming = orm.QueryBuilder(backend=backend) query_incoming.append(orm.Node, tag='sources') query_incoming.append(orm.Node, edge_filters=filters_backwards, with_outgoing='sources') rule_incoming = UpdateRule(query_incoming, max_iterations=1, track_edges=get_links) diff --git a/aiida/tools/visualization/graph.py b/aiida/tools/visualization/graph.py index 2793ce9ce0..b864ad28bb 100644 --- a/aiida/tools/visualization/graph.py +++ b/aiida/tools/visualization/graph.py @@ -10,17 +10,21 @@ """ provides functionality to create graphs of the AiiDa data providence, *via* graphviz. """ - import os from types import MappingProxyType # pylint: disable=no-name-in-module,useless-suppression +from typing import TYPE_CHECKING, Optional from graphviz import Digraph from aiida import orm from aiida.common import LinkType +from aiida.manage.manager import get_manager from aiida.orm.utils.links import LinkPair from aiida.tools.graph.graph_traversers import traverse_graph +if TYPE_CHECKING: + from aiida.orm.implementation import Backend + __all__ = ('Graph', 'default_link_styles', 'default_node_styles', 'pstate_node_styles', 'default_node_sublabels') @@ -359,7 +363,8 @@ def __init__( link_style_fn=None, node_style_fn=None, node_sublabel_fn=None, - node_id_type='pk' + node_id_type='pk', + backend: Optional['Backend'] = None ): """a class to create graphviz graphs of the AiiDA node provenance @@ -398,10 +403,16 @@ def __init__( self._node_styles = node_style_fn or default_node_styles self._node_sublabels = node_sublabel_fn or default_node_sublabels self._node_id_type = node_id_type + self._backend = backend or get_manager().get_backend() self._ignore_node_style = _OVERRIDE_STYLES_DICT['ignore_node'] self._origin_node_style = _OVERRIDE_STYLES_DICT['origin_node'] + @property + def backend(self) -> 'Backend': + """The backend used to create the graph""" + return self._backend + @property def graphviz(self): """return a copy of the graphviz.Digraph""" @@ -539,10 +550,11 @@ def add_incoming(self, node, link_types=(), annotate_links=None, return_pks=True (node_pk,), max_iterations=1, get_links=True, + backend=self.backend, links_backward=valid_link_types, ) - traversed_nodes = orm.QueryBuilder().append( + traversed_nodes = orm.QueryBuilder(backend=self.backend).append( orm.Node, filters={'id': { 'in': traversed_graph['nodes'] @@ -595,10 +607,11 @@ def add_outgoing(self, node, link_types=(), annotate_links=None, return_pks=True (node_pk,), max_iterations=1, get_links=True, + backend=self.backend, links_forward=valid_link_types, ) - traversed_nodes = orm.QueryBuilder().append( + traversed_nodes = orm.QueryBuilder(backend=self.backend).append( orm.Node, filters={'id': { 'in': traversed_graph['nodes'] @@ -664,6 +677,7 @@ def recurse_descendants( (origin_pk,), max_iterations=depth, get_links=True, + backend=self.backend, links_forward=valid_link_types, ) @@ -674,13 +688,14 @@ def recurse_descendants( traversed_graph['nodes'], max_iterations=1, get_links=True, + backend=self.backend, links_backward=[LinkType.INPUT_WORK, LinkType.INPUT_CALC] ) traversed_graph['nodes'] = traversed_graph['nodes'].union(traversed_outputs['nodes']) traversed_graph['links'] = traversed_graph['links'].union(traversed_outputs['links']) # Do one central query for all nodes in the Graph and generate a {id: Node} dictionary - traversed_nodes = orm.QueryBuilder().append( + traversed_nodes = orm.QueryBuilder(backend=self.backend).append( orm.Node, filters={'id': { 'in': traversed_graph['nodes'] @@ -755,6 +770,7 @@ def recurse_ancestors( (origin_pk,), max_iterations=depth, get_links=True, + backend=self.backend, links_backward=valid_link_types, ) @@ -765,13 +781,14 @@ def recurse_ancestors( traversed_graph['nodes'], max_iterations=1, get_links=True, + backend=self.backend, links_forward=[LinkType.CREATE, LinkType.RETURN] ) traversed_graph['nodes'] = traversed_graph['nodes'].union(traversed_outputs['nodes']) traversed_graph['links'] = traversed_graph['links'].union(traversed_outputs['links']) # Do one central query for all nodes in the Graph and generate a {id: Node} dictionary - traversed_nodes = orm.QueryBuilder().append( + traversed_nodes = orm.QueryBuilder(backend=self.backend).append( orm.Node, filters={'id': { 'in': traversed_graph['nodes'] @@ -842,6 +859,7 @@ def add_origin_to_targets( self.add_node(origin_node, style_override=dict(origin_style)) query = orm.QueryBuilder( + backend=self.backend, **{ 'path': [{ 'cls': origin_node.__class__, @@ -902,6 +920,7 @@ def add_origins_to_targets( origin_filters = {} query = orm.QueryBuilder( + backend=self.backend, **{'path': [{ 'cls': origin_cls, 'filters': origin_filters,