Skip to content

Commit

Permalink
🔧 Add types for DefaultFieldsAttributeDict subclasses (#6059)
Browse files Browse the repository at this point in the history
Improves type checking and LSP completion.
  • Loading branch information
chrisjsewell authored Jun 17, 2023
1 parent 3976344 commit afed5dc
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 2 deletions.
38 changes: 38 additions & 0 deletions aiida/common/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Module to define commonly used data structures."""
from __future__ import annotations

from enum import Enum, IntEnum
from typing import TYPE_CHECKING

from .extendeddicts import DefaultFieldsAttributeDict

Expand Down Expand Up @@ -93,6 +96,31 @@ class CalcInfo(DefaultFieldsAttributeDict):
'provenance_exclude_list', 'codes_info', 'codes_run_mode', 'skip_submit'
)

if TYPE_CHECKING:

job_environment: None | dict[str, str]
email: None | str
email_on_started: bool
email_on_terminated: bool
uuid: None | str
prepend_text: None | str
append_text: None | str
num_machines: None | int
num_mpiprocs_per_machine: None | int
priority: None | int
max_wallclock_seconds: None | int
max_memory_kb: None | int
rerunnable: bool
retrieve_list: None | list[str | tuple[str, str, str]]
retrieve_temporary_list: None | list[str | tuple[str, str, str]]
local_copy_list: None | list[tuple[str, str, str]]
remote_copy_list: None | list[tuple[str, str, str]]
remote_symlink_list: None | list[tuple[str, str, str]]
provenance_exclude_list: None | list[str]
codes_info: None | list[CodeInfo]
codes_run_mode: None | CodeRunMode
skip_submit: None | bool


class CodeInfo(DefaultFieldsAttributeDict):
"""
Expand Down Expand Up @@ -148,6 +176,16 @@ class CodeInfo(DefaultFieldsAttributeDict):
'code_uuid'
)

if TYPE_CHECKING:

cmdline_params: None | list[str]
stdin_name: None | str
stdout_name: None | str
stderr_name: None | str
join_files: None | bool
withmpi: None | bool
code_uuid: None | str


class CodeRunMode(IntEnum):
"""Enum to indicate the way the codes of a calculation should be run.
Expand Down
2 changes: 1 addition & 1 deletion aiida/engine/processes/calcjobs/calcjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ def presubmit(self, folder: Folder) -> CalcInfo:
tmpl_code_info.stdin_name = code_info.stdin_name
tmpl_code_info.stdout_name = code_info.stdout_name
tmpl_code_info.stderr_name = code_info.stderr_name
tmpl_code_info.join_files = code_info.join_files
tmpl_code_info.join_files = code_info.join_files or False

tmpl_codes_info.append(tmpl_code_info)
job_tmpl.codes_info = tmpl_codes_info
Expand Down
64 changes: 63 additions & 1 deletion aiida/schedulers/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
from datetime import datetime, timezone
import enum
import json
from typing import TYPE_CHECKING

from aiida.common import AIIDA_LOGGER
from aiida.common import AIIDA_LOGGER, CodeRunMode
from aiida.common.extendeddicts import AttributeDict, DefaultFieldsAttributeDict
from aiida.common.timezone import make_aware, timezone_from_name

Expand Down Expand Up @@ -108,6 +109,12 @@ class NodeNumberJobResource(JobResource):
'num_cores_per_mpiproc',
)

if TYPE_CHECKING:
num_machines: int
num_mpiprocs_per_machine: int
num_cores_per_machine: int
num_cores_per_mpiproc: int

@classmethod
def validate_resources(cls, **kwargs):
"""Validate the resources against the job resource class of this scheduler.
Expand Down Expand Up @@ -193,6 +200,10 @@ class ParEnvJobResource(JobResource):
'tot_num_mpiprocs',
)

if TYPE_CHECKING:
parallel_env: str
tot_num_mpiprocs: int

@classmethod
def validate_resources(cls, **kwargs):
"""Validate the resources against the job resource class of this scheduler.
Expand Down Expand Up @@ -366,6 +377,34 @@ class JobTemplate(DefaultFieldsAttributeDict): # pylint: disable=too-many-insta
'codes_info',
)

if TYPE_CHECKING:
shebang: str
submit_as_hold: bool
rerunnable: bool
job_environment: dict[str, str] | None
environment_variables_double_quotes: bool | None
working_directory: str
email: str
email_on_started: bool
email_on_terminated: bool
job_name: str
sched_output_path: str
sched_error_path: str
sched_join_files: bool
queue_name: str
account: str
qos: str
job_resource: JobResource
priority: str
max_memory_kb: int | None
max_wallclock_seconds: int
custom_scheduler_commands: str
prepend_text: str
append_text: str
import_sys_environment: bool | None
codes_run_mode: CodeRunMode
codes_info: list[JobTemplateCodeInfo]


@dataclass
class JobTemplateCodeInfo:
Expand Down Expand Up @@ -474,6 +513,29 @@ class JobInfo(DefaultFieldsAttributeDict): # pylint: disable=too-many-instance-
'finish_time'
)

if TYPE_CHECKING:
job_id: str
title: str
exit_status: int
terminating_signal: int
annotation: str
job_state: JobState
job_substate: str
allocated_machines: list[MachineInfo]
job_owner: str
num_mpiprocs: int
num_cpus: int
num_machines: int
queue_name: str
account: str
qos: str
wallclock_time_seconds: int
requested_wallclock_time_seconds: int
cpu_time: int
submission_time: datetime
dispatch_time: datetime
finish_time: datetime

# If some fields require special serializers, specify them here.
# You then need to define also the respective _serialize_FIELDTYPE and
# _deserialize_FIELDTYPE methods
Expand Down

0 comments on commit afed5dc

Please sign in to comment.