From f39c192e675cdf6183310064150cfb97a10ec2da Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Wed, 11 Aug 2021 09:35:30 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=91=8C=20IMPROVE:=20Add=20typing=20to=20Q?= =?UTF-8?q?ueryBuilder?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In preparation for migration to sqlalchemy 1.4 --- .pre-commit-config.yaml | 1 + aiida/orm/implementation/backends.py | 93 ++---- aiida/orm/implementation/querybuilder.py | 6 +- aiida/orm/querybuilder.py | 369 +++++++++++------------ docs/source/nitpick-exceptions | 3 + 5 files changed, 221 insertions(+), 251 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 07ea0ad2f4..164f8139a4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -68,6 +68,7 @@ repos: aiida/engine/.*py| aiida/manage/manager.py| aiida/manage/database/delete/nodes.py| + aiida/orm/querybuilder.py| aiida/orm/nodes/node.py| aiida/orm/nodes/process/.*py| aiida/repository/.*py| diff --git a/aiida/orm/implementation/backends.py b/aiida/orm/implementation/backends.py index f0dfd50fe2..f2463d1b9d 100644 --- a/aiida/orm/implementation/backends.py +++ b/aiida/orm/implementation/backends.py @@ -9,6 +9,16 @@ ########################################################################### """Generic backend related objects""" import abc +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sqlalchemy.orm.session import Session + + from aiida.orm.implementation import ( + BackendAuthInfoCollection, BackendCommentCollection, BackendComputerCollection, BackendGroupCollection, + BackendLogCollection, BackendNodeCollection, BackendQueryBuilder, BackendUserCollection + ) + from aiida.backends.general.abstractqueries import AbstractQueryManager __all__ = ('Backend',) @@ -21,85 +31,40 @@ def migrate(self): """Migrate the database to the latest schema generation or version.""" @abc.abstractproperty - def authinfos(self): - """ - Return the collection of authorisation information objects - - :return: the authinfo collection - :rtype: :class:`aiida.orm.implementation.BackendAuthInfoCollection` - """ + def authinfos(self) -> 'BackendAuthInfoCollection': + """Return the collection of authorisation information objects""" @abc.abstractproperty - def comments(self): - """ - Return the collection of comments - - :return: the comment collection - :rtype: :class:`aiida.orm.implementation.BackendCommentCollection` - """ + def comments(self) -> 'BackendCommentCollection': + """Return the collection of comments""" @abc.abstractproperty - def computers(self): - """ - Return the collection of computers - - :return: the computers collection - :rtype: :class:`aiida.orm.implementation.BackendComputerCollection` - """ + def computers(self) -> 'BackendComputerCollection': + """Return the collection of computers""" @abc.abstractproperty - def groups(self): - """ - Return the collection of groups - - :return: the groups collection - :rtype: :class:`aiida.orm.implementation.BackendGroupCollection` - """ + def groups(self) -> 'BackendGroupCollection': + """Return the collection of groups""" @abc.abstractproperty - def logs(self): - """ - Return the collection of logs - - :return: the log collection - :rtype: :class:`aiida.orm.implementation.BackendLogCollection` - """ + def logs(self) -> 'BackendLogCollection': + """Return the collection of logs""" @abc.abstractproperty - def nodes(self): - """ - Return the collection of nodes - - :return: the nodes collection - :rtype: :class:`aiida.orm.implementation.BackendNodeCollection` - """ + def nodes(self) -> 'BackendNodeCollection': + """Return the collection of nodes""" @abc.abstractproperty - def query_manager(self): - """ - Return the query manager for the objects stored in the backend - - :return: The query manger - :rtype: :class:`aiida.backends.general.abstractqueries.AbstractQueryManager` - """ + def query_manager(self) -> 'AbstractQueryManager': + """Return the query manager for the objects stored in the backend""" @abc.abstractmethod - def query(self): - """ - Return an instance of a query builder implementation for this backend - - :return: a new query builder instance - :rtype: :class:`aiida.orm.implementation.BackendQueryBuilder` - """ + def query(self) -> 'BackendQueryBuilder': + """Return an instance of a query builder implementation for this backend""" @abc.abstractproperty - def users(self): - """ - Return the collection of users - - :return: the users collection - :rtype: :class:`aiida.orm.implementation.BackendUserCollection` - """ + def users(self) -> 'BackendUserCollection': + """Return the collection of users""" @abc.abstractmethod def transaction(self): @@ -112,7 +77,7 @@ def transaction(self): """ @abc.abstractmethod - def get_session(self): + def get_session(self) -> 'Session': """Return a database session that can be used by the `QueryBuilder` to perform its query. :return: an instance of :class:`sqlalchemy.orm.session.Session` diff --git a/aiida/orm/implementation/querybuilder.py b/aiida/orm/implementation/querybuilder.py index acb1fa9af9..de266e524f 100644 --- a/aiida/orm/implementation/querybuilder.py +++ b/aiida/orm/implementation/querybuilder.py @@ -16,6 +16,7 @@ likely be moved to a `SqlAlchemyBasedQueryBuilder` class and restore this abstract class to being a pure agnostic one. """ import abc +from typing import TYPE_CHECKING import uuid # pylint: disable=no-name-in-module,import-error @@ -25,6 +26,9 @@ from aiida.common.lang import type_check +if TYPE_CHECKING: + from sqlalchemy.orm.session import Session # pylint: disable=ungrouped-imports + __all__ = ('BackendQueryBuilder',) @@ -111,7 +115,7 @@ def AiidaNode(self): from aiida.orm import Node return Node - def get_session(self): + def get_session(self) -> 'Session': """ :returns: a valid session, an instance of :class:`sqlalchemy.orm.session.Session` """ diff --git a/aiida/orm/querybuilder.py b/aiida/orm/querybuilder.py index 859180c508..07b80071f5 100644 --- a/aiida/orm/querybuilder.py +++ b/aiida/orm/querybuilder.py @@ -22,6 +22,7 @@ from inspect import isclass as inspect_isclass import copy import logging +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union, TYPE_CHECKING import warnings from sqlalchemy import and_, or_, not_, func as sa_func, select, join @@ -43,6 +44,10 @@ from . import entities from . import convert +if TYPE_CHECKING: + from sqlalchemy.orm import Query # pylint: disable=ungrouped-imports + from aiida.orm.implementation import Backend # pylint: disable=ungrouped-imports + __all__ = ('QueryBuilder',) _LOGGER = logging.getLogger(__name__) @@ -55,6 +60,26 @@ # subclassing for any entity type. This workaround should then be able to be removed. GROUP_ENTITY_TYPE_PREFIX = 'group.' +NODE_CLS_TYPE = Type[Any] # pylint: disable=invalid-name +PROJECT_TYPE = Union[str, dict, Sequence[Union[str, dict]]] # pylint: disable=invalid-name +FILTER_TYPE = Dict[str, Any] # pylint: disable=invalid-name + +try: + # new in python 3.8 + from typing import TypedDict # pylint: disable=ungrouped-imports + + class PathItemType(TypedDict): + """An item on the query path""" + + entity_type: Any + tag: str + joining_keyword: str + joining_value: str + outerjoin: bool + edge_tag: str +except ImportError: + PathItemType = Dict[str, Any] # type: ignore + def get_querybuilder_classifiers_from_cls(cls, query): # pylint: disable=invalid-name """ @@ -73,7 +98,7 @@ def get_querybuilder_classifiers_from_cls(cls, query): # pylint: disable=invali from aiida.engine import Process from aiida.orm.utils.node import is_valid_node_type_string - classifiers = {} + classifiers: Dict[str, Optional[str]] = {} classifiers['process_type_string'] = None @@ -166,12 +191,12 @@ def get_querybuilder_classifiers_from_type(ormclass_type_string, query): # pyli Same as get_querybuilder_classifiers_from_cls, but accepts a string instead of a class. """ from aiida.orm.utils.node import is_valid_node_type_string - classifiers = {} + classifiers: Dict[str, Optional[str]] = {} classifiers['process_type_string'] = None classifiers['ormclass_type_string'] = ormclass_type_string.lower() - if classifiers['ormclass_type_string'].startswith(GROUP_ENTITY_TYPE_PREFIX): + if ormclass_type_string.lower().startswith(GROUP_ENTITY_TYPE_PREFIX): classifiers['ormclass_type_string'] = 'group.core' ormclass = query.Group elif classifiers['ormclass_type_string'] == 'computer': @@ -190,7 +215,7 @@ def get_querybuilder_classifiers_from_type(ormclass_type_string, query): # pyli return ormclass, classifiers -def get_node_type_filter(classifiers, subclassing): +def get_node_type_filter(classifiers: dict, subclassing: bool) -> dict: """ Return filter dictionaries given a set of classifiers. @@ -215,7 +240,7 @@ def get_node_type_filter(classifiers, subclassing): return filters -def get_process_type_filter(classifiers, subclassing): +def get_process_type_filter(classifiers: dict, subclassing: bool) -> dict: """ Return filter dictionaries given a set of classifiers. @@ -282,7 +307,7 @@ def get_process_type_filter(classifiers, subclassing): return filters -def get_group_type_filter(classifiers, subclassing): +def get_group_type_filter(classifiers: dict, subclassing: bool) -> dict: """Return filter dictionaries for `Group.type_string` given a set of classifiers. :param classifiers: a dictionary with classifiers (note: does *not* support lists) @@ -332,17 +357,28 @@ class QueryBuilder: _EDGE_TAG_DELIM = '--' _VALID_PROJECTION_KEYS = ('func', 'cast') - def __init__(self, backend=None, **kwargs): + def __init__( + self, + backend: Optional['Backend'] = None, + *, + debug: bool = False, + path: Optional[Sequence[Union[str, Dict[str, Any], NODE_CLS_TYPE]]] = (), + filters: Optional[Dict[str, FILTER_TYPE]] = None, + project: Optional[Dict[str, PROJECT_TYPE]] = None, + limit: Optional[int] = None, + offset: Optional[int] = None, + order_by: Optional[Any] = None, + ) -> None: """ Instantiates a QueryBuilder instance. Which backend is used decided here based on backend-settings (taken from the user profile). - This cannot be overriden so far by the user. + This cannot be overridden so far by the user. - :param bool debug: + :param debug: Turn on debug mode. This feature prints information on the screen about the stages of the QueryBuilder. Does not affect results. - :param list path: + :param path: A list of the vertices to traverse. Leave empty if you plan on using the method :func:`QueryBuilder.append`. :param filters: @@ -353,10 +389,10 @@ def __init__(self, backend=None, **kwargs): The projections to apply. You can specify the projections here, when appending to the query using :func:`QueryBuilder.append` or even later using :func:`QueryBuilder.add_projection`. Latter gives you API-details. - :param int limit: + :param limit: Limit the number of rows to this number. Check :func:`QueryBuilder.limit` for more information. - :param int offset: + :param offset: Set an offset for the results returned. Details in :func:`QueryBuilder.offset`. :param order_by: How to order the results. As the 2 above, can be set also at later stage, @@ -367,26 +403,23 @@ def __init__(self, backend=None, **kwargs): self._impl = backend.query() # A list storing the path being traversed by the query - self._path = [] - - # A list of unique aliases in same order as path - self._aliased_path = [] + self._path: List[PathItemType] = [] # A dictionary tag:alias of ormclass # redundant but makes life easier - self.tag_to_alias_map = {} - self.tag_to_projected_property_dict = {} + self.tag_to_alias_map: Dict[str, Any] = {} + self.tag_to_projected_property_dict: Dict[str, dict] = {} # A dictionary tag: filter specification for this alias - self._filters = {} + self._filters: Dict[str, FILTER_TYPE] = {} # A dictionary tag: projections for this alias - self._projections = {} + self._projections: Dict[str, List[dict]] = {} self.nr_of_projections = 0 - self._attrkeys_as_in_sql_result = None + self._attrkeys_as_in_sql_result: Optional[dict] = None - self._query = None + self._query: 'Query' = None # A dictionary for classes passed to the tag given to them # Everything is specified with unique tags, which are strings. @@ -402,10 +435,10 @@ def __init__(self, backend=None, **kwargs): # {PwCalculation:'PwCalculation', StructureData:'StructureData'} # Keep in mind that it needs to be checked (and this is done) whether the class # is used twice. In that case, the user has to provide a tag! - self._cls_to_tag_map = {} + self._cls_to_tag_map: Dict[Any, str] = {} - # Hashing the the internal queryhelp allows me to avoid to build a query again - self._hash = None + # Hashing the internal queryhelp avoids rebuild a query + self._hash: Optional[str] = None # The hash being None implies that the query will be build (Check the code in .get_query # The user can inject a query, this keyword stores whether this was done. @@ -413,11 +446,10 @@ def __init__(self, backend=None, **kwargs): self._injected = False # Setting debug levels: - self.set_debug(kwargs.pop('debug', False)) + self.set_debug(debug) # One can apply the path as a keyword. Allows for jsons to be given to the QueryBuilder. - path = kwargs.pop('path', []) - if not isinstance(path, (tuple, list)): + if not isinstance(path, (list, tuple)): raise TypeError('Path needs to be a tuple or a list') # If the user specified a path, I use the append method to analyze, see QueryBuilder.append for path_spec in path: @@ -433,43 +465,31 @@ def __init__(self, backend=None, **kwargs): # Projections. The user provides a dictionary, but the specific checks is # left to QueryBuilder.add_project. - projection_dict = kwargs.pop('project', {}) + projection_dict = project or {} if not isinstance(projection_dict, dict): raise TypeError('You need to provide the projections as dictionary') for key, val in projection_dict.items(): self.add_projection(key, val) # For filters, I also expect a dictionary, and the checks are done lower. - filter_dict = kwargs.pop('filters', {}) + filter_dict = filters or {} if not isinstance(filter_dict, dict): raise TypeError('You need to provide the filters as dictionary') for key, val in filter_dict.items(): self.add_filter(key, val) # The limit is caps the number of results returned, and can also be set with QueryBuilder.limit - self.limit(kwargs.pop('limit', None)) + self.limit(limit) # The offset returns results after the offset - self.offset(kwargs.pop('offset', None)) + self.offset(offset) # The user can also specify the order. - self._order_by = {} - order_spec = kwargs.pop('order_by', None) - if order_spec: - self.order_by(order_spec) - - # I've gone through all the keywords, popping each item - # If kwargs is not empty, there is a problem: - if kwargs: - valid_keys = ('path', 'filters', 'project', 'limit', 'offset', 'order_by') - raise ValueError( - 'Received additional keywords: {}' - '\nwhich I cannot process' - '\nValid keywords are: {}' - ''.format(list(kwargs.keys()), valid_keys) - ) + self._order_by: List[dict] = [] + if order_by: + self.order_by(order_by) - def __str__(self): + def __str__(self) -> str: """ When somebody hits: print(QueryBuilder) or print(str(QueryBuilder)) I want to print the SQL-query. Because it looks cool... @@ -531,11 +551,11 @@ def _get_ormclass(self, cls, ormclass_type_string): return ormclass, classifiers - def _get_unique_tag(self, classifiers): + def _get_unique_tag(self, classifiers) -> str: """ Using the function get_tag_from_type, I get a tag. I increment an index that is appended to that tag until I have an unused tag. - This function is called in :func:`QueryBuilder.append` when autotag is set to True. + This function is called in :func:`QueryBuilder.append` when no tag is given. :param dict classifiers: Classifiers, containing the string that defines the type of the AiiDA ORM class. @@ -580,18 +600,20 @@ def get_tag_from_type(classifiers): def append( self, - cls=None, - entity_type=None, - tag=None, - filters=None, - project=None, - subclassing=True, - edge_tag=None, - edge_filters=None, - edge_project=None, - outerjoin=False, - **kwargs - ): + cls: Optional[Union[NODE_CLS_TYPE, Sequence[NODE_CLS_TYPE]]] = None, + entity_type: Optional[Union[str, Sequence[str]]] = None, + tag: Optional[str] = None, + filters: Optional[FILTER_TYPE] = None, + project: Optional[PROJECT_TYPE] = None, + subclassing: bool = True, + edge_tag: Optional[str] = None, + edge_filters: Optional[FILTER_TYPE] = None, + edge_project: Optional[PROJECT_TYPE] = None, + outerjoin: bool = False, + joining_keyword: Optional[str] = None, + joining_value: Optional[Any] = None, + **kwargs: Any + ) -> 'QueryBuilder': """ Any iterative procedure to build the path for a graph query needs to invoke this method to append to the path. @@ -609,9 +631,7 @@ def append( cls=(Group, Node) :param entity_type: The node type of the class, if cls is not given. Also here, a tuple or list is accepted. - :type type: str - :param bool autotag: Whether to find automatically a unique tag. If this is set to True (default False), - :param str tag: + :param tag: A unique tag. If none is given, I will create a unique tag myself. :param filters: Filters to apply for this vertex. @@ -620,20 +640,27 @@ def append( Projections to apply. See usage examples for details. More information also in :meth:`.add_projection`. :param bool subclassing: - Whether to include subclasses of the given class - (default **True**). - E.g. Specifying a ProcessNode as cls will include CalcJobNode, WorkChainNode, CalcFunctionNode, etc.. - :param bool outerjoin: - If True, (default is False), will do a left outerjoin - instead of an inner join - :param str edge_tag: + Whether to include subclasses of the given class (default **True**). + E.g. Specifying a ProcessNode as cls will include CalcJobNode, WorkChainNode, CalcFunctionNode, etc.. + :param edge_tag: The tag that the edge will get. If nothing is specified (and there is a meaningful edge) the default is tag1--tag2 with tag1 being the entity joining from and tag2 being the entity joining to (this entity). - :param str edge_filters: + :param edge_filters: The filters to apply on the edge. Also here, details in :meth:`.add_filter`. - :param str edge_project: + :param edge_project: The project from the edges. API-details in :meth:`.add_projection`. + :param bool outerjoin: + If True, (default is False), will do a left outerjoin + instead of an inner join + + Joining can be specified in two ways: + + - Specifying the 'joining_keyword' and 'joining_value' arguments + - Specify a single keyword argument + + The joining keyword wil be ``with_*`` or ``direction``, depending on the joining entity type. + The joining value is the tag name or class of the entity to join to. A small usage example how this can be invoked:: @@ -648,14 +675,11 @@ def append( ) :return: self - :rtype: :class:`aiida.orm.QueryBuilder` """ # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements # INPUT CHECKS ########################## - # This function can be called by users, so I am checking the - # input now. - # First of all, let's make sure the specified - # the class or the type (not both) + # This function can be called by users, so I am checking the input now. + # First of all, let's make sure the specified the class or the type (not both) if cls is not None and entity_type is not None: raise ValueError(f'You cannot specify both a class ({cls}) and a entity_type ({entity_type})') @@ -665,21 +689,19 @@ def append( # Let's check if it is a valid class or type if cls: - if isinstance(cls, (tuple, list, set)): + if isinstance(cls, (list, tuple)): for sub_cls in cls: if not inspect_isclass(sub_cls): raise TypeError(f"{sub_cls} was passed with kw 'cls', but is not a class") - else: - if not inspect_isclass(cls): - raise TypeError(f"{cls} was passed with kw 'cls', but is not a class") + elif not inspect_isclass(cls): + raise TypeError(f"{cls} was passed with kw 'cls', but is not a class") elif entity_type is not None: - if isinstance(entity_type, (tuple, list, set)): + if isinstance(entity_type, (list, tuple)): for sub_type in entity_type: if not isinstance(sub_type, str): raise TypeError(f'{sub_type} was passed as entity_type, but is not a string') - else: - if not isinstance(entity_type, str): - raise TypeError(f'{entity_type} was passed as entity_type, but is not a string') + elif not isinstance(entity_type, str): + raise TypeError(f'{entity_type} was passed as entity_type, but is not a string') ormclass, classifiers = self._get_ormclass(cls, entity_type) @@ -711,7 +733,7 @@ def append( if isinstance(cls, (list, set)): tag_key = tuple(cls) else: - tag_key = cls + tag_key = cls # type: ignore[assignment] if tag_key in self._cls_to_tag_map.keys(): # In this case, this class already stands for another @@ -784,48 +806,42 @@ def append( # pylint: disable=too-many-nested-blocks try: # Get the functions that are implemented: - spec_to_function_map = [] + # 'direction 'was an old implementation, which is now converted below to with_outgoing or with_incoming + spec_to_function_map = {'direction'} for secondary_dict in self._get_function_map().values(): - for key in secondary_dict.keys(): - if key not in spec_to_function_map: - spec_to_function_map.append(key) - joining_keyword = kwargs.pop('joining_keyword', None) - joining_value = kwargs.pop('joining_value', None) + spec_to_function_map.update(secondary_dict.keys()) for key, val in kwargs.items(): if key not in spec_to_function_map: raise ValueError( - '{} is not a valid keyword ' - 'for joining specification\n' - 'Valid keywords are: ' - '{}'.format( - key, spec_to_function_map + ['cls', 'type', 'tag', 'autotag', 'filters', 'project'] - ) + f"'{key}' is not a valid keyword for joining specification\n" + f'Valid keywords are: {spec_to_function_map!r}' ) - elif joining_keyword: + if joining_keyword: raise ValueError( 'You already specified joining specification {}\n' 'But you now also want to specify {}' ''.format(joining_keyword, key) ) + + joining_keyword = key + if joining_keyword == 'direction': + if not isinstance(val, int): + raise TypeError('direction=n expects n to be an integer') + try: + if val < 0: + joining_keyword = 'with_outgoing' + elif val > 0: + joining_keyword = 'with_incoming' + else: + raise ValueError('direction=0 is not valid') + joining_value = self._path[-abs(val)]['tag'] + except IndexError as exc: + raise ValueError( + f'You have specified a non-existent entity with\ndirection={joining_value}\n{exc}\n' + ) else: - joining_keyword = key - if joining_keyword == 'direction': - if not isinstance(val, int): - raise TypeError('direction=n expects n to be an integer') - try: - if val < 0: - joining_keyword = 'with_outgoing' - elif val > 0: - joining_keyword = 'with_incoming' - else: - raise ValueError('direction=0 is not valid') - joining_value = self._path[-abs(val)]['tag'] - except IndexError as exc: - raise ValueError( - f'You have specified a non-existent entity with\ndirection={joining_value}\n{exc}\n' - ) - else: - joining_value = self._get_tag_from_specification(val) + joining_value = self._get_tag_from_specification(val) + # the default is that this vertice is 'with_incoming' as the previous one if joining_keyword is None and len(self._path) > 0: joining_keyword = 'with_incoming' @@ -902,16 +918,19 @@ def append( dict( entity_type=path_type, tag=tag, - joining_keyword=joining_keyword, - joining_value=joining_value, + # for the first item joining_keyword/joining_value can be None, + # but after they always default to 'with_incoming' of the previous item + joining_keyword=joining_keyword, # type: ignore + joining_value=joining_value, # type: ignore + # same for edge_tag for which a default is applied + edge_tag=edge_tag, # type: ignore outerjoin=outerjoin, - edge_tag=edge_tag ) ) return self - def order_by(self, order_by): + def order_by(self, order_by: Union[dict, List[dict], Tuple[dict, ...]]) -> 'QueryBuilder': """ Set the entity to order by @@ -961,7 +980,7 @@ def order_by(self, order_by): '[columns to sort]' ''.format(order_spec) ) - _order_spec = {} + _order_spec: dict = {} for tagspec, items_to_order_by in order_spec.items(): if not isinstance(items_to_order_by, (tuple, list)): items_to_order_by = [items_to_order_by] @@ -1013,7 +1032,7 @@ def order_by(self, order_by): self._order_by.append(_order_spec) return self - def add_filter(self, tagspec, filter_spec): + def add_filter(self, tagspec: str, filter_spec: FILTER_TYPE) -> None: """ Adding a filter to my filters. @@ -1036,7 +1055,7 @@ def add_filter(self, tagspec, filter_spec): self._filters[tag].update(filters) @staticmethod - def _process_filters(filters): + def _process_filters(filters: FILTER_TYPE) -> dict: """Process filters.""" if not isinstance(filters, dict): raise TypeError('Filters have to be passed as dictionaries') @@ -1063,7 +1082,7 @@ def _add_node_type_filter(self, tagspec, classifiers, subclassing): """ if isinstance(classifiers, list): # If a list was passed to QueryBuilder.append, this propagates to a list in the classifiers - entity_type_filter = {'or': []} + entity_type_filter: dict = {'or': []} for classifier in classifiers: entity_type_filter['or'].append(get_node_type_filter(classifier, subclassing)) else: @@ -1083,7 +1102,7 @@ def _add_process_type_filter(self, tagspec, classifiers, subclassing): """ if isinstance(classifiers, list): # If a list was passed to QueryBuilder.append, this propagates to a list in the classifiers - process_type_filter = {'or': []} + process_type_filter: dict = {'or': []} for classifier in classifiers: if classifier['process_type_string'] is not None: process_type_filter['or'].append(get_process_type_filter(classifier, subclassing)) @@ -1106,7 +1125,7 @@ def _add_group_type_filter(self, tagspec, classifiers, subclassing): """ if isinstance(classifiers, list): # If a list was passed to QueryBuilder.append, this propagates to a list in the classifiers - type_string_filter = {'or': []} + type_string_filter: dict = {'or': []} for classifier in classifiers: type_string_filter['or'].append(get_group_type_filter(classifier, subclassing)) else: @@ -1114,7 +1133,7 @@ def _add_group_type_filter(self, tagspec, classifiers, subclassing): self.add_filter(tagspec, {'type_string': type_string_filter}) - def add_projection(self, tag_spec, projection_spec): + def add_projection(self, tag_spec: str, projection_spec: PROJECT_TYPE) -> None: r""" Adds a projection @@ -1167,7 +1186,7 @@ def add_projection(self, tag_spec, projection_spec): print('DEBUG: Adding projection of', tag_spec) print(' projection', projection_spec) if not isinstance(projection_spec, (list, tuple)): - projection_spec = [projection_spec] + projection_spec = [projection_spec] # type: ignore for projection in projection_spec: if isinstance(projection, dict): _thisprojection = projection @@ -1289,7 +1308,7 @@ def _get_tag_from_specification(self, specification): ) return tag - def set_debug(self, debug): + def set_debug(self, debug: bool) -> 'QueryBuilder': """ Run in debug mode. This does not affect functionality, but prints intermediate stages when creating a query on screen. @@ -1302,7 +1321,7 @@ def set_debug(self, debug): return self - def limit(self, limit): + def limit(self, limit: Optional[int]) -> 'QueryBuilder': """ Set the limit (nr of rows to return) @@ -1314,7 +1333,7 @@ def limit(self, limit): self._limit = limit return self - def offset(self, offset): + def offset(self, offset: Optional[int]) -> 'QueryBuilder': """ Set the offset. If offset is set, that many rows are skipped before returning. *offset* = 0 is the same as omitting setting the offset. @@ -1716,7 +1735,7 @@ def _join_comment_user(self, joined_entity, entity_to_join, isouterjoin): self._check_dbentities((joined_entity, self._impl.Comment), (entity_to_join, self._impl.User), 'with_comment') self._query = self._query.join(entity_to_join, joined_entity.user_id == entity_to_join.id, isouter=isouterjoin) - def _get_function_map(self): + def _get_function_map(self) -> Dict[str, Dict[str, Callable[[Any, Any, bool], None]]]: """ Map relationship type keywords to functions The new mapping (since 1.0.0a5) is a two level dictionary. The first level defines the entity which has been @@ -1733,37 +1752,31 @@ def _get_function_map(self): 'with_computer': self._join_to_computer_used, 'with_user': self._join_created_by, 'with_group': self._join_group_members, - 'direction': None, }, 'computer': { 'with_node': self._join_computer, - 'direction': None, }, 'user': { 'with_comment': self._join_comment_user, 'with_node': self._join_creator_of, 'with_group': self._join_group_user, - 'direction': None, }, 'group': { 'with_node': self._join_groups, 'with_user': self._join_user_group, - 'direction': None, }, 'comment': { 'with_user': self._join_user_comment, 'with_node': self._join_node_comment, - 'direction': None }, 'log': { 'with_node': self._join_node_log, - 'direction': None } } - return mapping + return mapping # type: ignore - def _get_connecting_node(self, index, joining_keyword=None, joining_value=None, **kwargs): + def _get_connecting_node(self, index: int, joining_keyword: str, joining_value: str, **_: Any): """ :param querydict: A dictionary specifying how the current node @@ -1783,32 +1796,22 @@ def _get_connecting_node(self, index, joining_keyword=None, joining_value=None, else: calling_entity = entity_type - if joining_keyword == 'direction': - if joining_value > 0: - returnval = self._aliased_path[index - joining_value], self._join_outputs - elif joining_value < 0: - returnval = self._aliased_path[index + joining_value], self._join_inputs - else: - raise Exception('Direction 0 is not valid') - else: + try: + func = self._get_function_map()[calling_entity][joining_keyword] + except KeyError: + raise ValueError(f"'{joining_keyword}' is not a valid joining keyword for a '{calling_entity}' type entity") + + if isinstance(joining_value, str): try: - func = self._get_function_map()[calling_entity][joining_keyword] + return self.tag_to_alias_map[self._get_tag_from_specification(joining_value)], func except KeyError: raise ValueError( - f"'{joining_keyword}' is not a valid joining keyword for a '{calling_entity}' type entity" + 'Key {} is unknown to the types I know about:\n' + '{}'.format(self._get_tag_from_specification(joining_value), self.tag_to_alias_map.keys()) ) - - if isinstance(joining_value, int): - returnval = (self._aliased_path[joining_value], func) - elif isinstance(joining_value, str): - try: - returnval = self.tag_to_alias_map[self._get_tag_from_specification(joining_value)], func - except KeyError: - raise ValueError( - 'Key {} is unknown to the types I know about:\n' - '{}'.format(self._get_tag_from_specification(joining_value), self.tag_to_alias_map.keys()) - ) - return returnval + raise ValueError( + f'Key {self._get_tag_from_specification(joining_value)} value is not a string:\n{joining_value}' + ) @property def queryhelp(self): @@ -1936,7 +1939,7 @@ def _build(self): # LINK-PROJECTIONS ######################### for vertex in self._path[1:]: - edge_tag = vertex.get('edge_tag', None) + edge_tag = vertex.get('edge_tag', None) # type: ignore if self._debug: print('DEBUG: Checking projections for edges:') print( @@ -1990,12 +1993,6 @@ def _build(self): return self._query - def get_aliases(self): - """ - :returns: the list of aliases - """ - return self._aliased_path - def get_alias(self, tag): """ In order to continue a query by the user, this utility function @@ -2093,7 +2090,7 @@ def inject_query(self, query): self._query = query self._injected = True - def distinct(self): + def distinct(self) -> 'QueryBuilder': """ Asks for distinct rows, which is the same as asking the backend to remove duplicates. @@ -2134,7 +2131,7 @@ def first(self): if not isinstance(result, (list, tuple)): result = [result] - if len(result) != len(self._attrkeys_as_in_sql_result): + if not self._attrkeys_as_in_sql_result or len(result) != len(self._attrkeys_as_in_sql_result): raise Exception('length of query result does not match the number of specified projections') return [self.get_aiida_entity_res(self._impl.get_aiida_res(rowitem)) for colindex, rowitem in enumerate(result)] @@ -2154,7 +2151,7 @@ def one(self): raise NotExistent('No result was found') return res[0] - def count(self): + def count(self) -> int: """ Counts the number of rows returned by the backend. @@ -2186,7 +2183,7 @@ def iterall(self, batch_size=100): yield item - def iterdict(self, batch_size=100): + def iterdict(self, batch_size: Optional[int] = 100) -> Iterable[Dict]: """ Same as :meth:`.dict`, but returns a generator. Be aware that this is only safe if no commit will take place during this @@ -2208,7 +2205,7 @@ def iterdict(self, batch_size=100): yield item - def all(self, batch_size=None, flat=False): + def all(self, batch_size: Optional[int] = None, flat: bool = False): """Executes the full query with the order of the rows as returned by the backend. The order inside each row is given by the order of the vertices in the path and the order of the projections for @@ -2227,7 +2224,7 @@ def all(self, batch_size=None, flat=False): return [projection for entry in matches for projection in entry] - def dict(self, batch_size=None): + def dict(self, batch_size: Optional[int] = None) -> List[dict]: """ Executes the full query with the order of the rows as returned by the backend. the order inside each row is given by the order of the vertices in the path @@ -2281,7 +2278,7 @@ def dict(self, batch_size=None): """ return list(self.iterdict(batch_size=batch_size)) - def inputs(self, **kwargs): + def inputs(self, **kwargs: Any) -> 'QueryBuilder': """ Join to inputs of previous vertice in path. @@ -2290,10 +2287,10 @@ def inputs(self, **kwargs): from aiida.orm import Node join_to = self._path[-1]['tag'] cls = kwargs.pop('cls', Node) - self.append(cls=cls, with_outgoing=join_to, autotag=True, **kwargs) + self.append(cls=cls, with_outgoing=join_to, **kwargs) return self - def outputs(self, **kwargs): + def outputs(self, **kwargs: Any) -> 'QueryBuilder': """ Join to outputs of previous vertice in path. @@ -2302,10 +2299,10 @@ def outputs(self, **kwargs): from aiida.orm import Node join_to = self._path[-1]['tag'] cls = kwargs.pop('cls', Node) - self.append(cls=cls, with_incoming=join_to, autotag=True, **kwargs) + self.append(cls=cls, with_incoming=join_to, **kwargs) return self - def children(self, **kwargs): + def children(self, **kwargs: Any) -> 'QueryBuilder': """ Join to children/descendants of previous vertice in path. @@ -2314,10 +2311,10 @@ def children(self, **kwargs): from aiida.orm import Node join_to = self._path[-1]['tag'] cls = kwargs.pop('cls', Node) - self.append(cls=cls, with_ancestors=join_to, autotag=True, **kwargs) + self.append(cls=cls, with_ancestors=join_to, **kwargs) return self - def parents(self, **kwargs): + def parents(self, **kwargs: Any) -> 'QueryBuilder': """ Join to parents/ancestors of previous vertice in path. @@ -2326,5 +2323,5 @@ def parents(self, **kwargs): from aiida.orm import Node join_to = self._path[-1]['tag'] cls = kwargs.pop('cls', Node) - self.append(cls=cls, with_descendants=join_to, autotag=True, **kwargs) + self.append(cls=cls, with_descendants=join_to, **kwargs) return self diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 3dae986e7d..1529e9a9f3 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -161,3 +161,6 @@ py:class pgsu.PGSU py:meth pgsu.PGSU.__init__ py:class jsonschema.exceptions._Error + +py:class Session +py:class BackendQueryBuilder