diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 78b1615e1c..2e8baeb6d0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,7 +43,8 @@ repos: files: >- (?x)^( aiida/common/progress_reporter.py| - aiida/engine/processes/calcjobs/calcjob.py| + aiida/manage/manager.py| + aiida/engine/.*py| aiida/manage/database/delete/nodes.py| aiida/tools/graph/graph_traversers.py| aiida/tools/groups/paths.py| diff --git a/aiida/engine/__init__.py b/aiida/engine/__init__.py index 41e147e19e..984ff61866 100644 --- a/aiida/engine/__init__.py +++ b/aiida/engine/__init__.py @@ -14,4 +14,4 @@ from .processes import * from .utils import * -__all__ = (launch.__all__ + processes.__all__ + utils.__all__) +__all__ = (launch.__all__ + processes.__all__ + utils.__all__) # type: ignore[name-defined] diff --git a/aiida/engine/daemon/client.py b/aiida/engine/daemon/client.py index 32d96466bf..1215bb5056 100644 --- a/aiida/engine/daemon/client.py +++ b/aiida/engine/daemon/client.py @@ -16,13 +16,21 @@ import shutil import socket import tempfile +from typing import Any, Dict, Optional, TYPE_CHECKING from aiida.manage.configuration import get_config, get_config_option +from aiida.manage.configuration.profile import Profile + +if TYPE_CHECKING: + from circus.client import CircusClient VERDI_BIN = shutil.which('verdi') # Recent versions of virtualenv create the environment variable VIRTUAL_ENV VIRTUALENV = os.environ.get('VIRTUAL_ENV', None) +# see https://github.com/python/typing/issues/182 +JsonDictType = Dict[str, Any] + class ControllerProtocol(enum.Enum): """ @@ -33,13 +41,13 @@ class ControllerProtocol(enum.Enum): TCP = 1 -def get_daemon_client(profile_name=None): +def get_daemon_client(profile_name: Optional[str] = None) -> 'DaemonClient': """ Return the daemon client for the given profile or the current profile if not specified. :param profile_name: the profile name, will use the current profile if None :return: the daemon client - :rtype: :class:`aiida.engine.daemon.client.DaemonClient` + :raises aiida.common.MissingConfigurationError: if the configuration file cannot be found :raises aiida.common.ProfileConfigurationError: if the given profile does not exist """ @@ -65,7 +73,7 @@ class DaemonClient: # pylint: disable=too-many-public-methods _DAEMON_NAME = 'aiida-{name}' _ENDPOINT_PROTOCOL = ControllerProtocol.IPC - def __init__(self, profile): + def __init__(self, profile: Profile): """ Construct a DaemonClient instance for a given profile @@ -73,22 +81,22 @@ def __init__(self, profile): """ config = get_config() self._profile = profile - self._SOCKET_DIRECTORY = None # pylint: disable=invalid-name - self._DAEMON_TIMEOUT = config.get_option('daemon.timeout') # pylint: disable=invalid-name + self._SOCKET_DIRECTORY: Optional[str] = None # pylint: disable=invalid-name + self._DAEMON_TIMEOUT: int = config.get_option('daemon.timeout') # pylint: disable=invalid-name @property - def profile(self): + def profile(self) -> Profile: return self._profile @property - def daemon_name(self): + def daemon_name(self) -> str: """ Get the daemon name which is tied to the profile name """ return self._DAEMON_NAME.format(name=self.profile.name) @property - def cmd_string(self): + def cmd_string(self) -> str: """ Return the command string to start the AiiDA daemon """ @@ -101,42 +109,42 @@ def cmd_string(self): return f'{VERDI_BIN} -p {self.profile.name} devel run_daemon' @property - def loglevel(self): + def loglevel(self) -> str: return get_config_option('logging.circus_loglevel') @property - def virtualenv(self): + def virtualenv(self) -> Optional[str]: return VIRTUALENV @property - def circus_log_file(self): + def circus_log_file(self) -> str: return self.profile.filepaths['circus']['log'] @property - def circus_pid_file(self): + def circus_pid_file(self) -> str: return self.profile.filepaths['circus']['pid'] @property - def circus_port_file(self): + def circus_port_file(self) -> str: return self.profile.filepaths['circus']['port'] @property - def circus_socket_file(self): + def circus_socket_file(self) -> str: return self.profile.filepaths['circus']['socket']['file'] @property - def circus_socket_endpoints(self): + def circus_socket_endpoints(self) -> Dict[str, str]: return self.profile.filepaths['circus']['socket'] @property - def daemon_log_file(self): + def daemon_log_file(self) -> str: return self.profile.filepaths['daemon']['log'] @property - def daemon_pid_file(self): + def daemon_pid_file(self) -> str: return self.profile.filepaths['daemon']['pid'] - def get_circus_port(self): + def get_circus_port(self) -> int: """ Retrieve the port for the circus controller, which should be written to the circus port file. If the daemon is running, the port file should exist and contain the port to which the controller is connected. @@ -158,7 +166,7 @@ def get_circus_port(self): return port - def get_circus_socket_directory(self): + def get_circus_socket_directory(self) -> str: """ Retrieve the absolute path of the directory where the circus sockets are stored if the IPC protocol is used and the daemon is running. If the daemon is running, the sockets file should exist and contain the @@ -192,7 +200,7 @@ def get_circus_socket_directory(self): self._SOCKET_DIRECTORY = socket_dir_path return socket_dir_path - def get_daemon_pid(self): + def get_daemon_pid(self) -> Optional[int]: """ Get the daemon pid which should be written in the daemon pid file specific to the profile @@ -207,7 +215,7 @@ def get_daemon_pid(self): return None @property - def is_daemon_running(self): + def is_daemon_running(self) -> bool: """ Return whether the daemon is running, which is determined by seeing if the daemon pid file is present @@ -215,7 +223,7 @@ def is_daemon_running(self): """ return self.get_daemon_pid() is not None - def delete_circus_socket_directory(self): + def delete_circus_socket_directory(self) -> None: """ Attempt to delete the directory used to store the circus endpoint sockets. Will not raise if the directory does not exist @@ -321,7 +329,7 @@ def get_tcp_endpoint(self, port=None): return endpoint @property - def client(self): + def client(self) -> 'CircusClient': """ Return an instance of the CircusClient with the endpoint defined by the controller endpoint, which used the port that was written to the port file upon starting of the daemon @@ -334,7 +342,7 @@ def client(self): from circus.client import CircusClient return CircusClient(endpoint=self.get_controller_endpoint(), timeout=self._DAEMON_TIMEOUT) - def call_client(self, command): + def call_client(self, command: JsonDictType) -> JsonDictType: """ Call the client with a specific command. Will check whether the daemon is running first by checking for the pid file. When the pid is found yet the call still fails with a @@ -358,47 +366,51 @@ def call_client(self, command): return result - def get_status(self): + def get_status(self) -> JsonDictType: """ Get the daemon running status :return: the client call response + If successful, will will contain 'status' key """ command = {'command': 'status', 'properties': {'name': self.daemon_name}} return self.call_client(command) - def get_numprocesses(self): + def get_numprocesses(self) -> JsonDictType: """ Get the number of running daemon processes :return: the client call response + If successful, will contain 'numprocesses' key """ command = {'command': 'numprocesses', 'properties': {'name': self.daemon_name}} return self.call_client(command) - def get_worker_info(self): + def get_worker_info(self) -> JsonDictType: """ Get workers statistics for this daemon :return: the client call response + If successful, will contain 'info' key """ command = {'command': 'stats', 'properties': {'name': self.daemon_name}} return self.call_client(command) - def get_daemon_info(self): + def get_daemon_info(self) -> JsonDictType: """ Get statistics about this daemon itself :return: the client call response + If successful, will contain 'info' key """ command = {'command': 'dstats', 'properties': {}} return self.call_client(command) - def increase_workers(self, number): + def increase_workers(self, number: int) -> JsonDictType: """ Increase the number of workers @@ -409,7 +421,7 @@ def increase_workers(self, number): return self.call_client(command) - def decrease_workers(self, number): + def decrease_workers(self, number: int) -> JsonDictType: """ Decrease the number of workers @@ -420,7 +432,7 @@ def decrease_workers(self, number): return self.call_client(command) - def stop_daemon(self, wait): + def stop_daemon(self, wait: bool) -> JsonDictType: """ Stop the daemon @@ -436,7 +448,7 @@ def stop_daemon(self, wait): return result - def restart_daemon(self, wait): + def restart_daemon(self, wait: bool) -> JsonDictType: """ Restart the daemon diff --git a/aiida/engine/daemon/execmanager.py b/aiida/engine/daemon/execmanager.py index 5f8a136589..6d1eecf5ca 100644 --- a/aiida/engine/daemon/execmanager.py +++ b/aiida/engine/daemon/execmanager.py @@ -13,23 +13,56 @@ the routines make reference to the suitable plugins for all plugin-specific operations. """ +from collections.abc import Mapping +from logging import LoggerAdapter import os import shutil +from tempfile import NamedTemporaryFile +from typing import Any, List, Optional, Mapping as MappingType, Tuple, Union from aiida.common import AIIDA_LOGGER, exceptions +from aiida.common.datastructures import CalcInfo from aiida.common.folders import SandboxFolder from aiida.common.links import LinkType -from aiida.orm import FolderData, Node +from aiida.orm import load_node, CalcJobNode, Code, FolderData, Node, RemoteData from aiida.orm.utils.log import get_dblogger_extra from aiida.plugins import DataFactory from aiida.schedulers.datastructures import JobState +from aiida.transports import Transport REMOTE_WORK_DIRECTORY_LOST_FOUND = 'lost+found' execlogger = AIIDA_LOGGER.getChild('execmanager') -def upload_calculation(node, transport, calc_info, folder, inputs=None, dry_run=False): +def _find_data_node(inputs: MappingType[str, Any], uuid: str) -> Optional[Node]: + """Find and return the node with the given UUID from a nested mapping of input nodes. + + :param inputs: (nested) mapping of nodes + :param uuid: UUID of the node to find + :return: instance of `Node` or `None` if not found + """ + data_node = None + + for input_node in inputs.values(): + if isinstance(input_node, Mapping): + data_node = _find_data_node(input_node, uuid) + elif isinstance(input_node, Node) and input_node.uuid == uuid: + data_node = input_node + if data_node is not None: + break + + return data_node + + +def upload_calculation( + node: CalcJobNode, + transport: Transport, + calc_info: CalcInfo, + folder: SandboxFolder, + inputs: Optional[MappingType[str, Any]] = None, + dry_run: bool = False +) -> None: """Upload a `CalcJob` instance :param node: the `CalcJobNode`. @@ -38,9 +71,6 @@ def upload_calculation(node, transport, calc_info, folder, inputs=None, dry_run= :param folder: temporary local file system folder containing the inputs written by `CalcJob.prepare_for_submission` """ # pylint: disable=too-many-locals,too-many-branches,too-many-statements - from logging import LoggerAdapter - from tempfile import NamedTemporaryFile - from aiida.orm import load_node, Code, RemoteData # If the calculation already has a `remote_folder`, simply return. The upload was apparently already completed # before, which can happen if the daemon is restarted and it shuts down after uploading but before getting the @@ -162,30 +192,10 @@ def upload_calculation(node, transport, calc_info, folder, inputs=None, dry_run= for uuid, filename, target in local_copy_list: logger.debug(f'[submission of calculation {node.uuid}] copying local file/folder to {target}') - def find_data_node(inputs, uuid): - """Find and return the node with the given UUID from a nested mapping of input nodes. - - :param inputs: (nested) mapping of nodes - :param uuid: UUID of the node to find - :return: instance of `Node` or `None` if not found - """ - from collections.abc import Mapping - data_node = None - - for input_node in inputs.values(): - if isinstance(input_node, Mapping): - data_node = find_data_node(input_node, uuid) - elif isinstance(input_node, Node) and input_node.uuid == uuid: - data_node = input_node - if data_node is not None: - break - - return data_node - try: data_node = load_node(uuid=uuid) except exceptions.NotExistent: - data_node = find_data_node(inputs, uuid) + data_node = _find_data_node(inputs, uuid) if inputs else None if data_node is None: logger.warning(f'failed to load Node<{uuid}> specified in the `local_copy_list`') @@ -294,7 +304,7 @@ def find_data_node(inputs, uuid): remotedata.store() -def submit_calculation(calculation, transport): +def submit_calculation(calculation: CalcJobNode, transport: Transport) -> str: """Submit a previously uploaded `CalcJob` to the scheduler. :param calculation: the instance of CalcJobNode to submit. @@ -322,7 +332,7 @@ def submit_calculation(calculation, transport): return job_id -def retrieve_calculation(calculation, transport, retrieved_temporary_folder): +def retrieve_calculation(calculation: CalcJobNode, transport: Transport, retrieved_temporary_folder: str) -> None: """Retrieve all the files of a completed job calculation using the given transport. If the job defined anything in the `retrieve_temporary_list`, those entries will be stored in the @@ -394,7 +404,7 @@ def retrieve_calculation(calculation, transport, retrieved_temporary_folder): retrieved_files.add_incoming(calculation, link_type=LinkType.CREATE, link_label=calculation.link_label_retrieved) -def kill_calculation(calculation, transport): +def kill_calculation(calculation: CalcJobNode, transport: Transport) -> bool: """ Kill the calculation through the scheduler @@ -425,7 +435,13 @@ def kill_calculation(calculation, transport): return True -def _retrieve_singlefiles(job, transport, folder, retrieve_file_list, logger_extra=None): +def _retrieve_singlefiles( + job: CalcJobNode, + transport: Transport, + folder: SandboxFolder, + retrieve_file_list: List[Tuple[str, str, str]], + logger_extra: Optional[dict] = None +): """Retrieve files specified through the singlefile list mechanism.""" singlefile_list = [] for (linkname, subclassname, filename) in retrieve_file_list: @@ -454,7 +470,10 @@ def _retrieve_singlefiles(job, transport, folder, retrieve_file_list, logger_ext fil.store() -def retrieve_files_from_list(calculation, transport, folder, retrieve_list): +def retrieve_files_from_list( + calculation: CalcJobNode, transport: Transport, folder: str, retrieve_list: List[Union[str, Tuple[str, str, int], + list]] +) -> None: """ Retrieve all the files in the retrieve_list from the remote into the local folder instance through the transport. The entries in the retrieve_list diff --git a/aiida/engine/daemon/runner.py b/aiida/engine/daemon/runner.py index 7085c15167..1826da38c2 100644 --- a/aiida/engine/daemon/runner.py +++ b/aiida/engine/daemon/runner.py @@ -14,12 +14,13 @@ from aiida.common.log import configure_logging from aiida.engine.daemon.client import get_daemon_client +from aiida.engine.runners import Runner from aiida.manage.manager import get_manager LOGGER = logging.getLogger(__name__) -async def shutdown_runner(runner): +async def shutdown_runner(runner: Runner) -> None: """Cleanup tasks tied to the service's shutdown.""" LOGGER.info('Received signal to shut down the daemon runner') @@ -40,8 +41,10 @@ async def shutdown_runner(runner): await asyncio.gather(*tasks, return_exceptions=True) runner.close() + LOGGER.info('Daemon runner stopped') + -def start_daemon(): +def start_daemon() -> None: """Start a daemon runner for the currently configured profile.""" daemon_client = get_daemon_client() configure_logging(daemon=True, daemon_log_file=daemon_client.daemon_log_file) @@ -65,4 +68,4 @@ def start_daemon(): LOGGER.info('Received a SystemError: %s', exception) runner.close() - LOGGER.info('Daemon runner stopped') + LOGGER.info('Daemon runner started') diff --git a/aiida/engine/launch.py b/aiida/engine/launch.py index 8cdcafbb9f..6026ac4731 100644 --- a/aiida/engine/launch.py +++ b/aiida/engine/launch.py @@ -8,27 +8,30 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Top level functions that can be used to launch a Process.""" +from typing import Any, Dict, Tuple, Type, Union from aiida.common import InvalidOperation from aiida.manage import manager +from aiida.orm import ProcessNode from .processes.functions import FunctionProcess -from .processes.process import Process +from .processes.process import Process, ProcessBuilder from .utils import is_process_scoped, instantiate_process __all__ = ('run', 'run_get_pk', 'run_get_node', 'submit') +TYPE_RUN_PROCESS = Union[Process, Type[Process], ProcessBuilder] # pylint: disable=invalid-name +# run can also be process function, but it is not clear what type this should be +TYPE_SUBMIT_PROCESS = Union[Process, Type[Process], ProcessBuilder] # pylint: disable=invalid-name -def run(process, *args, **inputs): + +def run(process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> Dict[str, Any]: """Run the process with the supplied inputs in a local runner that will block until the process is completed. :param process: the process class or process function to run - :type process: :class:`aiida.engine.Process` - :param inputs: the inputs to be passed to the process - :type inputs: dict :return: the outputs of the process - :rtype: dict + """ if isinstance(process, Process): runner = process.runner @@ -38,17 +41,13 @@ def run(process, *args, **inputs): return runner.run(process, *args, **inputs) -def run_get_node(process, *args, **inputs): +def run_get_node(process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> Tuple[Dict[str, Any], ProcessNode]: """Run the process with the supplied inputs in a local runner that will block until the process is completed. - :param process: the process class or process function to run - :type process: :class:`aiida.engine.Process` - + :param process: the process class, instance, builder or function to run :param inputs: the inputs to be passed to the process - :type inputs: dict :return: tuple of the outputs of the process and the process node - :rtype: (dict, :class:`aiida.orm.ProcessNode`) """ if isinstance(process, Process): @@ -59,17 +58,14 @@ def run_get_node(process, *args, **inputs): return runner.run_get_node(process, *args, **inputs) -def run_get_pk(process, *args, **inputs): +def run_get_pk(process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> Tuple[Dict[str, Any], int]: """Run the process with the supplied inputs in a local runner that will block until the process is completed. - :param process: the process class or process function to run - :type process: :class:`aiida.engine.Process` - + :param process: the process class, instance, builder or function to run :param inputs: the inputs to be passed to the process - :type inputs: dict :return: tuple of the outputs of the process and process node pk - :rtype: (dict, int) + """ if isinstance(process, Process): runner = process.runner @@ -79,7 +75,7 @@ def run_get_pk(process, *args, **inputs): return runner.run_get_pk(process, *args, **inputs) -def submit(process, **inputs): +def submit(process: TYPE_SUBMIT_PROCESS, **inputs: Any) -> ProcessNode: """Submit the process with the supplied inputs to the daemon immediately returning control to the interpreter. .. warning: this should not be used within another process. Instead, there one should use the `submit` method of @@ -87,14 +83,11 @@ def submit(process, **inputs): .. warning: submission of processes requires `store_provenance=True` - :param process: the process class to submit - :type process: :class:`aiida.engine.Process` - + :param process: the process class, instance or builder to submit :param inputs: the inputs to be passed to the process - :type inputs: dict :return: the calculation node of the process - :rtype: :class:`aiida.orm.ProcessNode` + """ # Submitting from within another process requires `self.submit` unless it is a work function, in which case the # current process in the scope should be an instance of `FunctionProcess` @@ -102,27 +95,29 @@ def submit(process, **inputs): raise InvalidOperation('Cannot use top-level `submit` from within another process, use `self.submit` instead') runner = manager.get_manager().get_runner() + assert runner.persister is not None, 'runner does not have a persister' + assert runner.controller is not None, 'runner does not have a persister' - process = instantiate_process(runner, process, **inputs) + process_inited = instantiate_process(runner, process, **inputs) # If a dry run is requested, simply forward to `run`, because it is not compatible with `submit`. We choose for this # instead of raising, because in this way the user does not have to change the launcher when testing. - if process.metadata.get('dry_run', False): - _, node = run_get_node(process) + if process_inited.metadata.get('dry_run', False): + _, node = run_get_node(process_inited) return node - if not process.metadata.store_provenance: + if not process_inited.metadata.store_provenance: raise InvalidOperation('cannot submit a process with `store_provenance=False`') - runner.persister.save_checkpoint(process) - process.close() + runner.persister.save_checkpoint(process_inited) + process_inited.close() # Do not wait for the future's result, because in the case of a single worker this would cock-block itself - runner.controller.continue_process(process.pid, nowait=False, no_reply=True) + runner.controller.continue_process(process_inited.pid, nowait=False, no_reply=True) - return process.node + return process_inited.node # Allow one to also use run.get_node and run.get_pk as a shortcut, without having to import the functions themselves -run.get_node = run_get_node -run.get_pk = run_get_pk +run.get_node = run_get_node # type: ignore[attr-defined] +run.get_pk = run_get_pk # type: ignore[attr-defined] diff --git a/aiida/engine/persistence.py b/aiida/engine/persistence.py index 5aedd9d386..2ccdac03c1 100644 --- a/aiida/engine/persistence.py +++ b/aiida/engine/persistence.py @@ -13,21 +13,27 @@ import importlib import logging import traceback +from typing import Any, Hashable, Optional, TYPE_CHECKING -import plumpy +import plumpy.persistence +import plumpy.loaders +from plumpy.exceptions import PersistenceError from aiida.orm.utils import serialize +if TYPE_CHECKING: + from aiida.engine.processes.process import Process + __all__ = ('AiiDAPersister', 'ObjectLoader', 'get_object_loader') LOGGER = logging.getLogger(__name__) OBJECT_LOADER = None -class ObjectLoader(plumpy.DefaultObjectLoader): +class ObjectLoader(plumpy.loaders.DefaultObjectLoader): """Custom object loader for `aiida-core`.""" - def load_object(self, identifier): + def load_object(self, identifier: str) -> Any: # pylint: disable=no-self-use """Attempt to load the object identified by the given `identifier`. .. note:: We override the `plumpy.DefaultObjectLoader` to be able to throw an `ImportError` instead of a @@ -37,11 +43,11 @@ def load_object(self, identifier): :return: loaded object :raises ImportError: if the object cannot be loaded """ - module, name = identifier.split(':') + module_name, name = identifier.split(':') try: - module = importlib.import_module(module) + module = importlib.import_module(module_name) except ImportError: - raise ImportError(f"module '{module}' from identifier '{identifier}' could not be loaded") + raise ImportError(f"module '{module_name}' from identifier '{identifier}' could not be loaded") try: return getattr(module, name) @@ -49,11 +55,11 @@ def load_object(self, identifier): raise ImportError(f"object '{name}' from identifier '{identifier}' could not be loaded") -def get_object_loader(): +def get_object_loader() -> ObjectLoader: """Return the global AiiDA object loader. :return: The global object loader - :rtype: :class:`plumpy.ObjectLoader` + """ global OBJECT_LOADER if OBJECT_LOADER is None: @@ -61,15 +67,15 @@ def get_object_loader(): return OBJECT_LOADER -class AiiDAPersister(plumpy.Persister): +class AiiDAPersister(plumpy.persistence.Persister): """Persister to take saved process instance states and persisting them to the database.""" - def save_checkpoint(self, process, tag=None): + def save_checkpoint(self, process: 'Process', tag: Optional[str] = None): # type: ignore[override] # pylint: disable=no-self-use """Persist a Process instance. :param process: :class:`aiida.engine.Process` :param tag: optional checkpoint identifier to allow distinguishing multiple checkpoints for the same process - :raises: :class:`plumpy.PersistenceError` Raised if there was a problem saving the checkpoint + :raises: :class:`PersistenceError` Raised if there was a problem saving the checkpoint """ LOGGER.debug('Persisting process<%d>', process.pid) @@ -77,26 +83,26 @@ def save_checkpoint(self, process, tag=None): raise NotImplementedError('Checkpoint tags not supported yet') try: - bundle = plumpy.Bundle(process, plumpy.LoadSaveContext(loader=get_object_loader())) + bundle = plumpy.persistence.Bundle(process, plumpy.persistence.LoadSaveContext(loader=get_object_loader())) except ImportError: # Couldn't create the bundle - raise plumpy.PersistenceError(f"Failed to create a bundle for '{process}': {traceback.format_exc()}") + raise PersistenceError(f"Failed to create a bundle for '{process}': {traceback.format_exc()}") try: process.node.set_checkpoint(serialize.serialize(bundle)) except Exception: - raise plumpy.PersistenceError(f"Failed to store a checkpoint for '{process}': {traceback.format_exc()}") + raise PersistenceError(f"Failed to store a checkpoint for '{process}': {traceback.format_exc()}") return bundle - def load_checkpoint(self, pid, tag=None): + def load_checkpoint(self, pid: Hashable, tag: Optional[str] = None) -> plumpy.persistence.Bundle: # pylint: disable=no-self-use """Load a process from a persisted checkpoint by its process id. :param pid: the process id of the :class:`plumpy.Process` :param tag: optional checkpoint identifier to allow retrieving a specific sub checkpoint :return: a bundle with the process state :rtype: :class:`plumpy.Bundle` - :raises: :class:`plumpy.PersistenceError` Raised if there was a problem loading the checkpoint + :raises: :class:`PersistenceError` Raised if there was a problem loading the checkpoint """ from aiida.common.exceptions import MultipleObjectsError, NotExistent from aiida.orm import load_node @@ -107,17 +113,17 @@ def load_checkpoint(self, pid, tag=None): try: calculation = load_node(pid) except (MultipleObjectsError, NotExistent): - raise plumpy.PersistenceError(f'Failed to load the node for process<{pid}>: {traceback.format_exc()}') + raise PersistenceError(f'Failed to load the node for process<{pid}>: {traceback.format_exc()}') checkpoint = calculation.checkpoint if checkpoint is None: - raise plumpy.PersistenceError(f'Calculation<{calculation.pk}> does not have a saved checkpoint') + raise PersistenceError(f'Calculation<{calculation.pk}> does not have a saved checkpoint') try: bundle = serialize.deserialize(checkpoint) except Exception: - raise plumpy.PersistenceError(f'Failed to load the checkpoint for process<{pid}>: {traceback.format_exc()}') + raise PersistenceError(f'Failed to load the checkpoint for process<{pid}>: {traceback.format_exc()}') return bundle @@ -127,14 +133,14 @@ def get_checkpoints(self): :return: list of PersistedCheckpoint tuples with element containing the process id and optional checkpoint tag. """ - def get_process_checkpoints(self, pid): + def get_process_checkpoints(self, pid: Hashable): """Return a list of all the current persisted process checkpoints for the specified process. :param pid: the process pid :return: list of PersistedCheckpoint tuples with element containing the process id and optional checkpoint tag. """ - def delete_checkpoint(self, pid, tag=None): + def delete_checkpoint(self, pid: Hashable, tag: Optional[str] = None) -> None: # pylint: disable=no-self-use,unused-argument """Delete a persisted process checkpoint, where no error will be raised if the checkpoint does not exist. :param pid: the process id of the :class:`plumpy.Process` @@ -145,7 +151,7 @@ def delete_checkpoint(self, pid, tag=None): calc = load_node(pid) calc.delete_checkpoint() - def delete_process_checkpoints(self, pid): + def delete_process_checkpoints(self, pid: Hashable): """Delete all persisted checkpoints related to the given process id. :param pid: the process id of the :class:`aiida.engine.processes.process.Process` diff --git a/aiida/engine/processes/__init__.py b/aiida/engine/processes/__init__.py index de5a86cf18..b3045dcfd4 100644 --- a/aiida/engine/processes/__init__.py +++ b/aiida/engine/processes/__init__.py @@ -19,6 +19,6 @@ from .workchains import * __all__ = ( - builder.__all__ + calcjobs.__all__ + exit_code.__all__ + functions.__all__ + ports.__all__ + process.__all__ + - process_spec.__all__ + workchains.__all__ + builder.__all__ + calcjobs.__all__ + exit_code.__all__ + functions.__all__ + # type: ignore[name-defined] + ports.__all__ + process.__all__ + process_spec.__all__ + workchains.__all__ # type: ignore[name-defined] ) diff --git a/aiida/engine/processes/builder.py b/aiida/engine/processes/builder.py index 9a620244b4..c7f6939918 100644 --- a/aiida/engine/processes/builder.py +++ b/aiida/engine/processes/builder.py @@ -9,10 +9,14 @@ ########################################################################### """Convenience classes to help building the input dictionaries for Processes.""" import collections +from typing import Any, Type, TYPE_CHECKING from aiida.orm import Node from aiida.engine.processes.ports import PortNamespace +if TYPE_CHECKING: + from aiida.engine.processes.process import Process + __all__ = ('ProcessBuilder', 'ProcessBuilderNamespace') @@ -22,7 +26,7 @@ class ProcessBuilderNamespace(collections.abc.MutableMapping): Dynamically generates the getters and setters for the input ports of a given PortNamespace """ - def __init__(self, port_namespace): + def __init__(self, port_namespace: PortNamespace) -> None: """Dynamically construct the get and set properties for the ports of the given port namespace. For each port in the given port namespace a get and set property will be constructed dynamically @@ -30,7 +34,7 @@ def __init__(self, port_namespace): by calling str() on the Port, which should return the description of the Port. :param port_namespace: the inputs PortNamespace for which to construct the builder - :type port_namespace: str + """ # pylint: disable=super-init-not-called self._port_namespace = port_namespace @@ -52,7 +56,7 @@ def fgetter(self, name=name): return self._data.get(name) elif port.has_default(): - def fgetter(self, name=name, default=port.default): # pylint: disable=cell-var-from-loop + def fgetter(self, name=name, default=port.default): # type: ignore # pylint: disable=cell-var-from-loop return self._data.get(name, default) else: @@ -67,16 +71,12 @@ def fsetter(self, value, name=name): getter.setter(fsetter) # pylint: disable=too-many-function-args setattr(self.__class__, name, getter) - def __setattr__(self, attr, value): + def __setattr__(self, attr: str, value: Any) -> None: """Assign the given value to the port with key `attr`. .. note:: Any attributes without a leading underscore being set correspond to inputs and should hence be validated with respect to the corresponding input port from the process spec - :param attr: attribute - :type attr: str - - :param value: value """ if attr.startswith('_'): object.__setattr__(self, attr, value) @@ -87,7 +87,7 @@ def __setattr__(self, attr, value): if not self._port_namespace.dynamic: raise AttributeError(f'Unknown builder parameter: {attr}') else: - value = port.serialize(value) + value = port.serialize(value) # type: ignore[union-attr] validation_error = port.validate(value) if validation_error: raise ValueError(f'invalid attribute value {validation_error.message}') @@ -126,10 +126,8 @@ def _update(self, *args, **kwds): principle the method functions just as `collections.abc.MutableMapping.update`. :param args: a single mapping that should be mapped on the namespace - :type args: list :param kwds: keyword value pairs that should be mapped onto the ports - :type kwds: dict """ if len(args) > 1: raise TypeError(f'update expected at most 1 arguments, got {int(len(args))}') @@ -147,7 +145,7 @@ def _update(self, *args, **kwds): else: self.__setattr__(key, value) - def _inputs(self, prune=False): + def _inputs(self, prune: bool = False) -> dict: """Return the entire mapping of inputs specified for this builder. :param prune: boolean, when True, will prune nested namespaces that contain no actual values whatsoever @@ -182,7 +180,7 @@ def _prune(self, value): class ProcessBuilder(ProcessBuilderNamespace): # pylint: disable=too-many-ancestors """A process builder that helps setting up the inputs for creating a new process.""" - def __init__(self, process_class): + def __init__(self, process_class: Type['Process']): """Construct a `ProcessBuilder` instance for the given `Process` class. :param process_class: the `Process` subclass @@ -192,6 +190,6 @@ def __init__(self, process_class): super().__init__(self._process_spec.inputs) @property - def process_class(self): + def process_class(self) -> Type['Process']: """Return the process class for which this builder is constructed.""" return self._process_class diff --git a/aiida/engine/processes/calcjobs/__init__.py b/aiida/engine/processes/calcjobs/__init__.py index dc7c275880..57d4777ae7 100644 --- a/aiida/engine/processes/calcjobs/__init__.py +++ b/aiida/engine/processes/calcjobs/__init__.py @@ -12,4 +12,4 @@ from .calcjob import * -__all__ = (calcjob.__all__) +__all__ = (calcjob.__all__) # type: ignore[name-defined] diff --git a/aiida/engine/processes/calcjobs/calcjob.py b/aiida/engine/processes/calcjobs/calcjob.py index b86ff8ee32..7fafe77a7c 100644 --- a/aiida/engine/processes/calcjobs/calcjob.py +++ b/aiida/engine/processes/calcjobs/calcjob.py @@ -9,8 +9,12 @@ ########################################################################### """Implementation of the CalcJob process.""" import io +import os +import shutil +from typing import Any, Dict, Hashable, Optional, Type, Union -import plumpy +import plumpy.ports +import plumpy.process_states from aiida import orm from aiida.common import exceptions, AttributeDict @@ -20,6 +24,7 @@ from aiida.common.links import LinkType from ..exit_code import ExitCode +from ..ports import PortNamespace from ..process import Process, ProcessState from ..process_spec import CalcJobProcessSpec from .tasks import Waiting, UPLOAD_COMMAND @@ -27,7 +32,7 @@ __all__ = ('CalcJob',) -def validate_calc_job(inputs, ctx): # pylint: disable=too-many-return-statements +def validate_calc_job(inputs: Any, ctx: PortNamespace) -> Optional[str]: # pylint: disable=too-many-return-statements """Validate the entire set of inputs passed to the `CalcJob` constructor. Reasons that will cause this validation to raise an `InputValidationError`: @@ -43,7 +48,7 @@ def validate_calc_job(inputs, ctx): # pylint: disable=too-many-return-statement ctx.get_port('metadata.computer') except ValueError: # If the namespace no longer contains the `code` or `metadata.computer` ports we skip validation - return + return None code = inputs.get('code', None) computer_from_code = code.computer @@ -69,11 +74,11 @@ def validate_calc_job(inputs, ctx): # pylint: disable=too-many-return-statement try: resources_port = ctx.get_port('metadata.options.resources') except ValueError: - return + return None # If the resources port exists but is not required, we don't need to validate it against the computer's scheduler if not resources_port.required: - return + return None computer = computer_from_code or computer_from_metadata scheduler = computer.get_scheduler() @@ -89,43 +94,47 @@ def validate_calc_job(inputs, ctx): # pylint: disable=too-many-return-statement except ValueError as exception: return f'input `metadata.options.resources` is not valid for the `{scheduler}` scheduler: {exception}' + return None -def validate_parser(parser_name, _): + +def validate_parser(parser_name: Any, _: Any) -> Optional[str]: """Validate the parser. :return: string with error message in case the inputs are invalid """ from aiida.plugins import ParserFactory - if parser_name is not plumpy.UNSPECIFIED: + if parser_name is not plumpy.ports.UNSPECIFIED: try: ParserFactory(parser_name) except exceptions.EntryPointError as exception: return f'invalid parser specified: {exception}' + return None + -def validate_additional_retrieve_list(additional_retrieve_list, _): +def validate_additional_retrieve_list(additional_retrieve_list: Any, _: Any) -> Optional[str]: """Validate the additional retrieve list. :return: string with error message in case the input is invalid. """ - import os - - if additional_retrieve_list is plumpy.UNSPECIFIED: - return + if additional_retrieve_list is plumpy.ports.UNSPECIFIED: + return None if any(not isinstance(value, str) or os.path.isabs(value) for value in additional_retrieve_list): return f'`additional_retrieve_list` should only contain relative filepaths but got: {additional_retrieve_list}' + return None + class CalcJob(Process): """Implementation of the CalcJob process.""" _node_class = orm.CalcJobNode _spec_class = CalcJobProcessSpec - link_label_retrieved = 'retrieved' + link_label_retrieved: str = 'retrieved' - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: """Construct a CalcJob instance. Construct the instance only if it is a sub class of `CalcJob`, otherwise raise `InvalidOperation`. @@ -138,14 +147,18 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @classmethod - def define(cls, spec: CalcJobProcessSpec): - # yapf: disable + def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] """Define the process specification, including its inputs, outputs and known exit codes. + Ports are added to the `metadata` input namespace (inherited from the base Process), + and a `code` input Port, a `remote_folder` output Port and retrieved folder output Port + are added. + :param spec: the calculation job process spec to define. """ + # yapf: disable super().define(spec) - spec.inputs.validator = validate_calc_job + spec.inputs.validator = validate_calc_job # type: ignore[assignment] # takes only PortNamespace not Port spec.input('code', valid_type=orm.Code, help='The `Code` to use for this job.') spec.input('metadata.dry_run', valid_type=bool, default=False, help='When set to `True` will prepare the calculation job for submission but not actually launch it.') @@ -217,6 +230,7 @@ def define(cls, spec: CalcJobProcessSpec): message='The job ran out of memory.') spec.exit_code(120, 'ERROR_SCHEDULER_OUT_OF_WALLTIME', message='The job ran out of walltime.') + # yapf: enable @classproperty def spec_options(cls): # pylint: disable=no-self-argument @@ -228,11 +242,11 @@ def spec_options(cls): # pylint: disable=no-self-argument return cls.spec_metadata['options'] # pylint: disable=unsubscriptable-object @property - def options(self): + def options(self) -> AttributeDict: """Return the options of the metadata that were specified when this process instance was launched. :return: options dictionary - :rtype: dict + """ try: return self.metadata.options @@ -240,14 +254,18 @@ def options(self): return AttributeDict() @classmethod - def get_state_classes(cls): + def get_state_classes(cls) -> Dict[Hashable, Type[plumpy.process_states.State]]: + """A mapping of the State constants to the corresponding state class. + + Overrides the waiting state with the Calcjob specific version. + """ # Overwrite the waiting state states_map = super().get_state_classes() states_map[ProcessState.WAITING] = Waiting return states_map @override - def on_terminated(self): + def on_terminated(self) -> None: """Cleanup the node by deleting the calulation job state. .. note:: This has to be done before calling the super because that will seal the node after we cannot change it @@ -256,13 +274,17 @@ def on_terminated(self): super().on_terminated() @override - def run(self): + def run(self) -> Union[plumpy.process_states.Stop, int, plumpy.process_states.Wait]: """Run the calculation job. This means invoking the `presubmit` and storing the temporary folder in the node's repository. Then we move the process in the `Wait` state, waiting for the `UPLOAD` transport task to be started. + + :returns: the `Stop` command if a dry run, int if the process has an exit status, + `Wait` command if the calcjob is to be uploaded + """ - if self.inputs.metadata.dry_run: + if self.inputs.metadata.dry_run: # type: ignore[union-attr] from aiida.common.folders import SubmitTestFolder from aiida.engine.daemon.execmanager import upload_calculation from aiida.transports.plugins.local import LocalTransport @@ -276,7 +298,7 @@ def run(self): 'folder': folder.abspath, 'script_filename': self.node.get_option('submit_script_filename') } - return plumpy.Stop(None, True) + return plumpy.process_states.Stop(None, True) # The following conditional is required for the caching to properly work. Even if the source node has a process # state of `Finished` the cached process will still enter the running state. The process state will have then @@ -286,7 +308,7 @@ def run(self): return self.node.exit_status # Launch the upload operation - return plumpy.Wait(msg='Waiting to upload', data=UPLOAD_COMMAND) + return plumpy.process_states.Wait(msg='Waiting to upload', data=UPLOAD_COMMAND) def prepare_for_submission(self, folder: Folder) -> CalcInfo: """Prepare the calculation for submission. @@ -301,13 +323,14 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo: """ raise NotImplementedError - def parse(self, retrieved_temporary_folder=None): + def parse(self, retrieved_temporary_folder: Optional[str] = None) -> ExitCode: """Parse a retrieved job calculation. This is called once it's finished waiting for the calculation to be finished and the data has been retrieved. - """ - import shutil + :param retrieved_temporary_folder: The path to the temporary folder + + """ try: retrieved = self.node.outputs.retrieved except exceptions.NotExistent: @@ -337,6 +360,7 @@ def parse(self, retrieved_temporary_folder=None): self.logger.warning(msg) # The final exit code is that of the scheduler, unless the output parser returned one + exit_code: Optional[ExitCode] if exit_code_retrieved is not None: exit_code = exit_code_retrieved else: @@ -346,9 +370,9 @@ def parse(self, retrieved_temporary_folder=None): for entry in self.node.get_outgoing(): self.out(entry.link_label, entry.node) - return exit_code or ExitCode(0) + return exit_code or ExitCode(0) # type: ignore[call-arg] - def parse_scheduler_output(self, retrieved): + def parse_scheduler_output(self, retrieved: orm.Node) -> Optional[ExitCode]: """Parse the output of the scheduler if that functionality has been implemented for the plugin.""" scheduler = self.node.computer.get_scheduler() filename_stderr = self.node.get_option('scheduler_stderr') @@ -376,16 +400,16 @@ def parse_scheduler_output(self, retrieved): # Only attempt to call the scheduler parser if all three resources of information are available if any(entry is None for entry in [detailed_job_info, scheduler_stderr, scheduler_stdout]): - return + return None try: exit_code = scheduler.parse_output(detailed_job_info, scheduler_stdout, scheduler_stderr) except exceptions.FeatureNotAvailable: self.logger.info(f'`{scheduler.__class__.__name__}` does not implement scheduler output parsing') - return + return None except Exception as exception: # pylint: disable=broad-except self.logger.error(f'the `parse_output` method of the scheduler excepted: {exception}') - return + return None if exit_code is not None and not isinstance(exit_code, ExitCode): args = (scheduler.__class__.__name__, type(exit_code)) @@ -393,12 +417,12 @@ def parse_scheduler_output(self, retrieved): return exit_code - def parse_retrieved_output(self, retrieved_temporary_folder=None): + def parse_retrieved_output(self, retrieved_temporary_folder: Optional[str] = None) -> Optional[ExitCode]: """Parse the retrieved data by calling the parser plugin if it was defined in the inputs.""" parser_class = self.node.get_parser_class() if parser_class is None: - return + return None parser = parser_class(self.node) parse_kwargs = parser.get_outputs_for_parsing() @@ -422,18 +446,15 @@ def parse_retrieved_output(self, retrieved_temporary_folder=None): return exit_code - def presubmit(self, folder): + def presubmit(self, folder: Folder) -> CalcInfo: """Prepares the calculation folder with all inputs, ready to be copied to the cluster. :param folder: a SandboxFolder that can be used to write calculation input files and the scheduling script. - :type folder: :class:`aiida.common.folders.Folder` :return calcinfo: the CalcInfo object containing the information needed by the daemon to handle operations. - :rtype calcinfo: :class:`aiida.common.CalcInfo` + """ # pylint: disable=too-many-locals,too-many-statements,too-many-branches - import os - from aiida.common.exceptions import PluginInternalError, ValidationError, InvalidOperation, InputValidationError from aiida.common import json from aiida.common.utils import validate_list_of_string_tuples @@ -445,19 +466,23 @@ def presubmit(self, folder): computer = self.node.computer inputs = self.node.get_incoming(link_type=LinkType.INPUT_CALC) - if not self.inputs.metadata.dry_run and self.node.has_cached_links(): + if not self.inputs.metadata.dry_run and self.node.has_cached_links(): # type: ignore[union-attr] raise InvalidOperation('calculation node has unstored links in cache') codes = [_ for _ in inputs.all_nodes() if isinstance(_, Code)] for code in codes: if not code.can_run_on(computer): - raise InputValidationError('The selected code {} for calculation {} cannot run on computer {}'.format( - code.pk, self.node.pk, computer.label)) + raise InputValidationError( + 'The selected code {} for calculation {} cannot run on computer {}'.format( + code.pk, self.node.pk, computer.label + ) + ) if code.is_local() and code.get_local_executable() in folder.get_content_list(): - raise PluginInternalError('The plugin created a file {} that is also the executable name!'.format( - code.get_local_executable())) + raise PluginInternalError( + f'The plugin created a file {code.get_local_executable()} that is also the executable name!' + ) calc_info = self.prepare_for_submission(folder) calc_info.uuid = str(self.node.uuid) @@ -495,7 +520,8 @@ def presubmit(self, folder): if not issubclass(file_sub_class, orm.SinglefileData): raise PluginInternalError( '[presubmission of calc {}] retrieve_singlefile_list subclass problem: {} is ' - 'not subclass of SinglefileData'.format(self.node.pk, file_sub_class.__name__)) + 'not subclass of SinglefileData'.format(self.node.pk, file_sub_class.__name__) + ) if retrieve_singlefile_list: self.node.set_retrieve_singlefile_list(retrieve_singlefile_list) @@ -553,11 +579,13 @@ def presubmit(self, folder): this_withmpi = self.node.get_option('withmpi') if this_withmpi: - this_argv = (mpi_args + extra_mpirun_params + [this_code.get_execname()] + - (code_info.cmdline_params if code_info.cmdline_params is not None else [])) + this_argv = ( + mpi_args + extra_mpirun_params + [this_code.get_execname()] + + (code_info.cmdline_params if code_info.cmdline_params is not None else []) + ) else: - this_argv = [this_code.get_execname()] + (code_info.cmdline_params - if code_info.cmdline_params is not None else []) + this_argv = [this_code.get_execname() + ] + (code_info.cmdline_params if code_info.cmdline_params is not None else []) # overwrite the old cmdline_params and add codename and mpirun stuff code_info.cmdline_params = this_argv @@ -642,13 +670,17 @@ def presubmit(self, folder): try: Computer.objects.get(uuid=remote_computer_uuid) # pylint: disable=unused-variable except exceptions.NotExistent as exc: - raise PluginInternalError('[presubmission of calc {}] ' - 'The remote copy requires a computer with UUID={}' - 'but no such computer was found in the ' - 'database'.format(this_pk, remote_computer_uuid)) from exc + raise PluginInternalError( + '[presubmission of calc {}] ' + 'The remote copy requires a computer with UUID={}' + 'but no such computer was found in the ' + 'database'.format(this_pk, remote_computer_uuid) + ) from exc if os.path.isabs(dest_rel_path): - raise PluginInternalError('[presubmission of calc {}] ' - 'The destination path of the remote copy ' - 'is absolute! ({})'.format(this_pk, dest_rel_path)) + raise PluginInternalError( + '[presubmission of calc {}] ' + 'The destination path of the remote copy ' + 'is absolute! ({})'.format(this_pk, dest_rel_path) + ) return calc_info diff --git a/aiida/engine/processes/calcjobs/manager.py b/aiida/engine/processes/calcjobs/manager.py index 4191f2494f..46d3c057e6 100644 --- a/aiida/engine/processes/calcjobs/manager.py +++ b/aiida/engine/processes/calcjobs/manager.py @@ -8,12 +8,18 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Module containing utilities and classes relating to job calculations running on systems that require transport.""" +import asyncio import contextlib import logging import time -import asyncio +from typing import Any, Dict, Hashable, Iterator, List, Optional, TYPE_CHECKING from aiida.common import lang +from aiida.orm import AuthInfo + +if TYPE_CHECKING: + from aiida.engine.transports import TransportQueue + from aiida.schedulers.datastructures import JobInfo __all__ = ('JobsList', 'JobManager') @@ -35,15 +41,13 @@ class JobsList: See the :py:class:`~aiida.engine.processes.calcjobs.manager.JobManager` for example usage. """ - def __init__(self, authinfo, transport_queue, last_updated=None): + def __init__(self, authinfo: AuthInfo, transport_queue: 'TransportQueue', last_updated: Optional[float] = None): """Construct an instance for the given authinfo and transport queue. :param authinfo: The authinfo used to check the jobs list - :type authinfo: :class:`aiida.orm.AuthInfo` :param transport_queue: A transport queue - :type: :class:`aiida.engine.transports.TransportQueue` :param last_updated: initialize the last updated timestamp - :type: float + """ lang.type_check(last_updated, float, allow_none=True) @@ -52,41 +56,41 @@ def __init__(self, authinfo, transport_queue, last_updated=None): self._loop = transport_queue.loop self._logger = logging.getLogger(__name__) - self._jobs_cache = {} - self._job_update_requests = {} # Mapping: {job_id: Future} + self._jobs_cache: Dict[Hashable, 'JobInfo'] = {} + self._job_update_requests: Dict[Hashable, asyncio.Future] = {} # Mapping: {job_id: Future} self._last_updated = last_updated - self._update_handle = None + self._update_handle: Optional[asyncio.TimerHandle] = None @property - def logger(self): + def logger(self) -> logging.Logger: """Return the logger configured for this instance. :return: the logger """ return self._logger - def get_minimum_update_interval(self): + def get_minimum_update_interval(self) -> float: """Get the minimum interval that should be respected between updates of the list. :return: the minimum interval - :rtype: float + """ return self._authinfo.computer.get_minimum_job_poll_interval() @property - def last_updated(self): + def last_updated(self) -> Optional[float]: """Get the timestamp of when the list was last updated as produced by `time.time()` :return: The last update point - :rtype: float + """ return self._last_updated - async def _get_jobs_from_scheduler(self): + async def _get_jobs_from_scheduler(self) -> Dict[Hashable, 'JobInfo']: """Get the current jobs list from the scheduler. :return: a mapping of job ids to :py:class:`~aiida.schedulers.datastructures.JobInfo` instances - :rtype: dict + """ with self._transport_queue.request_transport(self._authinfo) as request: self.logger.info('waiting for transport') @@ -95,7 +99,7 @@ async def _get_jobs_from_scheduler(self): scheduler = self._authinfo.computer.get_scheduler() scheduler.set_transport(transport) - kwargs = {'as_dict': True} + kwargs: Dict[str, Any] = {'as_dict': True} if scheduler.get_feature('can_query_by_user'): kwargs['user'] = '$USER' else: @@ -113,7 +117,7 @@ async def _get_jobs_from_scheduler(self): return jobs_cache - async def _update_job_info(self): + async def _update_job_info(self) -> None: """Update all of the job information objects. This will set the futures for all pending update requests where the corresponding job has a new status compared @@ -146,7 +150,7 @@ async def _update_job_info(self): self._job_update_requests = {} @contextlib.contextmanager - def request_job_info_update(self, job_id): + def request_job_info_update(self, job_id: Hashable) -> Iterator['asyncio.Future[JobInfo]']: """Request job info about a job when the job next changes state. If the job is not found in the jobs list at the update, the future will resolve to `None`. @@ -164,7 +168,7 @@ def request_job_info_update(self, job_id): finally: pass - def _ensure_updating(self): + def _ensure_updating(self) -> None: """Ensure that we are updating the job list from the remote resource. This will automatically stop if there are no outstanding requests. @@ -188,12 +192,10 @@ async def updating(): ) @staticmethod - def _has_job_state_changed(old, new): + def _has_job_state_changed(old: Optional['JobInfo'], new: Optional['JobInfo']) -> bool: """Return whether the states `old` and `new` are different. - :type old: :class:`aiida.schedulers.JobInfo` or `None` - :type new: :class:`aiida.schedulers.JobInfo` or `None` - :rtype: bool + """ if old is None and new is None: return False @@ -204,14 +206,14 @@ def _has_job_state_changed(old, new): return old.job_state != new.job_state or old.job_substate != new.job_substate - def _get_next_update_delay(self): + def _get_next_update_delay(self) -> float: """Calculate when we are next allowed to poll the scheduler. This delay is calculated as the minimum polling interval defined by the authentication info for this instance, minus time elapsed since the last update. :return: delay (in seconds) after which the scheduler may be polled again - :rtype: float + """ if self.last_updated is None: # Never updated, so do it straight away @@ -225,10 +227,10 @@ def _get_next_update_delay(self): return delay - def _update_requests_outstanding(self): + def _update_requests_outstanding(self) -> bool: return any(not request.done() for request in self._job_update_requests.values()) - def _get_jobs_with_scheduler(self): + def _get_jobs_with_scheduler(self) -> List[str]: """Get all the jobs that are currently with scheduler. :return: the list of jobs with the scheduler @@ -252,11 +254,11 @@ class JobManager: only hold per runner. """ - def __init__(self, transport_queue): + def __init__(self, transport_queue: 'TransportQueue') -> None: self._transport_queue = transport_queue - self._job_lists = {} + self._job_lists: Dict[Hashable, 'JobInfo'] = {} - def get_jobs_list(self, authinfo): + def get_jobs_list(self, authinfo: AuthInfo) -> JobsList: """Get or create a new `JobLists` instance for the given authinfo. :param authinfo: the `AuthInfo` @@ -268,13 +270,11 @@ def get_jobs_list(self, authinfo): return self._job_lists[authinfo.id] @contextlib.contextmanager - def request_job_info_update(self, authinfo, job_id): + def request_job_info_update(self, authinfo: AuthInfo, job_id: Hashable) -> Iterator['asyncio.Future[JobInfo]']: """Get a future that will resolve to information about a given job. This is a context manager so that if the user leaves the context the request is automatically cancelled. - :return: A tuple containing the `JobInfo` object and detailed job info. Both can be None. - :rtype: :class:`asyncio.Future` """ with self.get_jobs_list(authinfo).request_job_info_update(job_id) as request: try: diff --git a/aiida/engine/processes/calcjobs/tasks.py b/aiida/engine/processes/calcjobs/tasks.py index 293065ff6c..8a837a189c 100644 --- a/aiida/engine/processes/calcjobs/tasks.py +++ b/aiida/engine/processes/calcjobs/tasks.py @@ -11,19 +11,27 @@ import functools import logging import tempfile +from typing import Any, Callable, Optional, TYPE_CHECKING import plumpy +import plumpy.process_states +import plumpy.futures from aiida.common.datastructures import CalcJobState from aiida.common.exceptions import FeatureNotAvailable, TransportTaskException from aiida.common.folders import SandboxFolder from aiida.engine.daemon import execmanager -from aiida.engine.utils import exponential_backoff_retry, interruptable_task +from aiida.engine.transports import TransportQueue +from aiida.engine.utils import exponential_backoff_retry, interruptable_task, InterruptableFuture +from aiida.orm.nodes.process.calculation.calcjob import CalcJobNode from aiida.schedulers.datastructures import JobState from aiida.manage.configuration import get_config_option from ..process import ProcessState +if TYPE_CHECKING: + from .calcjob import CalcJob + UPLOAD_COMMAND = 'upload' SUBMIT_COMMAND = 'submit' UPDATE_COMMAND = 'update' @@ -40,7 +48,7 @@ class PreSubmitException(Exception): """Raise in the `do_upload` coroutine when an exception is raised in `CalcJob.presubmit`.""" -async def task_upload_job(process, transport_queue, cancellable): +async def task_upload_job(process: 'CalcJob', transport_queue: TransportQueue, cancellable: InterruptableFuture): """Transport task that will attempt to upload the files of a job calculation to the remote. The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager @@ -48,10 +56,10 @@ async def task_upload_job(process, transport_queue, cancellable): retry after an interval that increases exponentially with the number of retries, for a maximum number of retries. If all retries fail, the task will raise a TransportTaskException - :param node: the node that represents the job calculation + :param process: the job calculation :param transport_queue: the TransportQueue from which to request a Transport :param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled - :type cancellable: :class:`aiida.engine.utils.InterruptableFuture` + :raises: TransportTaskException if after the maximum number of retries the transport task still excepted """ node = process.node @@ -83,13 +91,13 @@ async def do_upload(): try: logger.info(f'scheduled request to upload CalcJob<{node.pk}>') - ignore_exceptions = (plumpy.CancelledError, PreSubmitException) + ignore_exceptions = (plumpy.futures.CancelledError, PreSubmitException) skip_submit = await exponential_backoff_retry( do_upload, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=ignore_exceptions ) except PreSubmitException: raise - except plumpy.CancelledError: + except plumpy.futures.CancelledError: pass except Exception: logger.warning(f'uploading CalcJob<{node.pk}> failed') @@ -100,7 +108,7 @@ async def do_upload(): return skip_submit -async def task_submit_job(node, transport_queue, cancellable): +async def task_submit_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture): """Transport task that will attempt to submit a job calculation. The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager @@ -111,7 +119,7 @@ async def task_submit_job(node, transport_queue, cancellable): :param node: the node that represents the job calculation :param transport_queue: the TransportQueue from which to request a Transport :param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled - :type cancellable: :class:`aiida.engine.utils.InterruptableFuture` + :raises: TransportTaskException if after the maximum number of retries the transport task still excepted """ if node.get_state() == CalcJobState.WITHSCHEDULER: @@ -132,9 +140,13 @@ async def do_submit(): try: logger.info(f'scheduled request to submit CalcJob<{node.pk}>') result = await exponential_backoff_retry( - do_submit, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption + do_submit, + initial_interval, + max_attempts, + logger=node.logger, + ignore_exceptions=plumpy.process_states.Interruption ) - except plumpy.Interruption: + except plumpy.process_states.Interruption: pass except Exception: logger.warning(f'submitting CalcJob<{node.pk}> failed') @@ -145,7 +157,7 @@ async def do_submit(): return result -async def task_update_job(node, job_manager, cancellable): +async def task_update_job(node: CalcJobNode, job_manager, cancellable: InterruptableFuture): """Transport task that will attempt to update the scheduler status of the job calculation. The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager @@ -190,9 +202,13 @@ async def do_update(): try: logger.info(f'scheduled request to update CalcJob<{node.pk}>') job_done = await exponential_backoff_retry( - do_update, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption + do_update, + initial_interval, + max_attempts, + logger=node.logger, + ignore_exceptions=plumpy.process_states.Interruption ) - except plumpy.Interruption: + except plumpy.process_states.Interruption: raise except Exception: logger.warning(f'updating CalcJob<{node.pk}> failed') @@ -205,7 +221,10 @@ async def do_update(): return job_done -async def task_retrieve_job(node, transport_queue, retrieved_temporary_folder, cancellable): +async def task_retrieve_job( + node: CalcJobNode, transport_queue: TransportQueue, retrieved_temporary_folder: str, + cancellable: InterruptableFuture +): """Transport task that will attempt to retrieve all files of a completed job calculation. The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager @@ -215,8 +234,9 @@ async def task_retrieve_job(node, transport_queue, retrieved_temporary_folder, c :param node: the node that represents the job calculation :param transport_queue: the TransportQueue from which to request a Transport + :param retrieved_temporary_folder: the absolute path to a directory to store files :param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled - :type cancellable: :class:`aiida.engine.utils.InterruptableFuture` + :raises: TransportTaskException if after the maximum number of retries the transport task still excepted """ if node.get_state() == CalcJobState.PARSING: @@ -251,9 +271,13 @@ async def do_retrieve(): try: logger.info(f'scheduled request to retrieve CalcJob<{node.pk}>') result = await exponential_backoff_retry( - do_retrieve, initial_interval, max_attempts, logger=node.logger, ignore_exceptions=plumpy.Interruption + do_retrieve, + initial_interval, + max_attempts, + logger=node.logger, + ignore_exceptions=plumpy.process_states.Interruption ) - except plumpy.Interruption: + except plumpy.process_states.Interruption: raise except Exception: logger.warning(f'retrieving CalcJob<{node.pk}> failed') @@ -264,7 +288,7 @@ async def do_retrieve(): return result -async def task_kill_job(node, transport_queue, cancellable): +async def task_kill_job(node: CalcJobNode, transport_queue: TransportQueue, cancellable: InterruptableFuture): """Transport task that will attempt to kill a job calculation. The task will first request a transport from the queue. Once the transport is yielded, the relevant execmanager @@ -275,7 +299,7 @@ async def task_kill_job(node, transport_queue, cancellable): :param node: the node that represents the job calculation :param transport_queue: the TransportQueue from which to request a Transport :param cancellable: the cancelled flag that will be queried to determine whether the task was cancelled - :type cancellable: :class:`aiida.engine.utils.InterruptableFuture` + :raises: TransportTaskException if after the maximum number of retries the transport task still excepted """ initial_interval = get_config_option(RETRY_INTERVAL_OPTION) @@ -295,7 +319,7 @@ async def do_kill(): try: logger.info(f'scheduled request to kill CalcJob<{node.pk}>') result = await exponential_backoff_retry(do_kill, initial_interval, max_attempts, logger=node.logger) - except plumpy.Interruption: + except plumpy.process_states.Interruption: raise except Exception: logger.warning(f'killing CalcJob<{node.pk}> failed') @@ -306,29 +330,42 @@ async def do_kill(): return result -class Waiting(plumpy.Waiting): +class Waiting(plumpy.process_states.Waiting): """The waiting state for the `CalcJob` process.""" - def __init__(self, process, done_callback, msg=None, data=None): + def __init__( + self, + process: 'CalcJob', + done_callback: Optional[Callable[..., Any]], + msg: Optional[str] = None, + data: Optional[Any] = None + ): """ - :param :class:`~plumpy.base.state_machine.StateMachine` process: The process this state belongs to + :param process: The process this state belongs to """ super().__init__(process, done_callback, msg, data) - self._task = None - self._killing = None + self._task: Optional[InterruptableFuture] = None + self._killing: Optional[plumpy.futures.Future] = None + + @property + def process(self) -> 'CalcJob': + """ + :return: The process + """ + return self.state_machine # type: ignore[return-value] def load_instance_state(self, saved_state, load_context): super().load_instance_state(saved_state, load_context) self._task = None self._killing = None - async def execute(self): # pylint: disable=invalid-overridden-method + async def execute(self) -> plumpy.process_states.State: # type: ignore[override] # pylint: disable=invalid-overridden-method """Override the execute coroutine of the base `Waiting` state.""" # pylint: disable=too-many-branches, too-many-statements node = self.process.node transport_queue = self.process.runner.transport command = self.data - result = self + result: plumpy.process_states.State = self process_status = f'Waiting for transport task: {command}' @@ -370,12 +407,15 @@ async def execute(self): # pylint: disable=invalid-overridden-method raise RuntimeError('Unknown waiting command') except TransportTaskException as exception: - raise plumpy.PauseInterruption(f'Pausing after failed transport task: {exception}') - except plumpy.KillInterruption: + raise plumpy.process_states.PauseInterruption(f'Pausing after failed transport task: {exception}') + except plumpy.process_states.KillInterruption: await self._launch_task(task_kill_job, node, transport_queue) - self._killing.set_result(True) + if self._killing is not None: + self._killing.set_result(True) + else: + logger.warning(f'killed CalcJob<{node.pk}> but async future was None') raise - except (plumpy.Interruption, plumpy.CancelledError): + except (plumpy.process_states.Interruption, plumpy.futures.CancelledError): node.set_process_status(f'Transport task {command} was interrupted') raise else: @@ -396,39 +436,45 @@ async def _launch_task(self, coro, *args, **kwargs): finally: self._task = None - def upload(self): + def upload(self) -> 'Waiting': """Return the `Waiting` state that will `upload` the `CalcJob`.""" msg = 'Waiting for calculation folder upload' - return self.create_state(ProcessState.WAITING, None, msg=msg, data=UPLOAD_COMMAND) + return self.create_state(ProcessState.WAITING, None, msg=msg, data=UPLOAD_COMMAND) # type: ignore[return-value] - def submit(self): + def submit(self) -> 'Waiting': """Return the `Waiting` state that will `submit` the `CalcJob`.""" msg = 'Waiting for scheduler submission' - return self.create_state(ProcessState.WAITING, None, msg=msg, data=SUBMIT_COMMAND) + return self.create_state(ProcessState.WAITING, None, msg=msg, data=SUBMIT_COMMAND) # type: ignore[return-value] - def update(self): + def update(self) -> 'Waiting': """Return the `Waiting` state that will `update` the `CalcJob`.""" msg = 'Waiting for scheduler update' - return self.create_state(ProcessState.WAITING, None, msg=msg, data=UPDATE_COMMAND) + return self.create_state(ProcessState.WAITING, None, msg=msg, data=UPDATE_COMMAND) # type: ignore[return-value] - def retrieve(self): + def retrieve(self) -> 'Waiting': """Return the `Waiting` state that will `retrieve` the `CalcJob`.""" msg = 'Waiting to retrieve' - return self.create_state(ProcessState.WAITING, None, msg=msg, data=RETRIEVE_COMMAND) + return self.create_state( + ProcessState.WAITING, None, msg=msg, data=RETRIEVE_COMMAND + ) # type: ignore[return-value] - def parse(self, retrieved_temporary_folder): + def parse(self, retrieved_temporary_folder: str) -> plumpy.process_states.Running: """Return the `Running` state that will parse the `CalcJob`. :param retrieved_temporary_folder: temporary folder used in retrieving that can be used during parsing. """ - return self.create_state(ProcessState.RUNNING, self.process.parse, retrieved_temporary_folder) + return self.create_state( + ProcessState.RUNNING, self.process.parse, retrieved_temporary_folder + ) # type: ignore[return-value] - def interrupt(self, reason): + def interrupt(self, reason: Any) -> Optional[plumpy.futures.Future]: # type: ignore[override] """Interrupt the `Waiting` state by calling interrupt on the transport task `InterruptableFuture`.""" if self._task is not None: self._task.interrupt(reason) - if isinstance(reason, plumpy.KillInterruption): + if isinstance(reason, plumpy.process_states.KillInterruption): if self._killing is None: - self._killing = plumpy.Future() + self._killing = plumpy.futures.Future() return self._killing + + return None diff --git a/aiida/engine/processes/exit_code.py b/aiida/engine/processes/exit_code.py index 0c54a5be72..cb13b0a765 100644 --- a/aiida/engine/processes/exit_code.py +++ b/aiida/engine/processes/exit_code.py @@ -34,7 +34,7 @@ class ExitCode(namedtuple('ExitCode', ['status', 'message', 'invalidates_cache'] :type invalidates_cache: bool """ - def format(self, **kwargs): + def format(self, **kwargs: str) -> 'ExitCode': """Create a clone of this exit code where the template message is replaced by the keyword arguments. :param kwargs: replacement parameters for the template message @@ -50,7 +50,7 @@ def format(self, **kwargs): # Set the defaults for the `ExitCode` attributes -ExitCode.__new__.__defaults__ = (0, None, False) +ExitCode.__new__.__defaults__ = (0, None, False) # type: ignore[attr-defined] class ExitCodesNamespace(AttributeDict): @@ -60,15 +60,13 @@ class ExitCodesNamespace(AttributeDict): `ExitCode` that needs to be retrieved or the key in the collection. """ - def __call__(self, identifier): + def __call__(self, identifier: str) -> ExitCode: """Return a specific exit code identified by either its exit status or label. :param identifier: the identifier of the exit code. If the type is integer, it will be interpreted as the exit code status, otherwise it be interpreted as the exit code label - :type identifier: str :returns: an `ExitCode` instance - :rtype: :class:`aiida.engine.ExitCode` :raises ValueError: if no exit code with the given label is defined for this process """ diff --git a/aiida/engine/processes/functions.py b/aiida/engine/processes/functions.py index a08e0ef012..0dd6ef4759 100644 --- a/aiida/engine/processes/functions.py +++ b/aiida/engine/processes/functions.py @@ -13,18 +13,24 @@ import inspect import logging import signal +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, TYPE_CHECKING from aiida.common.lang import override from aiida.manage.manager import get_manager +from aiida.orm import CalcFunctionNode, Data, ProcessNode, WorkFunctionNode +from aiida.orm.utils.mixins import FunctionCalculationMixin from .process import Process +if TYPE_CHECKING: + from .exit_code import ExitCode + __all__ = ('calcfunction', 'workfunction', 'FunctionProcess') LOGGER = logging.getLogger(__name__) -def calcfunction(function): +def calcfunction(function: Callable[..., Any]) -> Callable[..., Any]: """ A decorator to turn a standard python function into a calcfunction. Example usage: @@ -51,11 +57,10 @@ def calcfunction(function): :return: The decorated function. :rtype: callable """ - from aiida.orm import CalcFunctionNode return process_function(node_class=CalcFunctionNode)(function) -def workfunction(function): +def workfunction(function: Callable[..., Any]) -> Callable[..., Any]: """ A decorator to turn a standard python function into a workfunction. Example usage: @@ -80,13 +85,12 @@ def workfunction(function): :type function: callable :return: The decorated function. - :rtype: callable - """ - from aiida.orm import WorkFunctionNode + + """ return process_function(node_class=WorkFunctionNode)(function) -def process_function(node_class): +def process_function(node_class: Type['ProcessNode']) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """ The base function decorator to create a FunctionProcess out of a normal python function. @@ -94,7 +98,7 @@ def process_function(node_class): :type node_class: :class:`aiida.orm.ProcessNode` """ - def decorator(function): + def decorator(function: Callable[..., Any]) -> Callable[..., Any]: """ Turn the decorated function into a FunctionProcess. @@ -103,14 +107,14 @@ def decorator(function): """ process_class = FunctionProcess.build(function, node_class=node_class) - def run_get_node(*args, **kwargs): + def run_get_node(*args, **kwargs) -> Tuple[Optional[Dict[str, Any]], 'ProcessNode']: """ Run the FunctionProcess with the supplied inputs in a local runner. :param args: input arguments to construct the FunctionProcess :param kwargs: input keyword arguments to construct the FunctionProcess - :return: tuple of the outputs of the process and the process node pk - :rtype: (dict, int) + :return: tuple of the outputs of the process and the process node + """ manager = get_manager() runner = manager.get_runner() @@ -158,13 +162,13 @@ def kill_process(_num, _frame): return result, process.node - def run_get_pk(*args, **kwargs): + def run_get_pk(*args, **kwargs) -> Tuple[Optional[Dict[str, Any]], int]: """Recreate the `run_get_pk` utility launcher. :param args: input arguments to construct the FunctionProcess :param kwargs: input keyword arguments to construct the FunctionProcess :return: tuple of the outputs of the process and the process node pk - :rtype: (dict, int) + """ result, node = run_get_node(*args, **kwargs) return result, node.pk @@ -175,14 +179,14 @@ def decorated_function(*args, **kwargs): result, _ = run_get_node(*args, **kwargs) return result - decorated_function.run = decorated_function - decorated_function.run_get_pk = run_get_pk - decorated_function.run_get_node = run_get_node - decorated_function.is_process_function = True - decorated_function.node_class = node_class - decorated_function.process_class = process_class - decorated_function.recreate_from = process_class.recreate_from - decorated_function.spec = process_class.spec + decorated_function.run = decorated_function # type: ignore[attr-defined] + decorated_function.run_get_pk = run_get_pk # type: ignore[attr-defined] + decorated_function.run_get_node = run_get_node # type: ignore[attr-defined] + decorated_function.is_process_function = True # type: ignore[attr-defined] + decorated_function.node_class = node_class # type: ignore[attr-defined] + decorated_function.process_class = process_class # type: ignore[attr-defined] + decorated_function.recreate_from = process_class.recreate_from # type: ignore[attr-defined] + decorated_function.spec = process_class.spec # type: ignore[attr-defined] return decorated_function @@ -192,10 +196,10 @@ def decorated_function(*args, **kwargs): class FunctionProcess(Process): """Function process class used for turning functions into a Process""" - _func_args = None + _func_args: Sequence[str] = () @staticmethod - def _func(*_args, **_kwargs): + def _func(*_args, **_kwargs) -> dict: """ This is used internally to store the actual function that is being wrapped and will be replaced by the build method. @@ -203,7 +207,7 @@ def _func(*_args, **_kwargs): return {} @staticmethod - def build(func, node_class): + def build(func: Callable[..., Any], node_class: Type['ProcessNode']) -> Type['FunctionProcess']: """ Build a Process from the given function. @@ -211,19 +215,13 @@ def build(func, node_class): these will also become inputs. :param func: The function to build a process from - :type func: callable - :param node_class: Provide a custom node class to be used, has to be constructable with no arguments. It has to be a sub class of `ProcessNode` and the mixin :class:`~aiida.orm.utils.mixins.FunctionCalculationMixin`. - :type node_class: :class:`aiida.orm.nodes.process.process.ProcessNode` :return: A Process class that represents the function - :rtype: :class:`FunctionProcess` - """ - from aiida import orm - from aiida.orm.utils.mixins import FunctionCalculationMixin - if not issubclass(node_class, orm.ProcessNode) or not issubclass(node_class, FunctionCalculationMixin): + """ + if not issubclass(node_class, ProcessNode) or not issubclass(node_class, FunctionCalculationMixin): raise TypeError('the node_class should be a sub class of `ProcessNode` and `FunctionCalculationMixin`') args, varargs, keywords, defaults, _, _, _ = inspect.getfullargspec(func) @@ -240,7 +238,7 @@ def _define(cls, spec): # pylint: disable=unused-argument for i, arg in enumerate(args): default = () - if i >= first_default_pos: + if defaults and i >= first_default_pos: default = defaults[i - first_default_pos] # If the keyword was already specified, simply override the default @@ -251,9 +249,9 @@ def _define(cls, spec): # pylint: disable=unused-argument # Note that we cannot use `None` because the validation will call `isinstance` which does not work # when passing `None`, but it does work with `NoneType` which is returned by calling `type(None)` if default is None: - valid_type = (orm.Data, type(None)) + valid_type = (Data, type(None)) else: - valid_type = (orm.Data,) + valid_type = (Data,) spec.input(arg, valid_type=valid_type, default=default) @@ -269,7 +267,7 @@ def _define(cls, spec): # pylint: disable=unused-argument # Function processes must have a dynamic output namespace since we do not know beforehand what outputs # will be returned and the valid types for the value should be `Data` nodes as well as a dictionary because # the output namespace can be nested. - spec.outputs.valid_type = (orm.Data, dict) + spec.outputs.valid_type = (Data, dict) return type( func.__name__, (FunctionProcess,), { @@ -283,7 +281,7 @@ def _define(cls, spec): # pylint: disable=unused-argument ) @classmethod - def validate_inputs(cls, *args, **kwargs): # pylint: disable=unused-argument + def validate_inputs(cls, *args: Any, **kwargs: Any) -> None: # pylint: disable=unused-argument """ Validate the positional and keyword arguments passed in the function call. @@ -302,11 +300,8 @@ def validate_inputs(cls, *args, **kwargs): # pylint: disable=unused-argument raise TypeError(f'{name}() takes {nparameters} positional arguments but {nargs} were given') @classmethod - def create_inputs(cls, *args, **kwargs): - """Create the input args for the FunctionProcess. - - :rtype: dict - """ + def create_inputs(cls, *args: Any, **kwargs: Any) -> Dict[str, Any]: + """Create the input args for the FunctionProcess.""" cls.validate_inputs(*args, **kwargs) ins = {} @@ -317,29 +312,28 @@ def create_inputs(cls, *args, **kwargs): return ins @classmethod - def args_to_dict(cls, *args): + def args_to_dict(cls, *args: Any) -> Dict[str, Any]: """ Create an input dictionary (of form label -> value) from supplied args. :param args: The values to use for the dictionary - :type args: list :return: A label -> value dictionary - :rtype: dict + """ return dict(list(zip(cls._func_args, args))) @classmethod - def get_or_create_db_record(cls): + def get_or_create_db_record(cls) -> 'ProcessNode': return cls._node_class() - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: if kwargs.get('enable_persistence', False): raise RuntimeError('Cannot persist a function process') - super().__init__(enable_persistence=False, *args, **kwargs) + super().__init__(enable_persistence=False, *args, **kwargs) # type: ignore @property - def process_class(self): + def process_class(self) -> Callable[..., Any]: """ Return the class that represents this Process, for the FunctionProcess this is the function itself. @@ -348,33 +342,29 @@ def process_class(self): class that really represents what was being executed. :return: A Process class that represents the function - :rtype: :class:`FunctionProcess` + """ return self._func - def execute(self): + def execute(self) -> Optional[Dict[str, Any]]: """Execute the process.""" result = super().execute() # FunctionProcesses can return a single value as output, and not a dictionary, so we should also return that - if len(result) == 1 and self.SINGLE_OUTPUT_LINKNAME in result: + if result and len(result) == 1 and self.SINGLE_OUTPUT_LINKNAME in result: return result[self.SINGLE_OUTPUT_LINKNAME] return result @override - def _setup_db_record(self): + def _setup_db_record(self) -> None: """Set up the database record for the process.""" super()._setup_db_record() self.node.store_source_info(self._func) @override - def run(self): - """Run the process. - - :rtype: :class:`aiida.engine.ExitCode` - """ - from aiida.orm import Data + def run(self) -> Optional['ExitCode']: + """Run the process.""" from .exit_code import ExitCode # The following conditional is required for the caching to properly work. Even if the source node has a process @@ -388,9 +378,9 @@ def run(self): args = [None] * len(self._func_args) kwargs = {} - for name, value in self.inputs.items(): + for name, value in (self.inputs or {}).items(): try: - if self.spec().inputs[name].non_db: + if self.spec().inputs[name].non_db: # type: ignore[union-attr] # Don't consider non-database inputs continue except KeyError: @@ -418,4 +408,4 @@ def run(self): 'Must be a Data type or a mapping of {{string: Data}}'.format(result.__class__) ) - return ExitCode() + return ExitCode() # type: ignore[call-arg] diff --git a/aiida/engine/processes/futures.py b/aiida/engine/processes/futures.py index cf1d500cc8..1c3d06b67b 100644 --- a/aiida/engine/processes/futures.py +++ b/aiida/engine/processes/futures.py @@ -10,9 +10,12 @@ # pylint: disable=cyclic-import """Futures that can poll or receive broadcasted messages while waiting for a task to be completed.""" import asyncio +from typing import Optional, Union import kiwipy +from aiida.orm import Node, load_node + __all__ = ('ProcessFuture',) @@ -21,18 +24,23 @@ class ProcessFuture(asyncio.Future): _filtered = None - def __init__(self, pk, loop=None, poll_interval=None, communicator=None): + def __init__( + self, + pk: int, + loop: Optional[asyncio.AbstractEventLoop] = None, + poll_interval: Union[None, int, float] = None, + communicator: Optional[kiwipy.Communicator] = None + ): """Construct a future for a process node being finished. - If a None poll_interval is supplied polling will not be used. If a communicator is supplied it will be used - to listen for broadcast messages. + If a None poll_interval is supplied polling will not be used. + If a communicator is supplied it will be used to listen for broadcast messages. :param pk: process pk :param loop: An event loop :param poll_interval: optional polling interval, if None, polling is not activated. :param communicator: optional communicator, if None, will not subscribe to broadcasts. """ - from aiida.orm import load_node from .process import ProcessState # create future in specified event loop @@ -60,14 +68,14 @@ def __init__(self, pk, loop=None, poll_interval=None, communicator=None): if poll_interval is not None: loop.create_task(self._poll_process(node, poll_interval)) - def cleanup(self): + def cleanup(self) -> None: """Clean up the future by removing broadcast subscribers from the communicator if it still exists.""" if self._communicator is not None: self._communicator.remove_broadcast_subscriber(self._broadcast_identifier) self._communicator = None self._broadcast_identifier = None - async def _poll_process(self, node, poll_interval): + async def _poll_process(self, node: Node, poll_interval: Union[int, float]) -> None: """Poll whether the process node has reached a terminal state.""" while not self.done() and not node.is_terminated: await asyncio.sleep(poll_interval) diff --git a/aiida/engine/processes/ports.py b/aiida/engine/processes/ports.py index 1613d2169d..b288747138 100644 --- a/aiida/engine/processes/ports.py +++ b/aiida/engine/processes/ports.py @@ -10,9 +10,14 @@ """AiiDA specific implementation of plumpy Ports and PortNamespaces for the ProcessSpec.""" import collections import re +from typing import Any, Callable, Dict, Optional, Sequence import warnings from plumpy import ports +from plumpy.ports import breadcrumbs_to_port + +from aiida.common.links import validate_link_label +from aiida.orm import Data, Node __all__ = ( 'PortNamespace', 'InputPort', 'OutputPort', 'CalcJobOutputPort', 'WithNonDb', 'WithSerialize', @@ -26,21 +31,21 @@ class WithNonDb: """ - A mixin that adds support to a port to flag a that should not be stored + A mixin that adds support to a port to flag that it should not be stored in the database using the non_db=True flag. The mixins have to go before the main port class in the superclass order to make sure the mixin has the chance to strip out the non_db keyword. """ - def __init__(self, *args, **kwargs): - self._non_db_explicitly_set = bool('non_db' in kwargs) + def __init__(self, *args, **kwargs) -> None: + self._non_db_explicitly_set: bool = bool('non_db' in kwargs) non_db = kwargs.pop('non_db', False) - super().__init__(*args, **kwargs) - self._non_db = non_db + super().__init__(*args, **kwargs) # type: ignore[call-arg] + self._non_db: bool = non_db @property - def non_db_explicitly_set(self): + def non_db_explicitly_set(self) -> bool: """Return whether the a value for `non_db` was explicitly passed in the construction of the `Port`. :return: boolean, True if `non_db` was explicitly defined during construction, False otherwise @@ -48,7 +53,7 @@ def non_db_explicitly_set(self): return self._non_db_explicitly_set @property - def non_db(self): + def non_db(self) -> bool: """Return whether the value of this `Port` should be stored as a `Node` in the database. :return: boolean, True if it should be storable as a `Node`, False otherwise @@ -56,10 +61,8 @@ def non_db(self): return self._non_db @non_db.setter - def non_db(self, non_db): + def non_db(self, non_db: bool) -> None: """Set whether the value of this `Port` should be stored as a `Node` in the database. - - :param non_db: boolean """ self._non_db_explicitly_set = True self._non_db = non_db @@ -71,19 +74,17 @@ class WithSerialize: that are not AiiDA data types. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: serializer = kwargs.pop('serializer', None) - super().__init__(*args, **kwargs) - self._serializer = serializer + super().__init__(*args, **kwargs) # type: ignore[call-arg] + self._serializer: Callable[[Any], 'Data'] = serializer - def serialize(self, value): + def serialize(self, value: Any) -> 'Data': """Serialize the given value if it is not already a Data type and a serializer function is defined :param value: the value to be serialized :returns: a serialized version of the value or the unchanged value """ - from aiida.orm import Data - if self._serializer is None or isinstance(value, Data): return value @@ -96,11 +97,9 @@ class InputPort(WithSerialize, WithNonDb, ports.InputPort): value serialization to database storable types and support non database storable input types as well. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: """Override the constructor to check the type of the default if set and warn if not immutable.""" # pylint: disable=redefined-builtin,too-many-arguments - from aiida.orm import Node - if 'default' in kwargs: default = kwargs['default'] # If the default is specified and it is a node instance, raise a warning. This is to try and prevent that @@ -112,7 +111,7 @@ def __init__(self, *args, **kwargs): super(InputPort, self).__init__(*args, **kwargs) - def get_description(self): + def get_description(self) -> Dict[str, str]: """ Return a description of the InputPort, which will be a dictionary of its attributes @@ -127,13 +126,13 @@ def get_description(self): class CalcJobOutputPort(ports.OutputPort): """Sub class of plumpy.OutputPort which adds the `_pass_to_parser` attribute.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: pass_to_parser = kwargs.pop('pass_to_parser', False) super().__init__(*args, **kwargs) - self._pass_to_parser = pass_to_parser + self._pass_to_parser: bool = pass_to_parser @property - def pass_to_parser(self): + def pass_to_parser(self) -> bool: return self._pass_to_parser @@ -143,7 +142,7 @@ class PortNamespace(WithNonDb, ports.PortNamespace): serialization of a given mapping onto the ports of the PortNamespace. """ - def __setitem__(self, key, port): + def __setitem__(self, key: str, port: ports.Port) -> None: """Ensure that a `Port` being added inherits the `non_db` attribute if not explicitly defined at construction. The reasoning is that if a `PortNamespace` has `non_db=True`, which is different from the default value, very @@ -157,13 +156,13 @@ def __setitem__(self, key, port): self.validate_port_name(key) - if hasattr(port, 'non_db_explicitly_set') and not port.non_db_explicitly_set: - port.non_db = self.non_db + if hasattr(port, 'non_db_explicitly_set') and not port.non_db_explicitly_set: # type: ignore[attr-defined] + port.non_db = self.non_db # type: ignore[attr-defined] super().__setitem__(key, port) @staticmethod - def validate_port_name(port_name): + def validate_port_name(port_name: str) -> None: """Validate the given port name. Valid port names adhere to the following restrictions: @@ -181,8 +180,6 @@ def validate_port_name(port_name): :raise TypeError: if the port name is not a string type :raise ValueError: if the port name is invalid """ - from aiida.common.links import validate_link_label - try: validate_link_label(port_name) except ValueError as exception: @@ -195,7 +192,7 @@ def validate_port_name(port_name): if any([len(entry) > PORT_NAME_MAX_CONSECUTIVE_UNDERSCORES for entry in consecutive_underscores]): raise ValueError(f'invalid port name `{port_name}`: more than two consecutive underscores') - def serialize(self, mapping, breadcrumbs=()): + def serialize(self, mapping: Optional[Dict[str, Any]], breadcrumbs: Sequence[str] = ()) -> Optional[Dict[str, Any]]: """Serialize the given mapping onto this `Portnamespace`. It will recursively call this function on any nested `PortNamespace` or the serialize function on any `Ports`. @@ -204,26 +201,27 @@ def serialize(self, mapping, breadcrumbs=()): :param breadcrumbs: a tuple with the namespaces of parent namespaces :returns: the serialized mapping """ - from plumpy.ports import breadcrumbs_to_port - if mapping is None: return None - breadcrumbs += (self.name,) + breadcrumbs = (*breadcrumbs, self.name) if not isinstance(mapping, collections.Mapping): - port = breadcrumbs_to_port(breadcrumbs) - raise TypeError(f'port namespace `{port}` received `{type(mapping)}` instead of a dictionary') + port_name = breadcrumbs_to_port(breadcrumbs) + raise TypeError(f'port namespace `{port_name}` received `{type(mapping)}` instead of a dictionary') result = {} for name, value in mapping.items(): if name in self: + port = self[name] if isinstance(port, PortNamespace): result[name] = port.serialize(value, breadcrumbs) - else: + elif isinstance(port, InputPort): result[name] = port.serialize(value) + else: + raise AssertionError(f'port does not have a serialize method: {port}') else: result[name] = value diff --git a/aiida/engine/processes/process.py b/aiida/engine/processes/process.py index 4c2da21c8d..01cb51da30 100644 --- a/aiida/engine/processes/process.py +++ b/aiida/engine/processes/process.py @@ -8,36 +8,47 @@ # For further information please visit http://www.aiida.net # ########################################################################### """The AiiDA process class""" +import asyncio import collections import enum import inspect -import uuid +import logging +from uuid import UUID import traceback -import asyncio -from typing import Union +from types import TracebackType +from typing import ( + Any, cast, Dict, Iterable, Iterator, List, MutableMapping, Optional, Type, Tuple, Union, TYPE_CHECKING +) from aio_pika.exceptions import ConnectionClosed -import plumpy -from plumpy import ProcessState +import plumpy.exceptions +import plumpy.futures +import plumpy.processes +import plumpy.persistence +from plumpy.process_states import ProcessState, Finished from kiwipy.communications import UnroutableError from aiida import orm +from aiida.orm.utils import serialize from aiida.common import exceptions from aiida.common.extendeddicts import AttributeDict from aiida.common.lang import classproperty, override from aiida.common.links import LinkType from aiida.common.log import LOG_LEVEL_REPORT -from .exit_code import ExitCode +from .exit_code import ExitCode, ExitCodesNamespace from .builder import ProcessBuilder from .ports import InputPort, OutputPort, PortNamespace, PORT_NAMESPACE_SEPARATOR from .process_spec import ProcessSpec +if TYPE_CHECKING: + from aiida.engine.runners import Runner + __all__ = ('Process', 'ProcessState') -@plumpy.auto_persist('_parent_pid', '_enable_persistence') -class Process(plumpy.Process): +@plumpy.persistence.auto_persist('_parent_pid', '_enable_persistence') +class Process(plumpy.processes.Process): """ This class represents an AiiDA process which can be executed and will have full provenance saved in the database. @@ -47,88 +58,109 @@ class Process(plumpy.Process): _node_class = orm.ProcessNode _spec_class = ProcessSpec - SINGLE_OUTPUT_LINKNAME = 'result' + SINGLE_OUTPUT_LINKNAME: str = 'result' class SaveKeys(enum.Enum): """ Keys used to identify things in the saved instance state bundle. """ - CALC_ID = 'calc_id' + CALC_ID: str = 'calc_id' @classmethod - def define(cls, spec): - # yapf: disable + def spec(cls) -> ProcessSpec: + return super().spec() # type: ignore[return-value] + + @classmethod + def define(cls, spec: ProcessSpec) -> None: # type: ignore[override] + """Define the specification of the process, including its inputs, outputs and known exit codes. + + A `metadata` input namespace is defined, with optional ports that are not stored in the database. + + """ super().define(spec) spec.input_namespace(spec.metadata_key, required=False, non_db=True) - spec.input(f'{spec.metadata_key}.store_provenance', valid_type=bool, default=True, - help='If set to `False` provenance will not be stored in the database.') - spec.input(f'{spec.metadata_key}.description', valid_type=str, required=False, - help='Description to set on the process node.') - spec.input(f'{spec.metadata_key}.label', valid_type=str, required=False, - help='Label to set on the process node.') - spec.input(f'{spec.metadata_key}.call_link_label', valid_type=str, default='CALL', - help='The label to use for the `CALL` link if the process is called by another process.') + spec.input( + f'{spec.metadata_key}.store_provenance', + valid_type=bool, + default=True, + help='If set to `False` provenance will not be stored in the database.' + ) + spec.input( + f'{spec.metadata_key}.description', + valid_type=str, + required=False, + help='Description to set on the process node.' + ) + spec.input( + f'{spec.metadata_key}.label', valid_type=str, required=False, help='Label to set on the process node.' + ) + spec.input( + f'{spec.metadata_key}.call_link_label', + valid_type=str, + default='CALL', + help='The label to use for the `CALL` link if the process is called by another process.' + ) spec.exit_code(1, 'ERROR_UNSPECIFIED', message='The process has failed with an unspecified error.') spec.exit_code(2, 'ERROR_LEGACY_FAILURE', message='The process failed with legacy failure mode.') spec.exit_code(10, 'ERROR_INVALID_OUTPUT', message='The process returned an invalid output.') spec.exit_code(11, 'ERROR_MISSING_OUTPUT', message='The process did not register a required output.') @classmethod - def get_builder(cls): + def get_builder(cls) -> ProcessBuilder: return ProcessBuilder(cls) @classmethod - def get_or_create_db_record(cls): + def get_or_create_db_record(cls) -> orm.ProcessNode: """ Create a process node that represents what happened in this process. :return: A process node - :rtype: :class:`aiida.orm.ProcessNode` """ return cls._node_class() - def __init__(self, inputs=None, logger=None, runner=None, parent_pid=None, enable_persistence=True): + def __init__( + self, + inputs: Optional[Dict[str, Any]] = None, + logger: Optional[logging.Logger] = None, + runner: Optional['Runner'] = None, + parent_pid: Optional[int] = None, + enable_persistence: bool = True + ) -> None: """ Process constructor. :param inputs: process inputs - :type inputs: dict - :param logger: aiida logger - :type logger: :class:`logging.Logger` - :param runner: process runner - :type: :class:`aiida.engine.runners.Runner` - :param parent_pid: id of parent process - :type parent_pid: int - :param enable_persistence: whether to persist this process - :type enable_persistence: bool + """ from aiida.manage import manager self._runner = runner if runner is not None else manager.get_manager().get_runner() + assert self._runner.communicator is not None, 'communicator not set for runner' super().__init__( inputs=self.spec().inputs.serialize(inputs), logger=logger, loop=self._runner.loop, - communicator=self.runner.communicator) + communicator=self._runner.communicator + ) - self._node = None + self._node: Optional[orm.ProcessNode] = None self._parent_pid = parent_pid self._enable_persistence = enable_persistence if self._enable_persistence and self.runner.persister is None: self.logger.warning('Disabling persistence, runner does not have a persister') self._enable_persistence = False - def init(self): + def init(self) -> None: super().init() if self._logger is None: self.set_logger(self.node.logger) @classmethod - def get_exit_statuses(cls, exit_code_labels): + def get_exit_statuses(cls, exit_code_labels: Iterable[str]) -> List[int]: """Return the exit status (integers) for the given exit code labels. :param exit_code_labels: a list of strings that reference exit code labels of this process class @@ -139,37 +171,34 @@ def get_exit_statuses(cls, exit_code_labels): return [getattr(exit_codes, label).status for label in exit_code_labels] @classproperty - def exit_codes(cls): # pylint: disable=no-self-argument + def exit_codes(cls) -> ExitCodesNamespace: # pylint: disable=no-self-argument """Return the namespace of exit codes defined for this WorkChain through its ProcessSpec. The namespace supports getitem and getattr operations with an ExitCode label to retrieve a specific code. Additionally, the namespace can also be called with either the exit code integer status to retrieve it. :returns: ExitCodesNamespace of ExitCode named tuples - :rtype: :class:`aiida.engine.ExitCodesNamespace` + """ return cls.spec().exit_codes @classproperty - def spec_metadata(cls): # pylint: disable=no-self-argument - """Return the metadata port namespace of the process specification of this process. - - :return: metadata dictionary - :rtype: dict - """ - return cls.spec().inputs['metadata'] + def spec_metadata(cls) -> PortNamespace: # pylint: disable=no-self-argument + """Return the metadata port namespace of the process specification of this process.""" + return cls.spec().inputs['metadata'] # type: ignore[return-value] @property - def node(self): + def node(self) -> orm.ProcessNode: """Return the ProcessNode used by this process to represent itself in the database. :return: instance of sub class of ProcessNode - :rtype: :class:`aiida.orm.ProcessNode` + """ + assert self._node is not None return self._node @property - def uuid(self): + def uuid(self) -> str: # type: ignore[override] """Return the UUID of the process which corresponds to the UUID of its associated `ProcessNode`. :return: the UUID associated to this process instance @@ -177,32 +206,43 @@ def uuid(self): return self.node.uuid @property - def metadata(self): + def metadata(self) -> AttributeDict: """Return the metadata that were specified when this process instance was launched. :return: metadata dictionary - :rtype: dict + """ try: + assert self.inputs is not None return self.inputs.metadata - except AttributeError: + except (AssertionError, AttributeError): return AttributeDict() - def _save_checkpoint(self): + def _save_checkpoint(self) -> None: """ Save the current state in a chechpoint if persistence is enabled and the process state is not terminal If the persistence call excepts with a PersistenceError, it will be caught and a warning will be logged. """ if self._enable_persistence and not self._state.is_terminal(): + if self.runner.persister is None: + self.logger.exception( + 'No persister set to save checkpoint, this means you will ' + 'not be able to restart in case of a crash until the next successful checkpoint.' + ) + return None try: self.runner.persister.save_checkpoint(self) - except plumpy.PersistenceError: - self.logger.exception('Exception trying to save checkpoint, this means you will ' - 'not be able to restart in case of a crash until the next successful checkpoint.') + except plumpy.exceptions.PersistenceError: + self.logger.exception( + 'Exception trying to save checkpoint, this means you will ' + 'not be able to restart in case of a crash until the next successful checkpoint.' + ) @override - def save_instance_state(self, out_state, save_context): + def save_instance_state( + self, out_state: MutableMapping[str, Any], save_context: Optional[plumpy.persistence.LoadSaveContext] + ) -> None: """Save instance state. See documentation of :meth:`!plumpy.processes.Process.save_instance_state`. @@ -214,21 +254,23 @@ def save_instance_state(self, out_state, save_context): out_state[self.SaveKeys.CALC_ID.value] = self.pid - def get_provenance_inputs_iterator(self): + def get_provenance_inputs_iterator(self) -> Iterator[Tuple[str, Union[InputPort, PortNamespace]]]: """Get provenance input iterator. :rtype: filter """ + assert self.inputs is not None return filter(lambda kv: not kv[0].startswith('_'), self.inputs.items()) @override - def load_instance_state(self, saved_state, load_context): + def load_instance_state( + self, saved_state: MutableMapping[str, Any], load_context: plumpy.persistence.LoadSaveContext + ) -> None: """Load instance state. :param saved_state: saved instance state - :param load_context: - :type load_context: :class:`!plumpy.persistence.LoadSaveContext` + """ from aiida.manage import manager @@ -242,13 +284,13 @@ def load_instance_state(self, saved_state, load_context): if self.SaveKeys.CALC_ID.value in saved_state: self._node = orm.load_node(saved_state[self.SaveKeys.CALC_ID.value]) - self._pid = self.node.pk + self._pid = self.node.pk # pylint: disable=attribute-defined-outside-init else: - self._pid = self._create_and_setup_db_record() + self._pid = self._create_and_setup_db_record() # pylint: disable=attribute-defined-outside-init self.node.logger.info(f'Loaded process<{self.node.pk}> from saved state') - def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.Future]: + def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.futures.Future]: """ Kill the process and all the children calculations it called @@ -264,9 +306,12 @@ def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.Future]: if result is not False and not had_been_terminated: killing = [] for child in self.node.called: + if self.runner.controller is None: + self.logger.info('no controller available to kill child<%s>', child.pk) + continue try: result = self.runner.controller.kill_process(child.pk, f'Killed by parent<{self.node.pk}>') - result = asyncio.wrap_future(result) + result = asyncio.wrap_future(result) # type: ignore[arg-type] if asyncio.isfuture(result): killing.append(result) except ConnectionClosed: @@ -276,31 +321,30 @@ def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.Future]: if asyncio.isfuture(result): # We ourselves are waiting to be killed so add it to the list - killing.append(result) + killing.append(result) # type: ignore[arg-type] if killing: # We are waiting for things to be killed, so return the 'gathered' future - kill_future = plumpy.gather(*killing) + kill_future = plumpy.futures.gather(*killing) result = self.loop.create_future() - def done(done_future: plumpy.Future): + def done(done_future: plumpy.futures.Future): is_all_killed = all(done_future.result()) - result.set_result(is_all_killed) + result.set_result(is_all_killed) # type: ignore[union-attr] kill_future.add_done_callback(done) return result @override - def out(self, output_port, value=None): + def out(self, output_port: str, value: Any = None) -> None: """Attach output to output port. The name of the port will be used as the link label. :param output_port: name of output port - :type output_port: str - :param value: value to put inside output port + """ if value is None: # In this case assume that output_port is the actual value and there is just one return value @@ -309,7 +353,7 @@ def out(self, output_port, value=None): return super().out(output_port, value) - def out_many(self, out_dict): + def out_many(self, out_dict: Dict[str, Any]) -> None: """Attach outputs to multiple output ports. Keys of the dictionary will be used as output port names, values as outputs. @@ -320,39 +364,40 @@ def out_many(self, out_dict): for key, value in out_dict.items(): self.out(key, value) - def on_create(self): + def on_create(self) -> None: """Called when a Process is created.""" super().on_create() # If parent PID hasn't been supplied try to get it from the stack if self._parent_pid is None and Process.current(): current = Process.current() if isinstance(current, Process): - self._parent_pid = current.pid - self._pid = self._create_and_setup_db_record() + self._parent_pid = current.pid # type: ignore[assignment] + self._pid = self._create_and_setup_db_record() # pylint: disable=attribute-defined-outside-init @override - def on_entering(self, state): + def on_entering(self, state: plumpy.process_states.State) -> None: super().on_entering(state) # Update the node attributes every time we enter a new state - def on_entered(self, from_state): + def on_entered(self, from_state: Optional[plumpy.process_states.State]) -> None: + """After entering a new state, save a checkpoint and update the latest process state change timestamp.""" # pylint: disable=cyclic-import from aiida.engine.utils import set_process_state_change_timestamp self.update_node_state(self._state) self._save_checkpoint() - # Update the latest process state change timestamp set_process_state_change_timestamp(self) super().on_entered(from_state) @override - def on_terminated(self): + def on_terminated(self) -> None: """Called when a Process enters a terminal state.""" super().on_terminated() if self._enable_persistence: try: + assert self.runner.persister is not None self.runner.persister.delete_checkpoint(self.pid) - except Exception: # pylint: disable=broad-except - self.logger.exception('Failed to delete checkpoint') + except Exception as error: # pylint: disable=broad-except + self.logger.exception('Failed to delete checkpoint: %s', error) try: self.node.seal() @@ -360,7 +405,7 @@ def on_terminated(self): pass @override - def on_except(self, exc_info): + def on_except(self, exc_info: Tuple[Any, Exception, TracebackType]) -> None: """ Log the exception by calling the report method with formatted stack trace from exception info object and store the exception string as a node attribute @@ -372,14 +417,12 @@ def on_except(self, exc_info): self.report(''.join(traceback.format_exception(*exc_info))) @override - def on_finish(self, result, successful): + def on_finish(self, result: Union[int, ExitCode], successful: bool) -> None: """ Set the finish status on the process node. :param result: result of the process - :type result: int or :class:`aiida.engine.ExitCode` - :param successful: whether execution was successful - :type successful: bool + """ super().on_finish(result, successful) @@ -395,23 +438,24 @@ def on_finish(self, result, successful): self.node.set_exit_status(result.status) self.node.set_exit_message(result.message) else: - raise ValueError('the result should be an integer, ExitCode or None, got {} {} {}'.format( - type(result), result, self.pid)) + raise ValueError( + f'the result should be an integer, ExitCode or None, got {type(result)} {result} {self.pid}' + ) @override - def on_paused(self, msg=None): + def on_paused(self, msg: Optional[str] = None) -> None: """ The Process was paused so set the paused attribute on the process node :param msg: message - :type msg: str + """ super().on_paused(msg) self._save_checkpoint() self.node.pause() @override - def on_playing(self): + def on_playing(self) -> None: """ The Process was unpaused so remove the paused attribute on the process node """ @@ -419,14 +463,13 @@ def on_playing(self): self.node.unpause() @override - def on_output_emitting(self, output_port, value): + def on_output_emitting(self, output_port: str, value: Any) -> None: """ The process has emitted a value on the given output port. :param output_port: The output port name the value was emitted on - :type output_port: str - :param value: The value emitted + """ super().on_output_emitting(output_port, value) @@ -434,39 +477,36 @@ def on_output_emitting(self, output_port, value): if isinstance(output_port, OutputPort) and not isinstance(value, orm.Data): raise TypeError(f'Processes can only return `orm.Data` instances as output, got {value.__class__}') - def set_status(self, status): + def set_status(self, status: Optional[str]) -> None: """ The status of the Process is about to be changed, so we reflect this is in node's attribute proxy. :param status: the status message - :type status: str + """ super().set_status(status) self.node.set_process_status(status) - def submit(self, process, *args, **kwargs): + def submit(self, process: Type['Process'], *args, **kwargs) -> orm.ProcessNode: """Submit process for execution. :param process: process - :type process: :class:`aiida.engine.Process` + :return: the calculation node of the process """ return self.runner.submit(process, *args, **kwargs) @property - def runner(self): - """Get process runner. - - :rtype: :class:`aiida.engine.runners.Runner` - """ + def runner(self) -> 'Runner': + """Get process runner.""" return self._runner - def get_parent_calc(self): + def get_parent_calc(self) -> Optional[orm.ProcessNode]: """ Get the parent process node :return: the parent process node if there is one - :rtype: :class:`aiida.orm.ProcessNode` + """ # Can't get it if we don't know our parent if self._parent_pid is None: @@ -475,12 +515,11 @@ def get_parent_calc(self): return orm.load_node(pk=self._parent_pid) @classmethod - def build_process_type(cls): + def build_process_type(cls) -> str: """ The process type. :return: string of the process type - :rtype: str Note: This could be made into a property 'process_type' but in order to have it be a property of the class it would need to be defined in the metaclass, see https://bugs.python.org/issue20659 @@ -499,29 +538,25 @@ def build_process_type(cls): return process_type - def report(self, msg, *args, **kwargs): + def report(self, msg: str, *args, **kwargs) -> None: """Log a message to the logger, which should get saved to the database through the attached DbLogHandler. The pk, class name and function name of the caller are prepended to the given message :param msg: message to log - :type msg: str - :param args: args to pass to the log call - :type args: list - :param kwargs: kwargs to pass to the log call - :type kwargs: dict + """ message = f'[{self.node.pk}|{self.__class__.__name__}|{inspect.stack()[1][3]}]: {msg}' self.logger.log(LOG_LEVEL_REPORT, message, *args, **kwargs) - def _create_and_setup_db_record(self): + def _create_and_setup_db_record(self) -> Union[int, UUID]: """ Create and setup the database record for this process - :return: the uuid of the process - :rtype: :class:`!uuid.UUID` + :return: the uuid or pk of the process + """ self._node = self.get_or_create_db_record() self._setup_db_record() @@ -529,7 +564,7 @@ def _create_and_setup_db_record(self): try: self.node.store_all() if self.node.is_finished_ok: - self._state = ProcessState.FINISHED + self._state = Finished(self, None, True) # pylint: disable=attribute-defined-outside-init for entry in self.node.get_outgoing(link_type=LinkType.RETURN): if entry.link_label.endswith(f'_{entry.node.pk}'): continue @@ -548,35 +583,33 @@ def _create_and_setup_db_record(self): if self.node.pk is not None: return self.node.pk - return uuid.UUID(self.node.uuid) + return UUID(self.node.uuid) @override - def encode_input_args(self, inputs): + def encode_input_args(self, inputs: Dict[str, Any]) -> str: # pylint: disable=no-self-use """ Encode input arguments such that they may be saved in a Bundle :param inputs: A mapping of the inputs as passed to the process :return: The encoded (serialized) inputs """ - from aiida.orm.utils import serialize return serialize.serialize(inputs) @override - def decode_input_args(self, encoded): + def decode_input_args(self, encoded: str) -> Dict[str, Any]: # pylint: disable=no-self-use """ Decode saved input arguments as they came from the saved instance state Bundle :param encoded: encoded (serialized) inputs :return: The decoded input args """ - from aiida.orm.utils import serialize return serialize.deserialize(encoded) - def update_node_state(self, state): + def update_node_state(self, state: plumpy.process_states.State) -> None: self.update_outputs() self.node.set_process_state(state.LABEL) - def update_outputs(self): + def update_outputs(self) -> None: """Attach new outputs to the node since the last call. Does nothing, if self.metadata.store_provenance is False. @@ -600,7 +633,7 @@ def update_outputs(self): output.store() - def _setup_db_record(self): + def _setup_db_record(self) -> None: """ Create the database record for this process and the links with respect to its inputs @@ -637,7 +670,7 @@ def _setup_db_record(self): self._setup_metadata() self._setup_inputs() - def _setup_metadata(self): + def _setup_metadata(self) -> None: """Store the metadata on the ProcessNode.""" version_info = self.runner.plugin_version_provider.get_version_info(self) self.node.set_attribute_many(version_info) @@ -658,7 +691,7 @@ def _setup_metadata(self): else: raise RuntimeError(f'unsupported metadata key: {name}') - def _setup_inputs(self): + def _setup_inputs(self) -> None: """Create the links between the input nodes and the ProcessNode that represents this process.""" for name, node in self._flat_inputs().items(): @@ -677,7 +710,7 @@ def _setup_inputs(self): elif isinstance(self.node, orm.WorkflowNode): self.node.add_incoming(node, LinkType.INPUT_WORK, name) - def _flat_inputs(self): + def _flat_inputs(self) -> Dict[str, Any]: """ Return a flattened version of the parsed inputs dictionary. @@ -685,12 +718,13 @@ def _flat_inputs(self): is not passed, as those are dealt with separately in `_setup_metadata`. :return: flat dictionary of parsed inputs - :rtype: dict + """ + assert self.inputs is not None inputs = {key: value for key, value in self.inputs.items() if key != self.spec().metadata_key} return dict(self._flatten_inputs(self.spec().inputs, inputs)) - def _flat_outputs(self): + def _flat_outputs(self) -> Dict[str, Any]: """ Return a flattened version of the registered outputs dictionary. @@ -700,24 +734,23 @@ def _flat_outputs(self): """ return dict(self._flatten_outputs(self.spec().outputs, self.outputs)) - def _flatten_inputs(self, port, port_value, parent_name='', separator=PORT_NAMESPACE_SEPARATOR): + def _flatten_inputs( + self, + port: Union[None, InputPort, PortNamespace], + port_value: Any, + parent_name: str = '', + separator: str = PORT_NAMESPACE_SEPARATOR + ) -> List[Tuple[str, Any]]: """ Function that will recursively flatten the inputs dictionary, omitting inputs for ports that are marked as being non database storable :param port: port against which to map the port value, can be InputPort or PortNamespace - :type port: :class:`plumpy.ports.Port` - :param port_value: value for the current port, can be a Mapping - :param parent_name: the parent key with which to prefix the keys - :type parent_name: str - :param separator: character to use for the concatenation of keys - :type separator: str - :return: flat list of inputs - :rtype: list + """ if (port is None and isinstance(port_value, orm.Node)) or (isinstance(port, InputPort) and not port.non_db): return [(parent_name, port_value)] @@ -729,36 +762,36 @@ def _flatten_inputs(self, port, port_value, parent_name='', separator=PORT_NAMES prefixed_key = parent_name + separator + name if parent_name else name try: - nested_port = port[name] + nested_port = cast(Union[InputPort, PortNamespace], port[name]) if port else None except (KeyError, TypeError): nested_port = None sub_items = self._flatten_inputs( - port=nested_port, port_value=value, parent_name=prefixed_key, separator=separator) + port=nested_port, port_value=value, parent_name=prefixed_key, separator=separator + ) items.extend(sub_items) return items assert (port is None) or (isinstance(port, InputPort) and port.non_db) return [] - def _flatten_outputs(self, port, port_value, parent_name='', separator=PORT_NAMESPACE_SEPARATOR): + def _flatten_outputs( + self, + port: Union[None, OutputPort, PortNamespace], + port_value: Any, + parent_name: str = '', + separator: str = PORT_NAMESPACE_SEPARATOR + ) -> List[Tuple[str, Any]]: """ Function that will recursively flatten the outputs dictionary. :param port: port against which to map the port value, can be OutputPort or PortNamespace - :type port: :class:`plumpy.ports.Port` - :param port_value: value for the current port, can be a Mapping - :type parent_name: str - :param parent_name: the parent key with which to prefix the keys - :type parent_name: str - :param separator: character to use for the concatenation of keys - :type separator: str :return: flat list of outputs - :rtype: list + """ if port is None and isinstance(port_value, orm.Node) or isinstance(port, OutputPort): return [(parent_name, port_value)] @@ -770,34 +803,34 @@ def _flatten_outputs(self, port, port_value, parent_name='', separator=PORT_NAME prefixed_key = parent_name + separator + name if parent_name else name try: - nested_port = port[name] + nested_port = cast(Union[OutputPort, PortNamespace], port[name]) if port else None except (KeyError, TypeError): nested_port = None sub_items = self._flatten_outputs( - port=nested_port, port_value=value, parent_name=prefixed_key, separator=separator) + port=nested_port, port_value=value, parent_name=prefixed_key, separator=separator + ) items.extend(sub_items) return items assert port is None, port return [] - def exposed_inputs(self, process_class, namespace=None, agglomerate=True): - """ - Gather a dictionary of the inputs that were exposed for a given Process class under an optional namespace. + def exposed_inputs( + self, + process_class: Type['Process'], + namespace: Optional[str] = None, + agglomerate: bool = True + ) -> AttributeDict: + """Gather a dictionary of the inputs that were exposed for a given Process class under an optional namespace. :param process_class: Process class whose inputs to try and retrieve - :type process_class: :class:`aiida.engine.Process` - :param namespace: PortNamespace in which to look for the inputs - :type namespace: str - :param agglomerate: If set to true, all parent namespaces of the given ``namespace`` will also be searched for inputs. Inputs in lower-lying namespaces take precedence. - :type agglomerate: bool :returns: exposed inputs - :rtype: dict + """ exposed_inputs = {} @@ -811,9 +844,9 @@ def exposed_inputs(self, process_class, namespace=None, agglomerate=True): else: inputs = self.inputs for part in sub_namespace.split('.'): - inputs = inputs[part] + inputs = inputs[part] # type: ignore[index] try: - port_namespace = self.spec().inputs.get_port(sub_namespace) + port_namespace = self.spec().inputs.get_port(sub_namespace) # type: ignore[assignment] except KeyError: raise ValueError(f'this process does not contain the "{sub_namespace}" input namespace') @@ -821,26 +854,26 @@ def exposed_inputs(self, process_class, namespace=None, agglomerate=True): exposed_inputs_list = self.spec()._exposed_inputs[sub_namespace][process_class] # pylint: disable=protected-access for name in port_namespace.ports.keys(): - if name in inputs and name in exposed_inputs_list: + if inputs and name in inputs and name in exposed_inputs_list: exposed_inputs[name] = inputs[name] return AttributeDict(exposed_inputs) - def exposed_outputs(self, node, process_class, namespace=None, agglomerate=True): + def exposed_outputs( + self, + node: orm.ProcessNode, + process_class: Type['Process'], + namespace: Optional[str] = None, + agglomerate: bool = True + ) -> AttributeDict: """Return the outputs which were exposed from the ``process_class`` and emitted by the specific ``node`` :param node: process node whose outputs to try and retrieve - :type node: :class:`aiida.orm.nodes.process.ProcessNode` - :param namespace: Namespace in which to search for exposed outputs. - :type namespace: str - :param agglomerate: If set to true, all parent namespaces of the given ``namespace`` will also be searched for outputs. Outputs in lower-lying namespaces take precedence. - :type agglomerate: bool :returns: exposed outputs - :rtype: dict """ namespace_separator = self.spec().namespace_separator @@ -849,9 +882,7 @@ def exposed_outputs(self, node, process_class, namespace=None, agglomerate=True) # maps the exposed name to all outputs that belong to it top_namespace_map = collections.defaultdict(list) link_types = (LinkType.CREATE, LinkType.RETURN) - process_outputs_dict = { - entry.link_label: entry.node for entry in node.get_outgoing(link_type=link_types) - } + process_outputs_dict = {entry.link_label: entry.node for entry in node.get_outgoing(link_type=link_types)} for port_name in process_outputs_dict: top_namespace = port_name.split(namespace_separator)[0] @@ -876,30 +907,27 @@ def exposed_outputs(self, node, process_class, namespace=None, agglomerate=True) return AttributeDict(result) @staticmethod - def _get_namespace_list(namespace=None, agglomerate=True): + def _get_namespace_list(namespace: Optional[str] = None, agglomerate: bool = True) -> List[Optional[str]]: """Get the list of namespaces in a given namespace. :param namespace: name space - :type namespace: str - :param agglomerate: If set to true, all parent namespaces of the given ``namespace`` will also be searched. - :type agglomerate: bool :returns: namespace list - :rtype: list + """ if not agglomerate: return [namespace] - namespace_list = [None] + namespace_list: List[Optional[str]] = [None] if namespace is not None: split_ns = namespace.split('.') namespace_list.extend(['.'.join(split_ns[:i]) for i in range(1, len(split_ns) + 1)]) return namespace_list @classmethod - def is_valid_cache(cls, node): + def is_valid_cache(cls, node: orm.ProcessNode) -> bool: """Check if the given node can be cached from. .. warning :: When overriding this method, make sure to call @@ -915,7 +943,7 @@ def is_valid_cache(cls, node): return True -def get_query_string_from_process_type_string(process_type_string): # pylint: disable=invalid-name +def get_query_string_from_process_type_string(process_type_string: str) -> str: # pylint: disable=invalid-name """ Take the process type string of a Node and create the queryable type string. diff --git a/aiida/engine/processes/process_spec.py b/aiida/engine/processes/process_spec.py index 334a8e0794..4e73005f2a 100644 --- a/aiida/engine/processes/process_spec.py +++ b/aiida/engine/processes/process_spec.py @@ -8,7 +8,11 @@ # For further information please visit http://www.aiida.net # ########################################################################### """AiiDA specific implementation of plumpy's ProcessSpec.""" -import plumpy +from typing import Optional + +import plumpy.process_spec + +from aiida.orm import Dict from .exit_code import ExitCode, ExitCodesNamespace from .ports import InputPort, PortNamespace, CalcJobOutputPort @@ -16,32 +20,32 @@ __all__ = ('ProcessSpec', 'CalcJobProcessSpec') -class ProcessSpec(plumpy.ProcessSpec): +class ProcessSpec(plumpy.process_spec.ProcessSpec): """Default process spec for process classes defined in `aiida-core`. This sub class defines custom classes for input ports and port namespaces. It also adds support for the definition of exit codes and retrieving them subsequently. """ - METADATA_KEY = 'metadata' - METADATA_OPTIONS_KEY = 'options' + METADATA_KEY: str = 'metadata' + METADATA_OPTIONS_KEY: str = 'options' INPUT_PORT_TYPE = InputPort PORT_NAMESPACE_TYPE = PortNamespace - def __init__(self): + def __init__(self) -> None: super().__init__() self._exit_codes = ExitCodesNamespace() @property - def metadata_key(self): + def metadata_key(self) -> str: return self.METADATA_KEY @property - def options_key(self): + def options_key(self) -> str: return self.METADATA_OPTIONS_KEY @property - def exit_codes(self): + def exit_codes(self) -> ExitCodesNamespace: """ Return the namespace of exit codes defined for this ProcessSpec @@ -49,7 +53,7 @@ def exit_codes(self): """ return self._exit_codes - def exit_code(self, status, label, message, invalidates_cache=False): + def exit_code(self, status: int, label: str, message: str, invalidates_cache: bool = False) -> None: """ Add an exit code to the ProcessSpec @@ -76,24 +80,36 @@ def exit_code(self, status, label, message, invalidates_cache=False): self._exit_codes[label] = ExitCode(status, message, invalidates_cache=invalidates_cache) + # override return type to aiida's PortNamespace subclass + + @property + def ports(self) -> PortNamespace: + return super().ports # type: ignore[return-value] + + @property + def inputs(self) -> PortNamespace: + return super().inputs # type: ignore[return-value] + + @property + def outputs(self) -> PortNamespace: + return super().outputs # type: ignore[return-value] + class CalcJobProcessSpec(ProcessSpec): """Process spec intended for the `CalcJob` process class.""" OUTPUT_PORT_TYPE = CalcJobOutputPort - def __init__(self): + def __init__(self) -> None: super().__init__() - self._default_output_node = None + self._default_output_node: Optional[str] = None @property - def default_output_node(self): + def default_output_node(self) -> Optional[str]: return self._default_output_node @default_output_node.setter - def default_output_node(self, port_name): - from aiida.orm import Dict - + def default_output_node(self, port_name: str) -> None: if port_name not in self.outputs: raise ValueError(f'{port_name} is not a registered output port') diff --git a/aiida/engine/processes/workchains/__init__.py b/aiida/engine/processes/workchains/__init__.py index bea66d0a5b..9b0cf508c9 100644 --- a/aiida/engine/processes/workchains/__init__.py +++ b/aiida/engine/processes/workchains/__init__.py @@ -14,4 +14,4 @@ from .utils import * from .workchain import * -__all__ = (context.__all__ + restart.__all__ + utils.__all__ + workchain.__all__) +__all__ = (context.__all__ + restart.__all__ + utils.__all__ + workchain.__all__) # type: ignore[name-defined] diff --git a/aiida/engine/processes/workchains/awaitable.py b/aiida/engine/processes/workchains/awaitable.py index fee97be995..ea8954ae92 100644 --- a/aiida/engine/processes/workchains/awaitable.py +++ b/aiida/engine/processes/workchains/awaitable.py @@ -9,6 +9,7 @@ ########################################################################### """Enums and function for the awaitables of Processes.""" from enum import Enum +from typing import Union from plumpy.utils import AttributesDict from aiida.orm import ProcessNode @@ -31,7 +32,7 @@ class AwaitableAction(Enum): APPEND = 'append' -def construct_awaitable(target): +def construct_awaitable(target: Union[Awaitable, ProcessNode]) -> Awaitable: """ Construct an instance of the Awaitable class that will contain the information related to the action to be taken with respect to the context once the awaitable diff --git a/aiida/engine/processes/workchains/context.py b/aiida/engine/processes/workchains/context.py index c0c9f31bb4..a22bc0cc02 100644 --- a/aiida/engine/processes/workchains/context.py +++ b/aiida/engine/processes/workchains/context.py @@ -8,14 +8,17 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Convenience functions to add awaitables to the Context of a WorkChain.""" -from .awaitable import construct_awaitable, AwaitableAction +from typing import Union + +from aiida.orm import ProcessNode +from .awaitable import construct_awaitable, Awaitable, AwaitableAction __all__ = ('ToContext', 'assign_', 'append_') ToContext = dict -def assign_(target): +def assign_(target: Union[Awaitable, ProcessNode]) -> Awaitable: """ Convenience function that will construct an Awaitable for a given class instance with the context action set to ASSIGN. When the awaitable target is completed @@ -24,14 +27,14 @@ def assign_(target): :param target: an instance of a Process or Awaitable :returns: the awaitable - :rtype: Awaitable + """ awaitable = construct_awaitable(target) awaitable.action = AwaitableAction.ASSIGN return awaitable -def append_(target): +def append_(target: Union[Awaitable, ProcessNode]) -> Awaitable: """ Convenience function that will construct an Awaitable for a given class instance with the context action set to APPEND. When the awaitable target is completed @@ -40,7 +43,7 @@ def append_(target): :param target: an instance of a Process or Awaitable :returns: the awaitable - :rtype: Awaitable + """ awaitable = construct_awaitable(target) awaitable.action = AwaitableAction.APPEND diff --git a/aiida/engine/processes/workchains/restart.py b/aiida/engine/processes/workchains/restart.py index 7bf5d368bd..5719e1496f 100644 --- a/aiida/engine/processes/workchains/restart.py +++ b/aiida/engine/processes/workchains/restart.py @@ -9,6 +9,9 @@ ########################################################################### """Base implementation of `WorkChain` class that implements a simple automated restart mechanism for sub processes.""" import functools +from inspect import getmembers +from types import FunctionType +from typing import Any, Dict, List, Optional, Type, Union, TYPE_CHECKING from aiida import orm from aiida.common import AttributeDict @@ -17,10 +20,17 @@ from .workchain import WorkChain from .utils import ProcessHandlerReport, process_handler +if TYPE_CHECKING: + from aiida.engine.processes import ExitCode, PortNamespace, Process, ProcessSpec + __all__ = ('BaseRestartWorkChain',) -def validate_handler_overrides(process_class, handler_overrides, ctx): # pylint: disable=unused-argument +def validate_handler_overrides( + process_class: 'BaseRestartWorkChain', + handler_overrides: Optional[orm.Dict], + ctx: 'PortNamespace' # pylint: disable=unused-argument +) -> Optional[str]: """Validator for the `handler_overrides` input port of the `BaseRestartWorkChain. The `handler_overrides` should be a dictionary where keys are strings that are the name of a process handler, i.e. a @@ -36,7 +46,7 @@ def validate_handler_overrides(process_class, handler_overrides, ctx): # pylint :param ctx: the `PortNamespace` in which the port is embedded """ if not handler_overrides: - return + return None for handler, override in handler_overrides.get_dict().items(): if not isinstance(handler, str): @@ -48,6 +58,8 @@ def validate_handler_overrides(process_class, handler_overrides, ctx): # pylint if not isinstance(override, bool): return f'The value of key `{handler}` is not a boolean.' + return None + class BaseRestartWorkChain(WorkChain): """Base restart work chain. @@ -101,11 +113,19 @@ def handle_problem(self, node): `inspect_process`. Refer to their respective documentation for details. """ - _process_class = None + _process_class: Optional[Type['Process']] = None _considered_handlers_extra = 'considered_handlers' + @property + def process_class(self) -> Type['Process']: + """Return the process class to run in the loop.""" + from ..process import Process # pylint: disable=cyclic-import + if self._process_class is None or not issubclass(self._process_class, Process): + raise ValueError('no valid Process class defined for `_process_class` attribute') + return self._process_class + @classmethod - def define(cls, spec): + def define(cls, spec: 'ProcessSpec') -> None: # type: ignore[override] """Define the process specification.""" # yapf: disable super().define(spec) @@ -126,25 +146,28 @@ def define(cls, spec): message='The maximum number of iterations was exceeded.') spec.exit_code(402, 'ERROR_SECOND_CONSECUTIVE_UNHANDLED_FAILURE', message='The process failed for an unknown reason, twice in a row.') + # yapf: enable - def setup(self): + def setup(self) -> None: """Initialize context variables that are used during the logical flow of the `BaseRestartWorkChain`.""" - overrides = self.inputs.handler_overrides.get_dict() if 'handler_overrides' in self.inputs else {} + overrides = self.inputs.handler_overrides.get_dict() if (self.inputs and + 'handler_overrides' in self.inputs) else {} self.ctx.handler_overrides = overrides - self.ctx.process_name = self._process_class.__name__ + self.ctx.process_name = self.process_class.__name__ self.ctx.unhandled_failure = False self.ctx.is_finished = False self.ctx.iteration = 0 - def should_run_process(self): + def should_run_process(self) -> bool: """Return whether a new process should be run. This is the case as long as the last process has not finished successfully and the maximum number of restarts has not yet been exceeded. """ - return not self.ctx.is_finished and self.ctx.iteration < self.inputs.max_iterations.value + max_iterations = self.inputs.max_iterations.value # type: ignore[union-attr] + return not self.ctx.is_finished and self.ctx.iteration < max_iterations - def run_process(self): + def run_process(self) -> ToContext: """Run the next process, taking the input dictionary from the context at `self.ctx.inputs`.""" self.ctx.iteration += 1 @@ -156,8 +179,8 @@ def run_process(self): # Set the `CALL` link label unwrapped_inputs.setdefault('metadata', {})['call_link_label'] = f'iteration_{self.ctx.iteration:02d}' - inputs = self._wrap_bare_dict_inputs(self._process_class.spec().inputs, unwrapped_inputs) - node = self.submit(self._process_class, **inputs) + inputs = self._wrap_bare_dict_inputs(self.process_class.spec().inputs, unwrapped_inputs) + node = self.submit(self.process_class, **inputs) # Add a new empty list to the `BaseRestartWorkChain._considered_handlers_extra` extra. This will contain the # name and return value of all class methods, decorated with `process_handler`, that are called during @@ -170,7 +193,7 @@ def run_process(self): return ToContext(children=append_(node)) - def inspect_process(self): # pylint: disable=too-many-branches + def inspect_process(self) -> Optional['ExitCode']: # pylint: disable=too-many-branches """Analyse the results of the previous process and call the handlers when necessary. If the process is excepted or killed, the work chain will abort. Otherwise any attached handlers will be called @@ -202,10 +225,11 @@ def inspect_process(self): # pylint: disable=too-many-branches last_report = None # Sort the handlers with a priority defined, based on their priority in reverse order - for handler in sorted(self.get_process_handlers(), key=lambda handler: handler.priority, reverse=True): + get_priority = lambda handler: handler.priority + for handler in sorted(self.get_process_handlers(), key=get_priority, reverse=True): # Skip if the handler is enabled, either explicitly through `handler_overrides` or by default - if not self.ctx.handler_overrides.get(handler.__name__, handler.enabled): + if not self.ctx.handler_overrides.get(handler.__name__, handler.enabled): # type: ignore[attr-defined] continue # Even though the `handler` is an instance method, the `get_process_handlers` method returns unbound methods @@ -236,7 +260,7 @@ def inspect_process(self): # pylint: disable=too-many-branches self.ctx.unhandled_failure = True self.report('{}<{}> failed and error was not handled, restarting once more'.format(*report_args)) - return + return None # Here either the process finished successful or at least one handler returned a report so it can no longer be # considered to be an unhandled failed process and therefore we reset the flag @@ -260,16 +284,21 @@ def inspect_process(self): # pylint: disable=too-many-branches # Otherwise the process was successful and no handler returned anything so we consider the work done self.ctx.is_finished = True - def results(self): + return None + + def results(self) -> Optional['ExitCode']: """Attach the outputs specified in the output specification from the last completed process.""" node = self.ctx.children[self.ctx.iteration - 1] # We check the `is_finished` attribute of the work chain and not the successfulness of the last process # because the error handlers in the last iteration can have qualified a "failed" process as satisfactory # for the outcome of the work chain and so have marked it as `is_finished=True`. - if not self.ctx.is_finished and self.ctx.iteration >= self.inputs.max_iterations.value: - self.report('reached the maximum number of iterations {}: last ran {}<{}>'.format( - self.inputs.max_iterations.value, self.ctx.process_name, node.pk)) + max_iterations = self.inputs.max_iterations.value # type: ignore[union-attr] + if not self.ctx.is_finished and self.ctx.iteration >= max_iterations: + self.report( + f'reached the maximum number of iterations {max_iterations}: ' + f'last ran {self.ctx.process_name}<{node.pk}>' + ) return self.exit_codes.ERROR_MAXIMUM_ITERATIONS_EXCEEDED # pylint: disable=no-member self.report(f'work chain completed after {self.ctx.iteration} iterations') @@ -284,16 +313,17 @@ def results(self): else: self.out(name, output) - def __init__(self, *args, **kwargs): + return None + + def __init__(self, *args, **kwargs) -> None: """Construct the instance.""" - from ..process import Process # pylint: disable=cyclic-import super().__init__(*args, **kwargs) - if self._process_class is None or not issubclass(self._process_class, Process): - raise ValueError('no valid Process class defined for `_process_class` attribute') + # try retrieving process class + self.process_class # pylint: disable=pointless-statement @classmethod - def is_process_handler(cls, process_handler_name): + def is_process_handler(cls, process_handler_name: Union[str, FunctionType]) -> bool: """Return whether the given method name corresponds to a process handler of this class. :param process_handler_name: string name of the instance method @@ -308,15 +338,14 @@ def is_process_handler(cls, process_handler_name): return getattr(handler, 'decorator', None) == process_handler @classmethod - def get_process_handlers(cls): - from inspect import getmembers + def get_process_handlers(cls) -> List[FunctionType]: return [method[1] for method in getmembers(cls) if cls.is_process_handler(method[1])] def on_terminated(self): """Clean the working directories of all child calculation jobs if `clean_workdir=True` in the inputs.""" super().on_terminated() - if self.inputs.clean_workdir.value is False: + if self.inputs.clean_workdir.value is False: # type: ignore[union-attr] self.report('remote folders will not be cleaned') return @@ -333,7 +362,7 @@ def on_terminated(self): if cleaned_calcs: self.report(f"cleaned remote folders of calculations: {' '.join(cleaned_calcs)}") - def _wrap_bare_dict_inputs(self, port_namespace, inputs): + def _wrap_bare_dict_inputs(self, port_namespace: 'PortNamespace', inputs: Dict[str, Any]) -> AttributeDict: """Wrap bare dictionaries in `inputs` in a `Dict` node if dictated by the corresponding inputs portnamespace. :param port_namespace: a `PortNamespace` diff --git a/aiida/engine/processes/workchains/utils.py b/aiida/engine/processes/workchains/utils.py index 53dceb3a60..b25f15de20 100644 --- a/aiida/engine/processes/workchains/utils.py +++ b/aiida/engine/processes/workchains/utils.py @@ -12,6 +12,7 @@ from functools import partial from inspect import getfullargspec from types import FunctionType # pylint: disable=no-name-in-module +from typing import List, Optional, Union from wrapt import decorator from ..exit_code import ExitCode @@ -19,7 +20,7 @@ __all__ = ('ProcessHandlerReport', 'process_handler') ProcessHandlerReport = namedtuple('ProcessHandlerReport', 'do_break exit_code') -ProcessHandlerReport.__new__.__defaults__ = (False, ExitCode()) +ProcessHandlerReport.__new__.__defaults__ = (False, ExitCode()) # type: ignore[attr-defined,call-arg] """A namedtuple to define a process handler report for a :class:`aiida.engine.BaseRestartWorkChain`. This namedtuple should be returned by a process handler of a work chain instance if the condition of the handler was @@ -36,7 +37,13 @@ """ -def process_handler(wrapped=None, *, priority=0, exit_codes=None, enabled=True): +def process_handler( + wrapped: Optional[FunctionType] = None, + *, + priority: int = 0, + exit_codes: Union[None, ExitCode, List[ExitCode]] = None, + enabled: bool = True +) -> FunctionType: """Decorator to register a :class:`~aiida.engine.BaseRestartWorkChain` instance method as a process handler. The decorator will validate the `priority` and `exit_codes` optional keyword arguments and then add itself as an @@ -55,7 +62,7 @@ def process_handler(wrapped=None, *, priority=0, exit_codes=None, enabled=True): `do_break` attribute should be set to `True`. If the work chain is to be aborted entirely, the `exit_code` of the report can be set to an `ExitCode` instance with a non-zero status. - :param cls: the work chain class to register the process handler with + :param wrapped: the work chain method to register the process handler with :param priority: optional integer that defines the order in which registered handlers will be called during the handling of a finished process. Higher priorities will be handled first. Default value is `0`. Multiple handlers with the same priority is allowed, but the order of those is not well defined. @@ -67,7 +74,9 @@ def process_handler(wrapped=None, *, priority=0, exit_codes=None, enabled=True): basis through the input `handler_overrides`. """ if wrapped is None: - return partial(process_handler, priority=priority, exit_codes=exit_codes, enabled=enabled) + return partial( + process_handler, priority=priority, exit_codes=exit_codes, enabled=enabled + ) # type: ignore[return-value] if not isinstance(wrapped, FunctionType): raise TypeError('first argument can only be an instance method, use keywords for decorator arguments.') @@ -89,9 +98,9 @@ def process_handler(wrapped=None, *, priority=0, exit_codes=None, enabled=True): if len(handler_args) != 2: raise TypeError(f'process handler `{wrapped.__name__}` has invalid signature: should be (self, node)') - wrapped.decorator = process_handler - wrapped.priority = priority - wrapped.enabled = enabled + wrapped.decorator = process_handler # type: ignore[attr-defined] + wrapped.priority = priority # type: ignore[attr-defined] + wrapped.enabled = enabled # type: ignore[attr-defined] @decorator def wrapper(wrapped, instance, args, kwargs): diff --git a/aiida/engine/processes/workchains/workchain.py b/aiida/engine/processes/workchains/workchain.py index 00f0f479f2..698ad9de44 100644 --- a/aiida/engine/processes/workchains/workchain.py +++ b/aiida/engine/processes/workchains/workchain.py @@ -10,15 +10,17 @@ """Components for the WorkChain concept of the workflow engine.""" import collections import functools +import logging +from typing import Any, List, Optional, Sequence, Union, TYPE_CHECKING -import plumpy -from plumpy import auto_persist, Wait, Continue -from plumpy.workchains import if_, while_, return_, _PropagateReturn +from plumpy.persistence import auto_persist +from plumpy.process_states import Wait, Continue +from plumpy.workchains import if_, while_, return_, _PropagateReturn, Stepper, WorkChainSpec as PlumpyWorkChainSpec from aiida.common import exceptions from aiida.common.extendeddicts import AttributeDict from aiida.common.lang import override -from aiida.orm import Node, WorkChainNode +from aiida.orm import Node, ProcessNode, WorkChainNode from aiida.orm.utils import load_node from ..exit_code import ExitCode @@ -26,10 +28,13 @@ from ..process import Process, ProcessState from .awaitable import Awaitable, AwaitableTarget, AwaitableAction, construct_awaitable +if TYPE_CHECKING: + from aiida.engine.runners import Runner + __all__ = ('WorkChain', 'if_', 'while_', 'return_') -class WorkChainSpec(ProcessSpec, plumpy.WorkChainSpec): +class WorkChainSpec(ProcessSpec, PlumpyWorkChainSpec): pass @@ -42,22 +47,21 @@ class WorkChain(Process): _STEPPER_STATE = 'stepper_state' _CONTEXT = 'CONTEXT' - def __init__(self, inputs=None, logger=None, runner=None, enable_persistence=True): + def __init__( + self, + inputs: Optional[dict] = None, + logger: Optional[logging.Logger] = None, + runner: Optional['Runner'] = None, + enable_persistence: bool = True + ) -> None: """Construct a WorkChain instance. Construct the instance only if it is a sub class of `WorkChain`, otherwise raise `InvalidOperation`. :param inputs: work chain inputs - :type inputs: dict - :param logger: aiida logger - :type logger: :class:`logging.Logger` - :param runner: work chain runner - :type: :class:`aiida.engine.runners.Runner` - :param enable_persistence: whether to persist this work chain - :type enable_persistence: bool """ if self.__class__ == WorkChain: @@ -65,21 +69,22 @@ def __init__(self, inputs=None, logger=None, runner=None, enable_persistence=Tru super().__init__(inputs, logger, runner, enable_persistence=enable_persistence) - self._stepper = None - self._awaitables = [] + self._stepper: Optional[Stepper] = None + self._awaitables: List[Awaitable] = [] self._context = AttributeDict() - @property - def ctx(self): - """Get context. + @classmethod + def spec(cls) -> WorkChainSpec: + return super().spec() # type: ignore[return-value] - :rtype: :class:`aiida.common.extendeddicts.AttributeDict` - """ + @property + def ctx(self) -> AttributeDict: + """Get the context.""" return self._context @override def save_instance_state(self, out_state, save_context): - """Save instance stace. + """Save instance state. :param out_state: state to save in @@ -105,7 +110,7 @@ def load_instance_state(self, saved_state, load_context): self._stepper = None stepper_state = saved_state.get(self._STEPPER_STATE, None) if stepper_state is not None: - self._stepper = self.spec().get_outline().recreate_stepper(stepper_state, self) + self._stepper = self.spec().get_outline().recreate_stepper(stepper_state, self) # type: ignore[arg-type] self.set_logger(self.node.logger) @@ -116,7 +121,7 @@ def on_run(self): super().on_run() self.node.set_stepper_state_info(str(self._stepper)) - def insert_awaitable(self, awaitable): + def insert_awaitable(self, awaitable: Awaitable) -> None: """Insert an awaitable that should be terminated before before continuing to the next step. :param awaitable: the thing to await @@ -137,7 +142,7 @@ def insert_awaitable(self, awaitable): self._update_process_status() - def resolve_awaitable(self, awaitable, value): + def resolve_awaitable(self, awaitable: Awaitable, value: Any) -> None: """Resolve an awaitable. Precondition: must be an awaitable that was previously inserted. @@ -164,7 +169,7 @@ def resolve_awaitable(self, awaitable, value): self._update_process_status() - def to_context(self, **kwargs): + def to_context(self, **kwargs: Union[Awaitable, ProcessNode]) -> None: """Add a dictionary of awaitables to the context. This is a convenience method that provides syntactic sugar, for a user to add multiple intersteps that will @@ -175,7 +180,7 @@ def to_context(self, **kwargs): awaitable.key = key self.insert_awaitable(awaitable) - def _update_process_status(self): + def _update_process_status(self) -> None: """Set the process status with a message accounting the current sub processes that we are waiting for.""" if self._awaitables: status = f"Waiting for child processes: {', '.join([str(_.pk) for _ in self._awaitables])}" @@ -184,11 +189,11 @@ def _update_process_status(self): self.node.set_process_status(None) @override - def run(self): - self._stepper = self.spec().get_outline().create_stepper(self) + def run(self) -> Any: + self._stepper = self.spec().get_outline().create_stepper(self) # type: ignore[arg-type] return self._do_step() - def _do_step(self): + def _do_step(self) -> Any: """Execute the next step in the outline and return the result. If the stepper returns a non-finished status and the return value is of type ToContext, the contents of the @@ -199,16 +204,17 @@ def _do_step(self): from .context import ToContext self._awaitables = [] - result = None + result: Any = None try: + assert self._stepper is not None finished, stepper_result = self._stepper.step() except _PropagateReturn as exception: finished, result = True, exception.exit_code else: # Set result to None unless stepper_result was non-zero positive integer or ExitCode with similar status if isinstance(stepper_result, int) and stepper_result > 0: - result = ExitCode(stepper_result) + result = ExitCode(stepper_result) # type: ignore[call-arg] elif isinstance(stepper_result, ExitCode) and stepper_result.status > 0: result = stepper_result else: @@ -226,7 +232,7 @@ def _do_step(self): return Continue(self._do_step) - def _store_nodes(self, data): + def _store_nodes(self, data: Any) -> None: """Recurse through a data structure and store any unstored nodes that are found along the way :param data: a data structure potentially containing unstored nodes @@ -241,7 +247,7 @@ def _store_nodes(self, data): self._store_nodes(value) @override - def on_exiting(self): + def on_exiting(self) -> None: """Ensure that any unstored nodes in the context are stored, before the state is exited After the state is exited the next state will be entered and if persistence is enabled, a checkpoint will @@ -254,14 +260,15 @@ def on_exiting(self): # An uncaught exception here will have bizarre and disastrous consequences self.logger.exception('exception in _store_nodes called in on_exiting') - def on_wait(self, awaitables): + def on_wait(self, awaitables: Sequence[Awaitable]): + """Entering the WAITING state.""" super().on_wait(awaitables) if self._awaitables: self.action_awaitables() else: self.call_soon(self.resume) - def action_awaitables(self): + def action_awaitables(self) -> None: """Handle the awaitables that are currently registered with the work chain. Depending on the class type of the awaitable's target a different callback @@ -275,7 +282,7 @@ def action_awaitables(self): else: assert f"invalid awaitable target '{awaitable.target}'" - def on_process_finished(self, awaitable): + def on_process_finished(self, awaitable: Awaitable) -> None: """Callback function called by the runner when the process instance identified by pk is completed. The awaitable will be effectuated on the context of the work chain and removed from the internal list. If all diff --git a/aiida/engine/runners.py b/aiida/engine/runners.py index be2a3d377b..7708c851b5 100644 --- a/aiida/engine/runners.py +++ b/aiida/engine/runners.py @@ -9,23 +9,25 @@ ########################################################################### # pylint: disable=global-statement """Runners that can run and submit processes.""" -import collections +import asyncio import functools import logging import signal import threading +from typing import Any, Callable, Dict, NamedTuple, Optional, Tuple, Type, Union import uuid -import asyncio import kiwipy -import plumpy -from plumpy import set_event_loop_policy, reset_event_loop_policy +from plumpy.persistence import Persister +from plumpy.process_comms import RemoteProcessThreadController +from plumpy.events import set_event_loop_policy, reset_event_loop_policy +from plumpy.communications import wrap_communicator from aiida.common import exceptions -from aiida.orm import load_node +from aiida.orm import load_node, ProcessNode from aiida.plugins.utils import PluginVersionProvider -from .processes import futures, ProcessState +from .processes import futures, Process, ProcessBuilder, ProcessState from .processes.calcjobs import manager from . import transports from . import utils @@ -34,28 +36,46 @@ LOGGER = logging.getLogger(__name__) -ResultAndNode = collections.namedtuple('ResultAndNode', ['result', 'node']) -ResultAndPk = collections.namedtuple('ResultAndPk', ['result', 'pk']) + +class ResultAndNode(NamedTuple): + node: ProcessNode + result: Dict[str, Any] + + +class ResultAndPk(NamedTuple): + node: ProcessNode + pk: int + + +TYPE_RUN_PROCESS = Union[Process, Type[Process], ProcessBuilder] # pylint: disable=invalid-name +# run can also be process function, but it is not clear what type this should be +TYPE_SUBMIT_PROCESS = Union[Process, Type[Process], ProcessBuilder] # pylint: disable=invalid-name class Runner: # pylint: disable=too-many-public-methods """Class that can launch processes by running in the current interpreter or by submitting them to the daemon.""" - _persister = None - _communicator = None - _controller = None - _closed = False - - def __init__(self, poll_interval=0, loop=None, communicator=None, rmq_submit=False, persister=None): + _persister: Optional[Persister] = None + _communicator: Optional[kiwipy.Communicator] = None + _controller: Optional[RemoteProcessThreadController] = None + _closed: bool = False + + def __init__( + self, + poll_interval: Union[int, float] = 0, + loop: Optional[asyncio.AbstractEventLoop] = None, + communicator: Optional[kiwipy.Communicator] = None, + rmq_submit: bool = False, + persister: Optional[Persister] = None + ): """Construct a new runner. :param poll_interval: interval in seconds between polling for status of active sub processes :param loop: an asyncio event loop, if none is suppled a new one will be created :param communicator: the communicator to use - :type communicator: :class:`kiwipy.Communicator` :param rmq_submit: if True, processes will be submitted to RabbitMQ, otherwise they will be scheduled here :param persister: the persister to use to persist processes - :type persister: :class:`plumpy.Persister` + """ assert not (rmq_submit and persister is None), \ 'Must supply a persister if you want to submit using communicator' @@ -70,94 +90,86 @@ def __init__(self, poll_interval=0, loop=None, communicator=None, rmq_submit=Fal self._plugin_version_provider = PluginVersionProvider() if communicator is not None: - self._communicator = plumpy.wrap_communicator(communicator, self._loop) - self._controller = plumpy.RemoteProcessThreadController(communicator) + self._communicator = wrap_communicator(communicator, self._loop) + self._controller = RemoteProcessThreadController(communicator) elif self._rmq_submit: LOGGER.warning('Disabling RabbitMQ submission, no communicator provided') self._rmq_submit = False - def __enter__(self): + def __enter__(self) -> 'Runner': return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() @property - def loop(self): - """ - Get the event loop of this runner - - :return: the asyncio event loop - """ + def loop(self) -> asyncio.AbstractEventLoop: + """Get the event loop of this runner.""" return self._loop @property - def transport(self): + def transport(self) -> transports.TransportQueue: return self._transport @property - def persister(self): + def persister(self) -> Optional[Persister]: + """Get the persister used by this runner.""" return self._persister @property - def communicator(self): - """ - Get the communicator used by this runner - - :return: the communicator - :rtype: :class:`kiwipy.Communicator` - """ + def communicator(self) -> Optional[kiwipy.Communicator]: + """Get the communicator used by this runner.""" return self._communicator @property - def plugin_version_provider(self): + def plugin_version_provider(self) -> PluginVersionProvider: return self._plugin_version_provider @property - def job_manager(self): + def job_manager(self) -> manager.JobManager: return self._job_manager @property - def controller(self): + def controller(self) -> Optional[RemoteProcessThreadController]: + """Get the controller used by this runner.""" return self._controller @property - def is_daemon_runner(self): + def is_daemon_runner(self) -> bool: """Return whether the runner is a daemon runner, which means it submits processes over RabbitMQ. :return: True if the runner is a daemon runner - :rtype: bool """ return self._rmq_submit - def is_closed(self): + def is_closed(self) -> bool: return self._closed - def start(self): + def start(self) -> None: """Start the internal event loop.""" self._loop.run_forever() - def stop(self): + def stop(self) -> None: """Stop the internal event loop.""" self._loop.stop() - def run_until_complete(self, future): + def run_until_complete(self, future: asyncio.Future) -> Any: """Run the loop until the future has finished and return the result.""" with utils.loop_scope(self._loop): return self._loop.run_until_complete(future) - def close(self): + def close(self) -> None: """Close the runner by stopping the loop.""" assert not self._closed self.stop() reset_event_loop_policy() self._closed = True - def instantiate_process(self, process, *args, **inputs): + def instantiate_process(self, process: TYPE_RUN_PROCESS, *args, **inputs): from .utils import instantiate_process return instantiate_process(self, process, *args, **inputs) - def submit(self, process, *args, **inputs): + def submit(self, process: TYPE_SUBMIT_PROCESS, *args: Any, **inputs: Any): """ Submit the process with the supplied inputs to this runner immediately returning control to the interpreter. The return value will be the calculation node of the submitted process @@ -169,24 +181,26 @@ def submit(self, process, *args, **inputs): assert not utils.is_process_function(process), 'Cannot submit a process function' assert not self._closed - process = self.instantiate_process(process, *args, **inputs) + process_inited = self.instantiate_process(process, *args, **inputs) - if not process.metadata.store_provenance: + if not process_inited.metadata.store_provenance: raise exceptions.InvalidOperation('cannot submit a process with `store_provenance=False`') - if process.metadata.get('dry_run', False): + if process_inited.metadata.get('dry_run', False): raise exceptions.InvalidOperation('cannot submit a process from within another with `dry_run=True`') if self._rmq_submit: - self.persister.save_checkpoint(process) - process.close() - self.controller.continue_process(process.pid, nowait=False, no_reply=True) + assert self.persister is not None, 'runner does not have a persister' + assert self.controller is not None, 'runner does not have a controller' + self.persister.save_checkpoint(process_inited) + process_inited.close() + self.controller.continue_process(process_inited.pid, nowait=False, no_reply=True) else: - self.loop.create_task(process.step_until_terminated()) + self.loop.create_task(process_inited.step_until_terminated()) - return process.node + return process_inited.node - def schedule(self, process, *args, **inputs): + def schedule(self, process: TYPE_SUBMIT_PROCESS, *args: Any, **inputs: Any) -> ProcessNode: """ Schedule a process to be executed by this runner @@ -197,11 +211,11 @@ def schedule(self, process, *args, **inputs): assert not utils.is_process_function(process), 'Cannot submit a process function' assert not self._closed - process = self.instantiate_process(process, *args, **inputs) - self.loop.create_task(process.step_until_terminated()) - return process.node + process_inited = self.instantiate_process(process, *args, **inputs) + self.loop.create_task(process_inited.step_until_terminated()) + return process_inited.node - def _run(self, process, *args, **inputs): + def _run(self, process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> Tuple[Dict[str, Any], ProcessNode]: """ Run the process with the supplied inputs in this runner that will block until the process is completed. The return value will be the results of the completed process @@ -213,24 +227,24 @@ def _run(self, process, *args, **inputs): assert not self._closed if utils.is_process_function(process): - result, node = process.run_get_node(*args, **inputs) + result, node = process.run_get_node(*args, **inputs) # type: ignore[union-attr] return result, node with utils.loop_scope(self.loop): - process = self.instantiate_process(process, *args, **inputs) + process_inited = self.instantiate_process(process, *args, **inputs) def kill_process(_num, _frame): """Send the kill signal to the process in the current scope.""" - LOGGER.critical('runner received interrupt, killing process %s', process.pid) - process.kill(msg='Process was killed because the runner received an interrupt') + LOGGER.critical('runner received interrupt, killing process %s', process_inited.pid) + process_inited.kill(msg='Process was killed because the runner received an interrupt') signal.signal(signal.SIGINT, kill_process) signal.signal(signal.SIGTERM, kill_process) - process.execute() - return process.outputs, process.node + process_inited.execute() + return process_inited.outputs, process_inited.node - def run(self, process, *args, **inputs): + def run(self, process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> Dict[str, Any]: """ Run the process with the supplied inputs in this runner that will block until the process is completed. The return value will be the results of the completed process @@ -242,7 +256,7 @@ def run(self, process, *args, **inputs): result, _ = self._run(process, *args, **inputs) return result - def run_get_node(self, process, *args, **inputs): + def run_get_node(self, process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> ResultAndNode: """ Run the process with the supplied inputs in this runner that will block until the process is completed. The return value will be the results of the completed process @@ -254,7 +268,7 @@ def run_get_node(self, process, *args, **inputs): result, node = self._run(process, *args, **inputs) return ResultAndNode(result, node) - def run_get_pk(self, process, *args, **inputs): + def run_get_pk(self, process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> ResultAndPk: """ Run the process with the supplied inputs in this runner that will block until the process is completed. The return value will be the results of the completed process @@ -266,7 +280,7 @@ def run_get_pk(self, process, *args, **inputs): result, node = self._run(process, *args, **inputs) return ResultAndPk(result, node.pk) - def call_on_process_finish(self, pk, callback): + def call_on_process_finish(self, pk: int, callback: Callable[[], Any]) -> None: """Schedule a callback when the process of the given pk is terminated. This method will add a broadcast subscriber that will listen for state changes of the target process to be @@ -276,6 +290,8 @@ def call_on_process_finish(self, pk, callback): :param pk: pk of the process :param callback: function to be called upon process termination """ + assert self.communicator is not None, 'communicator not set for runner' + node = load_node(pk=pk) subscriber_identifier = str(uuid.uuid4()) event = threading.Event() @@ -293,17 +309,17 @@ def inline_callback(event, *args, **kwargs): # pylint: disable=unused-argument callback() finally: event.set() - self._communicator.remove_broadcast_subscriber(subscriber_identifier) + self.communicator.remove_broadcast_subscriber(subscriber_identifier) # type: ignore[union-attr] broadcast_filter = kiwipy.BroadcastFilter(functools.partial(inline_callback, event), sender=pk) for state in [ProcessState.FINISHED, ProcessState.KILLED, ProcessState.EXCEPTED]: broadcast_filter.add_subject_filter(f'state_changed.*.{state.value}') LOGGER.info('adding subscriber for broadcasts of %d', pk) - self._communicator.add_broadcast_subscriber(broadcast_filter, subscriber_identifier) + self.communicator.add_broadcast_subscriber(broadcast_filter, subscriber_identifier) self._poll_process(node, functools.partial(inline_callback, event)) - def get_process_future(self, pk): + def get_process_future(self, pk: int) -> futures.ProcessFuture: """Return a future for a process. The future will have the process node as the result when finished. diff --git a/aiida/engine/transports.py b/aiida/engine/transports.py index be028adb4f..8cd0204d40 100644 --- a/aiida/engine/transports.py +++ b/aiida/engine/transports.py @@ -12,8 +12,12 @@ import contextlib import logging import traceback +from typing import Awaitable, Dict, Hashable, Iterator, Optional import asyncio +from aiida.orm import AuthInfo +from aiida.transports import Transport + _LOGGER = logging.getLogger(__name__) @@ -22,7 +26,7 @@ class TransportRequest: def __init__(self): super().__init__() - self.future = asyncio.Future() + self.future: asyncio.Future = asyncio.Future() self.count = 0 @@ -39,20 +43,20 @@ class TransportQueue: """ AuthInfoEntry = namedtuple('AuthInfoEntry', ['authinfo', 'transport', 'callbacks', 'callback_handle']) - def __init__(self, loop=None): + def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None): """ :param loop: An asyncio event, will use `asyncio.get_event_loop()` if not supplied """ self._loop = loop if loop is not None else asyncio.get_event_loop() - self._transport_requests = {} + self._transport_requests: Dict[Hashable, TransportRequest] = {} @property - def loop(self): + def loop(self) -> asyncio.AbstractEventLoop: """ Get the loop being used by this transport queue """ return self._loop @contextlib.contextmanager - def request_transport(self, authinfo): + def request_transport(self, authinfo: AuthInfo) -> Iterator[Awaitable[Transport]]: """ Request a transport from an authinfo. Because the client is not allowed to request a transport immediately they will instead be given back a future @@ -79,7 +83,7 @@ async def transport_task(transport_queue, authinfo): def do_open(): """ Actually open the transport """ - if transport_request.count > 0: + if transport_request and transport_request.count > 0: # The user still wants the transport so open it _LOGGER.debug('Transport request opening transport for %s', authinfo) try: diff --git a/aiida/engine/utils.py b/aiida/engine/utils.py index 5130903966..d76d55443e 100644 --- a/aiida/engine/utils.py +++ b/aiida/engine/utils.py @@ -10,9 +10,15 @@ # pylint: disable=invalid-name """Utilities for the workflow engine.""" +import asyncio import contextlib +from datetime import datetime import logging -import asyncio +from typing import Any, Awaitable, Callable, Iterator, List, Optional, Type, Union, TYPE_CHECKING + +if TYPE_CHECKING: + from .processes import Process, ProcessBuilder + from .runners import Runner __all__ = ('interruptable_task', 'InterruptableFuture', 'is_process_function') @@ -21,7 +27,9 @@ PROCESS_STATE_CHANGE_DESCRIPTION = 'The last time a process of type {}, changed state' -def instantiate_process(runner, process, *args, **inputs): +def instantiate_process( + runner: 'Runner', process: Union['Process', Type['Process'], 'ProcessBuilder'], *args, **inputs +) -> 'Process': """ Return an instance of the process with the given inputs. The function can deal with various types of the `process`: @@ -48,7 +56,7 @@ def instantiate_process(runner, process, *args, **inputs): process_class = builder.process_class inputs.update(**builder._inputs(prune=True)) # pylint: disable=protected-access elif is_process_function(process): - process_class = process.process_class + process_class = process.process_class # type: ignore[attr-defined] elif issubclass(process, Process): process_class = process else: @@ -62,11 +70,11 @@ def instantiate_process(runner, process, *args, **inputs): class InterruptableFuture(asyncio.Future): """A future that can be interrupted by calling `interrupt`.""" - def interrupt(self, reason): + def interrupt(self, reason: Exception) -> None: """This method should be called to interrupt the coroutine represented by this InterruptableFuture.""" self.set_exception(reason) - async def with_interrupt(self, coro): + async def with_interrupt(self, coro: Awaitable[Any]) -> Any: """ return result of a coroutine which will be interrupted if this future is interrupted :: @@ -91,7 +99,10 @@ async def with_interrupt(self, coro): return result -def interruptable_task(coro, loop=None): +def interruptable_task( + coro: Callable[[InterruptableFuture], Awaitable[Any]], + loop: Optional[asyncio.AbstractEventLoop] = None +) -> InterruptableFuture: """ Turn the given coroutine into an interruptable task by turning it into an InterruptableFuture and returning it. @@ -126,7 +137,7 @@ async def execute_coroutine(): return future -def ensure_coroutine(fct): +def ensure_coroutine(fct: Callable[..., Any]) -> Callable[..., Awaitable[Any]]: """ Ensure that the given function ``fct`` is a coroutine @@ -144,7 +155,13 @@ async def wrapper(*args, **kwargs): return wrapper -async def exponential_backoff_retry(fct, initial_interval=10.0, max_attempts=5, logger=None, ignore_exceptions=None): +async def exponential_backoff_retry( + fct: Callable[..., Any], + initial_interval: Union[int, float] = 10.0, + max_attempts: int = 5, + logger: Optional[logging.Logger] = None, + ignore_exceptions=None +) -> Any: """ Coroutine to call a function, recalling it with an exponential backoff in the case of an exception @@ -162,7 +179,7 @@ async def exponential_backoff_retry(fct, initial_interval=10.0, max_attempts=5, if logger is None: logger = LOGGER - result = None + result: Any = None coro = ensure_coroutine(fct) interval = initial_interval @@ -191,7 +208,7 @@ async def exponential_backoff_retry(fct, initial_interval=10.0, max_attempts=5, return result -def is_process_function(function): +def is_process_function(function: Any) -> bool: """Return whether the given function is a process function :param function: a function @@ -203,7 +220,7 @@ def is_process_function(function): return False -def is_process_scoped(): +def is_process_scoped() -> bool: """Return whether the current scope is within a process. :returns: True if the current scope is within a nested process, False otherwise @@ -213,7 +230,7 @@ def is_process_scoped(): @contextlib.contextmanager -def loop_scope(loop): +def loop_scope(loop) -> Iterator[None]: """ Make an event loop current for the scope of the context @@ -229,7 +246,7 @@ def loop_scope(loop): asyncio.set_event_loop(current) -def set_process_state_change_timestamp(process): +def set_process_state_change_timestamp(process: 'Process') -> None: """ Set the global setting that reflects the last time a process changed state, for the process type of the given process, to the current timestamp. The process type will be determined based on @@ -263,7 +280,7 @@ def set_process_state_change_timestamp(process): process.logger.debug(f'could not update the {key} setting because of a UniquenessError: {exception}') -def get_process_state_change_timestamp(process_type=None): +def get_process_state_change_timestamp(process_type: Optional[str] = None) -> Optional[datetime]: """ Get the global setting that reflects the last time a process of the given process type changed its state. The returned value will be the corresponding timestamp or None if the setting does not exist. @@ -288,7 +305,7 @@ def get_process_state_change_timestamp(process_type=None): else: process_types = [process_type] - timestamps = [] + timestamps: List[datetime] = [] for process_type_key in process_types: key = PROCESS_STATE_CHANGE_KEY.format(process_type_key) diff --git a/aiida/manage/manager.py b/aiida/manage/manager.py index 94251cb8c1..1c88029764 100644 --- a/aiida/manage/manager.py +++ b/aiida/manage/manager.py @@ -9,11 +9,23 @@ ########################################################################### # pylint: disable=cyclic-import """AiiDA manager for global settings""" +import asyncio import functools +from typing import Any, Optional, TYPE_CHECKING -__all__ = ('get_manager', 'reset_manager') +if TYPE_CHECKING: + from kiwipy.rmq import RmqThreadCommunicator + from plumpy.process_comms import RemoteProcessThreadController + + from aiida.backends.manager import BackendManager + from aiida.engine.daemon.client import DaemonClient + from aiida.engine.runners import Runner + from aiida.manage.configuration.config import Config + from aiida.manage.configuration.profile import Profile + from aiida.orm.implementation import Backend + from aiida.engine.persistence import AiiDAPersister -MANAGER = None +__all__ = ('get_manager', 'reset_manager') class Manager: @@ -32,34 +44,62 @@ class Manager: * reset manager cache when loading a new profile """ + def __init__(self) -> None: + self._backend: Optional['Backend'] = None + self._backend_manager: Optional['BackendManager'] = None + self._config: Optional['Config'] = None + self._daemon_client: Optional['DaemonClient'] = None + self._profile: Optional['Profile'] = None + self._communicator: Optional['RmqThreadCommunicator'] = None + self._process_controller: Optional['RemoteProcessThreadController'] = None + self._persister: Optional['AiiDAPersister'] = None + self._runner: Optional['Runner'] = None + + def close(self) -> None: + """Reset the global settings entirely and release any global objects.""" + if self._communicator is not None: + self._communicator.close() + if self._runner is not None: + self._runner.stop() + + self._backend = None + self._backend_manager = None + self._config = None + self._profile = None + self._communicator = None + self._daemon_client = None + self._process_controller = None + self._persister = None + self._runner = None + @staticmethod - def get_config(): + def get_config() -> 'Config': """Return the current config. :return: current loaded config instance - :rtype: :class:`~aiida.manage.configuration.config.Config` :raises aiida.common.ConfigurationError: if the configuration file could not be found, read or deserialized + """ from .configuration import get_config return get_config() @staticmethod - def get_profile(): + def get_profile() -> Optional['Profile']: """Return the current loaded profile, if any :return: current loaded profile instance - :rtype: :class:`~aiida.manage.configuration.profile.Profile` or None + """ from .configuration import get_profile return get_profile() - def unload_backend(self): + def unload_backend(self) -> None: """Unload the current backend and its corresponding database environment.""" manager = self.get_backend_manager() manager.reset_backend_environment() self._backend = None - def _load_backend(self, schema_check=True): + def _load_backend(self, schema_check: bool = True) -> 'Backend': """Load the backend for the currently configured profile and return it. .. note:: this will reconstruct the `Backend` instance in `self._backend` so the preferred method to load the @@ -67,7 +107,7 @@ def _load_backend(self, schema_check=True): :param schema_check: force a database schema check if the database environment has not yet been loaded :return: the database backend - :rtype: :class:`aiida.orm.implementation.Backend` + """ from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA from aiida.common import ConfigurationError, InvalidOperation @@ -87,7 +127,7 @@ def _load_backend(self, schema_check=True): # Do NOT reload the backend environment if already loaded, simply reload the backend instance after if configuration.BACKEND_UUID is None: from aiida.backends import get_backend_manager - backend_manager = get_backend_manager(self.get_profile().database_backend) + backend_manager = get_backend_manager(profile.database_backend) backend_manager.load_backend_environment(profile, validate_schema=schema_check) configuration.BACKEND_UUID = profile.uuid @@ -108,46 +148,52 @@ def _load_backend(self, schema_check=True): return self._backend @property - def backend_loaded(self): + def backend_loaded(self) -> bool: """Return whether a database backend has been loaded. :return: boolean, True if database backend is currently loaded, False otherwise """ return self._backend is not None - def get_backend_manager(self): + def get_backend_manager(self) -> 'BackendManager': """Return the database backend manager. .. note:: this is not the actual backend, but a manager class that is necessary for database operations that go around the actual ORM. For example when the schema version has not yet been validated. :return: the database backend manager - :rtype: :class:`aiida.backend.manager.BackendManager` + """ from aiida.backends import get_backend_manager + from aiida.common import ConfigurationError if self._backend_manager is None: self._load_backend() - self._backend_manager = get_backend_manager(self.get_profile().database_backend) + profile = self.get_profile() + if profile is None: + raise ConfigurationError( + 'Could not determine the current profile. Consider loading a profile using `aiida.load_profile()`.' + ) + self._backend_manager = get_backend_manager(profile.database_backend) return self._backend_manager - def get_backend(self): + def get_backend(self) -> 'Backend': """Return the database backend :return: the database backend - :rtype: :class:`aiida.orm.implementation.Backend` + """ if self._backend is None: self._load_backend() return self._backend - def get_persister(self): + def get_persister(self) -> 'AiiDAPersister': """Return the persister :return: the current persister instance - :rtype: :class:`plumpy.Persister` + """ from aiida.engine import persistence @@ -156,18 +202,20 @@ def get_persister(self): return self._persister - def get_communicator(self): + def get_communicator(self) -> 'RmqThreadCommunicator': """Return the communicator :return: a global communicator instance - :rtype: :class:`kiwipy.Communicator` + """ if self._communicator is None: self._communicator = self.create_communicator() return self._communicator - def create_communicator(self, task_prefetch_count=None, with_orm=True): + def create_communicator( + self, task_prefetch_count: Optional[int] = None, with_orm: bool = True + ) -> 'RmqThreadCommunicator': """Create a Communicator. :param task_prefetch_count: optional specify how many tasks this communicator take simultaneously @@ -175,12 +223,17 @@ def create_communicator(self, task_prefetch_count=None, with_orm=True): This is used by verdi status to get a communicator without needing to load the dbenv. :return: the communicator instance - :rtype: :class:`~kiwipy.rmq.communicator.RmqThreadCommunicator` + """ + from aiida.common import ConfigurationError from aiida.manage.external import rmq import kiwipy.rmq profile = self.get_profile() + if profile is None: + raise ConfigurationError( + 'Could not determine the current profile. Consider loading a profile using `aiida.load_profile()`.' + ) if task_prefetch_count is None: task_prefetch_count = self.get_config().get_option('daemon.worker_process_slots', profile.name) @@ -210,11 +263,11 @@ def create_communicator(self, task_prefetch_count=None, with_orm=True): testing_mode=profile.is_test_profile, ) - def get_daemon_client(self): + def get_daemon_client(self) -> 'DaemonClient': """Return the daemon client for the current profile. :return: the daemon client - :rtype: :class:`aiida.daemon.client.DaemonClient` + :raises aiida.common.MissingConfigurationError: if the configuration file cannot be found :raises aiida.common.ProfileConfigurationError: if the given profile does not exist """ @@ -225,52 +278,57 @@ def get_daemon_client(self): return self._daemon_client - def get_process_controller(self): + def get_process_controller(self) -> 'RemoteProcessThreadController': """Return the process controller :return: the process controller instance - :rtype: :class:`plumpy.RemoteProcessThreadController` + """ - import plumpy + from plumpy.process_comms import RemoteProcessThreadController if self._process_controller is None: - self._process_controller = plumpy.RemoteProcessThreadController(self.get_communicator()) + self._process_controller = RemoteProcessThreadController(self.get_communicator()) return self._process_controller - def get_runner(self, **kwargs): + def get_runner(self, **kwargs) -> 'Runner': """Return a runner that is based on the current profile settings and can be used globally by the code. :return: the global runner - :rtype: :class:`aiida.engine.runners.Runner` + """ if self._runner is None: self._runner = self.create_runner(**kwargs) return self._runner - def set_runner(self, new_runner): + def set_runner(self, new_runner: 'Runner') -> None: """Set the currently used runner :param new_runner: the new runner to use - :type new_runner: :class:`aiida.engine.runners.Runner` + """ if self._runner is not None: self._runner.close() self._runner = new_runner - def create_runner(self, with_persistence=True, **kwargs): + def create_runner(self, with_persistence: bool = True, **kwargs: Any) -> 'Runner': """Create and return a new runner :param with_persistence: create a runner with persistence enabled - :type with_persistence: bool + :return: a new runner instance - :rtype: :class:`aiida.engine.runners.Runner` + """ + from aiida.common import ConfigurationError from aiida.engine import runners config = self.get_config() profile = self.get_profile() + if profile is None: + raise ConfigurationError( + 'Could not determine the current profile. Consider loading a profile using `aiida.load_profile()`.' + ) poll_interval = 0.0 if profile.is_test_profile else config.get_option('runner.poll.interval', profile.name) settings = {'rmq_submit': False, 'poll_interval': poll_interval} @@ -285,17 +343,17 @@ def create_runner(self, with_persistence=True, **kwargs): return runners.Runner(**settings) - def create_daemon_runner(self, loop=None): + def create_daemon_runner(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Runner': """Create and return a new daemon runner. This is used by workers when the daemon is running and in testing. :param loop: the (optional) asyncio event loop to use - :type loop: the asyncio event loop + :return: a runner configured to work in the daemon configuration - :rtype: :class:`aiida.engine.runners.Runner` + """ - import plumpy + from plumpy.persistence import LoadSaveContext from aiida.engine import persistence from aiida.manage.external import rmq @@ -306,52 +364,27 @@ def create_daemon_runner(self, loop=None): task_receiver = rmq.ProcessLauncher( loop=runner_loop, persister=self.get_persister(), - load_context=plumpy.LoadSaveContext(runner=runner), + load_context=LoadSaveContext(runner=runner), loader=persistence.get_object_loader() ) + assert runner.communicator is not None, 'communicator not set for runner' runner.communicator.add_task_subscriber(task_receiver) return runner - def close(self): - """Reset the global settings entirely and release any global objects.""" - if self._communicator is not None: - self._communicator.close() - if self._runner is not None: - self._runner.stop() - - self._backend = None - self._backend_manager = None - self._config = None - self._profile = None - self._communicator = None - self._daemon_client = None - self._process_controller = None - self._persister = None - self._runner = None - def __init__(self): - super().__init__() - self._backend = None # type: aiida.orm.implementation.Backend - self._backend_manager = None # type: aiida.backend.manager.BackendManager - self._config = None # type: aiida.manage.configuration.config.Config - self._daemon_client = None # type: aiida.daemon.client.DaemonClient - self._profile = None # type: aiida.manage.configuration.profile.Profile - self._communicator = None # type: kiwipy.rmq.RmqThreadCommunicator - self._process_controller = None # type: plumpy.RemoteProcessThreadController - self._persister = None # type: aiida.engine.persistence.AiiDAPersister - self._runner = None # type: aiida.engine.runners.Runner +MANAGER: Optional[Manager] = None -def get_manager(): +def get_manager() -> Manager: global MANAGER # pylint: disable=global-statement if MANAGER is None: MANAGER = Manager() return MANAGER -def reset_manager(): +def reset_manager() -> None: global MANAGER # pylint: disable=global-statement if MANAGER is not None: MANAGER.close() diff --git a/aiida/sphinxext/process.py b/aiida/sphinxext/process.py index 4c80cefbeb..077b49af1c 100644 --- a/aiida/sphinxext/process.py +++ b/aiida/sphinxext/process.py @@ -117,9 +117,12 @@ def build_content(self): content += self.build_doctree(title='Outputs:', port_namespace=self.process_spec.outputs) if hasattr(self.process_spec, 'get_outline'): - outline = self.process_spec.get_outline() - if outline is not None: - content += self.build_outline_doctree(outline=outline) + try: + outline = self.process_spec.get_outline() + if outline is not None: + content += self.build_outline_doctree(outline=outline) + except AssertionError: + pass return content def build_doctree(self, title, port_namespace): diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 9b290f994d..ecfdb106bb 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -19,20 +19,45 @@ py:class builtins.str py:class builtins.dict # typing -py:class traceback +py:class asyncio.events.AbstractEventLoop +py:class EntityType +py:class function py:class IO +py:class traceback ### AiiDA # issues with order of object processing and type hinting -py:class WorkChainSpec +py:class aiida.engine.runners.ResultAndNode +py:class aiida.engine.runners.ResultAndPk +py:class aiida.engine.processes.workchains.workchain.WorkChainSpec +py:class aiida.manage.manager.Manager +py:class aiida.orm.utils.links.LinkQuadruple py:class aiida.tools.importexport.dbexport.ExportReport py:class aiida.tools.importexport.dbexport.ArchiveData - -py:class EntityType py:class aiida.tools.groups.paths.WalkNodeResult -py:class aiida.orm.utils.links.LinkQuadruple +py:class Node +py:class ProcessSpec +py:class CalcJobNode +py:class ExitCode +py:class Process +py:class AuthInfo +py:class ProcessNode +py:class PortNamespace +py:class Runner +py:class TransportQueue +py:class PersistenceError +py:class Port +py:class Data +py:class JobInfo +py:class CalcJob +py:class WorkChainSpec + +py:class kiwipy.communications.Communicator +py:class plumpy.process_states.State +py:class plumpy.workchains._If +py:class plumpy.workchains._While ### python packages # Note: These exceptions are needed if diff --git a/environment.yml b/environment.yml index 6701fa52ce..6fea17cf94 100644 --- a/environment.yml +++ b/environment.yml @@ -25,7 +25,7 @@ dependencies: - numpy~=1.17 - pamqp~=2.3 - paramiko~=2.7 -- plumpy~=0.18.1 +- plumpy~=0.18.4 - pgsu~=0.1.0 - psutil~=5.6 - psycopg2>=2.8.3,~=2.8 diff --git a/mypy.ini b/mypy.ini index 481b14c29f..073863bcca 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,6 +2,8 @@ [mypy] +show_error_codes = True + check_untyped_defs = True scripts_are_modules = True warn_unused_ignores = True @@ -37,13 +39,16 @@ follow_imports = skip [mypy-tests.*] check_untyped_defs = False +[mypy-circus.*] +ignore_missing_imports = True + [mypy-django.*] ignore_missing_imports = True -[mypy-numpy.*] +[mypy-kiwipy.*] ignore_missing_imports = True -[mypy-plumpy.*] +[mypy-numpy.*] ignore_missing_imports = True [mypy-scipy.*] @@ -54,3 +59,6 @@ ignore_missing_imports = True [mypy-tqdm.*] ignore_missing_imports = True + +[mypy-wrapt.*] +ignore_missing_imports = True diff --git a/pyproject.toml b/pyproject.toml index a2a957cce8..97975b5c3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,24 +75,19 @@ envlist = py37-django [testenv] usedevelop=True - -[testenv:py{36,37,38,39}-{django,sqla}] deps = py36: -rrequirements/requirements-py-3.6.txt py37: -rrequirements/requirements-py-3.7.txt py38: -rrequirements/requirements-py-3.8.txt py39: -rrequirements/requirements-py-3.9.txt + +[testenv:py{36,37,38,39}-{django,sqla}] setenv = django: AIIDA_TEST_BACKEND = django sqla: AIIDA_TEST_BACKEND = sqlalchemy commands = pytest {posargs} [testenv:py{36,37,38,39}-verdi] -deps = - py36: -rrequirements/requirements-py-3.6.txt - py37: -rrequirements/requirements-py-3.7.txt - py38: -rrequirements/requirements-py-3.8.txt - py39: -rrequirements/requirements-py-3.9.txt setenv = AIIDA_TEST_BACKEND = django commands = verdi {posargs} @@ -101,11 +96,6 @@ commands = verdi {posargs} description = clean: Build the documentation (remove any existing build) update: Build the documentation (modify any existing build) -deps = - py36: -rrequirements/requirements-py-3.6.txt - py37: -rrequirements/requirements-py-3.7.txt - py38: -rrequirements/requirements-py-3.8.txt - py39: -rrequirements/requirements-py-3.9.txt passenv = RUN_APIDOC setenv = update: RUN_APIDOC = False @@ -134,6 +124,6 @@ commands = [testenv:py{36,37,38,39}-pre-commit] description = Run the pre-commit checks -extras = all +extras = pre-commit commands = pre-commit run {posargs} """ diff --git a/requirements/requirements-py-3.6.txt b/requirements/requirements-py-3.6.txt index cb5064c3e2..f1fc33ef04 100644 --- a/requirements/requirements-py-3.6.txt +++ b/requirements/requirements-py-3.6.txt @@ -86,7 +86,7 @@ pgsu==0.1.0 pgtest==1.3.2 pickleshare==0.7.5 pluggy==0.13.1 -plumpy==0.18.1 +plumpy==0.18.4 prometheus-client==0.7.1 prompt-toolkit==3.0.4 psutil==5.7.0 diff --git a/requirements/requirements-py-3.7.txt b/requirements/requirements-py-3.7.txt index 68ca66e169..703102b09d 100644 --- a/requirements/requirements-py-3.7.txt +++ b/requirements/requirements-py-3.7.txt @@ -85,7 +85,7 @@ pgsu==0.1.0 pgtest==1.3.2 pickleshare==0.7.5 pluggy==0.13.1 -plumpy==0.18.1 +plumpy==0.18.4 prometheus-client==0.7.1 prompt-toolkit==3.0.4 psutil==5.7.0 diff --git a/requirements/requirements-py-3.8.txt b/requirements/requirements-py-3.8.txt index 0bb77ebddc..c665a43992 100644 --- a/requirements/requirements-py-3.8.txt +++ b/requirements/requirements-py-3.8.txt @@ -80,7 +80,7 @@ pgsu==0.1.0 pgtest==1.3.2 pickleshare==0.7.5 pluggy==0.13.1 -plumpy==0.18.1 +plumpy==0.18.4 prometheus-client==0.7.1 prompt-toolkit==3.0.4 psutil==5.7.0 diff --git a/requirements/requirements-py-3.9.txt b/requirements/requirements-py-3.9.txt index 32564e4c39..d8eec0286d 100644 --- a/requirements/requirements-py-3.9.txt +++ b/requirements/requirements-py-3.9.txt @@ -81,7 +81,7 @@ pika==1.1.0 Pillow==8.0.1 plotly==4.12.0 pluggy==0.13.1 -plumpy==0.18.1 +plumpy==0.18.4 prometheus-client==0.8.0 prompt-toolkit==3.0.8 psutil==5.7.3 diff --git a/setup.json b/setup.json index ff7303dc57..5df80dc5a1 100644 --- a/setup.json +++ b/setup.json @@ -40,7 +40,7 @@ "numpy~=1.17", "pamqp~=2.3", "paramiko~=2.7", - "plumpy~=0.18.1", + "plumpy~=0.18.4", "pgsu~=0.1.0", "psutil~=5.6", "psycopg2-binary~=2.8,>=2.8.3", diff --git a/tests/sphinxext/reference_results/workchain.xml b/tests/sphinxext/reference_results/workchain.xml index e7fae40799..8d487bb2f8 100644 --- a/tests/sphinxext/reference_results/workchain.xml +++ b/tests/sphinxext/reference_results/workchain.xml @@ -1,7 +1,7 @@ - +
sphinx-aiida demo This is a demo documentation to show off the features of the sphinx-aiida extension. @@ -74,16 +74,32 @@ finalize This module defines an example workchain for the aiida-workchain documentation directive. - class demo_workchain.DemoWorkChain*args**kwargs + class demo_workchain.DemoWorkChain*args: Any**kwargs: Any A demo workchain to show how the workchain auto-documentation works. + + + classmethod definespec + + Define the specification of the process, including its inputs, outputs and known exit codes. + A metadata input namespace is defined, with optional ports that are not stored in the database. + + - class demo_workchain.EmptyOutlineWorkChain*args**kwargs + class demo_workchain.EmptyOutlineWorkChain*args: Any**kwargs: Any Here we check that the directive works even if the outline is empty. + + + classmethod definespec + + Define the specification of the process, including its inputs, outputs and known exit codes. + A metadata input namespace is defined, with optional ports that are not stored in the database. + +