diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile deleted file mode 100644 index bde97605f6..0000000000 --- a/.devcontainer/Dockerfile +++ /dev/null @@ -1,16 +0,0 @@ -# Additions for dev container -FROM aiidateam/aiida-core:main - -# Add test dependencies (not installed in image) -RUN pip install ./aiida-core[tests,rest,docs,pre-commit] -# the `locate` command is needed by many tests -RUN apt-get update \ - && apt-get install -y mlocate \ - && rm -rf /var/lib/apt/lists/* - -# add aiida user -RUN /etc/my_init.d/10_create-system-user.sh - -# copy updated aiida configuration script -# this line can be deleted after the new script has been merged -COPY ../.docker/opt/configure-aiida.sh /opt/configure-aiida.sh diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 5bbcf39152..4dfd532583 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,9 +1,16 @@ { "dockerComposeFile": "docker-compose.yml", - "service": "aiida", - "workspaceFolder": "/home/aiida/aiida-core", - "postCreateCommand": "bash ./.devcontainer/post_create.sh", - "waitFor": "postCreateCommand", + "service": "daemon", + "workspaceFolder": "/workspaces/aiida-core", + "postCreateCommand": "/etc/init/aiida-prepare.sh", + "postStartCommand": "pip install -e /workspaces/aiida-core[tests,docs,rest,atomic_tools,pre-commit]", + "postAttachCommand": "verdi daemon start", + "waitFor": "postStartCommand", + "containerUser": "aiida", + "remoteUser": "aiida", + "remoteEnv": { + "HOME": "/home/aiida" + }, "customizations": { "vscode": { "extensions": ["ms-python.python", "eamodio.gitlens"] diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index 380a0a4760..0c6a2afc67 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -2,53 +2,36 @@ version: '3.4' services: - rabbitmq: - image: rabbitmq:3.8.3-management + database: + image: postgres:15 environment: - RABBITMQ_DEFAULT_USER: guest - RABBITMQ_DEFAULT_PASS: guest - ports: - - '5672:5672' - - '15672:15672' + POSTGRES_USER: postgres + POSTGRES_PASSWORD: password + POSTGRES_HOST_AUTH_METHOD: trust + healthcheck: + test: [CMD-SHELL, pg_isready] + interval: 5s + timeout: 5s + retries: 10 + messaging: + image: rabbitmq:3.8.14-management + environment: + RABBITMQ_DEFAULT_USER: guest + RABBITMQ_DEFAULT_PASS: guest healthcheck: - test: rabbitmq-diagnostics -q ping + test: rabbitmq-diagnostics check_port_connectivity interval: 30s timeout: 30s - retries: 5 - networks: - - aiida - - postgres: - image: postgres:12 - ports: - - '5432:5432' - networks: - - aiida - environment: - POSTGRES_HOST_AUTH_METHOD: trust + retries: 10 - aiida: - #image: "aiidateam/aiida-core:main" - image: "aiida-core-dev" - build: - # need to add the parent directory to context to copy over new configure-aiida.sh - context: .. - dockerfile: .devcontainer/Dockerfile + daemon: + image: ghcr.io/aiidateam/aiida-core-base:edge user: aiida - environment: - DB_HOST: postgres - BROKER_HOST: rabbitmq - - # no need for /sbin/my_init entrypoint: tail -f /dev/null - volumes: - - ..:/home/aiida/aiida-core:cached - networks: - - aiida + environment: + SETUP_DEFAULT_AIIDA_PROFILE: 'true' + TZ: Europe/Zurich depends_on: - - rabbitmq - - postgres - -networks: - aiida: + database: + condition: service_healthy diff --git a/.devcontainer/post_create.sh b/.devcontainer/post_create.sh deleted file mode 100644 index 71f8330853..0000000000 --- a/.devcontainer/post_create.sh +++ /dev/null @@ -1,4 +0,0 @@ - #!/bin/bash - -# configure aiida -/opt/configure-aiida.sh diff --git a/.docker/README.md b/.docker/README.md new file mode 100644 index 0000000000..e69fdf82a8 --- /dev/null +++ b/.docker/README.md @@ -0,0 +1,30 @@ +# AiiDA Docker stacks + +### Build images locally + +To build the images, run the following command: (tested with _docker buildx_ version v0.8.2) + +```bash +docker buildx bake -f docker-bake.hcl -f build.json --load +``` + +The build system will attempt to detect the local architecture and automatically build images for it (tested with amd64 and arm64). + +You can also specify a custom platform with `--platform`, for example: + +```bash +docker buildx bake -f docker-bake.hcl -f build.json --set *.platform=linux/amd64 --load +``` + +### Test built images locally + +To test the images, run + +```bash +TAG=newly-baked python -m pytest -s tests +``` + +### Trigger a build on ghcr.io and Dockerhub + +- Only an open PR to the organization's repository will trigger a build on ghcr.io. +- A push to dockerhub is triggered when making a release on github. diff --git a/.docker/aiida-core-base/Dockerfile b/.docker/aiida-core-base/Dockerfile new file mode 100644 index 0000000000..591070709c --- /dev/null +++ b/.docker/aiida-core-base/Dockerfile @@ -0,0 +1,178 @@ +# syntax=docker/dockerfile:1 + +# Inspired by jupyter's docker-stacks-fundation image: +# https://github.com/jupyter/docker-stacks/tree/main/images/docker-stacks-foundation/Dockerfile + +ARG BASE=ubuntu:22.04 + +FROM $BASE + +LABEL maintainer="AiiDA Team " + +ARG SYSTEM_USER="aiida" +ARG SYSTEM_UID="1000" +ARG SYSTEM_GID="100" + + +# Fix: https://github.com/hadolint/hadolint/wiki/DL4006 +# Fix: https://github.com/koalaman/shellcheck/wiki/SC3014 +SHELL ["/bin/bash", "-o", "pipefail", "-c"] + +USER root + +ENV SYSTEM_USER="${SYSTEM_USER}" + +# Install all OS dependencies for notebook server that starts but lacks all +# features (e.g., download as all possible file formats) +ENV DEBIAN_FRONTEND=noninteractive +RUN apt-get update --yes && \ + # - apt-get upgrade is run to patch known vulnerabilities in apt-get packages as + # the ubuntu base image is rebuilt too seldom sometimes (less than once a month) + apt-get upgrade --yes && \ + apt-get install --yes --no-install-recommends \ + # - bzip2 is necessary to extract the micromamba executable. + bzip2 \ + # - xz-utils is necessary to extract the s6-overlay. + xz-utils \ + ca-certificates \ + locales \ + sudo \ + # development tools + git \ + openssh-client \ + rsync \ + graphviz \ + vim \ + # the gcc compiler need to build some python packages e.g. psutil and pymatgen + build-essential \ + wget && \ + apt-get clean && rm -rf /var/lib/apt/lists/* && \ + echo "en_US.UTF-8 UTF-8" > /etc/locale.gen && \ + locale-gen + +# Install s6-overlay to handle startup and shutdown of services +ARG S6_OVERLAY_VERSION=3.1.5.0 +RUN wget --progress=dot:giga -O /tmp/s6-overlay-noarch.tar.xz \ + "https://github.com/just-containers/s6-overlay/releases/download/v${S6_OVERLAY_VERSION}/s6-overlay-noarch.tar.xz" && \ + tar -C / -Jxpf /tmp/s6-overlay-noarch.tar.xz && \ + rm /tmp/s6-overlay-noarch.tar.xz + +RUN set -x && \ + arch=$(uname -m) && \ + wget --progress=dot:giga -O /tmp/s6-overlay-binary.tar.xz \ + "https://github.com/just-containers/s6-overlay/releases/download/v${S6_OVERLAY_VERSION}/s6-overlay-${arch}.tar.xz" && \ + tar -C / -Jxpf /tmp/s6-overlay-binary.tar.xz && \ + rm /tmp/s6-overlay-binary.tar.xz + +# Configure environment +ENV CONDA_DIR=/opt/conda \ + SHELL=/bin/bash \ + SYSTEM_USER="${SYSTEM_USER}" \ + SYSTEM_UID=${SYSTEM_UID} \ + SYSTEM_GID=${SYSTEM_GID} \ + LC_ALL=en_US.UTF-8 \ + LANG=en_US.UTF-8 \ + LANGUAGE=en_US.UTF-8 +ENV PATH="${CONDA_DIR}/bin:${PATH}" \ + HOME="/home/${SYSTEM_USER}" + + +# Copy a script that we will use to correct permissions after running certain commands +COPY fix-permissions /usr/local/bin/fix-permissions +RUN chmod a+rx /usr/local/bin/fix-permissions + +# Enable prompt color in the skeleton .bashrc before creating the default SYSTEM_USER +# hadolint ignore=SC2016 +RUN sed -i 's/^#force_color_prompt=yes/force_color_prompt=yes/' /etc/skel/.bashrc && \ + # Add call to conda init script see https://stackoverflow.com/a/58081608/4413446 + echo 'eval "$(command conda shell.bash hook 2> /dev/null)"' >> /etc/skel/.bashrc + +# Create $SYSTEM_USER user with UID=1000 and 'users' group +# and make sure these dirs are writable by the `users` group. +RUN echo "auth requisite pam_deny.so" >> /etc/pam.d/su && \ + sed -i.bak -e 's/^%admin/#%admin/' /etc/sudoers && \ + sed -i.bak -e 's/^%sudo/#%sudo/' /etc/sudoers && \ + useradd -l -m -s /bin/bash -N -u "${SYSTEM_UID}" "${SYSTEM_USER}" && \ + mkdir -p "${CONDA_DIR}" && \ + chown "${SYSTEM_USER}:${SYSTEM_GID}" "${CONDA_DIR}" && \ + chmod g+w /etc/passwd && \ + fix-permissions "${HOME}" && \ + fix-permissions "${CONDA_DIR}" + +USER ${SYSTEM_UID} + +# Pin python version here +ARG PYTHON_VERSION + +# Download and install Micromamba, and initialize Conda prefix. +# +# Similar projects using Micromamba: +# - Micromamba-Docker: +# - repo2docker: +# Install Python, Mamba +# Cleanup temporary files and remove Micromamba +# Correct permissions +# Do all this in a single RUN command to avoid duplicating all of the +# files across image layers when the permissions change +COPY --chown="${SYSTEM_UID}:${SYSTEM_GID}" initial-condarc "${CONDA_DIR}/.condarc" +WORKDIR /tmp +RUN set -x && \ + arch=$(uname -m) && \ + if [ "${arch}" = "x86_64" ]; then \ + # Should be simpler, see + arch="64"; \ + fi && \ + wget --progress=dot:giga -O /tmp/micromamba.tar.bz2 \ + "https://micromamba.snakepit.net/api/micromamba/linux-${arch}/latest" && \ + tar -xvjf /tmp/micromamba.tar.bz2 --strip-components=1 bin/micromamba && \ + rm /tmp/micromamba.tar.bz2 && \ + PYTHON_SPECIFIER="python=${PYTHON_VERSION}" && \ + if [[ "${PYTHON_VERSION}" == "default" ]]; then PYTHON_SPECIFIER="python"; fi && \ + # Install the packages + ./micromamba install \ + --root-prefix="${CONDA_DIR}" \ + --prefix="${CONDA_DIR}" \ + --yes \ + "${PYTHON_SPECIFIER}" \ + mamba && \ + rm micromamba && \ + # Pin major.minor version of python + mamba list python | grep -oP 'python\s+\K[\d.]+' | tr -s ' ' | cut -d ' ' -f 1,2 >> "${CONDA_DIR}/conda-meta/pinned" && \ + mamba clean --all -f -y && \ + fix-permissions "${CONDA_DIR}" && \ + fix-permissions "/home/${SYSTEM_USER}" + +# Add ~/.local/bin to PATH where the dependencies get installed via pip +# This require the package installed with `--user` flag in pip, which we set as default. +ENV PATH=${PATH}:/home/${SYSTEM_USER}/.local/bin +ENV PIP_USER=1 + +# Switch to root to install AiiDA and set AiiDA as service +# Install AiiDA from source code +USER root +COPY --from=src . /tmp/aiida-core +RUN pip install /tmp/aiida-core --no-cache-dir && \ + rm -rf /tmp/aiida-core + +# Enable verdi autocompletion. +RUN mkdir -p "${CONDA_DIR}/etc/conda/activate.d" && \ + echo 'eval "$(_VERDI_COMPLETE=bash_source verdi)"' >> "${CONDA_DIR}/etc/conda/activate.d/activate_aiida_autocompletion.sh" && \ + chmod +x "${CONDA_DIR}/etc/conda/activate.d/activate_aiida_autocompletion.sh" && \ + fix-permissions "${CONDA_DIR}" + +# COPY AiiDA profile configuration for profile setup init script +COPY s6-assets/s6-rc.d /etc/s6-overlay/s6-rc.d +COPY s6-assets/init /etc/init +RUN mkdir /etc/init/run-before-daemon-start && \ + mkdir /etc/init/run-after-daemon-start + +# Otherwise will stuck on oneshot services +# https://github.com/just-containers/s6-overlay/issues/467 +ENV S6_CMD_WAIT_FOR_SERVICES_MAXTIME=0 + +# Switch back to USER aiida to avoid accidental container runs as root +USER ${SYSTEM_UID} + +ENTRYPOINT ["/init"] + +WORKDIR "${HOME}" diff --git a/.docker/aiida-core-base/fix-permissions b/.docker/aiida-core-base/fix-permissions new file mode 100644 index 0000000000..840173c605 --- /dev/null +++ b/.docker/aiida-core-base/fix-permissions @@ -0,0 +1,35 @@ +#!/bin/bash +# This is brought from jupyter docker-stacks: +# https://github.com/jupyter/docker-stacks/blob/main/docker-stacks-foundation/fix-permissions +# set permissions on a directory +# after any installation, if a directory needs to be (human) user-writable, +# run this script on it. +# It will make everything in the directory owned by the group ${SYSTEM_GID} +# and writable by that group. + +# uses find to avoid touching files that already have the right permissions, +# which would cause massive image explosion + +# right permissions are: +# group=${SYSEM_GID} +# AND permissions include group rwX (directory-execute) +# AND directories have setuid,setgid bits set + +set -e + +for d in "$@"; do + find "${d}" \ + ! \( \ + -group "${SYSTEM_GID}" \ + -a -perm -g+rwX \ + \) \ + -exec chgrp "${SYSTEM_GID}" -- {} \+ \ + -exec chmod g+rwX -- {} \+ + # setuid, setgid *on directories only* + find "${d}" \ + \( \ + -type d \ + -a ! -perm -6000 \ + \) \ + -exec chmod +6000 -- {} \+ +done diff --git a/.docker/aiida-core-base/initial-condarc b/.docker/aiida-core-base/initial-condarc new file mode 100644 index 0000000000..383aad3cb0 --- /dev/null +++ b/.docker/aiida-core-base/initial-condarc @@ -0,0 +1,6 @@ +# Conda configuration see https://conda.io/projects/conda/en/latest/configuration.html + +auto_update_conda: false +show_channel_urls: true +channels: + - conda-forge diff --git a/.docker/aiida-core-base/s6-assets/init/aiida-daemon-start.sh b/.docker/aiida-core-base/s6-assets/init/aiida-daemon-start.sh new file mode 100755 index 0000000000..ecda5bba9d --- /dev/null +++ b/.docker/aiida-core-base/s6-assets/init/aiida-daemon-start.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +verdi profile show +if [ $? == 0 ]; then + # Start the daemon + verdi daemon start +else + echo "The default profile is not set." +fi diff --git a/.docker/aiida-core-base/s6-assets/init/aiida-daemon-stop.sh b/.docker/aiida-core-base/s6-assets/init/aiida-daemon-stop.sh new file mode 100755 index 0000000000..57ee543e11 --- /dev/null +++ b/.docker/aiida-core-base/s6-assets/init/aiida-daemon-stop.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +verdi profile show +if [ $? == 0 ]; then + # Stop the daemon + verdi daemon stop +fi diff --git a/.docker/aiida-core-base/s6-assets/init/aiida-prepare.sh b/.docker/aiida-core-base/s6-assets/init/aiida-prepare.sh new file mode 100755 index 0000000000..a9e54142fa --- /dev/null +++ b/.docker/aiida-core-base/s6-assets/init/aiida-prepare.sh @@ -0,0 +1,66 @@ +#!/bin/bash + +# This script is executed whenever the docker container is (re)started. + +# Environment. +export SHELL=/bin/bash + +# Supress rabbitmq version warning +# If it is built using RMQ version > 3.8.15 (as we did for the `aiida-core-with-services` image) which has the issue as described in +# https://github.com/aiidateam/aiida-core/wiki/RabbitMQ-version-to-use +# We explicitly set consumer_timeout to disabled in /etc/rabbitmq/rabbitmq.conf +verdi config set warnings.rabbitmq_version False + +# Supress verdi version warning because we are using a development version +verdi config set warnings.development_version False + +# Check if user requested to set up AiiDA profile (and if it exists already) +# If the environment variable `SETUP_DEFAULT_AIIDA_PROFILE` is not set, set it to `true`. +if [[ ${SETUP_DEFAULT_AIIDA_PROFILE:-true} == true ]] && ! verdi profile show ${AIIDA_PROFILE_NAME} &> /dev/null; then + + # Create AiiDA profile. + verdi presto \ + --verbosity info \ + --profile-name "${AIIDA_PROFILE_NAME:-default}" \ + --email "${AIIDA_USER_EMAIL:-aiida@localhost}" \ + --use-postgres \ + --postgres-hostname "${AIIDA_POSTGRES_HOSTNAME:-localhost}" \ + --postgres-password "${AIIDA_POSTGRES_PASSWORD:-password}" + + # Setup and configure local computer. + computer_name=localhost + + # Determine the number of physical cores as a default for the number of + # available MPI ranks on the localhost. We do not count "logical" cores, + # since MPI parallelization over hyper-threaded cores is typically + # associated with a significant performance penalty. We use the + # `psutil.cpu_count(logical=False)` function as opposed to simply + # `os.cpu_count()` since the latter would include hyperthreaded (logical + # cores). + NUM_PHYSICAL_CORES=$(python -c 'import psutil; print(int(psutil.cpu_count(logical=False)))' 2>/dev/null) + LOCALHOST_MPI_PROCS_PER_MACHINE=${LOCALHOST_MPI_PROCS_PER_MACHINE:-${NUM_PHYSICAL_CORES}} + + if [ -z $LOCALHOST_MPI_PROCS_PER_MACHINE ]; then + echo "Unable to automatically determine the number of logical CPUs on this " + echo "machine. Please set the LOCALHOST_MPI_PROCS_PER_MACHINE variable to " + echo "explicitly set the number of available MPI ranks." + exit 1 + fi + + verdi computer show ${computer_name} &> /dev/null || verdi computer setup \ + --non-interactive \ + --label "${computer_name}" \ + --description "container computer" \ + --hostname "${computer_name}" \ + --transport core.local \ + --scheduler core.direct \ + --work-dir /home/${SYSTEM_USER}/aiida_run/ \ + --mpirun-command "mpirun -np {tot_num_mpiprocs}" \ + --mpiprocs-per-machine ${LOCALHOST_MPI_PROCS_PER_MACHINE} && \ + verdi computer configure core.local "${computer_name}" \ + --non-interactive \ + --safe-interval 0.0 + + # Migration will run for the default profile. + verdi storage migrate --force +fi diff --git a/aiida/calculations/importers/__init__.py b/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-daemon-start/dependencies.d/aiida-prepare similarity index 100% rename from aiida/calculations/importers/__init__.py rename to .docker/aiida-core-base/s6-assets/s6-rc.d/aiida-daemon-start/dependencies.d/aiida-prepare diff --git a/aiida/calculations/importers/arithmetic/__init__.py b/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-daemon-start/dependencies.d/base similarity index 100% rename from aiida/calculations/importers/arithmetic/__init__.py rename to .docker/aiida-core-base/s6-assets/s6-rc.d/aiida-daemon-start/dependencies.d/base diff --git a/aiida/calculations/monitors/__init__.py b/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-daemon-start/dependencies.d/run-before-daemon-start similarity index 100% rename from aiida/calculations/monitors/__init__.py rename to .docker/aiida-core-base/s6-assets/s6-rc.d/aiida-daemon-start/dependencies.d/run-before-daemon-start diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-daemon-start/down b/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-daemon-start/down new file mode 100644 index 0000000000..21b0cb2a6a --- /dev/null +++ b/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-daemon-start/down @@ -0,0 +1,6 @@ +#!/command/execlineb -S0 + +with-contenv + +foreground { s6-echo "Calling /etc/init/aiida-daemon-stop" } +/etc/init/aiida-daemon-stop.sh diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-daemon-start/timeout-up b/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-daemon-start/timeout-up new file mode 100644 index 0000000000..573541ac97 --- /dev/null +++ b/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-daemon-start/timeout-up @@ -0,0 +1 @@ +0 diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-daemon-start/type b/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-daemon-start/type new file mode 100644 index 0000000000..bdd22a1850 --- /dev/null +++ b/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-daemon-start/type @@ -0,0 +1 @@ +oneshot diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-daemon-start/up b/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-daemon-start/up new file mode 100644 index 0000000000..4380526b94 --- /dev/null +++ b/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-daemon-start/up @@ -0,0 +1,6 @@ +#!/command/execlineb -S0 + +with-contenv + +foreground { s6-echo "Calling /etc/init/aiida-daemon-start" } +/etc/init/aiida-daemon-start.sh diff --git a/aiida/tools/query/__init__.py b/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-prepare/dependencies.d/base similarity index 100% rename from aiida/tools/query/__init__.py rename to .docker/aiida-core-base/s6-assets/s6-rc.d/aiida-prepare/dependencies.d/base diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-prepare/timeout-up b/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-prepare/timeout-up new file mode 100644 index 0000000000..573541ac97 --- /dev/null +++ b/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-prepare/timeout-up @@ -0,0 +1 @@ +0 diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-prepare/type b/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-prepare/type new file mode 100644 index 0000000000..bdd22a1850 --- /dev/null +++ b/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-prepare/type @@ -0,0 +1 @@ +oneshot diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-prepare/up b/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-prepare/up new file mode 100644 index 0000000000..b1045997fd --- /dev/null +++ b/.docker/aiida-core-base/s6-assets/s6-rc.d/aiida-prepare/up @@ -0,0 +1,6 @@ +#!/command/execlineb -S0 + +with-contenv + +foreground { s6-echo "Calling /etc/init/aiida-prepare" } +/etc/init/aiida-prepare.sh diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/run-after-daemon-start/dependencies.d/aiida-daemon-start b/.docker/aiida-core-base/s6-assets/s6-rc.d/run-after-daemon-start/dependencies.d/aiida-daemon-start new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/run-after-daemon-start/dependencies.d/base b/.docker/aiida-core-base/s6-assets/s6-rc.d/run-after-daemon-start/dependencies.d/base new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/run-after-daemon-start/timeout-up b/.docker/aiida-core-base/s6-assets/s6-rc.d/run-after-daemon-start/timeout-up new file mode 100644 index 0000000000..573541ac97 --- /dev/null +++ b/.docker/aiida-core-base/s6-assets/s6-rc.d/run-after-daemon-start/timeout-up @@ -0,0 +1 @@ +0 diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/run-after-daemon-start/type b/.docker/aiida-core-base/s6-assets/s6-rc.d/run-after-daemon-start/type new file mode 100644 index 0000000000..bdd22a1850 --- /dev/null +++ b/.docker/aiida-core-base/s6-assets/s6-rc.d/run-after-daemon-start/type @@ -0,0 +1 @@ +oneshot diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/run-after-daemon-start/up b/.docker/aiida-core-base/s6-assets/s6-rc.d/run-after-daemon-start/up new file mode 100644 index 0000000000..d2e95d5190 --- /dev/null +++ b/.docker/aiida-core-base/s6-assets/s6-rc.d/run-after-daemon-start/up @@ -0,0 +1,6 @@ +#!/command/execlineb -P + +with-contenv + +foreground { s6-echo "Calling /etc/init/run-after-daemon-start" } +run-parts --regex=".*" /etc/init/run-after-daemon-start/ diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/run-before-daemon-start/dependencies.d/aiida-prepare b/.docker/aiida-core-base/s6-assets/s6-rc.d/run-before-daemon-start/dependencies.d/aiida-prepare new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/run-before-daemon-start/dependencies.d/base b/.docker/aiida-core-base/s6-assets/s6-rc.d/run-before-daemon-start/dependencies.d/base new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/run-before-daemon-start/timeout-up b/.docker/aiida-core-base/s6-assets/s6-rc.d/run-before-daemon-start/timeout-up new file mode 100644 index 0000000000..573541ac97 --- /dev/null +++ b/.docker/aiida-core-base/s6-assets/s6-rc.d/run-before-daemon-start/timeout-up @@ -0,0 +1 @@ +0 diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/run-before-daemon-start/type b/.docker/aiida-core-base/s6-assets/s6-rc.d/run-before-daemon-start/type new file mode 100644 index 0000000000..bdd22a1850 --- /dev/null +++ b/.docker/aiida-core-base/s6-assets/s6-rc.d/run-before-daemon-start/type @@ -0,0 +1 @@ +oneshot diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/run-before-daemon-start/up b/.docker/aiida-core-base/s6-assets/s6-rc.d/run-before-daemon-start/up new file mode 100644 index 0000000000..3ff7dc0360 --- /dev/null +++ b/.docker/aiida-core-base/s6-assets/s6-rc.d/run-before-daemon-start/up @@ -0,0 +1,6 @@ +#!/command/execlineb -P + +with-contenv + +foreground { s6-echo "Calling /etc/init/run-before-daemon-start" } +run-parts --regex=".*" /etc/init/run-before-daemon-start/ diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/user/contents.d/aiida-daemon-start b/.docker/aiida-core-base/s6-assets/s6-rc.d/user/contents.d/aiida-daemon-start new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/user/contents.d/aiida-prepare b/.docker/aiida-core-base/s6-assets/s6-rc.d/user/contents.d/aiida-prepare new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/user/contents.d/run-after-daemon-start b/.docker/aiida-core-base/s6-assets/s6-rc.d/user/contents.d/run-after-daemon-start new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-base/s6-assets/s6-rc.d/user/contents.d/run-before-daemon-start b/.docker/aiida-core-base/s6-assets/s6-rc.d/user/contents.d/run-before-daemon-start new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-dev/Dockerfile b/.docker/aiida-core-dev/Dockerfile new file mode 100644 index 0000000000..3974b7d569 --- /dev/null +++ b/.docker/aiida-core-dev/Dockerfile @@ -0,0 +1,10 @@ +# syntax=docker/dockerfile:1 +FROM aiida-core-with-services + +LABEL maintainer="AiiDA Team " + +COPY aiida-clone-and-install.sh /etc/init/run-before-daemon-start/10-aiida-clone-and-install.sh +COPY --chown=${SYSTEM_UID}:${SYSTEM_GID} --from=src . /home/${SYSTEM_USER}/aiida-core + +USER ${SYSTEM_UID} +WORKDIR "/home/${SYSTEM_USER}" diff --git a/.docker/aiida-core-dev/aiida-clone-and-install.sh b/.docker/aiida-core-dev/aiida-clone-and-install.sh new file mode 100755 index 0000000000..cdcc3d80ea --- /dev/null +++ b/.docker/aiida-core-dev/aiida-clone-and-install.sh @@ -0,0 +1,12 @@ +#!/bin/bash + +REPO_PATH=/home/aiida/aiida-core + +# If the repo is not already existent, clone it +# This is only necessary if the image is run through k8s with persistent volume, as the volume will be empty +# For the docker, the container folder (i.e. `$HOME`) that mounted to the volume will be copied to the volume. +if [ ! -d "$REPO_PATH" ]; then + git clone https://github.com/aiidateam/aiida-core.git --origin upstream $REPO_PATH +fi + +pip install --user -e "$REPO_PATH/[pre-commit,atomic_tools,docs,rest,tests,tui]" tox diff --git a/.docker/aiida-core-with-services/Dockerfile b/.docker/aiida-core-with-services/Dockerfile new file mode 100644 index 0000000000..276186175a --- /dev/null +++ b/.docker/aiida-core-with-services/Dockerfile @@ -0,0 +1,40 @@ +# syntax=docker/dockerfile:1 +FROM aiida-core-base + +LABEL maintainer="AiiDA Team " + +USER root +WORKDIR /opt/ + +ARG PGSQL_VERSION +ARG RMQ_VERSION + +ENV PGSQL_VERSION=${PGSQL_VERSION} +ENV RMQ_VERSION=${RMQ_VERSION} + +RUN mamba install --yes \ + --channel conda-forge \ + postgresql=${PGSQL_VERSION} && \ + mamba clean --all -f -y && \ + fix-permissions "${CONDA_DIR}" && \ + fix-permissions "/home/${SYSTEM_USER}" + +# Install erlang. +RUN apt-get update --yes && \ + apt-get install --yes --no-install-recommends \ + erlang && \ + apt-get clean && rm -rf /var/lib/apt/lists/* && \ + # Install rabbitmq. + wget -c --no-check-certificate https://github.com/rabbitmq/rabbitmq-server/releases/download/v${RMQ_VERSION}/rabbitmq-server-generic-unix-${RMQ_VERSION}.tar.xz && \ + tar -xf rabbitmq-server-generic-unix-${RMQ_VERSION}.tar.xz && \ + rm rabbitmq-server-generic-unix-${RMQ_VERSION}.tar.xz && \ + ln -sf /opt/rabbitmq_server-${RMQ_VERSION}/sbin/* /usr/local/bin/ && \ + fix-permissions /opt/rabbitmq_server-${RMQ_VERSION} + +# s6-overlay to start services +COPY s6-assets/s6-rc.d /etc/s6-overlay/s6-rc.d +COPY s6-assets/init /etc/init + +USER ${SYSTEM_UID} + +WORKDIR "/home/${SYSTEM_USER}" diff --git a/.docker/aiida-core-with-services/s6-assets/init/postgresql-init.sh b/.docker/aiida-core-with-services/s6-assets/init/postgresql-init.sh new file mode 100755 index 0000000000..0d3556f453 --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/init/postgresql-init.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# make DB directory, if not existent +if [ ! -d /home/${SYSTEM_USER}/.postgresql ]; then + mkdir /home/${SYSTEM_USER}/.postgresql + initdb -D /home/${SYSTEM_USER}/.postgresql + echo "unix_socket_directories = '/tmp'" >> /home/${SYSTEM_USER}/.postgresql/postgresql.conf +fi + +PSQL_STATUS_CMD="pg_ctl -D /home/${SYSTEM_USER}/.postgresql status" + +# Fix problem with kubernetes cluster that adds rws permissions to the group +# for more details see: https://github.com/materialscloud-org/aiidalab-z2jh-eosc/issues/5 +chmod g-rwxs /home/${SYSTEM_USER}/.postgresql -R + +# stores return value in $? +running=true +${PSQL_STATUS_CMD} > /dev/null 2>&1 || running=false + +# Postgresql was probably not shutdown properly. Cleaning up the mess... +if ! $running ; then + echo "" > /home/${SYSTEM_USER}/.postgresql/logfile # empty log files + rm -vf /home/${SYSTEM_USER}/.postgresql/postmaster.pid +fi diff --git a/.docker/aiida-core-with-services/s6-assets/init/postgresql-prepare.sh b/.docker/aiida-core-with-services/s6-assets/init/postgresql-prepare.sh new file mode 100755 index 0000000000..580ee47106 --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/init/postgresql-prepare.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +PG_ISREADY=1 +while [ "$PG_ISREADY" != "0" ]; do + sleep 1 + pg_isready --quiet + PG_ISREADY=$? +done diff --git a/.docker/aiida-core-with-services/s6-assets/init/rabbitmq-init.sh b/.docker/aiida-core-with-services/s6-assets/init/rabbitmq-init.sh new file mode 100755 index 0000000000..ced09bd059 --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/init/rabbitmq-init.sh @@ -0,0 +1,32 @@ +#!/bin/bash +RABBITMQ_DATA_DIR="/home/${SYSTEM_USER}/.rabbitmq" + +mkdir -p "${RABBITMQ_DATA_DIR}" +fix-permissions "${RABBITMQ_DATA_DIR}" + +# Fix issue where the erlang cookie permissions are corrupted. +chmod 400 "/home/${SYSTEM_USER}/.erlang.cookie" || echo "erlang cookie not created yet." + +# Set base directory for RabbitMQ to persist its data. This needs to be set to a folder in the system user's home +# directory as that is the only folder that is persisted outside of the container. +RMQ_ETC_DIR="/opt/rabbitmq_server-${RMQ_VERSION}/etc/rabbitmq" +echo MNESIA_BASE="${RABBITMQ_DATA_DIR}" >> "${RMQ_ETC_DIR}/rabbitmq-env.conf" +echo LOG_BASE="${RABBITMQ_DATA_DIR}/log" >> "${RMQ_ETC_DIR}/rabbitmq-env.conf" + +# using workaround from https://github.com/aiidateam/aiida-core/wiki/RabbitMQ-version-to-use +# setting the consumer_timeout to undefined disables the timeout +cat > "${RMQ_ETC_DIR}/advanced.config" <> "${RMQ_ETC_DIR}/rabbitmq-env.conf" diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/aiida-prepare/dependencies.d/base b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/aiida-prepare/dependencies.d/base new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/aiida-prepare/dependencies.d/postgresql b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/aiida-prepare/dependencies.d/postgresql new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/aiida-prepare/dependencies.d/postgresql-prepare b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/aiida-prepare/dependencies.d/postgresql-prepare new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/aiida-prepare/dependencies.d/rabbitmq b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/aiida-prepare/dependencies.d/rabbitmq new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-init/dependencies.d/base b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-init/dependencies.d/base new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-init/timeout-up b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-init/timeout-up new file mode 100644 index 0000000000..573541ac97 --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-init/timeout-up @@ -0,0 +1 @@ +0 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-init/type b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-init/type new file mode 100644 index 0000000000..bdd22a1850 --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-init/type @@ -0,0 +1 @@ +oneshot diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-init/up b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-init/up new file mode 100644 index 0000000000..6fc0f06f57 --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-init/up @@ -0,0 +1,6 @@ +#!/command/execlineb -S0 + +with-contenv + +foreground { s6-echo "Calling /etc/init/postgresql-init" } +/etc/init/postgresql-init.sh diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-prepare/dependencies.d/base b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-prepare/dependencies.d/base new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-prepare/timeout-up b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-prepare/timeout-up new file mode 100644 index 0000000000..573541ac97 --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-prepare/timeout-up @@ -0,0 +1 @@ +0 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-prepare/type b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-prepare/type new file mode 100644 index 0000000000..bdd22a1850 --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-prepare/type @@ -0,0 +1 @@ +oneshot diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-prepare/up b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-prepare/up new file mode 100644 index 0000000000..df5f5f83f9 --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql-prepare/up @@ -0,0 +1,6 @@ +#!/command/execlineb -S0 + +with-contenv + +foreground { s6-echo "Calling /etc/init/postgresql-prepare" } +/etc/init/postgresql-prepare.sh diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql/dependencies.d/base b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql/dependencies.d/base new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql/dependencies.d/postgresql-init b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql/dependencies.d/postgresql-init new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql/down b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql/down new file mode 100644 index 0000000000..f2cc3c69b8 --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql/down @@ -0,0 +1 @@ +pg_ctl -D /home/aiida/.postgresql stop diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql/timeout-up b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql/timeout-up new file mode 100644 index 0000000000..573541ac97 --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql/timeout-up @@ -0,0 +1 @@ +0 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql/type b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql/type new file mode 100644 index 0000000000..bdd22a1850 --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql/type @@ -0,0 +1 @@ +oneshot diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql/up b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql/up new file mode 100644 index 0000000000..776d110d6c --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/postgresql/up @@ -0,0 +1,5 @@ +#!/command/execlineb -P + +with-contenv + +pg_ctl -D /home/aiida/.postgresql -l /home/${SYSTEM_USER}/.postgresql/logfile start diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq-init/dependencies.d/base b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq-init/dependencies.d/base new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq-init/timeout-up b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq-init/timeout-up new file mode 100644 index 0000000000..573541ac97 --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq-init/timeout-up @@ -0,0 +1 @@ +0 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq-init/type b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq-init/type new file mode 100644 index 0000000000..bdd22a1850 --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq-init/type @@ -0,0 +1 @@ +oneshot diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq-init/up b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq-init/up new file mode 100644 index 0000000000..e574020053 --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq-init/up @@ -0,0 +1,5 @@ +#!/command/execlineb -S0 +with-contenv + +foreground { s6-echo "Calling /etc/init/rabbitmq-init.sh" } +/etc/init/rabbitmq-init.sh diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq/data/check b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq/data/check new file mode 100755 index 0000000000..46eb70ea89 --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq/data/check @@ -0,0 +1,15 @@ +#!/bin/bash + +rabbitmq-diagnostics ping + +if [ $? -ne 0 ]; then + exit 1 +fi + +rabbitmq-diagnostics check_running + +if [ $? -ne 0 ]; then + exit 1 +fi + +exit 0 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq/dependencies.d/base b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq/dependencies.d/base new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq/dependencies.d/rabbitmq-init b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq/dependencies.d/rabbitmq-init new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq/down-signal b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq/down-signal new file mode 100644 index 0000000000..d751378e19 --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq/down-signal @@ -0,0 +1 @@ +SIGINT diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq/notification-fd b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq/notification-fd new file mode 100644 index 0000000000..00750edc07 --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq/notification-fd @@ -0,0 +1 @@ +3 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq/run b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq/run new file mode 100644 index 0000000000..8a35acd20f --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq/run @@ -0,0 +1,15 @@ +#!/command/execlineb -P + +with-contenv + +foreground { s6-echo "Starting RMQ server and notifying back when the service is ready" } + + +# For the container that includes the services, aiida-prepare.sh script is called as soon as the RabbitMQ startup script has +# been launched, but it can take a while for the RMQ service to come up. If ``verdi presto`` is called straight away +# it is possible it tries to connect to the service before that and it will configure the profile without a broker. +# Here we use s6-notifyoncheck to do the polling healthy check of the readyness of RMQ service. +# +# -w 500: 500 ms between two invocations of ./data/check + +s6-notifyoncheck -w 500 rabbitmq-server diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq/type b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq/type new file mode 100644 index 0000000000..5883cff0cd --- /dev/null +++ b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/rabbitmq/type @@ -0,0 +1 @@ +longrun diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/user/contents.d/aiida-prepare b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/user/contents.d/aiida-prepare new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/user/contents.d/postgresql b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/user/contents.d/postgresql new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/user/contents.d/postgresql-init b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/user/contents.d/postgresql-init new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/user/contents.d/postgresql-prepare b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/user/contents.d/postgresql-prepare new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/user/contents.d/rabbitmq b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/user/contents.d/rabbitmq new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/aiida-core-with-services/s6-assets/s6-rc.d/user/contents.d/rabbitmq-init b/.docker/aiida-core-with-services/s6-assets/s6-rc.d/user/contents.d/rabbitmq-init new file mode 100644 index 0000000000..e69de29bb2 diff --git a/.docker/build.json b/.docker/build.json new file mode 100644 index 0000000000..df352a35e3 --- /dev/null +++ b/.docker/build.json @@ -0,0 +1,13 @@ +{ + "variable": { + "PYTHON_VERSION": { + "default": "3.10.13" + }, + "PGSQL_VERSION": { + "default": "15" + }, + "RMQ_VERSION": { + "default": "3.10.18" + } + } + } diff --git a/.docker/docker-bake.hcl b/.docker/docker-bake.hcl new file mode 100644 index 0000000000..12938b490c --- /dev/null +++ b/.docker/docker-bake.hcl @@ -0,0 +1,68 @@ +# docker-bake.hcl +variable "VERSION" { +} + +variable "PYTHON_VERSION" { +} + +variable "PGSQL_VERSION" { +} + +variable "ORGANIZATION" { + default = "aiidateam" +} + +variable "REGISTRY" { +} + +variable "PLATFORMS" { + default = ["linux/amd64"] +} + +variable "TARGETS" { + default = ["aiida-core-base", "aiida-core-with-services", "aiida-core-dev"] +} + +function "tags" { + params = [image] + result = [ + "${REGISTRY}${ORGANIZATION}/${image}" + ] +} + +group "default" { + targets = "${TARGETS}" +} + +target "aiida-core-base" { + tags = tags("aiida-core-base") + context = "aiida-core-base" + contexts = { + src = ".." + } + platforms = "${PLATFORMS}" + args = { + "PYTHON_VERSION" = "${PYTHON_VERSION}" + } +} +target "aiida-core-with-services" { + tags = tags("aiida-core-with-services") + context = "aiida-core-with-services" + contexts = { + aiida-core-base = "target:aiida-core-base" + } + platforms = "${PLATFORMS}" + args = { + "PGSQL_VERSION" = "${PGSQL_VERSION}" + "RMQ_VERSION" = "${RMQ_VERSION}" + } +} +target "aiida-core-dev" { + tags = tags("aiida-core-dev") + context = "aiida-core-dev" + contexts = { + src = ".." + aiida-core-with-services = "target:aiida-core-with-services" + } + platforms = "${PLATFORMS}" +} diff --git a/.docker/docker-compose.aiida-core-base.yml b/.docker/docker-compose.aiida-core-base.yml new file mode 100644 index 0000000000..2ac9326ab1 --- /dev/null +++ b/.docker/docker-compose.aiida-core-base.yml @@ -0,0 +1,50 @@ +version: '3.4' + +services: + + database: + image: postgres:15 + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: password + # volumes: + # - aiida-postgres-db:/var/lib/postgresql/data + healthcheck: + test: [CMD, pg_isready, -U, postgres] + interval: 5s + timeout: 5s + retries: 10 + + messaging: + image: rabbitmq:3.8.14-management + environment: + RABBITMQ_DEFAULT_USER: guest + RABBITMQ_DEFAULT_PASS: guest + # volumes: + # - aiida-rmq-data:/var/lib/rabbitmq/ + healthcheck: + test: [CMD, rabbitmq-diagnostics, check_running] + interval: 30s + timeout: 30s + retries: 10 + + aiida: + image: ${REGISTRY:-}${AIIDA_CORE_BASE_IMAGE:-aiidateam/aiida-core-base}${TAG:-} + environment: + AIIDA_POSTGRES_HOSTNAME: database + AIIDA_BROKER_HOST: messaging + RMQHOST: messaging + TZ: Europe/Zurich + SETUP_DEFAULT_AIIDA_PROFILE: 'true' + # volumes: + # - aiida-home-folder:/home/aiida + depends_on: + database: + condition: service_healthy + messaging: + condition: service_healthy + +#volumes: +# aiida-postgres-db: +# aiida-rmq-data: +# aiida-home-folder: diff --git a/.docker/docker-compose.aiida-core-dev.yml b/.docker/docker-compose.aiida-core-dev.yml new file mode 100644 index 0000000000..b59f6484a5 --- /dev/null +++ b/.docker/docker-compose.aiida-core-dev.yml @@ -0,0 +1,9 @@ +version: '3.4' + +services: + + aiida: + image: ${REGISTRY:-}${AIIDA_CORE_DEV_IMAGE:-aiidateam/aiida-core-dev}${TAG:-} + environment: + TZ: Europe/Zurich + SETUP_DEFAULT_AIIDA_PROFILE: 'true' diff --git a/.docker/docker-compose.aiida-core-with-services.yml b/.docker/docker-compose.aiida-core-with-services.yml new file mode 100644 index 0000000000..73cf83e7e3 --- /dev/null +++ b/.docker/docker-compose.aiida-core-with-services.yml @@ -0,0 +1,14 @@ +version: '3.4' + +services: + + aiida: + image: ${REGISTRY:-}${AIIDA_CORE_WITH_SERVICES_IMAGE:-aiidateam/aiida-core-with-services}${TAG:-} + environment: + TZ: Europe/Zurich + SETUP_DEFAULT_AIIDA_PROFILE: 'true' + #volumes: + # - aiida-home-folder:/home/aiida + +volumes: + aiida-home-folder: diff --git a/.docker/docker-rabbitmq.yml b/.docker/docker-rabbitmq.yml deleted file mode 100644 index da266790ff..0000000000 --- a/.docker/docker-rabbitmq.yml +++ /dev/null @@ -1,34 +0,0 @@ -# A small configuration for use in local CI testing, -# if you wish to control the rabbitmq used. - -# Simply install docker, then run: -# $ docker-compose -f .docker/docker-rabbitmq.yml up -d - -# and to power down, after testing: -# $ docker-compose -f .docker/docker-rabbitmq.yml down - -# you can monitor rabbitmq use at: http://localhost:15672 - -version: '3.4' - -services: - - rabbit: - image: rabbitmq:3.8.3-management - container_name: aiida-rmq - environment: - RABBITMQ_DEFAULT_USER: guest - RABBITMQ_DEFAULT_PASS: guest - ports: - - '5672:5672' - - '15672:15672' - healthcheck: - test: rabbitmq-diagnostics -q ping - interval: 30s - timeout: 30s - retries: 5 - networks: - - aiida-rmq - -networks: - aiida-rmq: diff --git a/.docker/my_init.d/configure-aiida.sh b/.docker/my_init.d/configure-aiida.sh deleted file mode 100755 index 7ac4476b07..0000000000 --- a/.docker/my_init.d/configure-aiida.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash -set -em - -su -c /opt/configure-aiida.sh ${SYSTEM_USER} diff --git a/.docker/opt/configure-aiida.sh b/.docker/opt/configure-aiida.sh deleted file mode 100755 index 92aa3ab45d..0000000000 --- a/.docker/opt/configure-aiida.sh +++ /dev/null @@ -1,90 +0,0 @@ -#!/bin/bash - -# This script is executed whenever the docker container is (re)started. - -# Debugging. -set -x - -# Environment. -export SHELL=/bin/bash - -# Setup AiiDA autocompletion. -grep _VERDI_COMPLETE /home/${SYSTEM_USER}/.bashrc &> /dev/null || echo 'eval "$(_VERDI_COMPLETE=source verdi)"' >> /home/${SYSTEM_USER}/.bashrc - -# Check if user requested to set up AiiDA profile (and if it exists already) -if [[ ${SETUP_DEFAULT_PROFILE} == true ]] && ! verdi profile show ${PROFILE_NAME} &> /dev/null; then - NEED_SETUP_PROFILE=true; -else - NEED_SETUP_PROFILE=false; -fi - -# Setup AiiDA profile if needed. -if [[ ${NEED_SETUP_PROFILE} == true ]]; then - - # Create AiiDA profile. - verdi quicksetup \ - --non-interactive \ - --profile "${PROFILE_NAME}" \ - --email "${USER_EMAIL}" \ - --first-name "${USER_FIRST_NAME}" \ - --last-name "${USER_LAST_NAME}" \ - --institution "${USER_INSTITUTION}" \ - --db-host "${DB_HOST:localhost}" \ - --broker-host "${BROKER_HOST:localhost}" - - # Setup and configure local computer. - computer_name=localhost - - # Determine the number of physical cores as a default for the number of - # available MPI ranks on the localhost. We do not count "logical" cores, - # since MPI parallelization over hyper-threaded cores is typically - # associated with a significant performance penalty. We use the - # `psutil.cpu_count(logical=False)` function as opposed to simply - # `os.cpu_count()` since the latter would include hyperthreaded (logical - # cores). - NUM_PHYSICAL_CORES=$(python -c 'import psutil; print(int(psutil.cpu_count(logical=False)))' 2>/dev/null) - LOCALHOST_MPI_PROCS_PER_MACHINE=${LOCALHOST_MPI_PROCS_PER_MACHINE:-${NUM_PHYSICAL_CORES}} - - if [ -z $LOCALHOST_MPI_PROCS_PER_MACHINE ]; then - echo "Unable to automatically determine the number of logical CPUs on this " - echo "machine. Please set the LOCALHOST_MPI_PROCS_PER_MACHINE variable to " - echo "explicitly set the number of available MPI ranks." - exit 1 - fi - - verdi computer show ${computer_name} || verdi computer setup \ - --non-interactive \ - --label "${computer_name}" \ - --description "this computer" \ - --hostname "${computer_name}" \ - --transport core.local \ - --scheduler core.direct \ - --work-dir /home/aiida/aiida_run/ \ - --mpirun-command "mpirun -np {tot_num_mpiprocs}" \ - --mpiprocs-per-machine ${LOCALHOST_MPI_PROCS_PER_MACHINE} && \ - verdi computer configure core.local "${computer_name}" \ - --non-interactive \ - --safe-interval 0.0 -fi - - -# Show the default profile -verdi profile show || echo "The default profile is not set." - -# Make sure that the daemon is not running, otherwise the migration will abort. -verdi daemon stop - -# Migration will run for the default profile. -verdi storage migrate --force || echo "Database migration failed." - -# Supress rabbitmq version warning for arm64 since -# the it build using latest version rabbitmq from apt install -# We explicitly set consumer_timeout to 100 hours in /etc/rabbitmq/rabbitmq.conf -export ARCH=`uname -m` -if [ "$ARCH" = "aarch64" ]; then \ - verdi config set warnings.rabbitmq_version False -fi - - -# Daemon will start only if the database exists and is migrated to the latest version. -verdi daemon start || echo "AiiDA daemon is not running." diff --git a/.docker/pytest.ini b/.docker/pytest.ini new file mode 100644 index 0000000000..e6c356c3eb --- /dev/null +++ b/.docker/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +minversion = 7.0 +addopts = -ra -q --strict-markers +testpaths = + tests diff --git a/.docker/requirements.txt b/.docker/requirements.txt new file mode 100644 index 0000000000..6c84d9fa30 --- /dev/null +++ b/.docker/requirements.txt @@ -0,0 +1,4 @@ +docker~=7.0.0 +pytest~=8.2.0 +requests~=2.32.0 +pytest-docker~=3.1.0 diff --git a/.docker/tests/conftest.py b/.docker/tests/conftest.py new file mode 100644 index 0000000000..eaf2eec743 --- /dev/null +++ b/.docker/tests/conftest.py @@ -0,0 +1,112 @@ +import json +from pathlib import Path + +import pytest + +TARGETS = ('aiida-core-base', 'aiida-core-with-services', 'aiida-core-dev') + + +def target_checker(value): + msg = f"Invalid image target '{value}', must be one of: {TARGETS}" + if value not in TARGETS: + raise pytest.UsageError(msg) + return value + + +def pytest_addoption(parser): + parser.addoption( + '--variant', + action='store', + required=True, + help='target (image name) of the docker-compose file to use.', + type=target_checker, + ) + + +@pytest.fixture(scope='session') +def variant(pytestconfig): + return pytestconfig.getoption('variant') + + +@pytest.fixture(scope='session') +def docker_compose_file(variant): + return f'docker-compose.{variant}.yml' + + +@pytest.fixture(scope='session') +def docker_compose(docker_services): + return docker_services._docker_compose + + +@pytest.fixture(scope='session', autouse=True) +def _docker_service_wait(docker_services): + """Container startup wait.""" + + # using `docker_compose` fixture would trigger a separate container + docker_compose = docker_services._docker_compose + + def is_container_ready(): + try: + output = docker_compose.execute('exec -T aiida verdi status').decode().strip() + except Exception: + return False + return '✔ broker:' in output and 'Daemon is running' in output + + try: + docker_services.wait_until_responsive( + timeout=300.0, + pause=10, + check=lambda: is_container_ready(), + ) + except Exception: + print('Timed out waiting for the profile and daemon to be up and running.') + + try: + docker_compose.execute('exec -T aiida verdi status').decode().strip() + except Exception as exception: + print(f'Output of `verdi status`:\n{exception}') + + try: + docker_compose.execute('exec -T aiida verdi profile show').decode().strip() + except Exception as exception: + print(f'Output of `verdi status`:\n{exception}') + + print(docker_compose.execute('logs').decode().strip()) + raise + + +@pytest.fixture +def container_user(): + return 'aiida' + + +@pytest.fixture +def aiida_exec(docker_compose): + def execute(command, user=None, **kwargs): + if user: + command = f'exec -T --user={user} aiida {command}' + else: + command = f'exec -T aiida {command}' + return docker_compose.execute(command, **kwargs) + + return execute + + +@pytest.fixture(scope='session') +def _build_config(): + return json.loads(Path('build.json').read_text(encoding='utf-8'))['variable'] + + +@pytest.fixture(scope='session') +def python_version(_build_config): + return _build_config['PYTHON_VERSION']['default'] + + +@pytest.fixture(scope='session') +def pgsql_version(_build_config): + return _build_config['PGSQL_VERSION']['default'] + + +@pytest.fixture(scope='session') +def rmq_version(_build_config): + return _build_config['RMQ_VERSION']['default'] diff --git a/.docker/tests/test_aiida.py b/.docker/tests/test_aiida.py new file mode 100644 index 0000000000..7f952bd855 --- /dev/null +++ b/.docker/tests/test_aiida.py @@ -0,0 +1,89 @@ +import json +import re + +import pytest +from packaging.version import parse + + +def test_correct_python_version_installed(aiida_exec, python_version): + info = json.loads(aiida_exec('mamba list --json --full-name python').decode())[0] + assert info['name'] == 'python' + assert parse(info['version']) == parse(python_version) + + +def test_correct_pgsql_version_installed(aiida_exec, pgsql_version, variant): + if variant == 'aiida-core-base': + pytest.skip('PostgreSQL is not installed in the base image') + + info = json.loads(aiida_exec('mamba list --json --full-name postgresql').decode())[0] + assert info['name'] == 'postgresql' + assert parse(info['version']).major == parse(pgsql_version).major + + +def test_rmq_consumer_timeout_unset(aiida_exec, variant): + if variant == 'aiida-core-base': + pytest.skip('RabbitMQ is not installed in the base image') + + output = aiida_exec('rabbitmqctl environment | grep consumer_timeout', user='root').decode().strip() + assert 'undefined' in output + + +def test_verdi_status(aiida_exec, container_user): + output = aiida_exec('verdi status', user=container_user).decode().strip() + assert '✔ broker:' in output + assert 'Daemon is running' in output + + # Check that we have suppressed the warnings coming from using an install from repo and newer RabbitMQ version. + # Make sure to match only lines that start with ``Warning:`` because otherwise deprecation warnings from other + # packages that we cannot control may fail the test. + assert not re.match('^Warning:.*', output) + + +def test_computer_setup_success(aiida_exec, container_user): + output = aiida_exec('verdi computer test localhost', user=container_user).decode().strip() + + assert 'Success' in output + assert 'Failed' not in output + + +def test_clone_dir_exists(aiida_exec, variant): + """Test that the aiida-core repository is cloned in the aiida-core-dev image.""" + if variant != 'aiida-core-dev': + pytest.skip(f'aiida-core clone not available in {variant} image') + + output = aiida_exec('ls /home/aiida/').decode().strip() + + assert 'aiida-core' in output + + +def test_editable_install(aiida_exec, variant): + """Test that the aiida-core repository is installed in editable mode in the aiida-core-dev image.""" + if variant != 'aiida-core-dev': + pytest.skip(f'aiida-core clone not available in {variant} image') + + package = 'aiida-core' + + output = aiida_exec(f'pip show {package}').decode().strip() + + assert f'Editable project location: /home/aiida/{package}' in output + + +@pytest.mark.parametrize( + 'package', + [ + 'ase', + 'Sphinx', + 'pre-commit', + 'Flask', + 'pytest', + 'trogon', + ], +) +def test_optional_dependency_install(aiida_exec, package, variant): + """Test that optional dependencies are installed in the aiida-core-dev image.""" + if variant != 'aiida-core-dev': + pytest.skip(f'optional dependencies are not installed in {variant} image') + + output = aiida_exec(f'pip show {package}').decode().strip() + + assert f'Name: {package}' in output diff --git a/.dockerignore b/.dockerignore deleted file mode 100644 index dfe06bad59..0000000000 --- a/.dockerignore +++ /dev/null @@ -1,13 +0,0 @@ -.benchmarks -.cache -.coverage -.mypy_cache -.pytest_cache -.tox -.vscode -aiida_core.egg-info -docs/build -pip-wheel-metadata -**/.DS_Store -**/*.pyc -**/__pycache__ diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index e908ba2fc9..3b1e3cb611 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,9 +2,7 @@ # currently active dependency manager (DM) to trigger an automatic review # request from the DM upon changes. Please see AEP-002 for details: # https://github.com/aiidateam/AEP/tree/master/002_dependency_management -setup.* @aiidateam/dependency-manager -environment.yml @aiidateam/dependency-manager -requirements*.txt @aiidateam/dependency-manager -pyproject.toml @aiidateam/dependency-manager -utils/dependency_management.py @aiidateam/dependency-manager -.github/workflows/dm.yml @aiidateam/dependency-manager +environment.yml @unkcpz @agoscinski +pyproject.toml @unkcpz @agoscinski +uv.lock @unkcpz @agoscinski +utils/dependency_management.py @unkcpz @agoscinski diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index fea729c5f5..f8bb3bcd73 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -9,8 +9,8 @@ assignees: '' -- [ ] [AiiDA Troubleshooting Documentation](https://aiida.readthedocs.io/projects/aiida-core/en/latest/intro/troubleshooting.html) -- [ ] [AiiDA Users Forum](https://groups.google.com/forum/#!forum/aiidausers) +- [ ] [AiiDA Troubleshooting Documentation](https://aiida.readthedocs.io/projects/aiida-core/en/stable/installation/troubleshooting.html) +- [ ] [AiiDA Discourse Forum](https://aiida.discourse.group/) ### Describe the bug diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index afef97b49a..b8d6ebc19f 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,7 +1,4 @@ contact_links: - - name: AiiDA Discussions - url: https://github.com/aiidateam/aiida-core/discussions - about: For aiida-core questions and discussion - - name: AiiDA Users Forum - url: http://www.aiida.net/mailing-list/ - about: For general questions and discussion +- name: AiiDA Discourse group + url: https://aiida.discourse.group + about: For general questions and discussion diff --git a/.github/actions/install-aiida-core/action.yml b/.github/actions/install-aiida-core/action.yml new file mode 100644 index 0000000000..26e6cac961 --- /dev/null +++ b/.github/actions/install-aiida-core/action.yml @@ -0,0 +1,48 @@ +name: Install aiida-core +description: Install aiida-core package and its Python dependencies + +inputs: + python-version: + description: Python version + default: '3.9' # Lowest supported version + required: false + extras: + description: list of optional dependencies + # NOTE: The default 'pre-commit' extra recursively contains + # other extras needed to run the tests. + default: pre-commit + required: false + # NOTE: Hard-learned lesson: we cannot use type=boolean here, apparently :-( + # https://stackoverflow.com/a/76294014 + # NOTE2: When installing from lockfile, aiida-core and its dependencies + # are installed in a virtual environment located in .venv directory. + # Subsuquent jobs steps must either activate the environment or use `uv run` + from-lock: + description: Install aiida-core dependencies from uv lock file + default: 'true' + required: false + +runs: + using: composite + steps: + - name: Set Up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python-version }} + + - name: Set up uv + uses: astral-sh/setup-uv@v5 + with: + version: 0.5.x + python-version: ${{ inputs.python-version }} + + - name: Install dependencies from uv lock + if: ${{ inputs.from-lock == 'true' }} + # NOTE: We're asserting that the lockfile is up to date + run: uv sync --locked ${{ inputs.extras && format('--extra {0}', inputs.extras) || '' }} + shell: bash + + - name: Install aiida-core + if: ${{ inputs.from-lock != 'true' }} + run: uv pip install -e .${{ inputs.extras }} + shell: bash diff --git a/.github/config/add-containerized.yaml b/.github/config/add-containerized.yaml new file mode 100644 index 0000000000..cc8cd882b4 --- /dev/null +++ b/.github/config/add-containerized.yaml @@ -0,0 +1,9 @@ +label: add-singularity +description: Bash run in Docker image through Singularity +default_calc_job_plugin: core.arithmetic.add +computer: localhost +filepath_executable: /bin/sh +image_name: alpine:3 +engine_command: docker run --user 1001:100 -v $PWD:$PWD -w $PWD -i {image_name} +prepend_text: ' ' +append_text: ' ' diff --git a/.github/config/add-singularity.yaml b/.github/config/add-singularity.yaml deleted file mode 100644 index 6bffe3482f..0000000000 --- a/.github/config/add-singularity.yaml +++ /dev/null @@ -1,10 +0,0 @@ ---- -label: add-singularity -description: Bash run in Docker image through Singularity -default_calc_job_plugin: core.arithmetic.add -computer: localhost -filepath_executable: /bin/sh -image_name: docker://alpine:3 -engine_command: singularity exec --bind $PWD:$PWD {image_name} -prepend_text: ' ' -append_text: ' ' diff --git a/.github/config/add.yaml b/.github/config/add.yaml index f416ec6910..979d1696f8 100644 --- a/.github/config/add.yaml +++ b/.github/config/add.yaml @@ -1,4 +1,3 @@ ---- label: add description: add default_calc_job_plugin: core.arithmetic.add diff --git a/.github/config/doubler.yaml b/.github/config/doubler.yaml index 242598f726..53e35c2612 100644 --- a/.github/config/doubler.yaml +++ b/.github/config/doubler.yaml @@ -1,4 +1,3 @@ ---- label: doubler description: doubler default_calc_job_plugin: core.templatereplacer diff --git a/.github/config/localhost-config.yaml b/.github/config/localhost-config.yaml index ef0ca22365..d152d600c4 100644 --- a/.github/config/localhost-config.yaml +++ b/.github/config/localhost-config.yaml @@ -1,3 +1,2 @@ ---- use_login_shell: true safe_interval: 0 diff --git a/.github/config/localhost.yaml b/.github/config/localhost.yaml index 307a478413..7c6924315e 100644 --- a/.github/config/localhost.yaml +++ b/.github/config/localhost.yaml @@ -1,4 +1,3 @@ ---- label: localhost description: localhost hostname: localhost diff --git a/.github/config/profile.yaml b/.github/config/profile.yaml index 2f07476462..d0e2c9eebf 100644 --- a/.github/config/profile.yaml +++ b/.github/config/profile.yaml @@ -1,11 +1,10 @@ ---- profile: test_aiida email: aiida@localhost first_name: Giuseppe last_name: Verdi institution: Khedivial db_backend: core.psql_dos -db_engine: postgresql_psycopg2 +db_engine: postgresql_psycopg db_host: localhost db_port: 5432 db_name: test_aiida diff --git a/.github/config/slurm-ssh-config.yaml b/.github/config/slurm-ssh-config.yaml index 48332209de..6c2ba585cf 100644 --- a/.github/config/slurm-ssh-config.yaml +++ b/.github/config/slurm-ssh-config.yaml @@ -1,7 +1,6 @@ ---- safe_interval: 0 username: xenon look_for_keys: true -key_filename: "PLACEHOLDER_SSH_KEY" +key_filename: PLACEHOLDER_SSH_KEY key_policy: AutoAddPolicy port: 5001 diff --git a/.github/config/slurm-ssh.yaml b/.github/config/slurm-ssh.yaml index 7419e468cc..a14d7de42d 100644 --- a/.github/config/slurm-ssh.yaml +++ b/.github/config/slurm-ssh.yaml @@ -1,12 +1,11 @@ ---- label: slurm-ssh description: slurm container hostname: localhost transport: core.ssh scheduler: core.slurm -shebang: "#!/bin/bash" +shebang: '#!/bin/bash' work_dir: /home/{username}/workdir -mpirun_command: "mpirun -np {tot_num_mpiprocs}" +mpirun_command: mpirun -np {tot_num_mpiprocs} mpiprocs_per_machine: 1 -prepend_text: "" -append_text: "" +prepend_text: '' +append_text: '' diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000..8f0af41bb2 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,11 @@ +version: 2 +updates: +# Maintain dependencies for GitHub Actions +- package-ecosystem: github-actions + directory: / + schedule: + interval: monthly + groups: + gha-dependencies: + patterns: + - '*' diff --git a/.github/system_tests/pytest/test_memory_leaks.py b/.github/system_tests/pytest/test_memory_leaks.py deleted file mode 100644 index 396793346f..0000000000 --- a/.github/system_tests/pytest/test_memory_leaks.py +++ /dev/null @@ -1,64 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Utilities for testing memory leakage.""" -from aiida import orm -from aiida.engine import processes, run_get_node -from aiida.plugins import CalculationFactory -from tests.utils import processes as test_processes # pylint: disable=no-name-in-module,import-error -from tests.utils.memory import get_instances # pylint: disable=no-name-in-module,import-error - -ArithmeticAddCalculation = CalculationFactory('core.arithmetic.add') - - -def run_finished_ok(*args, **kwargs): - """Convenience function to check that run worked fine.""" - _, node = run_get_node(*args, **kwargs) - assert node.is_finished_ok, (node.exit_status, node.exit_message) - - -def test_leak_run_process(): - """Test whether running a dummy process leaks memory.""" - inputs = {'a': orm.Int(2), 'b': orm.Str('test')} - run_finished_ok(test_processes.DummyProcess, **inputs) - - # check that no reference to the process is left in memory - # some delay is necessary in order to allow for all callbacks to finish - process_instances = get_instances(processes.Process, delay=0.2) - assert not process_instances, f'Memory leak: process instances remain in memory: {process_instances}' - - -def test_leak_local_calcjob(aiida_local_code_factory): - """Test whether running a local CalcJob leaks memory.""" - inputs = {'x': orm.Int(1), 'y': orm.Int(2), 'code': aiida_local_code_factory('core.arithmetic.add', '/bin/bash')} - run_finished_ok(ArithmeticAddCalculation, **inputs) - - # check that no reference to the process is left in memory - # some delay is necessary in order to allow for all callbacks to finish - process_instances = get_instances(processes.Process, delay=0.2) - assert not process_instances, f'Memory leak: process instances remain in memory: {process_instances}' - - -def test_leak_ssh_calcjob(): - """Test whether running a CalcJob over SSH leaks memory. - - Note: This relies on the 'slurm-ssh' computer being set up. - """ - code = orm.InstalledCode( - default_calc_job_plugin='core.arithmetic.add', - computer=orm.load_computer('slurm-ssh'), - filepath_executable='/bin/bash' - ) - inputs = {'x': orm.Int(1), 'y': orm.Int(2), 'code': code} - run_finished_ok(ArithmeticAddCalculation, **inputs) - - # check that no reference to the process is left in memory - # some delay is necessary in order to allow for all callbacks to finish - process_instances = get_instances(processes.Process, delay=0.2) - assert not process_instances, f'Memory leak: process instances remain in memory: {process_instances}' diff --git a/.github/system_tests/test_containerized_code.py b/.github/system_tests/test_containerized_code.py index e452ad80bf..a4286d8ba7 100644 --- a/.github/system_tests/test_containerized_code.py +++ b/.github/system_tests/test_containerized_code.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- ########################################################################### # Copyright (c), The AiiDA team. All rights reserved. # # This file is part of the AiiDA code. # @@ -8,6 +7,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Test running a :class:`~aiida.orm.nodes.data.codes.containerized.ContainerizedCode` code.""" + from aiida import orm from aiida.engine import run_get_node diff --git a/.github/system_tests/test_daemon.py b/.github/system_tests/test_daemon.py index fce3a827cc..c49cb1f750 100644 --- a/.github/system_tests/test_daemon.py +++ b/.github/system_tests/test_daemon.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- ########################################################################### # Copyright (c), The AiiDA team. All rights reserved. # # This file is part of the AiiDA code. # @@ -7,8 +6,8 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=no-name-in-module """Tests to run with a running daemon.""" + import os import re import shutil @@ -37,9 +36,10 @@ from aiida.engine.processes import CalcJob, Process from aiida.manage.caching import enable_caching from aiida.orm import CalcJobNode, Dict, Int, List, Str, load_code, load_node +from aiida.orm.nodes.caching import NodeCaching from aiida.plugins import CalculationFactory, WorkflowFactory from aiida.workflows.arithmetic.add_multiply import add, add_multiply -from tests.utils.memory import get_instances # pylint: disable=import-error +from tests.utils.memory import get_instances CODENAME_ADD = 'add@localhost' CODENAME_DOUBLER = 'doubler@localhost' @@ -55,10 +55,12 @@ def print_daemon_log(): print(f"Output of 'cat {daemon_log}':") try: - print(subprocess.check_output( - ['cat', f'{daemon_log}'], - stderr=subprocess.STDOUT, - )) + print( + subprocess.check_output( + ['cat', f'{daemon_log}'], + stderr=subprocess.STDOUT, + ) + ) except subprocess.CalledProcessError as exception: print(f'Note: the command failed, message: {exception}') @@ -80,10 +82,12 @@ def print_report(pk): """Print the process report for given pk.""" print(f"Output of 'verdi process report {pk}':") try: - print(subprocess.check_output( - ['verdi', 'process', 'report', f'{pk}'], - stderr=subprocess.STDOUT, - )) + print( + subprocess.check_output( + ['verdi', 'process', 'report', f'{pk}'], + stderr=subprocess.STDOUT, + ) + ) except subprocess.CalledProcessError as exception: print(f'Note: the command failed, message: {exception}') @@ -192,12 +196,9 @@ def validate_workchains(expected_results): def validate_cached(cached_calcs): - """ - Check that the calculations with created with caching are indeed cached. - """ + """Check that the calculations with created with caching are indeed cached.""" valid = True for calc in cached_calcs: - if not calc.is_finished_ok: print( 'Cached calculation<{}> not finished ok: process_state<{}> exit_status<{}>'.format( @@ -207,14 +208,15 @@ def validate_cached(cached_calcs): print_report(calc.pk) valid = False - if '_aiida_cached_from' not in calc.base.extras or calc.base.caching.get_hash( - ) != calc.base.extras.get('_aiida_hash'): + if NodeCaching.CACHED_FROM_KEY not in calc.base.extras or calc.base.caching.get_hash() != calc.base.extras.get( + '_aiida_hash' + ): print(f'Cached calculation<{calc.pk}> has invalid hash') print_report(calc.pk) valid = False if isinstance(calc, CalcJobNode): - original_calc = load_node(calc.base.extras.get('_aiida_cached_from')) + original_calc = load_node(calc.base.extras.get(NodeCaching.CACHED_FROM_KEY)) files_original = original_calc.base.repository.list_object_names() files_cached = calc.base.repository.list_object_names() @@ -269,9 +271,7 @@ def launch_workfunction(inputval): def launch_calculation(code, counter, inputval): - """ - Launch calculations to the daemon through the Process layer - """ + """Launch calculations to the daemon through the Process layer""" process, inputs, expected_result = create_calculation_process(code=code, inputval=inputval) calc = submit(process, **inputs) print(f'[{counter}] launched calculation {calc.uuid}, pk={calc.pk}') @@ -279,9 +279,7 @@ def launch_calculation(code, counter, inputval): def run_calculation(code, counter, inputval): - """ - Run a calculation through the Process layer. - """ + """Run a calculation through the Process layer.""" process, inputs, expected_result = create_calculation_process(code=code, inputval=inputval) _, calc = run.get_node(process, **inputs) print(f'[{counter}] ran calculation {calc.uuid}, pk={calc.pk}') @@ -289,28 +287,25 @@ def run_calculation(code, counter, inputval): def create_calculation_process(code, inputval): - """ - Create the process and inputs for a submitting / running a calculation. - """ - TemplatereplacerCalculation = CalculationFactory('core.templatereplacer') + """Create the process and inputs for a submitting / running a calculation.""" parameters = Dict({'value': inputval}) - template = Dict({ - # The following line adds a significant sleep time. - # I set it to 1 second to speed up tests - # I keep it to a non-zero value because I want - # To test the case when AiiDA finds some calcs - # in a queued state - # 'cmdline_params': ["{}".format(counter % 3)], # Sleep time - 'cmdline_params': ['1'], - 'input_file_template': '{value}', # File just contains the value to double - 'input_file_name': 'value_to_double.txt', - 'output_file_name': 'output.txt', - 'retrieve_temporary_files': ['triple_value.tmp'] - }) + template = Dict( + { + # The following line adds a significant sleep time. + # I set it to 1 second to speed up tests + # I keep it to a non-zero value because I want + # To test the case when AiiDA finds some calcs + # in a queued state + # 'cmdline_params': ["{}".format(counter % 3)], # Sleep time + 'cmdline_params': ['1'], + 'input_file_template': '{value}', # File just contains the value to double + 'input_file_name': 'value_to_double.txt', + 'output_file_name': 'output.txt', + 'retrieve_temporary_files': ['triple_value.tmp'], + } + ) options = { - 'resources': { - 'num_machines': 1 - }, + 'resources': {'num_machines': 1}, 'max_wallclock_seconds': 5 * 60, 'withmpi': False, 'parser_name': 'core.templatereplacer', @@ -324,15 +319,13 @@ def create_calculation_process(code, inputval): 'template': template, 'metadata': { 'options': options, - } + }, } - return TemplatereplacerCalculation, inputs, expected_result + return CalculationFactory('core.templatereplacer'), inputs, expected_result def run_arithmetic_add(): """Run the `ArithmeticAddCalculation`.""" - ArithmeticAddCalculation = CalculationFactory('core.arithmetic.add') - code = load_code(CODENAME_ADD) inputs = { 'x': Int(1), @@ -341,7 +334,7 @@ def run_arithmetic_add(): } # Normal inputs should run just fine - results, node = run.get_node(ArithmeticAddCalculation, **inputs) + results, node = run.get_node(CalculationFactory('core.arithmetic.add'), **inputs) assert node.is_finished_ok, node.exit_status assert results['sum'] == 3 @@ -377,7 +370,7 @@ def run_base_restart_workchain(): inputs['add']['y'] = Int(10) results, node = run.get_node(ArithmeticAddBaseWorkChain, **inputs) assert not node.is_finished_ok, node.process_state - assert node.exit_status == ArithmeticAddBaseWorkChain.exit_codes.ERROR_TOO_BIG.status, node.exit_status # pylint: disable=no-member + assert node.exit_status == ArithmeticAddBaseWorkChain.exit_codes.ERROR_TOO_BIG.status, node.exit_status assert len(node.called) == 1 # Check that overriding default handler enabled status works @@ -385,14 +378,12 @@ def run_base_restart_workchain(): inputs['handler_overrides'] = Dict({'disabled_handler': True}) results, node = run.get_node(ArithmeticAddBaseWorkChain, **inputs) assert not node.is_finished_ok, node.process_state - assert node.exit_status == ArithmeticAddBaseWorkChain.exit_codes.ERROR_ENABLED_DOOM.status, node.exit_status # pylint: disable=no-member + assert node.exit_status == ArithmeticAddBaseWorkChain.exit_codes.ERROR_ENABLED_DOOM.status, node.exit_status assert len(node.called) == 1 def run_multiply_add_workchain(): """Run the `MultiplyAddWorkChain`.""" - MultiplyAddWorkChain = WorkflowFactory('core.arithmetic.multiply_add') - code = load_code(CODENAME_ADD) inputs = { 'x': Int(1), @@ -402,7 +393,7 @@ def run_multiply_add_workchain(): } # Normal inputs should run just fine - results, node = run.get_node(MultiplyAddWorkChain, **inputs) + results, node = run.get_node(WorkflowFactory('core.arithmetic.multiply_add'), **inputs) assert node.is_finished_ok, node.exit_status assert len(node.called) == 2 assert 'result' in results @@ -428,7 +419,6 @@ def launch_all(): :returns: dictionary with expected results and pks of all launched calculations and workchains """ - # pylint: disable=too-many-locals,too-many-statements expected_results_process_functions = {} expected_results_calculations = {} expected_results_workchains = {} @@ -450,7 +440,6 @@ def launch_all(): print('Testing the stashing functionality') process, inputs, expected_result = create_calculation_process(code=code_doubler, inputval=1) with tempfile.TemporaryDirectory() as tmpdir: - # Delete the temporary directory to test that the stashing functionality will create it if necessary shutil.rmtree(tmpdir, ignore_errors=True) @@ -570,8 +559,10 @@ def relaunch_cached(results): results['calculations'][calc.pk] = expected_result if not ( - validate_calculations(results['calculations']) and validate_workchains(results['workchains']) and - validate_cached(cached_calcs) and validate_process_functions(results['process_functions']) + validate_calculations(results['calculations']) + and validate_workchains(results['workchains']) + and validate_cached(cached_calcs) + and validate_process_functions(results['process_functions']) ): print_daemon_log() print('') @@ -585,7 +576,6 @@ def relaunch_cached(results): def main(): """Launch a bunch of calculation jobs and workchains.""" - results = launch_all() print('Waiting for end of execution...') @@ -602,19 +592,23 @@ def main(): print('#' * 78) print("Output of 'verdi process list -a':") try: - print(subprocess.check_output( - ['verdi', 'process', 'list', '-a'], - stderr=subprocess.STDOUT, - )) + print( + subprocess.check_output( + ['verdi', 'process', 'list', '-a'], + stderr=subprocess.STDOUT, + ) + ) except subprocess.CalledProcessError as exception: print(f'Note: the command failed, message: {exception}') print("Output of 'verdi daemon status':") try: - print(subprocess.check_output( - ['verdi', 'daemon', 'status'], - stderr=subprocess.STDOUT, - )) + print( + subprocess.check_output( + ['verdi', 'daemon', 'status'], + stderr=subprocess.STDOUT, + ) + ) except subprocess.CalledProcessError as exception: print(f'Note: the command failed, message: {exception}') diff --git a/.github/system_tests/test_ipython_magics.py b/.github/system_tests/test_ipython_magics.py deleted file mode 100644 index 95ac6298bf..0000000000 --- a/.github/system_tests/test_ipython_magics.py +++ /dev/null @@ -1,30 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Test the AiiDA iPython magics.""" -from IPython.testing.globalipapp import get_ipython - -from aiida.tools.ipython.ipython_magics import register_ipython_extension - - -def test_ipython_magics(): - """Test that the %aiida magic can be loaded and adds the QueryBuilder and Node variables.""" - ipy = get_ipython() - register_ipython_extension(ipy) - - cell = """ -%aiida -qb=QueryBuilder() -qb.append(Node) -qb.all() -Dict().store() -""" - result = ipy.run_cell(cell) - - assert result.success diff --git a/.github/system_tests/test_profile_manager.py b/.github/system_tests/test_profile_manager.py deleted file mode 100644 index 7c96c9b1cc..0000000000 --- a/.github/system_tests/test_profile_manager.py +++ /dev/null @@ -1,93 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Unittests for TestManager""" -import os -import sys -import unittest -import warnings - -from pgtest import pgtest -import pytest - -from aiida.common.utils import Capturing -from aiida.manage.tests import TemporaryProfileManager, TestManagerError, get_test_backend_name - - -class TemporaryProfileManagerTestCase(unittest.TestCase): - """Test the TemporaryProfileManager class""" - - def setUp(self): - if sys.version_info[0] >= 3: - # tell unittest not to warn about running processes - warnings.simplefilter('ignore', ResourceWarning) # pylint: disable=no-member,undefined-variable - - self.backend = get_test_backend_name() - self.profile_manager = TemporaryProfileManager(backend=self.backend) - - def tearDown(self): - self.profile_manager.destroy_all() - - def test_create_db_cluster(self): - self.profile_manager.create_db_cluster() - self.assertTrue(pgtest.is_server_running(self.profile_manager.pg_cluster.cluster)) - - def test_create_aiida_db(self): - self.profile_manager.create_db_cluster() - self.profile_manager.create_aiida_db() - self.assertTrue(self.profile_manager.postgres.db_exists(self.profile_manager.profile_info['database_name'])) - - @pytest.mark.filterwarnings('ignore:Creating AiiDA configuration folder') - def test_create_use_destroy_profile2(self): - """ - Test temporary test profile creation - - * The profile gets created, the dbenv loaded - * Data can be stored in the db - * reset_db deletes all data added after profile creation - * destroy_all removes all traces of the test run - - Note: This test function loads the dbenv - i.e. you cannot run similar test functions (that create profiles) - in the same test session. aiida.manage.configuration.reset_profile() was not yet enough, see - https://github.com/aiidateam/aiida-core/issues/3482 - """ - with Capturing() as output: - self.profile_manager.create_profile() - - self.assertTrue(self.profile_manager.root_dir_ok, msg=output) - self.assertTrue(self.profile_manager.config_dir_ok, msg=output) - self.assertTrue(self.profile_manager.repo_ok, msg=output) - from aiida.manage.configuration.settings import AIIDA_CONFIG_FOLDER - self.assertEqual(str(AIIDA_CONFIG_FOLDER), self.profile_manager.config_dir, msg=output) - - from aiida.orm import load_node - from aiida.plugins import DataFactory - data = DataFactory('core.dict')(dict={'key': 'value'}) - data.store() - data_pk = data.pk - self.assertTrue(load_node(data_pk)) - - with self.assertRaises(TestManagerError): - self.test_create_aiida_db() - - self.profile_manager.clear_profile() - with self.assertRaises(Exception): - load_node(data_pk) - - temp_dir = self.profile_manager.root_dir - self.profile_manager.destroy_all() - with self.assertRaises(Exception): - self.profile_manager.postgres.db_exists(self.profile_manager.dbinfo['db_name']) - self.assertFalse(os.path.exists(temp_dir)) - self.assertIsNone(self.profile_manager.root_dir) - self.assertIsNone(self.profile_manager.pg_cluster) - - -if __name__ == '__main__': - unittest.main() diff --git a/.github/system_tests/test_test_manager.py b/.github/system_tests/test_test_manager.py deleted file mode 100644 index f31ae6325d..0000000000 --- a/.github/system_tests/test_test_manager.py +++ /dev/null @@ -1,52 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Unittests for TestManager""" -import sys -import unittest -import warnings - -import pytest - -from aiida.manage.tests import TestManager, get_test_backend_name - - -class TestManagerTestCase(unittest.TestCase): - """Test the TestManager class""" - - def setUp(self): - if sys.version_info[0] >= 3: - # tell unittest not to warn about running processes - warnings.simplefilter('ignore', ResourceWarning) # pylint: disable=no-member,undefined-variable - - self.backend = get_test_backend_name() - self.test_manager = TestManager() - - def tearDown(self): - self.test_manager.destroy_all() - - @pytest.mark.filterwarnings('ignore:Creating AiiDA configuration folder') - def test_pgtest_argument(self): - """ - Create a temporary profile, passing the pgtest argument. - """ - from pgtest.pgtest import which - - # this should fail - pgtest = {'pg_ctl': 'notapath'} - with self.assertRaises(AssertionError): - self.test_manager.use_temporary_profile(backend=self.backend, pgtest=pgtest) - - # pg_ctl is what PGTest also looks for (although it might be more clever) - pgtest = {'pg_ctl': which('pg_ctl')} - self.test_manager.use_temporary_profile(backend=self.backend, pgtest=pgtest) - - -if __name__ == '__main__': - unittest.main() diff --git a/.github/system_tests/test_verdi_load_time.sh b/.github/system_tests/test_verdi_load_time.sh deleted file mode 100755 index 07e8772ebe..0000000000 --- a/.github/system_tests/test_verdi_load_time.sh +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env bash -set -e - -# Test the loading time of `verdi`. This is and attempt to catch changes to the imports in `aiida.cmdline` that will -# indirectly load the `aiida.orm` module which will trigger loading of the backend environment. This slows down `verdi` -# significantly, making tab-completion unusable. -VERDI=`which verdi` - -# Typically, the loading time of `verdi` should be around ~0.2 seconds. When loading the database environment this -# tends to go towards ~0.8 seconds. Since these timings are obviously machine and environment dependent, typically these -# types of tests are fragile. But with a load limit of more than twice the ideal loading time, if exceeded, should give -# a reasonably sure indication that the loading of `verdi` is unacceptably slowed down. -LOAD_LIMIT=0.4 -MAX_NUMBER_ATTEMPTS=5 - -iteration=0 - -while true; do - - iteration=$((iteration+1)) - load_time=$(/usr/bin/time -q -f "%e" $VERDI 2>&1 > /dev/null) - - if (( $(echo "$load_time < $LOAD_LIMIT" | bc -l) )); then - echo "SUCCESS: loading time $load_time at iteration $iteration below $LOAD_LIMIT" - break - else - echo "WARNING: loading time $load_time at iteration $iteration above $LOAD_LIMIT" - - if [ $iteration -eq $MAX_NUMBER_ATTEMPTS ]; then - echo "ERROR: loading time exceeded the load limit $iteration consecutive times." - echo "ERROR: please check that 'aiida.cmdline' does not import 'aiida.orm' at module level, even indirectly" - echo "ERROR: also, the database backend environment should not be loaded." - exit 2 - fi - fi - -done diff --git a/.github/system_tests/workchains.py b/.github/system_tests/workchains.py index af2ef91c4f..34b29af58a 100644 --- a/.github/system_tests/workchains.py +++ b/.github/system_tests/workchains.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- ########################################################################### # Copyright (c), The AiiDA team. All rights reserved. # # This file is part of the AiiDA code. # @@ -7,8 +6,8 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -# pylint: disable=invalid-name """Work chain implementations for testing purposes.""" + from aiida.common import AttributeDict from aiida.engine import ( BaseRestartWorkChain, @@ -64,15 +63,15 @@ def setup(self): def sanity_check_not_too_big(self, node): """My puny brain cannot deal with numbers that I cannot count on my hand.""" if node.is_finished_ok and node.outputs.sum > 10: - return ProcessHandlerReport(True, self.exit_codes.ERROR_TOO_BIG) # pylint: disable=no-member + return ProcessHandlerReport(True, self.exit_codes.ERROR_TOO_BIG) @process_handler(priority=460, enabled=False) - def disabled_handler(self, node): # pylint: disable=unused-argument + def disabled_handler(self, node): """By default this is not enabled and so should never be called, irrespective of exit codes of sub process.""" - return ProcessHandlerReport(True, self.exit_codes.ERROR_ENABLED_DOOM) # pylint: disable=no-member + return ProcessHandlerReport(True, self.exit_codes.ERROR_ENABLED_DOOM) @process_handler(priority=450, exit_codes=ExitCode(1000, 'Unicorn encountered')) - def a_magic_unicorn_appeared(self, node): # pylint: disable=no-self-argument + def a_magic_unicorn_appeared(self, node): """As we all know unicorns do not exist so we should never have to deal with it.""" raise RuntimeError('this handler should never even have been called') @@ -85,9 +84,7 @@ def error_negative_sum(self, node): class NestedWorkChain(WorkChain): - """ - Nested workchain which creates a workflow where the nesting level is equal to its input. - """ + """Nested workchain which creates a workflow where the nesting level is equal to its input.""" @classmethod def define(cls, spec): @@ -216,9 +213,7 @@ def do_test(self): class CalcFunctionRunnerWorkChain(WorkChain): - """ - WorkChain which calls an InlineCalculation in its step. - """ + """WorkChain which calls an InlineCalculation in its step.""" @classmethod def define(cls, spec): @@ -234,9 +229,7 @@ def do_run(self): class WorkFunctionRunnerWorkChain(WorkChain): - """ - WorkChain which calls a workfunction in its step - """ + """WorkChain which calls a workfunction in its step""" @classmethod def define(cls, spec): diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 6c198943fa..c526293448 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -3,13 +3,16 @@ name: Performance benchmarks on: push: branches: [main] - paths-ignore: ['docs/**'] + paths-ignore: [docs/**] + pull_request: + branches-ignore: [gh-pages] + paths: [.github/workflows/benchmark*] # https://docs.github.com/en/actions/using-jobs/using-concurrency concurrency: # only cancel in-progress jobs or runs for the current workflow - matches against branch & tags - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true jobs: @@ -23,14 +26,14 @@ jobs: matrix: os: [ubuntu-22.04] postgres: ['12.14'] - rabbitmq: ['3.8.14-management'] + rabbitmq: [3.8.14-management] runs-on: ${{ matrix.os }} timeout-minutes: 60 services: postgres: - image: "postgres:${{ matrix.postgres }}" + image: postgres:${{ matrix.postgres }} env: POSTGRES_DB: test_aiida POSTGRES_PASSWORD: '' @@ -41,46 +44,37 @@ jobs: --health-timeout 5s --health-retries 5 ports: - - 5432:5432 + - 5432:5432 rabbitmq: - image: "rabbitmq:${{ matrix.rabbitmq }}" + image: rabbitmq:${{ matrix.rabbitmq }} ports: - - 5672:5672 + - 5672:5672 steps: - - uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v4 + - uses: actions/checkout@v4 + + - name: Install aiida-core + uses: ./.github/actions/install-aiida-core with: python-version: '3.10' - - - name: Upgrade pip - run: | - pip install --upgrade pip - pip --version - - - name: Install python dependencies - run: | - pip install -r requirements/requirements-py-3.10.txt - pip install --no-deps -e . - pip freeze + from-lock: 'true' - name: Run benchmarks - run: pytest --benchmark-only --benchmark-json benchmark.json + run: uv run pytest --db-backend psql --benchmark-only --benchmark-json benchmark.json tests/ - name: Store benchmark result uses: aiidateam/github-action-benchmark@v3 with: - benchmark-data-dir-path: "dev/bench/${{ matrix.os }}/psql_dos" - name: "pytest-benchmarks:${{ matrix.os }},psql_dos" - metadata: "postgres:${{ matrix.postgres }}, rabbitmq:${{ matrix.rabbitmq }}" + benchmark-data-dir-path: dev/bench/${{ matrix.os }}/psql_dos + name: pytest-benchmarks:${{ matrix.os }},psql_dos + metadata: postgres:${{ matrix.postgres }}, rabbitmq:${{ matrix.rabbitmq }} output-file-path: benchmark.json render-json-path: .github/workflows/benchmark-config.json - commit-msg-append: "[ci skip]" + commit-msg-append: '[ci skip]' github-token: ${{ secrets.GITHUB_TOKEN }} auto-push: true # Show alert with commit comment on detecting possible performance regression - alert-threshold: '200%' + alert-threshold: 200% comment-on-alert: true fail-on-alert: false - alert-comment-cc-users: '@chrisjsewell,@giovannipizzi' + alert-comment-cc-users: '@giovannipizzi,@agoscinski,@GeigerJ2,@khsrali,@unkcpz' diff --git a/.github/workflows/build_and_test_docker_on_pr.yml b/.github/workflows/build_and_test_docker_on_pr.yml deleted file mode 100644 index 9078daefc9..0000000000 --- a/.github/workflows/build_and_test_docker_on_pr.yml +++ /dev/null @@ -1,65 +0,0 @@ -# Test the Docker image on every pull request. -# -# The steps are: -# 1. Build docker image using cached data. -# 2. Start the docker container. -# 3. Check that AiiDA is responsive. - -name: build-and-test-image-from-pull-request - -on: - pull_request: - path_ignore: - - 'docs/**' - -# https://docs.github.com/en/actions/using-jobs/using-concurrency -concurrency: - # only cancel in-progress jobs or runs for the current workflow - matches against branch & tags - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - - build-and-test: - - # Only run this job on the main repository and not on forks - if: github.repository == 'aiidateam/aiida-core' - - runs-on: ubuntu-latest - timeout-minutes: 30 - - steps: - - - uses: actions/checkout@v2 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v1 - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v1 - - - name: Cache Docker layers - uses: actions/cache@v2 - with: - path: /tmp/.buildx-cache - key: ${{ runner.os }}-buildx-${{ github.sha }} - restore-keys: | - ${{ runner.os }}-buildx- - - - name: Build image locally - uses: docker/build-push-action@v2 - with: - load: true - push: false - tags: aiida-core:latest - cache-from: type=local,src=/tmp/.buildx-cache - cache-to: type=local,dest=/tmp/.buildx-cache - - - name: Start and test the container - run: | - export DOCKERID=`docker run -d aiida-core:latest` - docker exec --tty $DOCKERID wait-for-services - docker logs $DOCKERID - docker exec --tty --user aiida $DOCKERID /bin/bash -l -c 'verdi profile show default' - docker exec --tty --user aiida $DOCKERID /bin/bash -l -c 'verdi computer show localhost' - docker exec --tty --user aiida $DOCKERID /bin/bash -l -c 'verdi daemon status' diff --git a/.github/workflows/check_release_tag.py b/.github/workflows/check_release_tag.py index 47b45865c5..bf31b06824 100644 --- a/.github/workflows/check_release_tag.py +++ b/.github/workflows/check_release_tag.py @@ -1,5 +1,5 @@ -# -*- coding: utf-8 -*- """Check that the GitHub release tag matches the package version.""" + import argparse import ast from pathlib import Path @@ -11,14 +11,17 @@ def get_version_from_module(content: str) -> str: try: module = ast.parse(content) except SyntaxError as exc: - raise IOError(f'Unable to parse module: {exc}') + raise OSError(f'Unable to parse module: {exc}') try: return next( - ast.literal_eval(statement.value) for statement in module.body if isinstance(statement, ast.Assign) - for target in statement.targets if isinstance(target, ast.Name) and target.id == '__version__' + ast.literal_eval(statement.value) + for statement in module.body + if isinstance(statement, ast.Assign) + for target in statement.targets + if isinstance(target, ast.Name) and target.id == '__version__' ) except StopIteration: - raise IOError('Unable to find __version__ in module') + raise OSError('Unable to find __version__ in module') if __name__ == '__main__': @@ -27,5 +30,5 @@ def get_version_from_module(content: str) -> str: args = parser.parse_args() assert args.GITHUB_REF.startswith('refs/tags/v'), f'GITHUB_REF should start with "refs/tags/v": {args.GITHUB_REF}' tag_version = args.GITHUB_REF[11:] - pypi_version = get_version_from_module(Path('aiida/__init__.py').read_text(encoding='utf-8')) + pypi_version = get_version_from_module(Path('src/aiida/__init__.py').read_text(encoding='utf-8')) assert tag_version == pypi_version, f'The tag version {tag_version} != {pypi_version} specified in `pyproject.toml`' diff --git a/.github/workflows/ci-code.yml b/.github/workflows/ci-code.yml index b93f0a4dad..2281c39f91 100644 --- a/.github/workflows/ci-code.yml +++ b/.github/workflows/ci-code.yml @@ -1,61 +1,36 @@ -name: continuous-integration-code +name: ci-code on: push: branches-ignore: [gh-pages] pull_request: branches-ignore: [gh-pages] - paths-ignore: ['docs/**'] + paths-ignore: [docs/**] # https://docs.github.com/en/actions/using-jobs/using-concurrency concurrency: # only cancel in-progress jobs or runs for the current workflow - matches against branch & tags - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true -jobs: - - check-requirements: - - runs-on: ubuntu-latest - timeout-minutes: 5 - - steps: - - uses: actions/checkout@v2 - - - name: Set up Python 3.9 - uses: actions/setup-python@v2 - with: - python-version: '3.9' - - - name: Install utils/ dependencies - run: pip install -r utils/requirements.txt - - - name: Check requirements files - id: check_reqs - run: python ./utils/dependency_management.py check-requirements DEFAULT - - - name: Create commit comment - if: failure() && steps.check_reqs.outputs.error - uses: peter-evans/commit-comment@v1 - with: - path: pyproject.toml - body: | - ${{ steps.check_reqs.outputs.error }} +env: + FORCE_COLOR: 1 - Click [here](https://github.com/aiidateam/aiida-core/wiki/AiiDA-Dependency-Management) for more information on dependency management. +jobs: tests: - needs: [check-requirements] - - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 timeout-minutes: 45 strategy: fail-fast: false matrix: - python-version: ['3.9', '3.11'] + python-version: ['3.9', '3.12'] + database-backend: [psql] + include: + - python-version: '3.9' + database-backend: sqlite services: postgres: @@ -70,82 +45,96 @@ jobs: --health-timeout 5s --health-retries 5 ports: - - 5432:5432 + - 5432:5432 rabbitmq: image: rabbitmq:3.8.14-management ports: - - 5672:5672 - - 15672:15672 + - 5672:5672 + - 15672:15672 slurm: image: xenonmiddleware/slurm:17 ports: - - 5001:22 + - 5001:22 steps: - - uses: actions/checkout@v2 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} + - uses: actions/checkout@v4 - name: Install system dependencies run: sudo apt update && sudo apt install postgresql graphviz - - name: Upgrade pip and setuptools - # Install specific version of setuptools, because 65.6.0 breaks a number of packages, such as numpy - run: | - pip install --upgrade pip - pip install setuptools==65.5.0 - pip --version - - name: Install aiida-core - run: | - pip install -r requirements/requirements-py-${{ matrix.python-version }}.txt - pip install --no-deps -e . - pip freeze + uses: ./.github/actions/install-aiida-core + with: + python-version: ${{ matrix.python-version }} - name: Setup environment - run: - .github/workflows/setup.sh + run: .github/workflows/setup.sh - name: Run test suite env: + AIIDA_TEST_PROFILE: test_aiida AIIDA_WARN_v3: 1 - SQLALCHEMY_WARN_20: 1 - run: - .github/workflows/tests.sh + # NOTE1: Python 3.12 has a performance regression when running with code coverage + # so run code coverage only for python 3.9. + # TODO: Remove a workaround for VIRTUAL_ENV once the setup-uv action is updated + # https://github.com/astral-sh/setup-uv/issues/219 + run: | + ${{ matrix.python-version == '3.9' && 'VIRTUAL_ENV=$PWD/.venv' || '' }} + pytest -n auto --db-backend ${{ matrix.database-backend }} -m 'not nightly' tests/ ${{ matrix.python-version == '3.9' && '--cov aiida' || '' }} - name: Upload coverage report if: matrix.python-version == 3.9 && github.repository == 'aiidateam/aiida-core' - uses: codecov/codecov-action@v1 + uses: codecov/codecov-action@v5 with: + token: ${{ secrets.CODECOV_TOKEN }} name: aiida-pytests-py3.9 file: ./coverage.xml fail_ci_if_error: false # don't fail job, if coverage upload fails - verdi: - - runs-on: ubuntu-latest - timeout-minutes: 15 + tests-presto: - strategy: - fail-fast: false - matrix: - python-version: ['3.9', '3.11'] + runs-on: ubuntu-24.04 + timeout-minutes: 20 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + - name: Install graphviz + run: sudo apt update && sudo apt install graphviz + + - name: Install aiida-core + uses: ./.github/actions/install-aiida-core with: - python-version: ${{ matrix.python-version }} + python-version: '3.11' - - name: Install python dependencies - run: pip install -e . + - name: Setup SSH on localhost + run: .github/workflows/setup_ssh.sh + + - name: Run test suite + env: + AIIDA_WARN_v3: 0 + run: uv run pytest -n auto -m 'presto' tests/ + + + verdi: + + runs-on: ubuntu-24.04 + timeout-minutes: 10 + + steps: + - uses: actions/checkout@v4 + + - name: Install aiida-core + uses: ./.github/actions/install-aiida-core + with: + python-version: '3.12' + from-lock: 'true' + # NOTE: The `verdi devel check-undesired-imports` fails if + # the 'tui' extra is installed. + extras: '' - - name: Run verdi + - name: Run verdi tests run: | verdi devel check-load-time + verdi devel check-undesired-imports .github/workflows/verdi.sh diff --git a/.github/workflows/ci-style.yml b/.github/workflows/ci-style.yml deleted file mode 100644 index f606ac45e6..0000000000 --- a/.github/workflows/ci-style.yml +++ /dev/null @@ -1,39 +0,0 @@ -name: continuous-integration-style - -on: - push: - branches-ignore: [gh-pages] - pull_request: - branches-ignore: [gh-pages] - -jobs: - - pre-commit: - - runs-on: ubuntu-latest - timeout-minutes: 30 - - steps: - - uses: actions/checkout@v2 - - - name: Set up Python 3.9 - uses: actions/setup-python@v2 - with: - python-version: '3.9' - - - name: Install system dependencies - # note libkrb5-dev is required as a dependency for the gssapi pip install - run: | - sudo apt update - sudo apt install libkrb5-dev ruby ruby-dev - - - name: Install python dependencies - run: | - pip install --upgrade pip - pip install -r requirements/requirements-py-3.9.txt - pip install -e .[pre-commit] - pip freeze - - - name: Run pre-commit - run: - pre-commit run --all-files || ( git status --short ; git diff ; exit 1 ) diff --git a/.github/workflows/daemon_tests.sh b/.github/workflows/daemon_tests.sh new file mode 100755 index 0000000000..8fff830eed --- /dev/null +++ b/.github/workflows/daemon_tests.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash +set -ev + +# Make sure the folder containing the workchains is in the python path before the daemon is started +SYSTEM_TESTS="${GITHUB_WORKSPACE}/.github/system_tests" +MODULE_POLISH="${GITHUB_WORKSPACE}/.molecule/default/files/polish" + +export PYTHONPATH="${PYTHONPATH}:${SYSTEM_TESTS}:${MODULE_POLISH}" + +verdi daemon start 4 +verdi -p test_aiida run ${SYSTEM_TESTS}/test_daemon.py +verdi -p test_aiida run ${SYSTEM_TESTS}/test_containerized_code.py +bash ${SYSTEM_TESTS}/test_polish_workchains.sh +verdi daemon stop diff --git a/.github/workflows/docker-build-test.yml b/.github/workflows/docker-build-test.yml new file mode 100644 index 0000000000..64edd7bd56 --- /dev/null +++ b/.github/workflows/docker-build-test.yml @@ -0,0 +1,73 @@ +# This workflow is only meant to be run on PRs from forked repositoritories. +# The full workflow that is run on pushes to origin is in docker.yml +# The difference here is that we do not upload to ghcr.io, +# and thus don't need a GITHUB_TOKEN secret. +name: Build & Test Docker Images + +env: + BUILDKIT_PROGRESS: plain + FORCE_COLOR: 1 + +on: + pull_request: + paths-ignore: + - '**.md' + - '**.txt' + - docs/** + - tests/** + +# https://docs.github.com/en/actions/using-jobs/using-concurrency +concurrency: + # only cancel in-progress jobs or runs for the current workflow - matches against branch & tags + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-and-test: + if: ${{ github.event.pull_request.head.repo.fork }} + name: build and test amd64 images + runs-on: ubuntu-24.04 + timeout-minutes: 60 + defaults: + run: + working-directory: .docker + + steps: + + - name: Checkout Repo + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Build images + uses: docker/bake-action@v5 + with: + # Load to Docker engine for testing + load: true + workdir: .docker/ + set: | + *.platform=amd64 + *.cache-to=type=gha,scope=${{ github.workflow }},mode=min + *.cache-from=type=gha,scope=${{ github.workflow }} + files: | + docker-bake.hcl + build.json + + - name: Set Up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + cache: pip + + - name: Install dependencies + run: pip install -r requirements.txt + + - name: Test aiida-core-base + run: pytest -s --variant aiida-core-base tests/ + + - name: Test aiida-core-with-services + run: pytest -s --variant aiida-core-with-services tests/ + + - name: Test aiida-core-dev + run: pytest -s --variant aiida-core-dev tests/ diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml new file mode 100644 index 0000000000..b278ec8349 --- /dev/null +++ b/.github/workflows/docker-build.yml @@ -0,0 +1,78 @@ +name: Build Docker images and upload them to ghcr.io + +env: + BUILDKIT_PROGRESS: plain + REGISTRY: ghcr.io/ + +on: + workflow_call: + inputs: + runsOn: + description: GitHub Actions Runner image + required: true + type: string + platforms: + description: Target platforms for the build (linux/amd64 and/or linux/arm64) + required: true + type: string + outputs: + images: + description: Images identified by digests + value: ${{ jobs.build.outputs.images }} + +jobs: + build: + name: ${{ inputs.platforms }} + runs-on: ${{ inputs.runsOn }} + timeout-minutes: 60 + defaults: + run: + # Make sure we fail if any command in a piped command sequence fails + shell: bash -e -o pipefail {0} + + outputs: + images: ${{ steps.bake_metadata.outputs.images }} + + steps: + + - name: Checkout Repo ⚡️ + uses: actions/checkout@v4 + + - name: Set up QEMU + if: ${{ inputs.platforms != 'linux/amd64' }} + uses: docker/setup-qemu-action@v3 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to GitHub Container Registry 🔑 + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build and upload to ghcr.io 📤 + id: build + uses: docker/bake-action@v5 + with: + push: true + workdir: .docker/ + # Using provenance to disable default attestation so it will build only desired images: + # https://github.com/orgs/community/discussions/45969 + provenance: false + set: | + *.platform=${{ inputs.platforms }} + *.output=type=registry,push-by-digest=true,name-canonical=true + *.cache-to=type=gha,scope=${{ github.workflow }},mode=max + *.cache-from=type=gha,scope=${{ github.workflow }} + files: | + docker-bake.hcl + build.json + + - name: Set output variables + id: bake_metadata + run: | + .github/workflows/extract-docker-image-names.sh | tee -a "${GITHUB_OUTPUT}" + env: + BAKE_METADATA: ${{ steps.build.outputs.metadata }} diff --git a/.github/workflows/docker-publish.yml b/.github/workflows/docker-publish.yml new file mode 100644 index 0000000000..53969ae2d8 --- /dev/null +++ b/.github/workflows/docker-publish.yml @@ -0,0 +1,93 @@ +name: Publish images to Docker container registries + +env: + # https://github.com/docker/metadata-action?tab=readme-ov-file#environment-variables + DOCKER_METADATA_PR_HEAD_SHA: true + +on: + workflow_call: + inputs: + runsOn: + description: GitHub Actions Runner image + required: true + type: string + images: + description: Images built in build step + required: true + type: string + registry: + description: Docker container registry + required: true + type: string + +jobs: + + release: + runs-on: ${{ inputs.runsOn }} + timeout-minutes: 30 + strategy: + fail-fast: true + matrix: + target: [aiida-core-base, aiida-core-with-services, aiida-core-dev] + defaults: + run: + # Make sure we fail if any command in a piped command sequence fails + shell: bash -e -o pipefail {0} + working-directory: .docker + + steps: + - uses: actions/checkout@v4 + + - name: Login to GitHub Container Registry 🔑 + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Login to DockerHub 🔑 + uses: docker/login-action@v3 + if: inputs.registry == 'docker.io' + with: + registry: docker.io + username: ${{ secrets.DOCKER_USERNAME }} + password: ${{ secrets.DOCKER_TOKEN }} + + - name: Read build variables + id: build_vars + run: | + vars=$(cat build.json | jq -c '[.variable | to_entries[] | {"key": .key, "value": .value.default}] | from_entries') + echo "vars=$vars" | tee -a "${GITHUB_OUTPUT}" + + - id: get-version + if: ${{ github.ref_type == 'tag' && startsWith(github.ref_name, 'v') }} + run: | + tag="${{ github.ref_name }}" + echo "AIIDA_VERSION=${tag#v}" >> $GITHUB_OUTPUT + + - name: Docker meta + id: meta + uses: docker/metadata-action@v5 + env: ${{ fromJSON(steps.build_vars.outputs.vars) }} + with: + images: ${{ inputs.registry }}/${{ github.repository_owner }}/${{ matrix.target }} + tags: | + type=ref,event=pr + type=ref,event=branch,enable=${{ github.ref_name != 'main' }} + type=edge,enable={{is_default_branch}} + type=raw,value=aiida-${{ steps.get-version.outputs.AIIDA_VERSION }},enable=${{ github.ref_type == 'tag' && startsWith(github.ref_name, 'v') }} + type=raw,value=python-${{ env.PYTHON_VERSION }},enable=${{ github.ref_type == 'tag' && startsWith(github.ref_name, 'v') }} + type=raw,value=postgresql-${{ env.PGSQL_VERSION }},enable=${{ github.ref_type == 'tag' && startsWith(github.ref_name, 'v') }} + type=match,pattern=v(\d{4}\.\d{4}(-.+)?),group=1 + + - name: Determine source image + id: images + run: | + src=$(echo '${{ inputs.images }}'| jq -cr '.[("${{ matrix.target }}"|ascii_upcase|sub("-"; "_"; "g")) + "_IMAGE"]') + echo "src=$src" | tee -a "${GITHUB_OUTPUT}" + + - name: Push image + uses: akhilerm/tag-push-action@v2.2.0 + with: + src: ${{ steps.images.outputs.src }} + dst: ${{ steps.meta.outputs.tags }} diff --git a/.github/workflows/docker-test.yml b/.github/workflows/docker-test.yml new file mode 100644 index 0000000000..652c459102 --- /dev/null +++ b/.github/workflows/docker-test.yml @@ -0,0 +1,52 @@ +name: Test newly built images + +on: + workflow_call: + inputs: + runsOn: + description: GitHub Actions Runner image + required: true + type: string + images: + description: Images built in build step + required: true + type: string + target: + description: Target image for testing + required: true + type: string + +jobs: + + test: + runs-on: ${{ inputs.runsOn }} + timeout-minutes: 20 + defaults: + run: + working-directory: .docker + + steps: + + - name: Checkout Repo ⚡️ + uses: actions/checkout@v4 + + - name: Login to GitHub Container Registry 🔑 + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Set Up Python 🐍 + if: ${{ startsWith(inputs.runsOn, 'ubuntu') }} + uses: actions/setup-python@v5 + with: + python-version: '3.11' + cache: pip + + - name: Install dependencies 📦 + run: pip install -r requirements.txt + + - name: Run tests + run: pytest -s --variant ${{ inputs.target }} tests/ + env: ${{ fromJSON(inputs.images) }} diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml new file mode 100644 index 0000000000..60b139f298 --- /dev/null +++ b/.github/workflows/docker.yml @@ -0,0 +1,89 @@ +name: Docker Images + +# This workflow needs permissions to publish images to ghcr.io, +# so it does not work for forks. Therefore, we only trigger it +# on pushes to aiidateam/aiida-core repo. +on: + push: + branches: + - '*' + tags: + - v* + paths-ignore: + - '**.md' + - '**.txt' + - docs/** + - tests/** + workflow_dispatch: + +# https://docs.github.com/en/actions/using-jobs/using-concurrency +concurrency: + # only cancel in-progress jobs or runs for the current workflow - matches against branch & tags + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + FORCE_COLOR: 1 + +jobs: + # We build only amd64 first to catch failures faster. + build-amd64: + if: ${{ github.repository == 'aiidateam/aiida-core' }} + uses: ./.github/workflows/docker-build.yml + with: + runsOn: ubuntu-22.04 + platforms: linux/amd64 + + test-amd64: + needs: build-amd64 + uses: ./.github/workflows/docker-test.yml + strategy: + matrix: + target: [aiida-core-base, aiida-core-with-services, aiida-core-dev] + fail-fast: false + with: + runsOn: ubuntu-22.04 + images: ${{ needs.build-amd64.outputs.images }} + target: ${{ matrix.target }} + + build: + needs: test-amd64 + uses: ./.github/workflows/docker-build.yml + with: + runsOn: ubuntu-22.04 + platforms: linux/amd64,linux/arm64 + + publish-ghcr: + needs: [build, test-amd64] + uses: ./.github/workflows/docker-publish.yml + secrets: inherit + with: + runsOn: ubuntu-22.04 + images: ${{ needs.build.outputs.images }} + registry: ghcr.io + + # IMPORTANT: To save arm64 runners resources, + # we run the test only when pushing to main. + # We also only test the aiida-core-dev image. + test-arm64: + needs: build + if: >- + github.repository == 'aiidateam/aiida-core' + && (github.ref_type == 'tag' || github.ref_name == 'main') + uses: ./.github/workflows/docker-test.yml + with: + runsOn: buildjet-4vcpu-ubuntu-2204-arm + images: ${{ needs.build.outputs.images }} + target: aiida-core-dev + + publish-dockerhub: + if: >- + github.repository == 'aiidateam/aiida-core' + && (github.ref_type == 'tag' || github.ref_name == 'main') + needs: [build, test-amd64, test-arm64, publish-ghcr] + uses: ./.github/workflows/docker-publish.yml + secrets: inherit + with: + runsOn: ubuntu-22.04 + images: ${{ needs.build.outputs.images }} + registry: docker.io diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 8e3704e640..a380c91f85 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -5,25 +5,28 @@ on: branches-ignore: [gh-pages] pull_request: branches-ignore: [gh-pages] - paths: ['docs/**'] + paths: [docs/**] + +env: + FORCE_COLOR: 1 jobs: docs-linkcheck: - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 timeout-minutes: 30 steps: - - uses: actions/checkout@v2 - - name: Set up Python 3.9 - uses: actions/setup-python@v2 + - uses: actions/checkout@v4 + + - name: Install aiida-core and docs deps + uses: ./.github/actions/install-aiida-core with: python-version: '3.9' - - name: Install python dependencies - run: | - pip install --upgrade pip - pip install -e .[docs,tests,rest,atomic_tools] + extras: '[docs,tests,rest,atomic_tools]' + from-lock: 'false' + - name: Build HTML docs id: linkcheck run: | diff --git a/.github/workflows/extract-docker-image-names.sh b/.github/workflows/extract-docker-image-names.sh new file mode 100755 index 0000000000..e395432ddb --- /dev/null +++ b/.github/workflows/extract-docker-image-names.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +set -euo pipefail + +# Extract image names together with their sha256 digests +# from the docker/bake-action metadata output. +# These together uniquely identify newly built images. +# +# The input to this script is a JSON string passed via BAKE_METADATA env variable +# Here's example input (trimmed to relevant bits): +# BAKE_METADATA: { +# "aiida-core-base": { +# # ... +# "containerimage.descriptor": { +# "mediaType": "application/vnd.docker.distribution.manifest.v2+json", +# "digest": "sha256:8e57a52b924b67567314b8ed3c968859cad99ea13521e60bbef40457e16f391d", +# "size": 6170, +# }, +# "containerimage.digest": "sha256:8e57a52b924b67567314b8ed3c968859cad99ea13521e60bbef40457e16f391d", +# "image.name": "ghcr.io/aiidateam/aiida-core-base" +# }, +# "aiida-core-dev": { +# "containerimage.digest": "sha256:4d9be090da287fcdf2d4658bb82f78bad791ccd15dac9af594fb8306abe47e97", +# "...": ... +# "image.name": "ghcr.io/aiidateam/aiida-core-dev" +# }, +# "aiida-core-with-services": { +# "...": "" +# "containerimage.digest": "sha256:85ee91f61be1ea601591c785db038e5899d68d5fb89e07d66d9efbe8f352ee48", +# "image.name": "ghcr.io/aiidateam/aiida-core-with-services" +# }, +# "some-other-key": ... +# } +# +# Example output (real output is on one line): +# +# images={ +# "AIIDA_CORE_BASE_IMAGE": "ghcr.io/aiidateam/aiida-core-base@sha256:4c402a8bfd635650ad691674f8f29e7ddec5fa656fb425452067950415ee447f", +# "AIIDA_CORE_DEV_IMAGE": "ghcr.io/aiidateam/aiida-core-dev@sha256:f94c06e47f801e751f9829010b31532039b210aad2649d43205e16c08371b2ed", +# "AIIDA_CORE_WITH_SERVICES_IMAGE": "ghcr.io/aiidateam/aiida-core-with-services@sha256:bd8272f2a331af7eac3e83c44cc16d23b2e5f601a20ab4a865402659b758515e" +# } +# +# This json output is later turned to environment variables using fromJson() GHA builtin +# (e.g. AIIDA_CORE_BASE_IMAGE=ghcr.io/aiidateam/aiida-core-base@sha256:8e57a52b...) +# and these are in turn read in the docker-compose..yml files for tests. + +if [[ -z ${BAKE_METADATA-} ]];then + echo "ERROR: Environment variable BAKE_METADATA is not set!" + exit 1 +fi + +images=$(echo "${BAKE_METADATA}" | +jq -c 'to_entries | map(select(.key | startswith("aiida"))) | from_entries' | # filters out every key that does not start with aiida +jq -c '. as $base |[to_entries[] |{"key": (.key|ascii_upcase|sub("-"; "_"; "g") + "_IMAGE"), "value": [(.value."image.name"|split(",")[0]),.value."containerimage.digest"]|join("@")}] |from_entries') +echo "images=$images" diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 9ab273a6bd..787793a951 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -1,101 +1,153 @@ name: nightly on: - schedule: - - cron: '0 0 * * *' # Run every day at midnight - pull_request: - paths: - - '.github/workflows/nightly.yml' - - '.github/workflows/setup.sh' - - '.github/system_tests/test_daemon.py' - - '.molecule/default/files/**' - - 'aiida/storage/psql_dos/migrations/**' - - 'tests/storage/psql_dos/migrations/**' - workflow_dispatch: + schedule: + - cron: 0 0 * * * # Run every day at midnight + pull_request: + paths: + - .github/workflows/nightly.yml + - .github/workflows/setup.sh + - .github/system_tests/test_daemon.py + - .molecule/default/files/** + - aiida/storage/psql_dos/migrations/** + - tests/storage/psql_dos/migrations/** + workflow_dispatch: # https://docs.github.com/en/actions/using-jobs/using-concurrency concurrency: # only cancel in-progress jobs or runs for the current workflow - matches against branch & tags - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + FORCE_COLOR: 1 jobs: - tests: - - if: github.repository == 'aiidateam/aiida-core' # Prevent running the builds on forks as well - runs-on: ubuntu-latest - - strategy: - matrix: - python-version: ['3.10'] - - services: - postgres: - image: postgres:12 - env: - POSTGRES_DB: test_aiida - POSTGRES_PASSWORD: '' - POSTGRES_HOST_AUTH_METHOD: trust - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - ports: - - 5432:5432 - rabbitmq: - image: rabbitmq:3.8.14-management - ports: - - 5672:5672 - - 15672:15672 - - steps: - - uses: actions/checkout@v2 - - uses: eWaterCycle/setup-singularity@v7 # for containerized code test - with: - singularity-version: 3.8.7 - - - name: Cache Python dependencies - uses: actions/cache@v1 - with: - path: ~/.cache/pip - key: pip-${{ matrix.python-version }}-tests-${{ hashFiles('**/setup.json') }} - restore-keys: - pip-${{ matrix.python-version }}-tests - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - - name: Install system dependencies - run: sudo apt update && sudo apt install postgresql - - - name: Install aiida-core - id: install - run: | - pip install -r requirements/requirements-py-${{ matrix.python-version }}.txt - pip install --no-deps -e . - pip freeze - - - name: Setup environment - run: .github/workflows/setup.sh - - - name: Run tests - id: tests - run: .github/workflows/tests_nightly.sh - - - name: Slack notification - # Always run this step (otherwise it would be skipped if any of the previous steps fail) but only if the - # `install` or `tests` steps failed, and the `SLACK_WEBHOOK` is available. The latter is not the case for - # pull requests that come from forks. This is a limitation of secrets on GHA - if: always() && (steps.install.outcome == 'Failure' || steps.tests.outcome == 'Failure') && env.SLACK_WEBHOOK != null - uses: rtCamp/action-slack-notify@v2 - env: - SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} - SLACK_ICON: https://www.materialscloud.org/discover/images/0ba0a17d.aiida-logo-128.png - SLACK_CHANNEL: dev-aiida-core - SLACK_COLOR: b60205 - SLACK_TITLE: "Nightly build of `aiida-core/main` failed" - SLACK_MESSAGE: "The tests of the `nightly.yml` GHA worklow failed." + nightly-tests: + + if: github.repository == 'aiidateam/aiida-core' # Prevent running the builds on forks as well + runs-on: ubuntu-24.04 + + services: + postgres: + image: postgres:12 + env: + POSTGRES_DB: test_aiida + POSTGRES_PASSWORD: '' + POSTGRES_HOST_AUTH_METHOD: trust + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + rabbitmq: + image: rabbitmq:3.8.14-management + ports: + - 5672:5672 + - 15672:15672 + slurm: + image: xenonmiddleware/slurm:17 + ports: + - 5001:22 + + steps: + - uses: actions/checkout@v4 + + - name: Install system dependencies + run: sudo apt update && sudo apt install postgresql + + - name: Install aiida-core + id: install + uses: ./.github/actions/install-aiida-core + with: + python-version: '3.11' + from-lock: 'true' + + - name: Setup environment + run: .github/workflows/setup.sh + + - name: Run pytest nigthly tests + id: pytest-tests + env: + AIIDA_TEST_PROFILE: test_aiida + AIIDA_WARN_v3: 1 + run: | + pytest --db-backend psql -m nightly tests/ + + - name: Run daemon nightly tests + id: daemon-tests + run: .github/workflows/daemon_tests.sh + + - name: Slack notification + # Always run this step (otherwise it would be skipped if any of the previous steps fail) but only if the + # `install` or `tests` steps failed, and the `SLACK_WEBHOOK` is available. The latter is not the case for + # pull requests that come from forks. This is a limitation of secrets on GHA + if: always() && (steps.install.outcome == 'failure' || steps.pytest-tests.outcome == 'failure' || steps.daemon-tests.outcome == 'failure') && env.SLACK_WEBHOOK != null + uses: rtCamp/action-slack-notify@v2 + env: + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} + SLACK_ICON: https://www.materialscloud.org/discover/images/0ba0a17d.aiida-logo-128.png + SLACK_CHANNEL: dev-aiida-core + SLACK_COLOR: b60205 + SLACK_TITLE: Nightly build of `aiida-core/main` failed + SLACK_MESSAGE: The tests of the `nightly.yml` GHA worklow failed. + + + # Run a subset of test suite to ensure compatibility with latest RabbitMQ releases + rabbitmq-tests: + + runs-on: ubuntu-24.04 + timeout-minutes: 10 + + strategy: + fail-fast: false + matrix: + # Currently supported RMQ versions per: + # https://www.rabbitmq.com/docs/which-erlang#compatibility-matrix + rabbitmq-version: ['3.11', '3.12', '3.13', '4.0'] + + services: + rabbitmq: + image: rabbitmq:${{ matrix.rabbitmq-version }}-management + ports: + - 5672:5672 + - 15672:15672 + + steps: + - uses: actions/checkout@v4 + + - name: Install aiida-core + id: install + uses: ./.github/actions/install-aiida-core + with: + python-version: '3.11' + from-lock: 'true' + + - name: Setup SSH on localhost + run: source .venv/bin/activate && .github/workflows/setup_ssh.sh + + - name: Suppress RabbitMQ version warning + run: uv run verdi config set warnings.rabbitmq_version False + + - name: Run tests + id: tests + env: + AIIDA_WARN_v3: 0 + run: uv run pytest -s --db-backend sqlite -m 'requires_rmq' tests/ + + - name: Slack notification + # Always run this step (otherwise it would be skipped if any of the previous steps fail) but only if the + # `install` or `tests` steps failed, and the `SLACK_WEBHOOK` is available. The latter is not the case for + # pull requests that come from forks. This is a limitation of secrets on GHA + if: always() && (steps.install.outcome == 'failure' || steps.tests.outcome == 'failure') && env.SLACK_WEBHOOK != null + uses: rtCamp/action-slack-notify@v2 + env: + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} + SLACK_ICON: https://www.materialscloud.org/discover/images/0ba0a17d.aiida-logo-128.png + SLACK_CHANNEL: dev-aiida-core + SLACK_COLOR: b60205 + SLACK_TITLE: RabbitMQ nightly tests of `aiida-core/main` failed + SLACK_MESSAGE: The rabbitmq tests in the `nightly.yml` GHA worklow failed. diff --git a/.github/workflows/post-release.yml b/.github/workflows/post-release.yml deleted file mode 100644 index 8ec52d86e8..0000000000 --- a/.github/workflows/post-release.yml +++ /dev/null @@ -1,59 +0,0 @@ -name: post-release - -on: - release: - types: [published, edited] - -jobs: - - upload-transifex: - # Every time when a new version is released, - # upload the latest pot files to transifex services for team transilation. - # https://www.transifex.com/aiidateam/aiida-core/dashboard/ - - # Only run this job on the main repository and not on forks - if: github.repository == 'aiidateam/aiida-core' - - name: Upload pot files to transifex - runs-on: ubuntu-latest - timeout-minutes: 30 - - # Build doc to pot files and register them to `.tx/config` file - # Installation steps are modeled after the docs job in `ci.yml` - steps: - - uses: actions/checkout@v2 - - - name: Set up Python 3.10 - uses: actions/setup-python@v2 - with: - python-version: '3.10' - - - name: Install python dependencies - run: | - pip install -U -e .[docs,tests,rest,atomic_tools] - - - name: Build pot files - env: - READTHEDOCS: 'True' - RUN_APIDOC: 'True' - run: - make -C docs gettext - - - name: Install Transifex CLI - run: | - curl -o- https://raw.githubusercontent.com/transifex/cli/master/install.sh | bash -s -- v1.6.5 - mv tx /usr/local/bin/tx - - - name: Setting transifex configuration and upload pot files - env: - PROJECT_NAME: aiida-core - USER: ${{ secrets.TRANSIFEX_USER }} - PASSWD: ${{ secrets.TRANSIFEX_PASSWORD }} - run: | - sphinx-intl create-txconfig - sphinx-intl update-txconfig-resources --pot-dir docs/build/locale --transifex-project-name ${PROJECT_NAME} - echo $'[https://www.transifex.com]\nhostname = https://www.transifex.com\nusername = '"${USER}"$'\npassword = '"${PASSWD}"$'\n' > ~/.transifexrc - - - name: Push to transifex - run: | - tx push -t -s diff --git a/.github/workflows/push_image_to_dockerhub.yml b/.github/workflows/push_image_to_dockerhub.yml deleted file mode 100644 index 3178e78e04..0000000000 --- a/.github/workflows/push_image_to_dockerhub.yml +++ /dev/null @@ -1,54 +0,0 @@ -# Build the new Docker image on every commit to the main branch and on every new tag. -# No caching is involved for the image build. The new image is then pushed to the Docker Hub. - -name: build-and-push-to-dockerhub - -on: - push: - branches: - - main - tags: - - "v[0-9]+.[0-9]+.[0-9]+*" - -jobs: - - build-and-push: - - # Only run this job on the main repository and not on forks - if: github.repository == 'aiidateam/aiida-core' - - runs-on: ubuntu-latest - timeout-minutes: 30 - - steps: - - - uses: actions/checkout@v2 - - - name: Docker meta - id: meta - uses: docker/metadata-action@v3 - with: - images: ${{ github.repository }} - tags: | - type=ref,event=branch - type=semver,pattern={{version}} - - - name: Set up QEMU - uses: docker/setup-qemu-action@v1 - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v1 - - - name: Login to DockerHub - uses: docker/login-action@v1 - with: - username: ${{ secrets.DOCKER_USERNAME }} - password: ${{ secrets.DOCKER_TOKEN }} - - - name: Build and push - id: docker_build - uses: docker/build-push-action@v2 - with: - push: true - platforms: linux/amd64, linux/arm64 - tags: ${{ steps.meta.outputs.tags }} diff --git a/.github/workflows/rabbitmq.yml b/.github/workflows/rabbitmq.yml deleted file mode 100644 index 4cab93b8ac..0000000000 --- a/.github/workflows/rabbitmq.yml +++ /dev/null @@ -1,71 +0,0 @@ -name: rabbitmq - -on: - push: - branches-ignore: [gh-pages] - pull_request: - branches-ignore: [gh-pages] - paths-ignore: ['docs/**'] - -# https://docs.github.com/en/actions/using-jobs/using-concurrency -concurrency: - # only cancel in-progress jobs or runs for the current workflow - matches against branch & tags - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - - tests: - - runs-on: ubuntu-latest - timeout-minutes: 30 - - strategy: - fail-fast: false - matrix: - rabbitmq-version: ['3.6', '3.7', '3.8'] - - services: - postgres: - image: postgres:10 - env: - POSTGRES_DB: test_aiida - POSTGRES_PASSWORD: '' - POSTGRES_HOST_AUTH_METHOD: trust - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - ports: - - 5432:5432 - rabbitmq: - image: rabbitmq:${{ matrix.rabbitmq-version }}-management - ports: - - 5672:5672 - - 15672:15672 - - steps: - - uses: actions/checkout@v2 - - - name: Set up Python 3.9 - uses: actions/setup-python@v2 - with: - python-version: '3.9' - - - name: Install system dependencies - run: sudo apt update && sudo apt install postgresql - - - name: Upgrade pip - run: | - pip install --upgrade pip - pip --version - - - name: Install aiida-core - run: | - pip install -r requirements/requirements-py-3.9.txt - pip install --no-deps -e . - pip freeze - - - name: Run tests - run: pytest -sv -k 'requires_rmq' diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 686c4bc668..dd6806f521 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -7,7 +7,10 @@ name: release on: push: tags: - - "v[0-9]+.[0-9]+.[0-9]+*" + - v[0-9]+.[0-9]+.[0-9]+* + +env: + FORCE_COLOR: 1 jobs: @@ -15,12 +18,12 @@ jobs: # Only run this job on the main repository and not on forks if: github.repository == 'aiidateam/aiida-core' - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python 3.9 - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: '3.9' - run: python .github/workflows/check_release_tag.py $GITHUB_REF @@ -28,77 +31,43 @@ jobs: pre-commit: needs: [check-release-tag] - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 timeout-minutes: 30 steps: - - uses: actions/checkout@v2 - - name: Set up Python 3.9 - uses: actions/setup-python@v2 + - uses: actions/checkout@v4 + + - name: Install aiida-core and pre-commit + uses: ./.github/actions/install-aiida-core with: - python-version: '3.9' - - name: Install system dependencies - # note libkrb5-dev is required as a dependency for the gssapi pip install - run: | - sudo apt update - sudo apt install libkrb5-dev ruby ruby-dev - - name: Install python dependencies - run: | - pip install --upgrade pip - pip install -r requirements/requirements-py-3.9.txt - pip install -e .[pre-commit] - pip freeze + python-version: '3.11' + extras: '[pre-commit]' + from-lock: 'false' + - name: Run pre-commit run: pre-commit run --all-files || ( git status --short ; git diff ; exit 1 ) tests: needs: [check-release-tag] - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 timeout-minutes: 30 services: - postgres: - image: postgres:10 - env: - POSTGRES_DB: test_aiida - POSTGRES_PASSWORD: '' - POSTGRES_HOST_AUTH_METHOD: trust - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - ports: - - 5432:5432 rabbitmq: image: rabbitmq:3.8.14-management ports: - - 5672:5672 - - 15672:15672 + - 5672:5672 + - 15672:15672 steps: - - uses: actions/checkout@v2 - - name: Set up Python 3.9 - uses: actions/setup-python@v2 - with: - python-version: '3.9' - - name: Install system dependencies - run: | - sudo apt update - sudo apt install postgresql graphviz - - - name: Upgrade pip - run: | - pip install --upgrade pip - pip --version + - uses: actions/checkout@v4 - name: Install aiida-core - run: | - pip install -r requirements/requirements-py-3.9.txt - pip install --no-deps -e . + uses: ./.github/actions/install-aiida-core + - name: Run sub-set of test suite - run: pytest -sv -k 'requires_rmq' + run: pytest -s -m requires_rmq --db-backend=sqlite tests/ publish: @@ -106,13 +75,13 @@ jobs: needs: [check-release-tag, pre-commit, tests] - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - name: Checkout source - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Set up Python 3.9 - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: '3.9' - name: install flit diff --git a/.github/workflows/setup.sh b/.github/workflows/setup.sh index 174e09b598..9f2b22f35c 100755 --- a/.github/workflows/setup.sh +++ b/.github/workflows/setup.sh @@ -1,12 +1,8 @@ #!/usr/bin/env bash set -ev -ssh-keygen -q -t rsa -b 4096 -N "" -f "${HOME}/.ssh/id_rsa" -ssh-keygen -y -f "${HOME}/.ssh/id_rsa" >> "${HOME}/.ssh/authorized_keys" -ssh-keyscan -H localhost >> "${HOME}/.ssh/known_hosts" - -# The permissions on the GitHub runner are 777 which will cause SSH to refuse the keys and cause authentication to fail -chmod 755 "${HOME}" +# Setup SSH on localhost +${GITHUB_WORKSPACE}/.github/workflows/setup_ssh.sh # Replace the placeholders in configuration files with actual values CONFIG="${GITHUB_WORKSPACE}/.github/config" @@ -23,7 +19,7 @@ verdi computer configure core.local localhost --config "${CONFIG}/localhost-conf verdi computer test localhost verdi code create core.code.installed --non-interactive --config "${CONFIG}/doubler.yaml" verdi code create core.code.installed --non-interactive --config "${CONFIG}/add.yaml" -verdi code create core.code.containerized --non-interactive --config "${CONFIG}/add-singularity.yaml" +verdi code create core.code.containerized --non-interactive --config "${CONFIG}/add-containerized.yaml" # set up slurm-ssh computer verdi computer setup --non-interactive --config "${CONFIG}/slurm-ssh.yaml" diff --git a/.github/workflows/setup_ssh.sh b/.github/workflows/setup_ssh.sh new file mode 100755 index 0000000000..a244f1e470 --- /dev/null +++ b/.github/workflows/setup_ssh.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash +set -ev + +ssh-keygen -q -t rsa -b 4096 -N "" -f "${HOME}/.ssh/id_rsa" +ssh-keygen -y -f "${HOME}/.ssh/id_rsa" >> "${HOME}/.ssh/authorized_keys" +ssh-keyscan -H localhost >> "${HOME}/.ssh/known_hosts" + +# The permissions on the GitHub runner are 777 which will cause SSH to refuse the keys and cause authentication to fail +chmod 755 "${HOME}" diff --git a/.github/workflows/test-install.yml b/.github/workflows/test-install.yml index 305daccade..1ae3de6bd9 100644 --- a/.github/workflows/test-install.yml +++ b/.github/workflows/test-install.yml @@ -3,20 +3,22 @@ name: test-install on: pull_request: paths: - - 'environment.yml' - - '**/requirements*.txt' - - 'pyproject.toml' - - 'util/dependency_management.py' - - '.github/workflows/test-install.yml' + - environment.yml + - pyproject.toml + - util/dependency_management.py + - .github/workflows/test-install.yml branches-ignore: [gh-pages] schedule: - - cron: '30 02 * * *' # nightly build + - cron: 30 02 * * * # nightly build + +env: + FORCE_COLOR: 1 # https://docs.github.com/en/actions/using-jobs/using-concurrency concurrency: - # only cancel in-progress jobs or runs for the current workflow - matches against branch & tags - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true + # only cancel in-progress jobs or runs for the current workflow - matches against branch & tags + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true jobs: @@ -24,75 +26,44 @@ jobs: # Note: The specification is also validated by the pre-commit hook. if: github.repository == 'aiidateam/aiida-core' - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 timeout-minutes: 5 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - - name: Set up Python 3.9 - uses: actions/setup-python@v2 + - name: Set up Python 3.11 + uses: actions/setup-python@v5 with: - python-version: '3.9' - - - name: Install utils/ dependencies - run: pip install -r utils/requirements.txt - - - name: Validate - run: | - python ./utils/dependency_management.py check-requirements - python ./utils/dependency_management.py validate-all - - resolve-pip-dependencies: - # Check whether the environments defined in the requirements/* files are - # resolvable. - # - # This job should use the planned `pip resolve` command once released: - # https://github.com/pypa/pip/issues/7819 - - needs: [validate-dependency-specification] - if: github.repository == 'aiidateam/aiida-core' - runs-on: ubuntu-latest - timeout-minutes: 5 - - strategy: - fail-fast: false - matrix: - python-version: ['3.9', '3.10', '3.11'] - - steps: - - uses: actions/checkout@v2 + python-version: '3.11' - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + - name: Set up uv + uses: astral-sh/setup-uv@v5 with: - python-version: ${{ matrix.python-version }} + version: 0.5.x - - name: Upgrade pip and setuptools - # Install specific version of setuptools, because 65.6.0 breaks a number of packages, such as numpy - run: | - pip install --upgrade pip - pip install setuptools==65.5.0 - pip --version + - name: Install utils/ dependencies + run: uv pip install --system -r utils/requirements.txt - - name: Create environment from requirements file. - run: | - pip install -r requirements/requirements-py-${{ matrix.python-version }}.txt - pip freeze + - name: Validate uv lockfile + run: uv lock --check + + - name: Validate conda environment file + run: python ./utils/dependency_management.py validate-environment-yml create-conda-environment: # Verify that we can create a valid conda environment from the environment.yml file. needs: [validate-dependency-specification] if: github.repository == 'aiidateam/aiida-core' - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 timeout-minutes: 5 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Setup Conda - uses: conda-incubator/setup-miniconda@v2 + uses: conda-incubator/setup-miniconda@v3 with: channels: conda-forge @@ -105,19 +76,19 @@ jobs: install-with-pip: if: github.repository == 'aiidateam/aiida-core' - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 timeout-minutes: 15 strategy: fail-fast: false matrix: - extras: [ '', '[atomic_tools,docs,notebook,rest,tests]' ] + extras: ['', '[atomic_tools,docs,notebook,rest,tests,tui]'] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python 3.9 - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: '3.9' @@ -130,21 +101,20 @@ jobs: - name: Test importing aiida if: steps.pip_install.outcome == 'success' - run: - python -c "import aiida" + run: python -c "import aiida" install-with-conda: # Verify that we can install AiiDA with conda. if: github.repository == 'aiidateam/aiida-core' - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 timeout-minutes: 25 strategy: fail-fast: false matrix: - python-version: ['3.9', '3.10', '3.11'] + python-version: ['3.9', '3.10', '3.11', '3.12'] # Not being able to install with conda on a specific Python version is # not sufficient to fail the run, but something we want to be aware of. @@ -153,14 +123,14 @@ jobs: include: # Installing with conda without specyfing the Python version should # not fail since this is advocated as part of the user documentation. - - python-version: '' - optional: false + - python-version: '' + optional: false steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Setup Conda - uses: conda-incubator/setup-miniconda@v2 + uses: conda-incubator/setup-miniconda@v3 with: channels: conda-forge @@ -189,13 +159,13 @@ jobs: tests: needs: [install-with-pip] - runs-on: ubuntu-latest - timeout-minutes: 35 + runs-on: ubuntu-24.04 + timeout-minutes: 45 strategy: fail-fast: false matrix: - python-version: ['3.9', '3.10', '3.11'] + python-version: ['3.9', '3.10', '3.11', '3.12'] services: postgres: @@ -210,157 +180,39 @@ jobs: --health-timeout 5s --health-retries 5 ports: - - 5432:5432 + - 5432:5432 rabbitmq: image: rabbitmq:3.8.14-management ports: - - 5672:5672 - - 15672:15672 + - 5672:5672 + - 15672:15672 slurm: image: xenonmiddleware/slurm:17 ports: - - 5001:22 + - 5001:22 steps: - - uses: actions/checkout@v2 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} + - uses: actions/checkout@v4 - name: Install system dependencies run: sudo apt update && sudo apt install postgresql graphviz - - name: Upgrade pip and setuptools - # It is crucial to update `setuptools` or the installation of `pymatgen` can break - # Install specific version of setuptools, because 65.6.0 breaks a number of packages, such as numpy - run: | - pip install --upgrade pip - pip install setuptools==65.5.0 - pip --version - - name: Install aiida-core - run: | - pip install -e .[atomic_tools,docs,notebook,rest,tests] - - - run: pip freeze + uses: ./.github/actions/install-aiida-core + with: + python-version: ${{ matrix.python-version }} + extras: '[atomic_tools,docs,notebook,rest,tests,tui]' + from-lock: 'false' - name: Setup AiiDA environment - run: - .github/workflows/setup.sh + run: .github/workflows/setup.sh - name: Run test suite env: + AIIDA_TEST_PROFILE: test_aiida AIIDA_WARN_v3: 1 - SQLALCHEMY_WARN_20: 1 - run: - .github/workflows/tests.sh - - - name: Freeze test environment - run: pip freeze | sed '1d' | tee requirements-py-${{ matrix.python-version }}.txt - - # Add python-version specific requirements/ file to the requirements.txt artifact. - # This artifact can be used in the next step to automatically create a pull request - # updating the requirements (in case they are inconsistent with the pyproject.toml file). - - uses: actions/upload-artifact@v1 - with: - name: requirements.txt - path: requirements-py-${{ matrix.python-version }}.txt - -# Check whether the requirements/ files are consistent with the dependency specification in the pyproject.toml file. -# If the check fails, warn the user via a comment and try to automatically create a pull request to update the files -# (does not work on pull requests from forks). - - check-requirements: - - needs: tests - - runs-on: ubuntu-latest - timeout-minutes: 5 - - steps: - - uses: actions/checkout@v2 - - - name: Set up Python 3.9 - uses: actions/setup-python@v2 - with: - python-version: 3.9 - - - name: Install utils/ dependencies - run: pip install -r utils/requirements.txt - - - name: Check consistency of requirements/ files - id: check_reqs - continue-on-error: true - run: python ./utils/dependency_management.py check-requirements DEFAULT --no-github-annotate - -# -# The following steps are only executed if the consistency check failed. -# - - name: Create commit comment - if: steps.check_reqs.outcome == 'Failure' # only run if requirements/ are inconsistent - uses: peter-evans/commit-comment@v1 - with: - token: ${{ secrets.GITHUB_TOKEN }} - path: pyproject.toml - body: | - The requirements/ files are inconsistent! - - # Check out the base branch so that we can prepare the pull request. - - name: Checkout base branch - if: steps.check_reqs.outcome == 'Failure' # only run if requirements/ are inconsistent - uses: actions/checkout@v2 - with: - ref: ${{ github.head_ref }} - clean: true - - - name: Download requirements.txt files - if: steps.check_reqs.outcome == 'Failure' # only run if requirements/ are inconsistent - uses: actions/download-artifact@v1 - with: - name: requirements.txt - path: requirements - - - name: Commit requirements files - if: steps.check_reqs.outcome == 'Failure' # only run if requirements/ are inconsistent + # TODO: Remove a workaround for VIRTUAL_ENV once the setup-uv action is updated + # https://github.com/astral-sh/setup-uv/issues/219 run: | - git add requirements/* - - - name: Create pull request for updated requirements files - if: steps.check_reqs.outcome == 'Failure' # only run if requirements/ are inconsistent - id: create_update_requirements_pr - continue-on-error: true - uses: peter-evans/create-pull-request@v3 - with: - branch: update-requirements - commit-message: "Automated update of requirements/ files." - title: "Update requirements/ files." - body: | - Update requirements files to ensure that they are consistent - with the dependencies specified in the 'pyproject.toml' file. - - Please note, that this pull request was likely created to - resolve the inconsistency for a specific dependency, however - other versions that have changed since the last update will - be included as part of this commit as well. - - Click [here](https://github.com/aiidateam/aiida-core/wiki/AiiDA-Dependency-Management) for more information. - - - name: Create PR comment on success - if: steps.create_update_requirements_pr.outcome == 'Success' - uses: peter-evans/create-or-update-comment@v1 - with: - issue-number: ${{ github.event.number }} - body: | - I automatically created a pull request (#${{ steps.create_update_requirements_pr.outputs.pr_number }}) that adapts the - requirements/ files according to the dependencies specified in the 'pyproject.toml' file. - - - name: Create PR comment on failure - if: steps.create_update_requirements_pr.outcome == 'Failure' - uses: peter-evans/create-or-update-comment@v1 - with: - issue-number: ${{ github.event.number }} - body: | - Please update the requirements/ files to ensure that they - are consistent with the dependencies specified in the 'pyproject.toml' file. + ${{ matrix.python-version == '3.9' && 'VIRTUAL_ENV=$PWD/.venv' || '' }} + pytest -n auto --db-backend psql -m 'not nightly' tests/ diff --git a/.github/workflows/tests.sh b/.github/workflows/tests.sh deleted file mode 100755 index fe31c799f9..0000000000 --- a/.github/workflows/tests.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/usr/bin/env bash -set -ev - -# Make sure the folder containing the workchains is in the python path before the daemon is started -SYSTEM_TESTS="${GITHUB_WORKSPACE}/.github/system_tests" - -# tests for the testing infrastructure -pytest --cov aiida --verbose --noconftest ${SYSTEM_TESTS}/test_test_manager.py -pytest --cov aiida --verbose --noconftest ${SYSTEM_TESTS}/test_ipython_magics.py -pytest --cov aiida --verbose --noconftest ${SYSTEM_TESTS}/test_profile_manager.py - -# Until the `${SYSTEM_TESTS}/pytest` tests are moved within `tests` we have to run them separately and pass in the path to the -# `conftest.py` explicitly, because otherwise it won't be able to find the fixtures it provides -AIIDA_TEST_PROFILE=test_aiida pytest --cov aiida --verbose tests/conftest.py ${SYSTEM_TESTS}/pytest - -# main aiida-core tests -AIIDA_TEST_PROFILE=test_aiida pytest --cov aiida --verbose tests -m 'not nightly' diff --git a/.github/workflows/tests_nightly.sh b/.github/workflows/tests_nightly.sh deleted file mode 100755 index 10f26f7f15..0000000000 --- a/.github/workflows/tests_nightly.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/usr/bin/env bash -set -ev - -# Make sure the folder containing the workchains is in the python path before the daemon is started -SYSTEM_TESTS="${GITHUB_WORKSPACE}/.github/system_tests" -MODULE_POLISH="${GITHUB_WORKSPACE}/.molecule/default/files/polish" - -export PYTHONPATH="${PYTHONPATH}:${SYSTEM_TESTS}:${MODULE_POLISH}" - -verdi daemon start 4 -verdi -p test_aiida run ${SYSTEM_TESTS}/test_daemon.py -verdi -p test_aiida run ${SYSTEM_TESTS}/test_containerized_code.py -bash ${SYSTEM_TESTS}/test_polish_workchains.sh -verdi daemon stop - -AIIDA_TEST_PROFILE=test_aiida pytest -v tests -m 'nightly' diff --git a/.github/workflows/verdi.sh b/.github/workflows/verdi.sh index ac9ecda4b6..1aaac1a0a4 100755 --- a/.github/workflows/verdi.sh +++ b/.github/workflows/verdi.sh @@ -1,15 +1,14 @@ #!/usr/bin/env bash -# Test the loading time of `verdi`. This is and attempt to catch changes to the imports in `aiida.cmdline` that will -# indirectly load the `aiida.orm` module which will trigger loading of the backend environment. This slows down `verdi` -# significantly, making tab-completion unusable. +# Test the loading time of `verdi`. This is an attempt to catch changes to the imports in `aiida.cmdline` that +# would slow down `verdi` invocations and make tab-completion unusable. VERDI=`which verdi` -# Typically, the loading time of `verdi` should be around ~0.2 seconds. When loading the database environment this -# tends to go towards ~0.8 seconds. Since these timings are obviously machine and environment dependent, typically these -# types of tests are fragile. But with a load limit of more than twice the ideal loading time, if exceeded, should give -# a reasonably sure indication that the loading of `verdi` is unacceptably slowed down. -LOAD_LIMIT=0.5 +# Typically, the loading time of `verdi` should be around ~0.2 seconds. +# Typically these types of tests are fragile. But with a load limit of more than twice +# the ideal loading time, if exceeded, should give a reasonably sure indication +# that the loading of `verdi` is unacceptably slowed down. +LOAD_LIMIT=0.4 MAX_NUMBER_ATTEMPTS=5 iteration=0 @@ -35,10 +34,6 @@ while true; do done -$VERDI devel check-load-time -$VERDI devel check-undesired-imports - - # Test that we can also run the CLI via `python -m aiida`, # that it returns a 0 exit code, and contains the expected stdout. echo "Invoking verdi via `python -m aiida`" diff --git a/.gitignore b/.gitignore index 3cf188d3f3..a4fdd01ebc 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,8 @@ .tox Pipfile +.aiida + # files created by coverage .cache .pytest_cache @@ -38,3 +40,6 @@ docs/source/reference/apidoc _sandbox pplot_out/ + +# docker +docker-bake.override.json diff --git a/.molecule/default/config_local.yml b/.molecule/default/config_local.yml index 7e7eb34736..5a01fabe5f 100644 --- a/.molecule/default/config_local.yml +++ b/.molecule/default/config_local.yml @@ -3,28 +3,28 @@ scenario: create_sequence: - - create - - prepare + - create + - prepare converge_sequence: - - create - - prepare - - converge + - create + - prepare + - converge destroy_sequence: - - destroy + - destroy test_sequence: - - destroy - - create - - prepare - - converge - - verify - - destroy + - destroy + - create + - prepare + - converge + - verify + - destroy # configuration for building the isolated container driver: name: docker platforms: - name: molecule-aiida-${AIIDA_TEST_BACKEND:-psql_dos} image: molecule_tests - context: "../.." + context: ../.. command: /sbin/my_init healthcheck: test: wait-for-services @@ -52,7 +52,7 @@ provisioner: internal_poll_interval: 0.002 ssh_connection: # reduce network operations - pipelining: True + pipelining: true inventory: hosts: all: @@ -62,7 +62,7 @@ provisioner: aiida_core_dir: /aiida-core aiida_pip_cache: /home/.cache/pip venv_bin: /opt/conda/bin - ansible_python_interpreter: "{{ venv_bin }}/python" + ansible_python_interpreter: '{{ venv_bin }}/python' aiida_backend: ${AIIDA_TEST_BACKEND:-core.psql_dos} aiida_workers: ${AIIDA_TEST_WORKERS:-2} aiida_path: /tmp/.aiida_${AIIDA_TEST_BACKEND:-psql_dos} diff --git a/.molecule/default/create_docker.yml b/.molecule/default/create_docker.yml index 2bef943879..e96f898de7 100644 --- a/.molecule/default/create_docker.yml +++ b/.molecule/default/create_docker.yml @@ -4,7 +4,7 @@ hosts: localhost connection: local gather_facts: false - no_log: "{{ molecule_no_log }}" + no_log: '{{ molecule_no_log }}' vars: molecule_labels: owner: molecule @@ -12,40 +12,40 @@ - name: Discover local Docker images docker_image_info: - name: "molecule_local/{{ item.name }}" + name: molecule_local/{{ item.name }} docker_host: "{{ item.docker_host | default(lookup('env', 'DOCKER_HOST') or 'unix://var/run/docker.sock') }}" cacert_path: "{{ item.cacert_path | default((lookup('env', 'DOCKER_CERT_PATH') + '/ca.pem') if lookup('env', 'DOCKER_CERT_PATH') else omit) }}" cert_path: "{{ item.cert_path | default((lookup('env', 'DOCKER_CERT_PATH') + '/cert.pem') if lookup('env', 'DOCKER_CERT_PATH') else omit) }}" key_path: "{{ item.key_path | default((lookup('env', 'DOCKER_CERT_PATH') + '/key.pem') if lookup('env', 'DOCKER_CERT_PATH') else omit) }}" tls_verify: "{{ item.tls_verify | default(lookup('env', 'DOCKER_TLS_VERIFY')) or false }}" - with_items: "{{ molecule_yml.platforms }}" + with_items: '{{ molecule_yml.platforms }}' register: docker_images - name: Build the container image when: - - docker_images.results | map(attribute='images') | select('equalto', []) | list | count >= 0 + - docker_images.results | map(attribute='images') | select('equalto', []) | list | count >= 0 docker_image: build: - path: "{{ item.context | default(molecule_ephemeral_directory) }}" + path: '{{ item.context | default(molecule_ephemeral_directory) }}' dockerfile: "{{ item.dockerfile | default(molecule_scenario_directory + '/Dockerfile') }}" - pull: "{{ item.pull | default(true) }}" - network: "{{ item.network_mode | default(omit) }}" - args: "{{ item.buildargs | default(omit) }}" - name: "molecule_local/{{ item.image }}" + pull: '{{ item.pull | default(true) }}' + network: '{{ item.network_mode | default(omit) }}' + args: '{{ item.buildargs | default(omit) }}' + name: molecule_local/{{ item.image }} docker_host: "{{ item.docker_host | default(lookup('env', 'DOCKER_HOST') or 'unix://var/run/docker.sock') }}" cacert_path: "{{ item.cacert_path | default((lookup('env', 'DOCKER_CERT_PATH') + '/ca.pem') if lookup('env', 'DOCKER_CERT_PATH') else omit) }}" cert_path: "{{ item.cert_path | default((lookup('env', 'DOCKER_CERT_PATH') + '/cert.pem') if lookup('env', 'DOCKER_CERT_PATH') else omit) }}" key_path: "{{ item.key_path | default((lookup('env', 'DOCKER_CERT_PATH') + '/key.pem') if lookup('env', 'DOCKER_CERT_PATH') else omit) }}" tls_verify: "{{ item.tls_verify | default(lookup('env', 'DOCKER_TLS_VERIFY')) or false }}" - force_source: "{{ item.force | default(true) }}" + force_source: '{{ item.force | default(true) }}' source: build - with_items: "{{ molecule_yml.platforms }}" + with_items: '{{ molecule_yml.platforms }}' loop_control: - label: "molecule_local/{{ item.image }}" + label: molecule_local/{{ item.image }} no_log: false register: result until: result is not failed - retries: "{{ item.retries | default(3) }}" + retries: '{{ item.retries | default(3) }}' delay: 30 - debug: @@ -57,64 +57,64 @@ {{ command_directives_dict | default({}) | combine({ item.name: item.command | default('bash -c "while true; do sleep 10000; done"') }) }} - with_items: "{{ molecule_yml.platforms }}" + with_items: '{{ molecule_yml.platforms }}' when: item.override_command | default(true) - name: Create molecule instance(s) docker_container: - name: "{{ item.name }}" + name: '{{ item.name }}' docker_host: "{{ item.docker_host | default(lookup('env', 'DOCKER_HOST') or 'unix://var/run/docker.sock') }}" cacert_path: "{{ item.cacert_path | default((lookup('env', 'DOCKER_CERT_PATH') + '/ca.pem') if lookup('env', 'DOCKER_CERT_PATH') else omit) }}" cert_path: "{{ item.cert_path | default((lookup('env', 'DOCKER_CERT_PATH') + '/cert.pem') if lookup('env', 'DOCKER_CERT_PATH') else omit) }}" key_path: "{{ item.key_path | default((lookup('env', 'DOCKER_CERT_PATH') + '/key.pem') if lookup('env', 'DOCKER_CERT_PATH') else omit) }}" tls_verify: "{{ item.tls_verify | default(lookup('env', 'DOCKER_TLS_VERIFY')) or false }}" - hostname: "{{ item.hostname | default(item.name) }}" + hostname: '{{ item.hostname | default(item.name) }}' image: "{{ item.pre_build_image | default(false) | ternary('', 'molecule_local/') }}{{ item.image }}" - pull: "{{ item.pull | default(omit) }}" - memory: "{{ item.memory | default(omit) }}" - memory_swap: "{{ item.memory_swap | default(omit) }}" + pull: '{{ item.pull | default(omit) }}' + memory: '{{ item.memory | default(omit) }}' + memory_swap: '{{ item.memory_swap | default(omit) }}' state: started recreate: false log_driver: json-file - command: "{{ (command_directives_dict | default({}))[item.name] | default(omit) }}" - user: "{{ item.user | default(omit) }}" - pid_mode: "{{ item.pid_mode | default(omit) }}" - privileged: "{{ item.privileged | default(omit) }}" - security_opts: "{{ item.security_opts | default(omit) }}" - devices: "{{ item.devices | default(omit) }}" - volumes: "{{ item.volumes | default(omit) }}" - tmpfs: "{{ item.tmpfs | default(omit) }}" - capabilities: "{{ item.capabilities | default(omit) }}" - sysctls: "{{ item.sysctls | default(omit) }}" - exposed_ports: "{{ item.exposed_ports | default(omit) }}" - published_ports: "{{ item.published_ports | default(omit) }}" - ulimits: "{{ item.ulimits | default(omit) }}" - networks: "{{ item.networks | default(omit) }}" - network_mode: "{{ item.network_mode | default(omit) }}" - networks_cli_compatible: "{{ item.networks_cli_compatible | default(true) }}" - purge_networks: "{{ item.purge_networks | default(omit) }}" - dns_servers: "{{ item.dns_servers | default(omit) }}" - etc_hosts: "{{ item.etc_hosts | default(omit) }}" - env: "{{ item.env | default(omit) }}" - restart_policy: "{{ item.restart_policy | default(omit) }}" - restart_retries: "{{ item.restart_retries | default(omit) }}" - tty: "{{ item.tty | default(omit) }}" - labels: "{{ molecule_labels | combine(item.labels | default({})) }}" + command: '{{ (command_directives_dict | default({}))[item.name] | default(omit) }}' + user: '{{ item.user | default(omit) }}' + pid_mode: '{{ item.pid_mode | default(omit) }}' + privileged: '{{ item.privileged | default(omit) }}' + security_opts: '{{ item.security_opts | default(omit) }}' + devices: '{{ item.devices | default(omit) }}' + volumes: '{{ item.volumes | default(omit) }}' + tmpfs: '{{ item.tmpfs | default(omit) }}' + capabilities: '{{ item.capabilities | default(omit) }}' + sysctls: '{{ item.sysctls | default(omit) }}' + exposed_ports: '{{ item.exposed_ports | default(omit) }}' + published_ports: '{{ item.published_ports | default(omit) }}' + ulimits: '{{ item.ulimits | default(omit) }}' + networks: '{{ item.networks | default(omit) }}' + network_mode: '{{ item.network_mode | default(omit) }}' + networks_cli_compatible: '{{ item.networks_cli_compatible | default(true) }}' + purge_networks: '{{ item.purge_networks | default(omit) }}' + dns_servers: '{{ item.dns_servers | default(omit) }}' + etc_hosts: '{{ item.etc_hosts | default(omit) }}' + env: '{{ item.env | default(omit) }}' + restart_policy: '{{ item.restart_policy | default(omit) }}' + restart_retries: '{{ item.restart_retries | default(omit) }}' + tty: '{{ item.tty | default(omit) }}' + labels: '{{ molecule_labels | combine(item.labels | default({})) }}' container_default_behavior: "{{ item.container_default_behavior | default('compatibility' if ansible_version.full is version_compare('2.10', '>=') else omit) }}" - healthcheck: "{{ item.healthcheck | default(omit) }}" + healthcheck: '{{ item.healthcheck | default(omit) }}' register: server - with_items: "{{ molecule_yml.platforms }}" + with_items: '{{ molecule_yml.platforms }}' loop_control: - label: "{{ item.name }}" + label: '{{ item.name }}' no_log: false async: 7200 poll: 0 - name: Wait for instance(s) creation to complete async_status: - jid: "{{ item.ansible_job_id }}" + jid: '{{ item.ansible_job_id }}' register: docker_jobs until: docker_jobs.finished retries: 300 - with_items: "{{ server.results }}" + with_items: '{{ server.results }}' no_log: false diff --git a/.molecule/default/files/polish/__init__.py b/.molecule/default/files/polish/__init__.py index 2776a55f97..c56ff0a1f8 100644 --- a/.molecule/default/files/polish/__init__.py +++ b/.molecule/default/files/polish/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- ########################################################################### # Copyright (c), The AiiDA team. All rights reserved. # # This file is part of the AiiDA code. # diff --git a/.molecule/default/files/polish/cli.py b/.molecule/default/files/polish/cli.py index 34f7ff0a5d..4ef0485b00 100755 --- a/.molecule/default/files/polish/cli.py +++ b/.molecule/default/files/polish/cli.py @@ -1,5 +1,4 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- ########################################################################### # Copyright (c), The AiiDA team. All rights reserved. # # This file is part of the AiiDA code. # @@ -9,6 +8,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Command line interface to dynamically create and run a WorkChain that can evaluate a reversed polish expression.""" + import importlib import sys import time @@ -25,7 +25,7 @@ @options.CODE( type=types.CodeParamType(entry_point='core.arithmetic.add'), required=False, - help='Code to perform the add operations with. Required if -C flag is specified' + help='Code to perform the add operations with. Required if -C flag is specified', ) @click.option( '-C', @@ -33,7 +33,7 @@ is_flag=True, default=False, show_default=True, - help='Use job calculations to perform all additions' + help='Use job calculations to perform all additions', ) @click.option( '-F', @@ -41,7 +41,7 @@ is_flag=True, default=False, show_default=True, - help='Use calcfunctions to perform all substractions' + help='Use calcfunctions to perform all substractions', ) @click.option( '-s', @@ -49,7 +49,7 @@ type=click.INT, default=5, show_default=True, - help='When submitting to the daemon, the number of seconds to sleep between polling the workchain process state' + help='When submitting to the daemon, the number of seconds to sleep between polling the workchain process state', ) @click.option( '-t', @@ -57,7 +57,7 @@ type=click.INT, default=60, show_default=True, - help='When submitting to the daemon, the number of seconds to wait for a workchain to finish before timing out' + help='When submitting to the daemon, the number of seconds to wait for a workchain to finish before timing out', ) @click.option( '-m', @@ -65,19 +65,18 @@ type=click.INT, default=1000000, show_default=True, - help='Specify an integer to modulo all intermediate and the final result to avoid integer overflow' + help='Specify an integer to modulo all intermediate and the final result to avoid integer overflow', ) @click.option( '-n', '--dry-run', is_flag=True, default=False, - help='Only evaluate the expression and generate the workchain but do not launch it' + help='Only evaluate the expression and generate the workchain but do not launch it', ) @decorators.with_dbenv() def launch(expression, code, use_calculations, use_calcfunctions, sleep, timeout, modulo, dry_run, daemon): - """ - Evaluate the expression in Reverse Polish Notation in both a normal way and by procedurally generating + """Evaluate the expression in Reverse Polish Notation in both a normal way and by procedurally generating a workchain that encodes the sequence of operators and gets the stack of operands as an input. Multiplications are modelled by a 'while_' construct and addition will be done performed by an addition or a subtraction, depending on the sign, branched by the 'if_' construct. Powers will be simulated by nested workchains. @@ -98,7 +97,6 @@ def launch(expression, code, use_calculations, use_calcfunctions, sleep, timeout If no expression is specified, a random one will be generated that adheres to these rules """ - # pylint: disable=too-many-arguments,too-many-locals,too-many-statements,too-many-branches from aiida.engine import run_get_node from aiida.orm import AbstractCode, Int, Str @@ -199,4 +197,4 @@ def run_via_daemon(workchains, inputs, sleep, timeout): if __name__ == '__main__': - launch() # pylint: disable=no-value-for-parameter + launch() diff --git a/.molecule/default/files/polish/lib/__init__.py b/.molecule/default/files/polish/lib/__init__.py index 2776a55f97..c56ff0a1f8 100644 --- a/.molecule/default/files/polish/lib/__init__.py +++ b/.molecule/default/files/polish/lib/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- ########################################################################### # Copyright (c), The AiiDA team. All rights reserved. # # This file is part of the AiiDA code. # diff --git a/.molecule/default/files/polish/lib/expression.py b/.molecule/default/files/polish/lib/expression.py index 1bf2123970..369b14093e 100644 --- a/.molecule/default/files/polish/lib/expression.py +++ b/.molecule/default/files/polish/lib/expression.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- ########################################################################### # Copyright (c), The AiiDA team. All rights reserved. # # This file is part of the AiiDA code. # @@ -8,6 +7,7 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Functions to dynamically generate reversed polish notation expressions.""" + import collections import operator as operators import random @@ -20,8 +20,7 @@ def generate(min_operator_count=3, max_operator_count=5, min_operand_value=-5, max_operand_value=5): - """ - Generate a random valid expression in Reverse Polish Notation. There are a few limitations: + """Generate a random valid expression in Reverse Polish Notation. There are a few limitations: * Only integers are supported * Only the addition, multiplication and power operators (+, * and ^, respectively) are supported @@ -58,8 +57,7 @@ def generate(min_operator_count=3, max_operator_count=5, min_operand_value=-5, m def validate(expression): - """ - Validate an expression in Reverse Polish Notation. In addition to normal rules, the following restrictions apply: + """Validate an expression in Reverse Polish Notation. In addition to normal rules, the following restrictions apply: * Only integers are supported * Only the addition, multiplication and power operators (+, * and ^, respectively) are supported @@ -69,7 +67,6 @@ def validate(expression): :param expression: the expression in Reverse Polish Notation :return: tuple(Bool, list) indicating whether expression is valid and if not a list of error messages """ - # pylint: disable=too-many-return-statements try: symbols = expression.split() except ValueError as exception: @@ -106,8 +103,7 @@ def validate(expression): def evaluate(expression, modulo=None): - """ - Evaluate an expression in Reverse Polish Notation. There are a few limitations: + """Evaluate an expression in Reverse Polish Notation. There are a few limitations: * Only integers are supported * Only the addition, multiplication and power operators (+, * and ^, respectively) are supported diff --git a/.molecule/default/files/polish/lib/workchain.py b/.molecule/default/files/polish/lib/workchain.py index a77e7f6b29..167c23a9fd 100644 --- a/.molecule/default/files/polish/lib/workchain.py +++ b/.molecule/default/files/polish/lib/workchain.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- ########################################################################### # Copyright (c), The AiiDA team. All rights reserved. # # This file is part of the AiiDA code. # @@ -15,7 +14,7 @@ from pathlib import Path from string import Template -from .expression import OPERATORS # pylint: disable=relative-beyond-top-level +from .expression import OPERATORS INDENTATION_WIDTH = 4 @@ -71,8 +70,7 @@ def generate_outlines(expression): - """ - For a given expression in Reverse Polish Notation, generate the nested symbolic structure of the outlines. + """For a given expression in Reverse Polish Notation, generate the nested symbolic structure of the outlines. :param expression: a valid expression :return: a nested list structure of strings representing the structure of the outlines @@ -82,7 +80,6 @@ def generate_outlines(expression): outline = [['add']] for part in expression.split(): - if part not in OPERATORS: stack.appendleft(part) values.append(part) @@ -107,8 +104,7 @@ def generate_outlines(expression): def format_outlines(outlines, use_calculations=False, use_calcfunctions=False): - """ - Given the symbolic structure of the workchain outlines produced by ``generate_outlines``, format the actual + """Given the symbolic structure of the workchain outlines produced by ``generate_outlines``, format the actual string form of those workchain outlines :param outlines: the list of symbolic outline structures @@ -119,7 +115,6 @@ def format_outlines(outlines, use_calculations=False, use_calcfunctions=False): outline_strings = [] for sub_outline in outlines: - outline_string = '' for instruction in sub_outline: @@ -140,8 +135,7 @@ def format_outlines(outlines, use_calculations=False, use_calcfunctions=False): def format_block(instruction, level=0, use_calculations=False, use_calcfunctions=False): - """ - Format the instruction into its proper string form + """Format the instruction into its proper string form :param use_calculations: use CalcJobs for the add operations :param use_calcfunctions: use calcfunctions for the subtract operations @@ -176,8 +170,7 @@ def format_block(instruction, level=0, use_calculations=False, use_calcfunctions def format_indent(level=0, width=INDENTATION_WIDTH): - """ - Format the indentation for the given indentation level and indentation width + """Format the indentation for the given indentation level and indentation width :param level: the level of indentation :param width: the width in spaces of a single indentation @@ -187,8 +180,7 @@ def format_indent(level=0, width=INDENTATION_WIDTH): def write_workchain(outlines, directory=None) -> Path: - """ - Given a list of string formatted outlines, write the corresponding workchains to file + """Given a list of string formatted outlines, write the corresponding workchains to file :returns: file path """ @@ -219,10 +211,9 @@ def write_workchain(outlines, directory=None) -> Path: counter = len(outlines) - 1 for outline in outlines: - outline_string = '' for subline in outline.split('\n'): - outline_string += f'\t\t\t{subline}\n' # pylint: disable=consider-using-join + outline_string += f'\t\t\t{subline}\n' if counter == len(outlines) - 1: child_class = None diff --git a/.molecule/default/setup_aiida.yml b/.molecule/default/setup_aiida.yml index 2d1246f985..76b3b7a7be 100644 --- a/.molecule/default/setup_aiida.yml +++ b/.molecule/default/setup_aiida.yml @@ -4,17 +4,17 @@ # run as aiida user become: true - become_method: "{{ become_method }}" + become_method: '{{ become_method }}' become_user: "{{ aiida_user | default('aiida') }}" environment: - AIIDA_PATH: "{{ aiida_path }}" + AIIDA_PATH: '{{ aiida_path }}' tasks: - name: Create a new database with name "{{ aiida_backend }}" postgresql_db: - name: "{{ aiida_backend }}" + name: '{{ aiida_backend }}' login_host: localhost login_user: aiida login_password: '' @@ -30,7 +30,7 @@ login_host: localhost login_user: aiida login_password: '' - db: "{{ aiida_backend }}" + db: '{{ aiida_backend }}' - name: verdi setup for "{{ aiida_backend }}" command: > @@ -47,10 +47,10 @@ --db-username=aiida --db-password='' args: - creates: "{{ aiida_path }}/.aiida/config.json" + creates: '{{ aiida_path }}/.aiida/config.json' - - name: "Check if computer is already present" - command: "{{ venv_bin }}/verdi -p {{ aiida_backend }} computer show localhost" + - name: Check if computer is already present + command: '{{ venv_bin }}/verdi -p {{ aiida_backend }} computer show localhost' ignore_errors: true changed_when: false no_log: true @@ -82,7 +82,7 @@ # command: "{{ venv_bin }}/verdi -p {{ aiida_backend }} daemon start {{ aiida_workers }}" - name: get verdi status - command: "{{ venv_bin }}/verdi -p {{ aiida_backend }} status" + command: '{{ venv_bin }}/verdi -p {{ aiida_backend }} status' register: verdi_status changed_when: false diff --git a/.molecule/default/setup_python.yml b/.molecule/default/setup_python.yml index 8da81bc91d..eb03064ec5 100644 --- a/.molecule/default/setup_python.yml +++ b/.molecule/default/setup_python.yml @@ -4,24 +4,24 @@ # run as root user become: true - become_method: "{{ become_method }}" + become_method: '{{ become_method }}' become_user: root tasks: - name: pip install aiida-core requirements pip: - chdir: "{{ aiida_core_dir }}" + chdir: '{{ aiida_core_dir }}' # TODO dynamically change for python version requirements: requirements/requirements-py-3.9.txt - executable: "{{ venv_bin }}/pip" + executable: '{{ venv_bin }}/pip' extra_args: --cache-dir {{ aiida_pip_cache }} register: pip_install_deps - name: pip install aiida-core pip: - chdir: "{{ aiida_core_dir }}" + chdir: '{{ aiida_core_dir }}' name: . - executable: "{{ venv_bin }}/pip" - editable: "{{ aiida_pip_editable | default(true) }}" + executable: '{{ venv_bin }}/pip' + editable: '{{ aiida_pip_editable | default(true) }}' extra_args: --no-deps diff --git a/.molecule/default/tasks/log_query_stats.yml b/.molecule/default/tasks/log_query_stats.yml index b62c53e2d5..8c437e5170 100644 --- a/.molecule/default/tasks/log_query_stats.yml +++ b/.molecule/default/tasks/log_query_stats.yml @@ -3,7 +3,7 @@ login_host: localhost login_user: "{{ aiida_user | default('aiida') }}" login_password: '' - db: "{{ aiida_backend }}" + db: '{{ aiida_backend }}' query: | SELECT CAST(sum(calls) AS INTEGER) as calls, @@ -21,7 +21,7 @@ login_host: localhost login_user: "{{ aiida_user | default('aiida') }}" login_password: '' - db: "{{ aiida_backend }}" + db: '{{ aiida_backend }}' query: | SELECT to_char(total_time, '9.99EEEE') AS time_ms, @@ -42,7 +42,7 @@ login_host: localhost login_user: "{{ aiida_user | default('aiida') }}" login_password: '' - db: "{{ aiida_backend }}" + db: '{{ aiida_backend }}' query: | SELECT to_char(total_time, '9.99EEEE') AS time_ms, diff --git a/.molecule/default/tasks/reset_query_stats.yml b/.molecule/default/tasks/reset_query_stats.yml index 44fd9e3827..c3a9f92418 100644 --- a/.molecule/default/tasks/reset_query_stats.yml +++ b/.molecule/default/tasks/reset_query_stats.yml @@ -3,5 +3,5 @@ login_host: localhost login_user: "{{ aiida_user | default('aiida') }}" login_password: '' - db: "{{ aiida_backend }}" + db: '{{ aiida_backend }}' query: SELECT pg_stat_statements_reset(); diff --git a/.molecule/default/test_polish_workchains.yml b/.molecule/default/test_polish_workchains.yml index b1649d1953..cabcf5596f 100644 --- a/.molecule/default/test_polish_workchains.yml +++ b/.molecule/default/test_polish_workchains.yml @@ -4,16 +4,16 @@ # run as aiida user become: true - become_method: "{{ become_method }}" + become_method: '{{ become_method }}' become_user: "{{ aiida_user | default('aiida') }}" environment: - AIIDA_PATH: "{{ aiida_path }}" + AIIDA_PATH: '{{ aiida_path }}' tasks: - - name: "Check if add code is already present" - command: "{{ venv_bin }}/verdi -p {{ aiida_backend }} code show add@localhost" + - name: Check if add code is already present + command: '{{ venv_bin }}/verdi -p {{ aiida_backend }} code show add@localhost' ignore_errors: true changed_when: false no_log: true @@ -30,14 +30,14 @@ - name: Copy workchain files copy: src: polish - dest: "${HOME}/{{ aiida_backend }}" + dest: ${HOME}/{{ aiida_backend }} - name: get python path including workchains command: echo "${PYTHONPATH}:${HOME}/{{ aiida_backend }}/polish" register: echo_pythonpath - set_fact: - aiida_pythonpath: "{{ echo_pythonpath.stdout }}" + aiida_pythonpath: '{{ echo_pythonpath.stdout }}' - name: Reset pythonpath of daemon ({{ aiida_workers }} workers) # note `verdi daemon restart` did not seem to update the environmental variables? @@ -45,12 +45,12 @@ {{ venv_bin }}/verdi -p {{ aiida_backend }} daemon stop {{ venv_bin }}/verdi -p {{ aiida_backend }} daemon start {{ aiida_workers }} environment: - PYTHONPATH: "{{ aiida_pythonpath }}" + PYTHONPATH: '{{ aiida_pythonpath }}' - when: aiida_query_stats | default(false) | bool include_tasks: tasks/reset_query_stats.yml - - name: "run polish workchains" + - name: run polish workchains # Note the exclamation point after the code is necessary to force the value to be interpreted as LABEL type identifier shell: | set -e @@ -61,22 +61,22 @@ args: executable: /bin/bash vars: - polish_script: "${HOME}/{{ aiida_backend }}/polish/cli.py" + polish_script: ${HOME}/{{ aiida_backend }}/polish/cli.py polish_timeout: 600 polish_expressions: - - "1 -2 -1 4 -5 -5 * * * * +" - - "2 1 3 3 -1 + ^ ^ +" - - "3 -5 -1 -4 + * ^" - - "2 4 2 -4 * * +" - - "3 1 1 5 ^ ^ ^" + - 1 -2 -1 4 -5 -5 * * * * + + - 2 1 3 3 -1 + ^ ^ + + - 3 -5 -1 -4 + * ^ + - 2 4 2 -4 * * + + - 3 1 1 5 ^ ^ ^ # - "3 1 3 4 -4 2 * + + ^ ^" # this takes a longer time to run environment: - PYTHONPATH: "{{ aiida_pythonpath }}" + PYTHONPATH: '{{ aiida_pythonpath }}' register: polish_output - name: print polish workchain output debug: - msg: "{{ polish_output.stdout }}" + msg: '{{ polish_output.stdout }}' - when: aiida_query_stats | default(false) | bool include_tasks: tasks/log_query_stats.yml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b21190151f..fb2019ac26 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,241 +1,236 @@ ci: - autoupdate_schedule: monthly - autofix_prs: true - skip: [mypy, pylint, dm-generate-all, dependencies, verdi-autodocs] + autofix_prs: true + autoupdate_commit_msg: 'Devops: Update pre-commit dependencies' + autoupdate_schedule: quarterly + skip: [mypy, check-uv-lock, generate-conda-environment, validate-conda-environment, verdi-autodocs] repos: -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 - hooks: - - id: double-quote-string-fixer - - id: end-of-file-fixer - exclude: &exclude_pre_commit_hooks > - (?x)^( - tests/.*(? - (?x)^( - docs/.*| - )$ - args: ['-i'] - additional_dependencies: ['toml'] - -- repo: local - - hooks: - - - id: imports - name: imports - entry: python utils/make_all.py - language: python - types: [python] - require_serial: true - pass_filenames: false - files: aiida/.*py - - - id: mypy - name: mypy - entry: mypy - args: [--config-file=pyproject.toml] - language: python - types: [python] - require_serial: true - pass_filenames: true - exclude: >- - (?x)^( - .github/.*| - .molecule/.*| - docs/.*| - utils/.*| - - aiida/calculations/arithmetic/add.py| - aiida/calculations/diff_tutorial/calculations.py| - aiida/calculations/templatereplacer.py| - aiida/calculations/transfer.py| - aiida/cmdline/commands/cmd_archive.py| - aiida/cmdline/commands/cmd_calcjob.py| - aiida/cmdline/commands/cmd_code.py| - aiida/cmdline/commands/cmd_computer.py| - aiida/cmdline/commands/cmd_data/cmd_list.py| - aiida/cmdline/commands/cmd_data/cmd_upf.py| - aiida/cmdline/commands/cmd_devel.py| - aiida/cmdline/commands/cmd_group.py| - aiida/cmdline/commands/cmd_node.py| - aiida/cmdline/commands/cmd_shell.py| - aiida/cmdline/commands/cmd_storage.py| - aiida/cmdline/groups/dynamic.py| - aiida/cmdline/params/options/commands/setup.py| - aiida/cmdline/params/options/interactive.py| - aiida/cmdline/params/options/main.py| - aiida/cmdline/params/options/multivalue.py| - aiida/cmdline/params/types/group.py| - aiida/cmdline/params/types/plugin.py| - aiida/cmdline/utils/ascii_vis.py| - aiida/cmdline/utils/common.py| - aiida/cmdline/utils/echo.py| - aiida/common/extendeddicts.py| - aiida/common/hashing.py| - aiida/common/utils.py| - aiida/engine/daemon/execmanager.py| - aiida/engine/processes/calcjobs/manager.py| - aiida/engine/processes/calcjobs/monitors.py| - aiida/engine/processes/calcjobs/tasks.py| - aiida/engine/processes/control.py| - aiida/engine/processes/ports.py| - aiida/manage/configuration/__init__.py| - aiida/manage/configuration/config.py| - aiida/manage/configuration/profile.py| - aiida/manage/configuration/settings.py| - aiida/manage/external/rmq/launcher.py| - aiida/manage/tests/main.py| - aiida/manage/tests/pytest_fixtures.py| - aiida/orm/comments.py| - aiida/orm/computers.py| - aiida/orm/implementation/storage_backend.py| - aiida/orm/nodes/caching.py| - aiida/orm/nodes/comments.py| - aiida/orm/nodes/data/array/array.py| - aiida/orm/nodes/data/array/bands.py| - aiida/orm/nodes/data/array/trajectory.py| - aiida/orm/nodes/data/cif.py| - aiida/orm/nodes/data/remote/base.py| - aiida/orm/nodes/data/structure.py| - aiida/orm/nodes/data/upf.py| - aiida/orm/nodes/process/calculation/calcjob.py| - aiida/orm/nodes/process/process.py| - aiida/orm/utils/builders/code.py| - aiida/orm/utils/builders/computer.py| - aiida/orm/utils/calcjob.py| - aiida/orm/utils/node.py| - aiida/orm/utils/remote.py| - aiida/repository/backend/disk_object_store.py| - aiida/repository/backend/sandbox.py| - aiida/restapi/common/utils.py| - aiida/restapi/resources.py| - aiida/restapi/run_api.py| - aiida/restapi/translator/base.py| - aiida/restapi/translator/computer.py| - aiida/restapi/translator/group.py| - aiida/restapi/translator/nodes/.*| - aiida/restapi/translator/user.py| - aiida/schedulers/plugins/direct.py| - aiida/schedulers/plugins/lsf.py| - aiida/schedulers/plugins/pbsbaseclasses.py| - aiida/schedulers/plugins/sge.py| - aiida/schedulers/plugins/slurm.py| - aiida/storage/psql_dos/migrations/utils/integrity.py| - aiida/storage/psql_dos/migrations/utils/legacy_workflows.py| - aiida/storage/psql_dos/migrations/utils/migrate_repository.py| - aiida/storage/psql_dos/migrations/utils/parity.py| - aiida/storage/psql_dos/migrations/utils/reflect.py| - aiida/storage/psql_dos/migrations/utils/utils.py| - aiida/storage/psql_dos/migrations/versions/1de112340b16_django_parity_1.py| - aiida/storage/psql_dos/migrator.py| - aiida/storage/psql_dos/models/.*| - aiida/storage/psql_dos/orm/.*| - aiida/storage/sqlite_temp/backend.py| - aiida/storage/sqlite_zip/backend.py| - aiida/storage/sqlite_zip/migrations/legacy_to_main.py| - aiida/storage/sqlite_zip/migrator.py| - aiida/storage/sqlite_zip/models.py| - aiida/storage/sqlite_zip/orm.py| - aiida/tools/data/array/kpoints/legacy.py| - aiida/tools/data/array/kpoints/seekpath.py| - aiida/tools/data/orbital/orbital.py| - aiida/tools/data/orbital/realhydrogen.py| - aiida/tools/dbimporters/plugins/.*| - aiida/tools/graph/age_entities.py| - aiida/tools/graph/age_rules.py| - aiida/tools/graph/deletions.py| - aiida/tools/graph/graph_traversers.py| - aiida/tools/groups/paths.py| - aiida/tools/query/calculation.py| - aiida/tools/query/mapping.py| - aiida/transports/cli.py| - aiida/transports/plugins/local.py| - aiida/transports/plugins/ssh.py| - aiida/workflows/arithmetic/multiply_add.py| - - tests/conftest.py| - tests/repository/conftest.py| - tests/repository/test_repository.py| - tests/sphinxext/sources/workchain/conf.py| - tests/sphinxext/sources/workchain_broken/conf.py| - tests/storage/psql_dos/migrations/conftest.py| - tests/storage/psql_dos/migrations/django_branch/test_0026_0027_traj_data.py| - tests/test_calculation_node.py| - tests/test_nodes.py| - - )$ - - - id: pylint - name: pylint - entry: pylint - types: [python] - language: system - exclude: *exclude_files - - - id: dm-generate-all - name: Update all requirements files - entry: python ./utils/dependency_management.py generate-all - language: system - pass_filenames: false - files: >- - (?x)^( - pyproject.toml| - utils/dependency_management.py - )$ - - - id: dependencies - name: Validate environment.yml - entry: python ./utils/dependency_management.py validate-environment-yml - language: system - pass_filenames: false - files: >- - (?x)^( - pyproject.toml| - utils/dependency_management.py| - environment.yml| - )$ - - - id: verdi-autodocs - name: Automatically generating verdi docs - entry: python ./utils/validate_consistency.py verdi-autodocs - language: system - pass_filenames: false - files: >- - (?x)^( - aiida/cmdline/commands/.*| - aiida/cmdline/params/.*| - aiida/cmdline/params/types/.*| - utils/validate_consistency.py| - )$ +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: check-merge-conflict + - id: check-yaml + - id: double-quote-string-fixer + - id: end-of-file-fixer + exclude: &exclude_pre_commit_hooks > + (?x)^( + tests/.*(?- + (?x)^( + tests/.*| + environment.yml| + )$ + +- repo: local + + hooks: + + - id: imports + name: imports + entry: python utils/make_all.py + language: python + types: [python] + require_serial: true + pass_filenames: false + files: src/aiida/.*py + + - id: mypy + name: mypy + entry: mypy + args: [--config-file=pyproject.toml] + language: python + types: [python] + require_serial: true + pass_filenames: true + exclude: >- + (?x)^( + .github/.*| + .molecule/.*| + .docker/.*| + docs/.*| + utils/.*| + tests/.*| + + src/aiida/calculations/arithmetic/add.py| + src/aiida/calculations/diff_tutorial/calculations.py| + src/aiida/calculations/templatereplacer.py| + src/aiida/calculations/transfer.py| + src/aiida/cmdline/commands/cmd_archive.py| + src/aiida/cmdline/commands/cmd_calcjob.py| + src/aiida/cmdline/commands/cmd_code.py| + src/aiida/cmdline/commands/cmd_computer.py| + src/aiida/cmdline/commands/cmd_data/cmd_list.py| + src/aiida/cmdline/commands/cmd_data/cmd_upf.py| + src/aiida/cmdline/commands/cmd_devel.py| + src/aiida/cmdline/commands/cmd_group.py| + src/aiida/cmdline/commands/cmd_node.py| + src/aiida/cmdline/commands/cmd_shell.py| + src/aiida/cmdline/commands/cmd_storage.py| + src/aiida/cmdline/params/options/commands/setup.py| + src/aiida/cmdline/params/options/interactive.py| + src/aiida/cmdline/params/options/main.py| + src/aiida/cmdline/params/options/multivalue.py| + src/aiida/cmdline/params/types/group.py| + src/aiida/cmdline/utils/ascii_vis.py| + src/aiida/cmdline/utils/common.py| + src/aiida/cmdline/utils/echo.py| + src/aiida/common/extendeddicts.py| + src/aiida/common/utils.py| + src/aiida/engine/daemon/execmanager.py| + src/aiida/engine/processes/calcjobs/manager.py| + src/aiida/engine/processes/calcjobs/monitors.py| + src/aiida/engine/processes/calcjobs/tasks.py| + src/aiida/engine/processes/control.py| + src/aiida/engine/processes/ports.py| + src/aiida/manage/configuration/__init__.py| + src/aiida/manage/configuration/config.py| + src/aiida/manage/external/rmq/launcher.py| + src/aiida/manage/tests/main.py| + src/aiida/manage/tests/pytest_fixtures.py| + src/aiida/orm/comments.py| + src/aiida/orm/computers.py| + src/aiida/orm/implementation/storage_backend.py| + src/aiida/orm/nodes/comments.py| + src/aiida/orm/nodes/data/array/bands.py| + src/aiida/orm/nodes/data/array/trajectory.py| + src/aiida/orm/nodes/data/cif.py| + src/aiida/orm/nodes/data/remote/base.py| + src/aiida/orm/nodes/data/structure.py| + src/aiida/orm/nodes/data/upf.py| + src/aiida/orm/nodes/process/calculation/calcjob.py| + src/aiida/orm/nodes/process/process.py| + src/aiida/orm/utils/builders/code.py| + src/aiida/orm/utils/builders/computer.py| + src/aiida/orm/utils/calcjob.py| + src/aiida/orm/utils/node.py| + src/aiida/repository/backend/disk_object_store.py| + src/aiida/repository/backend/sandbox.py| + src/aiida/restapi/common/utils.py| + src/aiida/restapi/resources.py| + src/aiida/restapi/run_api.py| + src/aiida/restapi/translator/base.py| + src/aiida/restapi/translator/computer.py| + src/aiida/restapi/translator/group.py| + src/aiida/restapi/translator/nodes/.*| + src/aiida/restapi/translator/user.py| + src/aiida/schedulers/plugins/direct.py| + src/aiida/schedulers/plugins/lsf.py| + src/aiida/schedulers/plugins/pbsbaseclasses.py| + src/aiida/schedulers/plugins/sge.py| + src/aiida/schedulers/plugins/slurm.py| + src/aiida/storage/psql_dos/migrations/utils/integrity.py| + src/aiida/storage/psql_dos/migrations/utils/legacy_workflows.py| + src/aiida/storage/psql_dos/migrations/utils/migrate_repository.py| + src/aiida/storage/psql_dos/migrations/utils/parity.py| + src/aiida/storage/psql_dos/migrations/utils/reflect.py| + src/aiida/storage/psql_dos/migrations/utils/utils.py| + src/aiida/storage/psql_dos/migrations/versions/1de112340b16_django_parity_1.py| + src/aiida/storage/psql_dos/migrator.py| + src/aiida/storage/psql_dos/models/.*| + src/aiida/storage/psql_dos/orm/.*| + src/aiida/storage/sqlite_temp/backend.py| + src/aiida/storage/sqlite_zip/backend.py| + src/aiida/storage/sqlite_zip/migrations/legacy_to_main.py| + src/aiida/storage/sqlite_zip/migrator.py| + src/aiida/storage/sqlite_zip/models.py| + src/aiida/storage/sqlite_zip/orm.py| + src/aiida/tools/data/array/kpoints/legacy.py| + src/aiida/tools/data/array/kpoints/seekpath.py| + src/aiida/tools/data/orbital/orbital.py| + src/aiida/tools/data/orbital/realhydrogen.py| + src/aiida/tools/dbimporters/plugins/.*| + src/aiida/tools/graph/age_entities.py| + src/aiida/tools/graph/age_rules.py| + src/aiida/tools/graph/deletions.py| + src/aiida/tools/graph/graph_traversers.py| + src/aiida/tools/groups/paths.py| + src/aiida/tools/query/calculation.py| + src/aiida/tools/query/mapping.py| + src/aiida/transports/cli.py| + src/aiida/transports/plugins/local.py| + src/aiida/transports/plugins/ssh.py| + src/aiida/workflows/arithmetic/multiply_add.py| + )$ + + - id: check-uv-lock + name: Check uv lockfile up to date + # NOTE: This will not automatically update the lockfile + entry: uv lock --check + language: system + pass_filenames: false + files: >- + (?x)^( + pyproject.toml| + uv.lock| + )$ + + - id: generate-conda-environment + name: Update conda environment file + entry: python ./utils/dependency_management.py generate-environment-yml + language: system + pass_filenames: false + files: >- + (?x)^( + pyproject.toml| + utils/dependency_management.py + )$ + + - id: validate-conda-environment + name: Validate environment.yml + entry: python ./utils/dependency_management.py validate-environment-yml + language: system + pass_filenames: false + files: >- + (?x)^( + pyproject.toml| + utils/dependency_management.py| + environment.yml| + )$ + + - id: verdi-autodocs + name: Automatically generating verdi docs + entry: python ./utils/validate_consistency.py verdi-autodocs + language: system + pass_filenames: false + files: >- + (?x)^( + src/aiida/cmdline/commands/.*| + src/aiida/cmdline/params/.*| + src/aiida/cmdline/params/types/.*| + utils/validate_consistency.py| + )$ diff --git a/.readthedocs.yml b/.readthedocs.yml index d2b58b1aac..6df080b365 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -8,26 +8,25 @@ formats: [] build: apt_packages: - - graphviz + - graphviz os: ubuntu-22.04 tools: - python: "3.10" - -# Need to install the package itself such that the entry points are installed and the API doc can build properly -python: - install: - - method: pip - path: . - extra_requirements: - - docs - - tests - - rest - - atomic_tools + python: '3.11' + jobs: + # Use uv to speed up the build + # https://docs.readthedocs.io/en/stable/build-customization.html#install-dependencies-with-uv + pre_create_environment: + - asdf plugin add uv + - asdf install uv 0.2.9 + - asdf global uv 0.2.9 + post_install: + - VIRTUAL_ENV=$READTHEDOCS_VIRTUALENV_PATH uv pip install .[docs,tests,rest,atomic_tools] --preview # Let the build fail if there are any warnings sphinx: - builder: html - fail_on_warning: true + builder: html + configuration: docs/source/conf.py + fail_on_warning: true search: ranking: diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d51740c1a..04327d7105 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,688 @@ # Changelog +## v2.6.3 - 2024-11-6 + +### Fixes +- CLI: Fix exception for `verdi plugin list` (#6560) [[c3b10b7]](https://github.com/aiidateam/aiida-core/commit/c3b10b759a9cd062800ef120591d5c7fd0ae4ee7) +- `DirectScheduler`: Ensure killing child processes (#6572) [[fddffca]](https://github.com/aiidateam/aiida-core/commit/fddffca67b4f7e3b76b19df7db8e1511c449d2d9) +- Engine: Fix state change broadcast before process node is updated (#6580) [[867353c]](https://github.com/aiidateam/aiida-core/commit/867353c415c61d94a2427d5225dd5224a1b95fb9) + +### Devops +- Docker: Replace sleep with `s6-notifyoncheck` (#6475) [[9579378b]](https://github.com/aiidateam/aiida-core/commit/9579378ba063237baa5b73380eb8e9f0a28529ee) +- Fix failed docker CI using more reasoning grep regex to parse python version (#6581) [[332a4a91]](https://github.com/aiidateam/aiida-core/commit/332a4a915771afedcb144463b012558e4669e529) +- DevOps: Fix json query in reading the docker names to filter out fields not starting with aiida (#6573) [[e1467edc]](https://github.com/aiidateam/aiida-core/commit/e1467edca902867e53605e0e60b67f8767bf8d3e) + + +## v2.6.2 - 2024-08-07 + +### Fixes +- `LocalTransport`: Fix typo for `ignore_nonexisting` in `put` (#6471) [[ecda558d0]](https://github.com/aiidateam/aiida-core/commit/ecda558d08c5608880308f69a21c05fe918be89f) +- CLI: `verdi computer test` report correct failed tests (#6536) [[9c3f2bb58]](https://github.com/aiidateam/aiida-core/commit/9c3f2bb589f1a6cc920ed2fbf0627924d8fce954) +- CLI: Fix `verdi storage migrate` for profile without broker (#6550) [[389fc487d]](https://github.com/aiidateam/aiida-core/commit/389fc487d092c2e34713d228e38c8164608bab2d) +- CLI: Fix bug `verdi presto` when tab-completing without config (#6535) [[efcf75e40]](https://github.com/aiidateam/aiida-core/commit/efcf75e405dcef8ca8c51b19d6262ca81f4413c3) +- Engine: Change signature of `set_process_state_change_timestamp` [[923fd9f6e]](https://github.com/aiidateam/aiida-core/commit/923fd9f6ec3cb39b8985ac149985046e720438f8) +- Engine: Ensure node is sealed when process excepts (#6549) [[e3ed9a2f3]](https://github.com/aiidateam/aiida-core/commit/e3ed9a2f3bf84e72cdc4c90e21494dc2b9c1cd56) +- Engine: Fix bug in upload calculation for `PortableCode` with SSH (#6519) [[740ae2040]](https://github.com/aiidateam/aiida-core/commit/740ae20408e4e3047c26c91221de5b96b8d7afbe) +- Engine: Ignore failing process state change for `core.sqlite_dos` [[fb4f9815f]](https://github.com/aiidateam/aiida-core/commit/fb4f9815fbd3bfe053cc0d1e3abb5b86bbc9dffd) + +### Dependencies +- Pin requirement to minor version `sphinx~=7.2.0` (#6527) [[25cb73188]](https://github.com/aiidateam/aiida-core/commit/25cb731880d45045773f7674fee9367d792aeda9) + +### Documentation +- Add `PluggableSchemaValidator` to nitpick exceptions (#6515) [[0ce3c0025]](https://github.com/aiidateam/aiida-core/commit/0ce3c0025384e95006d43e98da2a96b2cfadde70) +- Add `robots.txt` to only allow indexing of `latest` and `stable` (#6517) [[a492e3492]](https://github.com/aiidateam/aiida-core/commit/a492e349288ee90896193575edf319c2b7ea4c85) +- Add succint overview of limitations of no-services profile [[d7ca5657b]](https://github.com/aiidateam/aiida-core/commit/d7ca5657b0dec83cb8a440d0b8e11c9c84e38149) +- Correct signature of `get_daemon_client` example snippet (#6554) [[92d391658]](https://github.com/aiidateam/aiida-core/commit/92d391658e5e92113ad9a32d61444e23342ee0ae) +- Fix typo in pytest plugins codeblock (#6513) [[eb23688fe]](https://github.com/aiidateam/aiida-core/commit/eb23688febacb5b828d0f45ebdeedee65d0c00fe) +- Move limitations warning to top of quick install [[493e529a7]](https://github.com/aiidateam/aiida-core/commit/493e529a7aec052f81d1bbfed33e5b44db9434e2) +- Remove note in installation guide regarding Python requirement [[9aa2044e4]](https://github.com/aiidateam/aiida-core/commit/9aa2044e4ce665d64bb3a1c26f166512287aa28b) +- Update `redirects.txt` for installation pages (#6509) [[508a9fb2a]](https://github.com/aiidateam/aiida-core/commit/508a9fb2a7f6eb6f31b40b3f28d13c96f7875503) + +### Devops +- Fix pymatgen import causing mypy to fail (#6540) [[813374fe1]](https://github.com/aiidateam/aiida-core/commit/813374fe1ee003c8c05f9aa191147ec928dfa918) + + +## v2.6.1 - 2024-07-01 + +### Fixes +- Fixtures: Make `pgtest` truly an optional dependency [[9fe8fd2e0]](https://github.com/aiidateam/aiida-core/commit/9fe8fd2e0b88e746ee2156eccb71b7adbab6b2c5) + + +## v2.6.0 - 2024-07-01 + +This minor release comes with a number of features that are focused on user friendliness and ease-of-use of the CLI and the API. +The caching mechanism has received a number of improvements guaranteeing even greater savings of computational time. +For existing calculations to be valid cache sources in the new version, their hash has to be regenerated (see [Improvements and changes to caching](#improvements-and-changes-to-caching) for details). + +- [Making RabbitMQ optional](#making-rabbitmq-optional) +- [Simplifying profile setup](#simplifying-profile-setup) +- [Improved test fixtures without services](#improved-test-fixtures-without-services) +- [Improvements and changes to caching](#improvements-and-changes-to-caching) +- [Programmatic syntax for query builder filters and projections](#programmatic-syntax-for-query-builder-filters-and-projections) +- [Automated profile storage backups](#automated-profile-storage-backups) +- [Full list of changes](#full-list-of-changes) + - [Features](#features) + - [Performance](#performance) + - [Changes](#changes) + - [Fixes](#fixes) + - [Deprecations](#deprecations) + - [Dependencies](#dependencies) + - [Refactoring](#refactoring) + - [Documentation](#documentation) + - [Devops](#devops) + + +### Making RabbitMQ optional + +The RabbitMQ message broker service is now optional for running AiiDA. +The requirement was added in AiiDA v1.0 when the engine was completely overhauled. +Although it significantly improved the scaling and responsiveness, it also made it more difficult to start using AiiDA. +As of v2.6, profiles can now be configured without RabbitMQ, at the cost that the daemon can not be used and all processes have to be run locally. + +### Simplifying profile setup + +With the removal of RabbitMQ as a hard requirement, combined with storage plugins that replace PostgreSQL with the serverless SQLite that were introduced in v2.5, it is now possible to setup a profile that requires no services. +A new command is introduced, `verdi presto`, that automatically creates a profile with sensible defaults. +This now in principle makes it possible to run just the two following commands on any operating system: +``` +pip install aiida-core +verdi presto +``` +and get a working AiiDA installation that is ready to go. +As a bonus, it also configures the localhost as a `Computer`. +See the [documentation for more details](https://aiida.readthedocs.io/projects/aiida-core/en/v2.6.0/installation/guide_quick.html). + +### Improved test fixtures without services + +Up till now, running tests would always require a fully functional profile, which meant that PostgreSQL and RabbitMQ had to be available. +As described in the section above, it is now possible to set up a profile without these services. +This new feature is leveraged to provide a set of `pytest` fixtures that provide a test profile that can be used essentially on any system that just has AiiDA installed. +To start writing tests, simply create a `conftest.py` and import the fixtures with: +```python +pytest_plugins = 'aiida.tools.pytest_fixtures' +``` +The new fixtures include the `aiida_profile` fixture which is session-scoped and automatically loaded. +The fixture creates a temporary test profile at the start of the test session and automatically deletes it when the session ends. +For more information and an overview of all available fixtures, please refer to [the documentation on `pytest` fixtures](https://aiida.readthedocs.io/projects/aiida-core/en/v2.6.0/topics/plugins.html#plugin-test-fixtures). + +### Improvements and changes to caching + +A number of fixes and changes to the caching mechanism were introduced (see the [changes](#changes) subsection of the [full list of changes](#full-list-of-changes) for a more detailed overview). +For existing calculations to be valid cache sources in the new version, their hash has to be regenerated by running `verdi node rehash`. +Note that this can take a while for large databases. + +Since its introduction, the cache would essentially be reset each time AiiDA or any of the plugin packages would be updated, since the version of these packages were included in the calculation of the node hashes. +This was originally done out of precaution to err on the safe-side and limit the possibility of false-positives in cache hits. +However, this strategy has turned out to be unnecessarily cautious and severely limited the effectiveness of caching. + +The package version information is no longer included in the hash and therefore no longer impacts the caching. +This change does now make it possible for false positives if the implementation of a `CalcJob` or `Parser` plugin changes signficantly. +Therefore, a mechanism is introduced to give control to these plugins to effectively reset the cache of existing nodes. +Please refer to the [documentation on controlling caching](https://aiida.readthedocs.io/projects/aiida-core/en/v2.6.0/topics/provenance/caching.html#calculation-jobs-and-parsers) for more details. + +### Programmatic syntax for query builder filters and projections + +In the `QueryBuilder`, fields to filter on or project always had to be provided with strings: +```python +QueryBuilder().append(Node, filters={'label': 'some-label'}, project=['id', 'ctime']) +``` +and it is not always trivial to know what fields exist that _can_ be filtered on and can be projected. +In addition, there was a discrepancy for some fields, most notably the `pk` property, which had to be converted to `id` in the query builder syntax. + +These limitations have been solved as each class in AiiDA's ORM now defines the `fields` property, which allows to discover these fields programmatically. +The example above would convert to: +```python +QueryBuilder().append(Node, filters={Node.fields.label: 'some-label'}, project=[Node.fields.pk, Node.fields.ctime]) +``` +The `fields` property provides tab-completion allowing easy discovery of available fields for an ORM class in IDEs and interactive shells. +The fields also allow to express logical conditions programmatically and more. +For more details, please refer to the [documentation on programmatic field syntax](https://aiida.readthedocs.io/projects/aiida-core/en/v2.6.0/howto/query.html#programmatic-syntax-for-filters). + +Data plugins can also define custom fields, adding on top of the fields inherited from their base class(es). +The [documentation on data plugin fields](https://aiida.readthedocs.io/projects/aiida-core/en/v2.6.0/topics/data_types.html#fields) provides more information, but the API is currently in beta and guaranteed to be changed in an upcoming version. +It is therefore recommended for plugin developers to hold off making use of this new API. + +### Automated profile storage backups + +A generic mechanism has been implemented to allow easily backing up the data of a profile. +The command `verdi storage backup` automatically maintains a directory structure of previous backups allowing efficient incremental backups. +Note that the exact details of the backup mechanism is dependent on the storage plugin that is used by the profile and not all storage plugins necessarily implement it. +For now the storage plugins `core.psql_dos`, and `core.sqlite_dos` implement the functionality. +For more information, please refer [to the documentation](https://aiida.readthedocs.io/projects/aiida-core/en/v2.6.0/howto/installation.html#backing-up-your-installation). +Please refer to [this section of the documentation](https://aiida.readthedocs.io/projects/aiida-core/en/v2.6.0/howto/installation.html#restoring-data-from-a-backup) for instructions to restore from a backup. + +### Full list of changes + +#### Features +- `CalcJob`: Allow to define order of copying of input files [[6898ff4d8]](https://github.com/aiidateam/aiida-core/commit/6898ff4d8c263cf08707c61411a005f6a7f731dd) +- `SqliteDosStorage`: Implement the backup functionality [[18e447c77]](https://github.com/aiidateam/aiida-core/commit/18e447c77f48a18f361e458186cd87b2355aea75) +- `SshTransport`: Return `FileNotFoundError` if destination parent does not exist [[d86bb38bf]](https://github.com/aiidateam/aiida-core/commit/d86bb38bf9a0ced8029f8a4b895e1a6be1ccb339) +- Add improved more configurable versions of `pytest` fixtures [[e3a60460e]](https://github.com/aiidateam/aiida-core/commit/e3a60460ef1208a5c46ecd6af35d891a88ee784e) +- Add the `orm.Entity.fields` interface for `QueryBuilder` [[4b9abe2bd]](https://github.com/aiidateam/aiida-core/commit/4b9abe2bd0bb82449547a3377c2b6dbc7c174123) +- CLI: `verdi computer test` make unexpected output check optional [[589a3b2c0]](https://github.com/aiidateam/aiida-core/commit/589a3b2c03d44cebd26e88243ca34fcdb0e23ff4) +- CLI: `verdi node graph generate` root nodes as arguments [[06f8f4cfb]](https://github.com/aiidateam/aiida-core/commit/06f8f4cfb0731ff699d5c01ad85418b6db0f6778) +- CLI: Add `--most-recent-node` option to `verdi process watch` [[72692fa5c]](https://github.com/aiidateam/aiida-core/commit/72692fa5cb667e2a7462770af18b7cedeaf8b3f0) +- CLI: Add `--sort/--no-sort` to `verdi code export` [[80c606890]](https://github.com/aiidateam/aiida-core/commit/80c60689063f1517c3de91d86eef80f7852667e3) +- CLI: Add `verdi process dump` and the `ProcessDumper` [[6291accf0]](https://github.com/aiidateam/aiida-core/commit/6291accf0538eafe7426e89bc4c1e9eb90ce0385) +- CLI: Add RabbitMQ options to `verdi profile setup` [[f553f805e]](https://github.com/aiidateam/aiida-core/commit/f553f805e86d766da6208eb1682f7cf12c7907ac) +- CLI: Add the `-M/--most-recent-node` option [[5aae874aa]](https://github.com/aiidateam/aiida-core/commit/5aae874aaa44459ce8cf3ddd3bf1a82d8a2e8d37) +- CLI: Add the `verdi computer export` command [[9e3ebf6ea]](https://github.com/aiidateam/aiida-core/commit/9e3ebf6ea55d883c7857a1dbafe398b9579cca03) +- CLI: Add the `verdi node list` command [[cf091e80f]](https://github.com/aiidateam/aiida-core/commit/cf091e80ff2b6aa03f41b56ba1976abb97298972) +- CLI: Add the `verdi presto` command [[6b6e1520f]](https://github.com/aiidateam/aiida-core/commit/6b6e1520f2d3807e366dd672e7917f381ea7b524) +- CLI: Add the `verdi profile configure-rabbitmq` command [[202a3ece9]](https://github.com/aiidateam/aiida-core/commit/202a3ece9705289a1f12c85e64cf90307ca85c39) +- CLI: Allow `verdi computer delete` to delete associated nodes [[348777571]](https://github.com/aiidateam/aiida-core/commit/3487775711e7412fb2cb82600fb266316d6ce12a) +- CLI: Allow multiple root nodes in `verdi node graph generate` [[f16c432af]](https://github.com/aiidateam/aiida-core/commit/f16c432af107b1f9c01a12e03cbd0a9ecc2744ad) +- Engine: Allow `CalcJob` monitors to return outputs [[b7e59a0db]](https://github.com/aiidateam/aiida-core/commit/b7e59a0dbc0dd629be5c8178e98c70e7a2c116e9) +- Make `postgres_cluster` and `config_psql_dos` fixtures configurable [[35d7ca63b]](https://github.com/aiidateam/aiida-core/commit/35d7ca63b44f051a26d3f96d84e043919eb3f101) +- Process: Add the `metadata.disable_cache` input [[4626b11f8]](https://github.com/aiidateam/aiida-core/commit/4626b11f85cd0d95a17d8f5766a90b88ddddd689) +- Storage: Add backup mechanism to the interface [[bf79f23ee]](https://github.com/aiidateam/aiida-core/commit/bf79f23eef66d362a34aac170577ba8f5c2088ba) +- Transports: fix overwrite behaviour for `puttree`/`gettree` [[a55451703]](https://github.com/aiidateam/aiida-core/commit/a55451703aa8f8d330b25bc5da95d41caf0db9ac) + +#### Performance +- CLI: Speed up tab-completion by lazily importing `Config` [[9524cda0b]](https://github.com/aiidateam/aiida-core/commit/9524cda0b8c742fb5bf740d7b0035e326eace28f) +- Improve import time of `aiida.orm` and `aiida.storage` [[fb9b6cc3b]](https://github.com/aiidateam/aiida-core/commit/fb9b6cc3b3df244549fdd78576c34f6d9dfd4568) +- ORM: Cache the logger adapter for `ProcessNode` [[1d104d06b]](https://github.com/aiidateam/aiida-core/commit/1d104d06b95da36c71cab132c7b6fec52a005e18) + +#### Changes +- Caching: `NodeCaching._get_objects_to_hash` return type to `dict` [[c9c7c4bd8]](https://github.com/aiidateam/aiida-core/commit/c9c7c4bd8e1cd306271b5cf267095d3cbd8aafe2) +- Caching: Add `CACHE_VERSION` attribute to `CalcJob` and `Parser` [[39d0f312d]](https://github.com/aiidateam/aiida-core/commit/39d0f312d212a642d1537ca89e7622e48a23e701) +- Caching: Include the node's class in objects to hash [[68ce11161]](https://github.com/aiidateam/aiida-core/commit/68ce111610c40e3d9146e128c0a698fc60b6e5e5) +- Caching: Make `NodeCaching._get_object_to_hash` public [[e33000402]](https://github.com/aiidateam/aiida-core/commit/e330004024ad5121f9bc82cbe972cd283f25fec8) +- Caching: Remove core and plugin information from hash calculation [[4c60bbef8]](https://github.com/aiidateam/aiida-core/commit/4c60bbef852eef55a06b48b813d3fbcc8fb5a43f) +- Caching: Rename `get_hash` to `compute_hash` [[b544f7cf9]](https://github.com/aiidateam/aiida-core/commit/b544f7cf95a0e6e698224f36c1bea57d1cd99e7d) +- CLI: Always do hard reset in `verdi daemon restart` [[8ac642410]](https://github.com/aiidateam/aiida-core/commit/8ac6424108d1528bd3279c81da62dd44855b6ebc) +- CLI: Change `--profile` to `-p/--profile-name` for `verdi profile setup` [[8ea203cd9]](https://github.com/aiidateam/aiida-core/commit/8ea203cd9b1d2fbb4a3b38ba67beec97bb8c7145) +- CLI: Let `-v/--verbosity` only affect `aiida` and `verdi` loggers [[487c6bf04]](https://github.com/aiidateam/aiida-core/commit/487c6bf047030ee19deed49d5fbf9a093253538e) +- Engine: Set the `to_aiida_type` as default inport port serializer [[2fa7a5305]](https://github.com/aiidateam/aiida-core/commit/2fa7a530511a94ead83d79669efed71706a0a472) +- `QueryBuilder`: Remove implementation for `has_key` in SQLite storage [[24cfbe27e]](https://github.com/aiidateam/aiida-core/commit/24cfbe27e7408b78fca8e6f69799ebad3659400b) + +#### Fixes +- `BandsData`: Use f-strings in `_prepare_gnuplot` [[dba117437]](https://github.com/aiidateam/aiida-core/commit/dba117437782abc6d11f9ef208923f7e70f79ed2) +- `BaseRestartWorkChain`: Fix handler overrides used only first iteration [[65786a6bd]](https://github.com/aiidateam/aiida-core/commit/65786a6bda1c74dfb4aea90becd0664de6b1abde) +- `SlurmScheduler`: Make detailed job info fields dynamic [[4f9774a68]](https://github.com/aiidateam/aiida-core/commit/4f9774a689b81a446fac37ad8281b2d854eefa7a) +- `SqliteDosStorage`: Fix exception when importing archive [[af0c260bb]](https://github.com/aiidateam/aiida-core/commit/af0c260bb1c32c3b33c50175d790907774561b3e) +- `StructureData`: Fix the pbc constraints of `get_pymatgen_structure` [[adcce4bcd]](https://github.com/aiidateam/aiida-core/commit/adcce4bcd0b59c8371be73058a060bedcaba40f6) +- Archive: Automatically create nested output directories [[212f6163b]](https://github.com/aiidateam/aiida-core/commit/212f6163b03b8762509ae2230c30172af8c02fed) +- Archive: Respect `filter_size` in query for existing nodes [[ef60b66aa]](https://github.com/aiidateam/aiida-core/commit/ef60b66aa3ce76d654abe5e7caafef3f221defd0) +- CLI: Ensure deprecation warnings are printed before any prompts [[deb293d0e]](https://github.com/aiidateam/aiida-core/commit/deb293d0e6a566256fac5069881de4846d77f6d1) +- CLI: Fix `verdi archive create --dry-run` for empty file repository [[cc96c9d04]](https://github.com/aiidateam/aiida-core/commit/cc96c9d043c6616a068a5498f557fa21a728eb96) +- CLI: Fix `verdi plugin list` incorrectly not displaying description [[e952d7717]](https://github.com/aiidateam/aiida-core/commit/e952d7717c1d8001555e8d19f54f4fa349da6c6e) +- CLI: Fix `verdi process [show|report|status|watch|call-root]` no output [[a56a1389d]](https://github.com/aiidateam/aiida-core/commit/a56a1389dee5cb9ae70a5511d77aad248ea21731) +- CLI: Fix `verdi process list` if no available workers [[b44afcb3c]](https://github.com/aiidateam/aiida-core/commit/b44afcb3c1a7efa452d4e72aa6f8a615f652aaa4) +- CLI: Fix `verdi quicksetup` when profiles exist where storage is not `core.psql_dos` [[6cb91c181]](https://github.com/aiidateam/aiida-core/commit/6cb91c18163ac6228ed4a64c1c467dfd0398a624) +- CLI: Fix dry-run resulting in critical error in `verdi archive import` [[36991c6c8]](https://github.com/aiidateam/aiida-core/commit/36991c6c84f4ba0b4553e8cd6689bbc1815dbd35) +- CLI: Fix logging not showing in `verdi daemon worker` [[9bd8585bd]](https://github.com/aiidateam/aiida-core/commit/9bd8585bd5e7989e24646a0018710e86836e5a9f) +- CLI: Fix the `ctx.obj.profile` attribute not being initialized [[8a286f26e]](https://github.com/aiidateam/aiida-core/commit/8a286f26e8d303c498ac2eabd49be5f1f4ced9ef) +- CLI: Hide misleading message for `verdi archive create --test-run` [[7e42d7aa7]](https://github.com/aiidateam/aiida-core/commit/7e42d7aa7d16fa9e81cbd300ada14e4dea2426ce) +- CLI: Improve error message of `PathOrUrl` and `FileOrUrl` [[ffc6e4f70]](https://github.com/aiidateam/aiida-core/commit/ffc6e4f706277854dbd454d6f3164cec31e7819a) +- CLI: Only configure logging in `set_log_level` callback once [[66a2dcedd]](https://github.com/aiidateam/aiida-core/commit/66a2dcedd0a9428b5b2218b8c82bad9c9aff4956) +- CLI: Unify help of `verdi process` commands [[d91e0a58d]](https://github.com/aiidateam/aiida-core/commit/d91e0a58dabfd242b5f886d692c8761499a6719c) +- Config: Set existing user as default for read-only storages [[e66592509]](https://github.com/aiidateam/aiida-core/commit/e665925097bb3344fde4bcc66ee185a2d9207ac3) +- Config: Use UUID in `Manager.load_profile` to identify profile [[b01038bf1]](https://github.com/aiidateam/aiida-core/commit/b01038bf1fca7d33c4915aee904acea89a847614) +- Daemon: Log the worker's path and Python interpreter [[ae2094169]](https://github.com/aiidateam/aiida-core/commit/ae209416996ec361c474aeaf0fa06f49dd59f296) +- Docker: Start and stop daemon only when a profile exists [[0a5b20023]](https://github.com/aiidateam/aiida-core/commit/0a5b200236419d8caf8e05bb04ba80d03a438e03) +- Engine: Add positional inputs for `Process.submit` [[d1131fe94]](https://github.com/aiidateam/aiida-core/commit/d1131fe9450972080207db6e9615784490b3252b) +- Engine: Catch `NotImplementedError`in `get_process_state_change_timestamp` [[04926fe20]](https://github.com/aiidateam/aiida-core/commit/04926fe20da15065f8f086f1ff3cb14cc163aa08) +- Engine: Fix paused work chains not showing it in process status [[40b22d593]](https://github.com/aiidateam/aiida-core/commit/40b22d593875b97355996bbfc15e2850ad1f0495) +- Fix passwords containing `@` not being accepted for Postgres databases [[d14c14db2]](https://github.com/aiidateam/aiida-core/commit/d14c14db2f82d3a678e9747bd463ec1a61642120) +- ORM: Correct field type of `InstalledCode` and `PortableCode` models [[0079cc1e4]](https://github.com/aiidateam/aiida-core/commit/0079cc1e4b46c61edf2323b2d42af46367fe04b6) +- ORM: Fix `ProcessNode.get_builder_restart` [[0dee9d8ef]](https://github.com/aiidateam/aiida-core/commit/0dee9d8efba5c48615e8510f5ada706724b4a2e8) +- ORM: Fix deprecation warning being shown for new code types [[a9155713b]](https://github.com/aiidateam/aiida-core/commit/a9155713bbb10e57fe91cd320e2a12391d098a46) +- Runner: Close event loop in `Runner.close()` [[53cc45837]](https://github.com/aiidateam/aiida-core/commit/53cc458377685e54179eb1e1b73bb0383c8dae13) + +#### Deprecations +- CLI: Deprecate `verdi profile setdefault` and rename to `verdi profile set-default` [[ab48a4f62]](https://github.com/aiidateam/aiida-core/commit/ab48a4f627b4c9eec9133b5efa9fb888ce2c4914) +- CLI: Deprecate accepting `0` for `default_mpiprocs_per_machine` [[acec0c190]](https://github.com/aiidateam/aiida-core/commit/acec0c190cbb45ba267c6eb8ee7ceba18cf3302b) +- CLI: Deprecate the `deprecated_command` decorator [[4c11c0616]](https://github.com/aiidateam/aiida-core/commit/4c11c0616c583236119f838a1780a606c58b4ee2) +- CLI: Remove the deprecated `verdi database` command [[3dbde9e31]](https://github.com/aiidateam/aiida-core/commit/3dbde9e311781509b738202ad6f1de3bbd4b7a82) +- ORM: Undo deprecation of `Code.get_description` [[1b13014b1]](https://github.com/aiidateam/aiida-core/commit/1b13014b14274024dcb6bb0a721eb62665567987) + +### Dependencies +- Update `tabulate>=0.8.0,<0.10.0` [[6db2f4060]](https://github.com/aiidateam/aiida-core/commit/6db2f4060d4ece4552f5fe757c0f7d938810f4d1) + +#### Refactoring +- Abstract message broker functionality [[69389e038]](https://github.com/aiidateam/aiida-core/commit/69389e0387369d8437e1219487b88430b7b2e679) +- Config: Refactor `get_configuration_directory_from_envvar` [[65739f524]](https://github.com/aiidateam/aiida-core/commit/65739f52446087439ba93158eb948b58ed081ce5) +- Config: Refactor the `create_profile` function and method [[905e93444]](https://github.com/aiidateam/aiida-core/commit/905e93444cf996461e679cd458511d1c471a7e02) +- Engine: Refactor handling of `remote_folder` and `retrieved` outputs [[28adacaf8]](https://github.com/aiidateam/aiida-core/commit/28adacaf8ae21357bf6e5a2a48c43ed56d3bd78b) +- ORM: Switch to `pydantic` for code schema definition [[06189d528]](https://github.com/aiidateam/aiida-core/commit/06189d528c2362516f42e0d48840882812b97fe4) +- Replace deprecated `IOError` with `OSError` [[7f9129fd1]](https://github.com/aiidateam/aiida-core/commit/7f9129fd193374bdbeaa7ba4dd8c3cdf706db97d) +- Storage: Move profile locking to the abstract base class [[ea5f51bcb]](https://github.com/aiidateam/aiida-core/commit/ea5f51bcb6af172eb1a754df3981003bf7bad959) + +#### Documentation +- Add more instructions on how to use docker image [[aaf44afcc]](https://github.com/aiidateam/aiida-core/commit/aaf44afcce0f90fff2eb38bc47d28b4adf87db24) +- Add the updated cheat sheet [[09f9058a7]](https://github.com/aiidateam/aiida-core/commit/09f9058a7444f3ac1d3f243b608fa3f24f771f27) +- Add tips for common problems with conda PostgreSQL setup [[cd5313825]](https://github.com/aiidateam/aiida-core/commit/cd5313825afdb1771ca19d899567e4ed4774a2bc) +- Customize the color scheme through custom style sheet [[a6cf7fc7e]](https://github.com/aiidateam/aiida-core/commit/a6cf7fc7e02a48a7e3b9c4ba6ce5e2cd413e6b23) +- Docs: Clarify `Transport.copy` requires `recursive=True` if source is a directory [[310ff1db7]](https://github.com/aiidateam/aiida-core/commit/310ff1db77bc75b6cadedf77394b96af05456f43) +- Fix example of the `entry_points` fixture [[081fc5547]](https://github.com/aiidateam/aiida-core/commit/081fc5547370e1b5a19b1fb507681091c632bb7a) +- Fixing several small issues [[6a3a59b29]](https://github.com/aiidateam/aiida-core/commit/6a3a59b29ba64401828d9ab51dc123060868278b) +- Minor cheatsheet update for v2.6 release [[c3cc169c4]](https://github.com/aiidateam/aiida-core/commit/c3cc169c487a88e2357b7377e897f0521c23f05a) +- Reorganize the tutorial content [[5bd960efa]](https://github.com/aiidateam/aiida-core/commit/5bd960efae5a7f916b978a420f5f43501c9bc529) +- Rework the installation section [[0ee0a0c6a]](https://github.com/aiidateam/aiida-core/commit/0ee0a0c6ae13588e82edf1cf9e8cb9857c94c31b) +- Standardize usage of `versionadded` directive [[bf5dac848]](https://github.com/aiidateam/aiida-core/commit/bf5dac8484638d7ba5c492e91975b5fcc0cc9770) +- Update twitter logo [[5e4f60d83]](https://github.com/aiidateam/aiida-core/commit/5e4f60d83160774ca83defe4bf1f6c6381aaa1a0) +- Use uv installer in readthedocs build [[be0db3cc4]](https://github.com/aiidateam/aiida-core/commit/be0db3cc49506294ae1845b6e746e40cd76f39a9) + +#### Devops +- Add `check-jsonschema` pre-commit hook for GHA workflows [[14c5bb0f7]](https://github.com/aiidateam/aiida-core/commit/14c5bb0f764f0fd7933df205aa22d61c85ec0cf2) +- Add Dependabot config for maintaining GH actions [[0812f4b9e]](https://github.com/aiidateam/aiida-core/commit/0812f4b9eeffdff5a8c3d0802aea94c8919d9922) +- Add docker image `aiida-core-dev` for development [[6d0984109]](https://github.com/aiidateam/aiida-core/commit/6d0984109478ec1c0fd96dfd1d3f2b54e0b75dd2) +- Add Python 3.12 tox environment [[6b0d43960]](https://github.com/aiidateam/aiida-core/commit/6b0d4396068a43b6823eca8c78b9048044b0b4b8) +- Add the `slurm` service to nightly workflow [[5460a0414]](https://github.com/aiidateam/aiida-core/commit/5460a0414d55e3531eb86e6906ee963a6b712aae) +- Add typing to `aiida.common.hashing` [[ba21ba1d4]](https://github.com/aiidateam/aiida-core/commit/ba21ba1d40a76df73a2e27ce6f1a4f68aba7fb9a) +- Add workflow to build Docker images on PRs from forks [[23d2aa5ee]](https://github.com/aiidateam/aiida-core/commit/23d2aa5ee3c08438cfc4b4734e9670e19c090150) +- Address internal deprecation warnings [[ceed7d55d]](https://github.com/aiidateam/aiida-core/commit/ceed7d55dfb7df8dbe52c4557d145593d83f788a) +- Allow unit test suite to be ran against SQLite [[0dc8bbcb2]](https://github.com/aiidateam/aiida-core/commit/0dc8bbcb261b745683bc542c1aced2412ebd66a0) +- Bump the gha-dependencies group with 4 updates [[ccb56286c]](https://github.com/aiidateam/aiida-core/commit/ccb56286c40f6be0d61a0c62442993e43faf1ba6) +- Dependencies: Update the requirements files [[61ae1a55b]](https://github.com/aiidateam/aiida-core/commit/61ae1a55b94c50979b4e47bb7572e1d4c9b2391f) +- Disable code coverage in `test-install.yml` [[4cecda517]](https://github.com/aiidateam/aiida-core/commit/4cecda5177c456cee252c16295416c3842bb5d2d) +- Do not pin the mamba version [[82bba1307]](https://github.com/aiidateam/aiida-core/commit/82bba130792f6c965f0ede8b221eee70fd01d9f1) +- Fix Docker build not defining `REGISTRY` [[e7953fd4d]](https://github.com/aiidateam/aiida-core/commit/e7953fd4dd14875e380b125b99f86c12ce15359b) +- Fix publishing to DockerHub using incorrect secret name [[9c9ff7986]](https://github.com/aiidateam/aiida-core/commit/9c9ff79865225b125ba5f9fe23969d4c2c8fb9b2) +- Fix Slack notification for nightly tests [[082589f45]](https://github.com/aiidateam/aiida-core/commit/082589f456201fbd79d3df809e2cfc5fb5f27922) +- Fix the `test-install.yml` workflow [[22ea06362]](https://github.com/aiidateam/aiida-core/commit/22ea06362e9de5d314f103332da2e25ae6080f61) +- Fix the Docker builds [[3404c0192]](https://github.com/aiidateam/aiida-core/commit/3404c01925da941c08f246e231b6587f53ce445b) +- Increase timeout for the `test-install` workflow [[e36a3f11f]](https://github.com/aiidateam/aiida-core/commit/e36a3f11fdd165eea3af9f3337382e1bbd181390) +- Move RabbitMQ CI to nightly and update versions [[b47a56698]](https://github.com/aiidateam/aiida-core/commit/b47a56698e8fdf350a10c7abfd8ba00443fabd8d) +- Refactor the GHA Docker build [[e47932ee9]](https://github.com/aiidateam/aiida-core/commit/e47932ee9e0833dca546c7c7b5b584f2687d9073) +- Remove `verdi tui` from CLI reference documentation [[1b4a19a44]](https://github.com/aiidateam/aiida-core/commit/1b4a19a44461271aea58e54acd93e896220b413d) +- Run Docker workflow only for pushes to origin [[b1a714155]](https://github.com/aiidateam/aiida-core/commit/b1a714155ec9e51263e453a4354934fa91d04f33) +- Tests: Convert hierarchy functions into fixtures [[a02abc470]](https://github.com/aiidateam/aiida-core/commit/a02abc4701e81b75284164510be966ff0fd04dab) +- Tests: extend `node_and_calc_info` fixture to `core.ssh` [[9cf28f208]](https://github.com/aiidateam/aiida-core/commit/9cf28f20875fe6c3b0f2844bff16415b1dfc7b6f) +- Tests: Remove test classes for transport plugins [[b77e51f8c]](https://github.com/aiidateam/aiida-core/commit/b77e51f8c15c7ddea80d7d6328737abb705c6ce8) +- Tests: Unskip test in `tests/cmdline/commands/test_archive_import.py` [[7b7958c7a]](https://github.com/aiidateam/aiida-core/commit/7b7958c7aee150162cb4db0562a7352764a94c04) +- Update codecov action [[fc2a84d9b]](https://github.com/aiidateam/aiida-core/commit/fc2a84d9bd045d48d46511153dfde389070bf552) +- Update deprecated `whitelist_externals` option in tox config [[8feef5189]](https://github.com/aiidateam/aiida-core/commit/8feef5189ab9a2ba4b358bb6937d5d7c3f555ad8) +- Update pre-commit hooks [[3dda84ff3]](https://github.com/aiidateam/aiida-core/commit/3dda84ff3057a97e422b809a6adc778cbf60c125) +- Update pre-commit requirement `ruff==0.3.5` [[acd54543d]](https://github.com/aiidateam/aiida-core/commit/acd54543dffca05df7189f36c71afd2bb0065f34) +- Update requirements `mypy` and `pre-commit`[[04b3260a0]](https://github.com/aiidateam/aiida-core/commit/04b3260a098f061301edb0f56f1675fe9283b41b) +- Update requirements to address deprecation warnings [[566f681f7]](https://github.com/aiidateam/aiida-core/commit/566f681f72426a9a08200ff1d86c604f4c37bbcf) +- Use `uv` to install package in CI and CD [[73a734ae3]](https://github.com/aiidateam/aiida-core/commit/73a734ae3cd0977a97c631f97ddb781fa293864a) +- Use recursive dependencies for `pre-commit` extra [[6564e78dd]](https://github.com/aiidateam/aiida-core/commit/6564e78ddb349b89f6a3e9bfa81ce357ce865961) + + +## v2.5.1 - 2024-01-31 + +This is a patch release with a few bug fixes, but mostly devops changes related to the package structure. + +### Fixes +- CLI: Fix `verdi process repair` not actually repairing [[784ad6488]](https://github.com/aiidateam/aiida-core/commit/784ad64885e23c8c93fb21554dc3c7d1f6bdde0f) +- Docker: Allow default profile parameters to be configured through env variables [[06ea130df]](https://github.com/aiidateam/aiida-core/commit/06ea130df8854f621e25853af6ac723c37397ed0) + +### Dependencies +- Dependencies: Fix incompatibility with `spglib>=2.3` [[fa8b9275e]](https://github.com/aiidateam/aiida-core/commit/fa8b9275e74d16df7df4884b7c2eff4ad0cca1ce) + +### Devops +- Devops: Move the source directory into `src/` [[53748d4de]](https://github.com/aiidateam/aiida-core/commit/53748d4de609c79b37cf9c7e0170c913e8d6dd0d) +- Devops: Remove post release action for uploading pot to transifex [[9feda35eb]](https://github.com/aiidateam/aiida-core/commit/9feda35ebac0101c7fa16629cefcd411ed994425) +- Pre-commit: Add `ruff` as the new linter and formatter [[64c5e6a82]](https://github.com/aiidateam/aiida-core/commit/64c5e6a82d8bb515d07fea84b611dc35cce1263b) +- Pre-commit: Update a number of pre-commit hooks [[a4ced7a67]](https://github.com/aiidateam/aiida-core/commit/a4ced7a67e10e0e88a2fe09a4f5c5c597789d43a) +- Pre-commit: Add YAML and TOML formatters [[c27aa33f3]](https://github.com/aiidateam/aiida-core/commit/c27aa33f33a7417da5d0b571b1927668f6505707) +- Update pre-commit CI configuration [[cb95f0c4c]](https://github.com/aiidateam/aiida-core/commit/cb95f0c4cb5ac0f56b0a3ec6654409cb6f22b5ba) +- Update pre-commit dependencies [[8dfab0e09]](https://github.com/aiidateam/aiida-core/commit/8dfab0e0928da5b8bbe5182e97825a701ca0130b) +- Dependencies: Pin `mypy` to minor version `mypy~=1.7.1` [[d65fa3d2d]](https://github.com/aiidateam/aiida-core/commit/d65fa3d2d724b126c26631771fa7840a2583d1a4) + +### Documentation +- Streamline and fix typos in `docs/topics/processes/usage.rst` [[45ba27732]](https://github.com/aiidateam/aiida-core/commit/45ba27732bb8ff8d6714c6a6114bc2c00d14c18c) +- Update process function section on file deduplication [[f35d7ae98]](https://github.com/aiidateam/aiida-core/commit/f35d7ae9801423c55e04fca22500ca23bca90739) +- Correct a typo in `docs/source/topics/data_types.rst` [[6ee278ceb]](https://github.com/aiidateam/aiida-core/commit/6ee278cebe8fb58cd6e69517d678b97570d0d661) +- Fix the ADES paper citation [[80117f8f7]](https://github.com/aiidateam/aiida-core/commit/80117f8f7b36a0932bb8a2ec843d37f28bd41f87) + + +## v2.5.0 - 2023-12-20 + +This minor release comes with a number of features that are focused on user friendliness of the CLI and the API. +It also reduces the import time of modules, which makes the CLI faster to load and so tab-completion should be snappier. +The release adds support for Python 3.12 and a great number of bugs are fixed. + +- [Create profiles without a database server](#create-profiles-without-a-database-server) +- [Changes in process launch functions](#changes-in-process-launch-functions) +- [Improvements for built-in data types](#improvements-for-built-in-data-types) +- [Repository interface improvements](#repository-interface-improvements) +- [Full list of changes](#full-list-of-changes) + - [Features](#features) + - [Performance](#performance) + - [Changes](#changes) + - [Fixes](#fixes) + - [Deprecations](#deprecations) + - [Documentation](#documentation) + - [Dependencies](#dependencies) + - [Devops](#devops) + + +### Create profiles without a database server + +A new storage backend plugin has been added that uses [`SQLite`](https://www.sqlite.org/index.html) instead of PostgreSQL. +This makes it a lot easier to setup across all platforms. +A new profile using this storage backend can be created in a single command: +```shell +verdi profile setup core.sqlite_dos -n --profile --email +``` +Although easier to setup compared to the default storage backend that uses PostgreSQL, it is less performant. +This makes this storage ideally suited for use-cases that want to test or demonstrate AiiDA, or to just play around a bit. +The storage is compatible with most of AiiDA's functionality, except for automated database migrations and some very specific `QueryBuilder` functionality. +Therefore, for production databases, the default `core.psql_dos` storage entry point remains the recommended storage. + +It is now also possible to create a profile using an export archive: +```shell +verdi profile setup core.sqlite_dos -n --profile --filepath +``` +where `` should point to an export archive on disk. +You can now use this profile like any other profile to inspect the data of the export archive. +Note that this profile is read-only, so you will not be able to use it to mutate existing data or add new data to the profile. +See the [documentation for more details and a more in-depth example](https://aiida.readthedocs.io/projects/aiida-core/en/v2.5.0/howto/archive_profile.html). + +Finally, the original storage plugin `core.psql_dos`, which uses PostgreSQL for the database is also accessible through `verdi profile setup core.psql_dos`. +Essentially this is the same as the `verdi setup` command, which is kept for now for backwards compatibility. + +See the [documentation on storage plugins](https://aiida.readthedocs.io/projects/aiida-core/en/v2.5.0/topics/storage.html) for more details on the differences between these storage plugins and when to use which. + +The `verdi profile delete` command can now also be used to delete a profile for any of these storage plugins. +You will be prompted whether you also want to delete all the data, or you can specify this with the `--delete-data` or `--keep-data` flags. + +### Changes in process launch functions + +The `aiida.engine.submit` method now accepts the argument `wait`. +When set to `True`, instead of returning the process node straight away, the function will wait for the process to terminate before returning. +By default it is set to `False` so the current behavior remains unchanged. +```python +from aiida.engine import submit +node = submit(Process, wait=True) # This call will block until process is terminated +assert node.is_terminated +``` + +This new feature is mostly useful for interactive demos and tutorials in notebooks. +In these situations, it might be beneficial to use `aiida.engine.run` because the cell will be blocking until it is finished, indicating to the user that something is processing. +When using `submit`, the cell returns immediately, but the results are not ready yet and typically the next cell cannot yet be executed. +Instead, the demo should redirect the user to using something like `verdi process list` to query the status of the process. + +However, using `run` has downsides as well, most notably that the process will be lost if the notebook gets disconnected. +For processes that are expected to run longer, this can be really problematic, and so `submit` will have to be used regardless. +With the new `wait` argument, `submit` provides the best of both worlds. + +Although very useful, the introduction of this feature does break any processes that define `wait` or `wait_interval` as an input. +Since the inputs to a process are defined as keyword arguments, these inputs would overlap with the arguments to the `submit` method. +To solve this problem, inputs can now _also_ be passed as a dictionary, e.g., where one would do before: +```python +submit(SomeProcess, x=Int(1), y=Int(2), code=load_code('some-code')) +# or alternatively +inputs = { + 'x': Int(1), + 'y': Int(2), + 'code': load_code('some-code'), +} +submit(SomeProcess, **inputs) +``` +The new syntax allows the following: +```python +inputs = { + 'x': Int(1), + 'y': Int(2), + 'code': load_code('some-code'), +} +submit(SomeProcess, inputs) +``` +Passing inputs as keyword arguments is still supported because sometimes that notation is still more legible than defining an intermediate dictionary. +However, if both an input dictionary and keyword arguments are define, an exception is raised. + +### Improvements for built-in data types + +The `XyData` and `ArrayData` data plugins now allow to directly pass the content in the constructor. +This allows defining the complete node in a single line +```python +import numpy as np +from aiida.orm import ArrayData, XyData + +xy = XyData(np.array([1, 2]), np.array([3, 4]), x_name='x', x_units='E', y_names='y', y_units='F') +assert all(xy.get_x()[1] == np.array([1, 2])) + +array = ArrayData({'a': np.array([1, 2]), 'b': np.array([3, 4])}) +assert all(array.get_array('a') == np.array([1, 2])) +``` +It is now also no longer required to specify the name in `ArrayData.get_array` as long as the node contains just a single array: +```python +import numpy as np +from aiida.orm import ArrayData + +array = ArrayData(np.array([1, 2])) +assert all(array.get_array() == np.array([1, 2])) +``` + +### Repository interface improvements + +As of `v2.0.0`, the repository interface of the `Node` class was moved to the `Node.base.repository` namespace. +This was done to clean up the top-level namespace of the `Node` class which was getting very crowded, and in most use-cases, a user never needs to directly access these methods. +It is up to the data plugin to provide specific methods to retrieve data that might be stored in the repository. +For example, with the `ArrayData`, a user should now have to go to `ArrayData.base.repository.get_object_content` to retrieve an array from the repository, but the class provides `ArrayData.get_array` as a shortcut. + +A few data plugins that ship with `aiida-core` didn't respect this guideline, most notably the `FolderData` and `SinglefileData` plugins. +This has been corrected in this release: for `FolderData`, all the repository methods are now once again directly available on the top-level namespace. +The `SinglefileData` now makes it easier to get the content as bytes. +Before, one had to do: +```python +from aiida.orm import SinglefileData +node = SinglefileData.from_string('some content') +with node.open(mode='rb') as handle: + byte_content = handle.read() +``` +this can now be achieved with: +```python +from aiida.orm import SinglefileData +node = SinglefileData.from_string('some content') +byte_content = node.get_content(mode='rb') +``` + +As of v2.0, due to the repository redesign, it was no longer possible to access a file directly by a filepath on disk. +The repository interface only interacts with file-like objects to stream the content. +However, a lot of Python libraries expect filepaths on disk and do not support file-like objects. +This would force an AiiDA user to write the file from the repository to a temporary file on disk, and pass that temporary filepath. +For example, consider the `numpy.loadtxt` function which requires a filepath, the code would look something like: +```python +import pathlib +import shutil +import tempfile + +with tempfile.TemporaryDirectory() as tmp_path: + + # Copy the entire content to the temporary folder + dirpath = pathlib.Path(tmp_path) + node.base.repository.copy_tree(dirpath) + + # Or copy the content of a file. Should use streaming + # to avoid reading everything into memory + filepath = (dirpath / 'some_file.txt') + with filepath.open('rb') as target: + with node.base.repository.open('rb') as source: + shutil.copyfileobj(source, target) + + # Now use `filepath` to library call, e.g. + numpy.loadtxt(filepath) +``` +This burdensome boilerplate has now been made obsolete by the `as_path` method: +```python +with node.base.repository.as_path() as filepath: + numpy.loadtxt(filepath) +``` +For the `FolderData` and `SinglefileData` plugins, the method can be accessed on the top-level namespace of course. + +### Full list of changes + +#### Features +- Add the `SqliteDosStorage` storage backend [[702f88788]](https://github.com/aiidateam/aiida-core/commit/702f8878829b8e2a65d81623cc2238eb40791bc6) +- `XyData`: Allow defining array(s) on construction [[f11598dc6]](https://github.com/aiidateam/aiida-core/commit/f11598dc68a80bbfa026db064158aae64ac0e802) +- `ArrayData`: Make `name` optional in `get_array` [[7fbe67cb6]](https://github.com/aiidateam/aiida-core/commit/7fbe67cb6273cf2bae4256cdbda284aeb89a9372) +- `ArrayData`: Allow defining array(s) on construction [[35e669fe8]](https://github.com/aiidateam/aiida-core/commit/35e669fe86ca467e656f4e500f11d533f7492107) +- `FolderData`: Expose repository API on top-level namespace [[3e1f87373]](https://github.com/aiidateam/aiida-core/commit/3e1f87373e3cf2c40e8a3134ac848d4c16b9dbcf) +- Repository: Add the `as_path` context manager [[b0546e8ed]](https://github.com/aiidateam/aiida-core/commit/b0546e8ed12b0982617293ab4a03ba3ec2d8ea44) +- Caching: Add the `strict` argument configuration validation [[f272e197e]](https://github.com/aiidateam/aiida-core/commit/f272e197e2992f445b2b51608a6ffe17a2a8f4c1) +- Caching: Try to import an identifier if it is a class path [[2c56fc234]](https://github.com/aiidateam/aiida-core/commit/2c56fc234139e624eb1da5ee016c1761b7b1a70a) +- CLI: Add the command `verdi profile setup` [[351021164]](https://github.com/aiidateam/aiida-core/commit/351021164d00aa3a2a78b5b6e43e8a87a8553151) +- CLI: Add `cached` and `cached_from` projections to `verdi process list` [[3b445c4f1]](https://github.com/aiidateam/aiida-core/commit/3b445c4f1c793ecc9b5c2efce863620748610d61) +- CLI: Add `--all` flag to `verdi process kill` [[db1375949]](https://github.com/aiidateam/aiida-core/commit/db1375949b9ec133ee3b06bc3bfe2f8185eceeb6) +- CLI: Lazily validate entry points in parameter types [[d3807d422]](https://github.com/aiidateam/aiida-core/commit/d3807d42229ffbad4e74752b6842a60f66bbafed) +- CLI: Add repair hint to `verdi process play/pause/kill` [[8bc31bfd1]](https://github.com/aiidateam/aiida-core/commit/8bc31bfd1dae84a2240470a8163b3407eb27ae03) +- CLI: Add the `verdi process repair` command [[3e3d9b9f7]](https://github.com/aiidateam/aiida-core/commit/3e3d9b9f70bb1ae2f7ae86db06469b73c5ebdfae) +- CLI: Validate strict in `verdi config set caching.disabled_for` [[9cff59232]](https://github.com//commit/9cff5923263cd349da731b02d309120e754c0b95) +- `DynamicEntryPointCommandGroup`: Allow entry points to be excluded [[9e30ec8ba]](https://github.com//commit/9e30ec8baeee74ae6d1c08459cb6eacd46d12e8a) +- Add the `aiida.common.log.capture_logging` utility [[9006eef3a]](https://github.com/aiidateam/aiida-core/commit/9006eef3ac1bb7b47c8ced63766e2f5346d46e91) +- `Config`: Add the `create_profile` method [[ae7abe8a6]](https://github.com/aiidateam/aiida-core/commit/ae7abe8a6bddcf8d59b6ac213a73deeb65d4c056) +- Engine: Add the `await_processes` utility function [[45767f050]](https://github.com/aiidateam/aiida-core/commit/45767f0509513fecd287e334fb26299db2adf14b) +- Engine: Add the `wait` argument to `submit` [[8f5e929d1]](https://github.com/aiidateam/aiida-core/commit/8f5e929d1660b663894bac52f385874011e47872) +- ORM: Add the `User.is_default` property [[a43c4cd0f]](https://github.com/aiidateam/aiida-core/commit/a43c4cd0fcee252202f9a5a3016aef156a36ac29) +- ORM: Add `NodeCaching.CACHED_FROM_KEY` for `_aiida_cached_from` constant [[35fc3ae57]](https://github.com/aiidateam/aiida-core/commit/35fc3ae5790023022d4d78cf2fe7274a72b590d2) +- ORM: Add the `Entity.get_collection` classmethod [[305f1dbf4]](https://github.com/aiidateam/aiida-core/commit/305f1dbf4ccb3e0e2e79865aee8d248e5ad55b95) +- ORM: Add the `Dict.get` method [[184fcd16e]](https://github.com//commit/184fcd16e9a88fbf9d4e754870416f4a56de55b5) +- ORM: Register `numpy.ndarray` with the `to_aiida_type` to `ArrayData` [[d8dd776a6]](https://github.com/aiidateam/aiida-core/commit/d8dd776a68f438702aa07b58d754b35ab0745937) +- Manager: Add the `set_default_user_email` [[8f8f55807]](https://github.com/aiidateam/aiida-core/commit/8f8f55807fd02872e7a345b7bd10eb68f65cbcda) +- `CalcJob`: Add support for nested targets in `remote_symlink_list` [[0ec650c1a]](https://github.com/aiidateam/aiida-core/commit/0ec650c1ae31ac42f80940103ac81cb0eb53f06d) +- `RemoteData`: Add the `is_cleaned` property [[2a2353d3d]](https://github.com/aiidateam/aiida-core/commit/2a2353d3dd2712afda8f1ebbcf749c7cc99f06fd) +- `SqliteTempBackend`: Add support for reading from and writing to archives [[83fc5cf69]](https://github.com/aiidateam/aiida-core/commit/83fc5cf69e8fcecba1f4c47ccb6599e6d78ba9dc) +- `StorageBackend`: Add the `read_only` class attribute [[8a4303ff5]](https://github.com//commit/8a4303ff53ec0b14fe43fbf1f4e01b69efc689df) +- `SinglefileData`: Add `mode` keyword to `get_content` [[d082df7f1]](https://github.com/aiidateam/aiida-core/commit/d082df7f1b53057e15c8cbbc7e662ec808c27722) +- `BaseRestartWorkChain`: Factor out attachment of outputs [[d6093d101]](https://github.com/aiidateam/aiida-core/commit/d6093d101ddcdaba74a14b44bdd91eea95628903) +- Add support for `NodeLinksManager` to YAML serializer [[6905c134e]](https://github.com//commit/6905c134e737183a1f366d9f86d9f77dd4d74730) + +#### Performance +- CLI: Make loading of config lazy for improved responsiveness [[d533b7a54]](https://github.com/aiidateam/aiida-core/commit/d533b7a540ab9d420acec1833bb7e23f50d8a7c1) +- Cache the lookup of entry points [[12cc930db]](https://github.com/aiidateam/aiida-core/commit/12cc930dbf8f377527d89f6f39bc28a4638f8377) +- Refactor: Delay import of heavy packages to speed up import time [[5dda6fd97]](https://github.com/aiidateam/aiida-core/commit/5dda6fd9749a886585cebf9afc288ebc46f00429) +- Refactor: Delay import of heavy packages to speed up import time [[8e6e08dc7]](https://github.com/aiidateam/aiida-core/commit/8e6e08dc780152333e4a6b6966469a98e51fe061) +- Do not import `aiida.cmdline` in `aiida.orm` [[0879a4e27]](https://github.com/aiidateam/aiida-core/commit/0879a4e27559ac368545afd18a1f061e9c29b8c7) +- Lazily define `__type_string` in `orm.Group` [[ebf3101d9]](https://github.com/aiidateam/aiida-core/commit/ebf3101d9b2c6298070853bae6c7b06489a363ca) +- Lazily define `_plugin_type_string` and `_query_type_string of `Node` [[3a61a7003]](https://github.com/aiidateam/aiida-core/commit/3a61a70032d6ace3d27f1a701be048f3f2026b43) + +#### Changes +- CLI: `verdi profile delete` is now storage plugin agnostic [[5015f5fe1]](https://github.com//commit/5015f5fe12d93024ed0d7594d860f1f2cd977548) +- CLI: Usability improvements for interactive `verdi setup` [[c53ea20a4]](https://github.com/aiidateam/aiida-core/commit/c53ea20a497f66bc88f68d0603cf9a32614fc4c2) +- CLI: Do not load config in defaults and callbacks during tab-completion [[062058862]](https://github.com/aiidateam/aiida-core/commit/06205886204c142f771dab37f1a78f3bf0ba7251) +- Engine: Make process inputs in launchers positional [[6d18ccb86]](https://github.com//commit/6d18ccb8680f16e8da80deffe40808cc2e669de0) +- Remove `aiida.manage.configuration.load_documentation_profile` [[9941266ce]](https://github.com//commit/9941266ced93f31191152034606bf5b1e049cc79) +- ORM: `Sealable.seal()` return `self` instead of `None` [[16e3bd3b5]](https://github.com/aiidateam/aiida-core/commit/16e3bd3b5087b95d31983df2147d4c14bb331077) +- ORM: Move deprecation warnings from module level [[c4afdb9be]](https://github.com//commit/c4afdb9be5633b68d72121c36916dfc6791d8b29) +- Config: Switch from `jsonschema` to `pydantic` [[4203f162d]](https://github.com/aiidateam/aiida-core/commit/4203f162df803946b2396ca820e6b6139a3ecc61) +- `DynamicEntryPointCommandGroup`: Use `pydantic` to define config model [[1d8ea2a27]](https://github.com/aiidateam/aiida-core/commit/1d8ea2a27381feeabfe38f5a3647d22ac1b825e4) +- Config: Remove use of `NO_DEFAULT` for `Option.default` [[275718cc8]](https://github.com/aiidateam/aiida-core/commit/275718cc8dae866a6fc847fa898a3290672e9d7a) + +#### Fixes +- Add the `report` method to `logging.LoggerAdapter` [[7d6684ce1]](https://github.com/aiidateam/aiida-core/commit/7d6684ce1f46862e69c59e9b48da97ab63d9f786) +- `CalcJob`: Fix MPI behavior if `withmpi` option default is True [[84737506e]](https://github.com//commit/84737506e99860beb3ecfa329c1d1e9d4636cd16) +- `CalcJobNode`: Fix validation for `depth=None` in `retrieve_list` [[03c86d5c9]](https://github.com/aiidateam/aiida-core/commit/03c86d5c988d9d2e1f656ba28bd2b8292fc7b02d) +- CLI: Fix bug in `verdi data core.trajectory show` for various formats [[fd4c1269b]](https://github.com/aiidateam/aiida-core/commit/fd4c1269bf913602660b13bdb49c3bc15360448a) +- CLI: Add missing entry point groups for `verdi plugin list` [[ae637d8c4]](https://github.com/aiidateam/aiida-core/commit/ae637d8c474a0071031c6a9bf6f65d2a924f2e81) +- CLI: Remove loading backend for `verdi plugin list` [[34e564ad0]](https://github.com/aiidateam/aiida-core/commit/34e564ad081143a4739c58a7aaa499e55d4e4651) +- CLI: Fix `repository` being required for `verdi quicksetup` [[d4666009e]](https://github.com/aiidateam/aiida-core/commit/d4666009e82fc104a1fa7965b1f50934bec36f0f) +- CLI: Fix `verdi config set` when setting list option [[314917801]](https://github.com/aiidateam/aiida-core/commit/314917801181d163f0760ca5788c543103d96bf5) +- CLI: Keep list unique in `verdi config set --append` [[3844f86c6]](https://github.com/aiidateam/aiida-core/commit/3844f86c6bb7da1dfc40542210b450a70b8950c5) +- CLI: Improve the formatting of `verdi user list` [[806d7e236]](https://github.com/aiidateam/aiida-core/commit/806d7e2366225bbe16ed982c320a708dbbf323f5) +- CLI: Set defaults for user details in profile setup [[8b8887e55]](https://github.com/aiidateam/aiida-core/commit/8b8887e559e02eadac832a89f7012872040e1cbc) +- CLI: Reuse options in `verdi user configure` from setup [[1c0b702ba]](https://github.com/aiidateam/aiida-core/commit/1c0b702bafb56c6452c975ad7020796303742405) +- `InteractiveOption`: Fix validation being skipped if `!` provided [[c4b183bc6]](https://github.com/aiidateam/aiida-core/commit/c4b183bc6d6083dad0754e42de19e96a867ff8ed) +- ORM: Fix problem with detached `DbAuthInfo` instances [[ec2c6a8fe]](https://github.com//commit/ec2c6a8fe3b397ab9f7314c556551114ea15c7df) +- ORM: Check nodes are from same backend in `validate_link` [[7bd546ebe]](https://github.com/aiidateam/aiida-core/commit/7bd546ebe67845b47c0dc14567c1ef7a557c23ef) +- ORM: `ProcessNode.is_valid_cache` is `False` for unsealed nodes [[a1f456d43]](https://github.com/aiidateam/aiida-core/commit/a1f456d436fee6a54327e4ba9b0841a980998f52) +- ORM: Explicitly pass backend when constructing new entity [[96667c8c6]](https://github.com/aiidateam/aiida-core/commit/96667c8c63b0053e79c8a1531707890027f10e6a) +- ORM: Replace `.collection(backend)` with `.get_collection(backend)` [[bac2152c4]](https://github.com/aiidateam/aiida-core/commit/bac2152c450a83cb6332516db315147cfc982265) +- Make `warn_deprecation` respect the `warnings.showdeprecations` option [[6c28c63e9]](https://github.com//commit/6c28c63e95323a4e3ba8730ef720e1a708d91133) +- `PsqlDosBackend`: Fix changes not persisted after `iterall` and `iterdict` [[2ea5087c0]](https://github.com/aiidateam/aiida-core/commit/2ea5087c079417d6d0b37cbc0502ed7cab173c11) +- `PsqlDosBackend`: Fix `Node.store` excepting when inside a transaction [[624dcd9fc]](https://github.com/aiidateam/aiida-core/commit/624dcd9fcc1f0f9aadf54c59afa435fd78598ef7) +- `Parser.parse_from_node`: Validate outputs against process spec [[d16792f3d]](https://github.com/aiidateam/aiida-core/commit/d16792f3d80fb1c497840ff1b0f6f1e114a262da) +- Fix `QueryBuilder.count` for storage backends using sqlite [[5dc1555bc]](https://github.com/aiidateam/aiida-core/commit/5dc1555bc186a7b0205323801833037ae9a6bc36) +- Process functions: Fix bug with variable arguments [[ca8bbc67f]](https://github.com//commit/ca8bbc67fcb40d6cec4e1cae32ce114495c0eb1d) +- `SqliteZipBackend`: Return `self` in `store` [[6a43b3f15]](https://github.com/aiidateam/aiida-core/commit/6a43b3f15ca9cc2eab1a13f6670921f71809a956) +- `SqliteZipBackend`: Ensure the `filepath` is absolute and exists [[5eac8b49d]](https://github.com//commit/5eac8b49df33287c3dc6cfbf46eae491c3196fc4) +- Remove `with_dbenv` use in `aiida.orm` [[35c57b9eb]](https://github.com/aiidateam/aiida-core/commit/35c57b9eb63b42531111f27ac7cc76e129ccd14a) + +#### Deprecations +- Deprecated `aiida.orm.nodes.data.upf` and `verdi data core.upf` [[6625fd245]](https://github.com/aiidateam/aiida-core/commit/6625fd2456f4ee13297d797d08925a359474e30e) + +#### Documentation +- Add topic section on storage [[83dbe1ad9]](https://github.com//commit/83dbe1ad92be580fa26412e5db4d1f370ec91c7a) +- Add important note on using `iterall` and `iterdict` [[0aea7e41b]](https://github.com/aiidateam/aiida-core/commit/0aea7e41b24fb479b2a1bbc71ab72f43e823f3a7) +- Add links about "entry point" and "plugin" to tutorial [[517ffcb1c]](https://github.com/aiidateam/aiida-core/commit/517ffcb1c5ce32f281589432cde1d58588fa83e0) +- Disable the `warnings.showdeprecations` option [[4adb06c0c]](https://github.com//commit/4adb06c0ce32335fffe5d970febfd36dcd85edd5) +- Fix instructions for inspecting archive files [[0a9c2788e]](https://github.com//commit/0a9c2788ea54926a202c5c3393d9d34815bf4356) +- Changes are reverted if exception during `iterall` [[17c5d8724]](https://github.com/aiidateam/aiida-core/commit/17c5d872495fbb1b6a80d985cb71088095083bb9) +- Various minor fixes to `run_docker.rst` [[d3788adea]](https://github.com/aiidateam/aiida-core/commit/d3788adea220107bce3582d246bcc9674b5e1571) +- Update `pydata-sphinx-theme` and add Discourse links [[13df42c14]](https://github.com/aiidateam/aiida-core/commit/13df42c14abc6145da3880616288a98b2d5ecc74) +- Correct example of `verdi config unset` in troubleshooting [[d6143dbc8]](https://github.com/aiidateam/aiida-core/commit/d6143dbc87bbbb3b6d4758b3922a47741493897e) +- Improvements to sections containing recently added functionality [[836419f66]](https://github.com/aiidateam/aiida-core/commit/836419f6694e9d4d8e580f1b6fd71ffa27f635ef) +- Fix typo in `run_codes.rst` [[9bde86ec7]](https://github.com/aiidateam/aiida-core/commit/9bde86ec7700b3dd2df55c69fb8efb9887ed07d6) +- Fixtures: Fix `suppress_warnings` of `run_cli_command` [[9807cede4]](https://github.com//commit/9807cede4601349a50ac2bff72a32173a0e3d702) +- Update citation suggestions [[1dafdf2dd]](https://github.com/aiidateam/aiida-core/commit/1dafdf2ddb38c801d2075d9af9bbde9e0d26c8ca) + +#### Dependencies +- Add support for Python 3.12 [[c39b4fda4]](https://github.com/aiidateam/aiida-core/commit/c39b4fda40c88737f1c56f5ad6f42cbed974478b) +- Update to `sqlalchemy~=2.0` [[a216f5052]](https://github.com/aiidateam/aiida-core/commit/a216f5052c56bbbeffac296fcd59af177f703829) +- Update to `disk-objectstore~=1.0` [[56f9f6ca0]](https://github.com/aiidateam/aiida-core/commit/56f9f6ca03c7b69766e725449fd955848577055a) +- Add new extra `tui` that provides `verdi` as a TUI [[a42e09c02]](https://github.com/aiidateam/aiida-core/commit/a42e09c026e793e5670b88037d5f4863cc4097f0) +- Add upper limit `jedi<0.19` [[fae2a9cfd]](https://github.com/aiidateam/aiida-core/commit/fae2a9cfda461a26e80b648795e45087ea8133fd) +- Update requirement `mypy~=1.7` [[c2fcad4ab]](https://github.com/aiidateam/aiida-core/commit/c2fcad4ab3f6bc1899475af037e4b14f3497feec) +- Add compatibility for `pymatgen>=v2023.9.2` [[4e0e7d8e9]](https://github.com/aiidateam/aiida-core/commit/4e0e7d8e9fd10c4adc3630cf24cebdf749f95351) +- Bump `yapf` to `0.40.0` [[a8ae50853]](https://github.com/aiidateam/aiida-core/commit/a8ae508537d2b6e9ffa1de9beb140065282a30f8) +- Update pre-commit requirement `flynt==1.0.1` [[e01ea4b97]](https://github.com/aiidateam/aiida-core/commit/e01ea4b97d094f0543b0f0c631fa0463c8baf2f5) +- Docker: Pinning mamba version to 1.5.2 [[a6c2dbe1c]](https://github.com//commit/a6c2dbe1c434f0df7790e41632c5dc578edebb97) +- Docker: Bump Python version to 3.10.13 [[b168f2e12]](https://github.com//commit/b168f2e12776136a8601b42dd85d7b2bb4746e30) + +#### Devops +- CI: Use Python 3.10 for `pre-commit` in CI and CD workflows [[f41c8ac90]](https://github.com/aiidateam/aiida-core/commit/f41c8ac9061c379f72286631bfb1c486cc302dc8) +- CI: Using concurrency for CI actions [[4db54b7f8]](https://github.com/aiidateam/aiida-core/commit/4db54b7f833096e2d5f3d439683c28749467b20d) +- CI: Update tox to use Python 3.9 [[227390a52]](https://github.com/aiidateam/aiida-core/commit/227390a52a6dc77faa20cb1cc6372ec7f66e0409) +- Docker: Bump `upload-artifact` action to v4 for Docker workflow [[bfdb2828a]](https://github.com//commit/bfdb2828a823052df52cb5cf61599cbc07b0bb4b) +- Refactor: Replace `all` with `iterall` where beneficial [[8a2fece02]](https://github.com/aiidateam/aiida-core/commit/8a2fece02411c982eb16e8fed8991ffaf75fa76f) +- Pre-commit: Disable `no-member` and `no-name-in-module` for `aiida.orm` [[15379bbee]](https://github.com/aiidateam/aiida-core/commit/15379bbee2cbf9889772d497e1a6b77e230aaa2f) +- Tests: Move memory leak tests to main unit test suite [[561f93cef]](https://github.com/aiidateam/aiida-core/commit/561f93cef15355e08a3ec19173132deec031ed67) +- Tests: Move ipython magic tests to main unit test suite [[ce9acc312]](https://github.com/aiidateam/aiida-core/commit/ce9acc312c0cfe351f188d399046de6a4248cb16) +- Tests: Remove deprecated `aiida/manage/tests/main` module [[5b9da7d1e]](https://github.com/aiidateam/aiida-core/commit/5b9da7d1eeb3cb01474f2c95526148ba136c6f3c) +- Tests: Refactor transport tests from `unittest` to `pytest` [[ec64780c2]](https://github.com/aiidateam/aiida-core/commit/ec64780c206cdb040eee740b17865e6f0ff81cd8) +- Tests: Fix failing `tests/cmdline/commands/test_setup.py` [[b6f7ec188]](https://github.com/aiidateam/aiida-core/commit/b6f7ec18830d8495a76eefb3ef59e0069db49f99) +- Tests: Print stack trace if CLI command excepts with `run_cli_command` [[08cba0f78]](https://github.com/aiidateam/aiida-core/commit/08cba0f78acbf3da760f8d9110426b80df20ab3a) +- Tests: Make `PsqlDosStorage` profile unload test more robust [[1c72eac1f]](https://github.com/aiidateam/aiida-core/commit/1c72eac1f91e02bc464c66328ea74911762b94fd) +- Tests: Fix flaky work chain tests using `recwarn` fixture [[207151784]](https://github.com/aiidateam/aiida-core/commit/2071517849820e218a28d3968e45d211e8cd6247) +- Tests: Fix `StructureData` test breaking for recent `pymatgen` versions [[d1d64e800]](https://github.com/aiidateam/aiida-core/commit/d1d64e8004c31209488f71a160a4f4824d02c081) +- Typing: Improve annotations of process functions [[a85af4f0c]](https://github.com/aiidateam/aiida-core/commit/a85af4f0c017b8c03426ef7927163a33add08004) +- Typing: Add type hinting for `aiida.orm.nodes.data.array.xy` [[2eaa5449b]](https://github.com/aiidateam/aiida-core/commit/2eaa5449bca55ac87475900dd64ca086bddc0023) +- Typing: Add type hinting for `aiida.orm.nodes.data.array.array` [[c19b1423a]](https://github.com/aiidateam/aiida-core/commit/c19b1423adfb0b8490cdfb899cabd8e88e03237f) +- Typing: Add overload signatures for `open` [[0986f6b59]](https://github.com/aiidateam/aiida-core/commit/0986f6b59086e2e0947906654c1642cf264b462e) +- Typing: Add overload signatures for `get_object_content` [[d18eedc8b]](https://github.com/aiidateam/aiida-core/commit/d18eedc8be565af12f36e48bd8392e9b29438c15) +- Typing: Correct type annotation of `WorkChain.on_wait` [[923cc314c]](https://github.com/aiidateam/aiida-core/commit/923cc314c527a183e55819b96de8ae027c9f0612) +- Typing: Improve type hinting for `aiida.orm.nodes.data.singlefile` [[b9d087dd4]](https://github.com/aiidateam/aiida-core/commit/b9d087dd47c2b09878d078fc6a64cede0e1ce5e1) + + +## v2.4.2 - 2023-11-30 + +### Docker +- Disable the consumer timeout for RabbitMQ [[5ce1e7ec3]](https://github.com/aiidateam/aiida-core/commit/5ce1e7ec37207013a7733b9df943977a15e421e5) +- Add `rsync` and `graphviz` to system requirements [[c4799add4]](https://github.com/aiidateam/aiida-core/commit/c4799add41a29944dd02be2ca44756eaf8035b1c) + +### Dependencies +- Add upper limit `jedi<0.19` [[90e586fe3]](https://github.com/aiidateam/aiida-core/commit/90e586fe367daf8f9ebe953c2a976bc5c4d33903) + + +## v2.4.1 - 2023-11-15 + +This patch release comes with an improved set of Docker images and a few fixes to provide compatibility with recent versions of `pymatgen`. + +### Docker +- Improved Docker images [[fec4e3bc4]](https://github.com/aiidateam/aiida-core/commit/fec4e3bc4dffd7d15b63e7ef0f306a8034ca3816) +- Add folders that automatically run scripts before/after daemon start in Docker image [[fe4bc1d3d]](https://github.com/aiidateam/aiida-core/commit/fe4bc1d3d380686094021515baf31babf47388ac) +- Pass environment variable to `aiida-prepare` script in Docker image [[ea47668ea]](https://github.com/aiidateam/aiida-core/commit/ea47668ea9b38581fbe1b6c72e133824043a8d38) +- Update the `.devcontainer` to use the new docker stack [[413a0db65]](https://github.com/aiidateam/aiida-core/commit/413a0db65cb31156e6e794dac4f8d36e74b0b2cb) + +### Dependencies +- Add compatibility for `pymatgen>=v2023.9.2` [[1f6027f06]](https://github.com/aiidateam/aiida-core/commit/1f6027f062a9eca5d8006741df91545d8ec01ed3) + +### Devops +- Tests: Make `PsqlDosStorage` profile unload test more robust [[f392459bd]](https://github.com/aiidateam/aiida-core/commit/f392459bd417bec8a3ce184ee8f753649bcb77b8) +- Tests: Fix `StructureData` test breaking for recent `pymatgen` versions [[093037d48]](https://github.com/aiidateam/aiida-core/commit/093037d48a2d92cbb6f068c1111fe1564a4500c0) +- Trigger Docker image build when pushing to `support/*` branch [[5cf3d1d75]](https://github.com/aiidateam/aiida-core/commit/5cf3d1d75e8d22d6a3f0909c84aa63cc228bcf4b) +- Use `aiida-core-base` image from `ghcr.io` [[0e5b1c747]](https://github.com/aiidateam/aiida-core/commit/0e5b1c7473030dd5b5027ea4eb0a658db9174091) +- Loosen trigger conditions for Docker build CI workflow [[22e8a8069]](https://github.com/aiidateam/aiida-core/commit/22e8a80690747b792b70f96a0e332906f0e65e97) +- Follow-up docker build runner macOS-ARM64 [[1bd9bf03d]](https://github.com/aiidateam/aiida-core/commit/1bd9bf03d19dda4c462728fb87cf4712b74c5f39) +- Upload artifact by PR from forks for docker workflow [[afc2dad8a]](https://github.com/aiidateam/aiida-core/commit/afc2dad8a68e280f01e89fcb5b13e7a60c2fd072) +- Update the image name for docker image [[17507b410]](https://github.com/aiidateam/aiida-core/commit/17507b4108b5dd1cd6e074b08e0bc2535bf0a164) + + ## v2.4.0 - 2023-06-22 This minor release comes with a number of new features and improvements as well as a significant amount of bug fixes. @@ -996,7 +1679,7 @@ This command runs the `verdi` CLI using the currently loaded profile of the IPyt %verdi status ``` -See the [Basic Tutorial](docs/source/intro/tutorial.md) for example usage. +See the [Basic Tutorial](docs/source/tutorials/basic.md) for example usage. ### New `SqliteTempBackend` ✨ @@ -1021,7 +1704,7 @@ profile = load_profile( ) ``` -See the [Basic Tutorial](docs/source/intro/tutorial.md) for example usage. +See the [Basic Tutorial](docs/source/tutorials/basic.md) for example usage. ### Key Pull Requests diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 9085a5e3ab..0000000000 --- a/Dockerfile +++ /dev/null @@ -1,23 +0,0 @@ -FROM aiidateam/aiida-prerequisites:0.7.0 - -USER root - -ENV SETUP_DEFAULT_PROFILE true - -ENV PROFILE_NAME default -ENV USER_EMAIL aiida@localhost -ENV USER_FIRST_NAME Giuseppe -ENV USER_LAST_NAME Verdi -ENV USER_INSTITUTION Khedivial -ENV AIIDADB_BACKEND core.psql_dos - -# Copy and install AiiDA -COPY . aiida-core -RUN pip install ./aiida-core[atomic_tools] - -# Configure aiida for the user -COPY .docker/opt/configure-aiida.sh /opt/configure-aiida.sh -COPY .docker/my_init.d/configure-aiida.sh /etc/my_init.d/40_configure-aiida.sh - -# Use phusion baseimage docker init system. -CMD ["/sbin/my_init"] diff --git a/README.md b/README.md index d722b0dc31..31564c6815 100644 --- a/README.md +++ b/README.md @@ -5,10 +5,10 @@ AiiDA (www.aiida.net) is a workflow manager for computational science with a str | | | |-----|----------------------------------------------------------------------------| |Latest release| [![PyPI version](https://badge.fury.io/py/aiida-core.svg)](https://badge.fury.io/py/aiida-core) [![conda-forge](https://img.shields.io/conda/vn/conda-forge/aiida-core.svg?style=flat)](https://anaconda.org/conda-forge/aiida-core) [![PyPI pyversions](https://img.shields.io/pypi/pyversions/aiida-core.svg)](https://pypi.python.org/pypi/aiida-core/) | -|Getting help| [![Docs status](https://readthedocs.org/projects/aiida-core/badge)](http://aiida-core.readthedocs.io/) [![Google Group](https://img.shields.io/badge/-Google%20Group-lightgrey.svg)](https://groups.google.com/forum/#!forum/aiidausers) +|Getting help| [![Docs status](https://readthedocs.org/projects/aiida-core/badge)](http://aiida-core.readthedocs.io/) [![Discourse status](https://img.shields.io/discourse/status?server=https%3A%2F%2Faiida.discourse.group%2F)](https://aiida.discourse.group/) |Build status| [![Build Status](https://github.com/aiidateam/aiida-core/actions/workflows/ci-code.yml/badge.svg)](https://github.com/aiidateam/aiida-core/actions) [![Coverage Status](https://codecov.io/gh/aiidateam/aiida-core/branch/main/graph/badge.svg)](https://codecov.io/gh/aiidateam/aiida-core) [Benchmarks](https://aiidateam.github.io/aiida-core/dev/bench/ubuntu-22.04/psql_dos/) | |Activity| [![PyPI-downloads](https://img.shields.io/pypi/dm/aiida-core.svg?style=flat)](https://pypistats.org/packages/aiida-core) [![Commit Activity](https://img.shields.io/github/commit-activity/m/aiidateam/aiida-core.svg)](https://github.com/aiidateam/aiida-core/pulse) -|Community| [![Affiliated with NumFOCUS](https://img.shields.io/badge/NumFOCUS-affiliated%20project-orange.svg?style=flat&colorA=E1523D&colorB=007D8A)](https://numfocus.org/sponsored-projects/affiliated-projects) [![Twitter](https://img.shields.io/twitter/follow/aiidateam.svg?style=social&label=Follow)](https://twitter.com/aiidateam) +|Community| [![Discourse](https://img.shields.io/discourse/topics?server=https%3A%2F%2Faiida.discourse.group%2F&logo=discourse)](https://aiida.discourse.group/) [![Affiliated with NumFOCUS](https://img.shields.io/badge/NumFOCUS-affiliated%20project-orange.svg?style=flat&colorA=E1523D&colorB=007D8A)](https://numfocus.org/sponsored-projects/affiliated-projects) [![Twitter](https://img.shields.io/twitter/follow/aiidateam.svg?style=social&label=Follow)](https://twitter.com/aiidateam) ## Features @@ -48,14 +48,18 @@ Please see the [Contributor wiki](https://github.com/aiidateam/aiida-core/wiki) ## Frequently Asked Questions If you are experiencing problems with your AiiDA installation, please refer to the [FAQ page of the documentation](https://aiida-core.readthedocs.io/en/latest/howto/faq.html). +For any other questions, discussion and requests for support, please visit the [Discourse server](https://aiida.discourse.group/). ## How to cite If you use AiiDA in your research, please consider citing the following publications: - * **AiiDA >= 1.0**: S. P. Huber *et al.*, *AiiDA 1.0, a scalable computational infrastructure for automated reproducible workflows and data provenance*, Scientific Data **7**, 300 (2020); DOI: [10.1038/s41597-020-00638-4](https://doi.org/10.1038/s41597-020-00638-4) - * **AiiDA >= 1.0**: M. Uhrin *et al.*, *Workflows in AiiDA: Engineering a high-throughput, event-based engine for robust and modular computational workflows*, Computational Materials Science **187**, 110086 (2021); DOI: [10.1016/j.commatsci.2020.110086](https://doi.org/10.1016/j.commatsci.2020.110086) - * **AiiDA < 1.0**: Giovanni Pizzi, Andrea Cepellotti, Riccardo Sabatini, Nicola Marzari,and Boris Kozinsky, *AiiDA: automated interactive infrastructure and database for computational science*, Computational Materials Science **111**, 218-230 (2016); DOI: [10.1016/j.commatsci.2015.09.013](https://doi.org/10.1016/j.commatsci.2015.09.013) + * S. P. Huber *et al.*, *AiiDA 1.0, a scalable computational infrastructure for automated reproducible workflows and data provenance*, Scientific Data **7**, 300 (2020); DOI: [10.1038/s41597-020-00638-4](https://doi.org/10.1038/s41597-020-00638-4) + * M. Uhrin *et al.*, *Workflows in AiiDA: Engineering a high-throughput, event-based engine for robust and modular computational workflows*, Computational Materials Science **187**, 110086 (2021); DOI: [10.1016/j.commatsci.2020.110086](https://doi.org/10.1016/j.commatsci.2020.110086) + +If the ADES concepts are referenced, please also cite: + +* Giovanni Pizzi, Andrea Cepellotti, Riccardo Sabatini, Nicola Marzari,and Boris Kozinsky, *AiiDA: automated interactive infrastructure and database for computational science*, Computational Materials Science **111**, 218-230 (2016); DOI: [10.1016/j.commatsci.2015.09.013](https://doi.org/10.1016/j.commatsci.2015.09.013) ## License diff --git a/aiida/__init__.py b/aiida/__init__.py deleted file mode 100644 index df6a8b25b6..0000000000 --- a/aiida/__init__.py +++ /dev/null @@ -1,97 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -AiiDA is a flexible and scalable informatics' infrastructure to manage, -preserve, and disseminate the simulations, data, and workflows of -modern-day computational science. - -Able to store the full provenance of each object, and based on a -tailored database built for efficient data mining of heterogeneous results, -AiiDA gives the user the ability to interact seamlessly with any number of -remote HPC resources and codes, thanks to its flexible plugin interface -and workflow engine for the automation of complex sequences of simulations. - -More information at http://www.aiida.net -""" -from aiida.common.log import configure_logging -from aiida.manage.configuration import get_config_option, get_profile, load_profile, profile_context - -__copyright__ = ( - 'Copyright (c), This file is part of the AiiDA platform. ' - 'For further information please visit http://www.aiida.net/. All rights reserved.' -) -__license__ = 'MIT license, see LICENSE.txt file.' -__version__ = '2.4.0.post0' -__authors__ = 'The AiiDA team.' -__paper__ = ( - 'S. P. Huber et al., "AiiDA 1.0, a scalable computational infrastructure for automated reproducible workflows and ' - 'data provenance", Scientific Data 7, 300 (2020); https://doi.org/10.1038/s41597-020-00638-4' -) -__paper_short__ = 'S. P. Huber et al., Scientific Data 7, 300 (2020).' - - -def get_strict_version(): - """ - Return a distutils StrictVersion instance with the current distribution version - - :returns: StrictVersion instance with the current version - :rtype: :class:`!distutils.version.StrictVersion` - """ - from distutils.version import StrictVersion - - from aiida.common.warnings import warn_deprecation - warn_deprecation( - 'This method is deprecated as the `distutils` package it uses will be removed in Python 3.12.', version=3 - ) - return StrictVersion(__version__) - - -def get_version() -> str: - """ - Return the current AiiDA distribution version - - :returns: the current version - """ - return __version__ - - -def _get_raw_file_header() -> str: - """ - Get the default header for source AiiDA source code files. - Note: is not preceded by comment character. - - :return: default AiiDA source file header - """ - return f"""This file has been created with AiiDA v. {__version__} -If you use AiiDA for publication purposes, please cite: -{__paper__} -""" - - -def get_file_header(comment_char: str = '# ') -> str: - """ - Get the default header for source AiiDA source code files. - - .. note:: - - Prepend by comment character. - - :param comment_char: string put in front of each line - - :return: default AiiDA source file header - """ - lines = _get_raw_file_header().splitlines() - return '\n'.join(f'{comment_char}{line}' for line in lines) - - -def load_ipython_extension(ipython): - """Load the AiiDA IPython extension, using ``%load_ext aiida``.""" - from .tools.ipython.ipython_magics import AiiDALoaderMagics - ipython.register_magics(AiiDALoaderMagics) diff --git a/aiida/calculations/__init__.py b/aiida/calculations/__init__.py deleted file mode 100644 index 2776a55f97..0000000000 --- a/aiida/calculations/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### diff --git a/aiida/calculations/arithmetic/__init__.py b/aiida/calculations/arithmetic/__init__.py deleted file mode 100644 index 2776a55f97..0000000000 --- a/aiida/calculations/arithmetic/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### diff --git a/aiida/calculations/arithmetic/add.py b/aiida/calculations/arithmetic/add.py deleted file mode 100644 index b3d327b19c..0000000000 --- a/aiida/calculations/arithmetic/add.py +++ /dev/null @@ -1,73 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""`CalcJob` implementation to add two numbers using bash for testing and demonstration purposes.""" -from aiida import orm -from aiida.common.datastructures import CalcInfo, CodeInfo -from aiida.common.folders import Folder -from aiida.engine import CalcJob, CalcJobProcessSpec - - -class ArithmeticAddCalculation(CalcJob): - """`CalcJob` implementation to add two numbers using bash for testing and demonstration purposes.""" - - @classmethod - def define(cls, spec: CalcJobProcessSpec): - """Define the process specification, including its inputs, outputs and known exit codes. - - :param spec: the calculation job process spec to define. - """ - super().define(spec) - spec.input('x', valid_type=(orm.Int, orm.Float), help='The left operand.') - spec.input('y', valid_type=(orm.Int, orm.Float), help='The right operand.') - spec.output('sum', valid_type=(orm.Int, orm.Float), help='The sum of the left and right operand.') - spec.input('metadata.options.sleep', required=False, valid_type=int) - # set default options (optional) - spec.inputs['metadata']['options']['parser_name'].default = 'core.arithmetic.add' - spec.inputs['metadata']['options']['input_filename'].default = 'aiida.in' - spec.inputs['metadata']['options']['output_filename'].default = 'aiida.out' - spec.inputs['metadata']['options']['resources'].default = {'num_machines': 1, 'num_mpiprocs_per_machine': 1} - # start exit codes - marker for docs - spec.exit_code( - 310, 'ERROR_READING_OUTPUT_FILE', invalidates_cache=True, message='The output file could not be read.' - ) - spec.exit_code( - 320, 'ERROR_INVALID_OUTPUT', invalidates_cache=True, message='The output file contains invalid output.' - ) - spec.exit_code(410, 'ERROR_NEGATIVE_NUMBER', message='The sum of the operands is a negative number.') - # end exit codes - marker for docs - - def prepare_for_submission(self, folder: Folder) -> CalcInfo: - """Prepare the calculation for submission. - - Convert the input nodes into the corresponding input files in the format that the code will expect. In addition, - define and return a `CalcInfo` instance, which is a simple data structure that contains information for the - engine, for example, on what files to copy to the remote machine, what files to retrieve once it has completed, - specific scheduler settings and more. - - :param folder: a temporary folder on the local file system. - :returns: the `CalcInfo` instance - """ - with folder.open(self.options.input_filename, 'w', encoding='utf8') as handle: - if 'sleep' in self.options: - handle.write(f'sleep {self.options.sleep}\n') - handle.write(f'echo $(({self.inputs.x.value} + {self.inputs.y.value}))\n') - - codeinfo = CodeInfo() - codeinfo.stdin_name = self.options.input_filename - codeinfo.stdout_name = self.options.output_filename - - if 'code' in self.inputs: - codeinfo.code_uuid = self.inputs.code.uuid - - calcinfo = CalcInfo() - calcinfo.codes_info = [codeinfo] - calcinfo.retrieve_list = [self.options.output_filename] - - return calcinfo diff --git a/aiida/calculations/importers/arithmetic/add.py b/aiida/calculations/importers/arithmetic/add.py deleted file mode 100644 index a7865bee70..0000000000 --- a/aiida/calculations/importers/arithmetic/add.py +++ /dev/null @@ -1,39 +0,0 @@ -# -*- coding: utf-8 -*- -"""Importer for the :class:`aiida.calculations.arithmetic.add.ArithmeticAddCalculation` plugin.""" -from pathlib import Path -from re import match -from tempfile import NamedTemporaryFile -from typing import Dict, Union - -from aiida.engine import CalcJobImporter -from aiida.orm import Int, Node, RemoteData - - -class ArithmeticAddCalculationImporter(CalcJobImporter): - """Importer for the :class:`aiida.calculations.arithmetic.add.ArithmeticAddCalculation` plugin.""" - - @staticmethod - def parse_remote_data(remote_data: RemoteData, **kwargs) -> Dict[str, Union[Node, Dict]]: - """Parse the input nodes from the files in the provided ``RemoteData``. - - :param remote_data: the remote data node containing the raw input files. - :param kwargs: additional keyword arguments to control the parsing process. - :returns: a dictionary with the parsed inputs nodes that match the input spec of the associated ``CalcJob``. - """ - with NamedTemporaryFile('w+') as handle: - with remote_data.get_authinfo().get_transport() as transport: - filepath = Path(remote_data.get_remote_path()) / 'aiida.in' - transport.getfile(filepath, handle.name) - - handle.seek(0) - data = handle.read() - - matches = match(r'echo \$\(\(([0-9]+) \+ ([0-9]+)\)\).*', data.strip()) - - if matches is None: - raise ValueError(f'failed to parse the integers `x` and `y` from the input content: {data}') - - return { - 'x': Int(matches.group(1)), - 'y': Int(matches.group(2)), - } diff --git a/aiida/calculations/monitors/base.py b/aiida/calculations/monitors/base.py deleted file mode 100644 index 9b4f2fd55e..0000000000 --- a/aiida/calculations/monitors/base.py +++ /dev/null @@ -1,29 +0,0 @@ -# -*- coding: utf-8 -*- -"""Monitors for the :class:`aiida.calculations.arithmetic.add.ArithmeticAddCalculation` plugin.""" -from __future__ import annotations - -import tempfile - -from aiida.orm import CalcJobNode -from aiida.transports import Transport - - -def always_kill(node: CalcJobNode, transport: Transport) -> str | None: # pylint: disable=unused-argument - """Retrieve and inspect files in working directory of job to determine whether the job should be killed. - - This particular implementation is just for demonstration purposes and will kill the job as long as there is a - submission script that contains some content, which should always be the case. - - :param node: The node representing the calculation job. - :param transport: The transport that can be used to retrieve files from remote working directory. - :returns: A string if the job should be killed, `None` otherwise. - """ - with tempfile.NamedTemporaryFile('w+') as handle: - transport.getfile('_aiidasubmit.sh', handle.name) - handle.seek(0) - output = handle.read() - - if output: - return 'Detected a non-empty submission script' - - return None diff --git a/aiida/cmdline/__init__.py b/aiida/cmdline/__init__.py deleted file mode 100644 index f34e1f5f52..0000000000 --- a/aiida/cmdline/__init__.py +++ /dev/null @@ -1,63 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""The command line interface of AiiDA.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .groups import * -from .params import * -from .utils import * - -__all__ = ( - 'AbsolutePathParamType', - 'CalculationParamType', - 'CodeParamType', - 'ComputerParamType', - 'ConfigOptionParamType', - 'DataParamType', - 'DynamicEntryPointCommandGroup', - 'EmailType', - 'EntryPointType', - 'FileOrUrl', - 'GroupParamType', - 'HostnameType', - 'IdentifierParamType', - 'LabelStringType', - 'LazyChoice', - 'MpirunCommandParamType', - 'MultipleValueParamType', - 'NodeParamType', - 'NonEmptyStringParamType', - 'PathOrUrl', - 'PluginParamType', - 'ProcessParamType', - 'ProfileParamType', - 'ShebangParamType', - 'UserParamType', - 'VerdiCommandGroup', - 'WorkflowParamType', - 'dbenv', - 'echo_critical', - 'echo_dictionary', - 'echo_error', - 'echo_info', - 'echo_report', - 'echo_success', - 'echo_warning', - 'format_call_graph', - 'is_verbose', - 'only_if_daemon_running', - 'with_dbenv', -) - -# yapf: enable diff --git a/aiida/cmdline/commands/__init__.py b/aiida/cmdline/commands/__init__.py deleted file mode 100644 index 8b99390a26..0000000000 --- a/aiida/cmdline/commands/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Sub commands of the ``verdi`` command line interface. - -The commands need to be imported here for them to be registered with the top-level command group. -""" -from aiida.cmdline.commands import ( - cmd_archive, - cmd_calcjob, - cmd_code, - cmd_computer, - cmd_config, - cmd_daemon, - cmd_data, - cmd_database, - cmd_devel, - cmd_group, - cmd_help, - cmd_node, - cmd_plugin, - cmd_process, - cmd_profile, - cmd_rabbitmq, - cmd_restapi, - cmd_run, - cmd_setup, - cmd_shell, - cmd_status, - cmd_storage, - cmd_user, -) diff --git a/aiida/cmdline/commands/cmd_data/__init__.py b/aiida/cmdline/commands/cmd_data/__init__.py deleted file mode 100644 index 1087e7a864..0000000000 --- a/aiida/cmdline/commands/cmd_data/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""The `verdi data` command line interface.""" - -from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.utils.pluginable import Pluginable - - -@verdi.group('data', entry_point_group='aiida.cmdline.data', cls=Pluginable) -def verdi_data(): - """Inspect, create and manage data nodes.""" diff --git a/aiida/cmdline/commands/cmd_data/cmd_remote.py b/aiida/cmdline/commands/cmd_data/cmd_remote.py deleted file mode 100644 index 84e38ca37f..0000000000 --- a/aiida/cmdline/commands/cmd_data/cmd_remote.py +++ /dev/null @@ -1,87 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""`verdi data core.remote` command.""" -import stat - -import click - -from aiida.cmdline.commands.cmd_data import verdi_data -from aiida.cmdline.params import arguments, types -from aiida.cmdline.utils import echo - - -@verdi_data.group('core.remote') -def remote(): - """Manipulate RemoteData objects (reference to remote folders). - - A RemoteData can be thought as a "symbolic link" to a folder on one of the - Computers set up in AiiDA (e.g. where a CalcJob will run). - This folder is called "remote" in the sense that it is on a Computer and - not in the AiiDA repository. Note, however, that the "remote" computer - could also be "localhost".""" - - -@remote.command('ls') -@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:core.remote',))) -@click.option('-l', '--long', 'ls_long', is_flag=True, default=False, help='Display also file metadata.') -@click.option('-p', '--path', type=click.STRING, default='.', help='The folder to list.') -def remote_ls(ls_long, path, datum): - """List content of a (sub)directory in a RemoteData object.""" - import datetime - try: - content = datum.listdir_withattributes(path=path) - except (IOError, OSError) as err: - echo.echo_critical( - f'Unable to access the remote folder or file, check if it exists.\nOriginal error: {str(err)}' - ) - for metadata in content: - if ls_long: - mtime = datetime.datetime.fromtimestamp(metadata['attributes'].st_mtime) - pre_line = '{} {:10} {} '.format( - stat.filemode(metadata['attributes'].st_mode), metadata['attributes'].st_size, - mtime.strftime('%d %b %Y %H:%M') - ) - echo.echo(pre_line, nl=False) - if metadata['isdir']: - echo.echo(metadata['name'], fg=echo.COLORS['info']) - else: - echo.echo(metadata['name']) - - -@remote.command('cat') -@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:core.remote',))) -@click.argument('path', type=click.STRING) -def remote_cat(datum, path): - """Show content of a file in a RemoteData object.""" - import os - import sys - import tempfile - try: - with tempfile.NamedTemporaryFile(delete=False) as tmpf: - tmpf.close() - datum.getfile(path, tmpf.name) - with open(tmpf.name, encoding='utf8') as fhandle: - sys.stdout.write(fhandle.read()) - except IOError as err: - echo.echo_critical(f'{err.errno}: {str(err)}') - - try: - os.remove(tmpf.name) - except OSError: - # If you cannot delete, ignore (maybe I didn't manage to create it in the first place - pass - - -@remote.command('show') -@arguments.DATUM(type=types.DataParamType(sub_classes=('aiida.data:core.remote',))) -def remote_show(datum): - """Show information for a RemoteData object.""" - echo.echo(f'- Remote computer name: {datum.computer.label}') - echo.echo(f'- Remote folder full path: {datum.get_remote_path()}') diff --git a/aiida/cmdline/commands/cmd_data/cmd_show.py b/aiida/cmdline/commands/cmd_data/cmd_show.py deleted file mode 100644 index c69211d489..0000000000 --- a/aiida/cmdline/commands/cmd_data/cmd_show.py +++ /dev/null @@ -1,233 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -This allows to manage showfunctionality to all data types. -""" -import pathlib - -import click - -from aiida.cmdline.params import options -from aiida.cmdline.params.options.multivalue import MultipleValueOption -from aiida.cmdline.utils import echo -from aiida.common.exceptions import MultipleObjectsError - -SHOW_OPTIONS = [ - options.TRAJECTORY_INDEX(), - options.WITH_ELEMENTS(), - click.option('-c', '--contour', type=click.FLOAT, cls=MultipleValueOption, default=None, help='Isovalues to plot'), - click.option( - '--sampling-stepsize', - type=click.INT, - default=None, - help='Sample positions in plot every sampling_stepsize timestep' - ), - click.option( - '--stepsize', - type=click.INT, - default=None, - help='The stepsize for the trajectory, set it higher to reduce number of points' - ), - click.option('--mintime', type=click.INT, default=None, help='The time to plot from'), - click.option('--maxtime', type=click.INT, default=None, help='The time to plot to'), - click.option('--indices', type=click.INT, cls=MultipleValueOption, default=None, help='Show only these indices'), - click.option( - '--dont-block', 'block', is_flag=True, default=True, help="Don't block interpreter when showing plot." - ), -] - - -def show_options(func): - for option in reversed(SHOW_OPTIONS): - func = option(func) - - return func - - -def _show_jmol(exec_name, trajectory_list, **kwargs): - """ - Plugin for jmol - """ - import subprocess - import tempfile - - # pylint: disable=protected-access - with tempfile.NamedTemporaryFile(mode='w+b') as handle: - for trajectory in trajectory_list: - handle.write(trajectory._exportcontent('cif', **kwargs)[0]) - handle.flush() - - try: - subprocess.check_output([exec_name, handle.name]) - except subprocess.CalledProcessError: - # The program died: just print a message - echo.echo_error(f'the call to {exec_name} ended with an error.') - except OSError as err: - if err.errno == 2: - echo.echo_critical(f"No executable '{exec_name}' found. Add to the path, or try with an absolute path.") - else: - raise - - -def _show_xcrysden(exec_name, object_list, **kwargs): - """ - Plugin for xcrysden - """ - import subprocess - import tempfile - - if len(object_list) > 1: - raise MultipleObjectsError('Visualization of multiple trajectories is not implemented') - obj = object_list[0] - - # pylint: disable=protected-access - with tempfile.NamedTemporaryFile(mode='w+b', suffix='.xsf') as tmpf: - tmpf.write(obj._exportcontent('xsf', **kwargs)[0]) - tmpf.flush() - - try: - subprocess.check_output([exec_name, '--xsf', tmpf.name]) - except subprocess.CalledProcessError: - # The program died: just print a message - echo.echo_error(f'the call to {exec_name} ended with an error.') - except OSError as err: - if err.errno == 2: - echo.echo_critical(f"No executable '{exec_name}' found. Add to the path, or try with an absolute path.") - else: - raise - - -# pylint: disable=unused-argument -def _show_mpl_pos(exec_name, trajectory_list, **kwargs): - """ - Produces a matplotlib plot of the trajectory - """ - for traj in trajectory_list: - traj.show_mpl_pos(**kwargs) - - -# pylint: disable=unused-argument -def _show_mpl_heatmap(exec_name, trajectory_list, **kwargs): - """ - Produces a matplotlib plot of the trajectory - """ - for traj in trajectory_list: - traj.show_mpl_heatmap(**kwargs) - - -# pylint: disable=unused-argument -def _show_ase(exec_name, structure_list): - """ - Plugin to show the structure with the ASE visualizer - """ - try: - from ase.visualize import view - for structure in structure_list: - view(structure.get_ase()) - except ImportError: # pylint: disable=try-except-raise - raise - - -def _show_vesta(exec_name, structure_list): - """ - Plugin for VESTA - This VESTA plugin was added by Yue-Wen FANG and Abel Carreras - at Kyoto University in the group of Prof. Isao Tanaka's lab - - """ - import subprocess - import tempfile - - # pylint: disable=protected-access - with tempfile.NamedTemporaryFile(mode='w+b', suffix='.cif') as tmpf: - for structure in structure_list: - tmpf.write(structure._exportcontent('cif')[0]) - tmpf.flush() - - try: - subprocess.check_output([exec_name, tmpf.name]) - except subprocess.CalledProcessError: - # The program died: just print a message - echo.echo_error(f'the call to {exec_name} ended with an error.') - except OSError as err: - if err.errno == 2: - echo.echo_critical(f"No executable '{exec_name}' found. Add to the path, or try with an absolute path.") - else: - raise - - -def _show_vmd(exec_name, structure_list): - """ - Plugin for vmd - """ - import subprocess - import tempfile - - if len(structure_list) > 1: - raise MultipleObjectsError('Visualization of multiple objects is not implemented') - structure = structure_list[0] - - # pylint: disable=protected-access - with tempfile.NamedTemporaryFile(mode='w+b', suffix='.xsf') as tmpf: - tmpf.write(structure._exportcontent('xsf')[0]) - tmpf.flush() - - try: - subprocess.check_output([exec_name, tmpf.name]) - except subprocess.CalledProcessError: - # The program died: just print a message - echo.echo_error(f'the call to {exec_name} ended with an error.') - except OSError as err: - if err.errno == 2: - echo.echo_critical(f"No executable '{exec_name}' found. Add to the path, or try with an absolute path.") - else: - raise - - -def _show_xmgrace(exec_name, list_bands): - """ - Plugin for showing the bands with the XMGrace plotting software. - """ - import subprocess - import sys - import tempfile - - from aiida.orm.nodes.data.array.bands import MAX_NUM_AGR_COLORS - - list_files = [] - current_band_number = 0 - - with tempfile.TemporaryDirectory() as tmpdir: - - dirpath = pathlib.Path(tmpdir) - - for iband, bnds in enumerate(list_bands): - # extract number of bands - nbnds = bnds.get_bands().shape[1] - text, _ = bnds._exportcontent( # pylint: disable=protected-access - 'agr', setnumber_offset=current_band_number, color_number=(iband + 1 % MAX_NUM_AGR_COLORS) - ) - # write a tempfile - filepath = dirpath / f'{iband}.agr' - filepath.write_bytes(text) - list_files.append(str(filepath)) - # update the number of bands already plotted - current_band_number += nbnds - - try: - subprocess.check_output([exec_name] + [str(filepath) for filepath in list_files]) - except subprocess.CalledProcessError: - print(f'Note: the call to {exec_name} ended with an error.') - except OSError as err: - if err.errno == 2: - print(f"No executable '{exec_name}' found. Add to the path, or try with an absolute path.") - sys.exit(1) - else: - raise diff --git a/aiida/cmdline/commands/cmd_database.py b/aiida/cmdline/commands/cmd_database.py deleted file mode 100644 index 2653759f81..0000000000 --- a/aiida/cmdline/commands/cmd_database.py +++ /dev/null @@ -1,131 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""`verdi database` commands.""" -# pylint: disable=unused-argument - -import click - -from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import options -from aiida.cmdline.utils import decorators - - -@verdi.group('database', hidden=True) -def verdi_database(): - """Inspect and manage the database. - - .. deprecated:: v2.0.0 - """ - - -@verdi_database.command('version') -@decorators.deprecated_command( - 'This command has been deprecated and no longer has any effect. It will be removed soon from the CLI (in v2.1).\n' - 'The same information is now available through `verdi storage version`.\n' -) -def database_version(): - """Show the version of the database. - - The database version is defined by the tuple of the schema generation and schema revision. - - .. deprecated:: v2.0.0 - """ - - -@verdi_database.command('migrate') -@options.FORCE() -@click.pass_context -@decorators.deprecated_command( - 'This command has been deprecated and will be removed soon (in v3.0). ' - 'Please call `verdi storage migrate` instead.\n' -) -def database_migrate(ctx, force): - """Migrate the database to the latest schema version. - - .. deprecated:: v2.0.0 - """ - from aiida.cmdline.commands.cmd_storage import storage_migrate - ctx.forward(storage_migrate) - - -@verdi_database.group('integrity') -def verdi_database_integrity(): - """Check the integrity of the database and fix potential issues. - - .. deprecated:: v2.0.0 - """ - - -@verdi_database_integrity.command('detect-duplicate-uuid') -@click.option( - '-t', - '--table', - default='db_dbnode', - type=click.Choice(('db_dbcomment', 'db_dbcomputer', 'db_dbgroup', 'db_dbnode')), - help='The database table to operate on.' -) -@click.option( - '-a', '--apply-patch', is_flag=True, help='Actually apply the proposed changes instead of performing a dry run.' -) -@decorators.deprecated_command( - 'This command has been deprecated and no longer has any effect. It will be removed soon from the CLI (in v2.1).\n' - 'For remaining available integrity checks, use `verdi storage integrity` instead.\n' -) -def detect_duplicate_uuid(table, apply_patch): - """Detect and fix entities with duplicate UUIDs. - - Before aiida-core v1.0.0, there was no uniqueness constraint on the UUID column of the node table in the database - and a few other tables as well. This made it possible to store multiple entities with identical UUIDs in the same - table without the database complaining. This bug was fixed in aiida-core=1.0.0 by putting an explicit uniqueness - constraint on UUIDs on the database level. However, this would leave databases created before this patch with - duplicate UUIDs in an inconsistent state. This command will run an analysis to detect duplicate UUIDs in a given - table and solve it by generating new UUIDs. Note that it will not delete or merge any rows. - - - .. deprecated:: v2.0.0 - """ - - -@verdi_database_integrity.command('detect-invalid-links') -@decorators.with_dbenv() -@decorators.deprecated_command( - 'This command has been deprecated and no longer has any effect. It will be removed soon from the CLI (in v2.1).\n' - 'For remaining available integrity checks, use `verdi storage integrity` instead.\n' -) -def detect_invalid_links(): - """Scan the database for invalid links. - - .. deprecated:: v2.0.0 - """ - - -@verdi_database_integrity.command('detect-invalid-nodes') -@decorators.with_dbenv() -@decorators.deprecated_command( - 'This command has been deprecated and no longer has any effect. It will be removed soon from the CLI (in v2.1).\n' - 'For remaining available integrity checks, use `verdi storage integrity` instead.\n' -) -def detect_invalid_nodes(): - """Scan the database for invalid nodes. - - .. deprecated:: v2.0.0 - """ - - -@verdi_database.command('summary') -@decorators.deprecated_command( - 'This command has been deprecated and no longer has any effect. It will be removed soon from the CLI (in v2.1).\n' - 'Please call `verdi storage info` instead.\n' -) -def database_summary(): - """Summarise the entities in the database. - - .. deprecated:: v2.0.0 - """ diff --git a/aiida/cmdline/commands/cmd_devel.py b/aiida/cmdline/commands/cmd_devel.py deleted file mode 100644 index 0f24dfd724..0000000000 --- a/aiida/cmdline/commands/cmd_devel.py +++ /dev/null @@ -1,199 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""`verdi devel` commands.""" -import sys - -import click - -from aiida import get_profile -from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import options, types -from aiida.cmdline.utils import decorators, echo -from aiida.common import exceptions - - -@verdi.group('devel') -def verdi_devel(): - """Commands for developers.""" - - -@verdi_devel.command('check-load-time') -def devel_check_load_time(): - """Check for common indicators that slowdown `verdi`. - - Check for environment properties that negatively affect the responsiveness of the `verdi` command line interface. - Known pathways that increase load time: - - * the database environment is loaded when it doesn't need to be - * Unexpected `aiida.*` modules are imported - - If either of these conditions are true, the command will raise a critical error - """ - from aiida.manage import get_manager - - loaded_aiida_modules = [key for key in sys.modules if key.startswith('aiida.')] - aiida_modules_str = '\n- '.join(sorted(loaded_aiida_modules)) - echo.echo_info(f'aiida modules loaded:\n- {aiida_modules_str}') - - manager = get_manager() - - if manager.profile_storage_loaded: - echo.echo_critical('potential `verdi` speed problem: database backend is loaded.') - - allowed = ('aiida.cmdline', 'aiida.common', 'aiida.manage', 'aiida.plugins', 'aiida.restapi') - for loaded in loaded_aiida_modules: - if not any(loaded.startswith(mod) for mod in allowed): - echo.echo_critical( - f'potential `verdi` speed problem: `{loaded}` module is imported which is not in: {allowed}' - ) - - echo.echo_success('no issues detected') - - -@verdi_devel.command('check-undesired-imports') -def devel_check_undesired_imports(): - """Check that verdi does not import python modules it shouldn't. - - Note: The blacklist was taken from the list of packages in the 'atomic_tools' extra but can be extended. - """ - loaded_modules = 0 - - for modulename in ['seekpath', 'CifFile', 'ase', 'pymatgen', 'spglib', 'pymysql']: - if modulename in sys.modules: - echo.echo_warning(f'Detected loaded module "{modulename}"') - loaded_modules += 1 - - if loaded_modules > 0: - echo.echo_critical(f'Detected {loaded_modules} unwanted modules') - echo.echo_success('no issues detected') - - -@verdi_devel.command('validate-plugins') -@decorators.with_dbenv() -def devel_validate_plugins(): - """Validate all plugins by checking they can be loaded.""" - from aiida.common.exceptions import EntryPointError - from aiida.plugins.entry_point import validate_registered_entry_points - - try: - validate_registered_entry_points() - except EntryPointError as exception: - echo.echo_critical(str(exception)) - - echo.echo_success('all registered plugins could successfully loaded.') - - -@verdi_devel.command('run-sql') -@click.argument('sql', type=str) -def devel_run_sql(sql): - """Run a raw SQL command on the profile database (only available for 'core.psql_dos' storage).""" - from sqlalchemy import text - - from aiida.storage.psql_dos.utils import create_sqlalchemy_engine - assert get_profile().storage_backend == 'core.psql_dos' - with create_sqlalchemy_engine(get_profile().storage_config).connect() as connection: - result = connection.execute(text(sql)).fetchall() - - if isinstance(result, (list, tuple)): - for row in result: - echo.echo(str(row)) - else: - echo.echo(str(result)) - - -@verdi_devel.command('play', hidden=True) -def devel_play(): - """Play the Aida triumphal march by Giuseppe Verdi.""" - import webbrowser - webbrowser.open_new('http://upload.wikimedia.org/wikipedia/commons/3/32/Triumphal_March_from_Aida.ogg') - - -@verdi_devel.command('launch-add') -@options.CODE(type=types.CodeParamType(entry_point='core.arithmetic.add')) -@click.option('-d', '--daemon', is_flag=True, help='Submit to the daemon instead of running blockingly.') -@click.option('-s', '--sleep', type=int, help='Set the `sleep` input in seconds.') -def devel_launch_arithmetic_add(code, daemon, sleep): - """Launch an ``ArithmeticAddCalculation``. - - Unless specified with the option ``--code``, a suitable ``Code`` is automatically setup. By default the command - configures ``bash`` on the ``localhost``. If the localhost is not yet configured as a ``Computer``, that is also - done automatically. - """ - from shutil import which - - from aiida.engine import run, submit - from aiida.orm import InstalledCode, Int, load_code - - default_calc_job_plugin = 'core.arithmetic.add' - - if not code: - try: - code = load_code('bash@localhost') - except exceptions.NotExistent: - localhost = prepare_localhost() - code = InstalledCode( - label='bash', - computer=localhost, - filepath_executable=which('bash'), - default_calc_job_plugin=default_calc_job_plugin - ).store() - else: - assert code.default_calc_job_plugin == default_calc_job_plugin - - builder = code.get_builder() - builder.x = Int(1) - builder.y = Int(1) - - if sleep: - builder.metadata.options.sleep = sleep - - if daemon: - node = submit(builder) - echo.echo_success(f'Submitted calculation `{node}`') - else: - _, node = run.get_node(builder) - if node.is_finished_ok: - echo.echo_success(f'ArithmeticAddCalculation<{node.pk}> finished successfully.') - else: - echo.echo_warning(f'ArithmeticAddCalculation<{node.pk}> did not finish successfully.') - - -def prepare_localhost(): - """Prepare and return the localhost as ``Computer``. - - If it doesn't already exist, the computer will be created, using ``core.local`` and ``core.direct`` as the entry - points for the transport and scheduler type, respectively. In that case, the safe transport interval and the minimum - job poll interval will both be set to 0 seconds in order to guarantee a throughput that is as fast as possible. - - :return: The localhost configured as a ``Computer``. - """ - import tempfile - - from aiida.orm import Computer, load_computer - - try: - computer = load_computer('localhost') - except exceptions.NotExistent: - echo.echo_warning('No `localhost` computer exists yet: creating and configuring the `localhost` computer.') - computer = Computer( - label='localhost', - hostname='localhost', - description='Localhost automatically created by `aiida.engine.launch_shell_job`', - transport_type='core.local', - scheduler_type='core.direct', - workdir=tempfile.gettempdir(), - ).store() - computer.configure(safe_interval=0.) - computer.set_minimum_job_poll_interval(0.) - - if not computer.is_configured: - computer.configure() - - return computer diff --git a/aiida/cmdline/commands/cmd_group.py b/aiida/cmdline/commands/cmd_group.py deleted file mode 100644 index 9b022bdcd2..0000000000 --- a/aiida/cmdline/commands/cmd_group.py +++ /dev/null @@ -1,511 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""`verdi group` commands""" -import click - -from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import arguments, options, types -from aiida.cmdline.utils import echo -from aiida.cmdline.utils.decorators import with_dbenv -from aiida.common.exceptions import UniquenessError -from aiida.common.links import GraphTraversalRules - - -@verdi.group('group') -def verdi_group(): - """Create, inspect and manage groups of nodes.""" - - -@verdi_group.command('add-nodes') -@options.GROUP(required=True) -@options.FORCE() -@arguments.NODES() -@with_dbenv() -def group_add_nodes(group, force, nodes): - """Add nodes to a group.""" - if not force: - click.confirm(f'Do you really want to add {len(nodes)} nodes to {group}?', abort=True) - - group.add_nodes(nodes) - - -@verdi_group.command('remove-nodes') -@options.GROUP(required=True) -@arguments.NODES() -@options.GROUP_CLEAR() -@options.FORCE() -@with_dbenv() -def group_remove_nodes(group, nodes, clear, force): - """Remove nodes from a group.""" - from aiida.orm import Group, Node, QueryBuilder - - if nodes and clear: - echo.echo_critical( - 'Specify either the `--clear` flag to remove all nodes or the identifiers of the nodes you want to remove.' - ) - - if not force: - - if nodes: - node_pks = [node.pk for node in nodes] - - query = QueryBuilder() - query.append(Group, filters={'id': group.pk}, tag='group') - query.append(Node, with_group='group', filters={'id': {'in': node_pks}}, project='id') - - group_node_pks = query.all(flat=True) - - if not group_node_pks: - echo.echo_critical(f'None of the specified nodes are in {group}.') - - if len(node_pks) > len(group_node_pks): - node_pks = set(node_pks).difference(set(group_node_pks)) - echo.echo_warning(f'{len(node_pks)} nodes with PK {node_pks} are not in {group}.') - - message = f'Are you sure you want to remove {len(group_node_pks)} nodes from {group}?' - - elif clear: - message = f'Are you sure you want to remove ALL the nodes from {group}?' - else: - echo.echo_critical(f'No nodes were provided for removal from {group}.') - - click.confirm(message, abort=True) - - if clear: - group.clear() - else: - group.remove_nodes(nodes) - - -@verdi_group.command('move-nodes') -@arguments.NODES() -@click.option('-s', '--source-group', type=types.GroupParamType(), required=True, help='The group whose nodes to move.') -@click.option( - '-t', '--target-group', type=types.GroupParamType(), required=True, help='The group to which the nodes are moved.' -) -@options.FORCE(help='Do not ask for confirmation and skip all checks.') -@options.ALL(help='Move all nodes from the source to the target group.') -@with_dbenv() -def group_move_nodes(source_group, target_group, force, nodes, all_entries): - """Move the specified NODES from one group to another.""" - from aiida.orm import Group, Node, QueryBuilder - - if source_group.pk == target_group.pk: - echo.echo_critical(f'Source and target group are the same: {source_group}.') - - if not nodes: - if all_entries: - nodes = list(source_group.nodes) - else: - echo.echo_critical('Neither NODES or the `-a, --all` option was specified.') - - node_pks = [node.pk for node in nodes] - - if not all_entries: - query = QueryBuilder() - query.append(Group, filters={'id': source_group.pk}, tag='group') - query.append(Node, with_group='group', filters={'id': {'in': node_pks}}, project='id') - - source_group_node_pks = query.all(flat=True) - - if not source_group_node_pks: - echo.echo_critical(f'None of the specified nodes are in {source_group}.') - - if len(node_pks) > len(source_group_node_pks): - absent_node_pks = set(node_pks).difference(set(source_group_node_pks)) - echo.echo_warning(f'{len(absent_node_pks)} nodes with PK {absent_node_pks} are not in {source_group}.') - nodes = [node for node in nodes if node.pk in source_group_node_pks] - node_pks = set(node_pks).difference(absent_node_pks) - - query = QueryBuilder() - query.append(Group, filters={'id': target_group.pk}, tag='group') - query.append(Node, with_group='group', filters={'id': {'in': node_pks}}, project='id') - - target_group_node_pks = query.all(flat=True) - - if target_group_node_pks: - echo.echo_warning( - f'{len(target_group_node_pks)} nodes with PK {set(target_group_node_pks)} are already in ' - f'{target_group}. These will still be removed from {source_group}.' - ) - - if not force: - click.confirm( - f'Are you sure you want to move {len(nodes)} nodes from {source_group} ' - f'to {target_group}?', abort=True - ) - - source_group.remove_nodes(nodes) - target_group.add_nodes(nodes) - - -@verdi_group.command('delete') -@arguments.GROUP() -@options.FORCE() -@click.option( - '--delete-nodes', is_flag=True, default=False, help='Delete all nodes in the group along with the group itself.' -) -@options.graph_traversal_rules(GraphTraversalRules.DELETE.value) -@options.DRY_RUN() -@with_dbenv() -def group_delete(group, delete_nodes, dry_run, force, **traversal_rules): - """Delete a group and (optionally) the nodes it contains.""" - from aiida import orm - from aiida.tools import delete_group_nodes - - if not (force or dry_run): - click.confirm(f'Are you sure you want to delete {group}?', abort=True) - elif dry_run: - echo.echo_report(f'Would have deleted {group}.') - - if delete_nodes: - - def _dry_run_callback(pks): - if not pks or force: - return False - echo.echo_warning(f'YOU ARE ABOUT TO DELETE {len(pks)} NODES! THIS CANNOT BE UNDONE!') - return not click.confirm('Do you want to continue?', abort=True) - - _, nodes_deleted = delete_group_nodes([group.pk], dry_run=dry_run or _dry_run_callback, **traversal_rules) - if not nodes_deleted: - # don't delete the group if the nodes were not deleted - return - - if not dry_run: - group_str = str(group) - orm.Group.collection.delete(group.pk) - echo.echo_success(f'{group_str} deleted.') - - -@verdi_group.command('relabel') -@arguments.GROUP() -@click.argument('label', type=click.STRING) -@with_dbenv() -def group_relabel(group, label): - """Change the label of a group.""" - try: - group.label = label - except UniquenessError as exception: - echo.echo_critical(str(exception)) - else: - echo.echo_success(f"Label changed to '{label}'") - - -@verdi_group.command('description') -@arguments.GROUP() -@click.argument('description', type=click.STRING, required=False) -@with_dbenv() -def group_description(group, description): - """Change the description of a group. - - If no description is defined, the current description will simply be echoed. - """ - if description: - group.description = description - echo.echo_success(f'Changed the description of {group}.') - else: - echo.echo(group.description) - - -@verdi_group.command('show') -@options.RAW(help='Show only a space-separated list of PKs of the calculations in the group') -@options.LIMIT() -@click.option( - '-u', - '--uuid', - is_flag=True, - default=False, - help='Show UUIDs together with PKs. Note: if the --raw option is also passed, PKs are not printed, but only UUIDs.' -) -@arguments.GROUP() -@with_dbenv() -def group_show(group, raw, limit, uuid): - """Show information for a given group.""" - from tabulate import tabulate - - from aiida.common import timezone - from aiida.common.utils import str_timedelta - - if limit: - node_iterator = group.nodes[:limit] - else: - node_iterator = group.nodes - - if raw: - if uuid: - echo.echo(' '.join(str(_.uuid) for _ in node_iterator)) - else: - echo.echo(' '.join(str(_.pk) for _ in node_iterator)) - else: - type_string = group.type_string - desc = group.description - now = timezone.now() - - table = [] - table.append(['Group label', group.label]) - table.append(['Group type_string', type_string]) - table.append(['Group description', desc if desc else '']) - echo.echo(tabulate(table)) - - table = [] - header = [] - if uuid: - header.append('UUID') - header.extend(['PK', 'Type', 'Created']) - echo.echo('# Nodes:') - for node in node_iterator: - row = [] - if uuid: - row.append(node.uuid) - row.append(node.pk) - row.append(node.node_type.rsplit('.', 2)[1]) - row.append(str_timedelta(now - node.ctime, short=True, negative_to_zero=True)) - table.append(row) - echo.echo(tabulate(table, headers=header)) - - -@verdi_group.command('list') -@options.ALL_USERS(help='Show groups for all users, rather than only for the current user.') -@options.USER(help='Add a filter to show only groups belonging to a specific user.') -@options.ALL(help='Show groups of all types.') -@options.TYPE_STRING() -@click.option( - '-d', - '--with-description', - 'with_description', - is_flag=True, - default=False, - help='Show also the group description.' -) -@click.option('-C', '--count', is_flag=True, default=False, help='Show also the number of nodes in the group.') -@options.PAST_DAYS(help='Add a filter to show only groups created in the past N days.', default=None) -@click.option( - '-s', - '--startswith', - type=click.STRING, - default=None, - help='Add a filter to show only groups for which the label begins with STRING.' -) -@click.option( - '-e', - '--endswith', - type=click.STRING, - default=None, - help='Add a filter to show only groups for which the label ends with STRING.' -) -@click.option( - '-c', - '--contains', - type=click.STRING, - default=None, - help='Add a filter to show only groups for which the label contains STRING.' -) -@options.ORDER_BY(type=click.Choice(['id', 'label', 'ctime']), default='label') -@options.ORDER_DIRECTION() -@options.NODE(help='Show only the groups that contain the node.') -@with_dbenv() -def group_list( - all_users, user, all_entries, type_string, with_description, count, past_days, startswith, endswith, contains, - order_by, order_dir, node -): - """Show a list of existing groups.""" - # pylint: disable=too-many-branches,too-many-arguments,too-many-locals,too-many-statements - import datetime - - from tabulate import tabulate - - from aiida import orm - from aiida.common import timezone - from aiida.common.escaping import escape_for_sql_like - - builder = orm.QueryBuilder() - filters = {} - - # Have to specify the default for `type_string` here instead of directly in the option otherwise it will always - # raise above if the user specifies just the `--group-type` option. Once that option is removed, the default can - # be moved to the option itself. - if type_string is None: - type_string = 'core' - - if not all_entries: - if '%' in type_string or '_' in type_string: - filters['type_string'] = {'like': type_string} - else: - filters['type_string'] = type_string - - # Creation time - if past_days: - filters['time'] = {'>': timezone.now() - datetime.timedelta(days=past_days)} - - # Query for specific group labels - filters['or'] = [] - if startswith: - filters['or'].append({'label': {'like': f'{escape_for_sql_like(startswith)}%'}}) - if endswith: - filters['or'].append({'label': {'like': f'%{escape_for_sql_like(endswith)}'}}) - if contains: - filters['or'].append({'label': {'like': f'%{escape_for_sql_like(contains)}%'}}) - - builder.append(orm.Group, filters=filters, tag='group', project='*') - - # Query groups that belong to specific user - if user: - user_email = user.email - else: - # By default: only groups of this user - user_email = orm.User.collection.get_default().email - - # Query groups that belong to all users - if not all_users: - builder.append(orm.User, filters={'email': {'==': user_email}}, with_group='group') - - # Query groups that contain a particular node - if node: - builder.append(orm.Node, filters={'id': {'==': node.pk}}, with_group='group') - - builder.order_by({orm.Group: {order_by: order_dir}}) - result = builder.all() - - projection_lambdas = { - 'pk': lambda group: str(group.pk), - 'label': lambda group: group.label, - 'type_string': lambda group: group.type_string, - 'count': lambda group: group.count(), - 'user': lambda group: group.user.email.strip(), - 'description': lambda group: group.description - } - - table = [] - projection_header = ['PK', 'Label', 'Type string', 'User'] - projection_fields = ['pk', 'label', 'type_string', 'user'] - - if with_description: - projection_header.append('Description') - projection_fields.append('description') - - if count: - projection_header.append('Node count') - projection_fields.append('count') - - for group in result: - table.append([projection_lambdas[field](group[0]) for field in projection_fields]) - - if not all_entries: - echo.echo_report('To show groups of all types, use the `-a/--all` option.') - - if not table: - echo.echo_report('No groups found matching the specified criteria.') - else: - echo.echo(tabulate(table, headers=projection_header)) - - -@verdi_group.command('create') -@click.argument('group_label', nargs=1, type=click.STRING) -@with_dbenv() -def group_create(group_label): - """Create an empty group with a given label.""" - from aiida import orm - - group, created = orm.Group.collection.get_or_create(label=group_label) - - if created: - echo.echo_success(f"Group created with PK = {group.pk} and label '{group.label}'.") - else: - echo.echo_report(f"Group with label '{group.label}' already exists: {group}.") - - -@verdi_group.command('copy') -@arguments.GROUP('source_group') -@click.argument('destination_group', nargs=1, type=click.STRING) -@with_dbenv() -def group_copy(source_group, destination_group): - """Duplicate a group. - - More in detail, add all nodes from the source group to the destination group. - Note that the destination group may not exist.""" - from aiida import orm - - dest_group, created = orm.Group.collection.get_or_create(label=destination_group) - - # Issue warning if destination group is not empty and get user confirmation to continue - if not created and not dest_group.is_empty: - echo.echo_warning(f'Destination {dest_group} already exists and is not empty.') - click.confirm('Do you wish to continue anyway?', abort=True) - - # Copy nodes - dest_group.add_nodes(list(source_group.nodes)) - echo.echo_success(f'Nodes copied from {source_group} to {dest_group}.') - - -@verdi_group.group('path') -def verdi_group_path(): - """Inspect groups of nodes, with delimited label paths.""" - - -@verdi_group_path.command('ls') -@click.argument('path', type=click.STRING, required=False) -@options.TYPE_STRING(default='core', help='Filter to only include groups of this type string.') -@click.option('-R', '--recursive', is_flag=True, default=False, help='Recursively list sub-paths encountered.') -@click.option('-l', '--long', 'as_table', is_flag=True, default=False, help='List as a table, with sub-group count.') -@click.option( - '-d', - '--with-description', - 'with_description', - is_flag=True, - default=False, - help='Show also the group description.' -) -@click.option( - '--no-virtual', - 'no_virtual', - is_flag=True, - default=False, - help='Only show paths that fully correspond to an existing group.' -) -@click.option('--no-warn', is_flag=True, default=False, help='Do not issue a warning if any paths are invalid.') -@with_dbenv() -def group_path_ls(path, type_string, recursive, as_table, no_virtual, with_description, no_warn): - # pylint: disable=too-many-arguments,too-many-branches - """Show a list of existing group paths.""" - from aiida.plugins import GroupFactory - from aiida.tools.groups.paths import GroupPath, InvalidPath - - try: - path = GroupPath(path or '', cls=GroupFactory(type_string), warn_invalid_child=not no_warn) - except InvalidPath as err: - echo.echo_critical(str(err)) - - if recursive: - children = path.walk() - else: - children = path.children - - if as_table or with_description: - from tabulate import tabulate - headers = ['Path', 'Sub-Groups'] - if with_description: - headers.append('Description') - rows = [] - for child in sorted(children): - if no_virtual and child.is_virtual: - continue - row = [ - child.path if child.is_virtual else click.style(child.path, bold=True), - len([c for c in child.walk() if not c.is_virtual]) - ] - if with_description: - row.append('-' if child.is_virtual else child.get_group().description) - rows.append(row) - echo.echo(tabulate(rows, headers=headers)) - else: - for child in sorted(children): - if no_virtual and child.is_virtual: - continue - echo.echo(child.path, bold=not child.is_virtual) diff --git a/aiida/cmdline/commands/cmd_process.py b/aiida/cmdline/commands/cmd_process.py deleted file mode 100644 index 7c3da143ec..0000000000 --- a/aiida/cmdline/commands/cmd_process.py +++ /dev/null @@ -1,321 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=too-many-arguments -"""`verdi process` command.""" -import click - -from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import arguments, options, types -from aiida.cmdline.utils import decorators, echo -from aiida.common.log import LOG_LEVELS -from aiida.manage import get_manager - - -def valid_projections(): - """Return list of valid projections for the ``--project`` option of ``verdi process list``. - - This indirection is necessary to prevent loading the imported module which slows down tab-completion. - """ - from aiida.tools.query.calculation import CalculationQueryBuilder - return CalculationQueryBuilder.valid_projections - - -def default_projections(): - """Return list of default projections for the ``--project`` option of ``verdi process list``. - - This indirection is necessary to prevent loading the imported module which slows down tab-completion. - """ - from aiida.tools.query.calculation import CalculationQueryBuilder - return CalculationQueryBuilder.default_projections - - -@verdi.group('process') -def verdi_process(): - """Inspect and manage processes.""" - - -@verdi_process.command('list') -@options.PROJECT(type=types.LazyChoice(valid_projections), default=lambda: default_projections()) # pylint: disable=unnecessary-lambda -@options.ORDER_BY() -@options.ORDER_DIRECTION() -@options.GROUP(help='Only include entries that are a member of this group.') -@options.ALL(help='Show all entries, regardless of their process state.') -@options.PROCESS_STATE() -@options.PROCESS_LABEL() -@options.PAUSED() -@options.EXIT_STATUS() -@options.FAILED() -@options.PAST_DAYS() -@options.LIMIT() -@options.RAW() -@click.pass_context -@decorators.with_dbenv() -def process_list( - ctx, all_entries, group, process_state, process_label, paused, exit_status, failed, past_days, limit, project, raw, - order_by, order_dir -): - """Show a list of running or terminated processes. - - By default, only those that are still running are shown, but there are options to show also the finished ones. - """ - # pylint: disable=too-many-locals - from tabulate import tabulate - - from aiida.cmdline.commands.cmd_daemon import execute_client_command - from aiida.cmdline.utils.common import print_last_process_state_change - from aiida.engine.daemon.client import get_daemon_client - from aiida.orm import ProcessNode, QueryBuilder - from aiida.tools.query.calculation import CalculationQueryBuilder - - relationships = {} - - if group: - relationships['with_node'] = group - - builder = CalculationQueryBuilder() - filters = builder.get_filters(all_entries, process_state, process_label, paused, exit_status, failed) - query_set = builder.get_query_set( - relationships=relationships, filters=filters, order_by={order_by: order_dir}, past_days=past_days, limit=limit - ) - projected = builder.get_projected(query_set, projections=project) - headers = projected.pop(0) - - if raw: - tabulated = tabulate(projected, tablefmt='plain') - echo.echo(tabulated) - return - - tabulated = tabulate(projected, headers=headers) - echo.echo(tabulated) - echo.echo(f'\nTotal results: {len(projected)}\n') - print_last_process_state_change() - - if not get_daemon_client().is_daemon_running: - echo.echo_warning('The daemon is not running', bold=True) - return - - echo.echo_report('Checking daemon load... ', nl=False) - response = execute_client_command('get_numprocesses') - - if not response: - # Daemon could not be reached - return - - try: - active_workers = response['numprocesses'] - except KeyError: - echo.echo_report('No active daemon workers.') - else: - # Second query to get active process count. Currently this is slow but will be fixed with issue #2770. It is - # placed at the end of the command so that the user can Ctrl+C after getting the process table. - slots_per_worker = ctx.obj.config.get_option('daemon.worker_process_slots', scope=ctx.obj.profile.name) - active_processes = QueryBuilder().append( - ProcessNode, filters={ - 'attributes.process_state': { - 'in': ('created', 'waiting', 'running') - } - } - ).count() - available_slots = active_workers * slots_per_worker - percent_load = active_processes / available_slots - if percent_load > 0.9: # 90% - echo.echo_warning(f'{percent_load * 100:.0f}% of the available daemon worker slots have been used!') - echo.echo_warning('Increase the number of workers with `verdi daemon incr`.') - else: - echo.echo_report(f'Using {percent_load * 100:.0f}% of the available daemon worker slots.') - - -@verdi_process.command('show') -@arguments.PROCESSES() -@decorators.with_dbenv() -def process_show(processes): - """Show details for one or multiple processes.""" - from aiida.cmdline.utils.common import get_node_info - - for process in processes: - echo.echo(get_node_info(process)) - - -@verdi_process.command('call-root') -@arguments.PROCESSES() -@decorators.with_dbenv() -def process_call_root(processes): - """Show root process of the call stack for the given processes.""" - for process in processes: - - caller = process.caller - - if caller is None: - echo.echo(f'No callers found for Process<{process.pk}>') - continue - - while True: - next_caller = caller.caller - - if next_caller is None: - break - - caller = next_caller - - echo.echo(f'{caller.pk}') - - -@verdi_process.command('report') -@arguments.PROCESSES() -@click.option('-i', '--indent-size', type=int, default=2, help='Set the number of spaces to indent each level by.') -@click.option( - '-l', - '--levelname', - type=click.Choice(list(LOG_LEVELS)), - default='REPORT', - help='Filter the results by name of the log level.' -) -@click.option( - '-m', '--max-depth', 'max_depth', type=int, default=None, help='Limit the number of levels to be printed.' -) -@decorators.with_dbenv() -def process_report(processes, levelname, indent_size, max_depth): - """Show the log report for one or multiple processes.""" - from aiida.cmdline.utils.common import get_calcjob_report, get_process_function_report, get_workchain_report - from aiida.orm import CalcFunctionNode, CalcJobNode, WorkChainNode, WorkFunctionNode - - for process in processes: - if isinstance(process, CalcJobNode): - echo.echo(get_calcjob_report(process)) - elif isinstance(process, WorkChainNode): - echo.echo(get_workchain_report(process, levelname, indent_size, max_depth)) - elif isinstance(process, (CalcFunctionNode, WorkFunctionNode)): - echo.echo(get_process_function_report(process)) - else: - echo.echo(f'Nothing to show for node type {process.__class__}') - - -@verdi_process.command('status') -@click.option('-c', '--call-link-label', 'call_link_label', is_flag=True, help='Include the call link label if set.') -@click.option( - '-m', '--max-depth', 'max_depth', type=int, default=None, help='Limit the number of levels to be printed.' -) -@arguments.PROCESSES() -def process_status(call_link_label, max_depth, processes): - """Print the status of one or multiple processes.""" - from aiida.cmdline.utils.ascii_vis import format_call_graph - - for process in processes: - graph = format_call_graph(process, max_depth=max_depth, call_link_label=call_link_label) - echo.echo(graph) - - -@verdi_process.command('kill') -@arguments.PROCESSES() -@options.ALL(help='Kill all processes if no specific processes are specified.') -@options.TIMEOUT() -@options.WAIT() -@decorators.with_dbenv() -def process_kill(processes, all_entries, timeout, wait): - """Kill running processes.""" - from aiida.engine.processes import control - - if processes and all_entries: - raise click.BadOptionUsage('all', 'cannot specify individual processes and the `--all` flag at the same time.') - - if all_entries: - click.confirm('Are you sure you want to kill all processes?', abort=True) - - try: - message = 'Killed through `verdi process kill`' - control.kill_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, message=message) - except control.ProcessTimeoutException as exception: - echo.echo_critical(str(exception) + '\nFrom the CLI you can call `verdi devel revive `.') - - -@verdi_process.command('pause') -@arguments.PROCESSES() -@options.ALL(help='Pause all active processes if no specific processes are specified.') -@options.TIMEOUT() -@options.WAIT() -@decorators.with_dbenv() -def process_pause(processes, all_entries, timeout, wait): - """Pause running processes.""" - from aiida.engine.processes import control - - if processes and all_entries: - raise click.BadOptionUsage('all', 'cannot specify individual processes and the `--all` flag at the same time.') - - try: - message = 'Paused through `verdi process pause`' - control.pause_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, message=message) - except control.ProcessTimeoutException as exception: - echo.echo_critical(str(exception) + '\nFrom the CLI you can call `verdi devel revive `.') - - -@verdi_process.command('play') -@arguments.PROCESSES() -@options.ALL(help='Play all paused processes if no specific processes are specified.') -@options.TIMEOUT() -@options.WAIT() -@decorators.with_dbenv() -def process_play(processes, all_entries, timeout, wait): - """Play (unpause) paused processes.""" - from aiida.engine.processes import control - - if processes and all_entries: - raise click.BadOptionUsage('all', 'cannot specify individual processes and the `--all` flag at the same time.') - - try: - control.play_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait) - except control.ProcessTimeoutException as exception: - echo.echo_critical(str(exception) + '\nFrom the CLI you can call `verdi devel revive `.') - - -@verdi_process.command('watch') -@arguments.PROCESSES() -@decorators.with_dbenv() -@decorators.only_if_daemon_running(echo.echo_warning, 'daemon is not running, so process may not be reachable') -def process_watch(processes): - """Watch the state transitions for a process.""" - from time import sleep - - from kiwipy import BroadcastFilter - - def _print(communicator, body, sender, subject, correlation_id): # pylint: disable=unused-argument - """Format the incoming broadcast data into a message and echo it to stdout.""" - if body is None: - body = 'No message specified' - - if correlation_id is None: - correlation_id = '--' - - echo.echo(f'Process<{sender}> [{subject}|{correlation_id}]: {body}') - - communicator = get_manager().get_communicator() - echo.echo_report('watching for broadcasted messages, press CTRL+C to stop...') - - for process in processes: - - if process.is_terminated: - echo.echo_error(f'Process<{process.pk}> is already terminated') - continue - - communicator.add_broadcast_subscriber(BroadcastFilter(_print, sender=process.pk)) - - try: - # Block this thread indefinitely until interrupt - while True: - sleep(2) - except (SystemExit, KeyboardInterrupt): - echo.echo('') # add a new line after the interrupt character - echo.echo_report('received interrupt, exiting...') - try: - communicator.close() - except RuntimeError: - pass - - # Reraise to trigger clicks builtin abort sequence - raise diff --git a/aiida/cmdline/commands/cmd_profile.py b/aiida/cmdline/commands/cmd_profile.py deleted file mode 100644 index df7fc2331b..0000000000 --- a/aiida/cmdline/commands/cmd_profile.py +++ /dev/null @@ -1,149 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""`verdi profile` command.""" -import click - -from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import arguments, options -from aiida.cmdline.utils import defaults, echo -from aiida.common import exceptions -from aiida.manage.configuration import get_config - - -@verdi.group('profile') -def verdi_profile(): - """Inspect and manage the configured profiles.""" - - -@verdi_profile.command('list') -def profile_list(): - """Display a list of all available profiles.""" - try: - config = get_config() - except (exceptions.MissingConfigurationError, exceptions.ConfigurationError) as exception: - # This can happen for a fresh install and the `verdi setup` has not yet been run. In this case it is still nice - # to be able to see the configuration directory, for instance for those who have set `AIIDA_PATH`. This way - # they can at least verify that it is correctly set. - from aiida.manage.configuration.settings import AIIDA_CONFIG_FOLDER - echo.echo_report(f'configuration folder: {AIIDA_CONFIG_FOLDER}') - echo.echo_critical(str(exception)) - else: - echo.echo_report(f'configuration folder: {config.dirpath}') - - if not config.profiles: - echo.echo_warning('no profiles configured: run `verdi setup` to create one') - else: - sort = lambda profile: profile.name # pylint: disable=unnecessary-lambda-assignment - highlight = lambda profile: profile.name == config.default_profile_name # pylint: disable=unnecessary-lambda-assignment - echo.echo_formatted_list(config.profiles, ['name'], sort=sort, highlight=highlight) - - -def _strip_private_keys(dct: dict): - """Remove private keys (starting `_`) from the dictionary.""" - return { - key: _strip_private_keys(value) if isinstance(value, dict) else value - for key, value in dct.items() - if not key.startswith('_') - } - - -@verdi_profile.command('show') -@arguments.PROFILE(default=defaults.get_default_profile) -def profile_show(profile): - """Show details for a profile.""" - - if profile is None: - echo.echo_critical('no profile to show') - - echo.echo_report(f'Profile: {profile.name}') - config = _strip_private_keys(profile.dictionary) - echo.echo_dictionary(config, fmt='yaml') - - -@verdi_profile.command('setdefault') -@arguments.PROFILE(required=True, default=None) -def profile_setdefault(profile): - """Set a profile as the default one.""" - try: - config = get_config() - except (exceptions.MissingConfigurationError, exceptions.ConfigurationError) as exception: - echo.echo_critical(str(exception)) - - config.set_default_profile(profile.name, overwrite=True).store() - echo.echo_success(f'{profile.name} set as default profile') - - -@verdi_profile.command('delete') -@options.FORCE(help='to skip questions and warnings about loss of data') -@click.option( - '--include-config/--skip-config', - default=True, - show_default=True, - help='Include deletion of entry in configuration file.' -) -@click.option( - '--include-db/--skip-db', - 'include_database', - default=True, - show_default=True, - help='Include deletion of associated database.' -) -@click.option( - '--include-db-user/--skip-db-user', - 'include_database_user', - default=False, - show_default=True, - help='Include deletion of associated database user.' -) -@click.option( - '--include-repository/--skip-repository', - default=True, - show_default=True, - help='Include deletion of associated file repository.' -) -@arguments.PROFILES(required=True) -def profile_delete(force, include_config, include_database, include_database_user, include_repository, profiles): - """Delete one or more profiles. - - The PROFILES argument takes one or multiple profile names that will be deleted. Deletion here means that the profile - will be removed including its file repository and database. The various options can be used to control which parts - of the profile are deleted. - """ - if not include_config: - echo.echo_deprecated('the `--skip-config` option is deprecated and is no longer respected.') - - for profile in profiles: - - includes = { - 'database': include_database, - 'database user': include_database_user, - 'file repository': include_repository - } - - if not all(includes.values()): - excludes = [label for label, value in includes.items() if not value] - message_suffix = f' excluding: {", ".join(excludes)}.' - else: - message_suffix = '.' - - echo.echo_warning(f'deleting profile `{profile.name}`{message_suffix}') - echo.echo_warning('this operation cannot be undone, ', nl=False) - - if not force and not click.confirm('are you sure you want to continue?'): - echo.echo_report(f'deleting of `{profile.name} cancelled.') - continue - - get_config().delete_profile( - profile.name, - include_database=include_database, - include_database_user=include_database_user, - include_repository=include_repository - ) - echo.echo_success(f'profile `{profile.name}` was deleted{message_suffix}.') diff --git a/aiida/cmdline/commands/cmd_rabbitmq.py b/aiida/cmdline/commands/cmd_rabbitmq.py deleted file mode 100644 index ece4893dfd..0000000000 --- a/aiida/cmdline/commands/cmd_rabbitmq.py +++ /dev/null @@ -1,388 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""`verdi devel rabbitmq` commands.""" -from __future__ import annotations - -import collections -import re -import sys -import typing as t - -import click -import requests -import tabulate -import wrapt -import yaml - -from aiida.cmdline.commands.cmd_devel import verdi_devel -from aiida.cmdline.params import arguments, options -from aiida.cmdline.utils import decorators, echo - -if t.TYPE_CHECKING: - import kiwipy.rmq - - from aiida.manage.configuration.profile import Profile - - -@verdi_devel.group('rabbitmq') -def cmd_rabbitmq(): - """Commands to interact with RabbitMQ.""" - - -@cmd_rabbitmq.group('queues') -def cmd_queues(): - """Commands to interact with RabbitMQ queues.""" - - -@cmd_rabbitmq.group('tasks') -def cmd_tasks(): - """Commands related to process tasks.""" - - -AVAILABLE_PROJECTORS = ( - 'arguments', - 'auto_delete', - 'backing_queue_status', - 'consumer_utilisation', - 'consumers', - 'durable', - 'effective_policy_definition', - 'exclusive', - 'exclusive_consumer_tag', - 'garbage_collection', - 'head_message_timestamp', - 'idle_since', - 'memory', - 'message_bytes', - 'message_bytes_paged_out', - 'message_bytes_persistent', - 'message_bytes_ram', - 'message_bytes_ready', - 'message_bytes_unacknowledged', - 'messages', - 'messages_details', - 'messages_paged_out', - 'messages_persistent', - 'messages_ram', - 'messages_ready', - 'messages_ready_details', - 'messages_ready_ram', - 'messages_unacknowledged', - 'messages_unacknowledged_details', - 'messages_unacknowledged_ram', - 'name', - 'node', - 'operator_policy', - 'policy', - 'recoverable_slaves', - 'reductions', - 'reductions_details', - 'single_active_consumer_tag', - 'state', - 'type', - 'vhost', -) - - -def echo_response(response: requests.Response, exit_on_error: bool = True) -> None: - """Echo the response of a request. - - :param response: The response to the request. - :param exit_on_error: Boolean, if ``True``, call ``sys.exit`` with the status code of the response. - """ - try: - response.raise_for_status() - except requests.HTTPError: - click.secho(f'[{response.status_code}] ', fg='red', bold=True, nl=False) - click.secho(f'{response.reason}: {response.url}') - if exit_on_error: - sys.exit(response.status_code) - else: - click.secho(f'[{response.status_code}] ', fg='green', bold=True) - - -@wrapt.decorator -@click.pass_context -def with_client(ctx, wrapped, _, args, kwargs): - """Decorate a function injecting a :class:`aiida.manage.external.rmq.client.RabbitmqManagementClient`.""" - - from aiida.manage.external.rmq.client import RabbitmqManagementClient - config = ctx.obj.profile.process_control_config - client = RabbitmqManagementClient( - username=config['broker_username'], - password=config['broker_password'], - hostname=config['broker_host'], - virtual_host=config['broker_virtual_host'], - ) - - if not client.is_connected: - echo.echo_critical( - 'Could not connect to the management API. Make sure RabbitMQ is running and the management plugin is ' - 'installed using `sudo rabbitmq-plugins enable rabbitmq_management`. The API is served on port 15672, so ' - 'if you are connecting to RabbitMQ running in a Docker container, make sure that port is exposed.' - ) - - kwargs['client'] = client - return wrapped(*args, **kwargs) - - -@wrapt.decorator -def with_manager(wrapped, _, args, kwargs): - """Decorate a function injecting a :class:`kiwipy.rmq.communicator.RmqCommunicator`.""" - from aiida.manage import get_manager - kwargs['manager'] = get_manager() - return wrapped(*args, **kwargs) - - -@cmd_rabbitmq.command('server-properties') -@with_manager -def cmd_server_properties(manager): - """List the server properties.""" - data = {} - for key, value in manager.get_communicator().server_properties.items(): - data[key] = value.decode('utf-8') if isinstance(value, bytes) else value - click.echo(yaml.dump(data, indent=4)) - - -@cmd_queues.command('list') -@click.option( - '-P', - '--project', - type=click.Choice(AVAILABLE_PROJECTORS), - cls=options.MultipleValueOption, - default=('name', 'messages', 'state') -) -@options.RAW() -@click.option('-f', '--filter-name', type=str, help='Provide a regex pattern to filter queues based on their name. ') -@with_client -def cmd_queues_list(client, project, raw, filter_name): - """List all queues.""" - response = client.request('queues') - - if not response.ok: - echo_response(response) - - if filter_name and 'name' not in project: - raise click.BadParameter('cannot use `--filter-name` when not projecting `name`.') - - if filter_name: - try: - re.match(filter_name, '') - except re.error as exception: - raise click.BadParameter(f'invalid regex pattern: {exception}', param_hint='`--filter-name`') - - queues = [queue for queue in response.json() if re.match(filter_name or '', queue['name'])] - output = [ - list(map(lambda key, values=queue: values.get(key, ''), project)) # type: ignore[misc] - for queue in queues - ] - - if not output: - echo.echo_report('No queues matched.') - return - - headers = [name.capitalize() for name in project] if not raw else [] - tablefmt = None if not raw else 'plain' - echo.echo(tabulate.tabulate(output, headers=headers, tablefmt=tablefmt)) - - -@cmd_queues.command('create') -@click.argument('queues', nargs=-1) -@with_client -def cmd_queues_create(client, queues): - """Create new queues.""" - for queue in queues: - response = client.request('queues/{virtual_host}/{queue}', {'queue': queue}, method='PUT') - click.secho(f'Create `{queue}`... ', nl=False) - echo_response(response) - - -@cmd_queues.command('delete') -@click.argument('queues', nargs=-1) -@with_client -def cmd_queues_delete(client, queues): - """Delete existing queues.""" - for queue in queues: - params = {'if-empty': True, 'if-unused': True} - response = client.request('queues/{virtual_host}/{queue}', {'queue': queue}, method='DELETE', params=params) - click.secho(f'Delete `{queue}`... ', nl=False) - echo_response(response, exit_on_error=False) - - -@cmd_tasks.command('list') -@with_manager -@decorators.only_if_daemon_not_running() -@click.pass_context -def cmd_tasks_list(ctx, manager): - """List all active process tasks. - - This command prints a list of process pk's for which there is an active process task with RabbitMQ. Since tasks can - only be seen when they are not currently with a daemon worker, this command can only be run when the daemon is not - running. - """ - for pk in get_process_tasks(ctx.obj.profile, manager.get_communicator()): - echo.echo(pk) - - -def get_active_processes() -> list[int]: - """Return the list of pks of active processes. - - An active process is defined as a process that has a node with its attribute ``process_state`` set to one of: - - * ``created`` - * ``waiting`` - * ``running`` - - :returns: A list of process pks that are marked as active in the database. - """ - from aiida.engine import ProcessState - from aiida.orm import ProcessNode, QueryBuilder - - return QueryBuilder().append( # type: ignore[return-value] - ProcessNode, - filters={ - 'attributes.process_state': { - 'in': [ProcessState.CREATED.value, ProcessState.WAITING.value, ProcessState.RUNNING.value] - } - }, - project='id' - ).all(flat=True) - - -def iterate_process_tasks( - profile: Profile, communicator: kiwipy.rmq.RmqCommunicator -) -> collections.abc.Iterator[kiwipy.rmq.RmqIncomingTask]: - """Return the list of process pks that have a process task in the RabbitMQ process queue. - - :returns: A list of process pks that have a corresponding process task with RabbitMQ. - """ - from aiida.manage.external.rmq import get_launch_queue_name - - launch_queue = get_launch_queue_name(profile.rmq_prefix) - - for task in communicator.task_queue(launch_queue): - yield task - - -def get_process_tasks(profile: Profile, communicator: kiwipy.rmq.RmqCommunicator) -> list[int]: - """Return the list of process pks that have a process task in the RabbitMQ process queue. - - :returns: A list of process pks that have a corresponding process task with RabbitMQ. - """ - pks = [] - - for task in iterate_process_tasks(profile, communicator): - try: - pks.append(task.body.get('args', {})['pid']) - except KeyError: - pass - - return pks - - -@cmd_tasks.command('analyze') -@click.option('--fix', is_flag=True, help='Attempt to fix the inconsistencies if any are detected.') -@with_manager -@decorators.only_if_daemon_not_running() -@click.pass_context -def cmd_tasks_analyze(ctx, manager, fix): - """Perform analysis of process tasks. - - This command will perform a query of the database to find all "active" processes, meaning those that haven't yet - reached a terminal state, and cross-references this with the active process tasks that are in the process queue of - RabbitMQ. Any active process that does not have a corresponding process task can be considered a zombie, as it will - never be picked up by a daemon worker to complete it and will effectively be "stuck". Any process task that does not - correspond to an active process is useless and should be discarded. Finally, duplicate process tasks are also - problematic and duplicates should be discarded. - - Use ``-v INFO`` to be more verbose and print more information. - """ - active_processes = get_active_processes() - process_tasks = get_process_tasks(ctx.obj.profile, manager.get_communicator()) - - set_active_processes = set(active_processes) - set_process_tasks = set(process_tasks) - - echo.echo_info(f'Active processes: {active_processes}') - echo.echo_info(f'Process tasks: {process_tasks}') - - state_inconsistent = False - - if len(process_tasks) != len(set_process_tasks): - state_inconsistent = True - echo.echo_warning('There are duplicates process tasks: ', nl=False) - echo.echo(set(x for x in process_tasks if process_tasks.count(x) > 1)) - - if set_process_tasks.difference(set_active_processes): - state_inconsistent = True - echo.echo_warning('There are process tasks for terminated processes: ', nl=False) - echo.echo(set_process_tasks.difference(set_active_processes)) - - if set_active_processes.difference(set_process_tasks): - state_inconsistent = True - echo.echo_warning('There are active processes without process task: ', nl=False) - echo.echo(set_active_processes.difference(set_process_tasks)) - - if state_inconsistent and not fix: - echo.echo_critical( - 'Inconsistencies detected between database and RabbitMQ. Run again with `--fix` to address problems.' - ) - - if not state_inconsistent: - echo.echo_success('No inconsistencies detected between database and RabbitMQ.') - return - - # At this point we have either exited because of inconsistencies and ``--fix`` was not passed, or we returned - # because there were no inconsistencies, so all that is left is to address inconsistencies - echo.echo_info('Attempting to fix inconsistencies') - - # Eliminate duplicate tasks and tasks that correspond to terminated process - for task in iterate_process_tasks(ctx.obj.profile, manager.get_communicator()): - pid = task.body.get('args', {}).get('pid', None) - if pid not in set_active_processes: - with task.processing() as outcome: - outcome.set_result(False) - echo.echo_report(f'Acknowledged task `{pid}`') - - # Revive zombie processes that no longer have a process task - process_controller = manager.get_process_controller() - for pid in set_active_processes: - if pid not in set_process_tasks: - process_controller.continue_process(pid) - echo.echo_report(f'Revived process `{pid}`') - - -@cmd_tasks.command('revive') -@arguments.PROCESSES() -@options.FORCE() -@decorators.only_if_daemon_running(message='The daemon has to be running for this command to work.') -def cmd_tasks_revive(processes, force): - """Revive processes that seem stuck and are no longer reachable. - - Warning: Use only as a last resort after you've gone through the checklist below. - - \b - 1. Does ``verdi status`` indicate that both daemon and RabbitMQ are running properly? - If not, restart the daemon with ``verdi daemon restart --reset`` and restart RabbitMQ. - 2. Try ``verdi process play ``. - If you receive a message that the process is no longer reachable, - use ``verdi devel rabbitmq tasks revive ``. - - Details: When RabbitMQ loses the process task before the process has completed, the process is never picked up by - the daemon and will remain "stuck". ``verdi devel rabbitmq tasks revive`` recreates the task, which can lead to - multiple instances of the task being executed and should thus be used with caution. - """ - from aiida.engine.processes.control import revive_processes - - if not force: - echo.echo_warning('This command should only be used if you are absolutely sure the process task was lost.') - click.confirm(text='Do you want to continue?', abort=True) - - revive_processes(processes) diff --git a/aiida/cmdline/commands/cmd_status.py b/aiida/cmdline/commands/cmd_status.py deleted file mode 100644 index e4d3f68dee..0000000000 --- a/aiida/cmdline/commands/cmd_status.py +++ /dev/null @@ -1,174 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""`verdi status` command.""" -import enum -import sys - -import click - -from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import options -from aiida.cmdline.utils import echo -from aiida.common.exceptions import CorruptStorage, IncompatibleStorageSchema, UnreachableStorage -from aiida.common.log import override_log_level - -from ..utils.echo import ExitCode # pylint: disable=import-error,no-name-in-module - - -class ServiceStatus(enum.IntEnum): - """Describe status of services for 'verdi status' command.""" - UP = 0 # pylint: disable=invalid-name - ERROR = 1 - WARNING = 2 - DOWN = 3 - - -STATUS_SYMBOLS = { - ServiceStatus.UP: { - 'color': 'green', - 'string': '\u2714', - }, - ServiceStatus.ERROR: { - 'color': 'red', - 'string': '\u2718', - }, - ServiceStatus.WARNING: { - 'color': 'yellow', - 'string': '\u23FA', - }, - ServiceStatus.DOWN: { - 'color': 'red', - 'string': '\u2718', - }, -} - - -@verdi.command('status') -@options.PRINT_TRACEBACK() -@click.option('--no-rmq', is_flag=True, help='Do not check RabbitMQ status') -def verdi_status(print_traceback, no_rmq): - """Print status of AiiDA services.""" - # pylint: disable=broad-except,too-many-statements,too-many-branches,too-many-locals, - from aiida import __version__ - from aiida.common.utils import Capturing - from aiida.engine.daemon.client import DaemonException, DaemonNotRunningException - from aiida.manage.configuration.settings import AIIDA_CONFIG_FOLDER - from aiida.manage.manager import get_manager - - exit_code = ExitCode.SUCCESS - - print_status(ServiceStatus.UP, 'version', f'AiiDA v{__version__}') - print_status(ServiceStatus.UP, 'config', AIIDA_CONFIG_FOLDER) - - manager = get_manager() - - try: - profile = manager.get_profile() - - if profile is None: - print_status(ServiceStatus.WARNING, 'profile', 'no profile configured yet') - echo.echo_report('Configure a profile by running `verdi quicksetup` or `verdi setup`.') - return - - print_status(ServiceStatus.UP, 'profile', profile.name) - - except Exception as exc: - message = 'Unable to read AiiDA profile' - print_status(ServiceStatus.ERROR, 'profile', message, exception=exc, print_traceback=print_traceback) - sys.exit(ExitCode.CRITICAL) # stop here - without a profile we cannot access anything - - # Check the backend storage - storage_head_version = None - try: - with override_log_level(): # temporarily suppress noisy logging - storage_cls = profile.storage_cls - storage_head_version = storage_cls.version_head() - storage_backend = storage_cls(profile) - except UnreachableStorage as exc: - message = 'Unable to connect to profile\'s storage.' - print_status(ServiceStatus.DOWN, 'storage', message, exception=exc, print_traceback=print_traceback) - exit_code = ExitCode.CRITICAL - except IncompatibleStorageSchema as exc: - message = ( - f'Storage schema version is incompatible with the code version {storage_head_version!r}. ' - 'Run `verdi storage migrate` to solve this.' - ) - print_status(ServiceStatus.DOWN, 'storage', message) - exit_code = ExitCode.CRITICAL - except CorruptStorage as exc: - message = 'Storage is corrupted.' - print_status(ServiceStatus.DOWN, 'storage', message, exception=exc, print_traceback=print_traceback) - exit_code = ExitCode.CRITICAL - except Exception as exc: - message = 'Unable to instatiate profile\'s storage.' - print_status(ServiceStatus.ERROR, 'storage', message, exception=exc, print_traceback=print_traceback) - exit_code = ExitCode.CRITICAL - else: - message = str(storage_backend) - print_status(ServiceStatus.UP, 'storage', message) - - # Getting the rmq status - if not no_rmq: - rmq_url = '' - try: - rmq_url = profile.get_rmq_url() - with Capturing(capture_stderr=True): - with override_log_level(): # temporarily suppress noisy logging - comm = manager.get_communicator() - except Exception as exc: - message = f'Unable to connect to rabbitmq with URL: {rmq_url}' - print_status(ServiceStatus.ERROR, 'rabbitmq', message, exception=exc, print_traceback=print_traceback) - exit_code = ExitCode.CRITICAL - else: - version, supported = manager.check_rabbitmq_version(comm) - connection = f'Connected to RabbitMQ v{version} as {rmq_url}' - if supported: - print_status(ServiceStatus.UP, 'rabbitmq', connection) - else: - print_status(ServiceStatus.WARNING, 'rabbitmq', 'Incompatible RabbitMQ version detected! ' + connection) - - # Getting the daemon status - try: - status = manager.get_daemon_client().get_status() - except DaemonNotRunningException as exception: - print_status(ServiceStatus.WARNING, 'daemon', str(exception)) - except DaemonException as exception: - print_status(ServiceStatus.ERROR, 'daemon', str(exception)) - except Exception as exception: - message = 'Error getting daemon status' - print_status(ServiceStatus.ERROR, 'daemon', message, exception=exception, print_traceback=print_traceback) - exit_code = ExitCode.CRITICAL - else: - print_status(ServiceStatus.UP, 'daemon', f'Daemon is running with PID {status["pid"]}') - - # Note: click does not forward return values to the exit code, see https://github.com/pallets/click/issues/747 - if exit_code != ExitCode.SUCCESS: - sys.exit(exit_code) - - -def print_status(status, service, msg='', exception=None, print_traceback=False): - """Print status message. - - Includes colored indicator. - - :param status: a ServiceStatus code - :param service: string for service name - :param msg: message string - """ - symbol = STATUS_SYMBOLS[status] - echo.echo(f" {symbol['string']} ", fg=symbol['color'], nl=False) - echo.echo(f"{service + ':':12s} {msg}") - - if exception is not None: - echo.echo_error(f'{type(exception).__name__}: {exception}') - - if print_traceback: - import traceback - traceback.print_exc() diff --git a/aiida/cmdline/commands/cmd_storage.py b/aiida/cmdline/commands/cmd_storage.py deleted file mode 100644 index cf32e9bf99..0000000000 --- a/aiida/cmdline/commands/cmd_storage.py +++ /dev/null @@ -1,168 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""`verdi storage` commands.""" -import click -from click_spinner import spinner - -from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import options -from aiida.cmdline.utils import decorators, echo -from aiida.common import exceptions - - -@verdi.group('storage') -def verdi_storage(): - """Inspect and manage stored data for a profile.""" - - -@verdi_storage.command('version') -def storage_version(): - """Print the current version of the storage schema.""" - from aiida import get_profile - profile = get_profile() - head_version = profile.storage_cls.version_head() - profile_version = profile.storage_cls.version_profile(profile) - echo.echo(f'Latest storage schema version: {head_version!r}') - echo.echo(f'Storage schema version of {profile.name!r}: {profile_version!r}') - - -@verdi_storage.command('migrate') -@options.FORCE() -def storage_migrate(force): - """Migrate the storage to the latest schema version.""" - from aiida.engine.daemon.client import get_daemon_client - from aiida.manage import get_manager - - client = get_daemon_client() - if client.is_daemon_running: - echo.echo_critical('Migration aborted, the daemon for the profile is still running.') - - manager = get_manager() - profile = manager.get_profile() - storage_cls = profile.storage_cls - - if not force: - - echo.echo_warning('Migrating your storage might take a while and is not reversible.') - echo.echo_warning('Before continuing, make sure you have completed the following steps:') - echo.echo_warning('') - echo.echo_warning(' 1. Make sure you have no active calculations and workflows.') - echo.echo_warning(' 2. If you do, revert the code to the previous version and finish running them first.') - echo.echo_warning(' 3. Stop the daemon using `verdi daemon stop`') - echo.echo_warning(' 4. Make a backup of your database and repository') - echo.echo_warning('') - echo.echo_warning('', nl=False) - - expected_answer = 'MIGRATE NOW' - confirm_message = 'If you have completed the steps above and want to migrate profile "{}", type {}'.format( - profile.name, expected_answer - ) - - try: - response = click.prompt(confirm_message) - while response != expected_answer: - response = click.prompt(confirm_message) - except click.Abort: - echo.echo('\n') - echo.echo_critical('Migration aborted, the data has not been affected.') - return - - try: - storage_cls.migrate(profile) - except (exceptions.ConfigurationError, exceptions.StorageMigrationError) as exception: - echo.echo_critical(str(exception)) - else: - echo.echo_success('migration completed') - - -@verdi_storage.group('integrity') -def storage_integrity(): - """Checks for the integrity of the data storage.""" - - -@verdi_storage.command('info') -@click.option('--detailed', is_flag=True, help='Provides more detailed information.') -@decorators.with_dbenv() -def storage_info(detailed): - """Summarise the contents of the storage.""" - from aiida.manage.manager import get_manager - - manager = get_manager() - storage = manager.get_profile_storage() - - with spinner(): - data = storage.get_info(detailed=detailed) - - echo.echo_dictionary(data, sort_keys=False, fmt='yaml') - - -@verdi_storage.command('maintain') -@click.option( - '--full', - is_flag=True, - help='Perform all maintenance tasks, including the ones that should not be executed while the profile is in use.' -) -@click.option( - '--no-repack', is_flag=True, help='Disable the repacking of the storage when running a `full maintenance`.' -) -@options.FORCE() -@click.option( - '--dry-run', - is_flag=True, - help= - 'Run the maintenance in dry-run mode which will print actions that would be taken without actually executing them.' -) -@click.option( - '--compress', is_flag=True, default=False, help='Use compression if possible when carrying out maintenance tasks.' -) -@decorators.with_dbenv() -@click.pass_context -def storage_maintain(ctx, full, no_repack, force, dry_run, compress): - """Performs maintenance tasks on the repository.""" - from aiida.common.exceptions import LockingProfileError - from aiida.manage.manager import get_manager - - manager = get_manager() - profile = ctx.obj.profile - storage = manager.get_profile_storage() - - if full: - echo.echo_warning( - '\nIn order to safely perform the full maintenance operations on the internal storage, the profile ' - f'{profile.name} needs to be locked. ' - 'This means that no other process will be able to access it and will fail instead. ' - 'Moreover, if any process is already using the profile, the locking attempt will fail and you will ' - 'have to either look for these processes and kill them or wait for them to stop by themselves. ' - 'Note that this includes verdi shells, daemon workers, scripts that manually load it, etc.\n' - 'For performing maintenance operations that are safe to run while actively using AiiDA, just run ' - '`verdi storage maintain` without the `--full` flag.\n' - ) - - else: - echo.echo_report( - '\nThis command will perform all maintenance operations on the internal storage that can be safely ' - 'executed while still running AiiDA. ' - 'However, not all operations that are required to fully optimize disk usage and future performance ' - 'can be done in this way.\n' - 'Whenever you find the time or opportunity, please consider running `verdi storage maintain --full` ' - 'for a more complete optimization.\n' - ) - - if not dry_run and not force and not click.confirm('Are you sure you want continue in this mode?'): - return - - try: - if full and no_repack: - storage.maintain(full=full, dry_run=dry_run, do_repack=False, compress=compress) - else: - storage.maintain(full=full, dry_run=dry_run, compress=compress) - except LockingProfileError as exception: - echo.echo_critical(str(exception)) - echo.echo_success('Requested maintenance procedures finished.') diff --git a/aiida/cmdline/commands/cmd_user.py b/aiida/cmdline/commands/cmd_user.py deleted file mode 100644 index 28b4173e4a..0000000000 --- a/aiida/cmdline/commands/cmd_user.py +++ /dev/null @@ -1,144 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""`verdi user` command.""" - -from functools import partial - -import click - -from aiida.cmdline.commands.cmd_verdi import verdi -from aiida.cmdline.params import arguments, options, types -from aiida.cmdline.utils import decorators, echo - - -def set_default_user(profile, user): - """Set the user as the default user for the given profile. - - :param profile: the profile - :param user: the user - """ - from aiida.manage.configuration import get_config - config = get_config() - profile.default_user_email = user.email - config.update_profile(profile) - config.store() - - -def get_user_attribute_default(attribute, ctx): - """Return the default value for the given attribute of the user passed in the context. - - :param attribute: attribute for which to get the current value - :param ctx: click context which should contain the selected user - :return: user attribute default value if set, or None - """ - default = getattr(ctx.params['user'], attribute) - - # None or empty string means there is no default - if not default: - return None - - return default - - -@verdi.group('user') -def verdi_user(): - """Inspect and manage users.""" - - -@verdi_user.command('list') -@decorators.with_dbenv() -def user_list(): - """Show a list of all users.""" - from aiida.orm import User - - default_user = User.collection.get_default() - - if default_user is None: - echo.echo_warning('no default user has been configured') - - attributes = ['email', 'first_name', 'last_name'] - sort = lambda user: user.email - highlight = lambda x: x.email == default_user.email if default_user else None - echo.echo_formatted_list(User.collection.all(), attributes, sort=sort, highlight=highlight) - - -@verdi_user.command('configure') -@click.option( - '--email', - 'user', - prompt='User email', - help='Email address that serves as the user name and a way to identify data created by it.', - type=types.UserParamType(create=True), - cls=options.interactive.InteractiveOption -) -@click.option( - '--first-name', - prompt='First name', - help='First name of the user.', - type=click.STRING, - contextual_default=partial(get_user_attribute_default, 'first_name'), - cls=options.interactive.InteractiveOption -) -@click.option( - '--last-name', - prompt='Last name', - help='Last name of the user.', - type=click.STRING, - contextual_default=partial(get_user_attribute_default, 'last_name'), - cls=options.interactive.InteractiveOption -) -@click.option( - '--institution', - prompt='Institution', - help='Institution of the user.', - type=click.STRING, - contextual_default=partial(get_user_attribute_default, 'institution'), - cls=options.interactive.InteractiveOption -) -@click.option( - '--set-default', - prompt='Set as default?', - help='Set the user as the default user for the current profile.', - is_flag=True, - cls=options.interactive.InteractiveOption -) -@click.pass_context -@decorators.with_dbenv() -def user_configure(ctx, user, first_name, last_name, institution, set_default): - """Configure a new or existing user. - - An e-mail address is used as the user name. - """ - # pylint: disable=too-many-arguments - if first_name is not None: - user.first_name = first_name - if last_name is not None: - user.last_name = last_name - if institution is not None: - user.institution = institution - - action = 'updated' if user.is_stored else 'created' - - user.store() - - echo.echo_success(f'{user.email} successfully {action}') - - if set_default: - ctx.invoke(user_set_default, user=user) - - -@verdi_user.command('set-default') -@arguments.USER() -@click.pass_context -@decorators.with_dbenv() -def user_set_default(ctx, user): - """Set a user as the default user for the profile.""" - set_default_user(ctx.obj.profile, user) - echo.echo_success(f'set `{user.email}` as the new default user for profile `{ctx.obj.profile.name}`') diff --git a/aiida/cmdline/groups/__init__.py b/aiida/cmdline/groups/__init__.py deleted file mode 100644 index 3403f5c550..0000000000 --- a/aiida/cmdline/groups/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# -*- coding: utf-8 -*- -"""Module with custom implementations of :class:`click.Group`.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .dynamic import * -from .verdi import * - -__all__ = ( - 'DynamicEntryPointCommandGroup', - 'VerdiCommandGroup', -) - -# yapf: enable diff --git a/aiida/cmdline/groups/dynamic.py b/aiida/cmdline/groups/dynamic.py deleted file mode 100644 index caead9d3ae..0000000000 --- a/aiida/cmdline/groups/dynamic.py +++ /dev/null @@ -1,148 +0,0 @@ -# -*- coding: utf-8 -*- -"""Subclass of :class:`click.Group` that loads subcommands dynamically from entry points.""" -from __future__ import annotations - -import copy -import functools -import re - -import click - -from aiida.common import exceptions -from aiida.plugins.entry_point import ENTRY_POINT_GROUP_FACTORY_MAPPING, get_entry_point_names - -from ..params import options -from ..params.options.interactive import InteractiveOption -from .verdi import VerdiCommandGroup - -__all__ = ('DynamicEntryPointCommandGroup',) - - -class DynamicEntryPointCommandGroup(VerdiCommandGroup): - """Subclass of :class:`click.Group` that loads subcommands dynamically from entry points. - - A command group using this class will automatically generate the sub commands from the entry points registered in - the given ``entry_point_group``. The entry points can be additionally filtered using a regex defined for the - ``entry_point_name_filter`` keyword. The actual command for each entry point is defined by ``command``, which should - take as a first argument the class that corresponds to the entry point. In addition, it should accept ``kwargs`` - which will be the values for the options passed when the command is invoked. The help string of the command will be - provided by the docstring of the class registered at the respective entry point. Example usage: - - .. code:: python - - def create_instance(cls, **kwargs): - instance = cls(**kwargs) - instance.store() - echo.echo_success(f'Created {cls.__name__}<{instance.pk}>') - - @click.group('create', cls=DynamicEntryPointCommandGroup, command=create_instance,) - def cmd_create(): - pass - - """ - - def __init__( - self, - command, - entry_point_group: str, - entry_point_name_filter: str = r'.*', - shared_options: list[click.Option] | None = None, - **kwargs - ): - super().__init__(**kwargs) - self.command = command - self.entry_point_group = entry_point_group - self.entry_point_name_filter = entry_point_name_filter - self.factory = ENTRY_POINT_GROUP_FACTORY_MAPPING[entry_point_group] - self.shared_options = shared_options - - def list_commands(self, ctx) -> list[str]: - """Return the sorted list of subcommands for this group. - - :param ctx: The :class:`click.Context`. - """ - commands = super().list_commands(ctx) - commands.extend([ - entry_point for entry_point in get_entry_point_names(self.entry_point_group) - if re.match(self.entry_point_name_filter, entry_point) - ]) - return sorted(commands) - - def get_command(self, ctx, cmd_name): - """Return the command with the given name. - - :param ctx: The :class:`click.Context`. - :param cmd_name: The name of the command. - :returns: The :class:`click.Command`. - """ - try: - command = self.create_command(ctx, cmd_name) - except exceptions.EntryPointError: - command = super().get_command(ctx, cmd_name) - return command - - def create_command(self, ctx, entry_point): - """Create a subcommand for the given ``entry_point``.""" - cls = self.factory(entry_point) - command = functools.partial(self.command, ctx, cls) - command.__doc__ = cls.__doc__ - return click.command(entry_point)(self.create_options(entry_point)(command)) - - def create_options(self, entry_point): - """Create the option decorators for the command function for the given entry point. - - :param entry_point: The entry point. - """ - - def apply_options(func): - """Decorate the command function with the appropriate options for the given entry point.""" - func = options.NON_INTERACTIVE()(func) - func = options.CONFIG_FILE()(func) - - options_list = self.list_options(entry_point) - options_list.reverse() - - for option in options_list: - func = option(func) - - shared_options = self.shared_options or [] - shared_options.reverse() - - for option in shared_options: - func = option(func) - - return func - - return apply_options - - def list_options(self, entry_point): - """Return the list of options that should be applied to the command for the given entry point. - - :param entry_point: The entry point. - """ - return [self.create_option(*item) for item in self.factory(entry_point).get_cli_options().items()] - - @staticmethod - def create_option(name, spec): - """Create a click option from a name and a specification.""" - spec = copy.deepcopy(spec) - - is_flag = spec.pop('is_flag', False) - default = spec.get('default') - name_dashed = name.replace('_', '-') - option_name = f'--{name_dashed}/--no-{name_dashed}' if is_flag else f'--{name_dashed}' - option_short_name = spec.pop('short_name', None) - - kwargs = {'cls': spec.pop('cls', InteractiveOption), 'show_default': True, 'is_flag': is_flag, **spec} - - # If the option is a flag with no default, make sure it is not prompted for, as that will force the user to - # specify it to be on or off, but cannot let it unspecified. - if kwargs['cls'] is InteractiveOption and is_flag and default is None: - kwargs['cls'] = functools.partial(InteractiveOption, prompt_fn=lambda ctx: False) - - if option_short_name: - option = click.option(option_short_name, option_name, **kwargs) - else: - option = click.option(option_name, **kwargs) - - return option diff --git a/aiida/cmdline/groups/verdi.py b/aiida/cmdline/groups/verdi.py deleted file mode 100644 index 64a08ce8c8..0000000000 --- a/aiida/cmdline/groups/verdi.py +++ /dev/null @@ -1,125 +0,0 @@ -# -*- coding: utf-8 -*- -"""Subclass of :class:`click.Group` for the ``verdi`` CLI.""" -from __future__ import annotations - -import base64 -import difflib -import gzip - -import click - -from aiida.common.exceptions import ConfigurationError -from aiida.common.extendeddicts import AttributeDict -from aiida.manage.configuration import get_config - -from ..params import options - -__all__ = ('VerdiCommandGroup',) - -GIU = ( - 'ABzY8%U8Kw0{@klyK?I~3`Ki?#qHQ&IIM|J;6yB`9_+{&w)p(JK}vokj-11jhve8xcx?dZ>+9nwrEF!x*S>9A+EWYrR?6GA-u?jFa+et65GF@1+D{%' - '8{C~xjt%>uVM4RTSS?j2M)XH%T#>M{K$lE2XGD`RS0T67213wbAs!SZmn+;(-m!>f(T@e%@oxd`yRBp9nu+9N`4xv8AS@O$CaQ;7FXzM=ug^$?3ta2551EDL`wK4|Cm' - '%RnJdS#0UFwVweDkcfdNjtUv1N^iSQui#TL(q!FmIeKb!yW4|L`@!@-4x6' - 'B6I^ptRdH+4o0ODM;1_f^}4@LMe@#_YHz0wQdq@d)@n)uYNtAb2OLo&fpBkct5{~3kbRag^_5QG%qrTksHMXAYAQoz1#2wtHCy0}h?CJtzv&@Q?^9r' - 'd&02;isB7NJMMr7F@>$!ELj(sbwzIR4)rnch=oVZrG;8)%R6}FUk*fv2O&!#ZA)$HloK9!es&4Eb+h=OIyWFha(8PPy9u?NqfkuPYg;GO1RVzBLX)7' - 'ORMM>1hEM`-96mGjJ+A!e-_}4X{M|4CkKE~uF4j+LW#6IsFa*_da_mLqzr)E<`%ikthkMO2>65cNMtpDE*VejqZV^MyewPJJAS*VM6jY;QY' - '#g7gOKgPbFg{@;YDL6Gbxxr|2T&BQunB?PBetq?X>jW1hFF7&>EaYkKYqIa_ld(Z@AJT' - '+lJ(Pd;+?<&&M>A0agti19^z3n4Z6_WG}c~_+XHyJI_iau7+V$#YA$pJ~H)yHEVy1D?5^Sw`tb@{nnNNo=eSMZLf0>m^A@7f{y$nb_HJWgLRtZ?x2?*>SwM?JoQ>p|-1ZRU0#+{^UhK22+~o' - 'R9k7rh(GH9y|jm){jY9_xAI4N_EfU#4' - 'taTUXFY4a4l$v=N-+f+w&wuH;Z(6p6#=n8XwlZ;*L&-rcL~T_vEm@#-Xi8&g06!MO+R( click.Command | None: - """Return the command that corresponds to the requested ``cmd_name``. - - This method is overridden from the base class in order to two functionalities: - - * If the command is found, automatically add the verbosity option. - * If the command is not found, attempt to provide a list of suggestions with existing commands that resemble - the requested command name. - - Note that if the command is not found and ``resilient_parsing`` is set to True on the context, then the latter - feature is disabled because most likely we are operating in tab-completion mode. - """ - if int(cmd_name.lower().encode('utf-8').hex(), 16) == 0x6769757365707065: - click.echo(gzip.decompress(base64.b85decode(GIU.encode('utf-8'))).decode('utf-8')) - return None - - cmd = super().get_command(ctx, cmd_name) - - if cmd is not None: - return self.add_verbosity_option(cmd) - - # If this command is called during tab-completion, we do not want to print an error message if the command can't - # be found, but instead we want to simply return here. However, in a normal command execution, we do want to - # execute the rest of this method to try and match commands that are similar in order to provide the user with - # some hints. The problem is that there is no one canonical way to determine whether the invocation is due to a - # normal command execution or a tab-complete operation. The `resilient_parsing` attribute of the `Context` is - # designed to allow things like tab-completion, however, it is not the only purpose. For now this is our best - # bet though to detect a tab-complete event. When `resilient_parsing` is switched on, we assume a tab-complete - # and do nothing in case the command name does not match an actual command. - if ctx.resilient_parsing: - return None - - self.fail_with_suggestions(ctx, cmd_name) - - return None - - def group(self, *args, **kwargs) -> click.Group: - """Ensure that sub command groups use the same class but do not override an explicitly set value.""" - kwargs.setdefault('cls', self.__class__) - return super().group(*args, **kwargs) diff --git a/aiida/cmdline/params/__init__.py b/aiida/cmdline/params/__init__.py deleted file mode 100644 index 128abf2797..0000000000 --- a/aiida/cmdline/params/__init__.py +++ /dev/null @@ -1,47 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Commandline parameters.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .types import * - -__all__ = ( - 'AbsolutePathParamType', - 'CalculationParamType', - 'CodeParamType', - 'ComputerParamType', - 'ConfigOptionParamType', - 'DataParamType', - 'EmailType', - 'EntryPointType', - 'FileOrUrl', - 'GroupParamType', - 'HostnameType', - 'IdentifierParamType', - 'LabelStringType', - 'LazyChoice', - 'MpirunCommandParamType', - 'MultipleValueParamType', - 'NodeParamType', - 'NonEmptyStringParamType', - 'PathOrUrl', - 'PluginParamType', - 'ProcessParamType', - 'ProfileParamType', - 'ShebangParamType', - 'UserParamType', - 'WorkflowParamType', -) - -# yapf: enable diff --git a/aiida/cmdline/params/arguments/__init__.py b/aiida/cmdline/params/arguments/__init__.py deleted file mode 100644 index 0c891e6691..0000000000 --- a/aiida/cmdline/params/arguments/__init__.py +++ /dev/null @@ -1,48 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# yapf: disable -"""Module with pre-defined reusable commandline arguments that can be used as `click` decorators.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .main import * -from .overridable import * - -__all__ = ( - 'CALCULATION', - 'CALCULATIONS', - 'CODE', - 'CODES', - 'COMPUTER', - 'COMPUTERS', - 'CONFIG_OPTION', - 'DATA', - 'DATUM', - 'GROUP', - 'GROUPS', - 'INPUT_FILE', - 'LABEL', - 'NODE', - 'NODES', - 'OUTPUT_FILE', - 'OverridableArgument', - 'PROCESS', - 'PROCESSES', - 'PROFILE', - 'PROFILES', - 'USER', - 'WORKFLOW', - 'WORKFLOWS', -) - -# yapf: enable diff --git a/aiida/cmdline/params/arguments/main.py b/aiida/cmdline/params/arguments/main.py deleted file mode 100644 index 71bb8c2544..0000000000 --- a/aiida/cmdline/params/arguments/main.py +++ /dev/null @@ -1,69 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# yapf: disable -"""Module with pre-defined reusable commandline arguments that can be used as `click` decorators.""" - -import click - -from .. import types -from .overridable import OverridableArgument - -__all__ = ( - 'PROFILE', 'PROFILES', 'CALCULATION', 'CALCULATIONS', 'CODE', 'CODES', 'COMPUTER', 'COMPUTERS', 'DATUM', 'DATA', - 'GROUP', 'GROUPS', 'NODE', 'NODES', 'PROCESS', 'PROCESSES', 'WORKFLOW', 'WORKFLOWS', 'INPUT_FILE', 'OUTPUT_FILE', - 'LABEL', 'USER', 'CONFIG_OPTION' -) - - -PROFILE = OverridableArgument('profile', type=types.ProfileParamType()) - -PROFILES = OverridableArgument('profiles', type=types.ProfileParamType(), nargs=-1) - -CALCULATION = OverridableArgument('calculation', type=types.CalculationParamType()) - -CALCULATIONS = OverridableArgument('calculations', nargs=-1, type=types.CalculationParamType()) - -CODE = OverridableArgument('code', type=types.CodeParamType()) - -CODES = OverridableArgument('codes', nargs=-1, type=types.CodeParamType()) - -COMPUTER = OverridableArgument('computer', type=types.ComputerParamType()) - -COMPUTERS = OverridableArgument('computers', nargs=-1, type=types.ComputerParamType()) - -DATUM = OverridableArgument('datum', type=types.DataParamType()) - -DATA = OverridableArgument('data', nargs=-1, type=types.DataParamType()) - -GROUP = OverridableArgument('group', type=types.GroupParamType()) - -GROUPS = OverridableArgument('groups', nargs=-1, type=types.GroupParamType()) - -NODE = OverridableArgument('node', type=types.NodeParamType()) - -NODES = OverridableArgument('nodes', nargs=-1, type=types.NodeParamType()) - -PROCESS = OverridableArgument('process', type=types.ProcessParamType()) - -PROCESSES = OverridableArgument('processes', nargs=-1, type=types.ProcessParamType()) - -WORKFLOW = OverridableArgument('workflow', type=types.WorkflowParamType()) - -WORKFLOWS = OverridableArgument('workflows', nargs=-1, type=types.WorkflowParamType()) - -INPUT_FILE = OverridableArgument('input_file', metavar='INPUT_FILE', type=click.Path(exists=True)) - -OUTPUT_FILE = OverridableArgument('output_file', metavar='OUTPUT_FILE', type=click.Path()) - -LABEL = OverridableArgument('label', type=click.STRING) - -USER = OverridableArgument('user', metavar='USER', type=types.UserParamType()) - -CONFIG_OPTION = OverridableArgument('option', type=types.ConfigOptionParamType()) diff --git a/aiida/cmdline/params/arguments/overridable.py b/aiida/cmdline/params/arguments/overridable.py deleted file mode 100644 index 72ddff6ff7..0000000000 --- a/aiida/cmdline/params/arguments/overridable.py +++ /dev/null @@ -1,63 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -.. py:module::overridable - :synopsis: Convenience class which can be used to defined a set of commonly used arguments that - can be easily reused and which improves consistency across the command line interface -""" -import click - -__all__ = ('OverridableArgument',) - - -class OverridableArgument: - """ - Wrapper around click.argument that increases reusability - - Once defined, the argument can be reused with a consistent name and sensible defaults while - other details can be customized on a per-command basis - - Example:: - - @click.command() - @CODE('code') - def print_code_pk(code): - click.echo(code.pk) - - @click.command() - @CODE('codes', nargs=-1) - def print_code_pks(codes): - click.echo([c.pk for c in codes]) - - Notice that the arguments, which are used to define the name of the argument and based on which - the function argument name is determined, can be overriden - """ - - def __init__(self, *args, **kwargs): - """ - Store the default args and kwargs - """ - self.args = args - self.kwargs = kwargs - - def __call__(self, *args, **kwargs): - """ - Override the stored kwargs with the passed kwargs and return the argument, using the stored args - only if they are not provided. This allows the user to override the variable name, which is - useful if for example they want to allow multiple value with nargs=-1 and want to pluralize - the function argument for consistency - """ - kw_copy = self.kwargs.copy() - kw_copy.update(kwargs) - - if args: - return click.argument(*args, **kw_copy) - - return click.argument(*self.args, **kw_copy) diff --git a/aiida/cmdline/params/options/__init__.py b/aiida/cmdline/params/options/__init__.py deleted file mode 100644 index b509d4e0ba..0000000000 --- a/aiida/cmdline/params/options/__init__.py +++ /dev/null @@ -1,116 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with pre-defined reusable commandline options that can be used as `click` decorators.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .config import * -from .main import * -from .multivalue import * -from .overridable import * - -__all__ = ( - 'ALL', - 'ALL_STATES', - 'ALL_USERS', - 'APPEND_TEXT', - 'ARCHIVE_FORMAT', - 'BROKER_HOST', - 'BROKER_PASSWORD', - 'BROKER_PORT', - 'BROKER_PROTOCOL', - 'BROKER_USERNAME', - 'BROKER_VIRTUAL_HOST', - 'CALCULATION', - 'CALCULATIONS', - 'CALC_JOB_STATE', - 'CODE', - 'CODES', - 'COMPUTER', - 'COMPUTERS', - 'CONFIG_FILE', - 'ConfigFileOption', - 'DATA', - 'DATUM', - 'DB_BACKEND', - 'DB_ENGINE', - 'DB_HOST', - 'DB_NAME', - 'DB_PASSWORD', - 'DB_PORT', - 'DB_USERNAME', - 'DEBUG', - 'DESCRIPTION', - 'DICT_FORMAT', - 'DICT_KEYS', - 'DRY_RUN', - 'EXIT_STATUS', - 'EXPORT_FORMAT', - 'FAILED', - 'FORCE', - 'FORMULA_MODE', - 'FREQUENCY', - 'GROUP', - 'GROUPS', - 'GROUP_CLEAR', - 'HOSTNAME', - 'IDENTIFIER', - 'INPUT_FORMAT', - 'INPUT_PLUGIN', - 'LABEL', - 'LIMIT', - 'MultipleValueOption', - 'NODE', - 'NODES', - 'NON_INTERACTIVE', - 'OLDER_THAN', - 'ORDER_BY', - 'ORDER_DIRECTION', - 'OverridableOption', - 'PAST_DAYS', - 'PAUSED', - 'PORT', - 'PREPEND_TEXT', - 'PRINT_TRACEBACK', - 'PROCESS_LABEL', - 'PROCESS_STATE', - 'PROFILE', - 'PROFILE_ONLY_CONFIG', - 'PROFILE_SET_DEFAULT', - 'PROJECT', - 'RAW', - 'REPOSITORY_PATH', - 'SCHEDULER', - 'SILENT', - 'TIMEOUT', - 'TRAJECTORY_INDEX', - 'TRANSPORT', - 'TRAVERSAL_RULE_HELP_STRING', - 'TYPE_STRING', - 'USER', - 'USER_EMAIL', - 'USER_FIRST_NAME', - 'USER_INSTITUTION', - 'USER_LAST_NAME', - 'VERBOSITY', - 'VISUALIZATION_FORMAT', - 'WAIT', - 'WITH_ELEMENTS', - 'WITH_ELEMENTS_EXCLUSIVE', - 'active_process_states', - 'graph_traversal_rules', - 'valid_calc_job_states', - 'valid_process_states', -) - -# yapf: enable diff --git a/aiida/cmdline/params/options/commands/__init__.py b/aiida/cmdline/params/options/commands/__init__.py deleted file mode 100644 index d0a4d73edb..0000000000 --- a/aiida/cmdline/params/options/commands/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module containing predefined command specific CLI options.""" diff --git a/aiida/cmdline/params/options/commands/code.py b/aiida/cmdline/params/options/commands/code.py deleted file mode 100644 index 98f0f8b3d8..0000000000 --- a/aiida/cmdline/params/options/commands/code.py +++ /dev/null @@ -1,174 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Reusable command line interface options for Code commands.""" -import click - -from aiida.cmdline.params import options, types -from aiida.cmdline.params.options.interactive import InteractiveOption, TemplateInteractiveOption -from aiida.cmdline.params.options.overridable import OverridableOption - - -def is_on_computer(ctx): - return bool(ctx.params.get('on_computer')) - - -def is_not_on_computer(ctx): - return bool(not is_on_computer(ctx)) - - -def validate_label_uniqueness(ctx, _, value): - """Validate the uniqueness of the label of the code. - - The exact uniqueness criterion depends on the type of the code, whether it is "local" or "remote". For the former, - the `label` itself should be unique, whereas for the latter it is the full label, i.e., `label@computer.label`. - - .. note:: For this to work in the case of the remote code, the computer parameter already needs to have been parsed - In interactive mode, this means that the computer parameter needs to be defined after the label parameter in the - command definition. For non-interactive mode, the parsing order will always be determined by the order the - parameters are specified by the caller and so this validator may get called before the computer is parsed. For - that reason, this validator should also be called in the command itself, to ensure it has both the label and - computer parameter available. - - """ - from aiida.common import exceptions - from aiida.orm import load_code - - computer = ctx.params.get('computer', None) - on_computer = ctx.params.get('on_computer', None) - - if on_computer is False: - try: - load_code(value) - except exceptions.NotExistent: - pass - except exceptions.MultipleObjectsError: - raise click.BadParameter(f'multiple copies of the remote code `{value}` already exist.') - else: - raise click.BadParameter(f'the code `{value}` already exists.') - - if computer is not None: - full_label = f'{value}@{computer.label}' - - try: - load_code(full_label) - except exceptions.NotExistent: - pass - except exceptions.MultipleObjectsError: - raise click.BadParameter(f'multiple copies of the local code `{full_label}` already exist.') - else: - raise click.BadParameter(f'the code `{full_label}` already exists.') - - return value - - -ON_COMPUTER = OverridableOption( - '--on-computer/--store-in-db', - is_eager=False, - default=True, - cls=InteractiveOption, - prompt='Installed on target computer?', - help='Whether the code is installed on the target computer, or should be copied to the target computer each time ' - 'from a local path.' -) - -REMOTE_ABS_PATH = OverridableOption( - '--remote-abs-path', - prompt='Remote absolute path', - required_fn=is_on_computer, - prompt_fn=is_on_computer, - type=types.AbsolutePathParamType(dir_okay=False), - cls=InteractiveOption, - help='[if --on-computer]: Absolute path to the executable on the target computer.' -) - -FOLDER = OverridableOption( - '--code-folder', - prompt='Local directory containing the code', - required_fn=is_not_on_computer, - prompt_fn=is_not_on_computer, - type=click.Path(file_okay=False, exists=True, readable=True), - cls=InteractiveOption, - help='[if --store-in-db]: Absolute path to directory containing the executable and all other files necessary for ' - 'running it (to be copied to target computer).' -) - -REL_PATH = OverridableOption( - '--code-rel-path', - prompt='Relative path of executable inside code folder', - required_fn=is_not_on_computer, - prompt_fn=is_not_on_computer, - type=click.Path(dir_okay=False), - cls=InteractiveOption, - help='[if --store-in-db]: Relative path of the executable inside the code-folder.' -) - -USE_DOUBLE_QUOTES = OverridableOption( - '--use-double-quotes/--not-use-double-quotes', - default=False, - cls=InteractiveOption, - prompt='Escape CLI arguments in double quotes', - help='Whether the executable and arguments of the code in the submission script should be escaped with single ' - 'or double quotes.' -) - -LABEL = options.LABEL.clone( - prompt='Label', - callback=validate_label_uniqueness, - cls=InteractiveOption, - help="This label can be used to identify the code (using 'label@computerlabel'), as long as labels are unique per " - 'computer.' -) - -DESCRIPTION = options.DESCRIPTION.clone( - prompt='Description', - cls=InteractiveOption, - help='A human-readable description of this code, ideally including version and compilation environment.' -) - -INPUT_PLUGIN = options.INPUT_PLUGIN.clone( - required=False, - prompt='Default calculation input plugin', - cls=InteractiveOption, - help="Entry point name of the default calculation plugin (as listed in 'verdi plugin list aiida.calculations')." -) - -COMPUTER = options.COMPUTER.clone( - prompt='Computer', - cls=InteractiveOption, - required_fn=is_on_computer, - prompt_fn=is_on_computer, - help='Name of the computer, on which the code is installed.' -) - -PREPEND_TEXT = OverridableOption( - '--prepend-text', - cls=TemplateInteractiveOption, - prompt='Prepend script', - type=click.STRING, - default='', - help='Bash commands that should be prepended to the executable call in all submit scripts for this code.', - extension='.bash', - header='PREPEND_TEXT: if there is any bash commands that should be prepended to the executable call in all ' - 'submit scripts for this code, type that between the equal signs below and save the file.', - footer='All lines that start with `#=` will be ignored.' -) - -APPEND_TEXT = OverridableOption( - '--append-text', - cls=TemplateInteractiveOption, - prompt='Append script', - type=click.STRING, - default='', - help='Bash commands that should be appended to the executable call in all submit scripts for this code.', - extension='.bash', - header='APPEND_TEXT: if there is any bash commands that should be appended to the executable call in all ' - 'submit scripts for this code, type that between the equal signs below and save the file.', - footer='All lines that start with `#=` will be ignored.' -) diff --git a/aiida/cmdline/params/options/commands/computer.py b/aiida/cmdline/params/options/commands/computer.py deleted file mode 100644 index 5d419a9d69..0000000000 --- a/aiida/cmdline/params/options/commands/computer.py +++ /dev/null @@ -1,178 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Reusable command line interface options for Computer commands.""" -import click - -from aiida.cmdline.params import options, types -from aiida.cmdline.params.options.interactive import InteractiveOption, TemplateInteractiveOption -from aiida.cmdline.params.options.overridable import OverridableOption - - -def get_job_resource_cls(ctx): - """ - Return job resource cls from ctx. - """ - from aiida.common.exceptions import ValidationError - - scheduler_ep = ctx.params['scheduler'] - if scheduler_ep is not None: - try: - scheduler_cls = scheduler_ep.load() - except ImportError: - raise ImportError(f"Unable to load the '{scheduler_ep.name}' scheduler") - else: - raise ValidationError( - 'The should_call_... function should always be run (and prompted) AFTER asking for a scheduler' - ) - - return scheduler_cls.job_resource_class - - -def should_call_default_mpiprocs_per_machine(ctx): # pylint: disable=invalid-name - """ - Return whether the selected scheduler type accepts `default_mpiprocs_per_machine`. - - :return: `True` if the scheduler type accepts `default_mpiprocs_per_machine`, `False` - otherwise. If the scheduler class could not be loaded `False` is returned by default. - """ - job_resource_cls = get_job_resource_cls(ctx) - - if job_resource_cls is None: - # Odd situation... - return False - - return job_resource_cls.accepts_default_mpiprocs_per_machine() - - -def should_call_default_memory_per_machine(ctx): # pylint: disable=invalid-name - """ - Return whether the selected scheduler type accepts `default_memory_per_machine`. - - :return: `True` if the scheduler type accepts `default_memory_per_machine`, `False` - otherwise. If the scheduler class could not be loaded `False` is returned by default. - """ - job_resource_cls = get_job_resource_cls(ctx) - - if job_resource_cls is None: - # Odd situation... - return False - - return job_resource_cls.accepts_default_memory_per_machine() - - -LABEL = options.LABEL.clone( - prompt='Computer label', - cls=InteractiveOption, - required=True, - help='Unique, human-readable label for this computer.' -) - -HOSTNAME = options.HOSTNAME.clone( - prompt='Hostname', - cls=InteractiveOption, - required=True, - help='The fully qualified hostname of the computer (e.g. daint.cscs.ch). ' - 'Use "localhost" when setting up the computer that AiiDA is running on.', -) - -DESCRIPTION = options.DESCRIPTION.clone( - prompt='Description', cls=InteractiveOption, help='A human-readable description of this computer.' -) - -TRANSPORT = options.TRANSPORT.clone(prompt='Transport plugin', cls=InteractiveOption) - -SCHEDULER = options.SCHEDULER.clone(prompt='Scheduler plugin', cls=InteractiveOption) - -SHEBANG = OverridableOption( - '--shebang', - prompt='Shebang line (first line of each script, starting with #!)', - default='#!/bin/bash', - cls=InteractiveOption, - help='Specify the first line of the submission script for this computer (only the bash shell is supported).', - type=types.ShebangParamType() -) - -WORKDIR = OverridableOption( - '-w', - '--work-dir', - prompt='Work directory on the computer', - default='/scratch/{username}/aiida/', - cls=InteractiveOption, - help='The absolute path of the directory on the computer where AiiDA will ' - 'run the calculations (often a "scratch" directory).' - 'The {username} string will be replaced by your username on the remote computer.' -) - -MPI_RUN_COMMAND = OverridableOption( - '-m', - '--mpirun-command', - prompt='Mpirun command', - default='mpirun -np {tot_num_mpiprocs}', - cls=InteractiveOption, - help='The mpirun command needed on the cluster to run parallel MPI programs. The {tot_num_mpiprocs} string will be ' - 'replaced by the total number of cpus. See the scheduler docs for further scheduler-dependent template variables.', - type=types.MpirunCommandParamType() -) - -MPI_PROCS_PER_MACHINE = OverridableOption( - '--mpiprocs-per-machine', - prompt='Default number of CPUs per machine', - cls=InteractiveOption, - prompt_fn=should_call_default_mpiprocs_per_machine, - required_fn=False, - type=click.INT, - help='The default number of MPI processes that should be executed per machine (node), if not otherwise specified.' - 'Use 0 to specify no default value.', -) - -DEFAULT_MEMORY_PER_MACHINE = OverridableOption( - '--default-memory-per-machine', - prompt='Default amount of memory per machine (kB).', - cls=InteractiveOption, - prompt_fn=should_call_default_memory_per_machine, - required_fn=False, - type=click.INT, - help='The default amount of RAM (kB) that should be allocated per machine (node), if not otherwise specified.' -) - -USE_DOUBLE_QUOTES = OverridableOption( - '--use-double-quotes/--not-use-double-quotes', - default=False, - cls=InteractiveOption, - prompt='Escape CLI arguments in double quotes', - help='Whether the command line arguments before and after the executable in the submission script should be ' - 'escaped with single or double quotes.' -) - -PREPEND_TEXT = OverridableOption( - '--prepend-text', - cls=TemplateInteractiveOption, - prompt='Prepend script', - type=click.STRING, - default='', - help='Bash commands that should be prepended to the executable call in all submit scripts for this computer.', - extension='.bash', - header='PREPEND_TEXT: if there is any bash commands that should be prepended to the executable call in all ' - 'submit scripts for this computer, type that between the equal signs below and save the file.', - footer='All lines that start with `#=` will be ignored.' -) - -APPEND_TEXT = OverridableOption( - '--append-text', - cls=TemplateInteractiveOption, - prompt='Append script', - type=click.STRING, - default='', - help='Bash commands that should be appended to the executable call in all submit scripts for this computer.', - extension='.bash', - header='APPEND_TEXT: if there is any bash commands that should be appended to the executable call in all ' - 'submit scripts for this computer, type that between the equal signs below and save the file.', - footer='All lines that start with `#=` will be ignored.' -) diff --git a/aiida/cmdline/params/options/config.py b/aiida/cmdline/params/options/config.py deleted file mode 100644 index c0694e9f4a..0000000000 --- a/aiida/cmdline/params/options/config.py +++ /dev/null @@ -1,187 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=cyclic-import -""" -.. py:module::config - :synopsis: Convenience class for configuration file option - -The functions :func:`configuration_callback` and :func:`configuration_option` were directly taken from the repository -https://github.com/phha/click_config_file/blob/7b93a20b4c79458987fac116418859f30a16d82a/click_config_file.py with a -minor modification to ``configuration_callback`` to add a check for unknown parameters in the configuration file and -the default provider is changed to :func:`yaml_config_file_provider`. -""" -from __future__ import annotations - -import functools -import os -import typing as t - -import click -import yaml - -from .overridable import OverridableOption - -__all__ = ('ConfigFileOption',) - - -def yaml_config_file_provider(handle, cmd_name): # pylint: disable=unused-argument - """Read yaml config file from file handle.""" - return yaml.safe_load(handle) - - -def configuration_callback( - cmd_name: str | None, - option_name: str, - config_file_name: str, - saved_callback: t.Callable[..., t.Any] | None, - provider: t.Callable[..., t.Any], - implicit: bool, - ctx: click.Context, - param: click.Parameter, - value: t.Any, -): - """Callback for reading the config file. - - Also takes care of calling user specified custom callback afterwards. - - :param cmd_name: The command name. This is used to determine the configuration directory. - :param option_name: The name of the option. This is used for error messages. - :param config_file_name: The name of the configuration file. - :param saved_callback: User-specified callback to be called later. - :param provider: A callable that parses the configuration file and returns a dictionary of the configuration - parameters. Will be called as ``provider(file_path, cmd_name)``. Default: ``yaml_config_file_provider``. - :param implicit: Whether a implicit value should be applied if no configuration option value was provided. - :param ctx: ``click`` context. - :param param: ``click`` parameters. - :param value: Value passed to the parameter. - """ - ctx.default_map = ctx.default_map or {} - cmd_name = cmd_name or ctx.info_name - - if implicit: - default_value = os.path.join(click.get_app_dir(cmd_name), config_file_name) # type: ignore[arg-type] - param.default = default_value - value = value or default_value - - if value: - try: - config = provider(value, cmd_name) - except Exception as exception: - raise click.BadOptionUsage(option_name, f'Error reading configuration file: {exception}', ctx) - - valid_params = [param.name for param in ctx.command.params if param.name != option_name] - specified_params = list(config.keys()) - unknown_params = set(specified_params).difference(set(valid_params)) - - if unknown_params: - raise click.BadParameter( - f'Invalid configuration file, the following keys are not supported: {unknown_params}', ctx, param - ) - - ctx.default_map.update(config) - - return saved_callback(ctx, param, value) if saved_callback else value - - -def configuration_option(*param_decls, **attrs): - """Adds configuration file support to a click application. - - This will create an option of type ``click.File`` expecting the path to a configuration file. When specified, it - overwrites the default values for all other click arguments or options with the corresponding value from the - configuration file. The default name of the option is ``--config``. By default, the configuration will be read from - a configuration directory as determined by ``click.get_app_dir``. This decorator accepts the same arguments as - ``click.option`` and ``click.Path``. In addition, the following keyword arguments are available: - - :param cmd_name: str - The command name. This is used to determine the configuration directory. Default: ``ctx.info_name``. - :param config_file_name: str - The name of the configuration file. Default: ``config``. - :param implicit: bool - If ``True`` then implicitly create a value for the configuration option using the above parameters. If a - configuration file exists in this path it will be applied even if no configuration option was suppplied as a - CLI argument or environment variable. If ``False`` only apply a configuration file that has been explicitely - specified. Default: ``False``. - :param provider: callable - A callable that parses the configuration file and returns a dictionary of the configuration parameters. Will be - called as ``provider(file_path, cmd_name)``. Default: ``yaml_config_file_provider``. - """ - param_decls = param_decls or ('--config',) - option_name = param_decls[0] - - def decorator(func): - attrs.setdefault('is_eager', True) - attrs.setdefault('help', 'Read configuration from FILE.') - attrs.setdefault('expose_value', False) - implicit = attrs.pop('implicit', True) - cmd_name = attrs.pop('cmd_name', None) - config_file_name = attrs.pop('config_file_name', 'config') - provider = attrs.pop('provider', yaml_config_file_provider) - path_default_params = { - 'exists': False, - 'file_okay': True, - 'dir_okay': False, - 'writable': False, - 'readable': True, - 'resolve_path': False - } - path_params = {k: attrs.pop(k, v) for k, v in path_default_params.items()} - attrs['type'] = attrs.get('type', click.Path(**path_params)) - saved_callback = attrs.pop('callback', None) - partial_callback = functools.partial( - configuration_callback, cmd_name, option_name, config_file_name, saved_callback, provider, implicit - ) - attrs['callback'] = partial_callback - return click.option(*param_decls, **attrs)(func) - - return decorator - - -class ConfigFileOption(OverridableOption): - """Reusable option that reads a configuration file containing values for other command parameters. - - Example:: - - CONFIG_FILE = ConfigFileOption('--config', help='A configuration file') - - @click.command() - @click.option('computer_name') - @CONFIG_FILE(help='Configuration file for computer_setup') - def computer_setup(computer_name): - click.echo(f"Setting up computer {computername}") - - computer_setup --config config.yml - - with config.yml:: - - --- - computer_name: computer1 - - """ - - def __init__(self, *args, **kwargs): - """Store the default args and kwargs. - - :param args: default arguments to be used for the option - :param kwargs: default keyword arguments to be used that can be overridden in the call - """ - kwargs.update({'provider': yaml_config_file_provider, 'implicit': False}) - super().__init__(*args, **kwargs) - - def __call__(self, **kwargs): - """Override the stored kwargs, (ignoring args as we do not allow option name changes) and return the option. - - :param kwargs: keyword arguments that will override those set in the construction - :return: click_config_file.configuration_option constructed with args and kwargs defined during construction - and call of this instance - """ - kw_copy = self.kwargs.copy() - kw_copy.update(kwargs) - - return configuration_option(*self.args, **kw_copy) diff --git a/aiida/cmdline/params/options/main.py b/aiida/cmdline/params/options/main.py deleted file mode 100644 index f7afd3ea50..0000000000 --- a/aiida/cmdline/params/options/main.py +++ /dev/null @@ -1,647 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with pre-defined reusable commandline options that can be used as `click` decorators.""" -import click -from pgsu import DEFAULT_DSN as DEFAULT_DBINFO # pylint: disable=no-name-in-module - -from aiida.common.log import LOG_LEVELS, configure_logging -from aiida.manage.external.rmq import BROKER_DEFAULTS - -from .. import types -from ...utils import defaults, echo # pylint: disable=no-name-in-module -from .config import ConfigFileOption -from .multivalue import MultipleValueOption -from .overridable import OverridableOption - -__all__ = ( - 'ALL', 'ALL_STATES', 'ALL_USERS', 'APPEND_TEXT', 'ARCHIVE_FORMAT', 'BROKER_HOST', 'BROKER_PASSWORD', 'BROKER_PORT', - 'BROKER_PROTOCOL', 'BROKER_USERNAME', 'BROKER_VIRTUAL_HOST', 'CALCULATION', 'CALCULATIONS', 'CALC_JOB_STATE', - 'CODE', 'CODES', 'COMPUTER', 'COMPUTERS', 'CONFIG_FILE', 'DATA', 'DATUM', 'DB_BACKEND', 'DB_ENGINE', 'DB_HOST', - 'DB_NAME', 'DB_PASSWORD', 'DB_PORT', 'DB_USERNAME', 'DEBUG', 'DESCRIPTION', 'DICT_FORMAT', 'DICT_KEYS', 'DRY_RUN', - 'EXIT_STATUS', 'EXPORT_FORMAT', 'FAILED', 'FORCE', 'FORMULA_MODE', 'FREQUENCY', 'GROUP', 'GROUPS', 'GROUP_CLEAR', - 'HOSTNAME', 'IDENTIFIER', 'INPUT_FORMAT', 'INPUT_PLUGIN', 'LABEL', 'LIMIT', 'NODE', 'NODES', 'NON_INTERACTIVE', - 'OLDER_THAN', 'ORDER_BY', 'ORDER_DIRECTION', 'PAST_DAYS', 'PAUSED', 'PORT', 'PREPEND_TEXT', 'PRINT_TRACEBACK', - 'PROCESS_LABEL', 'PROCESS_STATE', 'PROFILE', 'PROFILE_ONLY_CONFIG', 'PROFILE_SET_DEFAULT', 'PROJECT', 'RAW', - 'REPOSITORY_PATH', 'SCHEDULER', 'SILENT', 'TIMEOUT', 'TRAJECTORY_INDEX', 'TRANSPORT', 'TRAVERSAL_RULE_HELP_STRING', - 'TYPE_STRING', 'USER', 'USER_EMAIL', 'USER_FIRST_NAME', 'USER_INSTITUTION', 'USER_LAST_NAME', 'VERBOSITY', - 'VISUALIZATION_FORMAT', 'WAIT', 'WITH_ELEMENTS', 'WITH_ELEMENTS_EXCLUSIVE', 'active_process_states', - 'graph_traversal_rules', 'valid_calc_job_states', 'valid_process_states' -) - -TRAVERSAL_RULE_HELP_STRING = { - 'call_calc_backward': 'CALL links to calculations backwards', - 'call_calc_forward': 'CALL links to calculations forwards', - 'call_work_backward': 'CALL links to workflows backwards', - 'call_work_forward': 'CALL links to workflows forwards', - 'input_calc_backward': 'INPUT links to calculations backwards', - 'input_calc_forward': 'INPUT links to calculations forwards', - 'input_work_backward': 'INPUT links to workflows backwards', - 'input_work_forward': 'INPUT links to workflows forwards', - 'return_backward': 'RETURN links backwards', - 'return_forward': 'RETURN links forwards', - 'create_backward': 'CREATE links backwards', - 'create_forward': 'CREATE links forwards', -} - - -def valid_process_states(): - """Return a list of valid values for the ProcessState enum.""" - from plumpy import ProcessState - return tuple(state.value for state in ProcessState) - - -def valid_calc_job_states(): - """Return a list of valid values for the CalcState enum.""" - from aiida.common.datastructures import CalcJobState - return tuple(state.value for state in CalcJobState) - - -def active_process_states(): - """Return a list of process states that are considered active.""" - from plumpy import ProcessState - return ([ - ProcessState.CREATED.value, - ProcessState.WAITING.value, - ProcessState.RUNNING.value, - ]) - - -def graph_traversal_rules(rules): - """Apply the graph traversal rule options to the command.""" - - def decorator(command): - """Only apply to traversal rules if they are toggleable.""" - for name, traversal_rule in sorted(rules.items(), reverse=True): - if traversal_rule.toggleable: - option_name = name.replace('_', '-') - option_label = '--{option_name}/--no-{option_name}'.format(option_name=option_name) - help_string = f'Whether to expand the node set by following {TRAVERSAL_RULE_HELP_STRING[name]}.' - click.option(option_label, default=traversal_rule.default, show_default=True, help=help_string)(command) - - return command - - return decorator - - -def set_log_level(_ctx, _param, value): - """Configure the logging for the CLI command being executed. - - Note that we cannot use the most obvious approach of directly setting the level on the various loggers. The reason - is that after this callback is finished, the :meth:`aiida.common.log.configure_logging` method can be called again, - for example when the database backend is loaded, and this will undo this change. So instead, we set to globals in - the :mod:`aiida.common.log` module: ``CLI_ACTIVE`` and ``CLI_LOG_LEVEL``. The ``CLI_ACTIVE`` global is always set to - ``True``. The ``configure_logging`` function will interpret this as the code being executed through a ``verdi`` - call. The ``CLI_LOG_LEVEL`` global is only set if an explicit value is set for the ``--verbosity`` option. In this - case, it is set to the specified log level and ``configure_logging`` will then set this log level for all loggers. - - This approach tightly couples the generic :mod:`aiida.common.log` module to the :mod:`aiida.cmdline` module, which - is not the cleanest, but given that other module code can undo the logging configuration by calling that method, - there seems no easy way around this approach. - """ - from aiida.common import log - - log.CLI_ACTIVE = True - - # If the value is ``None``, it means the option was not specified, but we still configure logging for the CLI - if value is None: - configure_logging() - return None - - try: - log_level = value.upper() - except AttributeError: - raise click.BadParameter(f'`{value}` is not a string.') - - if log_level not in LOG_LEVELS: - raise click.BadParameter(f'`{log_level}` is not a valid log level.') - - log.CLI_LOG_LEVEL = log_level - - # Make sure the logging is configured, even if it may be undone in the future by another call to this method. - configure_logging() - - return log_level - - -VERBOSITY = OverridableOption( - '-v', - '--verbosity', - type=click.Choice(tuple(map(str.lower, LOG_LEVELS.keys())), case_sensitive=False), - callback=set_log_level, - expose_value=False, # Ensures that the option is not actually passed to the command, because it doesn't need it - help='Set the verbosity of the output.' -) - -PROFILE = OverridableOption( - '-p', - '--profile', - 'profile', - type=types.ProfileParamType(), - default=defaults.get_default_profile, - help='Execute the command for this profile instead of the default profile.' -) - -CALCULATION = OverridableOption( - '-C', - '--calculation', - 'calculation', - type=types.CalculationParamType(), - help='A single calculation identified by its ID or UUID.' -) - -CALCULATIONS = OverridableOption( - '-C', - '--calculations', - 'calculations', - type=types.CalculationParamType(), - cls=MultipleValueOption, - help='One or multiple calculations identified by their ID or UUID.' -) - -CODE = OverridableOption( - '-X', '--code', 'code', type=types.CodeParamType(), help='A single code identified by its ID, UUID or label.' -) - -CODES = OverridableOption( - '-X', - '--codes', - 'codes', - type=types.CodeParamType(), - cls=MultipleValueOption, - help='One or multiple codes identified by their ID, UUID or label.' -) - -COMPUTER = OverridableOption( - '-Y', - '--computer', - 'computer', - type=types.ComputerParamType(), - help='A single computer identified by its ID, UUID or label.' -) - -COMPUTERS = OverridableOption( - '-Y', - '--computers', - 'computers', - type=types.ComputerParamType(), - cls=MultipleValueOption, - help='One or multiple computers identified by their ID, UUID or label.' -) - -DATUM = OverridableOption( - '-D', '--datum', 'datum', type=types.DataParamType(), help='A single datum identified by its ID, UUID or label.' -) - -DATA = OverridableOption( - '-D', - '--data', - 'data', - type=types.DataParamType(), - cls=MultipleValueOption, - help='One or multiple data identified by their ID, UUID or label.' -) - -GROUP = OverridableOption( - '-G', '--group', 'group', type=types.GroupParamType(), help='A single group identified by its ID, UUID or label.' -) - -GROUPS = OverridableOption( - '-G', - '--groups', - 'groups', - type=types.GroupParamType(), - cls=MultipleValueOption, - help='One or multiple groups identified by their ID, UUID or label.' -) - -NODE = OverridableOption( - '-N', '--node', 'node', type=types.NodeParamType(), help='A single node identified by its ID or UUID.' -) - -NODES = OverridableOption( - '-N', - '--nodes', - 'nodes', - type=types.NodeParamType(), - cls=MultipleValueOption, - help='One or multiple nodes identified by their ID or UUID.' -) - -FORCE = OverridableOption('-f', '--force', is_flag=True, default=False, help='Do not ask for confirmation.') - -SILENT = OverridableOption('-s', '--silent', is_flag=True, default=False, help='Suppress any output printed to stdout.') - -VISUALIZATION_FORMAT = OverridableOption( - '-F', '--format', 'fmt', show_default=True, help='Format of the visualized output.' -) - -INPUT_FORMAT = OverridableOption('-F', '--format', 'fmt', show_default=True, help='Format of the input file.') - -EXPORT_FORMAT = OverridableOption('-F', '--format', 'fmt', show_default=True, help='Format of the exported file.') - -ARCHIVE_FORMAT = OverridableOption( - '-F', - '--archive-format', - type=click.Choice(['zip', 'zip-uncompressed', 'tar.gz']), - default='zip', - show_default=True, - help='The format of the archive file.' -) - -NON_INTERACTIVE = OverridableOption( - '-n', - '--non-interactive', - is_flag=True, - is_eager=True, - help='In non-interactive mode, the CLI never prompts but simply uses default values for options that define one.' -) - -DRY_RUN = OverridableOption('-n', '--dry-run', is_flag=True, help='Perform a dry run.') - -USER_EMAIL = OverridableOption( - '--email', - 'email', - type=types.EmailType(), - help='Email address associated with the data you generate. The email address is exported along with the data, ' - 'when sharing it.' -) - -USER_FIRST_NAME = OverridableOption( - '--first-name', type=types.NonEmptyStringParamType(), help='First name of the user.' -) - -USER_LAST_NAME = OverridableOption('--last-name', type=types.NonEmptyStringParamType(), help='Last name of the user.') - -USER_INSTITUTION = OverridableOption( - '--institution', type=types.NonEmptyStringParamType(), help='Institution of the user.' -) - -DB_ENGINE = OverridableOption( - '--db-engine', - help='Engine to use to connect to the database.', - default='postgresql_psycopg2', - type=click.Choice(['postgresql_psycopg2']) -) - -DB_BACKEND = OverridableOption( - '--db-backend', type=click.Choice(['core.psql_dos']), default='core.psql_dos', help='Database backend to use.' -) - -DB_HOST = OverridableOption( - '--db-host', - type=types.HostnameType(), - help='Database server host. Leave empty for "peer" authentication.', - default='localhost' -) - -DB_PORT = OverridableOption( - '--db-port', - type=click.INT, - help='Database server port.', - default=DEFAULT_DBINFO['port'], -) - -DB_USERNAME = OverridableOption( - '--db-username', type=types.NonEmptyStringParamType(), help='Name of the database user.' -) - -DB_PASSWORD = OverridableOption( - '--db-password', - type=click.STRING, - help='Password of the database user.', - hide_input=True, -) - -DB_NAME = OverridableOption('--db-name', type=types.NonEmptyStringParamType(), help='Database name.') - -BROKER_PROTOCOL = OverridableOption( - '--broker-protocol', - type=click.Choice(('amqp', 'amqps')), - default=BROKER_DEFAULTS.protocol, - show_default=True, - help='Protocol to use for the message broker.' -) - -BROKER_USERNAME = OverridableOption( - '--broker-username', - type=types.NonEmptyStringParamType(), - default=BROKER_DEFAULTS.username, - show_default=True, - help='Username to use for authentication with the message broker.' -) - -BROKER_PASSWORD = OverridableOption( - '--broker-password', - type=types.NonEmptyStringParamType(), - default=BROKER_DEFAULTS.password, - show_default=True, - help='Password to use for authentication with the message broker.', - hide_input=True, -) - -BROKER_HOST = OverridableOption( - '--broker-host', - type=types.HostnameType(), - default=BROKER_DEFAULTS.host, - show_default=True, - help='Hostname for the message broker.' -) - -BROKER_PORT = OverridableOption( - '--broker-port', - type=click.INT, - default=BROKER_DEFAULTS.port, - show_default=True, - help='Port for the message broker.', -) - -BROKER_VIRTUAL_HOST = OverridableOption( - '--broker-virtual-host', - type=click.types.StringParamType(), - default=BROKER_DEFAULTS.virtual_host, - show_default=True, - help='Name of the virtual host for the message broker without leading forward slash.' -) - -REPOSITORY_PATH = OverridableOption( - '--repository', type=click.Path(file_okay=False), help='Absolute path to the file repository.' -) - -PROFILE_ONLY_CONFIG = OverridableOption( - '--only-config', is_flag=True, default=False, help='Only configure the user and skip creating the database.' -) - -PROFILE_SET_DEFAULT = OverridableOption( - '--set-default', is_flag=True, default=False, help='Set the profile as the new default.' -) - -PREPEND_TEXT = OverridableOption( - '--prepend-text', type=click.STRING, default='', help='Bash script to be executed before an action.' -) - -APPEND_TEXT = OverridableOption( - '--append-text', type=click.STRING, default='', help='Bash script to be executed after an action has completed.' -) - -LABEL = OverridableOption('-L', '--label', type=click.STRING, metavar='LABEL', help='Short name to be used as a label.') - -DESCRIPTION = OverridableOption( - '-D', - '--description', - type=click.STRING, - metavar='DESCRIPTION', - default='', - required=False, - help='A detailed description.' -) - -INPUT_PLUGIN = OverridableOption( - '-P', - '--input-plugin', - type=types.PluginParamType(group='calculations', load=False), - help='Calculation input plugin string.' -) - -CALC_JOB_STATE = OverridableOption( - '-s', - '--calc-job-state', - 'calc_job_state', - type=types.LazyChoice(valid_calc_job_states), - cls=MultipleValueOption, - help='Only include entries with this calculation job state.' -) - -PROCESS_STATE = OverridableOption( - '-S', - '--process-state', - 'process_state', - type=types.LazyChoice(valid_process_states), - cls=MultipleValueOption, - default=active_process_states, - help='Only include entries with this process state.' -) - -PAUSED = OverridableOption('--paused', 'paused', is_flag=True, help='Only include entries that are paused.') - -PROCESS_LABEL = OverridableOption( - '-L', - '--process-label', - 'process_label', - type=click.STRING, - required=False, - help='Only include entries whose process label matches this filter.' -) - -TYPE_STRING = OverridableOption( - '-T', - '--type-string', - 'type_string', - type=click.STRING, - required=False, - help='Only include entries whose type string matches this filter. Can include `_` to match a single arbitrary ' - 'character or `%` to match any number of characters.' -) - -EXIT_STATUS = OverridableOption( - '-E', '--exit-status', 'exit_status', type=click.INT, help='Only include entries with this exit status.' -) - -FAILED = OverridableOption( - '-X', '--failed', 'failed', is_flag=True, default=False, help='Only include entries that have failed.' -) - -LIMIT = OverridableOption( - '-l', '--limit', 'limit', type=click.INT, default=None, help='Limit the number of entries to display.' -) - -PROJECT = OverridableOption( - '-P', '--project', 'project', cls=MultipleValueOption, help='Select the list of entity attributes to project.' -) - -ORDER_BY = OverridableOption( - '-O', - '--order-by', - 'order_by', - type=click.Choice(['id', 'ctime']), - default='ctime', - show_default=True, - help='Order the entries by this attribute.' -) - -ORDER_DIRECTION = OverridableOption( - '-D', - '--order-direction', - 'order_dir', - type=click.Choice(['asc', 'desc']), - default='asc', - show_default=True, - help='List the entries in ascending or descending order' -) - -PAST_DAYS = OverridableOption( - '-p', - '--past-days', - 'past_days', - type=click.INT, - metavar='PAST_DAYS', - help='Only include entries created in the last PAST_DAYS number of days.' -) - -OLDER_THAN = OverridableOption( - '-o', - '--older-than', - 'older_than', - type=click.INT, - metavar='OLDER_THAN', - help='Only include entries created before OLDER_THAN days ago.' -) - -ALL = OverridableOption( - '-a', - '--all', - 'all_entries', - is_flag=True, - default=False, - help='Include all entries, disregarding all other filter options and flags.' -) - -ALL_STATES = OverridableOption('-A', '--all-states', is_flag=True, help='Do not limit to items in running state.') - -ALL_USERS = OverridableOption( - '-A', '--all-users', 'all_users', is_flag=True, default=False, help='Include all entries regardless of the owner.' -) - -GROUP_CLEAR = OverridableOption( - '-c', '--clear', is_flag=True, default=False, help='Remove all the nodes from the group.' -) - -RAW = OverridableOption( - '-r', - '--raw', - 'raw', - is_flag=True, - default=False, - help='Display only raw query results, without any headers or footers.' -) - -HOSTNAME = OverridableOption('-H', '--hostname', type=types.HostnameType(), help='Hostname.') - -TRANSPORT = OverridableOption( - '-T', - '--transport', - type=types.PluginParamType(group='transports'), - required=True, - help='A transport plugin (as listed in `verdi plugin list aiida.transports`).' -) - -SCHEDULER = OverridableOption( - '-S', - '--scheduler', - type=types.PluginParamType(group='schedulers'), - required=True, - help='A scheduler plugin (as listed in `verdi plugin list aiida.schedulers`).' -) - -USER = OverridableOption('-u', '--user', 'user', type=types.UserParamType(), help='Email address of the user.') - -PORT = OverridableOption('-P', '--port', 'port', type=click.INT, help='Port number.') - -FREQUENCY = OverridableOption('-F', '--frequency', 'frequency', type=click.INT) - -TIMEOUT = OverridableOption( - '-t', - '--timeout', - type=click.FLOAT, - default=5.0, - show_default=True, - help='Time in seconds to wait for a response before timing out.' -) - -WAIT = OverridableOption( - '--wait/--no-wait', - default=False, - help='Wait for the action to be completed otherwise return as soon as it is scheduled.' -) - -FORMULA_MODE = OverridableOption( - '-f', - '--formula-mode', - type=click.Choice(['hill', 'hill_compact', 'reduce', 'group', 'count', 'count_compact']), - default='hill', - help='Mode for printing the chemical formula.' -) - -TRAJECTORY_INDEX = OverridableOption( - '-i', - '--trajectory-index', - 'trajectory_index', - type=click.INT, - default=None, - help='Specific step of the Trajectory to select.' -) - -WITH_ELEMENTS = OverridableOption( - '-e', - '--with-elements', - 'elements', - type=click.STRING, - cls=MultipleValueOption, - default=None, - help='Only select objects containing these elements.' -) - -WITH_ELEMENTS_EXCLUSIVE = OverridableOption( - '-E', - '--with-elements-exclusive', - 'elements_exclusive', - type=click.STRING, - cls=MultipleValueOption, - default=None, - help='Only select objects containing only these and no other elements.' -) - -CONFIG_FILE = ConfigFileOption( - '--config', - type=types.FileOrUrl(), - help='Load option values from configuration file in yaml format (local path or URL).' -) - -IDENTIFIER = OverridableOption( - '-i', - '--identifier', - 'identifier', - help='The type of identifier used for specifying each node.', - default='pk', - type=click.Choice(['pk', 'uuid']) -) - -DICT_FORMAT = OverridableOption( - '-f', - '--format', - 'fmt', - type=click.Choice(list(echo.VALID_DICT_FORMATS_MAPPING.keys())), - default=list(echo.VALID_DICT_FORMATS_MAPPING.keys())[0], - help='The format of the output data.' -) - -DICT_KEYS = OverridableOption( - '-k', '--keys', type=click.STRING, cls=MultipleValueOption, help='Filter the output by one or more keys.' -) - -DEBUG = OverridableOption( - '--debug', is_flag=True, default=False, help='Show debug messages. Mostly relevant for developers.', hidden=True -) - -PRINT_TRACEBACK = OverridableOption( - '-t', - '--print-traceback', - is_flag=True, - help='Print the full traceback in case an exception is raised.', -) diff --git a/aiida/cmdline/params/options/overridable.py b/aiida/cmdline/params/options/overridable.py deleted file mode 100644 index fae2ca0aff..0000000000 --- a/aiida/cmdline/params/options/overridable.py +++ /dev/null @@ -1,89 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=cyclic-import -""" -.. py:module::overridable - :synopsis: Convenience class which can be used to defined a set of commonly used options that - can be easily reused and which improves consistency across the command line interface -""" - -import click - -__all__ = ('OverridableOption',) - - -class OverridableOption: - """ - Wrapper around click option that increases reusability - - Click options are reusable already but sometimes it can improve the user interface to for example customize a - help message for an option on a per-command basis. Sometimes the option should be prompted for if it is not given - On some commands an option might take any folder path, while on another the path only has to exist. - - Overridable options store the arguments to click.option and only instantiate the click.Option on call, - kwargs given to ``__call__`` override the stored ones. - - Example:: - - FOLDER = OverridableOption('--folder', type=click.Path(file_okay=False), help='A folder') - - @click.command() - @FOLDER(help='A folder, will be created if it does not exist') - def ls_or_create(folder): - click.echo(os.listdir(folder)) - - @click.command() - @FOLDER(help='An existing folder', type=click.Path(exists=True, file_okay=False, readable=True) - def ls(folder) - click.echo(os.listdir(folder)) - """ - - def __init__(self, *args, **kwargs): - """ - Store the default args and kwargs. - - :param args: default arguments to be used for the click option - :param kwargs: default keyword arguments to be used that can be overridden in the call - """ - self.args = args - self.kwargs = kwargs - - def __call__(self, **kwargs): - """ - Override the stored kwargs, (ignoring args as we do not allow option name changes) and return the option. - - :param kwargs: keyword arguments that will override those set in the construction - :return: click option constructed with args and kwargs defined during construction and call of this instance - """ - kw_copy = self.kwargs.copy() - kw_copy.update(kwargs) - return click.option(*self.args, **kw_copy) - - def clone(self, **kwargs): - """ - Create a new instance of the OverridableOption by cloning it and updating the stored kwargs with those passed. - - This can be useful when an already predefined OverridableOption needs to be further specified and reused - by a set of sub commands. Example:: - - LABEL = OverridableOption('-l', '--label', required=False, help='The label of the node' - LABEL_COMPUTER = LABEL.clone(required=True, help='The label of the computer') - - If multiple computer related sub commands need the LABEL option, but the default help string and required - attribute need to be different, the `clone` method allows to override these and create a new OverridableOption - instance that can then be used as a decorator. - - :param kwargs: keyword arguments to update - :return: OverridableOption instance with stored keyword arguments updated - """ - import copy - clone = copy.deepcopy(self) - clone.kwargs.update(kwargs) - return clone diff --git a/aiida/cmdline/params/types/__init__.py b/aiida/cmdline/params/types/__init__.py deleted file mode 100644 index 4607b6dcbe..0000000000 --- a/aiida/cmdline/params/types/__init__.py +++ /dev/null @@ -1,63 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Provides all parameter types.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .calculation import * -from .choice import * -from .code import * -from .computer import * -from .config import * -from .data import * -from .group import * -from .identifier import * -from .multiple import * -from .node import * -from .path import * -from .plugin import * -from .process import * -from .profile import * -from .strings import * -from .user import * -from .workflow import * - -__all__ = ( - 'AbsolutePathParamType', - 'CalculationParamType', - 'CodeParamType', - 'ComputerParamType', - 'ConfigOptionParamType', - 'DataParamType', - 'EmailType', - 'EntryPointType', - 'FileOrUrl', - 'GroupParamType', - 'HostnameType', - 'IdentifierParamType', - 'LabelStringType', - 'LazyChoice', - 'MpirunCommandParamType', - 'MultipleValueParamType', - 'NodeParamType', - 'NonEmptyStringParamType', - 'PathOrUrl', - 'PluginParamType', - 'ProcessParamType', - 'ProfileParamType', - 'ShebangParamType', - 'UserParamType', - 'WorkflowParamType', -) - -# yapf: enable diff --git a/aiida/cmdline/params/types/calculation.py b/aiida/cmdline/params/types/calculation.py deleted file mode 100644 index 2e4c0d0750..0000000000 --- a/aiida/cmdline/params/types/calculation.py +++ /dev/null @@ -1,35 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -Module for the calculation parameter type -""" - -from .identifier import IdentifierParamType - -__all__ = ('CalculationParamType',) - - -class CalculationParamType(IdentifierParamType): - """ - The ParamType for identifying Calculation entities or its subclasses - """ - - name = 'Calculation' - - @property - def orm_class_loader(self): - """ - Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed - to be used to load the entity for a given identifier - - :return: the orm entity loader class for this ParamType - """ - from aiida.orm.utils.loaders import CalculationEntityLoader - return CalculationEntityLoader diff --git a/aiida/cmdline/params/types/code.py b/aiida/cmdline/params/types/code.py deleted file mode 100644 index 32367323ff..0000000000 --- a/aiida/cmdline/params/types/code.py +++ /dev/null @@ -1,70 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module to define the custom click type for code.""" -import click - -from aiida.cmdline.utils import decorators - -from .identifier import IdentifierParamType - -__all__ = ('CodeParamType',) - - -class CodeParamType(IdentifierParamType): - """ - The ParamType for identifying Code entities or its subclasses - """ - - name = 'Code' - - def __init__(self, sub_classes=None, entry_point=None): - """Construct the param type - - :param sub_classes: specify a tuple of Code sub classes to narrow the query set - :param entry_point: specify an optional calculation entry point that the Code's input plugin should match - """ - super().__init__(sub_classes) - self._entry_point = entry_point - - @property - def orm_class_loader(self): - """ - Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed - to be used to load the entity for a given identifier - - :return: the orm entity loader class for this ParamType - """ - from aiida.orm.utils.loaders import CodeEntityLoader - return CodeEntityLoader - - @decorators.with_dbenv() - def shell_complete(self, ctx, param, incomplete): # pylint: disable=unused-argument - """Return possible completions based on an incomplete value. - - :returns: list of tuples of valid entry points (matching incomplete) and a description - """ - return [ - click.shell_completion.CompletionItem(option) - for option, in self.orm_class_loader.get_options(incomplete, project='label') - ] - - def convert(self, value, param, ctx): - code = super().convert(value, param, ctx) - - if code and self._entry_point is not None: - entry_point = code.default_calc_job_plugin - if entry_point != self._entry_point: - raise click.BadParameter( - 'the retrieved Code<{}> has plugin type "{}" while "{}" is required'.format( - code.pk, entry_point, self._entry_point - ) - ) - - return code diff --git a/aiida/cmdline/params/types/computer.py b/aiida/cmdline/params/types/computer.py deleted file mode 100644 index 97dcfdc2a0..0000000000 --- a/aiida/cmdline/params/types/computer.py +++ /dev/null @@ -1,110 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -Module for the custom click param type computer -""" -from click.shell_completion import CompletionItem -from click.types import StringParamType - -from ...utils import decorators # pylint: disable=no-name-in-module -from .identifier import IdentifierParamType - -__all__ = ('ComputerParamType', 'ShebangParamType', 'MpirunCommandParamType') - - -class ComputerParamType(IdentifierParamType): - """ - The ParamType for identifying Computer entities or its subclasses - """ - - name = 'Computer' - - @property - def orm_class_loader(self): - """ - Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed - to be used to load the entity for a given identifier - - :return: the orm entity loader class for this ParamType - """ - from aiida.orm.utils.loaders import ComputerEntityLoader - return ComputerEntityLoader - - @decorators.with_dbenv() - def shell_complete(self, ctx, param, incomplete): # pylint: disable=unused-argument - """Return possible completions based on an incomplete value. - - :returns: list of tuples of valid entry points (matching incomplete) and a description - """ - return [CompletionItem(option) for option, in self.orm_class_loader.get_options(incomplete, project='label')] - - -class ShebangParamType(StringParamType): - """ - Custom click param type for shebang line - """ - name = 'shebangline' - - def convert(self, value, param, ctx): - newval = super().convert(value, param, ctx) - if newval is None: - return None - if not newval.startswith('#!'): - self.fail(f'The shebang line should start with the two caracters #!, it is instead: {newval}') - return newval - - def __repr__(self): - return 'SHEBANGLINE' - - -class MpirunCommandParamType(StringParamType): - """ - Custom click param type for mpirun-command - - .. note:: requires also a scheduler to be provided, and the scheduler - must be called first! - - Validate that the provided 'mpirun' command only contains replacement fields - (e.g. ``{tot_num_mpiprocs}``) that are known. - - Return a list of arguments (by using 'value.strip().split(" ") on the input string) - """ - name = 'mpiruncommandstring' - - def __repr__(self): - return 'MPIRUNCOMMANDSTRING' - - def convert(self, value, param, ctx): - newval = super().convert(value, param, ctx) - - scheduler_ep = ctx.params.get('scheduler', None) - if scheduler_ep is not None: - try: - job_resource_keys = scheduler_ep.load().job_resource_class.get_valid_keys() - except ImportError: - self.fail(f"Unable to load the '{scheduler_ep.name}' scheduler") - else: - self.fail( - 'Scheduler not specified for this computer! The mpirun-command must always be asked ' - 'after asking for the scheduler.' - ) - - # Prepare some substitution values to check if it is all ok - subst = {i: 'value' for i in job_resource_keys} - subst['tot_num_mpiprocs'] = 'value' - - try: - newval.format(**subst) - except KeyError as exc: - self.fail(f"In workdir there is an unknown replacement field '{exc.args[0]}'") - except ValueError as exc: - self.fail(f"Error in the string: '{exc}'") - - return newval diff --git a/aiida/cmdline/params/types/config.py b/aiida/cmdline/params/types/config.py deleted file mode 100644 index 195516554c..0000000000 --- a/aiida/cmdline/params/types/config.py +++ /dev/null @@ -1,41 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module to define the custom click type for code.""" -import click - -__all__ = ('ConfigOptionParamType',) - - -class ConfigOptionParamType(click.types.StringParamType): - """ParamType for configuration options.""" - - name = 'config option' - - def convert(self, value, param, ctx): - from aiida.manage.configuration.options import get_option, get_option_names - - if value not in get_option_names(): - raise click.BadParameter(f'{value} is not a valid configuration option') - - return get_option(value) - - def shell_complete(self, ctx, param, incomplete): # pylint: disable=unused-argument - """ - Return possible completions based on an incomplete value - - :returns: list of tuples of valid entry points (matching incomplete) and a description - """ - from aiida.manage.configuration.options import get_option_names - - return [ - click.shell_completion.CompletionItem(option_name) - for option_name in get_option_names() - if option_name.startswith(incomplete) - ] diff --git a/aiida/cmdline/params/types/data.py b/aiida/cmdline/params/types/data.py deleted file mode 100644 index 742dec10eb..0000000000 --- a/aiida/cmdline/params/types/data.py +++ /dev/null @@ -1,34 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -Module for the custom click param type for data -""" -from .identifier import IdentifierParamType - -__all__ = ('DataParamType',) - - -class DataParamType(IdentifierParamType): - """ - The ParamType for identifying Data entities or its subclasses - """ - - name = 'Data' - - @property - def orm_class_loader(self): - """ - Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed - to be used to load the entity for a given identifier - - :return: the orm entity loader class for this ParamType - """ - from aiida.orm.utils.loaders import DataEntityLoader - return DataEntityLoader diff --git a/aiida/cmdline/params/types/group.py b/aiida/cmdline/params/types/group.py deleted file mode 100644 index fe55c7694c..0000000000 --- a/aiida/cmdline/params/types/group.py +++ /dev/null @@ -1,84 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module for custom click param type group.""" -import click - -from aiida.cmdline.utils import decorators -from aiida.common.lang import type_check - -from .identifier import IdentifierParamType - -__all__ = ('GroupParamType',) - - -class GroupParamType(IdentifierParamType): - """The ParamType for identifying Group entities or its subclasses.""" - - name = 'Group' - - def __init__(self, create_if_not_exist=False, sub_classes=('aiida.groups:core',)): - """Construct the parameter type. - - The `sub_classes` argument can be used to narrow the set of subclasses of `Group` that should be matched. By - default all subclasses of `Group` will be matched, otherwise it is restricted to the subclasses that correspond - to the entry point names in the tuple of `sub_classes`. - - To prevent having to load the database environment at import time, the actual loading of the entry points is - deferred until the call to `convert` is made. This is to keep the command line autocompletion light and - responsive. The entry point strings will be validated, however, to see if they correspond to known entry points. - - :param create_if_not_exist: boolean, if True, will create the group if it does not yet exist. By default the - group created will be of class `Group`, unless another subclass is specified through `sub_classes`. Note - that in this case, only a single entry point name can be specified - :param sub_classes: a tuple of entry point strings from the `aiida.groups` entry point group. - """ - type_check(sub_classes, tuple, allow_none=True) - - if create_if_not_exist and len(sub_classes) > 1: - raise ValueError('`sub_classes` can at most contain one entry point if `create_if_not_exist=True`') - - self._create_if_not_exist = create_if_not_exist - super().__init__(sub_classes=sub_classes) - - @property - def orm_class_loader(self): - """Return the orm entity loader class, which should be a subclass of `OrmEntityLoader`. - - This class is supposed to be used to load the entity for a given identifier. - - :return: the orm entity loader class for this `ParamType` - """ - from aiida.orm.utils.loaders import GroupEntityLoader - return GroupEntityLoader - - @decorators.with_dbenv() - def shell_complete(self, ctx, param, incomplete): # pylint: disable=unused-argument - """Return possible completions based on an incomplete value. - - :returns: list of tuples of valid entry points (matching incomplete) and a description - """ - return [ - click.shell_completion.CompletionItem(option) - for option, in self.orm_class_loader.get_options(incomplete, project='label') - ] - - @decorators.with_dbenv() - def convert(self, value, param, ctx): - try: - group = super().convert(value, param, ctx) - except click.BadParameter: - if self._create_if_not_exist: - # The particular subclass to load will be stored in `_sub_classes` as loaded by `convert` of the super. - cls = self._sub_classes[0] - group = cls(label=value).store() - else: - raise - - return group diff --git a/aiida/cmdline/params/types/identifier.py b/aiida/cmdline/params/types/identifier.py deleted file mode 100644 index fc15539fdd..0000000000 --- a/aiida/cmdline/params/types/identifier.py +++ /dev/null @@ -1,131 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -Module for custom click param type identifier -""" -from abc import ABC, abstractmethod - -import click - -from aiida.cmdline.utils.decorators import with_dbenv -from aiida.plugins.entry_point import get_entry_point_from_string - -__all__ = ('IdentifierParamType',) - - -class IdentifierParamType(click.ParamType, ABC): - """ - An extension of click.ParamType for a generic identifier parameter. In AiiDA, orm entities can often be - identified by either their ID, UUID or optionally some LABEL identifier. This parameter type implements - the convert method, which attempts to convert a value passed to the command for a parameter with this type, - to an orm entity. The actual loading of the entity is delegated to the orm class loader. Subclasses of this - parameter type should implement the `orm_class_loader` method to return the appropriate orm class loader, - which should be a subclass of `aiida.orm.utils.loaders.OrmEntityLoader` for the corresponding orm class. - """ - - def __init__(self, sub_classes=None): - """ - Construct the parameter type, optionally specifying a tuple of entry points that reference classes - that should be a sub class of the base orm class of the orm class loader. The classes pointed to by - these entry points will be passed to the OrmEntityLoader when converting an identifier and they will - restrict the query set by demanding that the class of the corresponding entity matches these sub classes. - - To prevent having to load the database environment at import time, the actual loading of the entry points - is deferred until the call to `convert` is made. This is to keep the command line autocompletion light - and responsive. The entry point strings will be validated, however, to see if the correspond to known - entry points. - - :param sub_classes: a tuple of entry point strings that can narrow the set of orm classes that values - will be mapped upon. These classes have to be strict sub classes of the base orm class defined - by the orm class loader - """ - from aiida.common import exceptions - - self._sub_classes = None - self._entry_points = [] - - if sub_classes is not None: - - if not isinstance(sub_classes, tuple): - raise TypeError('sub_classes should be a tuple of entry point strings') - - for entry_point_string in sub_classes: - - try: - entry_point = get_entry_point_from_string(entry_point_string) - except (ValueError, exceptions.EntryPointError) as exception: - raise ValueError(f'{entry_point_string} is not a valid entry point string: {exception}') - else: - self._entry_points.append(entry_point) - - @property - @abstractmethod - @with_dbenv() - def orm_class_loader(self): - """ - Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed - to be used to load the entity for a given identifier - - :return: the orm entity loader class for this ParamType - """ - - @with_dbenv() - def convert(self, value, param, ctx): - """ - Attempt to convert the given value to an instance of the orm class using the orm class loader. - - :return: the loaded orm entity - :raises click.BadParameter: if the value is ambiguous and leads to multiple entities - :raises click.BadParameter: if the value cannot be mapped onto any existing instance - :raises RuntimeError: if the defined orm class loader is not a subclass of the OrmEntityLoader class - """ - from aiida.common import exceptions - from aiida.orm.utils.loaders import OrmEntityLoader - - value = super().convert(value, param, ctx) - - if not value: - raise click.BadParameter('the value for the identifier cannot be empty') - - loader = self.orm_class_loader - - if not issubclass(loader, OrmEntityLoader): - raise RuntimeError('the orm class loader should be a subclass of OrmEntityLoader') - - # If entry points where in the constructor, we load their corresponding classes, validate that they are valid - # sub classes of the orm class loader and then pass it as the sub_class parameter to the load_entity call. - # We store the loaded entry points in an instance variable, such that the loading only has to be done once. - if self._entry_points and self._sub_classes is None: - - sub_classes = [] - - for entry_point in self._entry_points: - try: - sub_class = entry_point.load() - except ImportError as exception: - raise RuntimeError(f'failed to load the entry point {entry_point}: {exception}') - - if not issubclass(sub_class, loader.orm_base_class): - raise RuntimeError( - 'the class {} of entry point {} is not a sub class of {}'.format( - sub_class, entry_point, loader.orm_base_class - ) - ) - else: - sub_classes.append(sub_class) - - self._sub_classes = tuple(sub_classes) - - try: - entity = loader.load_entity(value, sub_classes=self._sub_classes) - except (exceptions.MultipleObjectsError, exceptions.NotExistent, ValueError) as exception: - raise click.BadParameter(str(exception)) - - return entity diff --git a/aiida/cmdline/params/types/node.py b/aiida/cmdline/params/types/node.py deleted file mode 100644 index 7642eb22d5..0000000000 --- a/aiida/cmdline/params/types/node.py +++ /dev/null @@ -1,34 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -Module to define the custom click param type for node -""" -from .identifier import IdentifierParamType - -__all__ = ('NodeParamType',) - - -class NodeParamType(IdentifierParamType): - """ - The ParamType for identifying Node entities or its subclasses - """ - - name = 'Node' - - @property - def orm_class_loader(self): - """ - Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed - to be used to load the entity for a given identifier - - :return: the orm entity loader class for this ParamType - """ - from aiida.orm.utils.loaders import NodeEntityLoader - return NodeEntityLoader diff --git a/aiida/cmdline/params/types/path.py b/aiida/cmdline/params/types/path.py deleted file mode 100644 index de016e42e9..0000000000 --- a/aiida/cmdline/params/types/path.py +++ /dev/null @@ -1,129 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Click parameter types for paths.""" -import os -from socket import timeout -import urllib.error -import urllib.request - -import click - -__all__ = ('AbsolutePathParamType', 'FileOrUrl', 'PathOrUrl') - -URL_TIMEOUT_SECONDS = 10 - - -def check_timeout_seconds(timeout_seconds): - """Raise if timeout is not within range [0;60]""" - try: - timeout_seconds = int(timeout_seconds) - except ValueError: - raise TypeError(f'timeout_seconds should be an integer but got: {type(timeout_seconds)}') - - if timeout_seconds < 0 or timeout_seconds > 60: - raise ValueError('timeout_seconds needs to be in the range [0;60].') - - return timeout_seconds - - -class AbsolutePathParamType(click.Path): - """The ParamType for identifying absolute Paths (derived from click.Path).""" - - name = 'AbsolutePath' - - def convert(self, value, param, ctx): - value = os.path.expanduser(value) - newval = super().convert(value, param, ctx) - if not os.path.isabs(newval): - raise click.BadParameter('path must be absolute') - return newval - - def __repr__(self): - return 'ABSOLUTEPATH' - - -class AbsolutePathOrEmptyParamType(AbsolutePathParamType): - """The ParamType for identifying absolute Paths, accepting also empty paths.""" - - name = 'AbsolutePathEmpty' - - def convert(self, value, param, ctx): - if not value: - return value - return super().convert(value, param, ctx) - - def __repr__(self): - return 'ABSOLUTEPATHEMPTY' - - -class PathOrUrl(click.Path): - """Extension of click's Path-type to include URLs. - - A PathOrUrl can either be a `click.Path`-type or a URL. - - :param int timeout_seconds: Maximum timeout accepted for URL response. - Must be an integer in the range [0;60]. - """ - - name = 'PathOrUrl' - - def __init__(self, timeout_seconds=URL_TIMEOUT_SECONDS, **kwargs): - super().__init__(**kwargs) - - self.timeout_seconds = check_timeout_seconds(timeout_seconds) - - def convert(self, value, param, ctx): - """Overwrite `convert` Check first if `click.Path`-type, then check if URL.""" - try: - return super().convert(value, param, ctx) - except click.exceptions.BadParameter: - return self.checks_url(value, param, ctx) - - def checks_url(self, url, param, ctx): - """Check whether URL is reachable within timeout.""" - try: - with urllib.request.urlopen(url, timeout=self.timeout_seconds): - pass - except (urllib.error.URLError, urllib.error.HTTPError, timeout): - self.fail(f'{self.name} "{url}" could not be reached within {self.timeout_seconds} s.\n', param, ctx) - - return url - - -class FileOrUrl(click.File): - """Extension of click's File-type to include URLs. - - Returns handle either to local file or to remote file fetched from URL. - - :param int timeout_seconds: Maximum timeout accepted for URL response. - Must be an integer in the range [0;60]. - """ - - name = 'FileOrUrl' - - def __init__(self, timeout_seconds=URL_TIMEOUT_SECONDS, **kwargs): - super().__init__(**kwargs) - - self.timeout_seconds = check_timeout_seconds(timeout_seconds) - - def convert(self, value, param, ctx): - """Return file handle.""" - try: - return super().convert(value, param, ctx) - except click.exceptions.BadParameter: - handle = self.get_url(value, param, ctx) - return handle - - def get_url(self, url, param, ctx): - """Retrieve file from URL.""" - try: - return urllib.request.urlopen(url, timeout=self.timeout_seconds) # pylint: disable=consider-using-with - except (urllib.error.URLError, urllib.error.HTTPError, timeout): - self.fail(f'{self.name} "{url}" could not be reached within {self.timeout_seconds} s.\n', param, ctx) diff --git a/aiida/cmdline/params/types/plugin.py b/aiida/cmdline/params/types/plugin.py deleted file mode 100644 index d0f684fc80..0000000000 --- a/aiida/cmdline/params/types/plugin.py +++ /dev/null @@ -1,262 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Click parameter type for AiiDA Plugins.""" -import functools - -import click -from importlib_metadata import EntryPoint - -from aiida.common import exceptions -from aiida.plugins import factories -from aiida.plugins.entry_point import ( - ENTRY_POINT_GROUP_PREFIX, - ENTRY_POINT_STRING_SEPARATOR, - EntryPointFormat, - format_entry_point_string, - get_entry_point, - get_entry_point_groups, - get_entry_point_string_format, - get_entry_points, -) - -from .strings import EntryPointType - -__all__ = ('PluginParamType',) - - -class PluginParamType(EntryPointType): - """ - AiiDA Plugin name parameter type. - - :param group: string or tuple of strings, where each is a valid entry point group. Adding the `aiida.` - prefix is optional. If it is not detected it will be prepended internally. - :param load: when set to True, convert will not return the entry point, but the loaded entry point - - Usage:: - - click.option(... type=PluginParamType(group='aiida.calculations') - - or:: - - click.option(... type=PluginParamType(group=('calculations', 'data')) - - """ - name = 'plugin' - - _factory_mapping = { - 'aiida.calculations': factories.CalculationFactory, - 'aiida.data': factories.DataFactory, - 'aiida.groups': factories.GroupFactory, - 'aiida.parsers': factories.ParserFactory, - 'aiida.schedulers': factories.SchedulerFactory, - 'aiida.transports': factories.TransportFactory, - 'aiida.tools.dbimporters': factories.DbImporterFactory, - 'aiida.tools.data.orbitals': factories.OrbitalFactory, - 'aiida.workflows': factories.WorkflowFactory, - } - - def __init__(self, group=None, load=False, *args, **kwargs): - """ - Validate that group is either a string or a tuple of valid entry point groups, or if it - is not specified use the tuple of all recognized entry point groups. - """ - # pylint: disable=keyword-arg-before-vararg - valid_entry_point_groups = get_entry_point_groups() - - if group is None: - self._groups = tuple(valid_entry_point_groups) - else: - if isinstance(group, str): - invalidated_groups = tuple([group]) - elif isinstance(group, tuple): - invalidated_groups = group - else: - raise ValueError('invalid type for group') - - groups = [] - - for grp in invalidated_groups: - - if not grp.startswith(ENTRY_POINT_GROUP_PREFIX): - grp = ENTRY_POINT_GROUP_PREFIX + grp - - if grp not in valid_entry_point_groups: - raise ValueError(f'entry point group {grp} is not recognized') - - groups.append(grp) - - self._groups = tuple(groups) - - self._init_entry_points() - self.load = load - - super().__init__(*args, **kwargs) - - def _init_entry_points(self): - """ - Populate entry point information that will be used later on. This should only be called - once in the constructor after setting self.groups because the groups should not be changed - after instantiation - """ - self._entry_points = [(group, entry_point) for group in self.groups for entry_point in get_entry_points(group)] - self._entry_point_names = [entry_point.name for group in self.groups for entry_point in get_entry_points(group)] - - @property - def groups(self): - return self._groups - - @property - def has_potential_ambiguity(self): - """ - Returns whether the set of supported entry point groups can lead to ambiguity when only an entry point name - is specified. This will happen if one ore more groups share an entry point with a common name - """ - return len(self._entry_point_names) != len(set(self._entry_point_names)) - - def get_valid_arguments(self): - """ - Return a list of all available plugins for the groups configured for this PluginParamType instance. - If the entry point names are not unique, because there are multiple groups that contain an entry - point that has an identical name, we need to prefix the names with the full group name - - :returns: list of valid entry point strings - """ - if self.has_potential_ambiguity: - fmt = EntryPointFormat.FULL - return sorted([format_entry_point_string(group, ep.name, fmt=fmt) for group, ep in self._entry_points]) - - return sorted(self._entry_point_names) - - def get_possibilities(self, incomplete=''): - """ - Return a list of plugins starting with incomplete - """ - if incomplete == '': - return self.get_valid_arguments() - - # If there is a chance of ambiguity we always return the entry point string in FULL format, otherwise - # return the possibilities in the same format as the incomplete. Note that this may have some unexpected - # effects. For example if incomplete equals `aiida.` or `calculations` it will be detected as the MINIMAL - # format, even though they would also be the valid beginnings of a FULL or PARTIAL format, except that we - # cannot know that for sure at this time - if self.has_potential_ambiguity: - possibilites = [eps for eps in self.get_valid_arguments() if eps.startswith(incomplete)] - else: - possibilites = [] - fmt = get_entry_point_string_format(incomplete) - - for group, entry_point in self._entry_points: - entry_point_string = format_entry_point_string(group, entry_point.name, fmt=fmt) - if entry_point_string.startswith(incomplete): - possibilites.append(entry_point_string) - - return possibilites - - def shell_complete(self, ctx, param, incomplete): # pylint: disable=unused-argument - """ - Return possible completions based on an incomplete value - - :returns: list of tuples of valid entry points (matching incomplete) and a description - """ - return [click.shell_completion.CompletionItem(p) for p in self.get_possibilities(incomplete=incomplete)] - - def get_missing_message(self, param): # pylint: disable=unused-argument - return 'Possible arguments are:\n\n' + '\n'.join(self.get_valid_arguments()) - - def get_entry_point_from_string(self, entry_point_string): - """ - Validate a given entry point string, which means that it should have a valid entry point string format - and that the entry point unambiguously corresponds to an entry point in the groups configured for this - instance of PluginParameterType. - - :returns: the entry point if valid - :raises: ValueError if the entry point string is invalid - """ - group = None - name = None - - entry_point_format = get_entry_point_string_format(entry_point_string) - - if entry_point_format in (EntryPointFormat.FULL, EntryPointFormat.PARTIAL): - - group, name = entry_point_string.split(ENTRY_POINT_STRING_SEPARATOR) - - if entry_point_format == EntryPointFormat.PARTIAL: - group = ENTRY_POINT_GROUP_PREFIX + group - - self.validate_entry_point_group(group) - - elif entry_point_format == EntryPointFormat.MINIMAL: - - name = entry_point_string - matching_groups = {group for group, entry_point in self._entry_points if entry_point.name == name} - - if len(matching_groups) > 1: - raise ValueError( - "entry point '{}' matches more than one valid entry point group [{}], " - 'please specify an explicit group prefix: {}'.format( - name, ' '.join(matching_groups), self._entry_points - ) - ) - elif not matching_groups: - raise ValueError( - "entry point '{}' is not valid for any of the allowed " - 'entry point groups: {}'.format(name, ' '.join(self.groups)) - ) - - group = matching_groups.pop() - - else: - raise ValueError(f'invalid entry point string format: {entry_point_string}') - - # If there is a factory for the entry point group, use that, otherwise use ``get_entry_point`` - try: - get_entry_point_partial = functools.partial(self._factory_mapping[group], load=False) - except KeyError: - get_entry_point_partial = functools.partial(get_entry_point, group) - - try: - return get_entry_point_partial(name) - except exceptions.EntryPointError as exception: - raise ValueError(exception) - - def validate_entry_point_group(self, group): - if group not in self.groups: - raise ValueError(f'entry point group `{group}` is not supported by this parameter.') - - def convert(self, value, param, ctx): - """ - Convert the string value to an entry point instance, if the value can be successfully parsed - into an actual entry point. Will raise click.BadParameter if validation fails. - """ - # If the value is already of the expected return type, simply return it. This behavior is new in `click==8.0`: - # https://click.palletsprojects.com/en/8.0.x/parameters/#implementing-custom-types - if isinstance(value, EntryPoint): - try: - self.validate_entry_point_group(value.group) - except ValueError as exception: - raise click.BadParameter(str(exception)) - return value - - value = super().convert(value, param, ctx) - - try: - entry_point = self.get_entry_point_from_string(value) - self.validate_entry_point_group(entry_point.group) - except ValueError as exception: - raise click.BadParameter(str(exception)) - - if self.load: - try: - return entry_point.load() - except exceptions.LoadingEntryPointError as exception: - raise click.BadParameter(str(exception)) - else: - return entry_point diff --git a/aiida/cmdline/params/types/process.py b/aiida/cmdline/params/types/process.py deleted file mode 100644 index 0cbe5abf65..0000000000 --- a/aiida/cmdline/params/types/process.py +++ /dev/null @@ -1,35 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -Module for the process node parameter type -""" - -from .identifier import IdentifierParamType - -__all__ = ('ProcessParamType',) - - -class ProcessParamType(IdentifierParamType): - """ - The ParamType for identifying ProcessNode entities or its subclasses - """ - - name = 'Process' - - @property - def orm_class_loader(self): - """ - Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed - to be used to load the entity for a given identifier - - :return: the orm entity loader class for this ParamType - """ - from aiida.orm.utils.loaders import ProcessEntityLoader - return ProcessEntityLoader diff --git a/aiida/cmdline/params/types/profile.py b/aiida/cmdline/params/types/profile.py deleted file mode 100644 index 562bbcaefa..0000000000 --- a/aiida/cmdline/params/types/profile.py +++ /dev/null @@ -1,92 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Profile param type for click.""" -from click.shell_completion import CompletionItem - -from .strings import LabelStringType - -__all__ = ('ProfileParamType',) - - -class ProfileParamType(LabelStringType): - """The profile parameter type for click. - - This parameter type requires the command that uses it to define the ``context_class`` class attribute to be the - :class:`aiida.cmdline.groups.verdi.VerdiContext` class, as that is responsible for creating the user defined object - ``obj`` on the context and loads the instance config. - """ - - name = 'profile' - - def __init__(self, *args, **kwargs): - self._cannot_exist = kwargs.pop('cannot_exist', False) - self._load_profile = kwargs.pop('load_profile', False) # If True, will load the profile converted from value - super().__init__(*args, **kwargs) - - @staticmethod - def deconvert_default(value): - return value.name - - def convert(self, value, param, ctx): - """Attempt to match the given value to a valid profile.""" - from aiida.common.exceptions import MissingConfigurationError, ProfileConfigurationError - from aiida.manage.configuration import Profile, load_profile - - try: - config = ctx.obj.config - except AttributeError: - raise RuntimeError( - 'The context does not contain a user defined object with the loaded AiiDA configuration. ' - 'Is your click command setting `context_class` to :class:`aiida.cmdline.groups.verdi.VerdiContext`?' - ) - - # If the value is already of the expected return type, simply return it. This behavior is new in `click==8.0`: - # https://click.palletsprojects.com/en/8.0.x/parameters/#implementing-custom-types - if isinstance(value, Profile): - return value - - value = super().convert(value, param, ctx) - - try: - profile = config.get_profile(value) - except (MissingConfigurationError, ProfileConfigurationError) as exception: - if not self._cannot_exist: - self.fail(str(exception)) - - # Create a new empty profile - profile = Profile(value, {}, validate=False) - else: - if self._cannot_exist: - self.fail(str(f'the profile `{value}` already exists')) - - if self._load_profile: - load_profile(profile.name) - - ctx.obj.profile = profile - - return profile - - def shell_complete(self, ctx, param, incomplete): # pylint: disable=unused-argument - """Return possible completions based on an incomplete value - - :returns: list of tuples of valid entry points (matching incomplete) and a description - """ - from aiida.common.exceptions import MissingConfigurationError - from aiida.manage.configuration import get_config - - if self._cannot_exist: - return [] - - try: - config = get_config() - except MissingConfigurationError: - return [] - - return [CompletionItem(profile.name) for profile in config.profiles if profile.name.startswith(incomplete)] diff --git a/aiida/cmdline/params/types/user.py b/aiida/cmdline/params/types/user.py deleted file mode 100644 index 69e9c24c30..0000000000 --- a/aiida/cmdline/params/types/user.py +++ /dev/null @@ -1,60 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""User param type for click.""" -import click - -from aiida.cmdline.utils.decorators import with_dbenv - -__all__ = ('UserParamType',) - - -class UserParamType(click.ParamType): - """ - The user parameter type for click. Can get or create a user. - """ - name = 'user' - - def __init__(self, create=False): - """ - :param create: If the user does not exist, create a new instance (unstored). - """ - self._create = create - - @with_dbenv() - def convert(self, value, param, ctx): - from aiida import orm - - results = orm.User.collection.find({'email': value}) - - if not results: - if self._create: - return orm.User(email=value) - - self.fail(f"User '{value}' not found", param, ctx) - - if len(results) > 1: - self.fail(f"Multiple users found with email '{value}': {results}") - - return results[0] - - @with_dbenv() - def shell_complete(self, ctx, param, incomplete): # pylint: disable=unused-argument - """ - Return possible completions based on an incomplete value - - :returns: list of tuples of valid entry points (matching incomplete) and a description - """ - from aiida import orm - - users = orm.User.collection.find() - - return [ - click.shell_completion.CompletionItem(user.email) for user in users if user.email.startswith(incomplete) - ] diff --git a/aiida/cmdline/params/types/workflow.py b/aiida/cmdline/params/types/workflow.py deleted file mode 100644 index 7403ff99f7..0000000000 --- a/aiida/cmdline/params/types/workflow.py +++ /dev/null @@ -1,35 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -Module for the workflow parameter type -""" - -from .identifier import IdentifierParamType - -__all__ = ('WorkflowParamType',) - - -class WorkflowParamType(IdentifierParamType): - """ - The ParamType for identifying WorkflowNode entities or its subclasses - """ - - name = 'WorkflowNode' - - @property - def orm_class_loader(self): - """ - Return the orm entity loader class, which should be a subclass of OrmEntityLoader. This class is supposed - to be used to load the entity for a given identifier - - :return: the orm entity loader class for this ParamType - """ - from aiida.orm.utils.loaders import WorkflowEntityLoader - return WorkflowEntityLoader diff --git a/aiida/cmdline/utils/__init__.py b/aiida/cmdline/utils/__init__.py deleted file mode 100644 index a851adef0a..0000000000 --- a/aiida/cmdline/utils/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Commandline utility functions.""" -# AUTO-GENERATED - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .ascii_vis import * -from .common import * -from .decorators import * -from .echo import * - -__all__ = ( - 'dbenv', - 'echo_critical', - 'echo_dictionary', - 'echo_error', - 'echo_info', - 'echo_report', - 'echo_success', - 'echo_warning', - 'format_call_graph', - 'is_verbose', - 'only_if_daemon_running', - 'with_dbenv', -) - -# yapf: enable diff --git a/aiida/cmdline/utils/common.py b/aiida/cmdline/utils/common.py deleted file mode 100644 index 3073688d53..0000000000 --- a/aiida/cmdline/utils/common.py +++ /dev/null @@ -1,486 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Common utility functions for command line commands.""" -import logging -import os -import sys -import textwrap -from typing import TYPE_CHECKING - -from click import style -from tabulate import tabulate - -from . import echo - -if TYPE_CHECKING: - from aiida.orm import WorkChainNode - -__all__ = ('is_verbose',) - - -def is_verbose(): - """Return whether the configured logging verbosity is considered verbose, i.e., equal or lower to ``INFO`` level. - - .. note:: This checks the effective logging level that is set on the ``CMDLINE_LOGGER``. This means that it will - consider the logging level set on the parent ``AIIDA_LOGGER`` if not explicitly set on itself. The level of the - main logger can be manipulated from the command line through the ``VERBOSITY`` option that is available for all - commands. - - """ - return echo.CMDLINE_LOGGER.getEffectiveLevel() <= logging.INFO - - -def get_env_with_venv_bin(): - """Create a clone of the current running environment with the AIIDA_PATH variable set directory of the config.""" - from aiida.common.warnings import warn_deprecation - from aiida.manage.configuration import get_config - - warn_deprecation( - '`get_env_with_venv_bin` function is deprecated use `aiida.engine.daemon.client.DaemonClient.get_env` instead.', - version=3 - ) - - config = get_config() - - currenv = os.environ.copy() - currenv['PATH'] = f"{os.path.dirname(sys.executable)}:{currenv['PATH']}" - currenv['AIIDA_PATH'] = config.dirpath - currenv['PYTHONUNBUFFERED'] = 'True' - - return currenv - - -def format_local_time(timestamp, format_str='%Y-%m-%d %H:%M:%S'): - """ - Format a datetime object or UNIX timestamp in a human readable format - - :param timestamp: a datetime object or a float representing a UNIX timestamp - :param format_str: optional string format to pass to strftime - """ - from aiida.common import timezone - - if isinstance(timestamp, float): - return timezone.datetime.fromtimestamp(timestamp).strftime(format_str) - - return timestamp.strftime(format_str) - - -def print_last_process_state_change(process_type=None): - """ - Print the last time that a process of the specified type has changed its state. - - :param process_type: optional process type for which to get the latest state change timestamp. - Valid process types are either 'calculation' or 'work'. - """ - from aiida.cmdline.utils.echo import echo_report - from aiida.common import timezone - from aiida.common.utils import str_timedelta - from aiida.engine.utils import get_process_state_change_timestamp - - timestamp = get_process_state_change_timestamp(process_type) - - if timestamp is None: - echo_report('last time an entry changed state: never') - else: - timedelta = timezone.delta(timestamp) - formatted = format_local_time(timestamp, format_str='at %H:%M:%S on %Y-%m-%d') - relative = str_timedelta(timedelta, negative_to_zero=True, max_num_fields=1) - echo_report(f'last time an entry changed state: {relative} ({formatted})') - - -def get_node_summary(node): - """Return a multi line string with a pretty formatted summary of a Node. - - :param node: a Node instance - :return: a string summary of the node - """ - from plumpy import ProcessState - - from aiida.orm import ProcessNode - - table_headers = ['Property', 'Value'] - table = [] - - if isinstance(node, ProcessNode): - table.append(['type', node.process_label]) - - try: - process_state = ProcessState(node.process_state) - except (AttributeError, ValueError): - pass - else: - process_state_string = process_state.value.capitalize() - - if process_state == ProcessState.FINISHED and node.exit_message: - table.append(['state', f'{process_state_string} [{node.exit_status}] {node.exit_message}']) - elif process_state == ProcessState.FINISHED: - table.append(['state', f'{process_state_string} [{node.exit_status}]']) - elif process_state == ProcessState.EXCEPTED: - table.append(['state', f'{process_state_string} <{node.exception}>']) - else: - table.append(['state', process_state_string]) - - else: - table.append(['type', node.__class__.__name__]) - - table.append(['pk', str(node.pk)]) - table.append(['uuid', str(node.uuid)]) - table.append(['label', node.label]) - table.append(['description', node.description]) - table.append(['ctime', node.ctime]) - table.append(['mtime', node.mtime]) - - try: - computer = node.computer - except AttributeError: - pass - else: - if computer is not None: - table.append(['computer', f'[{node.computer.pk}] {node.computer.label}']) - - return tabulate(table, headers=table_headers) - - -def get_node_info(node, include_summary=True): - """Return a multi line string of information about the given node, such as the incoming and outcoming links. - - :param include_summary: boolean, if True, also include a summary of node properties - :return: a string summary of the node including a description of all its links and log messages - """ - from aiida import orm - from aiida.common.links import LinkType - - if include_summary: - result = get_node_summary(node) - else: - result = '' - - nodes_caller = node.base.links.get_incoming(link_type=(LinkType.CALL_CALC, LinkType.CALL_WORK)) - nodes_called = node.base.links.get_outgoing(link_type=(LinkType.CALL_CALC, LinkType.CALL_WORK)) - nodes_input = node.base.links.get_incoming(link_type=(LinkType.INPUT_CALC, LinkType.INPUT_WORK)) - nodes_output = node.base.links.get_outgoing(link_type=(LinkType.CREATE, LinkType.RETURN)) - - if nodes_input: - result += f"\n{format_nested_links(nodes_input.nested(), headers=['Inputs', 'PK', 'Type'])}" - - if nodes_output: - result += f"\n{format_nested_links(nodes_output.nested(), headers=['Outputs', 'PK', 'Type'])}" - - if nodes_caller: - links = sorted(nodes_caller.all(), key=lambda x: x.node.ctime) - result += f"\n{format_flat_links(links, headers=['Caller', 'PK', 'Type'])}" - - if nodes_called: - links = sorted(nodes_called.all(), key=lambda x: x.node.ctime) - result += f"\n{format_flat_links(links, headers=['Called', 'PK', 'Type'])}" - - log_messages = orm.Log.collection.get_logs_for(node) - - if log_messages: - table = [] - table_headers = ['Log messages'] - table.append([f'There are {len(log_messages)} log messages for this calculation']) - table.append([f"Run 'verdi process report {node.pk}' to see them"]) - result += f'\n\n{tabulate(table, headers=table_headers)}' - - return result - - -def format_flat_links(links, headers): - """Given a flat list of LinkTriples, return a flat string representation. - - :param links: a list of LinkTriples - :param headers: headers to use - :return: formatted string - """ - table = [] - - for link_triple in links: - table.append([ - link_triple.link_label, link_triple.node.pk, - link_triple.node.base.attributes.get('process_label', '') - ]) - - result = f'\n{tabulate(table, headers=headers)}' - - return result - - -def format_nested_links(links, headers): - """Given a nested dictionary of nodes, return a nested string representation. - - :param links: a nested dictionary of nodes - :param headers: headers to use - :return: nested formatted string - """ - from collections.abc import Mapping - - import tabulate as tb - - tb.PRESERVE_WHITESPACE = True - - indent_size = 4 - - def format_recursive(links, depth=0): - """Recursively format a dictionary of nodes into indented strings.""" - rows = [] - for label, value in links.items(): - if isinstance(value, Mapping): - rows.append([depth, label, '', '']) - rows.extend(format_recursive(value, depth=depth + 1)) - else: - rows.append([depth, label, value.pk, value.__class__.__name__]) - return rows - - table = [] - - for depth, label, pk, class_name in format_recursive(links): - table.append([f"{' ' * (depth * indent_size)}{label}", pk, class_name]) - - result = f'\n{tabulate(table, headers=headers)}' - tb.PRESERVE_WHITESPACE = False - - return result - - -def get_calcjob_report(calcjob): - """ - Return a multi line string representation of the log messages and output of a given calcjob - - :param calcjob: the calcjob node - :return: a string representation of the log messages and scheduler output - """ - from aiida import orm - from aiida.common.datastructures import CalcJobState - - log_messages = orm.Log.collection.get_logs_for(calcjob) - scheduler_out = calcjob.get_scheduler_stdout() - scheduler_err = calcjob.get_scheduler_stderr() - calcjob_state = calcjob.get_state() - scheduler_state = calcjob.get_scheduler_state() - - report = [] - - if calcjob_state == CalcJobState.WITHSCHEDULER: - state_string = f"{calcjob_state}, scheduler state: {scheduler_state if scheduler_state else '(unknown)'}" - else: - state_string = f'{calcjob_state}' - - label_string = f' [{calcjob.label}]' if calcjob.label else '' - - report.append(f'*** {calcjob.pk}{label_string}: {state_string}') - - if scheduler_out is None: - report.append('*** Scheduler output: N/A') - elif scheduler_out: - report.append(f'*** Scheduler output:\n{scheduler_out}') - else: - report.append('*** (empty scheduler output file)') - - if scheduler_err is None: - report.append('*** Scheduler errors: N/A') - elif scheduler_err: - report.append(f'*** Scheduler errors:\n{scheduler_err}') - else: - report.append('*** (empty scheduler errors file)') - - if log_messages: - report.append(f'*** {len(log_messages)} LOG MESSAGES:') - else: - report.append('*** 0 LOG MESSAGES') - - for log in log_messages: - report.append(f'+-> {log.levelname} at {log.time}') - for message in log.message.splitlines(): - report.append(f' | {message}') - - return '\n'.join(report) - - -def get_process_function_report(node): - """ - Return a multi line string representation of the log messages and output of a given process function node - - :param node: the node - :return: a string representation of the log messages - """ - from aiida import orm - - report = [] - - for log in orm.Log.collection.get_logs_for(node): - report.append(f'{log.time:%Y-%m-%d %H:%M:%S} [{log.pk}]: {log.message}') - - return '\n'.join(report) - - -def get_workchain_report(node: 'WorkChainNode', levelname, indent_size=4, max_depth=None): - """ - Return a multi line string representation of the log messages and output of a given workchain - - :param node: the workchain node - :return: a nested string representation of the log messages - """ - # pylint: disable=too-many-locals - import itertools - - from aiida import orm - from aiida.common.log import LOG_LEVELS - - def get_report_messages(uuid, depth, levelname): - """Return list of log messages with given levelname and their depth for a node with a given uuid.""" - node_id = orm.load_node(uuid).pk - filters = {'dbnode_id': node_id} - - entries = orm.Log.collection.find(filters) - entries = [entry for entry in entries if LOG_LEVELS[entry.levelname] >= LOG_LEVELS[levelname]] - return [(_, depth) for _ in entries] - - def get_subtree(uuid, level=0): - """ - Get a nested tree of work calculation nodes and their nesting level starting from this uuid. - The result is a list of uuid of these nodes. - """ - builder = orm.QueryBuilder(backend=node.backend) - builder.append(cls=orm.WorkChainNode, filters={'uuid': uuid}, tag='workcalculation') - builder.append( - cls=orm.WorkChainNode, - project=['uuid'], - # In the future, we should specify here the type of link - # for now, CALL links are the only ones allowing calc-calc - # (we here really want instead to follow CALL links) - with_incoming='workcalculation', - tag='subworkchains' - ) - result = builder.all(flat=True) - - # This will return a single flat list of tuples, where the first element - # corresponds to the WorkChain pk and the second element is an integer - # that represents its level of nesting within the chain - return [(uuid, level)] + list(itertools.chain(*[get_subtree(subuuid, level=level + 1) for subuuid in result])) - - workchain_tree = get_subtree(node.uuid) - - if max_depth: - report_list = [ - get_report_messages(uuid, depth, levelname) for uuid, depth in workchain_tree if depth < max_depth - ] - else: - report_list = [get_report_messages(uuid, depth, levelname) for uuid, depth in workchain_tree] - - reports = list(itertools.chain(*report_list)) - reports.sort(key=lambda r: r[0].time) - - if not reports: - return 'No log messages recorded for this entry' - - log_ids = [entry[0].pk for entry in reports] - levelnames = [len(entry[0].levelname) for entry in reports] - width_id = len(str(max(log_ids))) - width_levelname = max(levelnames) - report = [] - - for entry, depth in reports: - line = '{time:%Y-%m-%d %H:%M:%S} [{id:<{width_id}} | {levelname:>{width_levelname}}]:{indent} {message}'.format( - id=entry.pk, - levelname=entry.levelname, - message=entry.message, - time=entry.time, - width_id=width_id, - width_levelname=width_levelname, - indent=' ' * (depth * indent_size) - ) - report.append(line) - - return '\n'.join(report) - - -def print_process_info(process): - """Print detailed information about a process class and its process specification. - - :param process: a :py:class:`~aiida.engine.processes.process.Process` class - """ - docstring = process.__doc__ - - if docstring is None or docstring.strip() is None: - docstring = 'No description available' - - echo.echo('Description:\n', fg=echo.COLORS['report'], bold=True) - echo.echo(textwrap.indent('\n'.join(textwrap.wrap(docstring, 100)), ' ')) - print_process_spec(process.spec()) - - -def print_process_spec(process_spec): - """Print the process spec in a human-readable formatted way. - - :param process_spec: a `ProcessSpec` instance - """ - - def build_entries(ports): - """Build a list of entries to be printed for a `PortNamespace. - - :param ports: the port namespace - :return: list of tuples with port name, required, valid types and info strings - """ - result = [] - - for name, port in sorted(ports.items(), key=lambda x: (not x[1].required, x[0])): - - if name.startswith('_'): - continue - - valid_types = port.valid_type if isinstance(port.valid_type, (list, tuple)) else (port.valid_type,) - valid_types = ', '.join([valid_type.__name__ for valid_type in valid_types if valid_type is not None]) - info = textwrap.wrap(port.help if port.help is not None else '', width=75) - result.append([name, port.required, valid_types, info]) - - return result - - inputs = build_entries(process_spec.inputs) - outputs = build_entries(process_spec.outputs) - - if process_spec.inputs: - echo.echo('\nInputs:', fg=echo.COLORS['report'], bold=True) - - table = [] - - for name, required, valid_types, info in inputs: - table.append((style(name, bold=required, fg='red' if required else 'white'), valid_types, '\n'.join(info))) - - if table: - echo.echo(tabulate(table, tablefmt='plain', colalign=('right',))) - echo.echo(style('\nRequired inputs are displayed in bold red.\n', italic=True)) - - if process_spec.outputs: - echo.echo('Outputs:', fg=echo.COLORS['report'], bold=True) - - table = [] - - for name, required, valid_types, info in outputs: - table.append((style(name, bold=required, fg='red' if required else 'white'), valid_types, '\n'.join(info))) - - if table: - echo.echo(tabulate(table, tablefmt='plain', colalign=('right',))) - echo.echo(style('\nRequired outputs are displayed in bold red.\n', italic=True)) - - if process_spec.exit_codes: - echo.echo('Exit codes:\n', fg=echo.COLORS['report'], bold=True) - - table = [('0', 'The process finished successfully.')] - - for exit_code in sorted(process_spec.exit_codes.values(), key=lambda exit_code: exit_code.status): - if exit_code.invalidates_cache: - status = style(exit_code.status, bold=True, fg='red') - else: - status = exit_code.status - table.append((status, '\n'.join(textwrap.wrap(exit_code.message, width=75)))) - - echo.echo(tabulate(table, tablefmt='plain')) - echo.echo(style('\nExit codes that invalidate the cache are marked in bold red.\n', italic=True)) diff --git a/aiida/cmdline/utils/defaults.py b/aiida/cmdline/utils/defaults.py deleted file mode 100644 index 43f825e827..0000000000 --- a/aiida/cmdline/utils/defaults.py +++ /dev/null @@ -1,37 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Default values and lazy default get methods for command line options.""" - -from aiida.cmdline.utils import echo -from aiida.common import exceptions -from aiida.manage.configuration import get_config - - -def get_default_profile(): # pylint: disable=unused-argument - """Try to get the name of the default profile. - - This utility function should only be used for defaults or callbacks in command line interface parameters. - Otherwise, the preference should go to calling `get_config` to load the actual config and using - `config.default_profile_name` to get the default profile name. - - :raises click.UsageError: if the config could not be loaded or no default profile exists - :return: the default profile name or None if no default is defined in the configuration - """ - try: - config = get_config(create=True) - except exceptions.ConfigurationError as exception: - echo.echo_critical(str(exception)) - - try: - default_profile = config.get_profile(config.default_profile_name).name - except exceptions.ProfileConfigurationError: - default_profile = None - - return default_profile diff --git a/aiida/cmdline/utils/log.py b/aiida/cmdline/utils/log.py deleted file mode 100644 index 893004d15a..0000000000 --- a/aiida/cmdline/utils/log.py +++ /dev/null @@ -1,63 +0,0 @@ -# -*- coding: utf-8 -*- -"""Utilities for logging in the command line interface context.""" -import logging - -import click - -from .echo import COLORS - - -class CliHandler(logging.Handler): - """Handler for writing to the console using click.""" - - def emit(self, record): - """Emit log record via click. - - Can make use of special attributes 'nl' (whether to add newline) and 'err' (whether to print to stderr), which - can be set via the 'extra' dictionary parameter of the logging methods. - """ - try: - nl = record.nl - except AttributeError: - nl = True - - try: - err = record.err - except AttributeError: - err = False - - try: - prefix = record.prefix - except AttributeError: - prefix = True - - record.prefix = prefix - - try: - msg = self.format(record) - click.echo(msg, err=err, nl=nl) - except Exception: # pylint: disable=broad-except - self.handleError(record) - - -class CliFormatter(logging.Formatter): - """Formatter that automatically prefixes log messages with a colored version of the log level.""" - - def format(self, record): - """Format the record using the style required for the command line interface.""" - try: - fg = COLORS[record.levelname.lower()] - except KeyError: - fg = 'white' - - try: - prefix = record.prefix - except AttributeError: - prefix = None - - formatted = super().format(record) - - if prefix: - return f'{click.style(record.levelname.capitalize(), fg=fg, bold=True)}: {formatted}' - - return formatted diff --git a/aiida/cmdline/utils/query/__init__.py b/aiida/cmdline/utils/query/__init__.py deleted file mode 100644 index 2776a55f97..0000000000 --- a/aiida/cmdline/utils/query/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### diff --git a/aiida/cmdline/utils/query/calculation.py b/aiida/cmdline/utils/query/calculation.py deleted file mode 100644 index 3cf8c9898f..0000000000 --- a/aiida/cmdline/utils/query/calculation.py +++ /dev/null @@ -1,15 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=unused-import -"""A utility module with a factory of standard QueryBuilder instances for Calculation nodes.""" -from aiida.common.warnings import warn_deprecation -from aiida.tools.query.calculation import CalculationQueryBuilder - -warn_deprecation('This module is deprecated, use `aiida.tools.query.calculation` instead.', version=3) diff --git a/aiida/cmdline/utils/query/formatting.py b/aiida/cmdline/utils/query/formatting.py deleted file mode 100644 index 67549a27a0..0000000000 --- a/aiida/cmdline/utils/query/formatting.py +++ /dev/null @@ -1,15 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=unused-import -"""A utility module with simple functions to format variables into strings for cli outputs.""" -from aiida.common.warnings import warn_deprecation -from aiida.tools.query.formatting import format_process_state, format_relative_time, format_sealed, format_state - -warn_deprecation('This module is deprecated, use `aiida.tools.query.formatting` instead.', version=3) diff --git a/aiida/cmdline/utils/query/mapping.py b/aiida/cmdline/utils/query/mapping.py deleted file mode 100644 index 7c70831f4b..0000000000 --- a/aiida/cmdline/utils/query/mapping.py +++ /dev/null @@ -1,15 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=unused-import -"""A utility module with mapper objects that map database entities projections on attributes and labels.""" -from aiida.common.warnings import warn_deprecation -from aiida.tools.query.mapping import CalculationProjectionMapper, ProjectionMapper - -warn_deprecation('This module is deprecated, use `aiida.tools.query.mapping` instead.', version=3) diff --git a/aiida/cmdline/utils/repository.py b/aiida/cmdline/utils/repository.py deleted file mode 100644 index d50f73f25a..0000000000 --- a/aiida/cmdline/utils/repository.py +++ /dev/null @@ -1,29 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Utility functions for command line commands operating on the repository.""" -from aiida.cmdline.utils import echo - - -def list_repository_contents(node, path, color): - """Print the contents of the directory `path` in the repository of the given node to stdout. - - :param node: the node - :param path: directory path - :raises FileNotFoundError: if the `path` does not exist in the repository of the given node - """ - from aiida.repository import FileType - - for entry in node.base.repository.list_objects(path): - bold = bool(entry.file_type == FileType.DIRECTORY) - echo.echo( - entry.name, - bold=bold, - fg=echo.COLORS['report'] if color and entry.file_type == FileType.DIRECTORY else None - ) diff --git a/aiida/common/__init__.py b/aiida/common/__init__.py deleted file mode 100644 index 3c68731ff0..0000000000 --- a/aiida/common/__init__.py +++ /dev/null @@ -1,94 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -Common data structures, utility classes and functions - -.. note:: Modules in this sub package have to run without a loaded database environment - -""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .datastructures import * -from .exceptions import * -from .extendeddicts import * -from .links import * -from .log import * -from .progress_reporter import * - -__all__ = ( - 'AIIDA_LOGGER', - 'AiidaException', - 'AttributeDict', - 'CalcInfo', - 'CalcJobState', - 'ClosedStorage', - 'CodeInfo', - 'CodeRunMode', - 'ConfigurationError', - 'ConfigurationVersionError', - 'ContentNotExistent', - 'CorruptStorage', - 'DbContentError', - 'DefaultFieldsAttributeDict', - 'EntryPointError', - 'FailedError', - 'FeatureDisabled', - 'FeatureNotAvailable', - 'FixedFieldsAttributeDict', - 'GraphTraversalRule', - 'GraphTraversalRules', - 'HashingError', - 'IncompatibleStorageSchema', - 'InputValidationError', - 'IntegrityError', - 'InternalError', - 'InvalidEntryPointTypeError', - 'InvalidOperation', - 'LicensingException', - 'LinkType', - 'LoadingEntryPointError', - 'LockedProfileError', - 'LockingProfileError', - 'MissingConfigurationError', - 'MissingEntryPointError', - 'ModificationNotAllowed', - 'MultipleEntryPointError', - 'MultipleObjectsError', - 'NotExistent', - 'NotExistentAttributeError', - 'NotExistentKeyError', - 'OutputParsingError', - 'ParsingError', - 'PluginInternalError', - 'ProfileConfigurationError', - 'ProgressReporterAbstract', - 'RemoteOperationError', - 'StashMode', - 'StorageMigrationError', - 'StoringNotAllowed', - 'TQDM_BAR_FORMAT', - 'TestsNotAllowedError', - 'TransportTaskException', - 'UniquenessError', - 'UnsupportedSpeciesError', - 'ValidationError', - 'create_callback', - 'get_progress_reporter', - 'override_log_level', - 'set_progress_bar_tqdm', - 'set_progress_reporter', - 'validate_link_label', -) - -# yapf: enable diff --git a/aiida/common/constants.py b/aiida/common/constants.py deleted file mode 100644 index 7a80dade9e..0000000000 --- a/aiida/common/constants.py +++ /dev/null @@ -1,607 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module to define the (physical) constants used throughout the code.""" - -# This is the precision with which AiiDA internally will store float numbers -# In particular, before storing a number (in attributes/extras), -# when going via clean_value, AiiDA will first serialize to a string with -# this precision, and then bring it back to float. This is important to -# ensure consistency between the precision of different backends, and also -# for hashing - the hashing MUST use the same precision, so that a hash computed -# after clean_value but before storing to the database will be the same as the -# hash recomputed (e.g. via `verdi rehash`) after storing and reloading the node -# from the database. -# See also discussion on GitHub issue #2631 -# -# IMPORTANT! Changing this value requires to rehash the whole database! -# So if you plan to change this, think carefully, and you might need to add -# a migration. -AIIDA_FLOAT_PRECISION = 14 - -# Element table, from NIST (http://www.nist.gov/pml/data/index.cfm) -# Retrieved in October 2014 for atomic numbers 1-103, and in May 2016 or atomic numbers 104-112, 114 and 116. -# In addition, element X is added to support unknown elements. -elements = { # pylint: disable=invalid-name - 0: { - 'mass': 1.00000, - 'name': 'Unknown', - 'symbol': 'X' - }, - 1: { - 'mass': 1.00794, - 'name': 'Hydrogen', - 'symbol': 'H' - }, - 2: { - 'mass': 4.002602, - 'name': 'Helium', - 'symbol': 'He' - }, - 3: { - 'mass': 6.941, - 'name': 'Lithium', - 'symbol': 'Li' - }, - 4: { - 'mass': 9.012182, - 'name': 'Beryllium', - 'symbol': 'Be' - }, - 5: { - 'mass': 10.811, - 'name': 'Boron', - 'symbol': 'B' - }, - 6: { - 'mass': 12.0107, - 'name': 'Carbon', - 'symbol': 'C' - }, - 7: { - 'mass': 14.0067, - 'name': 'Nitrogen', - 'symbol': 'N' - }, - 8: { - 'mass': 15.9994, - 'name': 'Oxygen', - 'symbol': 'O' - }, - 9: { - 'mass': 18.9984032, - 'name': 'Fluorine', - 'symbol': 'F' - }, - 10: { - 'mass': 20.1797, - 'name': 'Neon', - 'symbol': 'Ne' - }, - 11: { - 'mass': 22.98977, - 'name': 'Sodium', - 'symbol': 'Na' - }, - 12: { - 'mass': 24.305, - 'name': 'Magnesium', - 'symbol': 'Mg' - }, - 13: { - 'mass': 26.981538, - 'name': 'Aluminium', - 'symbol': 'Al' - }, - 14: { - 'mass': 28.0855, - 'name': 'Silicon', - 'symbol': 'Si' - }, - 15: { - 'mass': 30.973761, - 'name': 'Phosphorus', - 'symbol': 'P' - }, - 16: { - 'mass': 32.065, - 'name': 'Sulfur', - 'symbol': 'S' - }, - 17: { - 'mass': 35.453, - 'name': 'Chlorine', - 'symbol': 'Cl' - }, - 18: { - 'mass': 39.948, - 'name': 'Argon', - 'symbol': 'Ar' - }, - 19: { - 'mass': 39.0983, - 'name': 'Potassium', - 'symbol': 'K' - }, - 20: { - 'mass': 40.078, - 'name': 'Calcium', - 'symbol': 'Ca' - }, - 21: { - 'mass': 44.955912, - 'name': 'Scandium', - 'symbol': 'Sc' - }, - 22: { - 'mass': 47.867, - 'name': 'Titanium', - 'symbol': 'Ti' - }, - 23: { - 'mass': 50.9415, - 'name': 'Vanadium', - 'symbol': 'V' - }, - 24: { - 'mass': 51.9961, - 'name': 'Chromium', - 'symbol': 'Cr' - }, - 25: { - 'mass': 54.938045, - 'name': 'Manganese', - 'symbol': 'Mn' - }, - 26: { - 'mass': 55.845, - 'name': 'Iron', - 'symbol': 'Fe' - }, - 27: { - 'mass': 58.933195, - 'name': 'Cobalt', - 'symbol': 'Co' - }, - 28: { - 'mass': 58.6934, - 'name': 'Nickel', - 'symbol': 'Ni' - }, - 29: { - 'mass': 63.546, - 'name': 'Copper', - 'symbol': 'Cu' - }, - 30: { - 'mass': 65.38, - 'name': 'Zinc', - 'symbol': 'Zn' - }, - 31: { - 'mass': 69.723, - 'name': 'Gallium', - 'symbol': 'Ga' - }, - 32: { - 'mass': 72.64, - 'name': 'Germanium', - 'symbol': 'Ge' - }, - 33: { - 'mass': 74.9216, - 'name': 'Arsenic', - 'symbol': 'As' - }, - 34: { - 'mass': 78.96, - 'name': 'Selenium', - 'symbol': 'Se' - }, - 35: { - 'mass': 79.904, - 'name': 'Bromine', - 'symbol': 'Br' - }, - 36: { - 'mass': 83.798, - 'name': 'Krypton', - 'symbol': 'Kr' - }, - 37: { - 'mass': 85.4678, - 'name': 'Rubidium', - 'symbol': 'Rb' - }, - 38: { - 'mass': 87.62, - 'name': 'Strontium', - 'symbol': 'Sr' - }, - 39: { - 'mass': 88.90585, - 'name': 'Yttrium', - 'symbol': 'Y' - }, - 40: { - 'mass': 91.224, - 'name': 'Zirconium', - 'symbol': 'Zr' - }, - 41: { - 'mass': 92.90638, - 'name': 'Niobium', - 'symbol': 'Nb' - }, - 42: { - 'mass': 95.96, - 'name': 'Molybdenum', - 'symbol': 'Mo' - }, - 43: { - 'mass': 98.0, - 'name': 'Technetium', - 'symbol': 'Tc' - }, - 44: { - 'mass': 101.07, - 'name': 'Ruthenium', - 'symbol': 'Ru' - }, - 45: { - 'mass': 102.9055, - 'name': 'Rhodium', - 'symbol': 'Rh' - }, - 46: { - 'mass': 106.42, - 'name': 'Palladium', - 'symbol': 'Pd' - }, - 47: { - 'mass': 107.8682, - 'name': 'Silver', - 'symbol': 'Ag' - }, - 48: { - 'mass': 112.411, - 'name': 'Cadmium', - 'symbol': 'Cd' - }, - 49: { - 'mass': 114.818, - 'name': 'Indium', - 'symbol': 'In' - }, - 50: { - 'mass': 118.71, - 'name': 'Tin', - 'symbol': 'Sn' - }, - 51: { - 'mass': 121.76, - 'name': 'Antimony', - 'symbol': 'Sb' - }, - 52: { - 'mass': 127.6, - 'name': 'Tellurium', - 'symbol': 'Te' - }, - 53: { - 'mass': 126.90447, - 'name': 'Iodine', - 'symbol': 'I' - }, - 54: { - 'mass': 131.293, - 'name': 'Xenon', - 'symbol': 'Xe' - }, - 55: { - 'mass': 132.9054519, - 'name': 'Caesium', - 'symbol': 'Cs' - }, - 56: { - 'mass': 137.327, - 'name': 'Barium', - 'symbol': 'Ba' - }, - 57: { - 'mass': 138.90547, - 'name': 'Lanthanum', - 'symbol': 'La' - }, - 58: { - 'mass': 140.116, - 'name': 'Cerium', - 'symbol': 'Ce' - }, - 59: { - 'mass': 140.90765, - 'name': 'Praseodymium', - 'symbol': 'Pr' - }, - 60: { - 'mass': 144.242, - 'name': 'Neodymium', - 'symbol': 'Nd' - }, - 61: { - 'mass': 145.0, - 'name': 'Promethium', - 'symbol': 'Pm' - }, - 62: { - 'mass': 150.36, - 'name': 'Samarium', - 'symbol': 'Sm' - }, - 63: { - 'mass': 151.964, - 'name': 'Europium', - 'symbol': 'Eu' - }, - 64: { - 'mass': 157.25, - 'name': 'Gadolinium', - 'symbol': 'Gd' - }, - 65: { - 'mass': 158.92535, - 'name': 'Terbium', - 'symbol': 'Tb' - }, - 66: { - 'mass': 162.5, - 'name': 'Dysprosium', - 'symbol': 'Dy' - }, - 67: { - 'mass': 164.93032, - 'name': 'Holmium', - 'symbol': 'Ho' - }, - 68: { - 'mass': 167.259, - 'name': 'Erbium', - 'symbol': 'Er' - }, - 69: { - 'mass': 168.93421, - 'name': 'Thulium', - 'symbol': 'Tm' - }, - 70: { - 'mass': 173.054, - 'name': 'Ytterbium', - 'symbol': 'Yb' - }, - 71: { - 'mass': 174.9668, - 'name': 'Lutetium', - 'symbol': 'Lu' - }, - 72: { - 'mass': 178.49, - 'name': 'Hafnium', - 'symbol': 'Hf' - }, - 73: { - 'mass': 180.94788, - 'name': 'Tantalum', - 'symbol': 'Ta' - }, - 74: { - 'mass': 183.84, - 'name': 'Tungsten', - 'symbol': 'W' - }, - 75: { - 'mass': 186.207, - 'name': 'Rhenium', - 'symbol': 'Re' - }, - 76: { - 'mass': 190.23, - 'name': 'Osmium', - 'symbol': 'Os' - }, - 77: { - 'mass': 192.217, - 'name': 'Iridium', - 'symbol': 'Ir' - }, - 78: { - 'mass': 195.084, - 'name': 'Platinum', - 'symbol': 'Pt' - }, - 79: { - 'mass': 196.966569, - 'name': 'Gold', - 'symbol': 'Au' - }, - 80: { - 'mass': 200.59, - 'name': 'Mercury', - 'symbol': 'Hg' - }, - 81: { - 'mass': 204.3833, - 'name': 'Thallium', - 'symbol': 'Tl' - }, - 82: { - 'mass': 207.2, - 'name': 'Lead', - 'symbol': 'Pb' - }, - 83: { - 'mass': 208.9804, - 'name': 'Bismuth', - 'symbol': 'Bi' - }, - 84: { - 'mass': 209.0, - 'name': 'Polonium', - 'symbol': 'Po' - }, - 85: { - 'mass': 210.0, - 'name': 'Astatine', - 'symbol': 'At' - }, - 86: { - 'mass': 222.0, - 'name': 'Radon', - 'symbol': 'Rn' - }, - 87: { - 'mass': 223.0, - 'name': 'Francium', - 'symbol': 'Fr' - }, - 88: { - 'mass': 226.0, - 'name': 'Radium', - 'symbol': 'Ra' - }, - 89: { - 'mass': 227.0, - 'name': 'Actinium', - 'symbol': 'Ac' - }, - 90: { - 'mass': 232.03806, - 'name': 'Thorium', - 'symbol': 'Th' - }, - 91: { - 'mass': 231.03588, - 'name': 'Protactinium', - 'symbol': 'Pa' - }, - 92: { - 'mass': 238.02891, - 'name': 'Uranium', - 'symbol': 'U' - }, - 93: { - 'mass': 237.0, - 'name': 'Neptunium', - 'symbol': 'Np' - }, - 94: { - 'mass': 244.0, - 'name': 'Plutonium', - 'symbol': 'Pu' - }, - 95: { - 'mass': 243.0, - 'name': 'Americium', - 'symbol': 'Am' - }, - 96: { - 'mass': 247.0, - 'name': 'Curium', - 'symbol': 'Cm' - }, - 97: { - 'mass': 247.0, - 'name': 'Berkelium', - 'symbol': 'Bk' - }, - 98: { - 'mass': 251.0, - 'name': 'Californium', - 'symbol': 'Cf' - }, - 99: { - 'mass': 252.0, - 'name': 'Einsteinium', - 'symbol': 'Es' - }, - 100: { - 'mass': 257.0, - 'name': 'Fermium', - 'symbol': 'Fm' - }, - 101: { - 'mass': 258.0, - 'name': 'Mendelevium', - 'symbol': 'Md' - }, - 102: { - 'mass': 259.0, - 'name': 'Nobelium', - 'symbol': 'No' - }, - 103: { - 'mass': 262.0, - 'name': 'Lawrencium', - 'symbol': 'Lr' - }, - 104: { - 'mass': 267.0, - 'name': 'Rutherfordium', - 'symbol': 'Rf' - }, - 105: { - 'mass': 268.0, - 'name': 'Dubnium', - 'symbol': 'Db' - }, - 106: { - 'mass': 271.0, - 'name': 'Seaborgium', - 'symbol': 'Sg' - }, - 107: { - 'mass': 272.0, - 'name': 'Bohrium', - 'symbol': 'Bh' - }, - 108: { - 'mass': 270.0, - 'name': 'Hassium', - 'symbol': 'Hs' - }, - 109: { - 'mass': 276.0, - 'name': 'Meitnerium', - 'symbol': 'Mt' - }, - 110: { - 'mass': 281.0, - 'name': 'Darmstadtium', - 'symbol': 'Ds' - }, - 111: { - 'mass': 280.0, - 'name': 'Roentgenium', - 'symbol': 'Rg' - }, - 112: { - 'mass': 285.0, - 'name': 'Copernicium', - 'symbol': 'Cn' - }, - 114: { - 'mass': 289.0, - 'name': 'Flerovium', - 'symbol': 'Fl' - }, - 116: { - 'mass': 293.0, - 'name': 'Livermorium', - 'symbol': 'Lv' - }, -} diff --git a/aiida/common/datastructures.py b/aiida/common/datastructures.py deleted file mode 100644 index 731237ec5d..0000000000 --- a/aiida/common/datastructures.py +++ /dev/null @@ -1,205 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module to define commonly used data structures.""" -from __future__ import annotations - -from enum import Enum, IntEnum -from typing import TYPE_CHECKING - -from .extendeddicts import DefaultFieldsAttributeDict - -__all__ = ('StashMode', 'CalcJobState', 'CalcInfo', 'CodeInfo', 'CodeRunMode') - - -class StashMode(Enum): - """Mode to use when stashing files from the working directory of a completed calculation job for safekeeping.""" - - COPY = 'copy' - - -class CalcJobState(Enum): - """The sub state of a CalcJobNode while its Process is in an active state (i.e. Running or Waiting).""" - - UPLOADING = 'uploading' - SUBMITTING = 'submitting' - WITHSCHEDULER = 'withscheduler' - STASHING = 'stashing' - RETRIEVING = 'retrieving' - PARSING = 'parsing' - - -class CalcInfo(DefaultFieldsAttributeDict): - """ - This object will store the data returned by the calculation plugin and to be - passed to the ExecManager. - - In the following descriptions all paths have to be considered relative - - * retrieve_list: a list of strings or tuples that indicate files that are to be retrieved from the remote after the - calculation has finished and stored in the ``retrieved_folder`` output node of type ``FolderData``. If the entry - in the list is just a string, it is assumed to be the filepath on the remote and it will be copied to the base - directory of the retrieved folder, where the name corresponds to the basename of the remote relative path. This - means that any remote folder hierarchy is ignored entirely. - - Remote folder hierarchy can be (partially) maintained by using a tuple instead, with the following format - - (source, target, depth) - - The ``source`` and ``target`` elements are relative filepaths in the remote and retrieved folder. The contents - of ``source`` (whether it is a file or folder) are copied in its entirety to the ``target`` subdirectory in the - retrieved folder. If no subdirectory should be created, ``'.'`` should be specified for ``target``. - - The ``source`` filepaths support glob patterns ``*`` in case the exact name of the files that are to be - retrieved are not know a priori. - - The ``depth`` element can be used to control what level of nesting of the source folder hierarchy should be - maintained. If ``depth`` equals ``0`` or ``1`` (they are equivalent), only the basename of the ``source`` - filepath is kept. For each additional level, another subdirectory of the remote hierarchy is kept. For example: - - ('path/sub/file.txt', '.', 2) - - will retrieve the ``file.txt`` and store it under the path: - - sub/file.txt - - * retrieve_temporary_list: a list of strings or tuples that indicate files that will be retrieved - and stored temporarily in a FolderData, that will be available only during the parsing call. - The format of the list is the same as that of 'retrieve_list' - - * local_copy_list: a list of tuples with format ('node_uuid', 'filename', relativedestpath') - * remote_copy_list: a list of tuples with format ('remotemachinename', 'remoteabspath', 'relativedestpath') - * remote_symlink_list: a list of tuples with format ('remotemachinename', 'remoteabspath', 'relativedestpath') - * provenance_exclude_list: a sequence of relative paths of files in the sandbox folder of a `CalcJob` instance that - should not be stored permanantly in the repository folder of the corresponding `CalcJobNode` that will be - created, but should only be copied to the remote working directory on the target computer. This is useful for - input files that should be copied to the working directory but should not be copied as well to the repository - either, for example, because they contain proprietary information or because they are big and their content is - already indirectly present in the repository through one of the data nodes passed as input to the calculation. - * codes_info: a list of dictionaries used to pass the info of the execution of a code - * codes_run_mode: the mode of execution in which the codes will be run (`CodeRunMode.SERIAL` by default, - but can also be `CodeRunMode.PARALLEL`) - * skip_submit: a flag that, when set to True, orders the engine to skip the submit/update steps (so no code will - run, it will only upload the files and then retrieve/parse). - """ - - _default_fields = ( - 'job_environment', 'email', 'email_on_started', 'email_on_terminated', 'uuid', 'prepend_text', 'append_text', - 'num_machines', 'num_mpiprocs_per_machine', 'priority', 'max_wallclock_seconds', 'max_memory_kb', 'rerunnable', - 'retrieve_list', 'retrieve_temporary_list', 'local_copy_list', 'remote_copy_list', 'remote_symlink_list', - 'provenance_exclude_list', 'codes_info', 'codes_run_mode', 'skip_submit' - ) - - if TYPE_CHECKING: - - job_environment: None | dict[str, str] - email: None | str - email_on_started: bool - email_on_terminated: bool - uuid: None | str - prepend_text: None | str - append_text: None | str - num_machines: None | int - num_mpiprocs_per_machine: None | int - priority: None | int - max_wallclock_seconds: None | int - max_memory_kb: None | int - rerunnable: bool - retrieve_list: None | list[str | tuple[str, str, str]] - retrieve_temporary_list: None | list[str | tuple[str, str, str]] - local_copy_list: None | list[tuple[str, str, str]] - remote_copy_list: None | list[tuple[str, str, str]] - remote_symlink_list: None | list[tuple[str, str, str]] - provenance_exclude_list: None | list[str] - codes_info: None | list[CodeInfo] - codes_run_mode: None | CodeRunMode - skip_submit: None | bool - - -class CodeInfo(DefaultFieldsAttributeDict): - """ - This attribute-dictionary contains the information needed to execute a code. - Possible attributes are: - - * ``cmdline_params``: a list of strings, containing parameters to be written on - the command line right after the call to the code, as for example:: - - code.x cmdline_params[0] cmdline_params[1] ... < stdin > stdout - - * ``stdin_name``: (optional) the name of the standard input file. Note, it is - only possible to use the stdin with the syntax:: - - code.x < stdin_name - - If no stdin_name is specified, the string "< stdin_name" will not be - passed to the code. - Note: it is not possible to substitute/remove the '<' if stdin_name is specified; - if that is needed, avoid stdin_name and use instead the cmdline_params to - specify a suitable syntax. - * ``stdout_name``: (optional) the name of the standard output file. Note, it is - only possible to pass output to stdout_name with the syntax:: - - code.x ... > stdout_name - - If no stdout_name is specified, the string "> stdout_name" will not be - passed to the code. - Note: it is not possible to substitute/remove the '>' if stdout_name is specified; - if that is needed, avoid stdout_name and use instead the cmdline_params to - specify a suitable syntax. - * ``stderr_name``: (optional) a string, the name of the error file of the code. - * ``join_files``: (optional) if True, redirects the error to the output file. - If join_files=True, the code will be called as:: - - code.x ... > stdout_name 2>&1 - - otherwise, if join_files=False and stderr is passed:: - - code.x ... > stdout_name 2> stderr_name - - * ``withmpi``: if True, executes the code with mpirun (or another MPI installed - on the remote computer) - * ``code_uuid``: the uuid of the code associated to the CodeInfo - """ - _default_fields = ( - 'cmdline_params', # as a list of strings - 'stdin_name', - 'stdout_name', - 'stderr_name', - 'join_files', - 'withmpi', - 'code_uuid' - ) - - if TYPE_CHECKING: - - cmdline_params: None | list[str] - stdin_name: None | str - stdout_name: None | str - stderr_name: None | str - join_files: None | bool - withmpi: None | bool - code_uuid: None | str - - -class CodeRunMode(IntEnum): - """Enum to indicate the way the codes of a calculation should be run. - - For PARALLEL, the codes for a given calculation will be run in parallel by running them in the background:: - - code1.x & - code2.x & - - For the SERIAL option, codes will be executed sequentially by running for example the following:: - - code1.x - code2.x - """ - - SERIAL = 0 - PARALLEL = 1 diff --git a/aiida/common/exceptions.py b/aiida/common/exceptions.py deleted file mode 100644 index eec8b94446..0000000000 --- a/aiida/common/exceptions.py +++ /dev/null @@ -1,301 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module that define the exceptions that are thrown by AiiDA's internal code.""" - -__all__ = ( - 'AiidaException', 'NotExistent', 'NotExistentAttributeError', 'NotExistentKeyError', 'MultipleObjectsError', - 'RemoteOperationError', 'ContentNotExistent', 'FailedError', 'StoringNotAllowed', 'ModificationNotAllowed', - 'IntegrityError', 'UniquenessError', 'EntryPointError', 'MissingEntryPointError', 'MultipleEntryPointError', - 'LoadingEntryPointError', 'InvalidEntryPointTypeError', 'InvalidOperation', 'ParsingError', 'InternalError', - 'PluginInternalError', 'ValidationError', 'ConfigurationError', 'ProfileConfigurationError', - 'MissingConfigurationError', 'ConfigurationVersionError', 'IncompatibleStorageSchema', 'CorruptStorage', - 'DbContentError', 'InputValidationError', 'FeatureNotAvailable', 'FeatureDisabled', 'LicensingException', - 'TestsNotAllowedError', 'UnsupportedSpeciesError', 'TransportTaskException', 'OutputParsingError', 'HashingError', - 'StorageMigrationError', 'LockedProfileError', 'LockingProfileError', 'ClosedStorage' -) - - -class AiidaException(Exception): - """ - Base class for all AiiDA exceptions. - - Each module will have its own subclass, inherited from this - (e.g. ExecManagerException, TransportException, ...) - """ - - -class NotExistent(AiidaException): - """ - Raised when the required entity does not exist. - """ - - -class NotExistentAttributeError(AttributeError, NotExistent): - """ - Raised when the required entity does not exist, when fetched as an attribute. - """ - - -class NotExistentKeyError(KeyError, NotExistent): - """ - Raised when the required entity does not exist, when fetched as a dictionary key. - """ - - -class MultipleObjectsError(AiidaException): - """ - Raised when more than one entity is found in the DB, but only one was - expected. - """ - - -class RemoteOperationError(AiidaException): - """ - Raised when an error in a remote operation occurs, as in a failed kill() - of a scheduler job. - """ - - -class ContentNotExistent(NotExistent): - """ - Raised when trying to access an attribute, a key or a file in the result - nodes that is not present - """ - - -class FailedError(AiidaException): - """ - Raised when accessing a calculation that is in the FAILED status - """ - - -class StoringNotAllowed(AiidaException): - """ - Raised when the user tries to store an unstorable node (e.g. a base Node class) - """ - - -class ModificationNotAllowed(AiidaException): - """ - Raised when the user tries to modify a field, object, property, ... that should not - be modified. - """ - - -class IntegrityError(AiidaException): - """ - Raised when there is an underlying data integrity error. This can be database related - or a general data integrity error. This can happen if, e.g., a foreign key check fails. - See PEP 249 for details. - """ - - -class UniquenessError(AiidaException): - """ - Raised when the user tries to violate a uniqueness constraint (on the - DB, for instance). - """ - - -class EntryPointError(AiidaException): - """Raised when an entry point cannot be uniquely resolved and imported.""" - - -class MissingEntryPointError(EntryPointError): - """Raised when the requested entry point is not registered with the entry point manager.""" - - -class MultipleEntryPointError(EntryPointError): - """Raised when the requested entry point cannot uniquely be resolved by the entry point manager.""" - - -class LoadingEntryPointError(EntryPointError): - """Raised when the resource corresponding to requested entry point cannot be imported.""" - - -class InvalidEntryPointTypeError(EntryPointError): - """Raised when a loaded entry point has a type that is not supported by the corresponding entry point group.""" - - -class InvalidOperation(AiidaException): - """ - The allowed operation is not valid (e.g., when trying to add a non-internal attribute - before saving the entry), or deleting an entry that is protected (e.g., - because it is referenced by foreign keys) - """ - - -class ParsingError(AiidaException): - """ - Generic error raised when there is a parsing error - """ - - -class InternalError(AiidaException): - """ - Error raised when there is an internal error of AiiDA. - """ - - -class PluginInternalError(InternalError): - """ - Error raised when there is an internal error which is due to a plugin - and not to the AiiDA infrastructure. - """ - - -class ValidationError(AiidaException): - """ - Error raised when there is an error during the validation phase - of a property. - """ - - -class ConfigurationError(AiidaException): - """ - Error raised when there is a configuration error in AiiDA. - """ - - -class ProfileConfigurationError(ConfigurationError): - """ - Configuration error raised when a wrong/inexistent profile is requested. - """ - - -class MissingConfigurationError(ConfigurationError): - """ - Configuration error raised when the configuration file is missing. - """ - - -class ConfigurationVersionError(ConfigurationError): - """ - Configuration error raised when the configuration file version is not - compatible with the current version. - """ - - -class ClosedStorage(AiidaException): - """Raised when trying to access data from a closed storage backend.""" - - -class UnreachableStorage(ConfigurationError): - """Raised when a connection to the storage backend fails.""" - - -class IncompatibleDatabaseSchema(ConfigurationError): - """Raised when the storage schema is incompatible with that of the code. - - Deprecated for ``IncompatibleStorageSchema`` - """ - - -class IncompatibleStorageSchema(IncompatibleDatabaseSchema): - """Raised when the storage schema is incompatible with that of the code.""" - - -class CorruptStorage(ConfigurationError): - """Raised when the storage is not found to be internally consistent on validation.""" - - -class DatabaseMigrationError(AiidaException): - """Raised if a critical error is encountered during a storage migration. - - Deprecated for ``StorageMigrationError`` - """ - - -class StorageMigrationError(DatabaseMigrationError): - """Raised if a critical error is encountered during a storage migration.""" - - -class DbContentError(AiidaException): - """ - Raised when the content of the DB is not valid. - This should never happen if the user does not play directly - with the DB. - """ - - -class InputValidationError(ValidationError): - """ - The input data for a calculation did not validate (e.g., missing - required input data, wrong data, ...) - """ - - -class FeatureNotAvailable(AiidaException): - """ - Raised when a feature is requested from a plugin, that is not available. - """ - - -class FeatureDisabled(AiidaException): - """ - Raised when a feature is requested, but the user has chosen to disable - it (e.g., for submissions on disabled computers). - """ - - -class LicensingException(AiidaException): - """ - Raised when requirements for data licensing are not met. - """ - - -class TestsNotAllowedError(AiidaException): - """ - Raised when tests are required to be run/loaded, but we are not in a testing environment. - - This is to prevent data loss. - """ - - -class UnsupportedSpeciesError(ValueError): - """ - Raised when StructureData operations are fed species that are not supported by AiiDA such as Deuterium - """ - - -class TransportTaskException(AiidaException): - """ - Raised when a TransportTask, an task to be completed by the engine that requires transport, fails - """ - - -class OutputParsingError(ParsingError): - """ - Can be raised by a Parser when it fails to parse the output generated by a `CalcJob` process. - """ - - -class CircusCallError(AiidaException): - """ - Raised when an attempt to contact Circus returns an error in the response - """ - - -class HashingError(AiidaException): - """ - Raised when an attempt to hash an object fails via a known failure mode - """ - - -class LockedProfileError(AiidaException): - """ - Raised if attempting to access a locked profile - """ - - -class LockingProfileError(AiidaException): - """ - Raised if the profile can`t be locked - """ diff --git a/aiida/common/hashing.py b/aiida/common/hashing.py deleted file mode 100644 index 7cb06a3f81..0000000000 --- a/aiida/common/hashing.py +++ /dev/null @@ -1,309 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Common password and hash generation functions.""" -from collections import OrderedDict, abc -from datetime import date, datetime, timezone -from decimal import Decimal -from functools import singledispatch -import hashlib -from itertools import chain -import numbers -from operator import itemgetter -import secrets -import string -import typing -import uuid - -from aiida.common.constants import AIIDA_FLOAT_PRECISION -from aiida.common.exceptions import HashingError -from aiida.common.utils import DatetimePrecision - -from .folders import Folder - - -def get_random_string(length: int = 12) -> str: - """Return a securely generated random string. - - The default length of 12 with the all ASCII letters and digits returns a 71-bit value: - - log_2((26+26+10)^12) =~ 71 bits - - :param length: The number of characters to use for the string. - """ - alphabet = string.ascii_letters + string.digits - return ''.join(secrets.choice(alphabet) for i in range(length)) - - -BLAKE2B_OPTIONS = { - 'fanout': 0, # unlimited fanout/depth mode - 'depth': 2, # has fixed depth of 2 - 'digest_size': 32, # we do not need a cryptographically relevant digest - 'inner_size': 64, # ... but still use 64 as the inner size -} - - -def chunked_file_hash( - handle: typing.BinaryIO, hash_cls: typing.Any, chunksize: int = 524288, **kwargs: typing.Any -) -> str: - """Return the hash for the given file handle - - Will read the file in chunks, which should be opened in 'rb' mode. - - :param handle: a file handle, opened in 'rb' mode. - :param hash_cls: a class implementing hashlib._Hash - :param chunksize: number of bytes to chunk the file read in - :param kwargs: arguments to pass to the hasher initialisation - :return: the hash hexdigest (the hash key) - """ - hasher = hash_cls(**kwargs) - while True: - chunk = handle.read(chunksize) - hasher.update(chunk) - - if not chunk: - # Empty returned value: EOF - break - - return hasher.hexdigest() - - -def make_hash(object_to_hash, **kwargs): - """ - Makes a hash from a dictionary, list, tuple or set to any level, that contains - only other hashable or nonhashable types (including lists, tuples, sets, and - dictionaries). - - :param object_to_hash: the object to hash - - :returns: a unique hash - - There are a lot of modules providing functionalities to create unique - hashes for hashable values. - However, getting hashes for nonhashable items like sets or dictionaries is - not easily doable because order is not fixed. - This leads to the peril of getting different hashes for the same - dictionary. - - This function avoids this by recursing through nonhashable items and - hashing iteratively. Uses python's sorted function to sort unsorted - sets and dictionaries by sorting the hashed keys. - """ - hashes = _make_hash(object_to_hash, **kwargs) # pylint: disable=assignment-from-no-return - - # use the Unlimited fanout hashing protocol outlined in - # https://blake2.net/blake2_20130129.pdf - final_hash = hashlib.blake2b(node_depth=1, last_node=True, **BLAKE2B_OPTIONS) - - for sub in hashes: - final_hash.update(sub) - - # add an empty last leaf node - final_hash.update(hashlib.blake2b(node_depth=0, last_node=True, **BLAKE2B_OPTIONS).digest()) - - return final_hash.hexdigest() - - -@singledispatch -def _make_hash(object_to_hash, **_): - """ - Implementation of the ``make_hash`` function. The hash is created as a - 28 byte integer, and only later converted to a string. - """ - raise HashingError(f'Value of type {type(object_to_hash)} cannot be hashed') - - -def _single_digest(obj_type, obj_bytes=b''): - return hashlib.blake2b(obj_bytes, person=obj_type.encode('ascii'), node_depth=0, **BLAKE2B_OPTIONS).digest() - - -_END_DIGEST = _single_digest(')') - - -@_make_hash.register(bytes) -def _(bytes_obj, **kwargs): - """Hash arbitrary byte strings.""" - return [_single_digest('str', bytes_obj)] - - -@_make_hash.register(str) -def _(val, **kwargs): - """Convert strings explicitly to bytes.""" - return [_single_digest('str', val.encode('utf-8'))] - - -@_make_hash.register(abc.Sequence) -def _(sequence_obj, **kwargs): - # unpack the list and use the elements - return [_single_digest('list(')] + list(chain.from_iterable(_make_hash(i, **kwargs) for i in sequence_obj) - ) + [_END_DIGEST] - - -@_make_hash.register(abc.Set) -def _(set_obj, **kwargs): - # turn the set objects into a list of hashes which are always sortable, - # then return a flattened list of the hashes - return [_single_digest('set(')] + list(chain.from_iterable(sorted(_make_hash(i, **kwargs) for i in set_obj)) - ) + [_END_DIGEST] - - -@_make_hash.register(abc.Mapping) -def _(mapping, **kwargs): - """Hashing arbitrary mapping containers (dict, OrderedDict) by first sorting by hashed keys""" - - def hashed_key_mapping(): - for key, value in mapping.items(): - yield (_make_hash(key, **kwargs), value) - - return [_single_digest('dict(')] + list( - chain.from_iterable( - (k_digest + _make_hash(val, **kwargs)) for k_digest, val in sorted(hashed_key_mapping(), key=itemgetter(0)) - ) - ) + [_END_DIGEST] - - -@_make_hash.register(OrderedDict) -def _(mapping, **kwargs): - """ - Hashing of OrderedDicts - - :param odict_as_unordered: hash OrderedDicts as normal dicts (mostly for testing) - """ - - if kwargs.get('odict_as_unordered', False): - return _make_hash.registry[abc.Mapping](mapping) - - return ([_single_digest('odict(')] + list( - chain.from_iterable((_make_hash(key, **kwargs) + _make_hash(val, **kwargs)) for key, val in mapping.items()) - ) + [_END_DIGEST]) - - -@_make_hash.register(numbers.Real) -def _(val, **kwargs): - """ - Before hashing a float, convert to a string (via rounding) and with a fixed number of digits after the comma. - Note that the `_single_digest` requires a bytes object so we need to encode the utf-8 string first - """ - return [_single_digest('float', float_to_text(val, sig=AIIDA_FLOAT_PRECISION).encode('utf-8'))] - - -@_make_hash.register(Decimal) -def _(val, **kwargs): - """ - While a decimal can be converted exactly to a string which captures all characteristics of the underlying - implementation, we also need compatibility with "equal" representations as int or float. Hence we are checking - for the exponent (which is negative if there is a fractional component, 0 otherwise) and get the same hash - as for a corresponding float or int. - """ - if val.as_tuple().exponent < 0: - return [_single_digest('float', float_to_text(val, sig=AIIDA_FLOAT_PRECISION).encode('utf-8'))] - return [_single_digest('int', f'{val}'.encode('utf-8'))] - - -@_make_hash.register(numbers.Complex) -def _(val, **kwargs): - """ - In case of a complex number, use the same encoding of two floats and join them with a special symbol (a ! here). - """ - return [ - _single_digest( - 'complex', '{}!{}'.format( - float_to_text(val.real, sig=AIIDA_FLOAT_PRECISION), float_to_text(val.imag, sig=AIIDA_FLOAT_PRECISION) - ).encode('utf-8') - ) - ] - - -@_make_hash.register(numbers.Integral) -def _(val, **kwargs): - """get the hash of the little-endian signed long long representation of the integer""" - return [_single_digest('int', f'{val}'.encode('utf-8'))] - - -@_make_hash.register(bool) -def _(val, **kwargs): - return [_single_digest('bool', b'\x01' if val else b'\x00')] - - -@_make_hash.register(type(None)) -def _(val, **kwargs): - return [_single_digest('none')] - - -@_make_hash.register(datetime) -def _(val, **kwargs): - """hashes the little-endian rep of the float .""" - # see also https://stackoverflow.com/a/8778548 for an excellent elaboration - if val.tzinfo is None or val.utcoffset() is None: - val = val.replace(tzinfo=timezone.utc) - - timestamp = val.timestamp() - return [_single_digest('datetime', float_to_text(timestamp, sig=AIIDA_FLOAT_PRECISION).encode('utf-8'))] - - -@_make_hash.register(date) -def _(val, **kwargs): - """Hashes the string representation in ISO format of the `datetime.date` object.""" - return [_single_digest('date', val.isoformat().encode('utf-8'))] - - -@_make_hash.register(uuid.UUID) -def _(val, **kwargs): - return [_single_digest('uuid', val.bytes)] - - -@_make_hash.register(DatetimePrecision) -def _(datetime_precision, **kwargs): - """ Hashes for DatetimePrecision object - """ - return [_single_digest('dt_prec')] + list( - chain.from_iterable(_make_hash(i, **kwargs) for i in [datetime_precision.dtobj, datetime_precision.precision]) - ) + [_END_DIGEST] - - -@_make_hash.register(Folder) -def _(folder, **kwargs): - """ - Hash the content of a Folder object. The name of the folder itself is actually ignored - :param ignored_folder_content: list of filenames to be ignored for the hashing - """ - - ignored_folder_content = kwargs.get('ignored_folder_content', []) - - def folder_digests(subfolder): - """traverses the given folder and yields digests for the contained objects""" - for name, isfile in sorted(subfolder.get_content_list(only_paths=False), key=itemgetter(0)): - if name in ignored_folder_content: - continue - - if isfile: - yield _single_digest('fname', name.encode('utf-8')) - with subfolder.open(name, mode='rb') as fhandle: - yield _single_digest('fcontent', fhandle.read()) - else: - yield _single_digest('dir(', name.encode('utf-8')) - for digest in folder_digests(subfolder.get_subfolder(name)): - yield digest - yield _END_DIGEST - - return [_single_digest('folder')] + list(folder_digests(folder)) - - -def float_to_text(value, sig): - """ - Convert float to text string for computing hash. - Preseve up to N significant number given by sig. - - :param value: the float value to convert - :param sig: choose how many digits after the comma should be output - """ - if value == 0: - value = 0. # Identify value of -0. and overwrite with 0. - fmt = f'{{:.{sig}g}}' - return fmt.format(value) diff --git a/aiida/common/links.py b/aiida/common/links.py deleted file mode 100644 index 7e8b1fcb7b..0000000000 --- a/aiida/common/links.py +++ /dev/null @@ -1,124 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with utilities and data structures pertaining to links between nodes in the provenance graph.""" - -from collections import namedtuple -from enum import Enum - -from .lang import isidentifier, type_check - -__all__ = ('GraphTraversalRule', 'GraphTraversalRules', 'LinkType', 'validate_link_label') - - -class LinkType(Enum): - """A simple enum of allowed link types.""" - - CREATE = 'create' - RETURN = 'return' - INPUT_CALC = 'input_calc' - INPUT_WORK = 'input_work' - CALL_CALC = 'call_calc' - CALL_WORK = 'call_work' - - -GraphTraversalRule = namedtuple('GraphTraversalRule', ['link_type', 'direction', 'toggleable', 'default']) -"""A namedtuple that defines a graph traversal rule. - -When starting from a certain sub set of nodes, the graph traversal rules specify which links should be followed to -add adjacent nodes to finally arrive at a set of nodes that represent a valid and consistent sub graph. - -:param link_type: the `LinkType` that the rule applies to -:param direction: whether the link type should be followed backwards or forwards -:param toggleable: boolean to indicate whether the rule can be changed from the default value. If this is `False` it - means the default value can never be changed as it will result in an inconsistent graph. -:param default: boolean, the default value of the rule, if `True` means that the link type for the given direction - should be followed. -""" - - -class GraphTraversalRules(Enum): - """Graph traversal rules when deleting or exporting nodes.""" - - DEFAULT = { - 'input_calc_forward': GraphTraversalRule(LinkType.INPUT_CALC, 'forward', True, False), - 'input_calc_backward': GraphTraversalRule(LinkType.INPUT_CALC, 'backward', True, False), - 'create_forward': GraphTraversalRule(LinkType.CREATE, 'forward', True, False), - 'create_backward': GraphTraversalRule(LinkType.CREATE, 'backward', True, False), - 'return_forward': GraphTraversalRule(LinkType.RETURN, 'forward', True, False), - 'return_backward': GraphTraversalRule(LinkType.RETURN, 'backward', True, False), - 'input_work_forward': GraphTraversalRule(LinkType.INPUT_WORK, 'forward', True, False), - 'input_work_backward': GraphTraversalRule(LinkType.INPUT_WORK, 'backward', True, False), - 'call_calc_forward': GraphTraversalRule(LinkType.CALL_CALC, 'forward', True, False), - 'call_calc_backward': GraphTraversalRule(LinkType.CALL_CALC, 'backward', True, False), - 'call_work_forward': GraphTraversalRule(LinkType.CALL_WORK, 'forward', True, False), - 'call_work_backward': GraphTraversalRule(LinkType.CALL_WORK, 'backward', True, False) - } - - DELETE = { - 'input_calc_forward': GraphTraversalRule(LinkType.INPUT_CALC, 'forward', False, True), - 'input_calc_backward': GraphTraversalRule(LinkType.INPUT_CALC, 'backward', False, False), - 'create_forward': GraphTraversalRule(LinkType.CREATE, 'forward', True, True), - 'create_backward': GraphTraversalRule(LinkType.CREATE, 'backward', False, True), - 'return_forward': GraphTraversalRule(LinkType.RETURN, 'forward', False, False), - 'return_backward': GraphTraversalRule(LinkType.RETURN, 'backward', False, True), - 'input_work_forward': GraphTraversalRule(LinkType.INPUT_WORK, 'forward', False, True), - 'input_work_backward': GraphTraversalRule(LinkType.INPUT_WORK, 'backward', False, False), - 'call_calc_forward': GraphTraversalRule(LinkType.CALL_CALC, 'forward', True, True), - 'call_calc_backward': GraphTraversalRule(LinkType.CALL_CALC, 'backward', False, True), - 'call_work_forward': GraphTraversalRule(LinkType.CALL_WORK, 'forward', True, True), - 'call_work_backward': GraphTraversalRule(LinkType.CALL_WORK, 'backward', False, True) - } - - EXPORT = { - 'input_calc_forward': GraphTraversalRule(LinkType.INPUT_CALC, 'forward', True, False), - 'input_calc_backward': GraphTraversalRule(LinkType.INPUT_CALC, 'backward', False, True), - 'create_forward': GraphTraversalRule(LinkType.CREATE, 'forward', False, True), - 'create_backward': GraphTraversalRule(LinkType.CREATE, 'backward', True, True), - 'return_forward': GraphTraversalRule(LinkType.RETURN, 'forward', False, True), - 'return_backward': GraphTraversalRule(LinkType.RETURN, 'backward', True, False), - 'input_work_forward': GraphTraversalRule(LinkType.INPUT_WORK, 'forward', True, False), - 'input_work_backward': GraphTraversalRule(LinkType.INPUT_WORK, 'backward', False, True), - 'call_calc_forward': GraphTraversalRule(LinkType.CALL_CALC, 'forward', False, True), - 'call_calc_backward': GraphTraversalRule(LinkType.CALL_CALC, 'backward', True, True), - 'call_work_forward': GraphTraversalRule(LinkType.CALL_WORK, 'forward', False, True), - 'call_work_backward': GraphTraversalRule(LinkType.CALL_WORK, 'backward', True, True) - } - - -def validate_link_label(link_label): - """Validate the given link label. - - Valid link labels adhere to the following restrictions: - - * Has to be a valid python identifier - * Can only contain alphanumeric characters and underscores - * Can not start or end with an underscore - - :raises TypeError: if the link label is not a string type - :raises ValueError: if the link label is invalid - """ - import re - - message = f'invalid link label `{link_label}`: should be string type but is instead: {type(link_label)}' - type_check(link_label, str, message) - - allowed_character_set = '[a-zA-Z0-9_]' - - if link_label.endswith('_'): - raise ValueError('cannot end with an underscore') - - if link_label.startswith('_'): - raise ValueError('cannot start with an underscore') - - if re.sub(allowed_character_set, '', link_label): - raise ValueError('only alphanumeric and underscores are allowed') - - if not isidentifier(link_label): - raise ValueError('not a valid python identifier') diff --git a/aiida/common/log.py b/aiida/common/log.py deleted file mode 100644 index 67a32cd7f7..0000000000 --- a/aiida/common/log.py +++ /dev/null @@ -1,247 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module for all logging methods/classes that don't need the ORM.""" -from __future__ import annotations - -import collections -import contextlib -import logging -import types -from typing import cast - -__all__ = ('AIIDA_LOGGER', 'override_log_level') - -# Custom logging level, intended specifically for informative log messages reported during WorkChains. -# We want the level between INFO(20) and WARNING(30) such that it will be logged for the default loglevel, however -# the value 25 is already reserved for SUBWARNING by the multiprocessing module. -LOG_LEVEL_REPORT = 23 - -# Add the custom log level to the :mod:`logging` module and add a corresponding report logging method. -logging.addLevelName(LOG_LEVEL_REPORT, 'REPORT') - - -def report(self: logging.Logger, msg, *args, **kwargs): - """Log a message at the ``REPORT`` level.""" - self.log(LOG_LEVEL_REPORT, msg, *args, **kwargs) - - -class AiidaLoggerType(logging.Logger): - - def report(self, msg: str, *args, **kwargs) -> None: - """Log a message at the ``REPORT`` level.""" - - -setattr(logging, 'REPORT', LOG_LEVEL_REPORT) -setattr(logging.Logger, 'report', report) - -# Convenience dictionary of available log level names and their log level integer -LOG_LEVELS = { - logging.getLevelName(logging.NOTSET): logging.NOTSET, - logging.getLevelName(logging.DEBUG): logging.DEBUG, - logging.getLevelName(logging.INFO): logging.INFO, - logging.getLevelName(LOG_LEVEL_REPORT): LOG_LEVEL_REPORT, - logging.getLevelName(logging.WARNING): logging.WARNING, - logging.getLevelName(logging.ERROR): logging.ERROR, - logging.getLevelName(logging.CRITICAL): logging.CRITICAL, -} - -AIIDA_LOGGER = cast(AiidaLoggerType, logging.getLogger('aiida')) - -CLI_ACTIVE: bool | None = None -"""Flag that is set to ``True`` if the module is imported by ``verdi`` being called.""" - -CLI_LOG_LEVEL: str | None = None -"""Set if ``verdi`` is called with ``--verbosity`` flag specified, and is set to corresponding log level.""" - - -# The default logging dictionary for AiiDA that can be used in conjunction -# with the config.dictConfig method of python's logging module -def get_logging_config(): - from aiida.manage.configuration import get_config_option - - return { - 'version': 1, - 'disable_existing_loggers': False, - 'formatters': { - 'verbose': { - 'format': '%(levelname)s %(asctime)s %(module)s %(process)d ' - '%(thread)d %(message)s', - }, - 'halfverbose': { - 'format': '%(asctime)s <%(process)d> %(name)s: [%(levelname)s] %(message)s', - 'datefmt': '%m/%d/%Y %I:%M:%S %p', - }, - 'cli': { - 'class': 'aiida.cmdline.utils.log.CliFormatter' - } - }, - 'handlers': { - 'console': { - 'class': 'logging.StreamHandler', - 'formatter': 'halfverbose', - }, - 'cli': { - 'class': 'aiida.cmdline.utils.log.CliHandler', - 'formatter': 'cli' - }, - }, - 'loggers': { - 'aiida': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.aiida_loglevel'), - 'propagate': True, - }, - 'verdi': { - 'handlers': ['cli'], - 'level': lambda: get_config_option('logging.verdi_loglevel'), - 'propagate': False, - }, - 'plumpy': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.plumpy_loglevel'), - 'propagate': False, - }, - 'kiwipy': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.kiwipy_loglevel'), - 'propagate': False, - }, - 'paramiko': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.paramiko_loglevel'), - 'propagate': False, - }, - 'alembic': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.alembic_loglevel'), - 'propagate': False, - }, - 'aio_pika': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.aiopika_loglevel'), - 'propagate': False, - }, - 'sqlalchemy': { - 'handlers': ['console'], - 'level': lambda: get_config_option('logging.sqlalchemy_loglevel'), - 'propagate': False, - 'qualname': 'sqlalchemy.engine', - }, - 'py.warnings': { - 'handlers': ['console'], - }, - }, - } - - -def evaluate_logging_configuration(dictionary): - """Recursively evaluate the logging configuration, calling lambdas when encountered. - - This allows the configuration options that are dependent on the active profile to be loaded lazily. - - :return: evaluated logging configuration dictionary - """ - result = {} - - for key, value in dictionary.items(): - if isinstance(value, collections.abc.Mapping): - result[key] = evaluate_logging_configuration(value) - elif isinstance(value, types.LambdaType): # pylint: disable=no-member - result[key] = value() - else: - result[key] = value - - return result - - -def configure_logging(with_orm=False, daemon=False, daemon_log_file=None): - """ - Setup the logging by retrieving the LOGGING dictionary from aiida and passing it to - the python module logging.config.dictConfig. If the logging needs to be setup for the - daemon, set the argument 'daemon' to True and specify the path to the log file. This - will cause a 'daemon_handler' to be added to all the configured loggers, that is a - RotatingFileHandler that writes to the log file. - - :param with_orm: configure logging to the backend storage. - We don't configure this by default, since it would load the modules that slow the CLI - :param daemon: configure the logging for a daemon task by adding a file handler instead - of the default 'console' StreamHandler - :param daemon_log_file: absolute filepath of the log file for the RotatingFileHandler - """ - from logging.config import dictConfig - - # Evaluate the `LOGGING` configuration to resolve the lambdas that will retrieve the correct values based on the - # currently configured profile. - config = evaluate_logging_configuration(get_logging_config()) - daemon_handler_name = 'daemon_log_file' - - # Add the daemon file handler to all loggers if daemon=True - if daemon is True: - - # Daemon always needs to run with ORM enabled - with_orm = True - - if daemon_log_file is None: - raise ValueError('daemon_log_file has to be defined when configuring for the daemon') - - config.setdefault('handlers', {}) - config['handlers'][daemon_handler_name] = { - 'level': 'DEBUG', - 'formatter': 'halfverbose', - 'class': 'logging.handlers.RotatingFileHandler', - 'filename': daemon_log_file, - 'encoding': 'utf8', - 'maxBytes': 10000000, # 10 MB - 'backupCount': 10, - } - - for logger in config.get('loggers', {}).values(): - logger.setdefault('handlers', []).append(daemon_handler_name) - try: - # Remove the `console` stdout stream handler to prevent messages being duplicated in the daemon log file - logger['handlers'].remove('console') - except ValueError: - pass - - # If the ``CLI_ACTIVE`` is set, a ``verdi`` command is being executed, so we replace the ``console`` handler with - # the ``cli`` one for all loggers. - if CLI_ACTIVE is True and not daemon: - for logger in config['loggers'].values(): - handlers = logger['handlers'] - if 'console' in handlers: - handlers.remove('console') - handlers.append('cli') - - # If ``CLI_LOG_LEVEL`` is set, a ``verdi`` command is being executed with the ``--verbosity`` option. In this case - # we override the log levels of all loggers with the specified log level. - if CLI_LOG_LEVEL is not None: - for logger in config['loggers'].values(): - logger['level'] = CLI_LOG_LEVEL - - # Add the `DbLogHandler` if `with_orm` is `True` - if with_orm: - from aiida.manage.configuration import get_config_option - config['handlers']['db_logger'] = { - 'level': get_config_option('logging.db_loglevel'), - 'class': 'aiida.orm.utils.log.DBLogHandler' - } - config['loggers']['aiida']['handlers'].append('db_logger') - - dictConfig(config) - - -@contextlib.contextmanager -def override_log_level(level=logging.CRITICAL): - """Temporarily adjust the log-level of logger.""" - logging.disable(level=level) - try: - yield - finally: - logging.disable(level=logging.NOTSET) diff --git a/aiida/common/utils.py b/aiida/common/utils.py deleted file mode 100644 index 011ed06b7e..0000000000 --- a/aiida/common/utils.py +++ /dev/null @@ -1,615 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Miscellaneous generic utility functions and classes.""" -from datetime import datetime -import filecmp -import inspect -import io -import os -import re -import sys -from typing import Any, Dict -from uuid import UUID - -from .lang import classproperty - - -def get_new_uuid(): - """ - Return a new UUID (typically to be used for new nodes). - """ - import uuid - return str(uuid.uuid4()) - - -def validate_uuid(given_uuid: str) -> bool: - """A simple check for the UUID validity.""" - try: - parsed_uuid = UUID(given_uuid, version=4) - except ValueError: - # If not a valid UUID - return False - - # Check if there was any kind of conversion of the hex during - # the validation - return str(parsed_uuid) == given_uuid - - -def validate_list_of_string_tuples(val, tuple_length): - """ - Check that: - - 1. ``val`` is a list or tuple - 2. each element of the list: - - a. is a list or tuple - b. is of length equal to the parameter tuple_length - c. each of the two elements is a string - - Return if valid, raise ValidationError if invalid - """ - from aiida.common.exceptions import ValidationError - - err_msg = ( - 'the value must be a list (or tuple) ' - 'of length-N list (or tuples), whose elements are strings; ' - 'N={}'.format(tuple_length) - ) - - if not isinstance(val, (list, tuple)): - raise ValidationError(err_msg) - - for element in val: - if ( - not isinstance(element, (list, tuple)) or (len(element) != tuple_length) or - not all(isinstance(s, str) for s in element) - ): - raise ValidationError(err_msg) - - return True - - -def get_unique_filename(filename, list_of_filenames): - """ - Return a unique filename that can be added to the list_of_filenames. - - If filename is not in list_of_filenames, it simply returns the filename - string itself. Otherwise, it appends a integer number to the filename - (before the extension) until it finds a unique filename. - - :param filename: the filename to add - :param list_of_filenames: the list of filenames to which filename - should be added, without name duplicates - - :returns: Either filename or its modification, with a number appended - between the name and the extension. - """ - if filename not in list_of_filenames: - return filename - - basename, ext = os.path.splitext(filename) - - # Not optimized, but for the moment this should be fast enough - append_int = 1 - while True: - new_filename = f'{basename:s}-{append_int:d}{ext:s}' - if new_filename not in list_of_filenames: - break - append_int += 1 - return new_filename - - -def str_timedelta(dt, max_num_fields=3, short=False, negative_to_zero=False): # pylint: disable=invalid-name - """ - Given a dt in seconds, return it in a HH:MM:SS format. - - :param dt: a TimeDelta object - :param max_num_fields: maximum number of non-zero fields to show - (for instance if the number of days is non-zero, shows only - days, hours and minutes, but not seconds) - :param short: if False, print always ``max_num_fields`` fields, even - if they are zero. If True, do not print the first fields, if they - are zero. - :param negative_to_zero: if True, set dt = 0 if dt < 0. - """ - if max_num_fields <= 0: - raise ValueError('max_num_fields must be > 0') - - s_tot = dt.total_seconds() # Important to get more than 1 day, and for - # negative values. dt.seconds would give - # wrong results in these cases, see - # http://docs.python.org/2/library/datetime.html - s_tot = int(s_tot) - - if negative_to_zero: - s_tot = max(s_tot, 0) - - negative = s_tot < 0 - s_tot = abs(s_tot) - - negative_string = ' in the future' if negative else ' ago' - - # For the moment stay away from months and years, difficult to get - days, remainder = divmod(s_tot, 3600 * 24) - hours, remainder = divmod(remainder, 3600) - minutes, seconds = divmod(remainder, 60) - - all_fields = [(days, 'D'), (hours, 'h'), (minutes, 'm'), (seconds, 's')] - fields = [] - start_insert = False - counter = 0 - for idx, field in enumerate(all_fields): - if field[0] != 0: - start_insert = True - if (len(all_fields) - idx) <= max_num_fields: - start_insert = True - if start_insert: - if counter >= max_num_fields: - break - fields.append(field) - counter += 1 - - if short: - while len(fields) > 1: # at least one element has to remain - if fields[0][0] != 0: - break - fields.pop(0) # remove first element - - # Join the fields - raw_string = ':'.join(['{:02d}{}'.format(*f) for f in fields]) - - if raw_string.startswith('0'): - raw_string = raw_string[1:] - - # Return the resulting string, appending a suitable string if the time - # is negative - return f'{raw_string}{negative_string}' - - -def get_class_string(obj): - """ - Return the string identifying the class of the object (module + object name, - joined by dots). - - It works both for classes and for class instances. - """ - if inspect.isclass(obj): - return f'{obj.__module__}.{obj.__name__}' - - return f'{obj.__module__}.{obj.__class__.__name__}' - - -def get_object_from_string(class_string): - """ - Given a string identifying an object (as returned by the get_class_string - method) load and return the actual object. - """ - import importlib - - the_module, _, the_name = class_string.rpartition('.') - - return getattr(importlib.import_module(the_module), the_name) - - -def grouper(n, iterable): # pylint: disable=invalid-name - """ - Given an iterable, returns an iterable that returns tuples of groups of - elements from iterable of length n, except the last one that has the - required length to exaust iterable (i.e., there is no filling applied). - - :param n: length of each tuple (except the last one,that will have length - <= n - :param iterable: the iterable to divide in groups - """ - import itertools - - iterator = iter(iterable) - while True: - chunk = tuple(itertools.islice(iterator, n)) - if not chunk: - return - yield chunk - - -class ArrayCounter: - """ - A counter & a method that increments it and returns its value. - It is used in various tests. - """ - seq = None - - def __init__(self): - self.seq = -1 - - def array_counter(self): - self.seq += 1 - return self.seq - - -def are_dir_trees_equal(dir1, dir2): - """ - Compare two directories recursively. Files in each directory are - assumed to be equal if their names and contents are equal. - - @param dir1: First directory path - @param dir2: Second directory path - - @return: True if the directory trees are the same and - there were no errors while accessing the directories or files, - False otherwise. - """ - - # Directory comparison - dirs_cmp = filecmp.dircmp(dir1, dir2) - if dirs_cmp.left_only or dirs_cmp.right_only or dirs_cmp.funny_files: - return ( - False, 'Left directory: {}, right directory: {}, files only ' - 'in left directory: {}, files only in right directory: ' - '{}, not comparable files: {}'.format( - dir1, dir2, dirs_cmp.left_only, dirs_cmp.right_only, dirs_cmp.funny_files - ) - ) - - # If the directories contain the same files, compare the common files - (_, mismatch, errors) = filecmp.cmpfiles(dir1, dir2, dirs_cmp.common_files, shallow=False) - if mismatch: - return (False, f"The following files in the directories {dir1} and {dir2} don't match: {mismatch}") - if errors: - return (False, f"The following files in the directories {dir1} and {dir2} aren't regular: {errors}") - - for common_dir in dirs_cmp.common_dirs: - new_dir1 = os.path.join(dir1, common_dir) - new_dir2 = os.path.join(dir2, common_dir) - res, msg = are_dir_trees_equal(new_dir1, new_dir2) - if not res: - return False, msg - - return True, f'The given directories ({dir1} and {dir2}) are equal' - - -class Prettifier: - """ - Class to manage prettifiers (typically for labels of kpoints - in band plots) - """ - - @classmethod - def _prettify_label_pass(cls, label): - """ - No-op prettifier, simply returns the same label - - :param label: a string to prettify - """ - return label - - @classmethod - def _prettify_label_agr(cls, label): - """ - Prettifier for XMGrace - - :param label: a string to prettify - """ - - label = ( - label - .replace('GAMMA', r'\xG\f{}') - .replace('DELTA', r'\xD\f{}') - .replace('LAMBDA', r'\xL\f{}') - .replace('SIGMA', r'\xS\f{}') - ) # yapf:disable - return re.sub(r'_(.?)', r'\\s\1\\N', label) - - @classmethod - def _prettify_label_agr_simple(cls, label): - """ - Prettifier for XMGrace (for old label names) - - :param label: a string to prettify - """ - - if label == 'G': - return r'\xG' - - return re.sub(r'(\d+)', r'\\s\1\\N', label) - - @classmethod - def _prettify_label_gnuplot(cls, label): - """ - Prettifier for Gnuplot - - :note: uses unicode, returns unicode strings (potentially, if needed) - - :param label: a string to prettify - """ - - label = ( - label - .replace('GAMMA', 'Γ') - .replace('DELTA', 'Δ') - .replace('LAMBDA', 'Λ') - .replace('SIGMA', 'Σ') - ) # yapf:disable - return re.sub(r'_(.?)', r'_{\1}', label) - - @classmethod - def _prettify_label_gnuplot_simple(cls, label): - """ - Prettifier for Gnuplot (for old label names) - - :note: uses unicode, returns unicode strings (potentially, if needed) - - :param label: a string to prettify - """ - - if label == 'G': - return 'Γ' - - return re.sub(r'(\d+)', r'_{\1}', label) - - @classmethod - def _prettify_label_latex(cls, label): - """ - Prettifier for matplotlib, using LaTeX syntax - - :param label: a string to prettify - """ - - label = ( - label - .replace('GAMMA', r'$\Gamma$') - .replace('DELTA', r'$\Delta$') - .replace('LAMBDA', r'$\Lambda$') - .replace('SIGMA', r'$\Sigma$') - ) # yapf:disable - label = re.sub(r'_(.?)', r'$_{\1}$', label) - - # label += r"$_{\vphantom{0}}$" - - return label - - @classmethod - def _prettify_label_latex_simple(cls, label): - """ - Prettifier for matplotlib, using LaTeX syntax (for old label names) - - :param label: a string to prettify - """ - if label == 'G': - return r'$\Gamma$' - - return re.sub(r'(\d+)', r'$_{\1}$', label) - - @classproperty - def prettifiers(cls) -> Dict[str, Any]: # pylint: disable=no-self-argument - """ - Property that returns a dictionary that for each string associates - the function to prettify a label - - :return: a dictionary where keys are strings and values are functions - """ - return { - 'agr_seekpath': cls._prettify_label_agr, - 'agr_simple': cls._prettify_label_agr_simple, - 'latex_simple': cls._prettify_label_latex_simple, - 'latex_seekpath': cls._prettify_label_latex, - 'gnuplot_simple': cls._prettify_label_gnuplot_simple, - 'gnuplot_seekpath': cls._prettify_label_gnuplot, - 'pass': cls._prettify_label_pass, - } - - @classmethod - def get_prettifiers(cls): - """ - Return a list of valid prettifier strings - - :return: a list of strings - """ - return sorted(cls.prettifiers.keys()) - - def __init__(self, format): # pylint: disable=redefined-builtin - """ - Create a class to pretttify strings of a given format - - :param format: a string with the format to use to prettify. - Valid formats are obtained from self.prettifiers - """ - if format is None: - format = 'pass' - - try: - self._prettifier_f = self.prettifiers[format] # pylint: disable=unsubscriptable-object - except KeyError: - raise ValueError(f"Unknown prettifier format {format}; valid formats: {', '.join(self.get_prettifiers())}") - - def prettify(self, label): - """ - Prettify a label using the format passed in the initializer - - :param label: the string to prettify - :return: a prettified string - """ - return self._prettifier_f(label) - - -def prettify_labels(labels, format=None): # pylint: disable=redefined-builtin - """ - Prettify label for typesetting in various formats - - :param labels: a list of length-2 tuples, in the format(position, label) - :param format: a string with the format for the prettifier (e.g. 'agr', - 'matplotlib', ...) - :return: the same list as labels, but with the second value possibly replaced - with a prettified version that typesets nicely in the selected format - """ - prettifier = Prettifier(format) - - return [(pos, prettifier.prettify(label)) for pos, label in labels] - - -def join_labels(labels, join_symbol='|', threshold=1.e-6): - """ - Join labels with a joining symbol when they are very close - - :param labels: a list of length-2 tuples, in the format(position, label) - :param join_symbol: the string to use to join different paths. By default, a pipe - :param threshold: the threshold to decide if two float values are the same and should - be joined - :return: the same list as labels, but with the second value possibly replaced - with strings joined when close enough - """ - if labels: - new_labels = [list(labels[0])] - # modify labels when in overlapping position - j = 0 - for i in range(1, len(labels)): - if abs(labels[i][0] - labels[i - 1][0]) < threshold: - new_labels[j][1] += join_symbol + labels[i][1] - else: - new_labels.append(list(labels[i])) - j += 1 - else: - new_labels = [] - - return new_labels - - -def strip_prefix(full_string, prefix): - """ - Strip the prefix from the given string and return it. If the prefix is not present - the original string will be returned unaltered - - :param full_string: the string from which to remove the prefix - :param prefix: the prefix to remove - :return: the string with prefix removed - """ - if full_string.startswith(prefix): - return full_string.rsplit(prefix)[1] - - return full_string - - -class Capturing: - """ - This class captures stdout and returns it - (as a list, split by lines). - - Note: if you raise a SystemExit, you have to catch it outside. - E.g., in our tests, this works:: - - import sys - with self.assertRaises(SystemExit): - with Capturing() as output: - sys.exit() - - But out of the testing environment, the code instead just exits. - - To use it, access the obj.stdout_lines, or just iterate over the object - - :param capture_stderr: if True, also captures sys.stderr. To access the - lines, use obj.stderr_lines. If False, obj.stderr_lines is None. - """ - - # pylint: disable=attribute-defined-outside-init - - def __init__(self, capture_stderr=False): - """Construct a new instance.""" - self.stdout_lines = [] - super().__init__() - - self._capture_stderr = capture_stderr - if self._capture_stderr: - self.stderr_lines = [] - else: - self.stderr_lines = None - - def __enter__(self): - """Enter the context where all output is captured.""" - self._stdout = sys.stdout - self._stringioout = io.StringIO() - sys.stdout = self._stringioout - if self._capture_stderr: - self._stderr = sys.stderr - self._stringioerr = io.StringIO() - sys.stderr = self._stringioerr - return self - - def __exit__(self, *args): - """Exit the context where all output is captured.""" - self.stdout_lines.extend(self._stringioout.getvalue().splitlines()) - sys.stdout = self._stdout - del self._stringioout # free up some memory - if self._capture_stderr: - self.stderr_lines.extend(self._stringioerr.getvalue().splitlines()) - sys.stderr = self._stderr - del self._stringioerr # free up some memory - - def __str__(self): - return str(self.stdout_lines) - - def __iter__(self): - return iter(self.stdout_lines) - - -class ErrorAccumulator: - """ - Allows to run a number of functions and collect all the errors they raise - - This allows to validate multiple things and tell the user about all the - errors encountered at once. Works best if the individual functions do not depend on each other. - - Does not allow to trace the stack of each error, therefore do not use for debugging, but for - semantical checking with user friendly error messages. - """ - - def __init__(self, *error_cls): - self.error_cls = error_cls - self.errors = {k: [] for k in self.error_cls} - - def run(self, function, *args, **kwargs): - try: - function(*args, **kwargs) - except self.error_cls as err: - self.errors[err.__class__].append(err) - - def success(self): - return bool(not any(self.errors.values())) - - def result(self, raise_error=Exception): - if raise_error: - self.raise_errors(raise_error) - return self.success(), self.errors - - def raise_errors(self, raise_cls): - if not self.success(): - raise raise_cls(f'The following errors were encountered: {self.errors}') - - -class DatetimePrecision: - """ - A simple class which stores a datetime object with its precision. No - internal check is done (cause itis not possible). - - precision: 1 (only full date) - 2 (date plus hour) - 3 (date + hour + minute) - 4 (dare + hour + minute +second) - """ - - def __init__(self, dtobj, precision): - """ Constructor to check valid datetime object and precision """ - - if not isinstance(dtobj, datetime): - raise TypeError('dtobj argument has to be a datetime object') - - if not isinstance(precision, int): - raise TypeError('precision argument has to be an integer') - - self.dtobj = dtobj - self.precision = precision diff --git a/aiida/common/warnings.py b/aiida/common/warnings.py deleted file mode 100644 index 32b0793b48..0000000000 --- a/aiida/common/warnings.py +++ /dev/null @@ -1,50 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Define warnings that can be thrown by AiiDA.""" -import os -import warnings - - -class AiidaDeprecationWarning(Warning): - """ - Class for AiiDA deprecations. - - It does *not* inherit, on purpose, from `DeprecationWarning` as - this would be filtered out by default. - Enabled by default, you can disable it by running in the shell:: - - verdi config warnings.showdeprecations False - """ - - -class AiidaEntryPointWarning(Warning): - """ - Class for warnings concerning AiiDA entry points. - """ - - -class AiidaTestWarning(Warning): - """ - Class for warnings concerning the AiiDA testing infrastructure. - """ - - -def warn_deprecation(message: str, version: int, stacklevel=2) -> None: - """Warns about a deprecation for a future aiida-core version. - - Warnings are activated if the `AIIDA_WARN_v{major}` environment variable is set to `True`. - - :param message: the message to be printed - :param version: the major version number of the future version - :param stacklevel: the stack level at which the warning is issued - """ - if os.environ.get(f'AIIDA_WARN_v{version}'): - message = f'{message} (this will be removed in v{version})' - warnings.warn(message, AiidaDeprecationWarning, stacklevel=stacklevel) diff --git a/aiida/engine/__init__.py b/aiida/engine/__init__.py deleted file mode 100644 index 68b6a6d626..0000000000 --- a/aiida/engine/__init__.py +++ /dev/null @@ -1,79 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with all the internals that make up the engine of `aiida-core`.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .daemon import * -from .exceptions import * -from .launch import * -from .persistence import * -from .processes import * -from .runners import * -from .utils import * - -__all__ = ( - 'AiiDAPersister', - 'Awaitable', - 'AwaitableAction', - 'AwaitableTarget', - 'BaseRestartWorkChain', - 'CalcJob', - 'CalcJobImporter', - 'CalcJobOutputPort', - 'CalcJobProcessSpec', - 'DaemonClient', - 'ExitCode', - 'ExitCodesNamespace', - 'FunctionProcess', - 'InputPort', - 'InterruptableFuture', - 'JobManager', - 'JobsList', - 'ObjectLoader', - 'OutputPort', - 'PORT_NAMESPACE_SEPARATOR', - 'PastException', - 'PortNamespace', - 'Process', - 'ProcessBuilder', - 'ProcessBuilderNamespace', - 'ProcessFuture', - 'ProcessHandlerReport', - 'ProcessSpec', - 'ProcessState', - 'Runner', - 'ToContext', - 'WithNonDb', - 'WithSerialize', - 'WorkChain', - 'append_', - 'assign_', - 'calcfunction', - 'construct_awaitable', - 'get_daemon_client', - 'get_object_loader', - 'if_', - 'interruptable_task', - 'is_process_function', - 'process_handler', - 'return_', - 'run', - 'run_get_node', - 'run_get_pk', - 'submit', - 'while_', - 'workfunction', -) - -# yapf: enable diff --git a/aiida/engine/daemon/__init__.py b/aiida/engine/daemon/__init__.py deleted file mode 100644 index 4012ec5d62..0000000000 --- a/aiida/engine/daemon/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with resources for the daemon.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .client import * - -__all__ = ( - 'DaemonClient', - 'get_daemon_client', -) - -# yapf: enable diff --git a/aiida/engine/daemon/client.py b/aiida/engine/daemon/client.py deleted file mode 100644 index d5f804eee4..0000000000 --- a/aiida/engine/daemon/client.py +++ /dev/null @@ -1,738 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Client to interact with the daemon.""" -from __future__ import annotations - -import contextlib -import enum -import os -import pathlib -import shutil -import socket -import subprocess -import sys -import tempfile -import time -import typing as t -from typing import TYPE_CHECKING - -import psutil - -from aiida.common.exceptions import AiidaException, ConfigurationError -from aiida.common.lang import type_check -from aiida.common.log import AIIDA_LOGGER -from aiida.manage.configuration import get_config, get_config_option -from aiida.manage.configuration.profile import Profile -from aiida.manage.manager import get_manager - -if TYPE_CHECKING: - from circus.client import CircusClient - -LOGGER = AIIDA_LOGGER.getChild('engine.daemon.client') - -VERDI_BIN = shutil.which('verdi') -# Recent versions of virtualenv create the environment variable VIRTUAL_ENV -VIRTUALENV = os.environ.get('VIRTUAL_ENV', None) - -__all__ = ('DaemonClient', 'get_daemon_client') - - -class ControllerProtocol(enum.Enum): - """The protocol to use for the controller of the Circus daemon.""" - - IPC = 0 - TCP = 1 - - -class DaemonException(AiidaException): - """Base class for exceptions related to the daemon.""" - - -class DaemonNotRunningException(DaemonException): - """Raised when a connection to the daemon is attempted but it is not running.""" - - -class DaemonTimeoutException(DaemonException): - """Raised when a connection to the daemon is attempted but it times out.""" - - -class DaemonStalePidException(DaemonException): - """Raised when a connection to the daemon is attempted but it fails and the PID file appears to be stale.""" - - -def get_daemon_client(profile_name: str | None = None) -> 'DaemonClient': - """Return the daemon client for the given profile or the current profile if not specified. - - :param profile_name: Optional profile name to load. - :return: The daemon client. - - :raises aiida.common.MissingConfigurationError: if the configuration file cannot be found. - :raises aiida.common.ProfileConfigurationError: if the given profile does not exist. - """ - profile = get_manager().load_profile(profile_name) - return DaemonClient(profile) - - -class DaemonClient: # pylint: disable=too-many-public-methods - """Client to interact with the daemon.""" - - _DAEMON_NAME = 'aiida-{name}' - _ENDPOINT_PROTOCOL = ControllerProtocol.IPC - - def __init__(self, profile: Profile): - """Construct an instance for a given profile. - - :param profile: The profile instance. - """ - type_check(profile, Profile) - config = get_config() - self._profile = profile - self._socket_directory: str | None = None - self._daemon_timeout: int = config.get_option('daemon.timeout', scope=profile.name) - - @property - def profile(self) -> Profile: - return self._profile - - @property - def daemon_name(self) -> str: - """Get the daemon name which is tied to the profile name.""" - return self._DAEMON_NAME.format(name=self.profile.name) - - @property - def _verdi_bin(self) -> str: - """Return the absolute path to the ``verdi`` binary. - - :raises ConfigurationError: If the path to ``verdi`` could not be found - """ - if VERDI_BIN is None: - raise ConfigurationError( - "Unable to find 'verdi' in the path. Make sure that you are working " - "in a virtual environment, or that at least the 'verdi' executable is on the PATH" - ) - - return VERDI_BIN - - def cmd_start_daemon(self, number_workers: int = 1, foreground: bool = False) -> list[str]: - """Return the command to start the daemon. - - :param number_workers: Number of daemon workers to start. - :param foreground: Whether to launch the subprocess in the background or not. - """ - command = [self._verdi_bin, '-p', self.profile.name, 'daemon', 'start-circus', str(number_workers)] - - if foreground: - command.append('--foreground') - - return command - - @property - def cmd_start_daemon_worker(self) -> list[str]: - """Return the command to start a daemon worker process.""" - return [self._verdi_bin, '-p', self.profile.name, 'daemon', 'worker'] - - @property - def loglevel(self) -> str: - return get_config_option('logging.circus_loglevel') - - @property - def virtualenv(self) -> str | None: - return VIRTUALENV - - @property - def circus_log_file(self) -> str: - return self.profile.filepaths['circus']['log'] - - @property - def circus_pid_file(self) -> str: - return self.profile.filepaths['circus']['pid'] - - @property - def circus_port_file(self) -> str: - return self.profile.filepaths['circus']['port'] - - @property - def circus_socket_file(self) -> str: - return self.profile.filepaths['circus']['socket']['file'] - - @property - def circus_socket_endpoints(self) -> dict[str, str]: - return self.profile.filepaths['circus']['socket'] - - @property - def daemon_log_file(self) -> str: - return self.profile.filepaths['daemon']['log'] - - @property - def daemon_pid_file(self) -> str: - return self.profile.filepaths['daemon']['pid'] - - def get_circus_port(self) -> int: - """Retrieve the port for the circus controller, which should be written to the circus port file. - - If the daemon is running, the port file should exist and contain the port to which the controller is connected. - If it cannot be read, a RuntimeError will be thrown. If the daemon is not running, an available port will be - requested from the operating system, written to the port file and returned. - - :return: The port for the circus controller. - """ - if self.is_daemon_running: - try: - with open(self.circus_port_file, 'r', encoding='utf8') as fhandle: - return int(fhandle.read().strip()) - except (ValueError, IOError): - raise RuntimeError('daemon is running so port file should have been there but could not read it') - else: - port = self.get_available_port() - with open(self.circus_port_file, 'w', encoding='utf8') as fhandle: - fhandle.write(str(port)) - - return port - - @staticmethod - def get_env() -> dict[str, str]: - """Return the environment for this current process. - - This method is used to pass variables from the environment of the current process to a subprocess that is - spawned when the daemon or a daemon worker is started. - - It replicates the ``PATH``, ``PYTHONPATH` and the ``AIIDA_PATH`` environment variables. The ``PYTHONPATH`` - variable ensures that all Python modules that can be imported by the parent process, are also importable by - the subprocess. The ``AIIDA_PATH`` variable ensures that the subprocess will use the same AiiDA configuration - directory as used by the current process. - """ - env = os.environ.copy() - env['PATH'] = ':'.join([os.path.dirname(sys.executable), env['PATH']]) - env['PYTHONPATH'] = ':'.join(sys.path) - env['AIIDA_PATH'] = get_config().dirpath - env['PYTHONUNBUFFERED'] = 'True' - return env - - def get_circus_socket_directory(self) -> str: - """Retrieve the absolute path of the directory where the circus sockets are stored. - - If the daemon is running, the sockets file should exist and contain the absolute path of the directory that - contains the sockets of the circus endpoints. If it cannot be read, a ``RuntimeError`` will be thrown. If the - daemon is not running, a temporary directory will be created and its path will be written to the sockets file - and returned. - - .. note:: A temporary folder needs to be used for the sockets because UNIX limits the filepath length to - 107 bytes. Placing the socket files in the AiiDA config folder might seem like the more logical choice - but that folder can be placed in an arbitrarily nested directory, the socket filename will exceed the - limit. The solution is therefore to always store them in the temporary directory of the operation system - whose base path is typically short enough as to not exceed the limit - - :return: The absolute path of directory to write the sockets to. - """ - if self.is_daemon_running: - try: - with open(self.circus_socket_file, 'r', encoding='utf8') as fhandle: - content = fhandle.read().strip() - return content - except (ValueError, IOError): - raise RuntimeError('daemon is running so sockets file should have been there but could not read it') - else: - - # The SOCKET_DIRECTORY is already set, a temporary directory was already created and the same should be used - if self._socket_directory is not None: - return self._socket_directory - - socket_dir_path = tempfile.mkdtemp() - with open(self.circus_socket_file, 'w', encoding='utf8') as fhandle: - fhandle.write(str(socket_dir_path)) - - self._socket_directory = socket_dir_path - return socket_dir_path - - def get_daemon_pid(self) -> int | None: - """Get the daemon pid which should be written in the daemon pid file specific to the profile. - - :return: The pid of the circus daemon process or None if not found. - """ - if os.path.isfile(self.circus_pid_file): - try: - with open(self.circus_pid_file, 'r', encoding='utf8') as fhandle: - content = fhandle.read().strip() - return int(content) - except (ValueError, IOError): - return None - else: - return None - - @property - def is_daemon_running(self) -> bool: - """Return whether the daemon is running, which is determined by seeing if the daemon pid file is present. - - :return: True if daemon is running, False otherwise. - """ - return self.get_daemon_pid() is not None - - def delete_circus_socket_directory(self) -> None: - """Attempt to delete the directory used to store the circus endpoint sockets. - - Will not raise if the directory does not exist. - """ - directory = self.get_circus_socket_directory() - - try: - shutil.rmtree(directory) - except OSError as exception: - if exception.errno == 2: - pass - else: - raise - - @classmethod - def get_available_port(cls): - """Get an available port from the operating system. - - :return: A currently available port. - """ - open_socket = socket.socket() - open_socket.bind(('', 0)) - return open_socket.getsockname()[1] - - def get_controller_endpoint(self): - """Get the endpoint string for the circus controller. - - For the IPC protocol a profile specific socket will be used, whereas for the TCP protocol an available port will - be found and saved in the profile specific port file. - - :return: The endpoint string. - """ - if self._ENDPOINT_PROTOCOL == ControllerProtocol.IPC: - endpoint = self.get_ipc_endpoint('controller') - elif self._ENDPOINT_PROTOCOL == ControllerProtocol.TCP: - endpoint = self.get_tcp_endpoint(self.get_circus_port()) - else: - raise ValueError(f'invalid controller protocol {self._ENDPOINT_PROTOCOL}') - - return endpoint - - def get_pubsub_endpoint(self): - """Get the endpoint string for the circus pubsub endpoint. - - For the IPC protocol a profile specific socket will be used, whereas for the TCP protocol any available port - will be used. - - :return: The endpoint string. - """ - if self._ENDPOINT_PROTOCOL == ControllerProtocol.IPC: - endpoint = self.get_ipc_endpoint('pubsub') - elif self._ENDPOINT_PROTOCOL == ControllerProtocol.TCP: - endpoint = self.get_tcp_endpoint() - else: - raise ValueError(f'invalid controller protocol {self._ENDPOINT_PROTOCOL}') - - return endpoint - - def get_stats_endpoint(self): - """Get the endpoint string for the circus stats endpoint. - - For the IPC protocol a profile specific socket will be used, whereas for the TCP protocol any available port - will be used. - - :return: The endpoint string. - """ - if self._ENDPOINT_PROTOCOL == ControllerProtocol.IPC: - endpoint = self.get_ipc_endpoint('stats') - elif self._ENDPOINT_PROTOCOL == ControllerProtocol.TCP: - endpoint = self.get_tcp_endpoint() - else: - raise ValueError(f'invalid controller protocol {self._ENDPOINT_PROTOCOL}') - - return endpoint - - def get_ipc_endpoint(self, endpoint): - """Get the ipc endpoint string for a circus daemon endpoint for a given socket. - - :param endpoint: The circus endpoint for which to return a socket. - :return: The ipc endpoint string. - """ - filepath = self.get_circus_socket_directory() - filename = self.circus_socket_endpoints[endpoint] - template = 'ipc://{filepath}/{filename}' - endpoint = template.format(filepath=filepath, filename=filename) - - return endpoint - - def get_tcp_endpoint(self, port=None): - """Get the tcp endpoint string for a circus daemon endpoint. - - If the port is unspecified, the operating system will be asked for a currently available port. - - :param port: A port to use for the endpoint. - :return: The tcp endpoint string. - """ - if port is None: - port = self.get_available_port() - - template = 'tcp://127.0.0.1:{port}' - endpoint = template.format(port=port) - - return endpoint - - @contextlib.contextmanager - def get_client(self, timeout: int | None = None) -> 'CircusClient': - """Return an instance of the CircusClient. - - The endpoint is defined by the controller endpoint, which used the port that was written to the port file upon - starting of the daemon. - - :param timeout: Optional timeout to set for trying to reach the circus daemon. Default is set on the client upon - instantiation taken from the ``daemon.timeout`` config option. - :return: CircusClient instance - """ - from circus.client import CircusClient - - try: - client = CircusClient(endpoint=self.get_controller_endpoint(), timeout=timeout or self._daemon_timeout) - yield client - finally: - client.stop() - - def call_client(self, command: dict[str, t.Any], timeout: int | None = None) -> dict[str, t.Any]: - """Call the client with a specific command. - - Will check whether the daemon is running first by checking for the pid file. When the pid is found yet the call - still fails with a timeout, this means the daemon was actually not running and it was terminated unexpectedly - causing the pid file to not be cleaned up properly. - - :param command: Command to call the circus client with. - :param timeout: Optional timeout to set for trying to reach the circus daemon. Default is set on the client upon - instantiation taken from the ``daemon.timeout`` config option. - :return: The result of the circus client call. - :raises DaemonException: If the daemon is not running or cannot be reached. - :raises DaemonTimeoutException: If the connection to the daemon timed out. - :raises DaemonException: If the connection to the daemon failed for any other reason. - """ - from circus.exc import CallError - - try: - with self.get_client(timeout=timeout) as client: - result = client.call(command) - except CallError as exception: - if self.get_daemon_pid() is None: - raise DaemonNotRunningException('The daemon is not running.') from exception - - if self._is_pid_file_stale: - raise DaemonStalePidException( - 'The daemon could not be reached, seemingly because of a stale PID file. Either stop or start the ' - 'daemon to remove it and restore the daemon to a functional state.' - ) from exception - - if str(exception) == 'Timed out.': - raise DaemonTimeoutException('Connection to the daemon timed out.') from exception - - raise DaemonException('Connection to the daemon failed.') from exception - - return result - - def get_status(self, timeout: int | None = None) -> dict[str, t.Any]: - """Return the status of the daemon. - - :param timeout: Optional timeout to set for trying to reach the circus daemon. Default is set on the client upon - instantiation taken from the ``daemon.timeout`` config option. - :returns: The client call response. If successful, will contain 'pid' key. - """ - command = {'command': 'status', 'properties': {'name': self.daemon_name}} - response = self.call_client(command, timeout=timeout) - response['pid'] = self.get_daemon_pid() - return response - - def get_numprocesses(self, timeout: int | None = None) -> dict[str, t.Any]: - """Get the number of running daemon processes. - - :param timeout: Optional timeout to set for trying to reach the circus daemon. Default is set on the client upon - instantiation taken from the ``daemon.timeout`` config option. - :return: The client call response. If successful, will contain 'numprocesses' key. - """ - command = {'command': 'numprocesses', 'properties': {'name': self.daemon_name}} - return self.call_client(command, timeout=timeout) - - def get_worker_info(self, timeout: int | None = None) -> dict[str, t.Any]: - """Get workers statistics for this daemon. - - :param timeout: Optional timeout to set for trying to reach the circus daemon. Default is set on the client upon - instantiation taken from the ``daemon.timeout`` config option. - :return: The client call response. If successful, will contain 'info' key. - """ - command = {'command': 'stats', 'properties': {'name': self.daemon_name}} - return self.call_client(command, timeout=timeout) - - def get_daemon_info(self, timeout: int | None = None) -> dict[str, t.Any]: - """Get statistics about this daemon itself. - - :param timeout: Optional timeout to set for trying to reach the circus daemon. Default is set on the client upon - instantiation taken from the ``daemon.timeout`` config option. - :return: The client call response. If successful, will contain 'info' key. - """ - command = {'command': 'dstats', 'properties': {}} - return self.call_client(command, timeout=timeout) - - def increase_workers(self, number: int, timeout: int | None = None) -> dict[str, t.Any]: - """Increase the number of workers. - - :param number: The number of workers to add. - :param timeout: Optional timeout to set for trying to reach the circus daemon. Default is set on the client upon - instantiation taken from the ``daemon.timeout`` config option. - :return: The client call response. - """ - command = {'command': 'incr', 'properties': {'name': self.daemon_name, 'nb': number}} - return self.call_client(command, timeout=timeout) - - def decrease_workers(self, number: int, timeout: int | None = None) -> dict[str, t.Any]: - """Decrease the number of workers. - - :param number: The number of workers to remove. - :param timeout: Optional timeout to set for trying to reach the circus daemon. Default is set on the client upon - instantiation taken from the ``daemon.timeout`` config option. - :return: The client call response. - """ - command = {'command': 'decr', 'properties': {'name': self.daemon_name, 'nb': number}} - return self.call_client(command, timeout=timeout) - - def start_daemon( - self, number_workers: int = 1, foreground: bool = False, wait: bool = True, timeout: int | None = None - ) -> None: - """Start the daemon in a sub process running in the background. - - :param number_workers: Number of daemon workers to start. - :param foreground: Whether to launch the subprocess in the background or not. - :param wait: Boolean to indicate whether to wait for the result of the command. - :param timeout: Optional timeout to set for trying to reach the circus daemon after the subprocess has started. - Default is set on the client upon instantiation taken from the ``daemon.timeout`` config option. - :raises DaemonException: If the command to start the daemon subprocess excepts. - :raises DaemonTimeoutException: If the daemon starts but then is unresponsive or in an unexpected state. - """ - self._clean_potentially_stale_pid_file() - - env = self.get_env() - command = self.cmd_start_daemon(number_workers, foreground) - timeout = timeout or self._daemon_timeout - - try: - subprocess.check_output(command, env=env, stderr=subprocess.STDOUT) # pylint: disable=unexpected-keyword-arg - except subprocess.CalledProcessError as exception: - raise DaemonException('The daemon failed to start.') from exception - - if not wait: - return - - self._await_condition( - lambda: self.is_daemon_running, - DaemonTimeoutException(f'The daemon failed to start or is unresponsive after {timeout} seconds.'), - timeout=timeout, - ) - - def restart_daemon(self, wait: bool = True, timeout: int | None = None) -> dict[str, t.Any]: - """Restart the daemon. - - :param wait: Boolean to indicate whether to wait for the result of the command. - :param timeout: Optional timeout to set for trying to reach the circus daemon. Default is set on the client upon - instantiation taken from the ``daemon.timeout`` config option. - :returns: The client call response. - :raises DaemonException: If the daemon is not running or cannot be reached. - :raises DaemonTimeoutException: If the connection to the daemon timed out. - :raises DaemonException: If the connection to the daemon failed for any other reason. - """ - command = {'command': 'restart', 'properties': {'name': self.daemon_name, 'waiting': wait}} - return self.call_client(command, timeout=timeout) - - def stop_daemon(self, wait: bool = True, timeout: int | None = None) -> dict[str, t.Any]: - """Stop the daemon. - - :param wait: Boolean to indicate whether to wait for the result of the command. - :param timeout: Optional timeout to set for trying to reach the circus daemon. Default is set on the client upon - instantiation taken from the ``daemon.timeout`` config option. - :returns: The client call response. - :raises DaemonException: If the daemon is not running or cannot be reached. - :raises DaemonTimeoutException: If the connection to the daemon timed out. - :raises DaemonException: If the connection to the daemon failed for any other reason. - """ - self._clean_potentially_stale_pid_file() - - command = {'command': 'quit', 'properties': {'waiting': wait}} - response = self.call_client(command, timeout=timeout) - - if self._ENDPOINT_PROTOCOL == ControllerProtocol.IPC: - self.delete_circus_socket_directory() - - return response - - def _clean_potentially_stale_pid_file(self) -> None: - """Check the daemon PID file and delete it if it is likely to be stale.""" - try: - self._check_pid_file() - except DaemonException as exception: - pathlib.Path(self.circus_pid_file).unlink(missing_ok=True) - LOGGER.warning(f'Deleted apparently stale daemon PID file: {exception}') - - @property - def _is_pid_file_stale(self) -> bool: - """Return whether the daemon PID file is likely to be stale. - - :returns: ``True`` if the PID file is likely to be stale, ``False`` otherwise. - """ - try: - self._check_pid_file() - except DaemonException: - return True - - return False - - def _check_pid_file(self) -> None: - """Check that the daemon's PID file is not stale. - - Checks if the PID contained in the circus PID file matches a valid running ``verdi`` process. The PID file is - considered stale if any of the following conditions are true: - - * The process with the given PID no longer exists - * The process name does not match the command of the circus daemon - * The process username does not match the username of this Python interpreter - - In the latter two cases, the process with the PID of the PID file exists, but it is very likely that it is not - the original process that created the PID file, since the command or user is different, indicating the original - process died and the PID was recycled for a new process. - - The PID file can got stale if a system is shut down suddenly and so the process is killed but the PID file is - not deleted in time. When the `get_daemon_pid()` method is called, an incorrect PID is returned. Alternatively, - another process or the user may have meddled with the PID file in some way, corrupting it. - - :raises DaemonException: If the PID file is likely to be stale. - """ - pid = self.get_daemon_pid() - - if pid is None: - return - - try: - process = psutil.Process(pid) - - # The circus daemon process can appear as ``start-circus`` or ``circusd``. See this issue comment for - # details: https://github.com/aiidateam/aiida-core/issues/5336#issuecomment-1376093322 - if not any(cmd in process.cmdline() for cmd in ['start-circus', 'circusd']): - raise DaemonException( - f'process command `{process.cmdline()}` of PID `{pid}` does not match expected AiiDA daemon command' - ) - - process_user = process.username() - current_user = psutil.Process().username() - - if process_user != current_user: - raise DaemonException( - f'process user `{process_user}` of PID `{pid}` does not match current user `{current_user}`' - ) - - except (psutil.AccessDenied, psutil.NoSuchProcess, DaemonException) as exception: - raise DaemonException(exception) from exception - - @staticmethod - def _await_condition(condition: t.Callable, exception: Exception, timeout: int = 5, interval: float = 0.1): - """Await a condition to evaluate to ``True`` or raise the exception if the timeout is reached. - - :param condition: A callable that is waited for to return ``True``. - :param exception: Raise this exception if ``condition`` does not return ``True`` after ``timeout`` seconds. - :param timeout: Wait this number of seconds for ``condition`` to return ``True`` before raising. - :param interval: The time in seconds to wait between invocations of ``condition``. - :raises: The exception provided by ``exception`` if timeout is reached. - """ - start_time = time.time() - - while not condition(): - - time.sleep(interval) - - if time.time() - start_time > timeout: - raise exception - - def _start_daemon(self, number_workers: int = 1, foreground: bool = False) -> None: - """Start the daemon. - - .. warning:: This will daemonize the current process and put it in the background. It is most likely not what - you want to call if you want to start the daemon from the Python API. Instead you probably will want to use - the :meth:`aiida.engine.daemon.client.DaemonClient.start_daemon` function instead. - - :param number_workers: Number of daemon workers to start. - :param foreground: Whether to launch the subprocess in the background or not. - """ - from circus import get_arbiter - from circus import logger as circus_logger - from circus.circusd import daemonize - from circus.pidfile import Pidfile - from circus.util import check_future_exception_and_log, configure_logger - - if foreground and number_workers > 1: - raise ValueError('can only run a single worker when running in the foreground') - - loglevel = self.loglevel - logoutput = '-' - - if not foreground: - logoutput = self.circus_log_file - - arbiter_config = { - 'controller': self.get_controller_endpoint(), - 'pubsub_endpoint': self.get_pubsub_endpoint(), - 'stats_endpoint': self.get_stats_endpoint(), - 'logoutput': logoutput, - 'loglevel': loglevel, - 'debug': False, - 'statsd': True, - 'pidfile': self.circus_pid_file, - 'watchers': [{ - 'cmd': ' '.join(self.cmd_start_daemon_worker), - 'name': self.daemon_name, - 'numprocesses': number_workers, - 'virtualenv': self.virtualenv, - 'copy_env': True, - 'stdout_stream': { - 'class': 'FileStream', - 'filename': self.daemon_log_file, - }, - 'stderr_stream': { - 'class': 'FileStream', - 'filename': self.daemon_log_file, - }, - 'env': self.get_env(), - }] - } # yapf: disable - - if not foreground: - daemonize() - - arbiter = get_arbiter(**arbiter_config) - pidfile = Pidfile(arbiter.pidfile) - pidfile.create(os.getpid()) - - # Configure the logger - loggerconfig = None - loggerconfig = loggerconfig or arbiter.loggerconfig or None - configure_logger(circus_logger, loglevel, logoutput, loggerconfig) - - # Main loop - should_restart = True - - while should_restart: - try: - future = arbiter.start() - should_restart = False - if check_future_exception_and_log(future) is None: - should_restart = arbiter._restarting # pylint: disable=protected-access - except Exception as exception: - # Emergency stop - arbiter.loop.run_sync(arbiter._emergency_stop) # pylint: disable=protected-access - raise exception - except KeyboardInterrupt: - pass - finally: - arbiter = None - if pidfile is not None: - pidfile.unlink() diff --git a/aiida/engine/daemon/execmanager.py b/aiida/engine/daemon/execmanager.py deleted file mode 100644 index 74468eb34c..0000000000 --- a/aiida/engine/daemon/execmanager.py +++ /dev/null @@ -1,616 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -This file contains the main routines to submit, check and retrieve calculation -results. These are general and contain only the main logic; where appropriate, -the routines make reference to the suitable plugins for all -plugin-specific operations. -""" -from __future__ import annotations - -from collections.abc import Mapping -from logging import LoggerAdapter -import os -import pathlib -import shutil -from tempfile import NamedTemporaryFile -from typing import Any, List -from typing import Mapping as MappingType -from typing import Optional, Tuple, Union - -from aiida.common import AIIDA_LOGGER, exceptions -from aiida.common.datastructures import CalcInfo -from aiida.common.folders import SandboxFolder -from aiida.common.links import LinkType -from aiida.engine.processes.exit_code import ExitCode -from aiida.manage.configuration import get_config_option -from aiida.orm import CalcJobNode, Code, FolderData, Node, PortableCode, RemoteData, load_node -from aiida.orm.utils.log import get_dblogger_extra -from aiida.repository.common import FileType -from aiida.schedulers.datastructures import JobState -from aiida.transports import Transport - -REMOTE_WORK_DIRECTORY_LOST_FOUND = 'lost+found' - -EXEC_LOGGER = AIIDA_LOGGER.getChild('execmanager') - - -def _find_data_node(inputs: MappingType[str, Any], uuid: str) -> Optional[Node]: - """Find and return the node with the given UUID from a nested mapping of input nodes. - - :param inputs: (nested) mapping of nodes - :param uuid: UUID of the node to find - :return: instance of `Node` or `None` if not found - """ - data_node = None - - for input_node in inputs.values(): - if isinstance(input_node, Mapping): - data_node = _find_data_node(input_node, uuid) - elif isinstance(input_node, Node) and input_node.uuid == uuid: - data_node = input_node - if data_node is not None: - break - - return data_node - - -def upload_calculation( - node: CalcJobNode, - transport: Transport, - calc_info: CalcInfo, - folder: SandboxFolder, - inputs: Optional[MappingType[str, Any]] = None, - dry_run: bool = False -) -> None: - """Upload a `CalcJob` instance - - :param node: the `CalcJobNode`. - :param transport: an already opened transport to use to submit the calculation. - :param calc_info: the calculation info datastructure returned by `CalcJob.presubmit` - :param folder: temporary local file system folder containing the inputs written by `CalcJob.prepare_for_submission` - """ - # pylint: disable=too-many-locals,too-many-branches,too-many-statements - - # If the calculation already has a `remote_folder`, simply return. The upload was apparently already completed - # before, which can happen if the daemon is restarted and it shuts down after uploading but before getting the - # chance to perform the state transition. Upon reloading this calculation, it will re-attempt the upload. - link_label = 'remote_folder' - if node.base.links.get_outgoing(RemoteData, link_label_filter=link_label).first(): - EXEC_LOGGER.warning(f'CalcJobNode<{node.pk}> already has a `{link_label}` output: skipping upload') - return calc_info - - computer = node.computer - - codes_info = calc_info.codes_info - input_codes = [load_node(_.code_uuid, sub_classes=(Code,)) for _ in codes_info] - - logger_extra = get_dblogger_extra(node) - transport.set_logger_extra(logger_extra) - logger = LoggerAdapter(logger=EXEC_LOGGER, extra=logger_extra) - - if not dry_run and not node.is_stored: - raise ValueError( - f'Cannot submit calculation {node.pk} because it is not stored! If you just want to test the submission, ' - 'set `metadata.dry_run` to True in the inputs.' - ) - - # If we are performing a dry-run, the working directory should actually be a local folder that should already exist - if dry_run: - workdir = transport.getcwd() - else: - remote_user = transport.whoami() - remote_working_directory = computer.get_workdir().format(username=remote_user) - if not remote_working_directory.strip(): - raise exceptions.ConfigurationError( - f'[submission of calculation {node.pk}] No remote_working_directory ' - f"configured for computer '{computer.label}'" - ) - - # If it already exists, no exception is raised - try: - transport.chdir(remote_working_directory) - except IOError: - logger.debug( - f'[submission of calculation {node.pk}] Unable to ' - f'chdir in {remote_working_directory}, trying to create it' - ) - try: - transport.makedirs(remote_working_directory) - transport.chdir(remote_working_directory) - except EnvironmentError as exc: - raise exceptions.ConfigurationError( - f'[submission of calculation {node.pk}] ' - f'Unable to create the remote directory {remote_working_directory} on ' - f"computer '{computer.label}': {exc}" - ) - # Store remotely with sharding (here is where we choose - # the folder structure of remote jobs; then I store this - # in the calculation properties using _set_remote_dir - # and I do not have to know the logic, but I just need to - # read the absolute path from the calculation properties. - transport.mkdir(calc_info.uuid[:2], ignore_existing=True) - transport.chdir(calc_info.uuid[:2]) - transport.mkdir(calc_info.uuid[2:4], ignore_existing=True) - transport.chdir(calc_info.uuid[2:4]) - - try: - # The final directory may already exist, most likely because this function was already executed once, but - # failed and as a result was rescheduled by the eninge. In this case it would be fine to delete the folder - # and create it from scratch, except that we cannot be sure that this the actual case. Therefore, to err on - # the safe side, we move the folder to the lost+found directory before recreating the folder from scratch - transport.mkdir(calc_info.uuid[4:]) - except OSError: - # Move the existing directory to lost+found, log a warning and create a clean directory anyway - path_existing = os.path.join(transport.getcwd(), calc_info.uuid[4:]) - path_lost_found = os.path.join(remote_working_directory, REMOTE_WORK_DIRECTORY_LOST_FOUND) - path_target = os.path.join(path_lost_found, calc_info.uuid) - logger.warning( - f'tried to create path {path_existing} but it already exists, moving the entire folder to {path_target}' - ) - - # Make sure the lost+found directory exists, then copy the existing folder there and delete the original - transport.mkdir(path_lost_found, ignore_existing=True) - transport.copytree(path_existing, path_target) - transport.rmtree(path_existing) - - # Now we can create a clean folder for this calculation - transport.mkdir(calc_info.uuid[4:]) - finally: - transport.chdir(calc_info.uuid[4:]) - - # I store the workdir of the calculation for later file retrieval - workdir = transport.getcwd() - node.set_remote_workdir(workdir) - - # I first create the code files, so that the code can put - # default files to be overwritten by the plugin itself. - # Still, beware! The code file itself could be overwritten... - # But I checked for this earlier. - for code in input_codes: - if isinstance(code, PortableCode): - # Note: this will possibly overwrite files - for root, dirnames, filenames in code.base.repository.walk(): - # mkdir of root - transport.makedirs(root, ignore_existing=True) - - # remotely mkdir first - for dirname in dirnames: - transport.makedirs((root / dirname), ignore_existing=True) - - # Note, once #2579 is implemented, use the `node.open` method instead of the named temporary file in - # combination with the new `Transport.put_object_from_filelike` - # Since the content of the node could potentially be binary, we read the raw bytes and pass them on - for filename in filenames: - with NamedTemporaryFile(mode='wb+') as handle: - content = code.base.repository.get_object_content((pathlib.Path(root) / filename), mode='rb') - handle.write(content) - handle.flush() - transport.put(handle.name, (root / filename)) - transport.chmod(code.filepath_executable, 0o755) # rwxr-xr-x - - # local_copy_list is a list of tuples, each with (uuid, dest_path, rel_path) - # NOTE: validation of these lists are done inside calculation.presubmit() - local_copy_list = calc_info.local_copy_list or [] - remote_copy_list = calc_info.remote_copy_list or [] - remote_symlink_list = calc_info.remote_symlink_list or [] - provenance_exclude_list = calc_info.provenance_exclude_list or [] - - for uuid, filename, target in local_copy_list: - logger.debug(f'[submission of calculation {node.uuid}] copying local file/folder to {target}') - - try: - data_node = load_node(uuid=uuid) - except exceptions.NotExistent: - data_node = _find_data_node(inputs, uuid) if inputs else None - - if data_node is None: - logger.warning(f'failed to load Node<{uuid}> specified in the `local_copy_list`') - else: - - # If no explicit source filename is defined, we assume the top-level directory - filename_source = filename or '.' - filename_target = target or '' - - # Make the target filepath absolute and create any intermediate directories if they don't yet exist - filepath_target = pathlib.Path(folder.abspath) / filename_target - filepath_target.parent.mkdir(parents=True, exist_ok=True) - - if data_node.base.repository.get_object(filename_source).file_type == FileType.DIRECTORY: - # If the source object is a directory, we copy its entire contents - data_node.base.repository.copy_tree(filepath_target, filename_source) - sources = data_node.base.repository.list_object_names(filename_source) - if filename_target: - sources = [str(pathlib.Path(filename_target) / subpath) for subpath in sources] - provenance_exclude_list.extend(sources) - else: - # Otherwise, simply copy the file - with folder.open(target, 'wb') as handle: - with data_node.base.repository.open(filename, 'rb') as source: - shutil.copyfileobj(source, handle) - - provenance_exclude_list.append(target) - - # In a dry_run, the working directory is the raw input folder, which will already contain these resources - if not dry_run: - for filename in folder.get_content_list(): - logger.debug(f'[submission of calculation {node.pk}] copying file/folder {filename}...') - transport.put(folder.get_abs_path(filename), filename) - - for (remote_computer_uuid, remote_abs_path, dest_rel_path) in remote_copy_list: - if remote_computer_uuid == computer.uuid: - logger.debug( - f'[submission of calculation {node.pk}] copying {dest_rel_path} ' - f'remotely, directly on the machine {computer.label}' - ) - try: - transport.copy(remote_abs_path, dest_rel_path) - except FileNotFoundError: - logger.warning( - f'[submission of calculation {node.pk}] Unable to copy remote ' - f'resource from {remote_abs_path} to {dest_rel_path}! NOT Stopping but just ignoring!.' - ) - except (IOError, OSError): - logger.warning( - f'[submission of calculation {node.pk}] Unable to copy remote ' - f'resource from {remote_abs_path} to {dest_rel_path}! Stopping.' - ) - raise - else: - raise NotImplementedError( - f'[submission of calculation {node.pk}] Remote copy between two different machines is ' - 'not implemented yet' - ) - - for (remote_computer_uuid, remote_abs_path, dest_rel_path) in remote_symlink_list: - if remote_computer_uuid == computer.uuid: - logger.debug( - f'[submission of calculation {node.pk}] copying {dest_rel_path} remotely, ' - f'directly on the machine {computer.label}' - ) - try: - transport.symlink(remote_abs_path, dest_rel_path) - except (IOError, OSError): - logger.warning( - f'[submission of calculation {node.pk}] Unable to create remote symlink ' - f'from {remote_abs_path} to {dest_rel_path}! Stopping.' - ) - raise - else: - raise IOError( - f'It is not possible to create a symlink between two different machines for calculation {node.pk}' - ) - else: - - if remote_copy_list: - filepath = os.path.join(workdir, '_aiida_remote_copy_list.txt') - with open(filepath, 'w', encoding='utf-8') as handle: # type: ignore[assignment] - for remote_computer_uuid, remote_abs_path, dest_rel_path in remote_copy_list: - handle.write( - f'would have copied {remote_abs_path} to {dest_rel_path} in working ' - f'directory on remote {computer.label}' - ) - - if remote_symlink_list: - filepath = os.path.join(workdir, '_aiida_remote_symlink_list.txt') - with open(filepath, 'w', encoding='utf-8') as handle: # type: ignore[assignment] - for remote_computer_uuid, remote_abs_path, dest_rel_path in remote_symlink_list: - handle.write( - f'would have created symlinks from {remote_abs_path} to {dest_rel_path} in working' - f'directory on remote {computer.label}' - ) - - # Loop recursively over content of the sandbox folder copying all that are not in `provenance_exclude_list`. Note - # that directories are not created explicitly. The `node.put_object_from_filelike` call will create intermediate - # directories for nested files automatically when needed. This means though that empty folders in the sandbox or - # folders that would be empty when considering the `provenance_exclude_list` will *not* be copied to the repo. The - # advantage of this explicit copying instead of deleting the files from `provenance_exclude_list` from the sandbox - # first before moving the entire remaining content to the node's repository, is that in this way we are guaranteed - # not to accidentally move files to the repository that should not go there at all cost. Note that all entries in - # the provenance exclude list are normalized first, just as the paths that are in the sandbox folder, otherwise the - # direct equality test may fail, e.g.: './path/file.txt' != 'path/file.txt' even though they reference the same file - provenance_exclude_list = [os.path.normpath(entry) for entry in provenance_exclude_list] - - for root, _, filenames in os.walk(folder.abspath): - for filename in filenames: - filepath = os.path.join(root, filename) - relpath = os.path.normpath(os.path.relpath(filepath, folder.abspath)) - dirname = os.path.dirname(relpath) - - # Construct a list of all (partial) filepaths - # For example, if `relpath == 'some/sub/directory/file.txt'` then the list of relative directory paths is - # ['some', 'some/sub', 'some/sub/directory'] - # This is necessary, because if any of these paths is in the `provenance_exclude_list` the file should not - # be copied over. - components = dirname.split(os.sep) - dirnames = [os.path.join(*components[:i]) for i in range(1, len(components) + 1)] - if relpath not in provenance_exclude_list and all( - dirname not in provenance_exclude_list for dirname in dirnames - ): - with open(filepath, 'rb') as handle: # type: ignore[assignment] - node.base.repository._repository.put_object_from_filelike(handle, relpath) # pylint: disable=protected-access - - # Since the node is already stored, we cannot use the normal repository interface since it will raise a - # `ModificationNotAllowed` error. To bypass it, we go straight to the underlying repository instance to store the - # files, however, this means we have to manually update the node's repository metadata. - node.base.repository._update_repository_metadata() # pylint: disable=protected-access - - if not dry_run: - # Make sure that attaching the `remote_folder` with a link is the last thing we do. This gives the biggest - # chance of making this method idempotent. That is to say, if a runner gets interrupted during this action, it - # will simply retry the upload, unless we got here and managed to link it up, in which case we move to the next - # task. Because in that case, the check for the existence of this link at the top of this function will exit - # early from this command. - remotedata = RemoteData(computer=computer, remote_path=workdir) - remotedata.base.links.add_incoming(node, link_type=LinkType.CREATE, link_label='remote_folder') - remotedata.store() - - -def submit_calculation(calculation: CalcJobNode, transport: Transport) -> str | ExitCode: - """Submit a previously uploaded `CalcJob` to the scheduler. - - :param calculation: the instance of CalcJobNode to submit. - :param transport: an already opened transport to use to submit the calculation. - :return: the job id as returned by the scheduler `submit_from_script` call - """ - job_id = calculation.get_job_id() - - # If the `job_id` attribute is already set, that means this function was already executed once and the scheduler - # submit command was successful as the job id it returned was set on the node. This scenario can happen when the - # daemon runner gets shutdown right after accomplishing the submission task, but before it gets the chance to - # finalize the state transition of the `CalcJob` to the `UPDATE` transport task. Since the job is already submitted - # we do not want to submit it a second time, so we simply return the existing job id here. - if job_id is not None: - return job_id - - scheduler = calculation.computer.get_scheduler() - scheduler.set_transport(transport) - - submit_script_filename = calculation.get_option('submit_script_filename') - workdir = calculation.get_remote_workdir() - result = scheduler.submit_from_script(workdir, submit_script_filename) - - if isinstance(result, str): - calculation.set_job_id(result) - - return result - - -def stash_calculation(calculation: CalcJobNode, transport: Transport) -> None: - """Stash files from the working directory of a completed calculation to a permanent remote folder. - - After a calculation has been completed, optionally stash files from the work directory to a storage location on the - same remote machine. This is useful if one wants to keep certain files from a completed calculation to be removed - from the scratch directory, because they are necessary for restarts, but that are too heavy to retrieve. - Instructions of which files to copy where are retrieved from the `stash.source_list` option. - - :param calculation: the calculation job node. - :param transport: an already opened transport. - """ - from aiida.common.datastructures import StashMode - from aiida.orm import RemoteStashFolderData - - logger_extra = get_dblogger_extra(calculation) - - stash_options = calculation.get_option('stash') - stash_mode = stash_options.get('mode', StashMode.COPY.value) - source_list = stash_options.get('source_list', []) - - if not source_list: - return - - if stash_mode != StashMode.COPY.value: - EXEC_LOGGER.warning(f'stashing mode {stash_mode} is not implemented yet.') - return - - cls = RemoteStashFolderData - - EXEC_LOGGER.debug(f'stashing files for calculation<{calculation.pk}>: {source_list}', extra=logger_extra) - - uuid = calculation.uuid - source_basepath = pathlib.Path(calculation.get_remote_workdir()) - target_basepath = pathlib.Path(stash_options['target_base']) / uuid[:2] / uuid[2:4] / uuid[4:] - - for source_filename in source_list: - - if transport.has_magic(source_filename): - copy_instructions = [] - for globbed_filename in transport.glob(str(source_basepath / source_filename)): - target_filepath = target_basepath / pathlib.Path(globbed_filename).relative_to(source_basepath) - copy_instructions.append((globbed_filename, target_filepath)) - else: - copy_instructions = [(source_basepath / source_filename, target_basepath / source_filename)] - - for source_filepath, target_filepath in copy_instructions: - # If the source file is in a (nested) directory, create those directories first in the target directory - target_dirname = target_filepath.parent - transport.makedirs(str(target_dirname), ignore_existing=True) - - try: - transport.copy(str(source_filepath), str(target_filepath)) - except (IOError, ValueError) as exception: - EXEC_LOGGER.warning(f'failed to stash {source_filepath} to {target_filepath}: {exception}') - else: - EXEC_LOGGER.debug(f'stashed {source_filepath} to {target_filepath}') - - remote_stash = cls( - computer=calculation.computer, - target_basepath=str(target_basepath), - stash_mode=StashMode(stash_mode), - source_list=source_list, - ).store() - remote_stash.base.links.add_incoming(calculation, link_type=LinkType.CREATE, link_label='remote_stash') - - -def retrieve_calculation(calculation: CalcJobNode, transport: Transport, retrieved_temporary_folder: str) -> None: - """Retrieve all the files of a completed job calculation using the given transport. - - If the job defined anything in the `retrieve_temporary_list`, those entries will be stored in the - `retrieved_temporary_folder`. The caller is responsible for creating and destroying this folder. - - :param calculation: the instance of CalcJobNode to update. - :param transport: an already opened transport to use for the retrieval. - :param retrieved_temporary_folder: the absolute path to a directory in which to store the files - listed, if any, in the `retrieved_temporary_folder` of the jobs CalcInfo - """ - logger_extra = get_dblogger_extra(calculation) - workdir = calculation.get_remote_workdir() - filepath_sandbox = get_config_option('storage.sandbox') or None - - EXEC_LOGGER.debug(f'Retrieving calc {calculation.pk}', extra=logger_extra) - EXEC_LOGGER.debug(f'[retrieval of calc {calculation.pk}] chdir {workdir}', extra=logger_extra) - - # If the calculation already has a `retrieved` folder, simply return. The retrieval was apparently already completed - # before, which can happen if the daemon is restarted and it shuts down after retrieving but before getting the - # chance to perform the state transition. Upon reloading this calculation, it will re-attempt the retrieval. - link_label = calculation.link_label_retrieved - if calculation.base.links.get_outgoing(FolderData, link_label_filter=link_label).first(): - EXEC_LOGGER.warning( - f'CalcJobNode<{calculation.pk}> already has a `{link_label}` output folder: skipping retrieval' - ) - return - - # Create the FolderData node into which to store the files that are to be retrieved - retrieved_files = FolderData() - - with transport: - transport.chdir(workdir) - - # First, retrieve the files of folderdata - retrieve_list = calculation.get_retrieve_list() - retrieve_temporary_list = calculation.get_retrieve_temporary_list() - - with SandboxFolder(filepath_sandbox) as folder: - retrieve_files_from_list(calculation, transport, folder.abspath, retrieve_list) - # Here I retrieved everything; now I store them inside the calculation - retrieved_files.base.repository.put_object_from_tree(folder.abspath) - - # Retrieve the temporary files in the retrieved_temporary_folder if any files were - # specified in the 'retrieve_temporary_list' key - if retrieve_temporary_list: - retrieve_files_from_list(calculation, transport, retrieved_temporary_folder, retrieve_temporary_list) - - # Log the files that were retrieved in the temporary folder - for filename in os.listdir(retrieved_temporary_folder): - EXEC_LOGGER.debug( - f"[retrieval of calc {calculation.pk}] Retrieved temporary file or folder '{filename}'", - extra=logger_extra - ) - - # Store everything - EXEC_LOGGER.debug( - f'[retrieval of calc {calculation.pk}] Storing retrieved_files={retrieved_files.pk}', extra=logger_extra - ) - retrieved_files.store() - - # Make sure that attaching the `retrieved` folder with a link is the last thing we do. This gives the biggest chance - # of making this method idempotent. That is to say, if a runner gets interrupted during this action, it will simply - # retry the retrieval, unless we got here and managed to link it up, in which case we move to the next task. - retrieved_files.base.links.add_incoming( - calculation, link_type=LinkType.CREATE, link_label=calculation.link_label_retrieved - ) - - -def kill_calculation(calculation: CalcJobNode, transport: Transport) -> None: - """ - Kill the calculation through the scheduler - - :param calculation: the instance of CalcJobNode to kill. - :param transport: an already opened transport to use to address the scheduler - """ - job_id = calculation.get_job_id() - - if job_id is None: - # the calculation has not yet been submitted to the scheduler - return - - # Get the scheduler plugin class and initialize it with the correct transport - scheduler = calculation.computer.get_scheduler() - scheduler.set_transport(transport) - - # Call the proper kill method for the job ID of this calculation - result = scheduler.kill(job_id) - - if result is not True: - - # Failed to kill because the job might have already been completed - running_jobs = scheduler.get_jobs(jobs=[job_id], as_dict=True) - job = running_jobs.get(job_id, None) - - # If the job is returned it is still running and the kill really failed, so we raise - if job is not None and job.job_state != JobState.DONE: - raise exceptions.RemoteOperationError(f'scheduler.kill({job_id}) was unsuccessful') - else: - EXEC_LOGGER.warning( - 'scheduler.kill() failed but job<{%s}> no longer seems to be running regardless', job_id - ) - - -def retrieve_files_from_list( - calculation: CalcJobNode, transport: Transport, folder: str, retrieve_list: List[Union[str, Tuple[str, str, int], - list]] -) -> None: - """ - Retrieve all the files in the retrieve_list from the remote into the - local folder instance through the transport. The entries in the retrieve_list - can be of two types: - - * a string - * a list - - If it is a string, it represents the remote absolute filepath of the file. - If the item is a list, the elements will correspond to the following: - - * remotepath - * localpath - * depth - - If the remotepath contains file patterns with wildcards, the localpath will be - treated as the work directory of the folder and the depth integer determines - upto what level of the original remotepath nesting the files will be copied. - - :param transport: the Transport instance. - :param folder: an absolute path to a folder that contains the files to copy. - :param retrieve_list: the list of files to retrieve. - """ - # pylint: disable=too-many-branches - for item in retrieve_list: - if isinstance(item, (list, tuple)): - tmp_rname, tmp_lname, depth = item - # if there are more than one file I do something differently - if transport.has_magic(tmp_rname): - remote_names = transport.glob(tmp_rname) - local_names = [] - for rem in remote_names: - if depth is None: - local_names.append(os.path.join(tmp_lname, rem)) - else: - to_append = rem.split(os.path.sep)[-depth:] if depth > 0 else [] - local_names.append(os.path.sep.join([tmp_lname] + to_append)) - else: - remote_names = [tmp_rname] - to_append = tmp_rname.split(os.path.sep)[-depth:] if depth > 0 else [] - local_names = [os.path.sep.join([tmp_lname] + to_append)] - if depth is None or depth > 1: # create directories in the folder, if needed - for this_local_file in local_names: - new_folder = os.path.join(folder, os.path.split(this_local_file)[0]) - if not os.path.exists(new_folder): - os.makedirs(new_folder) - else: # it is a string - if transport.has_magic(item): - remote_names = transport.glob(item) - local_names = [os.path.split(rem)[1] for rem in remote_names] - else: - remote_names = [item] - local_names = [os.path.split(item)[1]] - - for rem, loc in zip(remote_names, local_names): - transport.logger.debug(f"[retrieval of calc {calculation.pk}] Trying to retrieve remote item '{rem}'") - transport.get(rem, os.path.join(folder, loc), ignore_nonexisting=True) diff --git a/aiida/engine/exceptions.py b/aiida/engine/exceptions.py deleted file mode 100644 index 127687fc13..0000000000 --- a/aiida/engine/exceptions.py +++ /dev/null @@ -1,18 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Exceptions that can be thrown by parts of the workflow engine.""" - -from aiida.common.exceptions import AiidaException - -__all__ = ('PastException',) - - -class PastException(AiidaException): - """Raised when an attempt is made to continue a Process that has already excepted before.""" diff --git a/aiida/engine/launch.py b/aiida/engine/launch.py deleted file mode 100644 index 888536cd61..0000000000 --- a/aiida/engine/launch.py +++ /dev/null @@ -1,126 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Top level functions that can be used to launch a Process.""" -from typing import Any, Dict, Tuple, Type, Union - -from aiida.common import InvalidOperation -from aiida.manage import manager -from aiida.orm import ProcessNode - -from .processes.builder import ProcessBuilder -from .processes.functions import FunctionProcess -from .processes.process import Process -from .utils import instantiate_process, is_process_scoped # pylint: disable=no-name-in-module - -__all__ = ('run', 'run_get_pk', 'run_get_node', 'submit') - -TYPE_RUN_PROCESS = Union[Process, Type[Process], ProcessBuilder] # pylint: disable=invalid-name -# run can also be process function, but it is not clear what type this should be -TYPE_SUBMIT_PROCESS = Union[Process, Type[Process], ProcessBuilder] # pylint: disable=invalid-name - - -def run(process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> Dict[str, Any]: - """Run the process with the supplied inputs in a local runner that will block until the process is completed. - - :param process: the process class or process function to run - :param inputs: the inputs to be passed to the process - - :return: the outputs of the process - - """ - if isinstance(process, Process): - runner = process.runner - else: - runner = manager.get_manager().get_runner() - - return runner.run(process, *args, **inputs) - - -def run_get_node(process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> Tuple[Dict[str, Any], ProcessNode]: - """Run the process with the supplied inputs in a local runner that will block until the process is completed. - - :param process: the process class, instance, builder or function to run - :param inputs: the inputs to be passed to the process - - :return: tuple of the outputs of the process and the process node - - """ - if isinstance(process, Process): - runner = process.runner - else: - runner = manager.get_manager().get_runner() - - return runner.run_get_node(process, *args, **inputs) - - -def run_get_pk(process: TYPE_RUN_PROCESS, *args: Any, **inputs: Any) -> Tuple[Dict[str, Any], int]: - """Run the process with the supplied inputs in a local runner that will block until the process is completed. - - :param process: the process class, instance, builder or function to run - :param inputs: the inputs to be passed to the process - - :return: tuple of the outputs of the process and process node pk - - """ - if isinstance(process, Process): - runner = process.runner - else: - runner = manager.get_manager().get_runner() - - return runner.run_get_pk(process, *args, **inputs) - - -def submit(process: TYPE_SUBMIT_PROCESS, **inputs: Any) -> ProcessNode: - """Submit the process with the supplied inputs to the daemon immediately returning control to the interpreter. - - .. warning: this should not be used within another process. Instead, there one should use the `submit` method of - the wrapping process itself, i.e. use `self.submit`. - - .. warning: submission of processes requires `store_provenance=True` - - :param process: the process class, instance or builder to submit - :param inputs: the inputs to be passed to the process - - :return: the calculation node of the process - - """ - # Submitting from within another process requires `self.submit` unless it is a work function, in which case the - # current process in the scope should be an instance of `FunctionProcess` - if is_process_scoped() and not isinstance(Process.current(), FunctionProcess): - raise InvalidOperation('Cannot use top-level `submit` from within another process, use `self.submit` instead') - - runner = manager.get_manager().get_runner() - assert runner.persister is not None, 'runner does not have a persister' - assert runner.controller is not None, 'runner does not have a persister' - - process_inited = instantiate_process(runner, process, **inputs) - - # If a dry run is requested, simply forward to `run`, because it is not compatible with `submit`. We choose for this - # instead of raising, because in this way the user does not have to change the launcher when testing. The same goes - # for if `remote_folder` is present in the inputs, which means we are importing an already completed calculation. - if process_inited.metadata.get('dry_run', False) or 'remote_folder' in inputs: - _, node = run_get_node(process_inited) - return node - - if not process_inited.metadata.store_provenance: - raise InvalidOperation('cannot submit a process with `store_provenance=False`') - - runner.persister.save_checkpoint(process_inited) - process_inited.close() - - # Do not wait for the future's result, because in the case of a single worker this would cock-block itself - runner.controller.continue_process(process_inited.pid, nowait=False, no_reply=True) - - return process_inited.node - - -# Allow one to also use run.get_node and run.get_pk as a shortcut, without having to import the functions themselves -run.get_node = run_get_node # type: ignore[attr-defined] -run.get_pk = run_get_pk # type: ignore[attr-defined] diff --git a/aiida/engine/processes/__init__.py b/aiida/engine/processes/__init__.py deleted file mode 100644 index 20668be208..0000000000 --- a/aiida/engine/processes/__init__.py +++ /dev/null @@ -1,67 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module for processes and related utilities.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .builder import * -from .calcjobs import * -from .exit_code import * -from .functions import * -from .futures import * -from .ports import * -from .process import * -from .process_spec import * -from .workchains import * - -__all__ = ( - 'Awaitable', - 'AwaitableAction', - 'AwaitableTarget', - 'BaseRestartWorkChain', - 'CalcJob', - 'CalcJobImporter', - 'CalcJobOutputPort', - 'CalcJobProcessSpec', - 'ExitCode', - 'ExitCodesNamespace', - 'FunctionProcess', - 'InputPort', - 'JobManager', - 'JobsList', - 'OutputPort', - 'PORT_NAMESPACE_SEPARATOR', - 'PortNamespace', - 'Process', - 'ProcessBuilder', - 'ProcessBuilderNamespace', - 'ProcessFuture', - 'ProcessHandlerReport', - 'ProcessSpec', - 'ProcessState', - 'ToContext', - 'WithNonDb', - 'WithSerialize', - 'WorkChain', - 'append_', - 'assign_', - 'calcfunction', - 'construct_awaitable', - 'if_', - 'process_handler', - 'return_', - 'while_', - 'workfunction', -) - -# yapf: enable diff --git a/aiida/engine/processes/calcjobs/__init__.py b/aiida/engine/processes/calcjobs/__init__.py deleted file mode 100644 index 77686c9969..0000000000 --- a/aiida/engine/processes/calcjobs/__init__.py +++ /dev/null @@ -1,28 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module for the `CalcJob` process and related utilities.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .calcjob import * -from .importer import * -from .manager import * - -__all__ = ( - 'CalcJob', - 'CalcJobImporter', - 'JobManager', - 'JobsList', -) - -# yapf: enable diff --git a/aiida/engine/processes/calcjobs/calcjob.py b/aiida/engine/processes/calcjobs/calcjob.py deleted file mode 100644 index 475fa94e4d..0000000000 --- a/aiida/engine/processes/calcjobs/calcjob.py +++ /dev/null @@ -1,1064 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=too-many-lines -"""Implementation of the CalcJob process.""" -from __future__ import annotations - -import dataclasses -import io -import json -import os -import shutil -from typing import Any, Dict, Hashable, Optional, Type, Union - -import plumpy.ports -import plumpy.process_states - -from aiida import orm -from aiida.common import AttributeDict, exceptions -from aiida.common.datastructures import CalcInfo -from aiida.common.folders import Folder -from aiida.common.lang import classproperty, override -from aiida.common.links import LinkType - -from ..exit_code import ExitCode -from ..ports import PortNamespace -from ..process import Process, ProcessState -from ..process_spec import CalcJobProcessSpec -from .importer import CalcJobImporter -from .monitors import CalcJobMonitor -from .tasks import UPLOAD_COMMAND, Waiting - -__all__ = ('CalcJob',) - - -def validate_calc_job(inputs: Any, ctx: PortNamespace) -> Optional[str]: # pylint: disable=too-many-return-statements - """Validate the entire set of inputs passed to the `CalcJob` constructor. - - Reasons that will cause this validation to raise an `InputValidationError`: - - * No `Computer` has been specified, neither directly in `metadata.computer` nor indirectly through the `Code` input - * The specified computer is not stored - * The `Computer` specified in `metadata.computer` is not the same as that of the specified `Code` - * No `Code` has been specified and no `remote_folder` input has been specified, i.e. this is no import run - - :return: string with error message in case the inputs are invalid - """ - try: - ctx.get_port('code') - ctx.get_port('metadata.computer') - except ValueError: - # If the namespace no longer contains the `code` or `metadata.computer` ports we skip validation - return None - - remote_folder = inputs.get('remote_folder', None) - - if remote_folder is not None: - # The `remote_folder` input has been specified and so this concerns an import run, which means that neither - # a `Code` nor a `Computer` are required. However, they are allowed to be specified but will not be explicitly - # checked for consistency. - return None - - code = inputs.get('code', None) - computer_from_code = code.computer - computer_from_metadata = inputs.get('metadata', {}).get('computer', None) - - if not computer_from_code and not computer_from_metadata: - return 'no computer has been specified in `metadata.computer` nor via `code`.' - - if computer_from_code and not computer_from_code.is_stored: - return f'the Computer<{computer_from_code}> is not stored' - - if computer_from_metadata and not computer_from_metadata.is_stored: - return f'the Computer<{computer_from_metadata}> is not stored' - - if computer_from_code and computer_from_metadata and computer_from_code.uuid != computer_from_metadata.uuid: - return ( - 'Computer<{}> explicitly defined in `metadata.computer` is different from Computer<{}> which is the ' - 'computer of Code<{}> defined as the `code` input.'.format( - computer_from_metadata, computer_from_code, code - ) - ) - - try: - resources_port = ctx.get_port('metadata.options.resources') - except ValueError: - return None - - # If the resources port exists but is not required, we don't need to validate it against the computer's scheduler - if not resources_port.required: - return None - - computer = computer_from_code or computer_from_metadata - scheduler = computer.get_scheduler() - try: - resources = inputs['metadata']['options']['resources'] - except KeyError: - return 'input `metadata.options.resources` is required but is not specified' - - scheduler.preprocess_resources(resources, computer.get_default_mpiprocs_per_machine()) - - try: - scheduler.validate_resources(**resources) - except ValueError as exception: - return f'input `metadata.options.resources` is not valid for the `{scheduler}` scheduler: {exception}' - - return None - - -def validate_stash_options(stash_options: Any, _: Any) -> Optional[str]: - """Validate the ``stash`` options.""" - from aiida.common.datastructures import StashMode - - target_base = stash_options.get('target_base', None) - source_list = stash_options.get('source_list', None) - stash_mode = stash_options.get('mode', StashMode.COPY.value) - - if not isinstance(target_base, str) or not os.path.isabs(target_base): - return f'`metadata.options.stash.target_base` should be an absolute filepath, got: {target_base}' - - if ( - not isinstance(source_list, (list, tuple)) or - any(not isinstance(src, str) or os.path.isabs(src) for src in source_list) - ): - port = 'metadata.options.stash.source_list' - return f'`{port}` should be a list or tuple of relative filepaths, got: {source_list}' - - try: - StashMode(stash_mode) - except ValueError: - port = 'metadata.options.stash.mode' - return f'`{port}` should be a member of aiida.common.datastructures.StashMode, got: {stash_mode}' - - return None - - -def validate_monitors(monitors: Any, _: PortNamespace) -> Optional[str]: - """Validate the ``monitors`` input namespace.""" - for key, monitor_node in monitors.items(): - try: - CalcJobMonitor(**monitor_node.get_dict()) - except (exceptions.EntryPointError, TypeError, ValueError) as exception: - return f'`monitors.{key}` is invalid: {exception}' - return None - - -def validate_parser(parser_name: Any, _: PortNamespace) -> Optional[str]: - """Validate the parser. - - :return: string with error message in case the inputs are invalid - """ - from aiida.plugins import ParserFactory - - try: - ParserFactory(parser_name) - except exceptions.EntryPointError as exception: - return f'invalid parser specified: {exception}' - - return None - - -def validate_additional_retrieve_list(additional_retrieve_list: Any, _: Any) -> Optional[str]: - """Validate the additional retrieve list. - - :return: string with error message in case the input is invalid. - """ - if any(not isinstance(value, str) or os.path.isabs(value) for value in additional_retrieve_list): - return f'`additional_retrieve_list` should only contain relative filepaths but got: {additional_retrieve_list}' - - return None - - -class CalcJob(Process): - """Implementation of the CalcJob process.""" - - _node_class = orm.CalcJobNode - _spec_class = CalcJobProcessSpec - link_label_retrieved: str = 'retrieved' - - def __init__(self, *args, **kwargs) -> None: - """Construct a CalcJob instance. - - Construct the instance only if it is a sub class of `CalcJob`, otherwise raise `InvalidOperation`. - - See documentation of :class:`aiida.engine.Process`. - """ - if self.__class__ == CalcJob: - raise exceptions.InvalidOperation('cannot construct or launch a base `CalcJob` class.') - - super().__init__(*args, **kwargs) - - @classmethod - def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] - """Define the process specification, including its inputs, outputs and known exit codes. - - Ports are added to the `metadata` input namespace (inherited from the base Process), - and a `code` input Port, a `remote_folder` output Port and retrieved folder output Port - are added. - - :param spec: the calculation job process spec to define. - """ - super().define(spec) - spec.inputs.validator = validate_calc_job # type: ignore[assignment] # takes only PortNamespace not Port - spec.input( - 'code', - valid_type=orm.AbstractCode, - required=False, - help='The `Code` to use for this job. This input is required, unless the `remote_folder` input is ' - 'specified, which means an existing job is being imported and no code will actually be run.' - ) - spec.input_namespace( - 'monitors', - valid_type=orm.Dict, - required=False, - validator=validate_monitors, - help='Add monitoring functions that can inspect output files while the job is running and decide to ' - 'prematurely terminate the job.' - ) - spec.input( - 'remote_folder', - valid_type=orm.RemoteData, - required=False, - help='Remote directory containing the results of an already completed calculation job without AiiDA. The ' - 'inputs should be passed to the `CalcJob` as normal but instead of launching the actual job, the ' - 'engine will recreate the input files and then proceed straight to the retrieve step where the files ' - 'of this `RemoteData` will be retrieved as if it had been actually launched through AiiDA. If a ' - 'parser is defined in the inputs, the results are parsed and attached as output nodes as usual.' - ) - spec.input( - 'metadata.dry_run', - valid_type=bool, - default=False, - help='When set to `True` will prepare the calculation job for submission but not actually launch it.' - ) - spec.input( - 'metadata.computer', - valid_type=orm.Computer, - required=False, - help='When using a "local" code, set the computer on which the calculation should be run.' - ) - spec.input_namespace(f'{spec.metadata_key}.{spec.options_key}', required=False) - spec.input( - 'metadata.options.input_filename', - valid_type=str, - required=False, - help='Filename to which the input for the code that is to be run is written.' - ) - spec.input( - 'metadata.options.output_filename', - valid_type=str, - required=False, - help='Filename to which the content of stdout of the code that is to be run is written.' - ) - spec.input( - 'metadata.options.submit_script_filename', - valid_type=str, - default='_aiidasubmit.sh', - help='Filename to which the job submission script is written.' - ) - spec.input( - 'metadata.options.scheduler_stdout', - valid_type=str, - default='_scheduler-stdout.txt', - help='Filename to which the content of stdout of the scheduler is written.' - ) - spec.input( - 'metadata.options.scheduler_stderr', - valid_type=str, - default='_scheduler-stderr.txt', - help='Filename to which the content of stderr of the scheduler is written.' - ) - spec.input( - 'metadata.options.resources', - valid_type=dict, - required=True, - help='Set the dictionary of resources to be used by the scheduler plugin, like the number of nodes, ' - 'cpus etc. This dictionary is scheduler-plugin dependent. Look at the documentation of the ' - 'scheduler for more details.' - ) - spec.input( - 'metadata.options.max_wallclock_seconds', - valid_type=int, - required=False, - help='Set the wallclock in seconds asked to the scheduler' - ) - spec.input( - 'metadata.options.custom_scheduler_commands', - valid_type=str, - default='', - help='Set a (possibly multiline) string with the commands that the user wants to manually set for the ' - 'scheduler. The difference of this option with respect to the `prepend_text` is the position in ' - 'the scheduler submission file where such text is inserted: with this option, the string is ' - 'inserted before any non-scheduler command' - ) - spec.input( - 'metadata.options.queue_name', - valid_type=str, - required=False, - help='Set the name of the queue on the remote computer' - ) - spec.input( - 'metadata.options.rerunnable', - valid_type=bool, - required=False, - help='Determines if the calculation can be requeued / rerun.' - ) - spec.input( - 'metadata.options.account', - valid_type=str, - required=False, - help='Set the account to use in for the queue on the remote computer' - ) - spec.input( - 'metadata.options.qos', - valid_type=str, - required=False, - help='Set the quality of service to use in for the queue on the remote computer' - ) - spec.input( - 'metadata.options.withmpi', - valid_type=bool, - required=False, - help='Set the calculation to use mpi', - ) - spec.input( - 'metadata.options.mpirun_extra_params', - valid_type=(list, tuple), - default=lambda: [], - help='Set the extra params to pass to the mpirun (or equivalent) command after the one provided in ' - 'computer.mpirun_command. Example: mpirun -np 8 extra_params[0] extra_params[1] ... exec.x', - ) - spec.input( - 'metadata.options.import_sys_environment', - valid_type=bool, - default=True, - help='If set to true, the submission script will load the system environment variables', - ) - spec.input( - 'metadata.options.environment_variables', - valid_type=dict, - default=lambda: {}, - help='Set a dictionary of custom environment variables for this calculation', - ) - spec.input( - 'metadata.options.environment_variables_double_quotes', - valid_type=bool, - default=False, - help='If set to True, use double quotes instead of single quotes to escape the environment variables ' - 'specified in ``environment_variables``.', - ) - spec.input( - 'metadata.options.priority', - valid_type=str, - required=False, - help='Set the priority of the job to be queued' - ) - spec.input( - 'metadata.options.max_memory_kb', - valid_type=int, - required=False, - help='Set the maximum memory (in KiloBytes) to be asked to the scheduler' - ) - spec.input( - 'metadata.options.prepend_text', - valid_type=str, - default='', - help='Set the calculation-specific prepend text, which is going to be prepended in the scheduler-job ' - 'script, just before the code execution', - ) - spec.input( - 'metadata.options.append_text', - valid_type=str, - default='', - help='Set the calculation-specific append text, which is going to be appended in the scheduler-job ' - 'script, just after the code execution', - ) - spec.input( - 'metadata.options.parser_name', - valid_type=str, - required=False, - validator=validate_parser, - help='Set a string for the output parser. Can be None if no output plugin is available or needed' - ) - spec.input( - 'metadata.options.additional_retrieve_list', - required=False, - valid_type=(list, tuple), - validator=validate_additional_retrieve_list, - help='List of relative file paths that should be retrieved in addition to what the plugin specifies.' - ) - spec.input_namespace( - 'metadata.options.stash', - required=False, - populate_defaults=False, - validator=validate_stash_options, - help='Optional directives to stash files after the calculation job has completed.' - ) - spec.input( - 'metadata.options.stash.target_base', - valid_type=str, - required=False, - help='The base location to where the files should be stashd. For example, for the `copy` stash mode, this ' - 'should be an absolute filepath on the remote computer.' - ) - spec.input( - 'metadata.options.stash.source_list', - valid_type=(tuple, list), - required=False, - help='Sequence of relative filepaths representing files in the remote directory that should be stashed.' - ) - spec.input( - 'metadata.options.stash.stash_mode', - valid_type=str, - required=False, - help='Mode with which to perform the stashing, should be value of `aiida.common.datastructures.StashMode`.' - ) - - spec.output( - 'remote_folder', - valid_type=orm.RemoteData, - help='Input files necessary to run the process will be stored in this folder node.' - ) - spec.output( - 'remote_stash', - valid_type=orm.RemoteStashData, - required=False, - help='Contents of the `stash.source_list` option are stored in this remote folder after job completion.' - ) - spec.output( - cls.link_label_retrieved, - valid_type=orm.FolderData, - pass_to_parser=True, - help='Files that are retrieved by the daemon will be stored in this node. By default the stdout and stderr ' - 'of the scheduler will be added, but one can add more by specifying them in `CalcInfo.retrieve_list`.' - ) - - spec.exit_code( - 100, - 'ERROR_NO_RETRIEVED_FOLDER', - invalidates_cache=True, - message='The process did not have the required `retrieved` output.' - ) - spec.exit_code( - 110, 'ERROR_SCHEDULER_OUT_OF_MEMORY', invalidates_cache=True, message='The job ran out of memory.' - ) - spec.exit_code( - 120, 'ERROR_SCHEDULER_OUT_OF_WALLTIME', invalidates_cache=True, message='The job ran out of walltime.' - ) - spec.exit_code( - 131, 'ERROR_SCHEDULER_INVALID_ACCOUNT', invalidates_cache=True, message='The specified account is invalid.' - ) - spec.exit_code( - 140, 'ERROR_SCHEDULER_NODE_FAILURE', invalidates_cache=True, message='The node running the job failed.' - ) - spec.exit_code(150, 'STOPPED_BY_MONITOR', invalidates_cache=True, message='{message}') - - @classproperty - def spec_options(cls): # pylint: disable=no-self-argument - """Return the metadata options port namespace of the process specification of this process. - - :return: options dictionary - :rtype: dict - """ - return cls.spec_metadata['options'] # pylint: disable=unsubscriptable-object - - @classmethod - def get_importer(cls, entry_point_name: str | None = None) -> CalcJobImporter: - """Load the `CalcJobImporter` associated with this `CalcJob` if it exists. - - By default an importer with the same entry point as the ``CalcJob`` will be loaded, however, this can be - overridden using the ``entry_point_name`` argument. - - :param entry_point_name: optional entry point name of a ``CalcJobImporter`` to override the default. - :return: the loaded ``CalcJobImporter``. - :raises: if no importer class could be loaded. - """ - from aiida.plugins import CalcJobImporterFactory - from aiida.plugins.entry_point import get_entry_point_from_class - - if entry_point_name is None: - _, entry_point = get_entry_point_from_class(cls.__module__, cls.__name__) - if entry_point is not None: - entry_point_name = entry_point.name - - assert entry_point_name is not None - - return CalcJobImporterFactory(entry_point_name)() - - @property - def options(self) -> AttributeDict: - """Return the options of the metadata that were specified when this process instance was launched. - - :return: options dictionary - - """ - try: - return self.metadata.options - except AttributeError: - return AttributeDict() - - @classmethod - def get_state_classes(cls) -> Dict[Hashable, Type[plumpy.process_states.State]]: - """A mapping of the State constants to the corresponding state class. - - Overrides the waiting state with the Calcjob specific version. - """ - # Overwrite the waiting state - states_map = super().get_state_classes() - states_map[ProcessState.WAITING] = Waiting - return states_map - - @property - def node(self) -> orm.CalcJobNode: - return super().node # type: ignore - - @override - def on_terminated(self) -> None: - """Cleanup the node by deleting the calulation job state. - - .. note:: This has to be done before calling the super because that will seal the node after we cannot change it - """ - self.node.delete_state() - super().on_terminated() - - @override - def run(self) -> Union[plumpy.process_states.Stop, int, plumpy.process_states.Wait]: - """Run the calculation job. - - This means invoking the `presubmit` and storing the temporary folder in the node's repository. Then we move the - process in the `Wait` state, waiting for the `UPLOAD` transport task to be started. - - :returns: the `Stop` command if a dry run, int if the process has an exit status, - `Wait` command if the calcjob is to be uploaded - - """ - if self.inputs.metadata.dry_run: - self._perform_dry_run() - return plumpy.process_states.Stop(None, True) - - if 'remote_folder' in self.inputs: - exit_code = self._perform_import() - return exit_code - - # The following conditional is required for the caching to properly work. Even if the source node has a process - # state of `Finished` the cached process will still enter the running state. The process state will have then - # been overridden by the engine to `Running` so we cannot check that, but if the `exit_status` is anything other - # than `None`, it should mean this node was taken from the cache, so the process should not be rerun. - if self.node.exit_status is not None: - # Normally the outputs will be attached to the process by a ``Parser``, if defined in the inputs. But in - # this case, the parser will not be called. The outputs will already have been added to the process node - # though, so all that needs to be done here is just also assign them to the process instance. This such that - # when the process returns its results, it returns the actual outputs and not an empty dictionary. - self._outputs = self.node.get_outgoing(link_type=LinkType.CREATE).nested() # pylint: disable=attribute-defined-outside-init - return self.node.exit_status - - # Launch the upload operation - return plumpy.process_states.Wait(msg='Waiting to upload', data=UPLOAD_COMMAND) - - def prepare_for_submission(self, folder: Folder) -> CalcInfo: - """Prepare the calculation for submission. - - Convert the input nodes into the corresponding input files in the format that the code will expect. In addition, - define and return a `CalcInfo` instance, which is a simple data structure that contains information for the - engine, for example, on what files to copy to the remote machine, what files to retrieve once it has completed, - specific scheduler settings and more. - - :param folder: a temporary folder on the local file system. - :returns: the `CalcInfo` instance - """ - raise NotImplementedError() - - def _setup_metadata(self, metadata: dict) -> None: - """Store the metadata on the ProcessNode.""" - computer = metadata.pop('computer', None) - if computer is not None: - self.node.computer = computer - - options = metadata.pop('options', {}) - for option_name, option_value in options.items(): - self.node.set_option(option_name, option_value) - - super()._setup_metadata(metadata) - - def _setup_inputs(self) -> None: - """Create the links between the input nodes and the ProcessNode that represents this process.""" - super()._setup_inputs() - - # If a computer has not yet been set, which should have been done in ``_setup_metadata`` if it was specified - # in the ``metadata`` inputs, set the computer associated with the ``code`` input. Note that not all ``code``s - # will have an associated computer, but in that case the ``computer`` property should return ``None`` and - # nothing would change anyway. - if not self.node.computer: - self.node.computer = self.inputs.code.computer - - def _perform_dry_run(self): - """Perform a dry run. - - Instead of performing the normal sequence of steps, just the `presubmit` is called, which will call the method - `prepare_for_submission` of the plugin to generate the input files based on the inputs. Then the upload action - is called, but using a normal local transport that will copy the files to a local sandbox folder. The generated - input script and the absolute path to the sandbox folder are stored in the `dry_run_info` attribute of the node - of this process. - """ - from aiida.common.folders import SubmitTestFolder - from aiida.engine.daemon.execmanager import upload_calculation - from aiida.transports.plugins.local import LocalTransport - - with LocalTransport() as transport: - with SubmitTestFolder() as folder: - calc_info = self.presubmit(folder) - transport.chdir(folder.abspath) - upload_calculation(self.node, transport, calc_info, folder, inputs=self.inputs, dry_run=True) - self.node.dry_run_info = { # type: ignore - 'folder': folder.abspath, - 'script_filename': self.node.get_option('submit_script_filename') - } - - def _perform_import(self): - """Perform the import of an already completed calculation. - - The inputs contained a `RemoteData` under the key `remote_folder` signalling that this is not supposed to be run - as a normal calculation job, but rather the results are already computed outside of AiiDA and merely need to be - imported. - """ - from aiida.common.datastructures import CalcJobState - from aiida.common.folders import SandboxFolder - from aiida.engine.daemon.execmanager import retrieve_calculation - from aiida.manage import get_config_option - from aiida.transports.plugins.local import LocalTransport - - filepath_sandbox = get_config_option('storage.sandbox') or None - - with LocalTransport() as transport: - with SandboxFolder(filepath_sandbox) as folder: - with SandboxFolder(filepath_sandbox) as retrieved_temporary_folder: - self.presubmit(folder) - self.node.set_remote_workdir(self.inputs.remote_folder.get_remote_path()) - retrieve_calculation(self.node, transport, retrieved_temporary_folder.abspath) - self.node.set_state(CalcJobState.PARSING) - self.node.base.attributes.set(orm.CalcJobNode.IMMIGRATED_KEY, True) - return self.parse(retrieved_temporary_folder.abspath) - - def parse( - self, retrieved_temporary_folder: Optional[str] = None, existing_exit_code: ExitCode | None = None - ) -> ExitCode: - """Parse a retrieved job calculation. - - This is called once it's finished waiting for the calculation to be finished and the data has been retrieved. - - :param retrieved_temporary_folder: The path to the temporary folder - - """ - try: - retrieved = self.node.outputs.retrieved - except exceptions.NotExistent: - return self.exit_codes.ERROR_NO_RETRIEVED_FOLDER # pylint: disable=no-member - - # Call the scheduler output parser - exit_code_scheduler = self.parse_scheduler_output(retrieved) - - if exit_code_scheduler is not None and exit_code_scheduler.status > 0: - # If an exit code is returned by the scheduler output parser, we log it and set it on the node. This will - # allow the actual `Parser` implementation, if defined in the inputs, to inspect it and decide to keep it, - # or override it with a more specific exit code, if applicable. - msg = f'scheduler parser returned exit code<{exit_code_scheduler.status}>: {exit_code_scheduler.message}' - self.logger.warning(msg) - self.node.set_exit_status(exit_code_scheduler.status) - self.node.set_exit_message(exit_code_scheduler.message) - - # Call the retrieved output parser - try: - exit_code_retrieved = self.parse_retrieved_output(retrieved_temporary_folder) - finally: - if retrieved_temporary_folder is not None: - shutil.rmtree(retrieved_temporary_folder, ignore_errors=True) - - if exit_code_retrieved is not None and exit_code_retrieved.status > 0: - msg = f'output parser returned exit code<{exit_code_retrieved.status}>: {exit_code_retrieved.message}' - self.logger.warning(msg) - - # The final exit code is that of the scheduler, unless the output parser returned one - exit_code: Optional[ExitCode] - if exit_code_retrieved is not None: - exit_code = exit_code_retrieved - else: - exit_code = exit_code_scheduler - - # Finally link up the outputs and we're done - for entry in self.node.base.links.get_outgoing(): - self.out(entry.link_label, entry.node) - - if existing_exit_code is not None: - return existing_exit_code - - return exit_code or ExitCode(0) - - @staticmethod - def terminate(exit_code: ExitCode) -> ExitCode: - """Terminate the process immediately and return the given exit code. - - This method is called by :meth:`aiida.engine.processes.calcjobs.tasks.Waiting.execute` if a monitor triggered - the job to be terminated and specified the parsing to be skipped. It will construct the running state and tell - this method to be run, which returns the given exit code which will cause the process to be terminated. - - :param exit_code: The exit code to return. - :returns: The provided exit code. - """ - return exit_code - - def parse_scheduler_output(self, retrieved: orm.Node) -> Optional[ExitCode]: - """Parse the output of the scheduler if that functionality has been implemented for the plugin.""" - computer = self.node.computer - - if computer is None: - self.logger.info( - 'no computer is defined for this calculation job which suggest that it is an imported job and so ' - 'scheduler output probably is not available or not in a format that can be reliably parsed, skipping..' - ) - return None - - scheduler = computer.get_scheduler() - filename_stderr = self.node.get_option('scheduler_stderr') - filename_stdout = self.node.get_option('scheduler_stdout') - - detailed_job_info = self.node.get_detailed_job_info() - - if detailed_job_info is None: - self.logger.info('could not parse scheduler output: the `detailed_job_info` attribute is missing') - elif detailed_job_info.get('retval', 0) != 0: - self.logger.info('could not parse scheduler output: return value of `detailed_job_info` is non-zero') - detailed_job_info = None - - if filename_stderr is None: - self.logger.warning('could not determine `stderr` filename because `scheduler_stderr` option was not set.') - else: - try: - scheduler_stderr = retrieved.base.repository.get_object_content(filename_stderr, mode='r') - except FileNotFoundError: - scheduler_stderr = None - self.logger.warning(f'could not parse scheduler output: the `{filename_stderr}` file is missing') - - if filename_stdout is None: - self.logger.warning('could not determine `stdout` filename because `scheduler_stdout` option was not set.') - else: - try: - scheduler_stdout = retrieved.base.repository.get_object_content(filename_stdout, mode='r') - except FileNotFoundError: - scheduler_stdout = None - self.logger.warning(f'could not parse scheduler output: the `{filename_stdout}` file is missing') - - try: - exit_code = scheduler.parse_output( - detailed_job_info, - scheduler_stdout or '', # type: ignore[arg-type] - scheduler_stderr or '', # type: ignore[arg-type] - ) - except exceptions.FeatureNotAvailable: - self.logger.info(f'`{scheduler.__class__.__name__}` does not implement scheduler output parsing') - return None - except Exception as exception: # pylint: disable=broad-except - self.logger.error(f'the `parse_output` method of the scheduler excepted: {exception}') - return None - - if exit_code is not None and not isinstance(exit_code, ExitCode): - args = (scheduler.__class__.__name__, type(exit_code)) - raise ValueError('`{}.parse_output` returned neither an `ExitCode` nor None, but: {}'.format(*args)) - - return exit_code - - def parse_retrieved_output(self, retrieved_temporary_folder: Optional[str] = None) -> Optional[ExitCode]: - """Parse the retrieved data by calling the parser plugin if it was defined in the inputs.""" - parser_class = self.node.get_parser_class() - - if parser_class is None: - return None - - parser = parser_class(self.node) - parse_kwargs = parser.get_outputs_for_parsing() - - if retrieved_temporary_folder: - parse_kwargs['retrieved_temporary_folder'] = retrieved_temporary_folder - - exit_code = parser.parse(**parse_kwargs) - - for link_label, node in parser.outputs.items(): - try: - self.out(link_label, node) - except ValueError as exception: - self.logger.error(f'invalid value {node} specified with label {link_label}: {exception}') - exit_code = self.exit_codes.ERROR_INVALID_OUTPUT # pylint: disable=no-member - break - - if exit_code is not None and not isinstance(exit_code, ExitCode): - args = (parser_class.__name__, type(exit_code)) - raise ValueError('`{}.parse` returned neither an `ExitCode` nor None, but: {}'.format(*args)) - - return exit_code - - def presubmit(self, folder: Folder) -> CalcInfo: - """Prepares the calculation folder with all inputs, ready to be copied to the cluster. - - :param folder: a SandboxFolder that can be used to write calculation input files and the scheduling script. - - :return calcinfo: the CalcInfo object containing the information needed by the daemon to handle operations. - - """ - # pylint: disable=too-many-locals,too-many-statements,too-many-branches - from aiida.common.datastructures import CodeInfo, CodeRunMode - from aiida.common.exceptions import InputValidationError, InvalidOperation, PluginInternalError, ValidationError - from aiida.common.utils import validate_list_of_string_tuples - from aiida.orm import AbstractCode, Computer, load_code - from aiida.schedulers.datastructures import JobTemplate, JobTemplateCodeInfo - - inputs = self.node.base.links.get_incoming(link_type=LinkType.INPUT_CALC) - - if not self.inputs.metadata.dry_run and not self.node.is_stored: - raise InvalidOperation('calculation node is not stored.') - - computer = self.node.computer - assert computer is not None - codes = [_ for _ in inputs.all_nodes() if isinstance(_, AbstractCode)] - - for code in codes: - if not code.can_run_on_computer(computer): - raise InputValidationError( - 'The selected code {} for calculation {} cannot run on computer {}'.format( - code.pk, self.node.pk, computer.label - ) - ) - - code.validate_working_directory(folder) - - calc_info = self.prepare_for_submission(folder) - calc_info.uuid = str(self.node.uuid) - - # I create the job template to pass to the scheduler - job_tmpl = JobTemplate() - job_tmpl.submit_as_hold = False - job_tmpl.rerunnable = self.options.get('rerunnable', False) - # 'email', 'email_on_started', 'email_on_terminated', - job_tmpl.job_name = f'aiida-{self.node.pk}' - job_tmpl.sched_output_path = self.options.scheduler_stdout - if computer is not None: - job_tmpl.shebang = computer.get_shebang() - if self.options.scheduler_stderr == self.options.scheduler_stdout: - job_tmpl.sched_join_files = True - else: - job_tmpl.sched_error_path = self.options.scheduler_stderr - job_tmpl.sched_join_files = False - - # Set retrieve path, add also scheduler STDOUT and STDERR - retrieve_list = calc_info.retrieve_list or [] - if (job_tmpl.sched_output_path is not None and job_tmpl.sched_output_path not in retrieve_list): - retrieve_list.append(job_tmpl.sched_output_path) - if not job_tmpl.sched_join_files: - if (job_tmpl.sched_error_path is not None and job_tmpl.sched_error_path not in retrieve_list): - retrieve_list.append(job_tmpl.sched_error_path) - retrieve_list.extend(self.node.get_option('additional_retrieve_list') or []) - self.node.set_retrieve_list(retrieve_list) - - # Handle the retrieve_temporary_list - retrieve_temporary_list = calc_info.retrieve_temporary_list or [] - self.node.set_retrieve_temporary_list(retrieve_temporary_list) - - # If the inputs contain a ``remote_folder`` input node, we are in an import scenario and can skip the rest - if 'remote_folder' in inputs.all_link_labels(): - return calc_info - - # The remaining code is only necessary for actual runs, for example, creating the submission script - scheduler = computer.get_scheduler() - - # the if is done so that if the method returns None, this is - # not added. This has two advantages: - # - it does not add too many \n\n if most of the prepend_text are empty - # - most importantly, skips the cases in which one of the methods - # would return None, in which case the join method would raise - # an exception - prepend_texts = [computer.get_prepend_text()] + \ - [code.prepend_text for code in codes] + \ - [calc_info.prepend_text, self.node.get_option('prepend_text')] - job_tmpl.prepend_text = '\n\n'.join(prepend_text for prepend_text in prepend_texts if prepend_text) - - append_texts = [self.node.get_option('append_text'), calc_info.append_text] + \ - [code.append_text for code in codes] + \ - [computer.get_append_text()] - job_tmpl.append_text = '\n\n'.join(append_text for append_text in append_texts if append_text) - - # Set resources, also with get_default_mpiprocs_per_machine - resources = self.node.get_option('resources') - scheduler.preprocess_resources(resources or {}, computer.get_default_mpiprocs_per_machine()) - job_tmpl.job_resource = scheduler.create_job_resource(**resources) # type: ignore - - subst_dict = {'tot_num_mpiprocs': job_tmpl.job_resource.get_tot_num_mpiprocs()} - - for key, value in job_tmpl.job_resource.items(): - subst_dict[key] = value - mpi_args = [arg.format(**subst_dict) for arg in computer.get_mpirun_command()] - extra_mpirun_params = self.node.get_option('mpirun_extra_params') # same for all codes in the same calc - - # set the codes_info - if not isinstance(calc_info.codes_info, (list, tuple)): - raise PluginInternalError('codes_info passed to CalcInfo must be a list of CalcInfo objects') - - tmpl_codes_info = [] - for code_info in calc_info.codes_info: - - if not isinstance(code_info, CodeInfo): - raise PluginInternalError('Invalid codes_info, must be a list of CodeInfo objects') - - if code_info.code_uuid is None: - raise PluginInternalError('CalcInfo should have the information of the code to be launched') - - code = load_code(code_info.code_uuid) - - # Here are the three values that will determine whether the code is to be run with MPI _if_ they are not - # ``None``. If any of them are explicitly defined but are not equivalent, an exception is raised. We use the - # ``self._raw_inputs`` to determine the actual value passed for ``metadata.options.withmpi`` and - # distinghuish it from the default. - raw_inputs = self._raw_inputs or {} # type: ignore[var-annotated] - with_mpi_option = raw_inputs.get('metadata', {}).get('options', {}).get('withmpi', None) - with_mpi_plugin = code_info.withmpi - with_mpi_code = code.with_mpi - - with_mpi_values = [with_mpi_option, with_mpi_plugin, with_mpi_code] - with_mpi_values_defined = [value for value in with_mpi_values if value is not None] - with_mpi_values_set = set(with_mpi_values_defined) - - # If more than one value is defined, they have to be identical, or we raise that a conflict is encountered - if len(with_mpi_values_set) > 1: - error = f'Inconsistent requirements as to whether code `{code}` should be run with or without MPI.' - if with_mpi_option is not None: - error += f'\nThe `metadata.options.withmpi` input was set to `{with_mpi_option}`.' - if with_mpi_plugin is not None: - error += f'\nThe plugin require `{with_mpi_plugin}`.' - if with_mpi_code is not None: - error += f'\nThe code `{code}` required `{with_mpi_code}`.' - raise RuntimeError(error) - - # At this point we know that the three explicit values agree if they are defined, so we simply set the value - if with_mpi_values_set: - with_mpi = with_mpi_values_set.pop() - else: - # Fall back to the default, which is to not use MPI - with_mpi = False - - if with_mpi: - prepend_cmdline_params = code.get_prepend_cmdline_params(mpi_args, extra_mpirun_params) - else: - prepend_cmdline_params = code.get_prepend_cmdline_params() - - cmdline_params = code.get_executable_cmdline_params(code_info.cmdline_params) - - tmpl_code_info = JobTemplateCodeInfo() - tmpl_code_info.prepend_cmdline_params = prepend_cmdline_params - tmpl_code_info.cmdline_params = cmdline_params - tmpl_code_info.use_double_quotes = [computer.get_use_double_quotes(), code.use_double_quotes] - tmpl_code_info.wrap_cmdline_params = code.wrap_cmdline_params - tmpl_code_info.stdin_name = code_info.stdin_name - tmpl_code_info.stdout_name = code_info.stdout_name - tmpl_code_info.stderr_name = code_info.stderr_name - tmpl_code_info.join_files = code_info.join_files or False - - tmpl_codes_info.append(tmpl_code_info) - job_tmpl.codes_info = tmpl_codes_info - - # set the codes execution mode, default set to `SERIAL` - codes_run_mode = CodeRunMode.SERIAL - if calc_info.codes_run_mode: - codes_run_mode = calc_info.codes_run_mode - - job_tmpl.codes_run_mode = codes_run_mode - ######################################################################## - - custom_sched_commands = self.node.get_option('custom_scheduler_commands') - if custom_sched_commands: - job_tmpl.custom_scheduler_commands = custom_sched_commands - - job_tmpl.import_sys_environment = self.node.get_option('import_sys_environment') - - job_tmpl.job_environment = self.node.get_option('environment_variables') - job_tmpl.environment_variables_double_quotes = self.node.get_option('environment_variables_double_quotes') - - queue_name = self.node.get_option('queue_name') - account = self.node.get_option('account') - qos = self.node.get_option('qos') - if queue_name is not None: - job_tmpl.queue_name = queue_name - if account is not None: - job_tmpl.account = account - if qos is not None: - job_tmpl.qos = qos - priority = self.node.get_option('priority') - if priority is not None: - job_tmpl.priority = priority - - job_tmpl.max_memory_kb = self.node.get_option('max_memory_kb') or computer.get_default_memory_per_machine() - - max_wallclock_seconds = self.node.get_option('max_wallclock_seconds') - if max_wallclock_seconds is not None: - job_tmpl.max_wallclock_seconds = max_wallclock_seconds - - submit_script_filename = self.node.get_option('submit_script_filename') - script_content = scheduler.get_submit_script(job_tmpl) - folder.create_file_from_filelike(io.StringIO(script_content), submit_script_filename, 'w', encoding='utf8') - - def encoder(obj): - if dataclasses.is_dataclass(obj): - return dataclasses.asdict(obj) - raise TypeError(f' {obj!r} is not JSON serializable') - - subfolder = folder.get_subfolder('.aiida', create=True) - subfolder.create_file_from_filelike( - io.StringIO(json.dumps(job_tmpl, default=encoder)), 'job_tmpl.json', 'w', encoding='utf8' - ) - subfolder.create_file_from_filelike(io.StringIO(json.dumps(calc_info)), 'calcinfo.json', 'w', encoding='utf8') - - if calc_info.local_copy_list is None: - calc_info.local_copy_list = [] - - if calc_info.remote_copy_list is None: - calc_info.remote_copy_list = [] - - # Some validation - this_pk = self.node.pk if self.node.pk is not None else '[UNSTORED]' - local_copy_list = calc_info.local_copy_list - try: - validate_list_of_string_tuples(local_copy_list, tuple_length=3) - except ValidationError as exception: - raise PluginInternalError( - f'[presubmission of calc {this_pk}] local_copy_list format problem: {exception}' - ) from exception - - remote_copy_list = calc_info.remote_copy_list - try: - validate_list_of_string_tuples(remote_copy_list, tuple_length=3) - except ValidationError as exception: - raise PluginInternalError( - f'[presubmission of calc {this_pk}] remote_copy_list format problem: {exception}' - ) from exception - - for (remote_computer_uuid, _, dest_rel_path) in remote_copy_list: - try: - Computer.collection.get(uuid=remote_computer_uuid) # pylint: disable=unused-variable - except exceptions.NotExistent as exception: - raise PluginInternalError( - '[presubmission of calc {}] ' - 'The remote copy requires a computer with UUID={}' - 'but no such computer was found in the ' - 'database'.format(this_pk, remote_computer_uuid) - ) from exception - if os.path.isabs(dest_rel_path): - raise PluginInternalError( - '[presubmission of calc {}] ' - 'The destination path of the remote copy ' - 'is absolute! ({})'.format(this_pk, dest_rel_path) - ) - - return calc_info diff --git a/aiida/engine/processes/calcjobs/manager.py b/aiida/engine/processes/calcjobs/manager.py deleted file mode 100644 index e75166cf41..0000000000 --- a/aiida/engine/processes/calcjobs/manager.py +++ /dev/null @@ -1,291 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module containing utilities and classes relating to job calculations running on systems that require transport.""" -import asyncio -import contextlib -import contextvars -import logging -import time -from typing import TYPE_CHECKING, Any, Dict, Hashable, Iterator, List, Optional - -from aiida.common import lang -from aiida.orm import AuthInfo - -if TYPE_CHECKING: - from aiida.engine.transports import TransportQueue - from aiida.schedulers.datastructures import JobInfo - -__all__ = ('JobsList', 'JobManager') - - -class JobsList: - """Manager of calculation jobs submitted with a specific ``AuthInfo``, i.e. computer configured for a specific user. - - This container of active calculation jobs is used to update their status periodically in batches, ensuring that - even when a lot of jobs are running, the scheduler update command is not triggered for each job individually. - - In addition, the :py:class:`~aiida.orm.computers.Computer` for which the :py:class:`~aiida.orm.authinfos.AuthInfo` - is configured, can define a minimum polling interval. This class will guarantee that the time between update calls - to the scheduler is larger or equal to that minimum interval. - - Note that since each instance operates on a specific authinfo, the guarantees of batching scheduler update calls - and the limiting of number of calls per unit time, through the minimum polling interval, is only applicable for jobs - launched with that particular authinfo. If multiple authinfo instances with the same computer, have active jobs - these limitations are not respected between them, since there is no communication between ``JobsList`` instances. - See the :py:class:`~aiida.engine.processes.calcjobs.manager.JobManager` for example usage. - """ - - def __init__(self, authinfo: AuthInfo, transport_queue: 'TransportQueue', last_updated: Optional[float] = None): - """Construct an instance for the given authinfo and transport queue. - - :param authinfo: The authinfo used to check the jobs list - :param transport_queue: A transport queue - :param last_updated: initialize the last updated timestamp - - """ - lang.type_check(last_updated, float, allow_none=True) - - self._authinfo = authinfo - self._transport_queue = transport_queue - self._loop = transport_queue.loop - self._logger = logging.getLogger(__name__) - - self._jobs_cache: Dict[Hashable, 'JobInfo'] = {} - self._job_update_requests: Dict[Hashable, asyncio.Future] = {} # Mapping: {job_id: Future} - self._last_updated = last_updated - self._update_handle: Optional[asyncio.TimerHandle] = None - - @property - def logger(self) -> logging.Logger: - """Return the logger configured for this instance. - - :return: the logger - """ - return self._logger - - def get_minimum_update_interval(self) -> float: - """Get the minimum interval that should be respected between updates of the list. - - :return: the minimum interval - - """ - return self._authinfo.computer.get_minimum_job_poll_interval() - - @property - def last_updated(self) -> Optional[float]: - """Get the timestamp of when the list was last updated as produced by `time.time()` - - :return: The last update point - - """ - return self._last_updated - - async def _get_jobs_from_scheduler(self) -> Dict[Hashable, 'JobInfo']: - """Get the current jobs list from the scheduler. - - :return: a mapping of job ids to :py:class:`~aiida.schedulers.datastructures.JobInfo` instances - - """ - with self._transport_queue.request_transport(self._authinfo) as request: - self.logger.info('waiting for transport') - transport = await request - - scheduler = self._authinfo.computer.get_scheduler() - scheduler.set_transport(transport) - - kwargs: Dict[str, Any] = {'as_dict': True} - if scheduler.get_feature('can_query_by_user'): - kwargs['user'] = '$USER' - else: - kwargs['jobs'] = self._get_jobs_with_scheduler() - - scheduler_response = scheduler.get_jobs(**kwargs) - - # Update the last update time and clear the jobs cache - self._last_updated = time.time() - jobs_cache = {} - self.logger.info(f'AuthInfo<{self._authinfo.pk}>: successfully retrieved status of active jobs') - - for job_id, job_info in scheduler_response.items(): - jobs_cache[job_id] = job_info - - return jobs_cache - - async def _update_job_info(self) -> None: - """Update all of the job information objects. - - This will set the futures for all pending update requests where the corresponding job has a new status compared - to the last update. - """ - try: - if not self._update_requests_outstanding(): - return - - # Update our cache of the job states - self._jobs_cache = await self._get_jobs_from_scheduler() - except Exception as exception: - # Set the exception on all the update futures - for future in self._job_update_requests.values(): - if not future.done(): - future.set_exception(exception) - - # Reset the `_update_handle` manually. Normally this is done in the `updating` coroutine, but since we - # reraise this exception, that code path is never hit. If the next time a request comes in, the method - # `_ensure_updating` will falsely conclude we are still updating, since the handle is not `None` and so it - # will not schedule the next update, causing the job update futures to never be resolved. - self._update_handle = None - - raise - else: - for job_id, future in self._job_update_requests.items(): - if not future.done(): - future.set_result(self._jobs_cache.get(job_id, None)) - finally: - self._job_update_requests = {} - - @contextlib.contextmanager - def request_job_info_update(self, job_id: Hashable) -> Iterator['asyncio.Future[JobInfo]']: - """Request job info about a job when the job next changes state. - - If the job is not found in the jobs list at the update, the future will resolve to `None`. - - :param job_id: job identifier - :return: future that will resolve to a `JobInfo` object when the job changes state - """ - # Get or create the future - request = self._job_update_requests.setdefault(job_id, asyncio.Future()) - assert not request.done(), 'Expected pending job info future, found in done state.' - - try: - self._ensure_updating() - yield request - finally: - pass - - def _ensure_updating(self) -> None: - """Ensure that we are updating the job list from the remote resource. - - This will automatically stop if there are no outstanding requests. - """ - - async def updating(): - """Do the actual update, stop if not requests left.""" - await self._update_job_info() - # Any outstanding requests? - if self._update_requests_outstanding(): - self._update_handle = self._loop.call_later( - self._get_next_update_delay(), - asyncio.ensure_future, - updating(), - context=contextvars.Context(), # type: ignore[call-arg] - ) - else: - self._update_handle = None - - # Check if we're already updating - if self._update_handle is None: - self._update_handle = self._loop.call_later( - self._get_next_update_delay(), - asyncio.ensure_future, - updating(), - context=contextvars.Context(), # type: ignore[call-arg] - ) - - @staticmethod - def _has_job_state_changed(old: Optional['JobInfo'], new: Optional['JobInfo']) -> bool: - """Return whether the states `old` and `new` are different. - - - """ - if old is None and new is None: - return False - - if old is None or new is None: - # One is None and the other isn't - return True - - return old.job_state != new.job_state or old.job_substate != new.job_substate - - def _get_next_update_delay(self) -> float: - """Calculate when we are next allowed to poll the scheduler. - - This delay is calculated as the minimum polling interval defined by the authentication info for this instance, - minus time elapsed since the last update. - - :return: delay (in seconds) after which the scheduler may be polled again - - """ - if self.last_updated is None: - # Never updated, so do it straight away - return 0. - - # Make sure to actually 'get' the minimum interval here, in case the user changed since last time - minimum_interval = self.get_minimum_update_interval() - elapsed = time.time() - self.last_updated - - delay = max(minimum_interval - elapsed, 0.) - - return delay - - def _update_requests_outstanding(self) -> bool: - return any(not request.done() for request in self._job_update_requests.values()) - - def _get_jobs_with_scheduler(self) -> List[str]: - """Get all the jobs that are currently with scheduler. - - :return: the list of jobs with the scheduler - :rtype: list - """ - return [str(job_id) for job_id, _ in self._job_update_requests.items()] - - -class JobManager: - """A manager for :py:class:`~aiida.engine.processes.calcjobs.calcjob.CalcJob` submitted to ``Computer`` instances. - - When a calculation job is submitted to a :py:class:`~aiida.orm.computers.Computer`, it actually uses a specific - :py:class:`~aiida.orm.authinfos.AuthInfo`, which is a computer configured for a :py:class:`~aiida.orm.users.User`. - The ``JobManager`` maintains a mapping of :py:class:`~aiida.engine.processes.calcjobs.manager.JobsList` instances - for each authinfo that has active calculation jobs. These jobslist instances are then responsible for bundling - scheduler updates for all the jobs they maintain (i.e. that all share the same authinfo) and update their status. - - As long as a :py:class:`~aiida.engine.runners.Runner` will create a single ``JobManager`` instance and use that for - its lifetime, the guarantees made by the ``JobsList`` about respecting the minimum polling interval of the scheduler - will be maintained. Note, however, that since each ``Runner`` will create its own job manager, these guarantees - only hold per runner. - """ - - def __init__(self, transport_queue: 'TransportQueue') -> None: - self._transport_queue = transport_queue - self._job_lists: Dict[Hashable, 'JobInfo'] = {} - - def get_jobs_list(self, authinfo: AuthInfo) -> JobsList: - """Get or create a new `JobLists` instance for the given authinfo. - - :param authinfo: the `AuthInfo` - :return: a `JobsList` instance - """ - if authinfo.pk not in self._job_lists: - self._job_lists[authinfo.pk] = JobsList(authinfo, self._transport_queue) - - return self._job_lists[authinfo.pk] - - @contextlib.contextmanager - def request_job_info_update(self, authinfo: AuthInfo, job_id: Hashable) -> Iterator['asyncio.Future[JobInfo]']: - """Get a future that will resolve to information about a given job. - - This is a context manager so that if the user leaves the context the request is automatically cancelled. - - """ - with self.get_jobs_list(authinfo).request_job_info_update(job_id) as request: - try: - yield request - finally: - if not request.done(): - request.cancel() diff --git a/aiida/engine/processes/functions.py b/aiida/engine/processes/functions.py deleted file mode 100644 index 8baf92c903..0000000000 --- a/aiida/engine/processes/functions.py +++ /dev/null @@ -1,603 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Class and decorators to generate processes out of simple python functions.""" -from __future__ import annotations - -import collections -import functools -import inspect -import itertools -import logging -import signal -import sys -import types -import typing as t -from typing import TYPE_CHECKING - -import docstring_parser - -from aiida.common.lang import override -from aiida.manage import get_manager -from aiida.orm import ( - Bool, - CalcFunctionNode, - Data, - Dict, - Float, - Int, - List, - ProcessNode, - Str, - WorkFunctionNode, - to_aiida_type, -) -from aiida.orm.utils.mixins import FunctionCalculationMixin - -from .process import Process - -try: - UnionType = types.UnionType # type: ignore[attr-defined] -except AttributeError: - # This type is not available for Python 3.9 and older - UnionType = None # pylint: disable=invalid-name - -try: - get_annotations = inspect.get_annotations # type: ignore[attr-defined] -except AttributeError: - # This is the backport for Python 3.9 and older - from get_annotations import get_annotations # type: ignore[no-redef] - -if TYPE_CHECKING: - from .exit_code import ExitCode - -__all__ = ('calcfunction', 'workfunction', 'FunctionProcess') - -LOGGER = logging.getLogger(__name__) - -FunctionType = t.TypeVar('FunctionType', bound=t.Callable[..., t.Any]) # pylint: disable=invalid-name - - -def get_stack_size(size: int = 2) -> int: # type: ignore[return] - """Return the stack size for the caller's frame. - - This solution is taken from https://stackoverflow.com/questions/34115298/ as a more performant alternative to the - naive ``len(inspect.stack())` solution. This implementation is about three orders of magnitude faster compared to - the naive solution and it scales especially well for larger stacks, which will be usually the case for the usage - of ``aiida-core``. However, it does use the internal ``_getframe`` of the ``sys`` standard library. It this ever - were to stop working, simply switch to using ``len(inspect.stack())``. - - :param size: Hint for the expected stack size. - :returns: The stack size for caller's frame. - """ - frame = sys._getframe(size) # pylint: disable=protected-access - try: - for size in itertools.count(size, 8): # pylint: disable=redefined-argument-from-local - frame = frame.f_back.f_back.f_back.f_back.f_back.f_back.f_back.f_back # type: ignore[assignment,union-attr] - except AttributeError: - while frame: - frame = frame.f_back # type: ignore[assignment] - size += 1 - return size - 1 - - -def calcfunction(function: FunctionType) -> FunctionType: - """ - A decorator to turn a standard python function into a calcfunction. - Example usage: - - >>> from aiida.orm import Int - >>> - >>> # Define the calcfunction - >>> @calcfunction - >>> def sum(a, b): - >>> return a + b - >>> # Run it with some input - >>> r = sum(Int(4), Int(5)) - >>> print(r) - 9 - >>> r.base.links.get_incoming().all() # doctest: +SKIP - [Neighbor(link_type='', link_label='result', - node=)] - >>> r.base.links.get_incoming().get_node_by_label('result').base.links.get_incoming().all_nodes() - [4, 5] - - :param function: The function to decorate. - :return: The decorated function. - """ - return process_function(node_class=CalcFunctionNode)(function) - - -def workfunction(function: FunctionType) -> FunctionType: - """ - A decorator to turn a standard python function into a workfunction. - Example usage: - - >>> from aiida.orm import Int - >>> - >>> # Define the workfunction - >>> @workfunction - >>> def select(a, b): - >>> return a - >>> # Run it with some input - >>> r = select(Int(4), Int(5)) - >>> print(r) - 4 - >>> r.base.links.get_incoming().all() # doctest: +SKIP - [Neighbor(link_type='', link_label='result', - node=)] - >>> r.base.links.get_incoming().get_node_by_label('result').base.links.get_incoming().all_nodes() - [4, 5] - - :param function: The function to decorate. - :return: The decorated function. - """ - return process_function(node_class=WorkFunctionNode)(function) - - -def process_function(node_class: t.Type['ProcessNode']) -> t.Callable[[FunctionType], FunctionType]: - """ - The base function decorator to create a FunctionProcess out of a normal python function. - - :param node_class: the ORM class to be used as the Node record for the FunctionProcess - """ - - def decorator(function: FunctionType) -> FunctionType: - """ - Turn the decorated function into a FunctionProcess. - - :param callable function: the actual decorated function that the FunctionProcess represents - :return callable: The decorated function. - """ - process_class = FunctionProcess.build(function, node_class=node_class) - - def run_get_node(*args, **kwargs) -> tuple[dict[str, t.Any] | None, 'ProcessNode']: - """ - Run the FunctionProcess with the supplied inputs in a local runner. - - :param args: input arguments to construct the FunctionProcess - :param kwargs: input keyword arguments to construct the FunctionProcess - :return: tuple of the outputs of the process and the process node - """ - frame_delta = 1000 - frame_count = get_stack_size() - stack_limit = sys.getrecursionlimit() - LOGGER.info('Executing process function, current stack status: %d frames of %d', frame_count, stack_limit) - - # If the current frame count is more than 80% of the stack limit, or comes within 200 frames, increase the - # stack limit by ``frame_delta``. - if frame_count > min(0.8 * stack_limit, stack_limit - 200): - LOGGER.warning( - 'Current stack contains %d frames which is close to the limit of %d. Increasing the limit by %d', - frame_count, stack_limit, frame_delta - ) - sys.setrecursionlimit(stack_limit + frame_delta) - - manager = get_manager() - runner = manager.get_runner() - inputs = process_class.create_inputs(*args, **kwargs) - - # Remove all the known inputs from the kwargs - for port in process_class.spec().inputs: - kwargs.pop(port, None) - - # If any kwargs remain, the spec should be dynamic, so we raise if it isn't - if kwargs and not process_class.spec().inputs.dynamic: - raise ValueError(f'{function.__name__} does not support these kwargs: {kwargs.keys()}') - - process = process_class(inputs=inputs, runner=runner) - - # Only add handlers for interrupt signal to kill the process if we are in a local and not a daemon runner. - # Without this check, running process functions in a daemon worker would be killed if the daemon is shutdown - current_runner = manager.get_runner() - original_handler = None - kill_signal = signal.SIGINT - - if not current_runner.is_daemon_runner: - - def kill_process(_num, _frame): - """Send the kill signal to the process in the current scope.""" - LOGGER.critical('runner received interrupt, killing process %s', process.pid) - result = process.kill(msg='Process was killed because the runner received an interrupt') - return result - - # Store the current handler on the signal such that it can be restored after process has terminated - original_handler = signal.getsignal(kill_signal) - signal.signal(kill_signal, kill_process) - - try: - result = process.execute() - finally: - # If the `original_handler` is set, that means the `kill_process` was bound, which needs to be reset - if original_handler: - signal.signal(signal.SIGINT, original_handler) - - store_provenance = inputs.get('metadata', {}).get('store_provenance', True) - if not store_provenance: - process.node._storable = False # pylint: disable=protected-access - process.node._unstorable_message = 'cannot store node because it was run with `store_provenance=False`' # pylint: disable=protected-access - - return result, process.node - - def run_get_pk(*args, **kwargs) -> tuple[dict[str, t.Any] | None, int]: - """Recreate the `run_get_pk` utility launcher. - - :param args: input arguments to construct the FunctionProcess - :param kwargs: input keyword arguments to construct the FunctionProcess - :return: tuple of the outputs of the process and the process node pk - - """ - result, node = run_get_node(*args, **kwargs) - return result, node.pk - - @functools.wraps(function) - def decorated_function(*args, **kwargs): - """This wrapper function is the actual function that is called.""" - result, _ = run_get_node(*args, **kwargs) - return result - - decorated_function.run = decorated_function # type: ignore[attr-defined] - decorated_function.run_get_pk = run_get_pk # type: ignore[attr-defined] - decorated_function.run_get_node = run_get_node # type: ignore[attr-defined] - decorated_function.is_process_function = True # type: ignore[attr-defined] - decorated_function.node_class = node_class # type: ignore[attr-defined] - decorated_function.process_class = process_class # type: ignore[attr-defined] - decorated_function.recreate_from = process_class.recreate_from # type: ignore[attr-defined] - decorated_function.spec = process_class.spec # type: ignore[attr-defined] - - return decorated_function # type: ignore[return-value] - - return decorator - - -def infer_valid_type_from_type_annotation(annotation: t.Any) -> tuple[t.Any, ...]: - """Infer the value for the ``valid_type`` of an input port from the given function argument annotation. - - :param annotation: The annotation of a function argument as returned by ``inspect.get_annotation``. - :returns: A tuple of valid types. If no valid types were defined or they could not be successfully parsed, an empty - tuple is returned. - """ - - def get_type_from_annotation(annotation): - valid_type_map = { - bool: Bool, - dict: Dict, - t.Dict: Dict, - float: Float, - int: Int, - list: List, - t.List: List, - str: Str, - } - - if inspect.isclass(annotation) and issubclass(annotation, Data): - return annotation - - return valid_type_map.get(annotation) - - inferred_valid_type: tuple[t.Any, ...] = () - - if inspect.isclass(annotation): - inferred_valid_type = (get_type_from_annotation(annotation),) - elif t.get_origin(annotation) is t.Union or t.get_origin(annotation) is UnionType: - inferred_valid_type = tuple(get_type_from_annotation(valid_type) for valid_type in t.get_args(annotation)) - elif t.get_origin(annotation) is t.Optional: - inferred_valid_type = (t.get_args(annotation),) - - return tuple(valid_type for valid_type in inferred_valid_type if valid_type is not None) - - -class FunctionProcess(Process): - """Function process class used for turning functions into a Process""" - - _func_args: t.Sequence[str] = () - _varargs: str | None = None - - @staticmethod - def _func(*_args, **_kwargs) -> dict: - """ - This is used internally to store the actual function that is being - wrapped and will be replaced by the build method. - """ - return {} - - @staticmethod - def build(func: FunctionType, node_class: t.Type['ProcessNode']) -> t.Type['FunctionProcess']: - """ - Build a Process from the given function. - - All function arguments will be assigned as process inputs. If keyword arguments are specified then - these will also become inputs. - - :param func: The function to build a process from - :param node_class: Provide a custom node class to be used, has to be constructable with no arguments. It has to - be a sub class of `ProcessNode` and the mixin :class:`~aiida.orm.utils.mixins.FunctionCalculationMixin`. - - :return: A Process class that represents the function - - """ - # pylint: disable=too-many-statements - if not issubclass(node_class, ProcessNode) or not issubclass(node_class, FunctionCalculationMixin): - raise TypeError('the node_class should be a sub class of `ProcessNode` and `FunctionCalculationMixin`') - - signature = inspect.signature(func) - - args: list[str] = [] - varargs: str | None = None - keywords: str | None = None - - try: - annotations = get_annotations(func, eval_str=True) - except Exception as exception: # pylint: disable=broad-except - # Since we are running with ``eval_str=True`` to unstringize the annotations, the call can except if the - # annotations are incorrect. In this case we simply want to log a warning and continue with type inference. - LOGGER.warning(f'function `{func.__name__}` has invalid type hints: {exception}') - annotations = {} - - try: - parsed_docstring = docstring_parser.parse(func.__doc__) - except Exception as exception: # pylint: disable=broad-except - LOGGER.warning(f'function `{func.__name__}` has a docstring that could not be parsed: {exception}') - param_help_string = {} - namespace_help_string = None - else: - param_help_string = {param.arg_name: param.description for param in parsed_docstring.params} - namespace_help_string = parsed_docstring.short_description if parsed_docstring.short_description else '' - if parsed_docstring.long_description is not None: - namespace_help_string += f'\n\n{parsed_docstring.long_description}' - - for key, parameter in signature.parameters.items(): - - if parameter.kind in [parameter.POSITIONAL_ONLY, parameter.POSITIONAL_OR_KEYWORD, parameter.KEYWORD_ONLY]: - args.append(key) - - if parameter.kind is parameter.VAR_POSITIONAL: - varargs = key - - if parameter.kind is parameter.VAR_KEYWORD: - varargs = key - - def _define(cls, spec): # pylint: disable=unused-argument - """Define the spec dynamically""" - from plumpy.ports import UNSPECIFIED - - super().define(spec) - - for parameter in signature.parameters.values(): - - if parameter.kind in [parameter.VAR_POSITIONAL, parameter.VAR_KEYWORD]: - continue - - annotation = annotations.get(parameter.name) - valid_type = infer_valid_type_from_type_annotation(annotation) or (Data,) - help_string = param_help_string.get(parameter.name, None) - - default = parameter.default if parameter.default is not parameter.empty else UNSPECIFIED - - # If the keyword was already specified, simply override the default - if spec.has_input(parameter.name): - spec.inputs[parameter.name].default = default - continue - - # If the default is ``None`` make sure that the port also accepts a ``NoneType``. Note that we cannot - # use ``None`` because the validation will call ``isinstance`` which does not work when passing ``None`` - # but it does work with ``NoneType`` which is returned by calling ``type(None)``. - if default is None: - valid_type += (type(None),) - - # If a default is defined and it is not a ``Data`` instance it should be serialized, but this should be - # done lazily using a lambda, just as any port defaults should not define node instances directly as is - # also checked by the ``spec.input`` call. - if ( - default is not None and default != UNSPECIFIED and not isinstance(default, Data) and - not callable(default) - ): - indirect_default = lambda value=default: to_aiida_type(value) # pylint: disable=unnecessary-lambda-assignment - else: - indirect_default = default - - spec.input( - parameter.name, - valid_type=valid_type, - default=indirect_default, - serializer=to_aiida_type, - help=help_string, - ) - - # Set defaults for label and description based on function name and docstring, if not explicitly defined - port_label = spec.inputs['metadata']['label'] - - if not port_label.has_default(): - port_label.default = func.__name__ - - spec.inputs.help = namespace_help_string - - # If the function supports varargs or kwargs then allow dynamic inputs, otherwise disallow - spec.inputs.dynamic = keywords is not None or varargs - - # Function processes must have a dynamic output namespace since we do not know beforehand what outputs - # will be returned and the valid types for the value should be `Data` nodes as well as a dictionary because - # the output namespace can be nested. - spec.outputs.valid_type = (Data, dict) - - return type( - func.__qualname__, (FunctionProcess,), { - '__module__': func.__module__, - '__name__': func.__name__, - '__qualname__': func.__qualname__, - '_func': staticmethod(func), - Process.define.__name__: classmethod(_define), - '_func_args': args, - '_varargs': varargs or None, - '_node_class': node_class - } - ) - - @classmethod - def validate_inputs(cls, *args: t.Any, **kwargs: t.Any) -> None: # pylint: disable=unused-argument - """ - Validate the positional and keyword arguments passed in the function call. - - :raises TypeError: if more positional arguments are passed than the function defines - """ - nargs = len(args) - nparameters = len(cls._func_args) - has_varargs = cls._varargs is not None - - # If the spec is dynamic, i.e. the function signature includes `**kwargs` and the number of positional arguments - # passed is larger than the number of explicitly defined parameters in the signature, the inputs are invalid and - # we should raise. If we don't, some of the passed arguments, intended to be positional arguments, will be - # misinterpreted as keyword arguments, but they won't have an explicit name to use for the link label, causing - # the input link to be completely lost. If the function supports variadic arguments, however, additional args - # should be accepted. - if cls.spec().inputs.dynamic and nargs > nparameters and not has_varargs: - name = cls._func.__name__ - raise TypeError(f'{name}() takes {nparameters} positional arguments but {nargs} were given') - - @classmethod - def create_inputs(cls, *args: t.Any, **kwargs: t.Any) -> dict[str, t.Any]: - """Create the input args for the FunctionProcess.""" - cls.validate_inputs(*args, **kwargs) - - ins = {} - if kwargs: - ins.update(kwargs) - if args: - ins.update(cls.args_to_dict(*args)) - return ins - - @classmethod - def args_to_dict(cls, *args: t.Any) -> dict[str, t.Any]: - """ - Create an input dictionary (of form label -> value) from supplied args. - - :param args: The values to use for the dictionary - - :return: A label -> value dictionary - - """ - dictionary = {} - values = list(args) - - for arg in cls._func_args: - try: - dictionary[arg] = values.pop(0) - except IndexError: - pass - - # If arguments remain and the function supports variadic arguments, add those as well. - if cls._varargs and args: - - # By default the prefix for variadic labels is the key with which the varargs were declared - variadic_prefix = cls._varargs - - for index, arg in enumerate(values): - label = f'{variadic_prefix}_{index}' - - # If the generated vararg label overlaps with a keyword argument, function signature should be changed - if label in dictionary: - raise RuntimeError( - f'variadic argument with index `{index}` would get the label `{label}` but this is already in ' - 'use by another function argument with the exact same name. To avoid this error, please change ' - f'the name of argument `{label}` to something else.' - ) - - dictionary[label] = arg - - return dictionary - - @classmethod - def get_or_create_db_record(cls) -> 'ProcessNode': - return cls._node_class() - - def __init__(self, *args, **kwargs) -> None: - if kwargs.get('enable_persistence', False): - raise RuntimeError('Cannot persist a function process') - super().__init__(enable_persistence=False, *args, **kwargs) # type: ignore - - @property - def process_class(self) -> t.Callable[..., t.Any]: - """ - Return the class that represents this Process, for the FunctionProcess this is the function itself. - - For a standard Process or sub class of Process, this is the class itself. However, for legacy reasons, - the Process class is a wrapper around another class. This function returns that original class, i.e. the - class that really represents what was being executed. - - :return: A Process class that represents the function - - """ - return self._func - - def execute(self) -> dict[str, t.Any] | None: - """Execute the process.""" - result = super().execute() - - # FunctionProcesses can return a single value as output, and not a dictionary, so we should also return that - if result and len(result) == 1 and self.SINGLE_OUTPUT_LINKNAME in result: - return result[self.SINGLE_OUTPUT_LINKNAME] - - return result - - @override - def _setup_db_record(self) -> None: - """Set up the database record for the process.""" - super()._setup_db_record() - self.node.store_source_info(self._func) - - @override - def run(self) -> 'ExitCode' | None: - """Run the process.""" - from .exit_code import ExitCode - - # The following conditional is required for the caching to properly work. Even if the source node has a process - # state of `Finished` the cached process will still enter the running state. The process state will have then - # been overridden by the engine to `Running` so we cannot check that, but if the `exit_status` is anything other - # than `None`, it should mean this node was taken from the cache, so the process should not be rerun. - if self.node.exit_status is not None: - return ExitCode(self.node.exit_status, self.node.exit_message) - - # Split the inputs into positional and keyword arguments - args = [None] * len(self._func_args) - kwargs = {} - - for name, value in (self.inputs or {}).items(): - try: - if self.spec().inputs[name].is_metadata: # type: ignore[union-attr] - # Don't consider ports that defined ``is_metadata=True`` - continue - except KeyError: - pass # No port found - - # Check if it is a positional arg, if not then keyword - try: - args[self._func_args.index(name)] = value - except ValueError: - if name.startswith(f'{self._varargs}_'): - args.append(value) - else: - kwargs[name] = value - - result = self._func(*args, **kwargs) - - if result is None or isinstance(result, ExitCode): - return result - - if isinstance(result, Data): - self.out(self.SINGLE_OUTPUT_LINKNAME, result) - elif isinstance(result, collections.abc.Mapping): - for name, value in result.items(): - self.out(name, value) - else: - raise TypeError( - "Function process returned an output with unsupported type '{}'\n" - 'Must be a Data type or a mapping of {{string: Data}}'.format(result.__class__) - ) - - return ExitCode() diff --git a/aiida/engine/processes/process.py b/aiida/engine/processes/process.py deleted file mode 100644 index eebfc3bc52..0000000000 --- a/aiida/engine/processes/process.py +++ /dev/null @@ -1,1085 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=too-many-lines -"""The AiiDA process class""" -import asyncio -import collections -from collections.abc import Mapping -import copy -import enum -import inspect -import logging -import traceback -from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Iterable, - Iterator, - List, - MutableMapping, - Optional, - Tuple, - Type, - Union, - cast, -) -from uuid import UUID - -from aio_pika.exceptions import ConnectionClosed -from kiwipy.communications import UnroutableError -import plumpy.exceptions -import plumpy.futures -import plumpy.persistence -from plumpy.process_states import Finished, ProcessState -import plumpy.processes -from plumpy.utils import AttributesFrozendict - -from aiida import orm -from aiida.common import exceptions -from aiida.common.extendeddicts import AttributeDict -from aiida.common.lang import classproperty, override -from aiida.common.links import LinkType -from aiida.common.log import LOG_LEVEL_REPORT -from aiida.orm.implementation.utils import clean_value -from aiida.orm.utils import serialize - -from .builder import ProcessBuilder -from .exit_code import ExitCode, ExitCodesNamespace -from .ports import PORT_NAMESPACE_SEPARATOR, InputPort, OutputPort, PortNamespace -from .process_spec import ProcessSpec -from .utils import prune_mapping - -if TYPE_CHECKING: - from aiida.engine.runners import Runner - -__all__ = ('Process', 'ProcessState') - - -@plumpy.persistence.auto_persist('_parent_pid', '_enable_persistence') -class Process(plumpy.processes.Process): - """ - This class represents an AiiDA process which can be executed and will - have full provenance saved in the database. - """ - # pylint: disable=too-many-public-methods - - _node_class = orm.ProcessNode - _spec_class = ProcessSpec - - SINGLE_OUTPUT_LINKNAME: str = 'result' - - class SaveKeys(enum.Enum): - """ - Keys used to identify things in the saved instance state bundle. - """ - CALC_ID: str = 'calc_id' - - @classmethod - def spec(cls) -> ProcessSpec: - return super().spec() # type: ignore[return-value] - - @classmethod - def define(cls, spec: ProcessSpec) -> None: # type: ignore[override] - """Define the specification of the process, including its inputs, outputs and known exit codes. - - A `metadata` input namespace is defined, with optional ports that are not stored in the database. - - """ - super().define(spec) - spec.input_namespace(spec.metadata_key, required=False, is_metadata=True) - spec.input( - f'{spec.metadata_key}.store_provenance', - valid_type=bool, - default=True, - help='If set to `False` provenance will not be stored in the database.' - ) - spec.input( - f'{spec.metadata_key}.description', - valid_type=str, - required=False, - help='Description to set on the process node.' - ) - spec.input( - f'{spec.metadata_key}.label', valid_type=str, required=False, help='Label to set on the process node.' - ) - spec.input( - f'{spec.metadata_key}.call_link_label', - valid_type=str, - default='CALL', - help='The label to use for the `CALL` link if the process is called by another process.' - ) - spec.inputs.valid_type = orm.Data - spec.inputs.dynamic = False # Settings a ``valid_type`` automatically makes it dynamic, so we reset it again - spec.exit_code( - 1, 'ERROR_UNSPECIFIED', invalidates_cache=True, message='The process has failed with an unspecified error.' - ) - spec.exit_code( - 2, 'ERROR_LEGACY_FAILURE', invalidates_cache=True, message='The process failed with legacy failure mode.' - ) - spec.exit_code( - 10, 'ERROR_INVALID_OUTPUT', invalidates_cache=True, message='The process returned an invalid output.' - ) - spec.exit_code( - 11, - 'ERROR_MISSING_OUTPUT', - invalidates_cache=True, - message='The process did not register a required output.' - ) - - @classmethod - def get_builder(cls) -> ProcessBuilder: - return ProcessBuilder(cls) - - @classmethod - def get_or_create_db_record(cls) -> orm.ProcessNode: - """ - Create a process node that represents what happened in this process. - - :return: A process node - """ - return cls._node_class() - - def __init__( - self, - inputs: Optional[Dict[str, Any]] = None, - logger: Optional[logging.Logger] = None, - runner: Optional['Runner'] = None, - parent_pid: Optional[int] = None, - enable_persistence: bool = True - ) -> None: - """ Process constructor. - - :param inputs: process inputs - :param logger: aiida logger - :param runner: process runner - :param parent_pid: id of parent process - :param enable_persistence: whether to persist this process - - """ - from aiida.manage import manager - - self._runner = runner if runner is not None else manager.get_manager().get_runner() - # assert self._runner.communicator is not None, 'communicator not set for runner' - - super().__init__( - inputs=self.spec().inputs.serialize(inputs), - logger=logger, - loop=self._runner.loop, - communicator=self._runner.communicator - ) - - self._node: Optional[orm.ProcessNode] = None - self._parent_pid = parent_pid - self._enable_persistence = enable_persistence - if self._enable_persistence and self.runner.persister is None: - self.logger.warning('Disabling persistence, runner does not have a persister') - self._enable_persistence = False - - def init(self) -> None: - super().init() - if self._logger is None: - self.set_logger(self.node.logger) - - @classmethod - def get_exit_statuses(cls, exit_code_labels: Iterable[str]) -> List[int]: - """Return the exit status (integers) for the given exit code labels. - - :param exit_code_labels: a list of strings that reference exit code labels of this process class - :return: list of exit status integers that correspond to the given exit code labels - :raises AttributeError: if at least one of the labels does not correspond to an existing exit code - """ - exit_codes = cls.exit_codes - return [getattr(exit_codes, label).status for label in exit_code_labels] - - @classproperty - def exit_codes(cls) -> ExitCodesNamespace: # pylint: disable=no-self-argument - """Return the namespace of exit codes defined for this WorkChain through its ProcessSpec. - - The namespace supports getitem and getattr operations with an ExitCode label to retrieve a specific code. - Additionally, the namespace can also be called with either the exit code integer status to retrieve it. - - :returns: ExitCodesNamespace of ExitCode named tuples - - """ - return cls.spec().exit_codes - - @classproperty - def spec_metadata(cls) -> PortNamespace: # pylint: disable=no-self-argument - """Return the metadata port namespace of the process specification of this process.""" - return cls.spec().inputs['metadata'] # type: ignore[return-value] - - @property - def node(self) -> orm.ProcessNode: - """Return the ProcessNode used by this process to represent itself in the database. - - :return: instance of sub class of ProcessNode - """ - assert self._node is not None - return self._node - - @property - def uuid(self) -> str: # type: ignore[override] - """Return the UUID of the process which corresponds to the UUID of its associated `ProcessNode`. - - :return: the UUID associated to this process instance - """ - return self.node.uuid - - @property - def inputs(self) -> AttributesFrozendict: - """Return the inputs attribute dictionary or an empty one. - - This overrides the property of the base class because that can also return ``None``. This override ensures - calling functions that they will always get an instance of ``AttributesFrozenDict``. - """ - return super().inputs or AttributesFrozendict() - - @property - def metadata(self) -> AttributeDict: - """Return the metadata that were specified when this process instance was launched. - - :return: metadata dictionary - - """ - try: - assert self.inputs is not None - return self.inputs.metadata - except (AssertionError, AttributeError): - return AttributeDict() - - def _save_checkpoint(self) -> None: - """ - Save the current state in a chechpoint if persistence is enabled and the process state is not terminal - - If the persistence call excepts with a PersistenceError, it will be caught and a warning will be logged. - """ - if self._enable_persistence and not self._state.is_terminal(): - if self.runner.persister is None: - self.logger.exception( - 'No persister set to save checkpoint, this means you will ' - 'not be able to restart in case of a crash until the next successful checkpoint.' - ) - return None - try: - self.runner.persister.save_checkpoint(self) - except plumpy.exceptions.PersistenceError: - self.logger.exception( - 'Exception trying to save checkpoint, this means you will ' - 'not be able to restart in case of a crash until the next successful checkpoint.' - ) - - @override - def save_instance_state( - self, out_state: MutableMapping[str, Any], save_context: Optional[plumpy.persistence.LoadSaveContext] - ) -> None: - """Save instance state. - - See documentation of :meth:`!plumpy.processes.Process.save_instance_state`. - """ - super().save_instance_state(out_state, save_context) - - if self.metadata.store_provenance: - assert self.node.is_stored - - out_state[self.SaveKeys.CALC_ID.value] = self.pid - - def get_provenance_inputs_iterator(self) -> Iterator[Tuple[str, Union[InputPort, PortNamespace]]]: - """Get provenance input iterator. - - :rtype: filter - """ - assert self.inputs is not None - return filter(lambda kv: not kv[0].startswith('_'), self.inputs.items()) - - @override - def load_instance_state( - self, saved_state: MutableMapping[str, Any], load_context: plumpy.persistence.LoadSaveContext - ) -> None: - """Load instance state. - - :param saved_state: saved instance state - :param load_context: - - """ - from aiida.manage import manager - - if 'runner' in load_context: - self._runner = load_context.runner - else: - self._runner = manager.get_manager().get_runner() - - load_context = load_context.copyextend(loop=self._runner.loop, communicator=self._runner.communicator) - super().load_instance_state(saved_state, load_context) - - if self.SaveKeys.CALC_ID.value in saved_state: - self._node = orm.load_node(saved_state[self.SaveKeys.CALC_ID.value]) # type: ignore - self._pid = self.node.pk # pylint: disable=attribute-defined-outside-init - else: - self._pid = self._create_and_setup_db_record() # pylint: disable=attribute-defined-outside-init - - self.node.logger.info(f'Loaded process<{self.node.pk}> from saved state') - - def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.futures.Future]: - """ - Kill the process and all the children calculations it called - - :param msg: message - """ - self.node.logger.info(f'Request to kill Process<{self.node.pk}>') - - had_been_terminated = self.has_terminated() - - result = super().kill(msg) - - # Only kill children if we could be killed ourselves - if result is not False and not had_been_terminated: - killing = [] - for child in self.node.called: - if self.runner.controller is None: - self.logger.info('no controller available to kill child<%s>', child.pk) - continue - try: - result = self.runner.controller.kill_process(child.pk, f'Killed by parent<{self.node.pk}>') - result = asyncio.wrap_future(result) # type: ignore[arg-type] - if asyncio.isfuture(result): - killing.append(result) - except ConnectionClosed: - self.logger.info('no connection available to kill child<%s>', child.pk) - except UnroutableError: - self.logger.info('kill signal was unable to reach child<%s>', child.pk) - - if asyncio.isfuture(result): - # We ourselves are waiting to be killed so add it to the list - killing.append(result) - - if killing: - # We are waiting for things to be killed, so return the 'gathered' future - kill_future = plumpy.futures.gather(*killing) - result = self.loop.create_future() - - def done(done_future: plumpy.futures.Future): - is_all_killed = all(done_future.result()) - result.set_result(is_all_killed) # type: ignore[union-attr] - - kill_future.add_done_callback(done) - - return result - - @override - def out(self, output_port: str, value: Any = None) -> None: - """Attach output to output port. - - The name of the port will be used as the link label. - - :param output_port: name of output port - :param value: value to put inside output port - - """ - if value is None: - # In this case assume that output_port is the actual value and there is just one return value - value = output_port - output_port = self.SINGLE_OUTPUT_LINKNAME - - return super().out(output_port, value) - - def out_many(self, out_dict: Dict[str, Any]) -> None: - """Attach outputs to multiple output ports. - - Keys of the dictionary will be used as output port names, values as outputs. - - :param out_dict: output dictionary - :type out_dict: dict - """ - for key, value in out_dict.items(): - self.out(key, value) - - def on_create(self) -> None: - """Called when a Process is created.""" - super().on_create() - # If parent PID hasn't been supplied try to get it from the stack - if self._parent_pid is None and Process.current(): - current = Process.current() - if isinstance(current, Process): - self._parent_pid = current.pid # type: ignore[assignment] - self._pid = self._create_and_setup_db_record() # pylint: disable=attribute-defined-outside-init - - @override - def on_entered(self, from_state: Optional[plumpy.process_states.State]) -> None: - """After entering a new state, save a checkpoint and update the latest process state change timestamp.""" - # pylint: disable=cyclic-import - from aiida.engine.utils import set_process_state_change_timestamp - - # For reasons unknown, it is important to update the outputs first, before doing anything else, otherwise there - # is the risk that certain outputs do not get attached before the process reaches a terminal state. Nevertheless - # we need to guarantee that the process state gets updated even if the ``update_outputs`` call excepts, for - # example if the process implementation attaches an invalid output through ``Process.out``, and so we call the - # ``ProcessNode.set_process_state`` in the finally-clause. This way the state gets properly set on the node even - # if the process is transitioning to the terminal excepted state. - try: - self.update_outputs() - except ValueError: # pylint: disable=try-except-raise - raise - finally: - self.node.set_process_state(self._state.LABEL) # type: ignore - - self._save_checkpoint() - set_process_state_change_timestamp(self) - super().on_entered(from_state) - - @override - def on_terminated(self) -> None: - """Called when a Process enters a terminal state.""" - super().on_terminated() - if self._enable_persistence: - try: - assert self.runner.persister is not None - self.runner.persister.delete_checkpoint(self.pid) - except Exception as error: # pylint: disable=broad-except - self.logger.exception('Failed to delete checkpoint: %s', error) - - try: - self.node.seal() - except exceptions.ModificationNotAllowed: - pass - - @override - def on_except(self, exc_info: Tuple[Any, Exception, TracebackType]) -> None: - """ - Log the exception by calling the report method with formatted stack trace from exception info object - and store the exception string as a node attribute - - :param exc_info: the sys.exc_info() object (type, value, traceback) - """ - super().on_except(exc_info) - self.node.set_exception(''.join(traceback.format_exception(exc_info[0], exc_info[1], None)).rstrip()) - self.report(''.join(traceback.format_exception(*exc_info))) - - @override - def on_finish(self, result: Union[int, ExitCode], successful: bool) -> None: - """ Set the finish status on the process node. - - :param result: result of the process - :param successful: whether execution was successful - - """ - super().on_finish(result, successful) - - if result is None: - if not successful: - result = self.exit_codes.ERROR_MISSING_OUTPUT # pylint: disable=no-member - else: - result = ExitCode() - - if isinstance(result, int): - self.node.set_exit_status(result) - elif isinstance(result, ExitCode): - self.node.set_exit_status(result.status) - self.node.set_exit_message(result.message) - else: - raise ValueError( - f'the result should be an integer, ExitCode or None, got {type(result)} {result} {self.pid}' - ) - - @override - def on_paused(self, msg: Optional[str] = None) -> None: - """ - The Process was paused so set the paused attribute on the process node - - :param msg: message - - """ - super().on_paused(msg) - self._save_checkpoint() - self.node.pause() - - @override - def on_playing(self) -> None: - """ - The Process was unpaused so remove the paused attribute on the process node - """ - super().on_playing() - self.node.unpause() - - @override - def on_output_emitting(self, output_port: str, value: Any) -> None: - """ - The process has emitted a value on the given output port. - - :param output_port: The output port name the value was emitted on - :param value: The value emitted - - """ - super().on_output_emitting(output_port, value) - - # Note that `PortNamespaces` should be able to receive non `Data` types such as a normal dictionary - if isinstance(output_port, OutputPort) and not isinstance(value, orm.Data): - raise TypeError(f'Processes can only return `orm.Data` instances as output, got {value.__class__}') - - def set_status(self, status: Optional[str]) -> None: - """ - The status of the Process is about to be changed, so we reflect this is in node's attribute proxy. - - :param status: the status message - - """ - super().set_status(status) - self.node.set_process_status(status) - - def submit(self, process: Type['Process'], **kwargs) -> orm.ProcessNode: - """Submit process for execution. - - :param process: process - :return: the calculation node of the process - - """ - return self.runner.submit(process, **kwargs) - - @property - def runner(self) -> 'Runner': - """Get process runner.""" - return self._runner - - def get_parent_calc(self) -> Optional[orm.ProcessNode]: - """ - Get the parent process node - - :return: the parent process node if there is one - - """ - # Can't get it if we don't know our parent - if self._parent_pid is None: - return None - - return orm.load_node(pk=self._parent_pid) # type: ignore - - @classmethod - def build_process_type(cls) -> str: - """ - The process type. - - :return: string of the process type - - Note: This could be made into a property 'process_type' but in order to have it be a property of the class - it would need to be defined in the metaclass, see https://bugs.python.org/issue20659 - """ - from aiida.plugins.entry_point import get_entry_point_string_from_class - - class_module = cls.__module__ - class_name = cls.__name__ - - # If the process is a registered plugin the corresponding entry point will be used as process type - process_type = get_entry_point_string_from_class(class_module, class_name) - - # If no entry point was found, default to fully qualified path name - if process_type is None: - return f'{class_module}.{class_name}' - - return process_type - - def report(self, msg: str, *args, **kwargs) -> None: - """Log a message to the logger, which should get saved to the database through the attached DbLogHandler. - - The pk, class name and function name of the caller are prepended to the given message - - :param msg: message to log - :param args: args to pass to the log call - :param kwargs: kwargs to pass to the log call - - """ - message = f'[{self.node.pk}|{self.__class__.__name__}|{inspect.stack()[1][3]}]: {msg}' - self.logger.log(LOG_LEVEL_REPORT, message, *args, **kwargs) - - def _create_and_setup_db_record(self) -> Union[int, UUID]: - """ - Create and setup the database record for this process - - :return: the uuid or pk of the process - - """ - self._node = self.get_or_create_db_record() - self._setup_db_record() - if self.metadata.store_provenance: - try: - self.node.store_all() - if self.node.is_finished_ok: - self._state = Finished(self, None, True) # pylint: disable=attribute-defined-outside-init - for entry in self.node.base.links.get_outgoing(link_type=LinkType.RETURN): - if entry.link_label.endswith(f'_{entry.node.pk}'): - continue - label = entry.link_label.replace(PORT_NAMESPACE_SEPARATOR, self.spec().namespace_separator) - self.out(label, entry.node) - # This is needed for CalcJob. In that case, the outputs are - # returned regardless of whether they end in '_pk' - for entry in self.node.base.links.get_outgoing(link_type=LinkType.CREATE): - label = entry.link_label.replace(PORT_NAMESPACE_SEPARATOR, self.spec().namespace_separator) - self.out(label, entry.node) - except exceptions.ModificationNotAllowed: - # The calculation was already stored - pass - else: - # Cannot persist the process if were not storing provenance because that would require a stored node - self._enable_persistence = False - - if self.node.pk is not None: - return self.node.pk - - return UUID(self.node.uuid) - - @override - def encode_input_args(self, inputs: Dict[str, Any]) -> str: - """ - Encode input arguments such that they may be saved in a Bundle - - :param inputs: A mapping of the inputs as passed to the process - :return: The encoded (serialized) inputs - """ - return serialize.serialize(inputs) - - @override - def decode_input_args(self, encoded: str) -> Dict[str, Any]: - """ - Decode saved input arguments as they came from the saved instance state Bundle - - :param encoded: encoded (serialized) inputs - :return: The decoded input args - """ - return serialize.deserialize_unsafe(encoded) - - def update_outputs(self) -> None: - """Attach new outputs to the node since the last call. - - Does nothing, if self.metadata.store_provenance is False. - """ - if self.metadata.store_provenance is False: - return - - outputs_flat = self._flat_outputs() - outputs_stored = self.node.base.links.get_outgoing(link_type=(LinkType.CREATE, LinkType.RETURN) - ).all_link_labels() - outputs_new = set(outputs_flat.keys()) - set(outputs_stored) - - for link_label, output in outputs_flat.items(): - - if link_label not in outputs_new: - continue - - if isinstance(self.node, orm.CalculationNode): - output.base.links.add_incoming(self.node, LinkType.CREATE, link_label) - elif isinstance(self.node, orm.WorkflowNode): - output.base.links.add_incoming(self.node, LinkType.RETURN, link_label) - - output.store() - - def _build_process_label(self) -> str: - """Construct the process label that should be set on ``ProcessNode`` instances for this process class. - - .. note:: By default this returns the name of the process class itself. It can be overridden by ``Process`` - subclasses to provide a more specific label. - - :returns: The process label to use for ``ProcessNode`` instances. - """ - return self.__class__.__name__ - - def _setup_db_record(self) -> None: - """ - Create the database record for this process and the links with respect to its inputs - - This function will set various attributes on the node that serve as a proxy for attributes of the Process. - This is essential as otherwise this information could only be introspected through the Process itself, which - is only available to the interpreter that has it in memory. To make this data introspectable from any - interpreter, for example for the command line interface, certain Process attributes are proxied through the - calculation node. - - In addition, the parent calculation will be setup with a CALL link if applicable and all inputs will be - linked up as well. - """ - assert self.inputs is not None - assert not self.node.is_sealed, 'process node cannot be sealed when setting up the database record' - - # Store important process attributes in the node proxy - self.node.set_process_state(None) - self.node.set_process_label(self._build_process_label()) - self.node.set_process_type(self.__class__.build_process_type()) - - parent_calc = self.get_parent_calc() - - if parent_calc and self.metadata.store_provenance: - - if isinstance(parent_calc, orm.CalculationNode): - raise exceptions.InvalidOperation('calling processes from a calculation type process is forbidden.') - - if isinstance(self.node, orm.CalculationNode): - self.node.base.links.add_incoming(parent_calc, LinkType.CALL_CALC, self.metadata.call_link_label) - - elif isinstance(self.node, orm.WorkflowNode): - self.node.base.links.add_incoming(parent_calc, LinkType.CALL_WORK, self.metadata.call_link_label) - - self._setup_metadata(copy.copy(dict(self.inputs.metadata))) - self._setup_version_info() - self._setup_inputs() - - def _setup_version_info(self) -> None: - """Store relevant plugin version information.""" - from aiida.plugins.entry_point import format_entry_point_string - - if self.inputs is None: - return - - version_info = self.runner.plugin_version_provider.get_version_info(self.__class__) - - for key, monitor in self.inputs.get('monitors', {}).items(): - entry_point = monitor.base.attributes.get('entry_point') - entry_point_string = format_entry_point_string('aiida.calculations.monitors', entry_point) - monitor_version_info = self.runner.plugin_version_provider.get_version_info(entry_point_string) - version_info['version'].setdefault('monitors', {})[key] = monitor_version_info['version']['plugin'] - - self.node.base.attributes.set_many(version_info) - - def _setup_metadata(self, metadata: dict) -> None: - """Store the metadata on the ProcessNode.""" - for name, value in metadata.items(): - if name in ['store_provenance', 'dry_run', 'call_link_label']: - continue - - if name == 'label': - self.node.label = value - elif name == 'description': - self.node.description = value - else: - raise RuntimeError(f'unsupported metadata key: {name}') - - # Store JSON-serializable values of ``metadata`` ports in the node's attributes. Note that instead of passing in - # the ``metadata`` inputs directly, the entire namespace of raw inputs is passed. The reason is that although - # currently in ``aiida-core`` all input ports with ``is_metadata=True`` in the port specification are located - # within the ``metadata`` port namespace, this may not always be the case. The ``_filter_serializable_metadata`` - # method will filter out all ports that set ``is_metadata=True`` no matter where in the namespace they are - # defined so this approach is more robust for the future. - serializable_inputs = self._filter_serializable_metadata(self.spec().inputs, self.raw_inputs) - pruned = prune_mapping(serializable_inputs) - self.node.set_metadata_inputs(pruned) - - def _setup_inputs(self) -> None: - """Create the links between the input nodes and the ProcessNode that represents this process.""" - for name, node in self._flat_inputs().items(): - - # Certain processes allow to specify ports with `None` as acceptable values - if node is None: - continue - - # Need this special case for tests that use ProcessNodes as classes - if isinstance(self.node, orm.CalculationNode): - self.node.base.links.add_incoming(node, LinkType.INPUT_CALC, name) - - elif isinstance(self.node, orm.WorkflowNode): - self.node.base.links.add_incoming(node, LinkType.INPUT_WORK, name) - - def _filter_serializable_metadata( - self, - port: Union[None, InputPort, PortNamespace], - port_value: Any, - ) -> Union[Any, None]: - """Return the inputs that correspond to ports with ``is_metadata=True`` and that are JSON serializable. - - The function is called recursively for any port namespaces. - - :param port: An ``InputPort`` or ``PortNamespace``. If an ``InputPort`` that specifies ``is_metadata=True`` the - ``port_value`` is returned. For a ``PortNamespace`` this method is called recursively for the keys within - the namespace and the resulting dictionary is returned, omitting ``None`` values. If either ``port`` or - ``port_value`` is ``None``, ``None`` is returned. - :return: The ``port_value`` where all inputs that do no correspond to a metadata port or are not JSON - serializable, have been filtered out. - """ - if port is None or port_value is None: - return None - - if isinstance(port, InputPort): - if not port.is_metadata: - return None - - try: - clean_value(port_value) - except exceptions.ValidationError: - return None - return port_value - - result = {} - - for key, value in port_value.items(): - if key not in port: - continue - - metadata_value = self._filter_serializable_metadata(port[key], value) # type: ignore[arg-type] - - if metadata_value is None: - continue - - result[key] = metadata_value - - return result or None - - def _flat_inputs(self) -> Dict[str, Any]: - """ - Return a flattened version of the parsed inputs dictionary. - - The eventual keys will be a concatenation of the nested keys. Note that the `metadata` dictionary, if present, - is not passed, as those are dealt with separately in `_setup_metadata`. - - :return: flat dictionary of parsed inputs - - """ - assert self.inputs is not None - inputs = {key: value for key, value in self.inputs.items() if key != self.spec().metadata_key} - return dict(self._flatten_inputs(self.spec().inputs, inputs)) - - def _flat_outputs(self) -> Dict[str, Any]: - """ - Return a flattened version of the registered outputs dictionary. - - The eventual keys will be a concatenation of the nested keys. - - :return: flat dictionary of parsed outputs - """ - return dict(self._flatten_outputs(self.spec().outputs, self.outputs)) - - def _flatten_inputs( - self, - port: Union[None, InputPort, PortNamespace], - port_value: Any, - parent_name: str = '', - separator: str = PORT_NAMESPACE_SEPARATOR - ) -> List[Tuple[str, Any]]: - """ - Function that will recursively flatten the inputs dictionary, omitting inputs for ports that - are marked as being non database storable - - :param port: port against which to map the port value, can be InputPort or PortNamespace - :param port_value: value for the current port, can be a Mapping - :param parent_name: the parent key with which to prefix the keys - :param separator: character to use for the concatenation of keys - :return: flat list of inputs - - """ - if (port is None and - isinstance(port_value, - orm.Node)) or (isinstance(port, InputPort) and not (port.is_metadata or port.non_db)): - return [(parent_name, port_value)] - - if port is None and isinstance(port_value, Mapping) or isinstance(port, PortNamespace): - items = [] - for name, value in port_value.items(): - - prefixed_key = parent_name + separator + name if parent_name else name - - try: - nested_port = cast(Union[InputPort, PortNamespace], port[name]) if port else None - except (KeyError, TypeError): - nested_port = None - - sub_items = self._flatten_inputs( - port=nested_port, port_value=value, parent_name=prefixed_key, separator=separator - ) - items.extend(sub_items) - return items - - assert (port is None) or (isinstance(port, InputPort) and (port.is_metadata or port.non_db)) - return [] - - def _flatten_outputs( - self, - port: Union[None, OutputPort, PortNamespace], - port_value: Any, - parent_name: str = '', - separator: str = PORT_NAMESPACE_SEPARATOR - ) -> List[Tuple[str, Any]]: - """ - Function that will recursively flatten the outputs dictionary. - - :param port: port against which to map the port value, can be OutputPort or PortNamespace - :param port_value: value for the current port, can be a Mapping - :param parent_name: the parent key with which to prefix the keys - :param separator: character to use for the concatenation of keys - - :return: flat list of outputs - - """ - if port is None and isinstance(port_value, orm.Node) or isinstance(port, OutputPort): - return [(parent_name, port_value)] - - if (port is None and isinstance(port_value, Mapping) or isinstance(port, PortNamespace)): - items = [] - for name, value in port_value.items(): - - prefixed_key = parent_name + separator + name if parent_name else name - - try: - nested_port = cast(Union[OutputPort, PortNamespace], port[name]) if port else None - except (KeyError, TypeError): - nested_port = None - - sub_items = self._flatten_outputs( - port=nested_port, port_value=value, parent_name=prefixed_key, separator=separator - ) - items.extend(sub_items) - return items - - assert port is None, port - return [] - - def exposed_inputs( - self, - process_class: Type['Process'], - namespace: Optional[str] = None, - agglomerate: bool = True - ) -> AttributeDict: - """Gather a dictionary of the inputs that were exposed for a given Process class under an optional namespace. - - :param process_class: Process class whose inputs to try and retrieve - :param namespace: PortNamespace in which to look for the inputs - :param agglomerate: If set to true, all parent namespaces of the given ``namespace`` will also be - searched for inputs. Inputs in lower-lying namespaces take precedence. - - :returns: exposed inputs - - """ - exposed_inputs = {} - - namespace_list = self._get_namespace_list(namespace=namespace, agglomerate=agglomerate) - for sub_namespace in namespace_list: - - # The sub_namespace None indicates the base level sub_namespace - if sub_namespace is None: - inputs = self.inputs - port_namespace = self.spec().inputs - else: - inputs = self.inputs - for part in sub_namespace.split('.'): - inputs = inputs[part] - try: - port_namespace = self.spec().inputs.get_port(sub_namespace) # type: ignore[assignment] - except KeyError: - raise ValueError(f'this process does not contain the "{sub_namespace}" input namespace') - - # Get the list of ports that were exposed for the given Process class in the current sub_namespace - exposed_inputs_list = self.spec()._exposed_inputs[sub_namespace][process_class] # pylint: disable=protected-access - - for name in port_namespace.ports.keys(): - if inputs and name in inputs and name in exposed_inputs_list: - exposed_inputs[name] = inputs[name] - - return AttributeDict(exposed_inputs) - - def exposed_outputs( - self, - node: orm.ProcessNode, - process_class: Type['Process'], - namespace: Optional[str] = None, - agglomerate: bool = True - ) -> AttributeDict: - """Return the outputs which were exposed from the ``process_class`` and emitted by the specific ``node`` - - :param node: process node whose outputs to try and retrieve - :param namespace: Namespace in which to search for exposed outputs. - :param agglomerate: If set to true, all parent namespaces of the given ``namespace`` will also - be searched for outputs. Outputs in lower-lying namespaces take precedence. - - :returns: exposed outputs - - """ - namespace_separator = self.spec().namespace_separator - - output_key_map = {} - # maps the exposed name to all outputs that belong to it - top_namespace_map = collections.defaultdict(list) - link_types = (LinkType.CREATE, LinkType.RETURN) - process_outputs_dict = node.base.links.get_outgoing(link_type=link_types).nested() - - for port_name in process_outputs_dict: - top_namespace = port_name.split(namespace_separator)[0] - top_namespace_map[top_namespace].append(port_name) - - for port_namespace in self._get_namespace_list(namespace=namespace, agglomerate=agglomerate): - # only the top-level key is stored in _exposed_outputs - for top_name in top_namespace_map: - if namespace is not None and namespace not in self.spec()._exposed_outputs: # pylint: disable=protected-access - raise KeyError(f'the namespace `{namespace}` is not an exposed namespace.') - if top_name in self.spec()._exposed_outputs[port_namespace][process_class]: # pylint: disable=protected-access - output_key_map[top_name] = port_namespace - - result = {} - - for top_name, port_namespace in output_key_map.items(): - # collect all outputs belonging to the given top_name - for port_name in top_namespace_map[top_name]: - if port_namespace is None: - result[port_name] = process_outputs_dict[port_name] - else: - result[port_namespace + namespace_separator + port_name] = process_outputs_dict[port_name] - - return AttributeDict(result) - - @staticmethod - def _get_namespace_list(namespace: Optional[str] = None, agglomerate: bool = True) -> List[Optional[str]]: - """Get the list of namespaces in a given namespace. - - :param namespace: name space - :param agglomerate: If set to true, all parent namespaces of the given ``namespace`` will also - be searched. - - :returns: namespace list - - """ - if not agglomerate: - return [namespace] - - namespace_list: List[Optional[str]] = [None] - if namespace is not None: - split_ns = namespace.split('.') - namespace_list.extend(['.'.join(split_ns[:i]) for i in range(1, len(split_ns) + 1)]) - return namespace_list - - @classmethod - def is_valid_cache(cls, node: orm.ProcessNode) -> bool: - """Check if the given node can be cached from. - - Overriding this method allows ``Process`` sub-classes to modify when - corresponding process nodes are considered as a cache. - - .. warning :: When overriding this method, make sure to return ``False`` - *at least* in all cases when ``super()._node.base.caching.is_valid_cache(node)`` - returns ``False``. Otherwise, the ``invalidates_cache`` keyword on exit - codes may have no effect. - - """ - exit_status = node.exit_status - if exit_status is None: - return True - try: - return not cls.spec().exit_codes(exit_status).invalidates_cache - except ValueError: - return True - - -def get_query_string_from_process_type_string(process_type_string: str) -> str: # pylint: disable=invalid-name - """ - Take the process type string of a Node and create the queryable type string. - - :param process_type_string: the process type string - :type process_type_string: str - - :return: string that can be used to query for subclasses of the process type using 'LIKE ' - :rtype: str - """ - if ':' in process_type_string: - return f'{process_type_string}.' - - path = process_type_string.rsplit('.', 2)[0] - return f'{path}.' diff --git a/aiida/engine/processes/utils.py b/aiida/engine/processes/utils.py deleted file mode 100644 index 340131a78b..0000000000 --- a/aiida/engine/processes/utils.py +++ /dev/null @@ -1,26 +0,0 @@ -# -*- coding: utf-8 -*- -"""Module with utilities.""" -from collections.abc import Mapping - -from aiida.orm import Node - - -def prune_mapping(value): - """Prune a nested mapping from all mappings that are completely empty. - - .. note:: A nested mapping that is completely empty means it contains at most other empty mappings. Other null - values, such as `None` or empty lists, should not be pruned. - - :param value: A nested mapping of port values. - :return: The same mapping but without any nested namespace that is completely empty. - """ - if isinstance(value, Mapping) and not isinstance(value, Node): - result = {} - for key, sub_value in value.items(): - pruned = prune_mapping(sub_value) - # If `pruned` is an "empty'ish" mapping and not an instance of `Node`, skip it, otherwise keep it. - if not (isinstance(pruned, Mapping) and not pruned and not isinstance(pruned, Node)): - result[key] = pruned - return result - - return value diff --git a/aiida/engine/processes/workchains/__init__.py b/aiida/engine/processes/workchains/__init__.py deleted file mode 100644 index 56b6a94d2d..0000000000 --- a/aiida/engine/processes/workchains/__init__.py +++ /dev/null @@ -1,40 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module for the `WorkChain` process and related utilities.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .awaitable import * -from .context import * -from .restart import * -from .utils import * -from .workchain import * - -__all__ = ( - 'Awaitable', - 'AwaitableAction', - 'AwaitableTarget', - 'BaseRestartWorkChain', - 'ProcessHandlerReport', - 'ToContext', - 'WorkChain', - 'append_', - 'assign_', - 'construct_awaitable', - 'if_', - 'process_handler', - 'return_', - 'while_', -) - -# yapf: enable diff --git a/aiida/engine/processes/workchains/utils.py b/aiida/engine/processes/workchains/utils.py deleted file mode 100644 index c98d2dc9cc..0000000000 --- a/aiida/engine/processes/workchains/utils.py +++ /dev/null @@ -1,139 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Utilities for `WorkChain` implementations.""" -from functools import partial -from inspect import getfullargspec -from types import FunctionType # pylint: disable=no-name-in-module -from typing import List, NamedTuple, Optional, Union - -from wrapt import decorator - -from ..exit_code import ExitCode - -__all__ = ('ProcessHandlerReport', 'process_handler') - - -class ProcessHandlerReport(NamedTuple): - """A namedtuple to define a process handler report for a :class:`aiida.engine.BaseRestartWorkChain`. - - This namedtuple should be returned by a process handler of a work chain instance if the condition of the handler was - met by the completed process. If no further handling should be performed after this method the `do_break` field - should be set to `True`. - If the handler encountered a fatal error and the work chain needs to be terminated, an `ExitCode` with - non-zero exit status can be set. This exit code is what will be set on the work chain itself. This works because the - value of the `exit_code` field returned by the handler, will in turn be returned by the `inspect_process` step and - returning a non-zero exit code from any work chain step will instruct the engine to abort the work chain. - - :param do_break: boolean, set to `True` if no further process handlers should be called, default is `False` - :param exit_code: an instance of the :class:`~aiida.engine.processes.exit_code.ExitCode` tuple. - If not explicitly set, the default `ExitCode` will be instantiated, - which has status `0` meaning that the work chain step will be considered - successful and the work chain will continue to the next step. - """ - do_break: bool = False - exit_code: ExitCode = ExitCode() - - -def process_handler( - wrapped: Optional[FunctionType] = None, - *, - priority: int = 0, - exit_codes: Union[None, ExitCode, List[ExitCode]] = None, - enabled: bool = True -) -> FunctionType: - """Decorator to register a :class:`~aiida.engine.BaseRestartWorkChain` instance method as a process handler. - - The decorator will validate the `priority` and `exit_codes` optional keyword arguments and then add itself as an - attribute to the `wrapped` instance method. This is used in the `inspect_process` to return all instance methods of - the class that have been decorated by this function and therefore are considered to be process handlers. - - Requirements on the function signature of process handling functions. The function to which the decorator is applied - needs to take two arguments: - - * `self`: This is the instance of the work chain itself - * `node`: This is the process node that finished and is to be investigated - - The function body typically consists of a conditional that will check for a particular problem that might have - occurred for the sub process. If a particular problem is handled, the process handler should return an instance of - the :class:`aiida.engine.ProcessHandlerReport` tuple. If no other process handlers should be considered, the set - `do_break` attribute should be set to `True`. If the work chain is to be aborted entirely, the `exit_code` of the - report can be set to an `ExitCode` instance with a non-zero status. - - :param wrapped: the work chain method to register the process handler with - :param priority: optional integer that defines the order in which registered handlers will be called during the - handling of a finished process. Higher priorities will be handled first. Default value is `0`. Multiple handlers - with the same priority is allowed, but the order of those is not well defined. - :param exit_codes: single or list of `ExitCode` instances. If defined, the handler will return `None` if the exit - code set on the `node` does not appear in the `exit_codes`. This is useful to have a handler called only when - the process failed with a specific exit code. - :param enabled: boolean, by default True, which will cause the handler to be called during `inspect_process`. When - set to `False`, the handler will be skipped. This static value can be overridden on a per work chain instance - basis through the input `handler_overrides`. - """ - if wrapped is None: - return partial( - process_handler, priority=priority, exit_codes=exit_codes, enabled=enabled - ) # type: ignore[return-value] - - if not isinstance(wrapped, FunctionType): - raise TypeError('first argument can only be an instance method, use keywords for decorator arguments.') - - if not isinstance(priority, int): - raise TypeError('the `priority` keyword should be an integer.') - - if exit_codes is not None and not isinstance(exit_codes, list): - exit_codes = [exit_codes] - - if exit_codes and any(not isinstance(exit_code, ExitCode) for exit_code in exit_codes): - raise TypeError('`exit_codes` keyword should be an instance of `ExitCode` or list thereof.') - - if not isinstance(enabled, bool): - raise TypeError('the `enabled` keyword should be a boolean.') - - handler_args = getfullargspec(wrapped)[0] - - if len(handler_args) != 2: - raise TypeError(f'process handler `{wrapped.__name__}` has invalid signature: should be (self, node)') - - wrapped.decorator = process_handler # type: ignore[attr-defined] - wrapped.priority = priority # type: ignore[attr-defined] - wrapped.enabled = enabled # type: ignore[attr-defined] - - @decorator - def wrapper(wrapped, instance, args, kwargs): - - # When the handler will be called by the `BaseRestartWorkChain` it will pass the node as the only argument - node = args[0] - - if exit_codes is not None and node.exit_status not in [ - exit_code.status for exit_code in exit_codes # type: ignore[union-attr] - ]: - result = None - else: - result = wrapped(*args, **kwargs) - - # Append the name and return value of the current process handler to the `considered_handlers` extra. - try: - considered_handlers = instance.node.base.extras.get(instance._considered_handlers_extra, []) # pylint: disable=protected-access - current_process = considered_handlers[-1] - except IndexError: - # The extra was never initialized, so we skip this functionality - pass - else: - # Append the name of the handler to the last list in `considered_handlers` and save it - serialized = result - if isinstance(serialized, ProcessHandlerReport): - serialized = {'do_break': serialized.do_break, 'exit_status': serialized.exit_code.status} - current_process.append((wrapped.__name__, serialized)) - instance.node.base.extras.set(instance._considered_handlers_extra, considered_handlers) # pylint: disable=protected-access - - return result - - return wrapper(wrapped) # pylint: disable=no-value-for-parameter diff --git a/aiida/engine/processes/workchains/workchain.py b/aiida/engine/processes/workchains/workchain.py deleted file mode 100644 index e6ca21a4b4..0000000000 --- a/aiida/engine/processes/workchains/workchain.py +++ /dev/null @@ -1,416 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Components for the WorkChain concept of the workflow engine.""" -from __future__ import annotations - -import collections.abc -import functools -import logging -import typing as t - -from plumpy.persistence import auto_persist -from plumpy.process_states import Continue, Wait -from plumpy.processes import ProcessStateMachineMeta -from plumpy.workchains import Stepper -from plumpy.workchains import WorkChainSpec as PlumpyWorkChainSpec -from plumpy.workchains import _PropagateReturn, if_, return_, while_ - -from aiida.common import exceptions -from aiida.common.extendeddicts import AttributeDict -from aiida.common.lang import override -from aiida.orm import Node, ProcessNode, WorkChainNode -from aiida.orm.utils import load_node - -from ..exit_code import ExitCode -from ..process import Process, ProcessState -from ..process_spec import ProcessSpec -from .awaitable import Awaitable, AwaitableAction, AwaitableTarget, construct_awaitable - -if t.TYPE_CHECKING: - from aiida.engine.runners import Runner # pylint: disable=unused-import - -__all__ = ('WorkChain', 'if_', 'while_', 'return_') - - -class WorkChainSpec(ProcessSpec, PlumpyWorkChainSpec): - pass - - -MethodType = t.TypeVar('MethodType') - - -class Protect(ProcessStateMachineMeta): - """Metaclass that allows protecting class methods from being overridden by subclasses. - - Usage as follows:: - - class SomeClass(metaclass=Protect): - - @Protect.final - def private_method(self): - "This method cannot be overridden by a subclass." - - If a subclass is imported that overrides the subclass, a ``RuntimeError`` is raised. - """ - - __SENTINEL = object() - - def __new__(mcs, name, bases, namespace, **kwargs): - """Collect all methods that were marked as protected and raise if the subclass defines it. - - :raises RuntimeError: If the new class defines (i.e. overrides) a method that was decorated with ``final``. - """ - private = { - key for base in bases for key, value in vars(base).items() if callable(value) and mcs.__is_final(value) - } - for key in namespace: - if key in private: - raise RuntimeError(f'the method `{key}` is protected cannot be overridden.') - return super().__new__(mcs, name, bases, namespace, **kwargs) - - @classmethod - def __is_final(mcs, method) -> bool: # pylint: disable=unused-private-member - """Return whether the method has been decorated by the ``final`` classmethod. - - :return: Boolean, ``True`` if the method is marked as final, ``False`` otherwise. - """ - try: - return method.__final is mcs.__SENTINEL # pylint: disable=protected-access - except AttributeError: - return False - - @classmethod - def final(mcs, method: MethodType) -> MethodType: - """Decorate a method with this method to protect it from being overridden. - - Adds the ``__SENTINEL`` object as the ``__final`` private attribute to the given ``method`` and wraps it in - the ``typing.final`` decorator. The latter indicates to typing systems that it cannot be overridden in - subclasses. - """ - method.__final = mcs.__SENTINEL # type: ignore[attr-defined] # pylint: disable=protected-access,unused-private-member - return t.final(method) - - -@auto_persist('_awaitables') -class WorkChain(Process, metaclass=Protect): - """The `WorkChain` class is the principle component to implement workflows in AiiDA.""" - - _node_class = WorkChainNode - _spec_class = WorkChainSpec - _STEPPER_STATE = 'stepper_state' - _CONTEXT = 'CONTEXT' - - def __init__( - self, - inputs: dict | None = None, - logger: logging.Logger | None = None, - runner: 'Runner' | None = None, - enable_persistence: bool = True - ) -> None: - """Construct a WorkChain instance. - - Construct the instance only if it is a sub class of `WorkChain`, otherwise raise `InvalidOperation`. - - :param inputs: work chain inputs - :param logger: aiida logger - :param runner: work chain runner - :param enable_persistence: whether to persist this work chain - - """ - if self.__class__ == WorkChain: - raise exceptions.InvalidOperation('cannot construct or launch a base `WorkChain` class.') - - super().__init__(inputs, logger, runner, enable_persistence=enable_persistence) - - self._stepper: Stepper | None = None - self._awaitables: list[Awaitable] = [] - self._context = AttributeDict() - - @classmethod - def spec(cls) -> WorkChainSpec: - return super().spec() # type: ignore[return-value] - - @property - def node(self) -> WorkChainNode: - return super().node # type: ignore - - @property - def ctx(self) -> AttributeDict: - """Get the context.""" - return self._context - - @override - def save_instance_state(self, out_state, save_context): - """Save instance state. - - :param out_state: state to save in - - :param save_context: - :type save_context: :class:`!plumpy.persistence.LoadSaveContext` - - """ - super().save_instance_state(out_state, save_context) - # Save the context - out_state[self._CONTEXT] = self.ctx - - # Ask the stepper to save itself - if self._stepper is not None: - out_state[self._STEPPER_STATE] = self._stepper.save() - - @override - def load_instance_state(self, saved_state, load_context): - super().load_instance_state(saved_state, load_context) - # Load the context - self._context = saved_state[self._CONTEXT] - - # Recreate the stepper - self._stepper = None - stepper_state = saved_state.get(self._STEPPER_STATE, None) - if stepper_state is not None: - self._stepper = self.spec().get_outline().recreate_stepper(stepper_state, self) # type: ignore[arg-type] - - self.set_logger(self.node.logger) - - if self._awaitables: - self._action_awaitables() - - @Protect.final - def on_run(self): - super().on_run() - self.node.set_stepper_state_info(str(self._stepper)) - - def _resolve_nested_context(self, key: str) -> tuple[AttributeDict, str]: - """ - Returns a reference to a sub-dictionary of the context and the last key, - after resolving a potentially segmented key where required sub-dictionaries are created as needed. - - :param key: A key into the context, where words before a dot are interpreted as a key for a sub-dictionary - """ - ctx = self.ctx - ctx_path = key.split('.') - - for index, path in enumerate(ctx_path[:-1]): - try: - ctx = ctx[path] - except KeyError: # see below why this is the only exception we have to catch here - ctx[path] = AttributeDict() # create the sub-dict and update the context - ctx = ctx[path] - continue - - # Notes: - # * the first ctx (self.ctx) is guaranteed to be an AttributeDict, hence the post-"dereference" checking - # * the values can be many different things: on insertion they are either AtrributeDict, List or Awaitables - # (subclasses of AttributeDict) but after resolution of an Awaitable this will be the value itself - # * assumption: a resolved value is never a plain AttributeDict, on the other hand if a resolved Awaitable - # would be an AttributeDict we can append things to it since the order of tasks is maintained. - if type(ctx) != AttributeDict: # pylint: disable=C0123 - raise ValueError( - f'Can not update the context for key `{key}`:' - f' found instance of `{type(ctx)}` at `{".".join(ctx_path[:index+1])}`, expected AttributeDict' - ) - - return ctx, ctx_path[-1] - - def _insert_awaitable(self, awaitable: Awaitable) -> None: - """Insert an awaitable that should be terminated before before continuing to the next step. - - :param awaitable: the thing to await - """ - ctx, key = self._resolve_nested_context(awaitable.key) - - # Already assign the awaitable itself to the location in the context container where it is supposed to end up - # once it is resolved. This is especially important for the `APPEND` action, since it needs to maintain the - # order, but the awaitables will not necessarily be resolved in the order in which they are added. By using the - # awaitable as a placeholder, in the `_resolve_awaitable`, it can be found and replaced by the resolved value. - if awaitable.action == AwaitableAction.ASSIGN: - ctx[key] = awaitable - elif awaitable.action == AwaitableAction.APPEND: - ctx.setdefault(key, []).append(awaitable) - else: - raise AssertionError(f'Unsupported awaitable action: {awaitable.action}') - - self._awaitables.append( - awaitable - ) # add only if everything went ok, otherwise we end up in an inconsistent state - self._update_process_status() - - def _resolve_awaitable(self, awaitable: Awaitable, value: t.Any) -> None: - """Resolve an awaitable. - - Precondition: must be an awaitable that was previously inserted. - - :param awaitable: the awaitable to resolve - """ - - ctx, key = self._resolve_nested_context(awaitable.key) - - if awaitable.action == AwaitableAction.ASSIGN: - ctx[key] = value - elif awaitable.action == AwaitableAction.APPEND: - # Find the same awaitable inserted in the context - container = ctx[key] - for index, placeholder in enumerate(container): - if isinstance(placeholder, Awaitable) and placeholder.pk == awaitable.pk: - container[index] = value - break - else: - raise AssertionError(f'Awaitable `{awaitable.pk} was not found in `ctx.{awaitable.key}`') - else: - raise AssertionError(f'Unsupported awaitable action: {awaitable.action}') - - awaitable.resolved = True - self._awaitables.remove(awaitable) # remove only if everything went ok, otherwise we may lose track - - if not self.has_terminated(): - # the process may be terminated, for example, if the process was killed or excepted - # then we should not try to update it - self._update_process_status() - - @Protect.final - def to_context(self, **kwargs: Awaitable | ProcessNode) -> None: - """Add a dictionary of awaitables to the context. - - This is a convenience method that provides syntactic sugar, for a user to add multiple intersteps that will - assign a certain value to the corresponding key in the context of the work chain. - """ - for key, value in kwargs.items(): - awaitable = construct_awaitable(value) - awaitable.key = key - self._insert_awaitable(awaitable) - - def _update_process_status(self) -> None: - """Set the process status with a message accounting the current sub processes that we are waiting for.""" - if self._awaitables: - status = f"Waiting for child processes: {', '.join([str(_.pk) for _ in self._awaitables])}" - self.node.set_process_status(status) - else: - self.node.set_process_status(None) - - @override - @Protect.final - def run(self) -> t.Any: - self._stepper = self.spec().get_outline().create_stepper(self) # type: ignore[arg-type] - return self._do_step() - - def _do_step(self) -> t.Any: - """Execute the next step in the outline and return the result. - - If the stepper returns a non-finished status and the return value is of type ToContext, the contents of the - ToContext container will be turned into awaitables if necessary. If any awaitables were created, the process - will enter in the Wait state, otherwise it will go to Continue. When the stepper returns that it is done, the - stepper result will be converted to None and returned, unless it is an integer or instance of ExitCode. - """ - from .context import ToContext - - self._awaitables = [] - result: t.Any = None - - try: - assert self._stepper is not None - finished, stepper_result = self._stepper.step() - except _PropagateReturn as exception: - finished, result = True, exception.exit_code - else: - # Set result to None unless stepper_result was non-zero positive integer or ExitCode with similar status - if isinstance(stepper_result, int) and stepper_result > 0: - result = ExitCode(stepper_result) - elif isinstance(stepper_result, ExitCode) and stepper_result.status > 0: - result = stepper_result - else: - result = None - - # If the stepper said we are finished or the result is an ExitCode, we exit by returning - if finished or isinstance(result, ExitCode): - return result - - if isinstance(stepper_result, ToContext): - self.to_context(**stepper_result) - - if self._awaitables: - return Wait(self._do_step, 'Waiting before next step') - - return Continue(self._do_step) - - def _store_nodes(self, data: t.Any) -> None: - """Recurse through a data structure and store any unstored nodes that are found along the way - - :param data: a data structure potentially containing unstored nodes - """ - if isinstance(data, Node) and not data.is_stored: - data.store() - elif isinstance(data, collections.abc.Mapping): - for _, value in data.items(): - self._store_nodes(value) - elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str): - for value in data: - self._store_nodes(value) - - @override - @Protect.final - def on_exiting(self) -> None: - """Ensure that any unstored nodes in the context are stored, before the state is exited - - After the state is exited the next state will be entered and if persistence is enabled, a checkpoint will - be saved. If the context contains unstored nodes, the serialization necessary for checkpointing will fail. - """ - super().on_exiting() - try: - self._store_nodes(self.ctx) - except Exception: # pylint: disable=broad-except - # An uncaught exception here will have bizarre and disastrous consequences - self.logger.exception('exception in _store_nodes called in on_exiting') - - @Protect.final - def on_wait(self, awaitables: t.Sequence[t.Awaitable]): - """Entering the WAITING state.""" - super().on_wait(awaitables) - if self._awaitables: - self._action_awaitables() - else: - self.call_soon(self.resume) - - def _action_awaitables(self) -> None: - """Handle the awaitables that are currently registered with the work chain. - - Depending on the class type of the awaitable's target a different callback - function will be bound with the awaitable and the runner will be asked to - call it when the target is completed - """ - for awaitable in self._awaitables: - if awaitable.target == AwaitableTarget.PROCESS: - callback = functools.partial(self.call_soon, self._on_awaitable_finished, awaitable) - self.runner.call_on_process_finish(awaitable.pk, callback) - else: - assert f"invalid awaitable target '{awaitable.target}'" - - def _on_awaitable_finished(self, awaitable: Awaitable) -> None: - """Callback function, for when an awaitable process instance is completed. - - The awaitable will be effectuated on the context of the work chain and removed from the internal list. If all - awaitables have been dealt with, the work chain process is resumed. - - :param awaitable: an Awaitable instance - """ - self.logger.info('received callback that awaitable %d has terminated', awaitable.pk) - - try: - node = load_node(awaitable.pk) - except (exceptions.MultipleObjectsError, exceptions.NotExistent): - raise ValueError(f'provided pk<{awaitable.pk}> could not be resolved to a valid Node instance') - - if awaitable.outputs: - value = {entry.link_label: entry.node for entry in node.base.links.get_outgoing()} - else: - value = node # type: ignore - - self._resolve_awaitable(awaitable, value) - - if self.state == ProcessState.WAITING and not self._awaitables: - self.resume() diff --git a/aiida/engine/utils.py b/aiida/engine/utils.py deleted file mode 100644 index 547ecd51cf..0000000000 --- a/aiida/engine/utils.py +++ /dev/null @@ -1,315 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name -"""Utilities for the workflow engine.""" -import asyncio -import contextlib -from datetime import datetime -import inspect -import logging -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterator, List, Optional, Tuple, Type, Union - -if TYPE_CHECKING: - from .processes import Process, ProcessBuilder - from .runners import Runner - -__all__ = ('interruptable_task', 'InterruptableFuture', 'is_process_function') - -LOGGER = logging.getLogger(__name__) -PROCESS_STATE_CHANGE_KEY = 'process|state_change|{}' -PROCESS_STATE_CHANGE_DESCRIPTION = 'The last time a process of type {}, changed state' - - -def instantiate_process( - runner: 'Runner', process: Union['Process', Type['Process'], 'ProcessBuilder'], **inputs -) -> 'Process': - """ - Return an instance of the process with the given inputs. The function can deal with various types - of the `process`: - - * Process instance: will simply return the instance - * ProcessBuilder instance: will instantiate the Process from the class and inputs defined within it - * Process class: will instantiate with the specified inputs - - If anything else is passed, a ValueError will be raised - - :param process: Process instance or class, CalcJobNode class or ProcessBuilder instance - :param inputs: the inputs for the process to be instantiated with - """ - from .processes import Process, ProcessBuilder - - if isinstance(process, Process): - assert not inputs - assert runner is process.runner - return process - - if isinstance(process, ProcessBuilder): - builder = process - process_class = builder.process_class - inputs.update(**builder._inputs(prune=True)) # pylint: disable=protected-access - elif is_process_function(process): - process_class = process.process_class # type: ignore[attr-defined] - elif inspect.isclass(process) and issubclass(process, Process): - process_class = process - else: - raise ValueError(f'invalid process {type(process)}, needs to be Process or ProcessBuilder') - - process = process_class(runner=runner, inputs=inputs) - - return process - - -class InterruptableFuture(asyncio.Future): - """A future that can be interrupted by calling `interrupt`.""" - - def interrupt(self, reason: Exception) -> None: - """This method should be called to interrupt the coroutine represented by this InterruptableFuture.""" - self.set_exception(reason) - - async def with_interrupt(self, coro: Awaitable[Any]) -> Any: - """ - return result of a coroutine which will be interrupted if this future is interrupted :: - - import asyncio - loop = asyncio.get_event_loop() - - interruptable = InterutableFuture() - loop.call_soon(interruptable.interrupt, RuntimeError("STOP")) - loop.run_until_complete(interruptable.with_interrupt(asyncio.sleep(2.))) - >>> RuntimeError: STOP - - - :param coro: The coroutine that can be interrupted - :return: The result of the coroutine - """ - task = asyncio.ensure_future(coro) - wait_iter = asyncio.as_completed({self, task}) - result = await next(wait_iter) - if self.done(): - raise RuntimeError(f"This interruptible future had it's result set unexpectedly to '{result}'") - - return result - - -def interruptable_task( - coro: Callable[[InterruptableFuture], Awaitable[Any]], - loop: Optional[asyncio.AbstractEventLoop] = None -) -> InterruptableFuture: - """ - Turn the given coroutine into an interruptable task by turning it into an InterruptableFuture and returning it. - - :param coro: the coroutine that should be made interruptable with object of InterutableFuture as last paramenter - :param loop: the event loop in which to run the coroutine, by default uses asyncio.get_event_loop() - :return: an InterruptableFuture - """ - - loop = loop or asyncio.get_event_loop() - future = InterruptableFuture() - - async def execute_coroutine(): - """Coroutine that wraps the original coroutine and sets it result on the future only if not already set.""" - try: - result = await coro(future) - except Exception as exception: # pylint: disable=broad-except - if not future.done(): - future.set_exception(exception) - else: - LOGGER.warning( - 'Interruptable future set to %s before its coro %s is done. %s', future.result(), coro.__name__, - str(exception) - ) - else: - # If the future has not been set elsewhere, i.e. by the interrupt call, by the time that the coroutine - # is executed, set the future's result to the result of the coroutine - if not future.done(): - future.set_result(result) - - loop.create_task(execute_coroutine()) - - return future - - -def ensure_coroutine(fct: Callable[..., Any]) -> Callable[..., Awaitable[Any]]: - """ - Ensure that the given function ``fct`` is a coroutine - - If the passed function is not already a coroutine, it will be made to be a coroutine - - :param fct: the function - :returns: the coroutine - """ - if asyncio.iscoroutinefunction(fct): - return fct - - async def wrapper(*args, **kwargs): - return fct(*args, **kwargs) - - return wrapper - - -async def exponential_backoff_retry( - fct: Callable[..., Any], - initial_interval: Union[int, float] = 10.0, - max_attempts: int = 5, - logger: Optional[logging.Logger] = None, - ignore_exceptions: Union[None, Type[Exception], Tuple[Type[Exception], ...]] = None -) -> Any: - """ - Coroutine to call a function, recalling it with an exponential backoff in the case of an exception - - This coroutine will loop ``max_attempts`` times, calling the ``fct`` function, breaking immediately when the call - finished without raising an exception, at which point the result will be returned. If an exception is caught, the - function will await a ``asyncio.sleep`` with a time interval equal to the ``initial_interval`` multiplied by - ``2 ** (N - 1)`` where ``N`` is the number of excepted calls. - - :param fct: the function to call, which will be turned into a coroutine first if it is not already - :param initial_interval: the time to wait after the first caught exception before calling the coroutine again - :param max_attempts: the maximum number of times to call the coroutine before re-raising the exception - :param ignore_exceptions: exceptions to ignore, i.e. when caught do nothing and simply re-raise - :return: result if the ``coro`` call completes within ``max_attempts`` retries without raising - """ - if logger is None: - logger = LOGGER - - result: Any = None - coro = ensure_coroutine(fct) - interval = initial_interval - - for iteration in range(max_attempts): - try: - result = await coro() - break # Finished successfully - except Exception as exception: # pylint: disable=broad-except - - # Re-raise exceptions that should be ignored - if ignore_exceptions is not None and isinstance(exception, ignore_exceptions): - raise - - count = iteration + 1 - coro_name = coro.__name__ - - if iteration == max_attempts - 1: - logger.exception('iteration %d of %s excepted', count, coro_name) - logger.warning('maximum attempts %d of calling %s, exceeded', max_attempts, coro_name) - raise - else: - logger.exception('iteration %d of %s excepted, retrying after %d seconds', count, coro_name, interval) - await asyncio.sleep(interval) - interval *= 2 - - return result - - -def is_process_function(function: Any) -> bool: - """Return whether the given function is a process function - - :param function: a function - :returns: True if the function is a wrapped process function, False otherwise - """ - try: - return function.is_process_function is True - except AttributeError: - return False - - -def is_process_scoped() -> bool: - """Return whether the current scope is within a process. - - :returns: True if the current scope is within a nested process, False otherwise - """ - from .processes.process import Process - return Process.current() is not None - - -@contextlib.contextmanager -def loop_scope(loop) -> Iterator[None]: - """ - Make an event loop current for the scope of the context - - :param loop: The event loop to make current for the duration of the scope - """ - current = asyncio.get_event_loop() - - try: - asyncio.set_event_loop(loop) - yield - finally: - asyncio.set_event_loop(current) - - -def set_process_state_change_timestamp(process: 'Process') -> None: - """ - Set the global setting that reflects the last time a process changed state, for the process type - of the given process, to the current timestamp. The process type will be determined based on - the class of the calculation node it has as its database container. - - :param process: the Process instance that changed its state - """ - from aiida.common import timezone - from aiida.manage import get_manager # pylint: disable=cyclic-import - from aiida.orm import CalculationNode, ProcessNode, WorkflowNode - - if isinstance(process.node, CalculationNode): - process_type = 'calculation' - elif isinstance(process.node, WorkflowNode): - process_type = 'work' - elif isinstance(process.node, ProcessNode): - # This will only occur for testing, as in general users cannot launch plain Process classes - return - else: - raise ValueError(f'unsupported calculation node type {type(process.node)}') - - key = PROCESS_STATE_CHANGE_KEY.format(process_type) - description = PROCESS_STATE_CHANGE_DESCRIPTION.format(process_type) - value = timezone.now().isoformat() - - backend = get_manager().get_profile_storage() - backend.set_global_variable(key, value, description) - - -def get_process_state_change_timestamp(process_type: Optional[str] = None) -> Optional[datetime]: - """ - Get the global setting that reflects the last time a process of the given process type changed its state. - The returned value will be the corresponding timestamp or None if the setting does not exist. - - :param process_type: optional process type for which to get the latest state change timestamp. - Valid process types are either 'calculation' or 'work'. If not specified, last timestamp for all - known process types will be returned. - :return: a timestamp or None - """ - from aiida.manage import get_manager # pylint: disable=cyclic-import - - valid_process_types = ['calculation', 'work'] - - if process_type is not None and process_type not in valid_process_types: - raise ValueError(f"invalid value for process_type, valid values are {', '.join(valid_process_types)}") - - if process_type is None: - process_types = valid_process_types - else: - process_types = [process_type] - - timestamps: List[datetime] = [] - - backend = get_manager().get_profile_storage() - - for process_type_key in process_types: - key = PROCESS_STATE_CHANGE_KEY.format(process_type_key) - try: - time_stamp = backend.get_global_variable(key) - if time_stamp is not None: - timestamps.append(datetime.fromisoformat(str(time_stamp))) - except KeyError: - continue - - if not timestamps: - return None - - return max(timestamps) diff --git a/aiida/manage/__init__.py b/aiida/manage/__init__.py deleted file mode 100644 index 2f41729ccf..0000000000 --- a/aiida/manage/__init__.py +++ /dev/null @@ -1,67 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -Managing an AiiDA instance: - - * configuration file - * profiles - * databases - * repositories - * external components (such as Postgres, RabbitMQ) - -.. note:: Modules in this sub package may require the database environment to be loaded - -""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .caching import * -from .configuration import * -from .external import * -from .manager import * - -__all__ = ( - 'BROKER_DEFAULTS', - 'CURRENT_CONFIG_VERSION', - 'Config', - 'ConfigValidationError', - 'DEFAULT_DBINFO', - 'MIGRATIONS', - 'ManagementApiConnectionError', - 'OLDEST_COMPATIBLE_CONFIG_VERSION', - 'Option', - 'Postgres', - 'PostgresConnectionMode', - 'ProcessLauncher', - 'Profile', - 'RabbitmqManagementClient', - 'check_and_migrate_config', - 'config_needs_migrating', - 'config_schema', - 'disable_caching', - 'downgrade_config', - 'enable_caching', - 'get_current_version', - 'get_launch_queue_name', - 'get_manager', - 'get_message_exchange_name', - 'get_option', - 'get_option_names', - 'get_rmq_url', - 'get_task_exchange_name', - 'get_use_cache', - 'parse_option', - 'upgrade_config', -) - -# yapf: enable diff --git a/aiida/manage/caching.py b/aiida/manage/caching.py deleted file mode 100644 index 2dfea4f9f4..0000000000 --- a/aiida/manage/caching.py +++ /dev/null @@ -1,269 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Definition of caching mechanism and configuration for calculations.""" -from collections import namedtuple -from contextlib import contextmanager, suppress -from enum import Enum -import keyword -import re - -from aiida.common import exceptions -from aiida.common.lang import type_check -from aiida.manage.configuration import get_config_option -from aiida.plugins.entry_point import ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP, ENTRY_POINT_STRING_SEPARATOR - -__all__ = ('get_use_cache', 'enable_caching', 'disable_caching') - - -class ConfigKeys(Enum): - """Valid keys for caching configuration.""" - - DEFAULT = 'caching.default_enabled' - ENABLED = 'caching.enabled_for' - DISABLED = 'caching.disabled_for' - - -class _ContextCache: - """Cache options, accounting for when in enable_caching or disable_caching contexts.""" - - def __init__(self): - self._default_all = None - self._enable = [] - self._disable = [] - - def clear(self): - """Clear caching overrides.""" - self.__init__() # type: ignore - - def enable_all(self): - self._default_all = 'enable' - - def disable_all(self): - self._default_all = 'disable' - - def enable(self, identifier): - self._enable.append(identifier) - with suppress(ValueError): - self._disable.remove(identifier) - - def disable(self, identifier): - self._disable.append(identifier) - with suppress(ValueError): - self._enable.remove(identifier) - - def get_options(self): - """Return the options, applying any context overrides.""" - - if self._default_all == 'disable': - return False, [], [] - - if self._default_all == 'enable': - return True, [], [] - - default = get_config_option(ConfigKeys.DEFAULT.value) - enabled = get_config_option(ConfigKeys.ENABLED.value)[:] - disabled = get_config_option(ConfigKeys.DISABLED.value)[:] - - for ident in self._disable: - disabled.append(ident) - with suppress(ValueError): - enabled.remove(ident) - - for ident in self._enable: - enabled.append(ident) - with suppress(ValueError): - disabled.remove(ident) - - # Check validity of enabled and disabled entries - try: - for identifier in enabled + disabled: - _validate_identifier_pattern(identifier=identifier) - except ValueError as exc: - raise exceptions.ConfigurationError('Invalid identifier pattern in enable or disable list.') from exc - - return default, enabled, disabled - - -_CONTEXT_CACHE = _ContextCache() - - -@contextmanager -def enable_caching(*, identifier=None): - """Context manager to enable caching, either for a specific node class, or globally. - - .. warning:: this does not affect the behavior of the daemon, only the local Python interpreter. - - :param identifier: Process type string of the node, or a pattern with '*' wildcard that matches it. - If not provided, caching is enabled for all classes. - :type identifier: str - """ - type_check(identifier, str, allow_none=True) - - if identifier is None: - _CONTEXT_CACHE.enable_all() - else: - _validate_identifier_pattern(identifier=identifier) - _CONTEXT_CACHE.enable(identifier) - yield - _CONTEXT_CACHE.clear() - - -@contextmanager -def disable_caching(*, identifier=None): - """Context manager to disable caching, either for a specific node class, or globally. - - .. warning:: this does not affect the behavior of the daemon, only the local Python interpreter. - - :param identifier: Process type string of the node, or a pattern with '*' wildcard that matches it. - If not provided, caching is disabled for all classes. - :type identifier: str - """ - type_check(identifier, str, allow_none=True) - - if identifier is None: - _CONTEXT_CACHE.disable_all() - else: - _validate_identifier_pattern(identifier=identifier) - _CONTEXT_CACHE.disable(identifier) - yield - _CONTEXT_CACHE.clear() - - -def get_use_cache(*, identifier=None): - """Return whether the caching mechanism should be used for the given process type according to the configuration. - - :param identifier: Process type string of the node - :type identifier: str - :return: boolean, True if caching is enabled, False otherwise - :raises: `~aiida.common.exceptions.ConfigurationError` if the configuration is invalid, either due to a general - configuration error, or by defining the class both enabled and disabled - """ - type_check(identifier, str, allow_none=True) - - default, enabled, disabled = _CONTEXT_CACHE.get_options() - - if identifier is not None: - type_check(identifier, str) - - enable_matches = [pattern for pattern in enabled if _match_wildcard(string=identifier, pattern=pattern)] - disable_matches = [pattern for pattern in disabled if _match_wildcard(string=identifier, pattern=pattern)] - - if enable_matches and disable_matches: - # If both enable and disable have matching identifier, we search for - # the most specific one. This is determined by checking whether - # all other patterns match the specific pattern. - PatternWithResult = namedtuple('PatternWithResult', ['pattern', 'use_cache']) - most_specific = [] - for specific_pattern in enable_matches: - if all( - _match_wildcard(string=specific_pattern, pattern=other_pattern) - for other_pattern in enable_matches + disable_matches - ): - most_specific.append(PatternWithResult(pattern=specific_pattern, use_cache=True)) - for specific_pattern in disable_matches: - if all( - _match_wildcard(string=specific_pattern, pattern=other_pattern) - for other_pattern in enable_matches + disable_matches - ): - most_specific.append(PatternWithResult(pattern=specific_pattern, use_cache=False)) - - if len(most_specific) > 1: - raise exceptions.ConfigurationError(( - 'Invalid configuration: multiple matches for identifier {}' - ', but the most specific identifier is not unique. Candidates: {}' - ).format(identifier, [match.pattern for match in most_specific])) - if not most_specific: - raise exceptions.ConfigurationError( - 'Invalid configuration: multiple matches for identifier {}, but none of them is most specific.'. - format(identifier) - ) - return most_specific[0].use_cache - if enable_matches: - return True - if disable_matches: - return False - return default - - -def _match_wildcard(*, string, pattern): - """ - Helper function to check whether a given name matches a pattern - which can contain '*' wildcards. - """ - regexp = '.*'.join(re.escape(part) for part in pattern.split('*')) - return re.fullmatch(pattern=regexp, string=string) is not None - - -def _validate_identifier_pattern(*, identifier): - """ - The identifier (without wildcards) can have one of two forms: - - 1. - - where `group_name` is one of the keys in `ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP` - and `tail` can be anything _except_ `ENTRY_POINT_STRING_SEPARATOR`. - - 2. a fully qualified Python name - - this is a colon-separated string, where each part satisfies - `part.isidentifier() and not keyword.iskeyword(part)` - - This function checks if an identifier _with_ wildcards can possibly - match one of these two forms. If it can not, a `ValueError` is raised. - - :param identifier: Process type string, or a pattern with '*' wildcard that matches it. - :type identifier: str - """ - common_error_msg = f"Invalid identifier pattern '{identifier}': " - assert ENTRY_POINT_STRING_SEPARATOR not in '.*' # The logic of this function depends on this - # Check if it can be an entry point string - if identifier.count(ENTRY_POINT_STRING_SEPARATOR) > 1: - raise ValueError( - f"{common_error_msg}Can contain at most one entry point string separator '{ENTRY_POINT_STRING_SEPARATOR}'" - ) - # If there is one separator, it must be an entry point string. - # Check if the left hand side is a matching pattern - if ENTRY_POINT_STRING_SEPARATOR in identifier: - group_pattern, _ = identifier.split(ENTRY_POINT_STRING_SEPARATOR) - if not any( - _match_wildcard(string=group_name, pattern=group_pattern) - for group_name in ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP - ): - raise ValueError( - common_error_msg + "Group name pattern '{}' does not match any of the AiiDA entry point group names.". - format(group_pattern) - ) - # The group name pattern matches, and there are no further - # entry point string separators in the identifier, hence it is - # a valid pattern. - return - # The separator might be swallowed in a wildcard, for example - # aiida.* or aiida.calculations* - if '*' in identifier: - group_part, _ = identifier.split('*', 1) - if any(group_name.startswith(group_part) for group_name in ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP): - return - # Finally, check if it could be a fully qualified Python name - for identifier_part in identifier.split('.'): - # If it contains a wildcard, we can not check for keywords. - # Replacing all wildcards with a single letter must give an - # identifier - this checks for invalid characters, and that it - # does not start with a number. - if '*' in identifier_part: - if not identifier_part.replace('*', 'a').isidentifier(): - raise ValueError( - common_error_msg + - f"Identifier part '{identifier_part}' can not match a fully qualified Python name." - ) - else: - if not identifier_part.isidentifier(): - raise ValueError(f"{common_error_msg}'{identifier_part}' is not a valid Python identifier.") - if keyword.iskeyword(identifier_part): - raise ValueError(f"{common_error_msg}'{identifier_part}' is a reserved Python keyword.") diff --git a/aiida/manage/configuration/__init__.py b/aiida/manage/configuration/__init__.py deleted file mode 100644 index 11ca28d3a0..0000000000 --- a/aiida/manage/configuration/__init__.py +++ /dev/null @@ -1,310 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Modules related to the configuration of an AiiDA instance.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .config import * -from .migrations import * -from .options import * -from .profile import * - -__all__ = ( - 'CURRENT_CONFIG_VERSION', - 'Config', - 'ConfigValidationError', - 'MIGRATIONS', - 'OLDEST_COMPATIBLE_CONFIG_VERSION', - 'Option', - 'Profile', - 'check_and_migrate_config', - 'config_needs_migrating', - 'config_schema', - 'downgrade_config', - 'get_current_version', - 'get_option', - 'get_option_names', - 'parse_option', - 'upgrade_config', -) - -# yapf: enable - -# END AUTO-GENERATED - -# pylint: disable=global-statement,redefined-outer-name,wrong-import-order - -__all__ += ( - 'get_config', 'get_config_option', 'get_config_path', 'get_profile', 'load_profile', 'reset_config', 'CONFIG' -) - -from contextlib import contextmanager -import os -import shutil -from typing import TYPE_CHECKING, Any, Optional -import warnings - -from aiida.common.warnings import AiidaDeprecationWarning, warn_deprecation - -if TYPE_CHECKING: - from aiida.manage.configuration import Config, Profile # pylint: disable=import-self - -# global variables for aiida -CONFIG: Optional['Config'] = None - - -def get_config_path(): - """Returns path to .aiida configuration directory.""" - from .settings import AIIDA_CONFIG_FOLDER, DEFAULT_CONFIG_FILE_NAME - - return os.path.join(AIIDA_CONFIG_FOLDER, DEFAULT_CONFIG_FILE_NAME) - - -def load_config(create=False) -> 'Config': - """Instantiate Config object representing an AiiDA configuration file. - - Warning: Contrary to :func:`~aiida.manage.configuration.get_config`, this function is uncached and will always - create a new Config object. You may want to call :func:`~aiida.manage.configuration.get_config` instead. - - :param create: if True, will create the configuration file if it does not already exist - :type create: bool - - :return: the config - :rtype: :class:`~aiida.manage.configuration.config.Config` - :raises aiida.common.MissingConfigurationError: if the configuration file could not be found and create=False - """ - from aiida.common import exceptions - - from .config import Config - - filepath = get_config_path() - - if not os.path.isfile(filepath) and not create: - raise exceptions.MissingConfigurationError(f'configuration file {filepath} does not exist') - - try: - config = Config.from_file(filepath) - except ValueError as exc: - raise exceptions.ConfigurationError(f'configuration file {filepath} contains invalid JSON') from exc - - _merge_deprecated_cache_yaml(config, filepath) - - return config - - -def _merge_deprecated_cache_yaml(config, filepath): - """Merge the deprecated cache_config.yml into the config.""" - from aiida.common import timezone - cache_path = os.path.join(os.path.dirname(filepath), 'cache_config.yml') - if not os.path.exists(cache_path): - return - - cache_path_backup = None - # Keep generating a new backup filename based on the current time until it does not exist - while not cache_path_backup or os.path.isfile(cache_path_backup): - cache_path_backup = f"{cache_path}.{timezone.now().strftime('%Y%m%d-%H%M%S.%f')}" - - warn_deprecation( - 'cache_config.yml use is deprecated and support will be removed in `v3.0`. Merging into config.json and ' - f'moving to: {cache_path_backup}', - version=3 - ) - import yaml - with open(cache_path, 'r', encoding='utf8') as handle: - cache_config = yaml.safe_load(handle) - for profile_name, data in cache_config.items(): - if profile_name not in config.profile_names: - warnings.warn(f"Profile '{profile_name}' from cache_config.yml not in config.json, skipping", UserWarning) - continue - for key, option_name in [('default', 'caching.default_enabled'), ('enabled', 'caching.enabled_for'), - ('disabled', 'caching.disabled_for')]: - if key in data: - value = data[key] - # in case of empty key - value = [] if value is None and key != 'default' else value - config.set_option(option_name, value, scope=profile_name) - config.store() - shutil.move(cache_path, cache_path_backup) - - -def load_profile(profile: Optional[str] = None, allow_switch=False) -> 'Profile': - """Load a global profile, unloading any previously loaded profile. - - .. note:: if a profile is already loaded and no explicit profile is specified, nothing will be done - - :param profile: the name of the profile to load, by default will use the one marked as default in the config - :param allow_switch: if True, will allow switching to a different profile when storage is already loaded - - :return: the loaded `Profile` instance - :raises `aiida.common.exceptions.InvalidOperation`: - if another profile has already been loaded and allow_switch is False - """ - from aiida.manage import get_manager - return get_manager().load_profile(profile, allow_switch) - - -def get_profile() -> Optional['Profile']: - """Return the currently loaded profile. - - :return: the globally loaded `Profile` instance or `None` - """ - from aiida.manage import get_manager - return get_manager().get_profile() - - -@contextmanager -def profile_context(profile: Optional[str] = None, allow_switch=False) -> 'Profile': - """Return a context manager for temporarily loading a profile, and unloading on exit. - - :param profile: the name of the profile to load, by default will use the one marked as default in the config - :param allow_switch: if True, will allow switching to a different profile - - :return: a context manager for temporarily loading a profile - """ - from aiida.manage import get_manager - manager = get_manager() - current_profile = manager.get_profile() - manager.load_profile(profile, allow_switch) - yield profile - if current_profile is None: - manager.unload_profile() - else: - manager.load_profile(current_profile, allow_switch=True) - - -def reset_config(): - """Reset the globally loaded config. - - .. warning:: This is experimental functionality and should for now be used only internally. If the reset is unclean - weird unknown side-effects may occur that end up corrupting or destroying data. - """ - global CONFIG - CONFIG = None - - -def get_config(create=False): - """Return the current configuration. - - If the configuration has not been loaded yet - * the configuration is loaded using ``load_config`` - * the global `CONFIG` variable is set - * the configuration object is returned - - Note: This function will except if no configuration file can be found. Only call this function, if you need - information from the configuration file. - - :param create: if True, will create the configuration file if it does not already exist - :type create: bool - - :return: the config - :rtype: :class:`~aiida.manage.configuration.config.Config` - :raises aiida.common.ConfigurationError: if the configuration file could not be found, read or deserialized - """ - global CONFIG - - if not CONFIG: - CONFIG = load_config(create=create) - - if CONFIG.get_option('warnings.showdeprecations'): - # If the user does not want to get AiiDA deprecation warnings, we disable them - this can be achieved with:: - # verdi config warnings.showdeprecations False - # Note that the AiidaDeprecationWarning does NOT inherit from DeprecationWarning - warnings.simplefilter('default', AiidaDeprecationWarning) # pylint: disable=no-member - # This should default to 'once', i.e. once per different message - else: - warnings.simplefilter('ignore', AiidaDeprecationWarning) # pylint: disable=no-member - - return CONFIG - - -def get_config_option(option_name: str) -> Any: - """Return the value of a configuration option. - - In order of priority, the option is returned from: - - 1. The current profile, if loaded and the option specified - 2. The current configuration, if loaded and the option specified - 3. The default value for the option - - :param option_name: the name of the option to return - :return: the value of the option - :raises `aiida.common.exceptions.ConfigurationError`: if the option is not found - """ - from aiida.manage import get_manager - return get_manager().get_option(option_name) - - -def load_documentation_profile(): - """Load a dummy profile just for the purposes of being able to build the documentation. - - The building of the documentation will require importing the `aiida` package and some code will try to access the - loaded configuration and profile, which if not done will except. - Calling this function allows the documentation to be built without having to install and configure AiiDA, - nor having an actual database present. - """ - import tempfile - - # imports required for docs/source/reference/api/public.rst - from aiida import ( # pylint: disable=unused-import - cmdline, - common, - engine, - manage, - orm, - parsers, - plugins, - schedulers, - tools, - transports, - ) - from aiida.cmdline.params import arguments, options # pylint: disable=unused-import - from aiida.storage.psql_dos.models.base import get_orm_metadata - - from .config import Config - - global CONFIG - - with tempfile.NamedTemporaryFile() as handle: - profile_name = 'readthedocs' - profile_config = { - 'storage': { - 'backend': 'core.psql_dos', - 'config': { - 'database_engine': 'postgresql_psycopg2', - 'database_port': 5432, - 'database_hostname': 'localhost', - 'database_name': 'aiidadb', - 'database_password': 'aiidadb', - 'database_username': 'aiida', - 'repository_uri': 'file:///dev/null', - } - }, - 'process_control': { - 'backend': 'rabbitmq', - 'config': { - 'broker_protocol': 'amqp', - 'broker_username': 'guest', - 'broker_password': 'guest', - 'broker_host': 'localhost', - 'broker_port': 5672, - 'broker_virtual_host': '', - } - }, - } - config = {'default_profile': profile_name, 'profiles': {profile_name: profile_config}} - CONFIG = Config(handle.name, config) - load_profile(profile_name) - - # we call this to make sure the ORM metadata is fully populated, - # so that ORM models can be properly documented - get_orm_metadata() diff --git a/aiida/manage/configuration/config.py b/aiida/manage/configuration/config.py deleted file mode 100644 index f4b3f1347f..0000000000 --- a/aiida/manage/configuration/config.py +++ /dev/null @@ -1,538 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module that defines the configuration file of an AiiDA instance and functions to create and load it.""" -import codecs -from functools import lru_cache -from importlib.resources import files -import json -import os -import shutil -import tempfile -from typing import Any, Dict, Optional, Sequence, Tuple - -import jsonschema - -from aiida.common.exceptions import ConfigurationError - -from . import schema as schema_module -from .options import NO_DEFAULT, Option, get_option, get_option_names, parse_option -from .profile import Profile - -__all__ = ('Config', 'config_schema', 'ConfigValidationError') - -SCHEMA_FILE = 'config-v9.schema.json' - - -@lru_cache(1) -def config_schema() -> Dict[str, Any]: - """Return the configuration schema.""" - return json.loads(files(schema_module).joinpath(SCHEMA_FILE).read_text(encoding='utf8')) - - -class ConfigValidationError(ConfigurationError): - """Configuration error raised when the file contents fails validation.""" - - def __init__( - self, message: str, keypath: Sequence[Any] = (), schema: Optional[dict] = None, filepath: Optional[str] = None - ): - super().__init__(message) - self._message = message - self._keypath = keypath - self._filepath = filepath - self._schema = schema - - def __str__(self) -> str: - prefix = f'{self._filepath}:' if self._filepath else '' - path = '/' + '/'.join(str(k) for k in self._keypath) + ': ' if self._keypath else '' - schema = f'\n schema:\n {self._schema}' if self._schema else '' - return f'Validation Error: {prefix}{path}{self._message}{schema}' - - -class Config: # pylint: disable=too-many-public-methods - """Object that represents the configuration file of an AiiDA instance.""" - - KEY_VERSION = 'CONFIG_VERSION' - KEY_VERSION_CURRENT = 'CURRENT' - KEY_VERSION_OLDEST_COMPATIBLE = 'OLDEST_COMPATIBLE' - KEY_DEFAULT_PROFILE = 'default_profile' - KEY_PROFILES = 'profiles' - KEY_OPTIONS = 'options' - KEY_SCHEMA = '$schema' - - @classmethod - def from_file(cls, filepath): - """Instantiate a configuration object from the contents of a given file. - - .. note:: if the filepath does not exist an empty file will be created with the current default configuration - and will be written to disk. If the filepath does already exist but contains a configuration with an - outdated schema, the content will be migrated and then written to disk. - - :param filepath: the absolute path to the configuration file - :return: `Config` instance - """ - from aiida.cmdline.utils import echo - - from .migrations import check_and_migrate_config, config_needs_migrating - - try: - with open(filepath, 'rb') as handle: - config = json.load(handle) - except FileNotFoundError: - config = Config(filepath, check_and_migrate_config({})) - config.store() - else: - migrated = False - - # If the configuration file needs to be migrated first create a specific backup so it can easily be reverted - if config_needs_migrating(config, filepath): - migrated = True - echo.echo_warning(f'current configuration file `{filepath}` is outdated and will be migrated') - filepath_backup = cls._backup(filepath) - echo.echo_warning(f'original backed up to `{filepath_backup}`') - - config = Config(filepath, check_and_migrate_config(config)) - - if migrated: - config.store() - - return config - - @classmethod - def _backup(cls, filepath): - """Create a backup of the configuration file with the given filepath. - - :param filepath: absolute path to the configuration file to backup - :return: the absolute path of the created backup - """ - from aiida.common import timezone - - filepath_backup = None - - # Keep generating a new backup filename based on the current time until it does not exist - while not filepath_backup or os.path.isfile(filepath_backup): - filepath_backup = f"{filepath}.{timezone.now().strftime('%Y%m%d-%H%M%S.%f')}" - - shutil.copy(filepath, filepath_backup) - - return filepath_backup - - @staticmethod - def validate(config: dict, filepath: Optional[str] = None): - """Validate a configuration dictionary.""" - try: - jsonschema.validate(instance=config, schema=config_schema()) - except jsonschema.ValidationError as error: - raise ConfigValidationError( - message=error.message, keypath=error.path, schema=error.schema, filepath=filepath - ) - - def __init__(self, filepath: str, config: dict, validate: bool = True): - """Instantiate a configuration object from a configuration dictionary and its filepath. - - If an empty dictionary is passed, the constructor will create the skeleton configuration dictionary. - - :param filepath: the absolute filepath of the configuration file - :param config: the content of the configuration file in dictionary form - :param validate: validate the dictionary against the schema - """ - from .migrations import CURRENT_CONFIG_VERSION, OLDEST_COMPATIBLE_CONFIG_VERSION - - if validate: - self.validate(config, filepath) - - self._filepath = filepath - self._schema = config.get(self.KEY_SCHEMA, None) - version = config.get(self.KEY_VERSION, {}) - self._current_version = version.get(self.KEY_VERSION_CURRENT, CURRENT_CONFIG_VERSION) - self._oldest_compatible_version = version.get( - self.KEY_VERSION_OLDEST_COMPATIBLE, OLDEST_COMPATIBLE_CONFIG_VERSION - ) - self._profiles = {} - - known_keys = [self.KEY_SCHEMA, self.KEY_VERSION, self.KEY_PROFILES, self.KEY_OPTIONS, self.KEY_DEFAULT_PROFILE] - unknown_keys = set(config.keys()) - set(known_keys) - - if unknown_keys: - keys = ', '.join(unknown_keys) - self.handle_invalid(f'encountered unknown keys [{keys}] in `{filepath}` which have been removed') - - try: - self._options = config[self.KEY_OPTIONS] - except KeyError: - self._options = {} - - try: - self._default_profile = config[self.KEY_DEFAULT_PROFILE] - except KeyError: - self._default_profile = None - - for name, config_profile in config.get(self.KEY_PROFILES, {}).items(): - self._profiles[name] = Profile(name, config_profile) - - def __eq__(self, other): - """Two configurations are considered equal, when their dictionaries are equal.""" - return self.dictionary == other.dictionary - - def __ne__(self, other): - """Two configurations are considered unequal, when their dictionaries are unequal.""" - return self.dictionary != other.dictionary - - def handle_invalid(self, message): - """Handle an incoming invalid configuration dictionary. - - The current content of the configuration file will be written to a backup file. - - :param message: a string message to echo with describing the infraction - """ - from aiida.cmdline.utils import echo - filepath_backup = self._backup(self.filepath) - echo.echo_warning(message) - echo.echo_warning(f'backup of the original config file written to: `{filepath_backup}`') - - @property - def dictionary(self) -> dict: - """Return the dictionary representation of the config as it would be written to file. - - :return: dictionary representation of config as it should be written to file - """ - config = {} - if self._schema: - config[self.KEY_SCHEMA] = self._schema - - config[self.KEY_VERSION] = self.version_settings - config[self.KEY_PROFILES] = {name: profile.dictionary for name, profile in self._profiles.items()} - - if self._default_profile: - config[self.KEY_DEFAULT_PROFILE] = self._default_profile - - if self._options: - config[self.KEY_OPTIONS] = self._options - - return config - - @property - def version(self): - return self._current_version - - @version.setter - def version(self, version): - self._current_version = version - - @property - def version_oldest_compatible(self): - return self._oldest_compatible_version - - @version_oldest_compatible.setter - def version_oldest_compatible(self, version_oldest_compatible): - self._oldest_compatible_version = version_oldest_compatible - - @property - def version_settings(self): - return { - self.KEY_VERSION_CURRENT: self.version, - self.KEY_VERSION_OLDEST_COMPATIBLE: self.version_oldest_compatible - } - - @property - def filepath(self): - return self._filepath - - @property - def dirpath(self): - return os.path.dirname(self.filepath) - - @property - def default_profile_name(self): - """Return the default profile name. - - :return: the default profile name or None if not defined - """ - return self._default_profile - - @property - def profile_names(self): - """Return the list of profile names. - - :return: list of profile names - """ - return list(self._profiles.keys()) - - @property - def profiles(self): - """Return the list of profiles. - - :return: the profiles - :rtype: list of `Profile` instances - """ - return list(self._profiles.values()) - - def validate_profile(self, name): - """Validate that a profile exists. - - :param name: name of the profile: - :raises aiida.common.ProfileConfigurationError: if the name is not found in the configuration file - """ - from aiida.common import exceptions - - if name not in self.profile_names: - raise exceptions.ProfileConfigurationError(f'profile `{name}` does not exist') - - def get_profile(self, name: Optional[str] = None) -> Profile: - """Return the profile for the given name or the default one if not specified. - - :return: the profile instance or None if it does not exist - :raises aiida.common.ProfileConfigurationError: if the name is not found in the configuration file - """ - from aiida.common import exceptions - - if not name and not self.default_profile_name: - raise exceptions.ProfileConfigurationError( - f'no default profile defined: {self._default_profile}\n{self.dictionary}' - ) - - if not name: - name = self.default_profile_name - - self.validate_profile(name) - - return self._profiles[name] - - def add_profile(self, profile): - """Add a profile to the configuration. - - :param profile: the profile configuration dictionary - :return: self - """ - self._profiles[profile.name] = profile - return self - - def update_profile(self, profile): - """Update a profile in the configuration. - - :param profile: the profile instance to update - :return: self - """ - self._profiles[profile.name] = profile - return self - - def remove_profile(self, name): - """Remove a profile from the configuration. - - :param name: the name of the profile to remove - :raises aiida.common.ProfileConfigurationError: if the given profile does not exist - :return: self - """ - self.validate_profile(name) - self._profiles.pop(name) - return self - - def delete_profile( - self, - name: str, - include_database: bool = True, - include_database_user: bool = False, - include_repository: bool = True - ): - """Delete a profile including its storage. - - :param include_database: also delete the database configured for the profile. - :param include_database_user: also delete the database user configured for the profile. - :param include_repository: also delete the repository configured for the profile. - """ - from aiida.manage.external.postgres import Postgres - - profile = self.get_profile(name) - - if include_repository: - # Note, this is currently being hardcoded, but really this `delete_profile` should get the storage backend - # for the given profile and call `StorageBackend.erase` method. Currently this is leaking details - # of the implementation of the PsqlDosBackend into a generic function which would fail for profiles that - # use a different storage backend implementation. - from aiida.storage.psql_dos.backend import get_filepath_container - folder = get_filepath_container(profile).parent - if folder.exists(): - shutil.rmtree(folder) - - if include_database: - postgres = Postgres.from_profile(profile) - if postgres.db_exists(profile.storage_config['database_name']): - postgres.drop_db(profile.storage_config['database_name']) - - if include_database_user and postgres.dbuser_exists(profile.storage_config['database_username']): - postgres.drop_dbuser(profile.storage_config['database_username']) - - self.remove_profile(name) - self.store() - - def set_default_profile(self, name, overwrite=False): - """Set the given profile as the new default. - - :param name: name of the profile to set as new default - :param overwrite: when True, set the profile as the new default even if a default profile is already defined - :raises aiida.common.ProfileConfigurationError: if the given profile does not exist - :return: self - """ - if self.default_profile_name and not overwrite: - return self - - self.validate_profile(name) - self._default_profile = name - return self - - @property - def options(self): - return self._options - - @options.setter - def options(self, value): - self._options = value - - def set_option(self, option_name, option_value, scope=None, override=True): - """Set a configuration option for a certain scope. - - :param option_name: the name of the configuration option - :param option_value: the option value - :param scope: set the option for this profile or globally if not specified - :param override: boolean, if False, will not override the option if it already exists - - :returns: the parsed value (potentially cast to a valid type) - """ - option, parsed_value = parse_option(option_name, option_value) - - if parsed_value is not None: - value = parsed_value - elif option.default is not NO_DEFAULT: - value = option.default - else: - return - - if not option.global_only and scope is not None: - self.get_profile(scope).set_option(option.name, value, override=override) - else: - if option.name not in self.options or override: - self.options[option.name] = value - - return value - - def unset_option(self, option_name: str, scope=None): - """Unset a configuration option for a certain scope. - - :param option_name: the name of the configuration option - :param scope: unset the option for this profile or globally if not specified - """ - option = get_option(option_name) - - if scope is not None: - self.get_profile(scope).unset_option(option.name) - else: - self.options.pop(option.name, None) - - def get_option(self, option_name, scope=None, default=True): - """Get a configuration option for a certain scope. - - :param option_name: the name of the configuration option - :param scope: get the option for this profile or globally if not specified - :param default: boolean, If True will return the option default, even if not defined within the given scope - :return: the option value or None if not set for the given scope - """ - option = get_option(option_name) - - # Default value is `None` unless `default=True` and the `option.default` is not `NO_DEFAULT` - default_value = option.default if default and option.default is not NO_DEFAULT else None - - if scope is not None: - value = self.get_profile(scope).get_option(option.name, default_value) - else: - value = self.options.get(option.name, default_value) - - return value - - def get_options(self, scope: Optional[str] = None) -> Dict[str, Tuple[Option, str, Any]]: - """Return a dictionary of all option values and their source ('profile', 'global', or 'default'). - - :param scope: the profile name or globally if not specified - :returns: (option, source, value) - """ - profile = self.get_profile(scope) if scope else None - output = {} - for name in get_option_names(): - option = get_option(name) - if profile and name in profile.options: - value = profile.options.get(name) - source = 'profile' - elif name in self.options: - value = self.options.get(name) - source = 'global' - elif 'default' in option.schema: - value = option.default - source = 'default' - else: - continue - output[name] = (option, source, value) - return output - - def store(self): - """Write the current config to file. - - .. note:: if the configuration file already exists on disk and its contents differ from those in memory, a - backup of the original file on disk will be created before overwriting it. - - :return: self - """ - from aiida.common.files import md5_file, md5_from_filelike - - from .settings import DEFAULT_CONFIG_INDENT_SIZE - - # If the filepath of this configuration does not yet exist, simply write it. - if not os.path.isfile(self.filepath): - self._atomic_write() - return self - - # Otherwise, we write the content to a temporary file and compare its md5 checksum with the current config on - # disk. When the checksums differ, we first create a backup and only then overwrite the existing file. - with tempfile.NamedTemporaryFile() as handle: - json.dump(self.dictionary, codecs.getwriter('utf-8')(handle), indent=DEFAULT_CONFIG_INDENT_SIZE) - handle.seek(0) - - if md5_from_filelike(handle) != md5_file(self.filepath): - self._backup(self.filepath) - - self._atomic_write() - - return self - - def _atomic_write(self, filepath=None): - """Write the config as it is in memory, i.e. the contents of ``self.dictionary``, to disk. - - .. note:: this command will write the config from memory to a temporary file in the same directory as the - target file ``filepath``. It will then use ``os.rename`` to move the temporary file to ``filepath`` which - will be overwritten if it already exists. The ``os.rename`` is the operation that gives the best guarantee - of being atomic within the limitations of the application. - - :param filepath: optional filepath to write the contents to, if not specified, the default filename is used. - """ - from .settings import DEFAULT_CONFIG_INDENT_SIZE, DEFAULT_UMASK - - umask = os.umask(DEFAULT_UMASK) - - if filepath is None: - filepath = self.filepath - - # Create a temporary file in the same directory as the target filepath, which guarantees that the temporary - # file is on the same filesystem, which is necessary to be able to use ``os.rename``. Since we are moving the - # temporary file, we should also tell the tempfile to not be automatically deleted as that will raise. - with tempfile.NamedTemporaryFile(dir=os.path.dirname(filepath), delete=False, mode='w') as handle: - try: - json.dump(self.dictionary, handle, indent=DEFAULT_CONFIG_INDENT_SIZE) - finally: - os.umask(umask) - - handle.flush() - os.rename(handle.name, self.filepath) diff --git a/aiida/manage/configuration/migrations/__init__.py b/aiida/manage/configuration/migrations/__init__.py deleted file mode 100644 index 5eb7bf3bba..0000000000 --- a/aiida/manage/configuration/migrations/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Methods and definitions of migrations for the configuration file of an AiiDA instance.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .migrations import * - -__all__ = ( - 'CURRENT_CONFIG_VERSION', - 'MIGRATIONS', - 'OLDEST_COMPATIBLE_CONFIG_VERSION', - 'check_and_migrate_config', - 'config_needs_migrating', - 'downgrade_config', - 'get_current_version', - 'upgrade_config', -) - -# yapf: enable diff --git a/aiida/manage/configuration/migrations/migrations.py b/aiida/manage/configuration/migrations/migrations.py deleted file mode 100644 index 01ddfe7d4f..0000000000 --- a/aiida/manage/configuration/migrations/migrations.py +++ /dev/null @@ -1,499 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Define the current configuration version and migrations.""" -from typing import Any, Dict, Iterable, Optional, Protocol, Type - -from aiida.common import exceptions -from aiida.common.log import AIIDA_LOGGER - -__all__ = ( - 'CURRENT_CONFIG_VERSION', 'OLDEST_COMPATIBLE_CONFIG_VERSION', 'get_current_version', 'check_and_migrate_config', - 'config_needs_migrating', 'upgrade_config', 'downgrade_config', 'MIGRATIONS' -) - -ConfigType = Dict[str, Any] - -# The expected version of the configuration file and the oldest backwards compatible configuration version. -# If the configuration file format is changed, the current version number should be upped and a migration added. -# When the configuration file format is changed in a backwards-incompatible way, the oldest compatible version should -# be set to the new current version. - -CURRENT_CONFIG_VERSION = 9 -OLDEST_COMPATIBLE_CONFIG_VERSION = 9 - -CONFIG_LOGGER = AIIDA_LOGGER.getChild('config') - - -class SingleMigration(Protocol): - """A single migration of the configuration.""" - - down_revision: int - """The initial configuration version.""" - - down_compatible: int - """The initial oldest backwards compatible configuration version""" - - up_revision: int - """The final configuration version.""" - - up_compatible: int - """The final oldest backwards compatible configuration version""" - - def upgrade(self, config: ConfigType) -> None: - """Migrate the configuration in-place.""" - - def downgrade(self, config: ConfigType) -> None: - """Downgrade the configuration in-place.""" - - -class Initial(SingleMigration): - """Base migration (no-op).""" - down_revision = 0 - down_compatible = 0 - up_revision = 1 - up_compatible = 0 - - def upgrade(self, config: ConfigType) -> None: - pass - - def downgrade(self, config: ConfigType) -> None: - pass - - -class AddProfileUuid(SingleMigration): - """Add the required values for a new default profile. - - * PROFILE_UUID - - The profile uuid will be used as a general purpose identifier for the profile, in - for example the RabbitMQ message queues and exchanges. - """ - down_revision = 1 - down_compatible = 0 - up_revision = 2 - up_compatible = 0 - - def upgrade(self, config: ConfigType) -> None: - from uuid import uuid4 # we require this import here, to patch it in the tests - for profile in config.get('profiles', {}).values(): - profile.setdefault('PROFILE_UUID', uuid4().hex) - - def downgrade(self, config: ConfigType) -> None: - # leave the uuid present, so we could migrate back up - pass - - -class SimplifyDefaultProfiles(SingleMigration): - """Replace process specific default profiles with single default profile key. - - The concept of a different 'process' for a profile has been removed and as such the default profiles key in the - configuration no longer needs a value per process ('verdi', 'daemon'). We remove the dictionary 'default_profiles' - and replace it with a simple value 'default_profile'. - """ - down_revision = 2 - down_compatible = 0 - up_revision = 3 - up_compatible = 3 - - def upgrade(self, config: ConfigType) -> None: - from aiida.manage.configuration import get_profile - - global_profile = get_profile() - default_profiles = config.pop('default_profiles', None) - - if default_profiles and 'daemon' in default_profiles: - config['default_profile'] = default_profiles['daemon'] - elif default_profiles and 'verdi' in default_profiles: - config['default_profile'] = default_profiles['verdi'] - elif global_profile is not None: - config['default_profile'] = global_profile.name - - def downgrade(self, config: ConfigType) -> None: - if 'default_profile' in config: - default = config.pop('default_profile') - config['default_profiles'] = {'daemon': default, 'verdi': default} - - -class AddMessageBroker(SingleMigration): - """Add the configuration for the message broker, which was not configurable up to now.""" - down_revision = 3 - down_compatible = 3 - up_revision = 4 - up_compatible = 3 - - def upgrade(self, config: ConfigType) -> None: - from aiida.manage.external.rmq import BROKER_DEFAULTS - defaults = [ - ('broker_protocol', BROKER_DEFAULTS.protocol), - ('broker_username', BROKER_DEFAULTS.username), - ('broker_password', BROKER_DEFAULTS.password), - ('broker_host', BROKER_DEFAULTS.host), - ('broker_port', BROKER_DEFAULTS.port), - ('broker_virtual_host', BROKER_DEFAULTS.virtual_host), - ] - - for profile in config.get('profiles', {}).values(): - for key, default in defaults: - if key not in profile: - profile[key] = default - - def downgrade(self, config: ConfigType) -> None: - pass - - -class SimplifyOptions(SingleMigration): - """Remove unnecessary difference between file/internal representation of options""" - down_revision = 4 - down_compatible = 3 - up_revision = 5 - up_compatible = 5 - - conversions = ( - ('runner_poll_interval', 'runner.poll.interval'), - ('daemon_default_workers', 'daemon.default_workers'), - ('daemon_timeout', 'daemon.timeout'), - ('daemon_worker_process_slots', 'daemon.worker_process_slots'), - ('db_batch_size', 'db.batch_size'), - ('verdi_shell_auto_import', 'verdi.shell.auto_import'), - ('logging_aiida_log_level', 'logging.aiida_loglevel'), - ('logging_db_log_level', 'logging.db_loglevel'), - ('logging_plumpy_log_level', 'logging.plumpy_loglevel'), - ('logging_kiwipy_log_level', 'logging.kiwipy_loglevel'), - ('logging_paramiko_log_level', 'logging.paramiko_loglevel'), - ('logging_alembic_log_level', 'logging.alembic_loglevel'), - ('logging_sqlalchemy_loglevel', 'logging.sqlalchemy_loglevel'), - ('logging_circus_log_level', 'logging.circus_loglevel'), - ('user_email', 'autofill.user.email'), - ('user_first_name', 'autofill.user.first_name'), - ('user_last_name', 'autofill.user.last_name'), - ('user_institution', 'autofill.user.institution'), - ('show_deprecations', 'warnings.showdeprecations'), - ('task_retry_initial_interval', 'transport.task_retry_initial_interval'), - ('task_maximum_attempts', 'transport.task_maximum_attempts'), - ) - - def upgrade(self, config: ConfigType) -> None: - for current, new in self.conversions: - # replace in profile options - for profile in config.get('profiles', {}).values(): - if current in profile.get('options', {}): - profile['options'][new] = profile['options'].pop(current) - # replace in global options - if current in config.get('options', {}): - config['options'][new] = config['options'].pop(current) - - def downgrade(self, config: ConfigType) -> None: - for current, new in self.conversions: - # replace in profile options - for profile in config.get('profiles', {}).values(): - if new in profile.get('options', {}): - profile['options'][current] = profile['options'].pop(new) - # replace in global options - if new in config.get('options', {}): - config['options'][current] = config['options'].pop(new) - - -class AbstractStorageAndProcess(SingleMigration): - """Move the storage config under a top-level "storage" key and rabbitmq config under "processing". - - This allows for different storage backends to have different configuration. - """ - down_revision = 5 - down_compatible = 5 - up_revision = 6 - up_compatible = 6 - - storage_conversions = ( - ('AIIDADB_ENGINE', 'database_engine'), - ('AIIDADB_HOST', 'database_hostname'), - ('AIIDADB_PORT', 'database_port'), - ('AIIDADB_USER', 'database_username'), - ('AIIDADB_PASS', 'database_password'), - ('AIIDADB_NAME', 'database_name'), - ('AIIDADB_REPOSITORY_URI', 'repository_uri'), - ) - process_keys = ( - 'broker_protocol', - 'broker_username', - 'broker_password', - 'broker_host', - 'broker_port', - 'broker_virtual_host', - 'broker_parameters', - ) - - def upgrade(self, config: ConfigType) -> None: - for profile_name, profile in config.get('profiles', {}).items(): - profile.setdefault('storage', {}) - if 'AIIDADB_BACKEND' not in profile: - CONFIG_LOGGER.warning(f'profile {profile_name!r} had no expected "AIIDADB_BACKEND" key') - profile['storage']['backend'] = profile.pop('AIIDADB_BACKEND', None) - profile['storage'].setdefault('config', {}) - for old, new in self.storage_conversions: - if old in profile: - profile['storage']['config'][new] = profile.pop(old) - else: - CONFIG_LOGGER.warning(f'profile {profile_name!r} had no expected {old!r} key') - profile.setdefault('process_control', {}) - profile['process_control']['backend'] = 'rabbitmq' - profile['process_control'].setdefault('config', {}) - for key in self.process_keys: - if key in profile: - profile['process_control']['config'][key] = profile.pop(key) - elif key not in ('broker_parameters', 'broker_virtual_host'): - CONFIG_LOGGER.warning(f'profile {profile_name!r} had no expected {old!r} key') - - def downgrade(self, config: ConfigType) -> None: - for profile_name, profile in config.get('profiles', {}).items(): - profile['AIIDADB_BACKEND'] = profile.get('storage', {}).get('backend', None) - if profile['AIIDADB_BACKEND'] is None: - CONFIG_LOGGER.warning(f'profile {profile_name!r} had no expected "storage.backend" key') - for old, new in self.storage_conversions: - if new in profile.get('storage', {}).get('config', {}): - profile[old] = profile['storage']['config'].pop(new) - profile.pop('storage', None) - for key in self.process_keys: - if key in profile.get('process_control', {}).get('config', {}): - profile[key] = profile['process_control']['config'].pop(key) - profile.pop('process_control', None) - - -class MergeStorageBackendTypes(SingleMigration): - """`django` and `sqlalchemy` are now merged into `psql_dos`. - - The legacy name is stored under the `_v6_backend` key, to allow for downgrades. - """ - down_revision = 6 - down_compatible = 6 - up_revision = 7 - up_compatible = 7 - - def upgrade(self, config: ConfigType) -> None: - for profile_name, profile in config.get('profiles', {}).items(): - if 'storage' in profile: - storage = profile['storage'] - if 'backend' in storage: - if storage['backend'] in ('django', 'sqlalchemy'): - profile['storage']['_v6_backend'] = storage['backend'] - storage['backend'] = 'psql_dos' - else: - CONFIG_LOGGER.warning( - f'profile {profile_name!r} had unknown storage backend {storage["backend"]!r}' - ) - - def downgrade(self, config: ConfigType) -> None: - for profile_name, profile in config.get('profiles', {}).items(): - if '_v6_backend' in profile.get('storage', {}): - profile.setdefault('storage', {})['backend'] = profile['storage'].pop('_v6_backend') - else: - CONFIG_LOGGER.warning(f'profile {profile_name!r} had no expected "storage._v6_backend" key') - - -class AddTestProfileKey(SingleMigration): - """Add the ``test_profile`` key.""" - down_revision = 7 - down_compatible = 7 - up_revision = 8 - up_compatible = 8 - - def upgrade(self, config: ConfigType) -> None: - for profile_name, profile in config.get('profiles', {}).items(): - profile['test_profile'] = profile_name.startswith('test_') - - def downgrade(self, config: ConfigType) -> None: - profiles = config.get('profiles', {}) - profile_names = list(profiles.keys()) - - # Iterate over the fixed list of the profile names, since we are mutating the profiles dictionary. - for profile_name in profile_names: - - profile = profiles.pop(profile_name) - profile_name_new = None - test_profile = profile.pop('test_profile', False) # If absent, assume it is not a test profile - - if test_profile and not profile_name.startswith('test_'): - profile_name_new = f'test_{profile_name}' - CONFIG_LOGGER.warning( - f'profile `{profile_name}` is a test profile but does not start with the required `test_` prefix.' - ) - - if not test_profile and profile_name.startswith('test_'): - profile_name_new = profile_name[5:] - CONFIG_LOGGER.warning( - f'profile `{profile_name}` is not a test profile but starts with the `test_` prefix.' - ) - - if profile_name_new is not None: - - if profile_name_new in profile_names: - raise exceptions.ConfigurationError( - f'cannot change `{profile_name}` to `{profile_name_new}` because it already exists.' - ) - - CONFIG_LOGGER.warning(f'changing profile name from `{profile_name}` to `{profile_name_new}`.') - profile_name = profile_name_new - - profile['test_profile'] = test_profile - profiles[profile_name] = profile - - -class AddPrefixToStorageBackendTypes(SingleMigration): - """The ``storage.backend`` key should be prefixed with ``core.``. - - At this point, it should only ever contain ``psql_dos`` which should therefore become ``core.psql_dos``. To cover - for cases where people manually added a read only ``sqlite_zip`` profile, we also migrate that. - """ - down_revision = 8 - down_compatible = 8 - up_revision = 9 - up_compatible = 9 - - def upgrade(self, config: ConfigType) -> None: - for profile_name, profile in config.get('profiles', {}).items(): - if 'storage' in profile: - backend = profile['storage'].get('backend', None) - if backend in ('psql_dos', 'sqlite_zip', 'sqlite_temp'): - profile['storage']['backend'] = 'core.' + backend - else: - CONFIG_LOGGER.warning(f'profile {profile_name!r} had unknown storage backend {backend!r}') - - def downgrade(self, config: ConfigType) -> None: - for profile_name, profile in config.get('profiles', {}).items(): - backend = profile.get('storage', {}).get('backend', None) - if backend in ('core.psql_dos', 'core.sqlite_zip', 'core.sqlite_temp'): - profile.setdefault('storage', {})['backend'] = backend[5:] - else: - CONFIG_LOGGER.warning( - f'profile {profile_name!r} has storage backend {backend!r} that will not be compatible ' - 'with the version of `aiida-core` that can be used with the new version of the configuration.' - ) - - -MIGRATIONS = ( - Initial, - AddProfileUuid, - SimplifyDefaultProfiles, - AddMessageBroker, - SimplifyOptions, - AbstractStorageAndProcess, - MergeStorageBackendTypes, - AddTestProfileKey, - AddPrefixToStorageBackendTypes, -) - - -def get_current_version(config): - """Return the current version of the config. - - :return: current config version or 0 if not defined - """ - return config.get('CONFIG_VERSION', {}).get('CURRENT', 0) - - -def get_oldest_compatible_version(config): - """Return the current oldest compatible version of the config. - - :return: current oldest compatible config version or 0 if not defined - """ - return config.get('CONFIG_VERSION', {}).get('OLDEST_COMPATIBLE', 0) - - -def upgrade_config( - config: ConfigType, - target: int = CURRENT_CONFIG_VERSION, - migrations: Iterable[Type[SingleMigration]] = MIGRATIONS -) -> ConfigType: - """Run the registered configuration migrations up to the target version. - - :param config: the configuration dictionary - :return: the migrated configuration dictionary - """ - current = get_current_version(config) - used = [] - while current < target: - current = get_current_version(config) - try: - migrator = next(m for m in migrations if m.down_revision == current) - except StopIteration: - raise exceptions.ConfigurationError(f'No migration found to upgrade version {current}') - if migrator in used: - raise exceptions.ConfigurationError(f'Circular migration detected, upgrading to {target}') - used.append(migrator) - migrator().upgrade(config) - current = migrator.up_revision - config.setdefault('CONFIG_VERSION', {})['CURRENT'] = current - config['CONFIG_VERSION']['OLDEST_COMPATIBLE'] = migrator.up_compatible - if current != target: - raise exceptions.ConfigurationError(f'Could not upgrade to version {target}, current version is {current}') - return config - - -def downgrade_config( - config: ConfigType, target: int, migrations: Iterable[Type[SingleMigration]] = MIGRATIONS -) -> ConfigType: - """Run the registered configuration migrations down to the target version. - - :param config: the configuration dictionary - :return: the migrated configuration dictionary - """ - current = get_current_version(config) - used = [] - while current > target: - current = get_current_version(config) - try: - migrator = next(m for m in migrations if m.up_revision == current) - except StopIteration: - raise exceptions.ConfigurationError(f'No migration found to downgrade version {current}') - if migrator in used: - raise exceptions.ConfigurationError(f'Circular migration detected, downgrading to {target}') - used.append(migrator) - migrator().downgrade(config) - config.setdefault('CONFIG_VERSION', {})['CURRENT'] = current = migrator.down_revision - config['CONFIG_VERSION']['OLDEST_COMPATIBLE'] = migrator.down_compatible - if current != target: - raise exceptions.ConfigurationError(f'Could not downgrade to version {target}, current version is {current}') - return config - - -def check_and_migrate_config(config, filepath: Optional[str] = None): - """Checks if the config needs to be migrated, and performs the migration if needed. - - :param config: the configuration dictionary - :param filepath: the path to the configuration file (optional, for error reporting) - :return: the migrated configuration dictionary - """ - if config_needs_migrating(config, filepath): - config = upgrade_config(config) - - return config - - -def config_needs_migrating(config, filepath: Optional[str] = None): - """Checks if the config needs to be migrated. - - If the oldest compatible version of the configuration is higher than the current configuration version defined - in the code, the config cannot be used and so the function will raise. - - :param filepath: the path to the configuration file (optional, for error reporting) - :return: True if the configuration has an older version and needs to be migrated, False otherwise - :raises aiida.common.ConfigurationVersionError: if the config's oldest compatible version is higher than the current - """ - current_version = get_current_version(config) - oldest_compatible_version = get_oldest_compatible_version(config) - - if oldest_compatible_version > CURRENT_CONFIG_VERSION: - filepath = filepath if filepath else '' - raise exceptions.ConfigurationVersionError( - f'The configuration file has version {current_version} ' - f'which is not compatible with the current version {CURRENT_CONFIG_VERSION}: {filepath}\n' - 'Use a newer version of AiiDA to downgrade this configuration.' - ) - - return CURRENT_CONFIG_VERSION > current_version diff --git a/aiida/manage/configuration/options.py b/aiida/manage/configuration/options.py deleted file mode 100644 index f41f6f3283..0000000000 --- a/aiida/manage/configuration/options.py +++ /dev/null @@ -1,139 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Definition of known configuration options and methods to parse and get option values.""" -from typing import Any, Dict, List, Tuple - -import jsonschema - -from aiida.common.exceptions import ConfigurationError - -__all__ = ('get_option', 'get_option_names', 'parse_option', 'Option') - -NO_DEFAULT = () - - -class Option: - """Represent a configuration option schema.""" - - def __init__(self, name: str, schema: Dict[str, Any]): - self._name = name - self._schema = schema - - def __str__(self) -> str: - return f'Option(name={self._name})' - - @property - def name(self) -> str: - return self._name - - @property - def schema(self) -> Dict[str, Any]: - return self._schema - - @property - def valid_type(self) -> Any: - return self._schema.get('type', None) - - @property - def default(self) -> Any: - return self._schema.get('default', NO_DEFAULT) - - @property - def description(self) -> str: - return self._schema.get('description', '') - - @property - def global_only(self) -> bool: - return self._schema.get('global_only', False) - - def validate(self, value: Any, cast: bool = True) -> Any: - """Validate a value - - :param value: The input value - :param cast: Attempt to cast the value to the required type - - :return: The output value - :raise: ConfigValidationError - - """ - # pylint: disable=too-many-branches - from aiida.manage.caching import _validate_identifier_pattern - - from .config import ConfigValidationError - - if cast: - try: - if self.valid_type == 'boolean': - if isinstance(value, str): - if value.strip().lower() in ['0', 'false', 'f']: - value = False - elif value.strip().lower() in ['1', 'true', 't']: - value = True - else: - value = bool(value) - elif self.valid_type == 'string': - value = str(value) - elif self.valid_type == 'integer': - value = int(value) - elif self.valid_type == 'number': - value = float(value) - elif self.valid_type == 'array' and isinstance(value, str): - value = value.split() - except ValueError: - pass - - try: - jsonschema.validate(instance=value, schema=self.schema) - except jsonschema.ValidationError as exc: - raise ConfigValidationError(message=exc.message, keypath=[self.name, *(exc.path or [])], schema=exc.schema) - - # special caching validation - if self.name in ('caching.enabled_for', 'caching.disabled_for'): - for i, identifier in enumerate(value): - try: - _validate_identifier_pattern(identifier=identifier) - except ValueError as exc: - raise ConfigValidationError(message=str(exc), keypath=[self.name, str(i)]) - - return value - - -def get_schema_options() -> Dict[str, Dict[str, Any]]: - """Return schema for options.""" - from .config import config_schema - schema = config_schema() - return schema['definitions']['options']['properties'] - - -def get_option_names() -> List[str]: - """Return a list of available option names.""" - return list(get_schema_options()) - - -def get_option(name: str) -> Option: - """Return option.""" - options = get_schema_options() - if name not in options: - raise ConfigurationError(f'the option {name} does not exist') - return Option(name, options[name]) - - -def parse_option(option_name: str, option_value: Any) -> Tuple[Option, Any]: - """Parse and validate a value for a configuration option. - - :param option_name: the name of the configuration option - :param option_value: the option value - :return: a tuple of the option and the parsed value - - """ - option = get_option(option_name) - value = option.validate(option_value, cast=True) - - return option, value diff --git a/aiida/manage/configuration/profile.py b/aiida/manage/configuration/profile.py deleted file mode 100644 index 24d4ac640a..0000000000 --- a/aiida/manage/configuration/profile.py +++ /dev/null @@ -1,272 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""AiiDA profile related code""" -import collections -from copy import deepcopy -import os -import pathlib -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Type - -from aiida.common import exceptions - -from .options import parse_option - -if TYPE_CHECKING: - from aiida.orm.implementation import StorageBackend - -__all__ = ('Profile',) - - -class Profile: # pylint: disable=too-many-public-methods - """Class that models a profile as it is stored in the configuration file of an AiiDA instance.""" - - KEY_UUID = 'PROFILE_UUID' - KEY_DEFAULT_USER_EMAIL = 'default_user_email' - KEY_STORAGE = 'storage' - KEY_PROCESS = 'process_control' - KEY_STORAGE_BACKEND = 'backend' - KEY_STORAGE_CONFIG = 'config' - KEY_PROCESS_BACKEND = 'backend' - KEY_PROCESS_CONFIG = 'config' - KEY_OPTIONS = 'options' - KEY_TEST_PROFILE = 'test_profile' - - # keys that are expected to be in the parsed configuration - REQUIRED_KEYS = ( - KEY_STORAGE, - KEY_PROCESS, - ) - - def __init__(self, name: str, config: Mapping[str, Any], validate=True): - """Load a profile with the profile configuration.""" - if not isinstance(config, collections.abc.Mapping): - raise TypeError(f'config should be a mapping but is {type(config)}') - if validate and not set(config.keys()).issuperset(self.REQUIRED_KEYS): - raise exceptions.ConfigurationError( - f'profile {name!r} configuration does not contain all required keys: {self.REQUIRED_KEYS}' - ) - - self._name = name - self._attributes: Dict[str, Any] = deepcopy(config) - - # Create a default UUID if not specified - if self._attributes.get(self.KEY_UUID, None) is None: - from uuid import uuid4 - self._attributes[self.KEY_UUID] = uuid4().hex - - def __repr__(self) -> str: - return f'Profile' - - def copy(self): - """Return a copy of the profile.""" - return self.__class__(self.name, self._attributes) - - @property - def uuid(self) -> str: - """Return the profile uuid. - - :return: string UUID - """ - return self._attributes[self.KEY_UUID] - - @uuid.setter - def uuid(self, value: str) -> None: - self._attributes[self.KEY_UUID] = value - - @property - def default_user_email(self) -> Optional[str]: - """Return the default user email.""" - return self._attributes.get(self.KEY_DEFAULT_USER_EMAIL, None) - - @default_user_email.setter - def default_user_email(self, value: Optional[str]) -> None: - """Set the default user email.""" - self._attributes[self.KEY_DEFAULT_USER_EMAIL] = value - - @property - def storage_backend(self) -> str: - """Return the type of the storage backend.""" - return self._attributes[self.KEY_STORAGE][self.KEY_STORAGE_BACKEND] - - @property - def storage_config(self) -> Dict[str, Any]: - """Return the configuration required by the storage backend.""" - return self._attributes[self.KEY_STORAGE][self.KEY_STORAGE_CONFIG] - - def set_storage(self, name: str, config: Dict[str, Any]) -> None: - """Set the storage backend and its configuration. - - :param name: the name of the storage backend - :param config: the configuration of the storage backend - """ - self._attributes.setdefault(self.KEY_STORAGE, {}) - self._attributes[self.KEY_STORAGE][self.KEY_STORAGE_BACKEND] = name - self._attributes[self.KEY_STORAGE][self.KEY_STORAGE_CONFIG] = config - - @property - def storage_cls(self) -> Type['StorageBackend']: - """Return the storage backend class for this profile.""" - from aiida.plugins import StorageFactory - return StorageFactory(self.storage_backend) - - @property - def process_control_backend(self) -> str: - """Return the type of the process control backend.""" - return self._attributes[self.KEY_PROCESS][self.KEY_PROCESS_BACKEND] - - @property - def process_control_config(self) -> Dict[str, Any]: - """Return the configuration required by the process control backend.""" - return self._attributes[self.KEY_PROCESS][self.KEY_PROCESS_CONFIG] - - def set_process_controller(self, name: str, config: Dict[str, Any]) -> None: - """Set the process control backend and its configuration. - - :param name: the name of the process backend - :param config: the configuration of the process backend - """ - self._attributes.setdefault(self.KEY_PROCESS, {}) - self._attributes[self.KEY_PROCESS][self.KEY_PROCESS_BACKEND] = name - self._attributes[self.KEY_PROCESS][self.KEY_PROCESS_CONFIG] = config - - @property - def options(self): - self._attributes.setdefault(self.KEY_OPTIONS, {}) - return self._attributes[self.KEY_OPTIONS] - - @options.setter - def options(self, value): - self._attributes[self.KEY_OPTIONS] = value - - def get_option(self, option_key, default=None): - return self.options.get(option_key, default) - - def set_option(self, option_key, value, override=True): - """Set a configuration option for a certain scope. - - :param option_key: the key of the configuration option - :param option_value: the option value - :param override: boolean, if False, will not override the option if it already exists - """ - _, parsed_value = parse_option(option_key, value) # ensure the value is validated - if option_key not in self.options or override: - self.options[option_key] = parsed_value - - def unset_option(self, option_key): - self.options.pop(option_key, None) - - @property - def name(self): - """Return the profile name. - - :return: the profile name - """ - return self._name - - @property - def dictionary(self) -> Dict[str, Any]: - """Return the profile attributes as a dictionary with keys as it is stored in the config - - :return: the profile configuration dictionary - """ - return self._attributes - - @property - def is_test_profile(self) -> bool: - """Return whether the profile is a test profile - - :return: boolean, True if test profile, False otherwise - """ - # Check explicitly for ``True`` for safety. If an invalid value is defined, we default to treating it as not - # a test profile as that can unintentionally clear the database. - return self._attributes.get(self.KEY_TEST_PROFILE, False) is True - - @is_test_profile.setter - def is_test_profile(self, value: bool) -> None: - """Set whether the profile is a test profile. - - :param value: boolean indicating whether this profile is a test profile. - """ - self._attributes[self.KEY_TEST_PROFILE] = value - - @property - def repository_path(self) -> pathlib.Path: - """Return the absolute path of the repository configured for this profile. - - The URI should be in the format `protocol://address` - - :note: At the moment, only the file protocol is supported. - - :return: absolute filepath of the profile's file repository - """ - from urllib.parse import urlparse - - from aiida.common.warnings import warn_deprecation - - warn_deprecation('This method has been deprecated', version=3) - - if 'repository_uri' not in self.storage_config: - raise KeyError('repository_uri not defined in profile storage config') - - parts = urlparse(self.storage_config['repository_uri']) - - if parts.scheme != 'file': - raise exceptions.ConfigurationError('invalid repository protocol, only the local `file://` is supported') - - if not os.path.isabs(parts.path): - raise exceptions.ConfigurationError('invalid repository URI: the path has to be absolute') - - return pathlib.Path(os.path.expanduser(parts.path)) - - @property - def rmq_prefix(self) -> str: - """Return the prefix that should be used for RMQ resources - - :return: the rmq prefix string - """ - return f'aiida-{self.uuid}' - - def get_rmq_url(self) -> str: - """Return the RMQ url for this profile.""" - from aiida.manage.external.rmq import get_rmq_url - - if self.process_control_backend != 'rabbitmq': - raise exceptions.ConfigurationError( - f"invalid process control backend, only 'rabbitmq' is supported: {self.process_control_backend}" - ) - kwargs = {key[7:]: val for key, val in self.process_control_config.items() if key.startswith('broker_')} - additional_kwargs = kwargs.pop('parameters', {}) - return get_rmq_url(**kwargs, **additional_kwargs) - - @property - def filepaths(self): - """Return the filepaths used by this profile. - - :return: a dictionary of filepaths - """ - from .settings import DAEMON_DIR, DAEMON_LOG_DIR - - return { - 'circus': { - 'log': str(DAEMON_LOG_DIR / f'circus-{self.name}.log'), - 'pid': str(DAEMON_DIR / f'circus-{self.name}.pid'), - 'port': str(DAEMON_DIR / f'circus-{self.name}.port'), - 'socket': { - 'file': str(DAEMON_DIR / f'circus-{self.name}.sockets'), - 'controller': 'circus.c.sock', - 'pubsub': 'circus.p.sock', - 'stats': 'circus.s.sock', - } - }, - 'daemon': { - 'log': str(DAEMON_LOG_DIR / f'aiida-{self.name}.log'), - 'pid': str(DAEMON_DIR / f'aiida-{self.name}.pid'), - } - } diff --git a/aiida/manage/configuration/schema/__init__.py b/aiida/manage/configuration/schema/__init__.py deleted file mode 100644 index 2776a55f97..0000000000 --- a/aiida/manage/configuration/schema/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### diff --git a/aiida/manage/configuration/settings.py b/aiida/manage/configuration/settings.py deleted file mode 100644 index 1236055468..0000000000 --- a/aiida/manage/configuration/settings.py +++ /dev/null @@ -1,121 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Base settings required for the configuration of an AiiDA instance.""" -import os -import pathlib -import typing -import warnings - -DEFAULT_UMASK = 0o0077 -DEFAULT_AIIDA_PATH_VARIABLE = 'AIIDA_PATH' -DEFAULT_AIIDA_PATH = '~' -DEFAULT_AIIDA_USER = 'aiida@localhost' -DEFAULT_CONFIG_DIR_NAME = '.aiida' -DEFAULT_CONFIG_FILE_NAME = 'config.json' -DEFAULT_CONFIG_INDENT_SIZE = 4 -DEFAULT_DAEMON_DIR_NAME = 'daemon' -DEFAULT_DAEMON_LOG_DIR_NAME = 'log' -DEFAULT_ACCESS_CONTROL_DIR_NAME = 'access' - -AIIDA_CONFIG_FOLDER: typing.Optional[pathlib.Path] = None -DAEMON_DIR: typing.Optional[pathlib.Path] = None -DAEMON_LOG_DIR: typing.Optional[pathlib.Path] = None -ACCESS_CONTROL_DIR: typing.Optional[pathlib.Path] = None - - -def create_instance_directories(): - """Create the base directories required for a new AiiDA instance. - - This will create the base AiiDA directory defined by the AIIDA_CONFIG_FOLDER variable, unless it already exists. - Subsequently, it will create the daemon directory within it and the daemon log directory. - """ - from aiida.common import ConfigurationError - - directory_base = pathlib.Path(AIIDA_CONFIG_FOLDER).expanduser() - directory_daemon = directory_base / DAEMON_DIR - directory_daemon_log = directory_base / DAEMON_LOG_DIR - directory_access = directory_base / ACCESS_CONTROL_DIR - - list_of_paths = [ - directory_base, - directory_daemon, - directory_daemon_log, - directory_access, - ] - - umask = os.umask(DEFAULT_UMASK) - - try: - for path in list_of_paths: - - if path is directory_base and not path.exists(): - warnings.warn(f'Creating AiiDA configuration folder `{path}`.') - - try: - path.mkdir(parents=True, exist_ok=True) - except OSError as exc: - raise ConfigurationError(f'could not create the `{path}` configuration directory: {exc}') from exc - finally: - os.umask(umask) - - -def set_configuration_directory(aiida_config_folder: pathlib.Path = None): - """Determine location of configuration directory, set related global variables and create instance directories. - - The location of the configuration folder will be determined and optionally created following these heuristics: - - * If an explicit path is provided by `aiida_config_folder`, that will be set as the configuration folder. - * Otherwise, if the `AIIDA_PATH` variable is set, all the paths will be checked to see if they contain a - configuration folder. The first one to be encountered will be set as `AIIDA_CONFIG_FOLDER`. If none of them - contain one, a configuration folder will be created in the last path considered. - * If the `AIIDA_PATH` variable is not set the `DEFAULT_AIIDA_PATH` value will be used as base path and if it - does not yet contain a configuration folder, one will be created. - - In principle then, a configuration folder should always be found or automatically created. - """ - # pylint: disable = global-statement - global AIIDA_CONFIG_FOLDER - global DAEMON_DIR - global DAEMON_LOG_DIR - global ACCESS_CONTROL_DIR - - environment_variable = os.environ.get(DEFAULT_AIIDA_PATH_VARIABLE, None) - - if aiida_config_folder is not None: - AIIDA_CONFIG_FOLDER = aiida_config_folder - elif environment_variable: - - # Loop over all the paths in the `AIIDA_PATH` variable to see if any of them contain a configuration folder - for base_dir_path in [path for path in environment_variable.split(':') if path]: - - AIIDA_CONFIG_FOLDER = pathlib.Path(base_dir_path).expanduser() - - # Only add the base config directory name to the base path if it does not already do so - # Someone might already include it in the environment variable. e.g.: AIIDA_PATH=/home/some/path/.aiida - if AIIDA_CONFIG_FOLDER.name != DEFAULT_CONFIG_DIR_NAME: - AIIDA_CONFIG_FOLDER = AIIDA_CONFIG_FOLDER / DEFAULT_CONFIG_DIR_NAME - - # If the directory exists, we leave it set and break the loop - if AIIDA_CONFIG_FOLDER.is_dir(): - break - - else: - # The `AIIDA_PATH` variable is not set, so default to the default path and try to create it if it does not exist - AIIDA_CONFIG_FOLDER = pathlib.Path(DEFAULT_AIIDA_PATH).expanduser() / DEFAULT_CONFIG_DIR_NAME - - DAEMON_DIR = AIIDA_CONFIG_FOLDER / DEFAULT_DAEMON_DIR_NAME - DAEMON_LOG_DIR = DAEMON_DIR / DEFAULT_DAEMON_LOG_DIR_NAME - ACCESS_CONTROL_DIR = AIIDA_CONFIG_FOLDER / DEFAULT_ACCESS_CONTROL_DIR_NAME - - create_instance_directories() - - -# Initialize the configuration directory settings -set_configuration_directory() diff --git a/aiida/manage/external/__init__.py b/aiida/manage/external/__init__.py deleted file mode 100644 index d5ebc58bbd..0000000000 --- a/aiida/manage/external/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""User facing APIs to control AiiDA from the verdi cli, scripts or plugins""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .postgres import * -from .rmq import * - -__all__ = ( - 'BROKER_DEFAULTS', - 'DEFAULT_DBINFO', - 'ManagementApiConnectionError', - 'Postgres', - 'PostgresConnectionMode', - 'ProcessLauncher', - 'RabbitmqManagementClient', - 'get_launch_queue_name', - 'get_message_exchange_name', - 'get_rmq_url', - 'get_task_exchange_name', -) - -# yapf: enable diff --git a/aiida/manage/external/rmq/__init__.py b/aiida/manage/external/rmq/__init__.py deleted file mode 100644 index 503173d812..0000000000 --- a/aiida/manage/external/rmq/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with utilities to interact with RabbitMQ.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .client import * -from .defaults import * -from .launcher import * -from .utils import * - -__all__ = ( - 'BROKER_DEFAULTS', - 'ManagementApiConnectionError', - 'ProcessLauncher', - 'RabbitmqManagementClient', - 'get_launch_queue_name', - 'get_message_exchange_name', - 'get_rmq_url', - 'get_task_exchange_name', -) - -# yapf: enable diff --git a/aiida/manage/external/rmq/client.py b/aiida/manage/external/rmq/client.py deleted file mode 100644 index 3c0938b7a1..0000000000 --- a/aiida/manage/external/rmq/client.py +++ /dev/null @@ -1,93 +0,0 @@ -# -*- coding: utf-8 -*- -"""Client for RabbitMQ Management HTTP API.""" -from __future__ import annotations - -import typing as t -from urllib.parse import quote - -import requests - -from aiida.common.exceptions import AiidaException - -__all__ = ('RabbitmqManagementClient', 'ManagementApiConnectionError') - - -class ManagementApiConnectionError(AiidaException): - """Raised when no connection can be made to the management HTTP API.""" - - -class RabbitmqManagementClient: - """Client for RabbitMQ Management HTTP API. - - This requires the ``rabbitmq_management`` plugin (https://www.rabbitmq.com/management.html) to be enabled. Typically - this is enabled by running ``rabbitmq-plugins enable rabbitmq_management``. - """ - - def __init__(self, username: str, password: str, hostname: str, virtual_host: str): - """Construct a new instance. - - :param username: The username to authenticate with. - :param password: The password to authenticate with. - :param hostname: The hostname of the RabbitMQ server. - :param virtual_host: The virtual host. - """ - self._username = username - self._password = password - self._hostname = hostname - self._virtual_host = virtual_host - self._authentication = requests.auth.HTTPBasicAuth(username, password) - - def format_url(self, url: str, url_params: dict[str, str] | None = None) -> str: - """Format the complete URL from a partial resource path with placeholders. - - The base URL will be automatically prepended. - - :param url: The resource path with placeholders, e.g., ``queues/{virtual_host}/{queue}``. - :param url_params: Dictionary with values for the placeholders in the ``url``. The ``virtual_host`` value is - automatically inserted and should not be specified. - :returns: The complete URL. - """ - url_params = url_params or {} - url_params['virtual_host'] = self._virtual_host if self._virtual_host else '/' - url_params = {key: quote(value, safe='') for key, value in url_params.items()} - return f'http://{self._hostname}:15672/api/{url.format(**url_params)}' - - def request( - self, - url: str, - url_params: dict[str, str] | None = None, - method: str = 'GET', - params: dict[str, t.Any] | None = None, - ) -> requests.Response: - """Make a request. - - :param url: The resource path with placeholders, e.g., ``queues/{virtual_host}/{queue}``. - :param url_params: Dictionary with values for the placeholders in the ``url``. The ``virtual_host`` value is - automatically inserted and should not be specified. - :param method: The HTTP method. - :param params: Query parameters to add to the URL. - :returns: The response of the request. - :raises `ManagementApiConnectionError`: If connection to the API cannot be made. - """ - url = self.format_url(url, url_params) - try: - return requests.request(method, url, auth=self._authentication, params=params or {}, timeout=5) - except requests.exceptions.ConnectionError as exception: - raise ManagementApiConnectionError( - 'Could not connect to the management API. Make sure RabbitMQ is running and the management plugin is ' - 'installed using `sudo rabbitmq-plugins enable rabbitmq_management`.' - ) from exception - - @property - def is_connected(self) -> bool: - """Return whether the API server can be connected to. - - .. note:: Tries to reach the server at the ``/api/cluster-name`` end-point. - - :returns: ``True`` if the server can be reached, ``False`` otherwise. - """ - try: - self.request('cluster-name') - except ManagementApiConnectionError: - return False - return True diff --git a/aiida/manage/external/rmq/defaults.py b/aiida/manage/external/rmq/defaults.py deleted file mode 100644 index 16058d8a52..0000000000 --- a/aiida/manage/external/rmq/defaults.py +++ /dev/null @@ -1,19 +0,0 @@ -# -*- coding: utf-8 -*- -"""Defaults related to RabbitMQ.""" -from aiida.common.extendeddicts import AttributeDict - -__all__ = ('BROKER_DEFAULTS',) - -LAUNCH_QUEUE = 'process.queue' -MESSAGE_EXCHANGE = 'messages' -TASK_EXCHANGE = 'tasks' - -BROKER_DEFAULTS = AttributeDict({ - 'protocol': 'amqp', - 'username': 'guest', - 'password': 'guest', - 'host': '127.0.0.1', - 'port': 5672, - 'virtual_host': '', - 'heartbeat': 600, -}) diff --git a/aiida/manage/external/rmq/utils.py b/aiida/manage/external/rmq/utils.py deleted file mode 100644 index 314f3946d3..0000000000 --- a/aiida/manage/external/rmq/utils.py +++ /dev/null @@ -1,75 +0,0 @@ -# -*- coding: utf-8 -*- -"""Utilites for RabbitMQ.""" -from urllib.parse import urlencode, urlunparse - -from . import defaults - -__all__ = ('get_rmq_url', 'get_launch_queue_name', 'get_message_exchange_name', 'get_task_exchange_name') - - -def get_rmq_url(protocol=None, username=None, password=None, host=None, port=None, virtual_host=None, **kwargs): - """Return the URL to connect to RabbitMQ. - - .. note:: - - The default of the ``host`` is set to ``127.0.0.1`` instead of ``localhost`` because on some computers localhost - resolves first to IPv6 with address ::1 and if RMQ is not running on IPv6 one gets an annoying warning. For more - info see: https://github.com/aiidateam/aiida-core/issues/1142 - - :param protocol: the protocol to use, `amqp` or `amqps`. - :param username: the username for authentication. - :param password: the password for authentication. - :param host: the hostname of the RabbitMQ server. - :param port: the port of the RabbitMQ server. - :param virtual_host: the virtual host to connect to. - :param kwargs: remaining keyword arguments that will be encoded as query parameters. - :returns: the connection URL string. - """ - if 'heartbeat' not in kwargs: - kwargs['heartbeat'] = defaults.BROKER_DEFAULTS.heartbeat - - scheme = protocol or defaults.BROKER_DEFAULTS.protocol - netloc = '{username}:{password}@{host}:{port}'.format( - username=username or defaults.BROKER_DEFAULTS.username, - password=password or defaults.BROKER_DEFAULTS.password, - host=host or defaults.BROKER_DEFAULTS.host, - port=port or defaults.BROKER_DEFAULTS.port, - ) - path = virtual_host or defaults.BROKER_DEFAULTS.virtual_host - parameters = '' - query = urlencode(kwargs) - fragment = '' - - # The virtual host is optional but if it is specified it needs to start with a forward slash. If the virtual host - # itself contains forward slashes, they need to be encoded. - if path and not path.startswith('/'): - path = f'/{path}' - - return urlunparse((scheme, netloc, path, parameters, query, fragment)) - - -def get_launch_queue_name(prefix=None): - """Return the launch queue name with an optional prefix. - - :returns: launch queue name - """ - if prefix is not None: - return f'{prefix}.{defaults.LAUNCH_QUEUE}' - - return defaults.LAUNCH_QUEUE - - -def get_message_exchange_name(prefix): - """Return the message exchange name for a given prefix. - - :returns: message exchange name - """ - return f'{prefix}.{defaults.MESSAGE_EXCHANGE}' - - -def get_task_exchange_name(prefix): - """Return the task exchange name for a given prefix. - - :returns: task exchange name - """ - return f'{prefix}.{defaults.TASK_EXCHANGE}' diff --git a/aiida/manage/manager.py b/aiida/manage/manager.py deleted file mode 100644 index 9bb862cd69..0000000000 --- a/aiida/manage/manager.py +++ /dev/null @@ -1,508 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=cyclic-import -"""AiiDA manager for global settings""" -import asyncio -import functools -from typing import TYPE_CHECKING, Any, Optional, Union - -if TYPE_CHECKING: - from kiwipy.rmq import RmqThreadCommunicator - from plumpy.process_comms import RemoteProcessThreadController - - from aiida.engine.daemon.client import DaemonClient - from aiida.engine.persistence import AiiDAPersister - from aiida.engine.runners import Runner - from aiida.manage.configuration.config import Config - from aiida.manage.configuration.profile import Profile - from aiida.orm.implementation import StorageBackend - -__all__ = ('get_manager',) - -MANAGER: Optional['Manager'] = None - - -def get_manager() -> 'Manager': - """Return the AiiDA global manager instance.""" - global MANAGER # pylint: disable=global-statement - if MANAGER is None: - MANAGER = Manager() - return MANAGER - - -class Manager: # pylint: disable=too-many-public-methods - """Manager singleton for globally loaded resources. - - AiiDA can have the following global resources loaded: - - 1. A single configuration object that contains: - - - Global options overrides - - The name of a default profile - - A mapping of profile names to their configuration and option overrides - - 2. A single profile object that contains: - - - The name of the profile - - The UUID of the profile - - The configuration of the profile, for connecting to storage and processing resources - - The option overrides for the profile - - 3. A single storage backend object for the profile, to connect to data storage resources - 5. A single daemon client object for the profile, to connect to the AiiDA daemon - 4. A single communicator object for the profile, to connect to the process control resources - 6. A single process controller object for the profile, which uses the communicator to control process tasks - 7. A single runner object for the profile, which uses the process controller to start and stop processes - 8. A single persister object for the profile, which can persist running processes to the profile storage - - """ - - def __init__(self) -> None: - """Construct a new instance.""" - # note: the config currently references the global variables - self._profile: Optional['Profile'] = None - self._profile_storage: Optional['StorageBackend'] = None - self._daemon_client: Optional['DaemonClient'] = None - self._communicator: Optional['RmqThreadCommunicator'] = None - self._process_controller: Optional['RemoteProcessThreadController'] = None - self._persister: Optional['AiiDAPersister'] = None - self._runner: Optional['Runner'] = None - - @staticmethod - def get_config(create=False) -> 'Config': - """Return the current config. - - :return: current loaded config instance - :raises aiida.common.ConfigurationError: if the configuration file could not be found, read or deserialized - - """ - from .configuration import get_config - return get_config(create=create) - - def get_profile(self) -> Optional['Profile']: - """Return the current loaded profile, if any - - :return: current loaded profile instance - """ - return self._profile - - def load_profile(self, profile: Union[None, str, 'Profile'] = None, allow_switch=False) -> 'Profile': - """Load a global profile, unloading any previously loaded profile. - - .. note:: If a profile is already loaded and no explicit profile is specified, nothing will be done. - - :param profile: the name of the profile to load, by default will use the one marked as default in the config - :param allow_switch: if True, will allow switching to a different profile when storage is already loaded - - :return: the loaded `Profile` instance - :raises `aiida.common.exceptions.InvalidOperation`: - if another profile has already been loaded and allow_switch is False - """ - from aiida.common.exceptions import InvalidOperation - from aiida.common.log import configure_logging - from aiida.manage.configuration.profile import Profile - - # If a profile is already loaded and no explicit profile is specified, we do nothing - if profile is None and self._profile: - return self._profile - - if profile is None or isinstance(profile, str): - profile = self.get_config().get_profile(profile) - elif not isinstance(profile, Profile): - raise TypeError(f'profile must be None, a string, or a Profile instance, got: {type(profile)}') - - # If a profile is loaded and the specified profile name is that of the currently loaded, do nothing - if self._profile and (self._profile.name == profile.name): - return self._profile - - if self._profile and self.profile_storage_loaded and not allow_switch: - raise InvalidOperation( - f'cannot switch to profile {profile.name!r} because profile {self._profile.name!r} storage ' - 'is already loaded and allow_switch is False' - ) - - self.unload_profile() - self._profile = profile - - # Reconfigure the logging to make sure that profile specific logging config options are taken into account. - # Note that we do not configure with `with_orm=True` because that will force the backend to be loaded. - # This should instead be done lazily in `Manager.get_profile_storage`. - configure_logging() - - # Check whether a development version is being run. Note that needs to be called after ``configure_logging`` - # because this function relies on the logging being properly configured for the warning to show. - self.check_version() - - return self._profile - - def reset_profile(self) -> None: - """Close and reset any associated resources for the current profile.""" - self.reset_profile_storage() - self.reset_communicator() - self.reset_runner() - - self._daemon_client = None - self._persister = None - - def reset_profile_storage(self) -> None: - """Reset the profile storage. - - This will close any connections to the services used by the storage, such as database connections. - """ - if self._profile_storage is not None: - self._profile_storage.close() - self._profile_storage = None - - def reset_communicator(self) -> None: - """Reset the communicator.""" - if self._communicator is not None: - self._communicator.close() - self._communicator = None - self._process_controller = None - - def reset_runner(self) -> None: - """Reset the process runner.""" - if self._runner is not None: - self._runner.close() - self._runner = None - - def unload_profile(self) -> None: - """Unload the current profile, closing any associated resources.""" - self.reset_profile() - self._profile = None - - @property - def profile_storage_loaded(self) -> bool: - """Return whether a storage backend has been loaded. - - :return: boolean, True if database backend is currently loaded, False otherwise - """ - return self._profile_storage is not None - - def get_option(self, option_name: str) -> Any: - """Return the value of a configuration option. - - In order of priority, the option is returned from: - - 1. The current profile, if loaded and the option specified - 2. The current configuration, if loaded and the option specified - 3. The default value for the option - - :param option_name: the name of the option to return - :return: the value of the option - :raises `aiida.common.exceptions.ConfigurationError`: if the option is not found - """ - from aiida.common.exceptions import ConfigurationError - from aiida.manage.configuration.options import get_option - - # try the profile - if self._profile and option_name in self._profile.options: - return self._profile.get_option(option_name) - # try the config - try: - config = self.get_config(create=True) - except ConfigurationError: - pass - else: - if option_name in config.options: - return config.get_option(option_name) - # try the defaults (will raise ConfigurationError if not present) - option = get_option(option_name) - return option.default - - def get_backend(self) -> 'StorageBackend': - """Return the current profile's storage backend, loading it if necessary. - - Deprecated: use `get_profile_storage` instead. - """ - from aiida.common.warnings import warn_deprecation - warn_deprecation('get_backend() is deprecated, use get_profile_storage() instead', version=3) - return self.get_profile_storage() - - def get_profile_storage(self) -> 'StorageBackend': - """Return the current profile's storage backend, loading it if necessary.""" - from aiida.common import ConfigurationError - from aiida.common.log import configure_logging - from aiida.manage.profile_access import ProfileAccessManager - - # if loaded, return the current storage backend (which is "synced" with the global profile) - if self._profile_storage is not None: - return self._profile_storage - - # get the currently loaded profile - profile = self.get_profile() - if profile is None: - raise ConfigurationError( - 'Could not determine the current profile. Consider loading a profile using `aiida.load_profile()`.' - ) - - # request access to the profile (for example, if it is being used by a maintenance operation) - ProfileAccessManager(profile).request_access() - - # retrieve the storage backend to use for the current profile - storage_cls = profile.storage_cls - - # now we can actually instatiate the backend and set the global variable, note: - # if the storage is not reachable, this will raise an exception - # if the storage schema is not at the latest version, this will except and the user will be informed to migrate - self._profile_storage = storage_cls(profile) - - # Reconfigure the logging with `with_orm=True` to make sure that profile specific logging configuration options - # are taken into account and the `DbLogHandler` is configured. - configure_logging(with_orm=True) - - return self._profile_storage - - def get_persister(self) -> 'AiiDAPersister': - """Return the persister - - :return: the current persister instance - - """ - from aiida.engine import persistence - - if self._persister is None: - self._persister = persistence.AiiDAPersister() - - return self._persister - - def get_communicator(self) -> 'RmqThreadCommunicator': - """Return the communicator - - :return: a global communicator instance - - """ - if self._communicator is None: - self._communicator = self.create_communicator() - - return self._communicator - - def create_communicator(self, task_prefetch_count: Optional[int] = None) -> 'RmqThreadCommunicator': - """Create a Communicator. - - :param task_prefetch_count: optional specify how many tasks this communicator take simultaneously - - :return: the communicator instance - - """ - import kiwipy.rmq - - from aiida.common import ConfigurationError - from aiida.manage.external import rmq - from aiida.orm.utils import serialize - - profile = self.get_profile() - if profile is None: - raise ConfigurationError( - 'Could not determine the current profile. Consider loading a profile using `aiida.load_profile()`.' - ) - - if task_prefetch_count is None: - task_prefetch_count = self.get_option('daemon.worker_process_slots') - - prefix = profile.rmq_prefix - - encoder = functools.partial(serialize.serialize, encoding='utf-8') - decoder = serialize.deserialize_unsafe - - communicator = kiwipy.rmq.RmqThreadCommunicator.connect( - connection_params={'url': profile.get_rmq_url()}, - message_exchange=rmq.get_message_exchange_name(prefix), - encoder=encoder, - decoder=decoder, - task_exchange=rmq.get_task_exchange_name(prefix), - task_queue=rmq.get_launch_queue_name(prefix), - task_prefetch_count=task_prefetch_count, - async_task_timeout=self.get_option('rmq.task_timeout'), - # This is needed because the verdi commands will call this function and when called in unit tests the - # testing_mode cannot be set. - testing_mode=profile.is_test_profile, - ) - - # Check whether a compatible version of RabbitMQ is being used. - self.check_rabbitmq_version(communicator) - - return communicator - - def get_daemon_client(self) -> 'DaemonClient': - """Return the daemon client for the current profile. - - :return: the daemon client - - :raises aiida.common.MissingConfigurationError: if the configuration file cannot be found - :raises aiida.common.ProfileConfigurationError: if the given profile does not exist - """ - from aiida.common import ConfigurationError - from aiida.engine.daemon.client import DaemonClient - - if self._daemon_client is None: - profile = self.get_profile() - if profile is None: - raise ConfigurationError( - 'Could not determine the current profile. Consider loading a profile using `aiida.load_profile()`.' - ) - self._daemon_client = DaemonClient(profile) - - return self._daemon_client - - def get_process_controller(self) -> 'RemoteProcessThreadController': - """Return the process controller - - :return: the process controller instance - - """ - from plumpy.process_comms import RemoteProcessThreadController - if self._process_controller is None: - self._process_controller = RemoteProcessThreadController(self.get_communicator()) - - return self._process_controller - - def get_runner(self, **kwargs) -> 'Runner': - """Return a runner that is based on the current profile settings and can be used globally by the code. - - :return: the global runner - - """ - if self._runner is None: - self._runner = self.create_runner(**kwargs) - - return self._runner - - def set_runner(self, new_runner: 'Runner') -> None: - """Set the currently used runner - - :param new_runner: the new runner to use - - """ - if self._runner is not None: - self._runner.close() - - self._runner = new_runner - - def create_runner(self, with_persistence: bool = True, **kwargs: Any) -> 'Runner': - """Create and return a new runner - - :param with_persistence: create a runner with persistence enabled - - :return: a new runner instance - - """ - from aiida.common import ConfigurationError - from aiida.engine import runners - - profile = self.get_profile() - if profile is None: - raise ConfigurationError( - 'Could not determine the current profile. Consider loading a profile using `aiida.load_profile()`.' - ) - poll_interval = 0.0 if profile.is_test_profile else self.get_option('runner.poll.interval') - - settings = {'rmq_submit': False, 'poll_interval': poll_interval} - settings.update(kwargs) - - if profile.process_control_backend == 'rabbitmq' and 'communicator' not in settings: - # Only call get_communicator if we have to as it will lazily create - settings['communicator'] = self.get_communicator() - - if with_persistence and 'persister' not in settings: - settings['persister'] = self.get_persister() - - return runners.Runner(**settings) - - def create_daemon_runner(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> 'Runner': - """Create and return a new daemon runner. - - This is used by workers when the daemon is running and in testing. - - :param loop: the (optional) asyncio event loop to use - - :return: a runner configured to work in the daemon configuration - - """ - from plumpy.persistence import LoadSaveContext - - from aiida.engine import persistence - from aiida.manage.external import rmq - - runner = self.create_runner(rmq_submit=True, loop=loop) - runner_loop = runner.loop - - # Listen for incoming launch requests - task_receiver = rmq.ProcessLauncher( - loop=runner_loop, - persister=self.get_persister(), - load_context=LoadSaveContext(runner=runner), - loader=persistence.get_object_loader() - ) - - assert runner.communicator is not None, 'communicator not set for runner' - runner.communicator.add_task_subscriber(task_receiver) - - return runner - - def check_rabbitmq_version(self, communicator: 'RmqThreadCommunicator'): - """Check the version of RabbitMQ that is being connected to and emit warning if it is not compatible.""" - from aiida.cmdline.utils import echo - - show_warning = self.get_option('warnings.rabbitmq_version') - version = get_rabbitmq_version(communicator) - - if show_warning and not is_rabbitmq_version_supported(communicator): - echo.echo_warning(f'RabbitMQ v{version} is not supported and will cause unexpected problems!') - echo.echo_warning('It can cause long-running workflows to crash and jobs to be submitted multiple times.') - echo.echo_warning('See https://github.com/aiidateam/aiida-core/wiki/RabbitMQ-version-to-use for details.') - return version, False - - return version, True - - def check_version(self): - """Check the currently installed version of ``aiida-core`` and warn if it is a post release development version. - - The ``aiida-core`` package maintains the protocol that the ``main`` branch will use a post release version - number. This means it will always append `.post0` to the version of the latest release. This should mean that if - this protocol is maintained properly, this method will print a warning if the currently installed version is a - post release development branch and not an actual release. - """ - from packaging.version import parse - - from aiida import __version__ - from aiida.cmdline.utils import echo - - # Showing of the warning can be turned off by setting the following option to false. - show_warning = self.get_option('warnings.development_version') - version = parse(__version__) - - if version.is_postrelease and show_warning: - echo.echo_warning(f'You are currently using a post release development version of AiiDA: {version}') - echo.echo_warning('Be aware that this is not recommended for production and is not officially supported.') - echo.echo_warning('Databases used with this version may not be compatible with future releases of AiiDA') - echo.echo_warning('as you might not be able to automatically migrate your data.\n') - - -def is_rabbitmq_version_supported(communicator: 'RmqThreadCommunicator') -> bool: - """Return whether the version of RabbitMQ configured for the current profile is supported. - - Versions 3.5 and below are not supported at all, whereas versions 3.8.15 and above are not compatible with a default - configuration of the RabbitMQ server. - - :return: boolean whether the current RabbitMQ version is supported. - """ - from packaging.version import parse - version = get_rabbitmq_version(communicator) - return parse('3.6.0') <= version < parse('3.8.15') - - -def get_rabbitmq_version(communicator: 'RmqThreadCommunicator'): - """Return the version of the RabbitMQ server that the current profile connects to. - - :return: :class:`packaging.version.Version` - """ - from packaging.version import parse - return parse(communicator.server_properties['version'].decode('utf-8')) diff --git a/aiida/manage/tests/__init__.py b/aiida/manage/tests/__init__.py deleted file mode 100644 index 4f4b410141..0000000000 --- a/aiida/manage/tests/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -Testing infrastructure for easy testing of AiiDA plugins. -""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .main import * - -__all__ = ( - 'ProfileManager', - 'TemporaryProfileManager', - 'TestManager', - 'TestManagerError', - 'get_test_backend_name', - 'get_test_profile_name', - 'test_manager', -) - -# yapf: enable diff --git a/aiida/manage/tests/main.py b/aiida/manage/tests/main.py deleted file mode 100644 index d21bb15f17..0000000000 --- a/aiida/manage/tests/main.py +++ /dev/null @@ -1,535 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -Testing infrastructure for easy testing of AiiDA plugins. - -""" -import contextlib -import os -import shutil -import tempfile -import warnings - -from aiida.common.log import override_log_level -from aiida.common.warnings import warn_deprecation -from aiida.manage import configuration, get_manager -from aiida.manage.configuration import settings -from aiida.manage.external.postgres import Postgres -from aiida.orm import User - -__all__ = ( - 'get_test_profile_name', - 'get_test_backend_name', - 'test_manager', - 'TestManager', - 'TestManagerError', - 'ProfileManager', - 'TemporaryProfileManager', -) - -_DEFAULT_PROFILE_INFO = { - 'name': 'test_profile', - 'email': 'tests@aiida.mail', - 'first_name': 'AiiDA', - 'last_name': 'Plugintest', - 'institution': 'aiidateam', - 'storage_backend': 'core.psql_dos', - 'database_engine': 'postgresql_psycopg2', - 'database_username': 'aiida', - 'database_password': 'aiida_pw', - 'database_name': 'aiida_db', - 'repo_dir': 'test_repo', - 'config_dir': '.aiida', - 'root_path': '', - 'broker_protocol': 'amqp', - 'broker_username': 'guest', - 'broker_password': 'guest', - 'broker_host': '127.0.0.1', - 'broker_port': 5672, - 'broker_virtual_host': '', - 'test_profile': True, -} - -warn_deprecation( - 'This module is deprecated; use the fixtures from `aiida.manage.tests.pytest_fixtures` instead', version=3 -) - - -class TestManagerError(Exception): - """Raised by TestManager in situations that may lead to inconsistent behaviour.""" - - def __init__(self, msg): - super().__init__() - self.msg = msg - - def __str__(self): - return repr(self.msg) - - -class TestManager: - """ - Test manager for plugin tests. - - Uses either ProfileManager for wrapping an existing profile or TemporaryProfileManager for setting up a complete - temporary AiiDA environment. - - For usage with pytest, see :py:class:`~aiida.manage.tests.pytest_fixtures`. - """ - - def __init__(self): - self._manager = None - - @property - def manager(self) -> 'ProfileManager': - assert self._manager is not None - return self._manager - - def use_temporary_profile(self, backend=None, pgtest=None): - """Set up Test manager to use temporary AiiDA profile. - - Uses :py:class:`aiida.manage.tests.main.TemporaryProfileManager` internally. - - :param backend: Backend to use. - :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, - e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. - - """ - if configuration.get_profile() is not None: - raise TestManagerError('An AiiDA profile must not be loaded before setting up a test profile.') - if self._manager is not None: - raise TestManagerError('Profile manager already loaded.') - - mngr = TemporaryProfileManager(backend=backend, pgtest=pgtest) - mngr.create_profile() - self._manager = mngr # don't assign before profile has actually been created! - - def use_profile(self, profile_name): - """Set up Test manager to use existing profile. - - Uses :py:class:`aiida.manage.tests.main.ProfileManager` internally. - - :param profile_name: Name of existing test profile to use. - """ - if configuration.get_profile() is not None: - raise TestManagerError('an AiiDA profile must not be loaded before setting up a test profile.') - if self._manager is not None: - raise TestManagerError('Profile manager already loaded.') - - self._manager = ProfileManager(profile_name=profile_name) - - def has_profile_open(self): - return self._manager and self._manager.has_profile_open() - - def reset_db(self): - warn_deprecation('reset_db() is deprecated, use clear_profile() instead', version=3) - return self._manager.clear_profile() - - def clear_profile(self): - """Reset the global profile, clearing all its data and closing any open resources.""" - return self._manager.clear_profile() - - def destroy_all(self): - if self._manager: - self._manager.destroy_all() - self._manager = None - - -class ProfileManager: - """ - Wraps existing AiiDA profile. - """ - - def __init__(self, profile_name): - """ - Use an existing profile. - - :param profile_name: Name of the profile to be loaded - """ - from aiida import load_profile - - self._profile = None - try: - self._profile = load_profile(profile_name) - except Exception: - raise TestManagerError(f'Unable to load test profile `{profile_name}`.') - if self._profile is None: - raise TestManagerError(f'Unable to load test profile `{profile_name}`.') - if not self._profile.is_test_profile: - raise TestManagerError(f'Profile `{profile_name}` is not a valid test profile.') - - def ensure_default_user(self): - """Ensure that the default user defined by the profile exists in the database.""" - created, user = User.collection.get_or_create(self._profile.default_user_email) - if created: - user.store() - - def clear_profile(self): - """Reset the global profile, clearing all its data and closing any open resources. - - If the daemon is running, it will be stopped because it might be holding on to entities that will be cleared - from the storage backend. - """ - from aiida.engine.daemon.client import get_daemon_client - - daemon_client = get_daemon_client() - - if daemon_client.is_daemon_running: - daemon_client.stop_daemon(wait=True) - - manager = get_manager() - manager.get_profile_storage()._clear() # pylint: disable=protected-access - manager.get_profile_storage() # reload the storage connection - manager.reset_communicator() - manager.reset_runner() - - self.ensure_default_user() - - def has_profile_open(self): - return self._profile is not None - - def destroy_all(self): - manager = get_manager() - manager.reset_profile() - - -class TemporaryProfileManager(ProfileManager): - """ - Manage the life cycle of a completely separated and temporary AiiDA environment. - - * No profile / database setup required - * Tests run via the TemporaryProfileManager never pollute the user's working environment - - Filesystem: - - * temporary ``.aiida`` configuration folder - * temporary repository folder - - Database: - - * temporary database cluster (via the ``pgtest`` package) - * with ``aiida`` database user - * with ``aiida_db`` database - - AiiDA: - - * configured to use the temporary configuration - * sets up a temporary profile for tests - - All of this happens automatically when using the corresponding tests classes & tests runners (unittest) - or fixtures (pytest). - - Example:: - - tests = TemporaryProfileManager(backend=backend) - tests.create_aiida_db() # set up only the database - tests.create_profile() # set up a profile (creates the db too if necessary) - - # ready for tests - - # run tests 1 - - tests.clear_profile() - # database ready for independent tests 2 - - # run tests 2 - - tests.destroy_all() - # everything cleaned up - - """ - - def __init__(self, backend='core.psql_dos', pgtest=None): # pylint: disable=super-init-not-called - """Construct a TemporaryProfileManager - - :param backend: a database backend - :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, - e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. - - """ - self.dbinfo = {} - self.profile_info = _DEFAULT_PROFILE_INFO - self.profile_info['storage_backend'] = backend - self._pgtest = pgtest or {} - - self.pg_cluster = None - self.postgres = None - self._profile = None - self._has_test_db = False - self._backup = { - 'config': configuration.CONFIG, - 'config_dir': settings.AIIDA_CONFIG_FOLDER, - settings.DEFAULT_AIIDA_PATH_VARIABLE: os.environ.get(settings.DEFAULT_AIIDA_PATH_VARIABLE, None) - } - - @property - def profile_dictionary(self): - """Profile parameters. - - Used to set up AiiDA profile from self.profile_info dictionary. - """ - dictionary = { - 'default_user_email': 'test@aiida.net', - 'test_profile': True, - 'storage': { - 'backend': self.profile_info.get('storage_backend'), - 'config': { - 'database_engine': self.profile_info.get('database_engine'), - 'database_port': self.profile_info.get('database_port'), - 'database_hostname': self.profile_info.get('database_hostname'), - 'database_name': self.profile_info.get('database_name'), - 'database_username': self.profile_info.get('database_username'), - 'database_password': self.profile_info.get('database_password'), - 'repository_uri': f'file://{self.repo}', - } - }, - 'process_control': { - 'backend': 'rabbitmq', - 'config': { - 'broker_protocol': self.profile_info.get('broker_protocol'), - 'broker_username': self.profile_info.get('broker_username'), - 'broker_password': self.profile_info.get('broker_password'), - 'broker_host': self.profile_info.get('broker_host'), - 'broker_port': self.profile_info.get('broker_port'), - 'broker_virtual_host': self.profile_info.get('broker_virtual_host'), - } - } - } - return dictionary - - def create_db_cluster(self): - """ - Create the database cluster using PGTest. - """ - from pgtest.pgtest import PGTest - - if self.pg_cluster is not None: - raise TestManagerError( - 'Running temporary postgresql cluster detected.Use destroy_all() before creating a new cluster.' - ) - self.pg_cluster = PGTest(**self._pgtest) - self.dbinfo.update(self.pg_cluster.dsn) - - def create_aiida_db(self): - """ - Create the necessary database on the temporary postgres instance. - """ - if configuration.get_profile() is not None: - raise TestManagerError('An AiiDA profile can not be loaded while creating a tests db environment') - if self.pg_cluster is None: - self.create_db_cluster() - self.postgres = Postgres(interactive=False, quiet=True, dbinfo=self.dbinfo) - # Note: We give the user CREATEDB privileges here, only since they are required for the migration tests - self.postgres.create_dbuser( - self.profile_info['database_username'], self.profile_info['database_password'], 'CREATEDB' - ) - self.postgres.create_db(self.profile_info['database_username'], self.profile_info['database_name']) - self.dbinfo = self.postgres.dbinfo - self.profile_info['database_hostname'] = self.postgres.host_for_psycopg2 - self.profile_info['database_port'] = self.postgres.port_for_psycopg2 - self._has_test_db = True - - def create_profile(self): - """ - Set AiiDA to use the tests config dir and create a default profile there - - Warning: the AiiDA dbenv must not be loaded when this is called! - """ - from aiida.manage.configuration import Profile - - manager = get_manager() - - if not self._has_test_db: - self.create_aiida_db() - - if not self.root_dir: - self.root_dir = tempfile.TemporaryDirectory().name # pylint: disable=consider-using-with - configuration.CONFIG = None - - os.environ[settings.DEFAULT_AIIDA_PATH_VARIABLE] = self.config_dir - - with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=UserWarning) - # This will raise a warning that the ``.aiida`` configuration directory is created. - settings.set_configuration_directory() - - manager.unload_profile() - profile_name = self.profile_info['name'] - config = configuration.get_config(create=True) - profile = Profile(profile_name, self.profile_dictionary) - config.add_profile(profile) - config.set_default_profile(profile_name).store() - self._profile = profile - - # Load the new profile and initialize the profile storage - with override_log_level(): - profile = manager.load_profile(profile_name) - profile.storage_cls.initialise(profile, reset=True) - - # Set options to suppress certain warnings - config.set_option('warnings.development_version', False) - config.set_option('warnings.rabbitmq_version', False) - - config.store() - - self.ensure_default_user() - - def repo_ok(self): - return bool(self.repo and os.path.isdir(os.path.dirname(self.repo))) - - @property - def repo(self): - return self._return_dir(self.profile_info['repo_dir']) - - def _return_dir(self, dir_path): - """Return a path to a directory from the fs environment""" - if os.path.isabs(dir_path): - return dir_path - return os.path.join(self.root_dir, dir_path) - - @property - def backend(self): - return self.profile_info['backend'] - - @backend.setter - def backend(self, backend): - if self.has_profile_open(): - raise TestManagerError('backend cannot be changed after setting up the environment') - - valid_backends = ['core.psql_dos'] - if backend not in valid_backends: - raise ValueError(f'invalid backend {backend}, must be one of {valid_backends}') - self.profile_info['backend'] = backend - - @property - def config_dir_ok(self): - return bool(self.config_dir and os.path.isdir(self.config_dir)) - - @property - def config_dir(self): - return self._return_dir(self.profile_info['config_dir']) - - @property - def root_dir(self): - return self.profile_info['root_path'] - - @root_dir.setter - def root_dir(self, root_dir): - self.profile_info['root_path'] = root_dir - - @property - def root_dir_ok(self): - return bool(self.root_dir and os.path.isdir(self.root_dir)) - - def destroy_all(self): - """Remove all traces of the tests run""" - super().destroy_all() - if self.root_dir: - shutil.rmtree(self.root_dir) - self.root_dir = None - if self.pg_cluster: - self.pg_cluster.close() - self.pg_cluster = None - self._has_test_db = False - self._profile = None - - if 'config' in self._backup: - configuration.CONFIG = self._backup['config'] - if 'config_dir' in self._backup: - settings.AIIDA_CONFIG_FOLDER = self._backup['config_dir'] - - if settings.DEFAULT_AIIDA_PATH_VARIABLE in self._backup and self._backup[settings.DEFAULT_AIIDA_PATH_VARIABLE]: - os.environ[settings.DEFAULT_AIIDA_PATH_VARIABLE] = self._backup[settings.DEFAULT_AIIDA_PATH_VARIABLE] - - def has_profile_open(self): - return self._profile is not None - - -_GLOBAL_TEST_MANAGER = TestManager() - - -@contextlib.contextmanager -def test_manager(backend='core.psql_dos', profile_name=None, pgtest=None): - """ Context manager for TestManager objects. - - Sets up temporary AiiDA environment for testing or reuses existing environment, - if `AIIDA_TEST_PROFILE` environment variable is set. - - Example pytest fixture:: - - def aiida_profile(): - with test_manager(backend) as test_mgr: - yield fixture_mgr - - Example unittest test runner:: - - with test_manager(backend) as test_mgr: - # ready for tests - # everything cleaned up - - - :param backend: storage backend type name - :param profile_name: name of test profile to be used or None (to use temporary profile) - :param pgtest: a dictionary of arguments to be passed to PGTest() for starting the postgresql cluster, - e.g. {'pg_ctl': '/somepath/pg_ctl'}. Should usually not be necessary. - """ - from aiida.common.log import configure_logging - from aiida.common.utils import Capturing - - try: - if not _GLOBAL_TEST_MANAGER.has_profile_open(): - if profile_name: - _GLOBAL_TEST_MANAGER.use_profile(profile_name=profile_name) - else: - with Capturing(): # capture output of AiiDA DB setup - _GLOBAL_TEST_MANAGER.use_temporary_profile(backend=backend, pgtest=pgtest) - configure_logging(with_orm=True) - yield _GLOBAL_TEST_MANAGER - finally: - _GLOBAL_TEST_MANAGER.destroy_all() - - -def get_test_backend_name() -> str: - """ Read name of storage backend from environment variable or the specified test profile. - - Reads storage backend from 'AIIDA_TEST_BACKEND' environment variable, - or the backend configured for the 'AIIDA_TEST_PROFILE'. - - :returns: name of storage backend - :raises: ValueError if unknown backend name detected. - :raises: ValueError if both 'AIIDA_TEST_BACKEND' and 'AIIDA_TEST_PROFILE' are set, and the two - backends do not match. - """ - test_profile_name = get_test_profile_name() - backend_env = os.environ.get('AIIDA_TEST_BACKEND', None) - if test_profile_name is not None: - backend_profile = configuration.get_config().get_profile(test_profile_name).storage_backend - if backend_env is not None and backend_env != backend_profile: - raise ValueError( - "The backend '{}' read from AIIDA_TEST_BACKEND does not match the backend '{}' " - "of AIIDA_TEST_PROFILE '{}'".format(backend_env, backend_profile, test_profile_name) - ) - backend_res = backend_profile - else: - backend_res = backend_env or 'core.psql_dos' - - if backend_res in ('core.psql_dos',): - return backend_res - raise ValueError(f"Unknown backend '{backend_res}' read from AIIDA_TEST_BACKEND environment variable") - - -def get_test_profile_name(): - """ Read name of test profile from environment variable. - - Reads name of existing test profile 'AIIDA_TEST_PROFILE' environment variable. - If specified, this profile is used for running the tests (instead of setting up a temporary profile). - - :returns: content of environment variable or `None` - """ - return os.environ.get('AIIDA_TEST_PROFILE', None) diff --git a/aiida/orm/__init__.py b/aiida/orm/__init__.py deleted file mode 100644 index e9a26461ed..0000000000 --- a/aiida/orm/__init__.py +++ /dev/null @@ -1,120 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Main module to expose all orm classes and methods""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .authinfos import * -from .comments import * -from .computers import * -from .entities import * -from .extras import * -from .groups import * -from .logs import * -from .nodes import * -from .querybuilder import * -from .users import * -from .utils import * - -__all__ = ( - 'ASCENDING', - 'AbstractCode', - 'AbstractNodeMeta', - 'ArrayData', - 'AttributeManager', - 'AuthInfo', - 'AutoGroup', - 'BandsData', - 'BaseType', - 'Bool', - 'CalcFunctionNode', - 'CalcJobNode', - 'CalcJobResultManager', - 'CalculationEntityLoader', - 'CalculationNode', - 'CifData', - 'Code', - 'CodeEntityLoader', - 'Collection', - 'Comment', - 'Computer', - 'ComputerEntityLoader', - 'ContainerizedCode', - 'DESCENDING', - 'Data', - 'Dict', - 'Entity', - 'EntityExtras', - 'EntityTypes', - 'EnumData', - 'Float', - 'FolderData', - 'Group', - 'GroupEntityLoader', - 'ImportGroup', - 'InstalledCode', - 'Int', - 'JsonableData', - 'Kind', - 'KpointsData', - 'LinkManager', - 'LinkPair', - 'LinkTriple', - 'List', - 'Log', - 'Node', - 'NodeAttributes', - 'NodeEntityLoader', - 'NodeLinksManager', - 'NodeRepository', - 'NumericType', - 'OrbitalData', - 'OrderSpecifier', - 'OrmEntityLoader', - 'PortableCode', - 'ProcessNode', - 'ProjectionData', - 'QueryBuilder', - 'RemoteData', - 'RemoteStashData', - 'RemoteStashFolderData', - 'SinglefileData', - 'Site', - 'Str', - 'StructureData', - 'TrajectoryData', - 'UpfData', - 'UpfFamily', - 'User', - 'WorkChainNode', - 'WorkFunctionNode', - 'WorkflowNode', - 'XyData', - 'cif_from_ase', - 'find_bandgap', - 'get_loader', - 'get_query_type_from_type_string', - 'get_type_string_from_class', - 'has_pycifrw', - 'load_code', - 'load_computer', - 'load_entity', - 'load_group', - 'load_node', - 'load_node_class', - 'pycifrw_from_cif', - 'to_aiida_type', - 'validate_link', -) - -# yapf: enable diff --git a/aiida/orm/authinfos.py b/aiida/orm/authinfos.py deleted file mode 100644 index e8131ff652..0000000000 --- a/aiida/orm/authinfos.py +++ /dev/null @@ -1,144 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module for the `AuthInfo` ORM class.""" -from typing import TYPE_CHECKING, Any, Dict, Optional, Type - -from aiida.common import exceptions -from aiida.common.lang import classproperty -from aiida.manage import get_manager -from aiida.plugins import TransportFactory - -from . import entities, users - -if TYPE_CHECKING: - from aiida.orm import Computer, User - from aiida.orm.implementation import BackendAuthInfo, StorageBackend - from aiida.transports import Transport - -__all__ = ('AuthInfo',) - - -class AuthInfoCollection(entities.Collection['AuthInfo']): - """The collection of `AuthInfo` entries.""" - - @staticmethod - def _entity_base_cls() -> Type['AuthInfo']: - return AuthInfo - - def delete(self, pk: int) -> None: - """Delete an entry from the collection. - - :param pk: the pk of the entry to delete - """ - self._backend.authinfos.delete(pk) - - -class AuthInfo(entities.Entity['BackendAuthInfo', AuthInfoCollection]): - """ORM class that models the authorization information that allows a `User` to connect to a `Computer`.""" - - _CLS_COLLECTION = AuthInfoCollection - - PROPERTY_WORKDIR = 'workdir' - - def __init__(self, computer: 'Computer', user: 'User', backend: Optional['StorageBackend'] = None) -> None: - """Create an `AuthInfo` instance for the given computer and user. - - :param computer: a `Computer` instance - :param user: a `User` instance - :param backend: the backend to use for the instance, or use the default backend if None - """ - backend = backend or get_manager().get_profile_storage() - model = backend.authinfos.create(computer=computer.backend_entity, user=user.backend_entity) - super().__init__(model) - - def __str__(self) -> str: - if self.enabled: - return f'AuthInfo for {self.user.email} on {self.computer.label}' - - return f'AuthInfo for {self.user.email} on {self.computer.label} [DISABLED]' - - @property - def enabled(self) -> bool: - """Return whether this instance is enabled. - - :return: True if enabled, False otherwise - """ - return self._backend_entity.enabled - - @enabled.setter - def enabled(self, enabled: bool) -> None: - """Set the enabled state - - :param enabled: boolean, True to enable the instance, False to disable it - """ - self._backend_entity.enabled = enabled - - @property - def computer(self) -> 'Computer': - """Return the computer associated with this instance.""" - from . import computers # pylint: disable=cyclic-import - return entities.from_backend_entity(computers.Computer, self._backend_entity.computer) - - @property - def user(self) -> 'User': - """Return the user associated with this instance.""" - return entities.from_backend_entity(users.User, self._backend_entity.user) - - def get_auth_params(self) -> Dict[str, Any]: - """Return the dictionary of authentication parameters - - :return: a dictionary with authentication parameters - """ - return self._backend_entity.get_auth_params() - - def set_auth_params(self, auth_params: Dict[str, Any]) -> None: - """Set the dictionary of authentication parameters - - :param auth_params: a dictionary with authentication parameters - """ - self._backend_entity.set_auth_params(auth_params) - - def get_metadata(self) -> Dict[str, Any]: - """Return the dictionary of metadata - - :return: a dictionary with metadata - """ - return self._backend_entity.get_metadata() - - def set_metadata(self, metadata: Dict[str, Any]) -> None: - """Set the dictionary of metadata - - :param metadata: a dictionary with metadata - """ - self._backend_entity.set_metadata(metadata) - - def get_workdir(self) -> str: - """Return the working directory. - - If no explicit work directory is set for this instance, the working directory of the computer will be returned. - - :return: the working directory - """ - try: - return self.get_metadata()[self.PROPERTY_WORKDIR] - except KeyError: - return self.computer.get_workdir() - - def get_transport(self) -> 'Transport': - """Return a fully configured transport that can be used to connect to the computer set for this instance.""" - computer = self.computer - transport_type = computer.transport_type - - try: - transport_class = TransportFactory(transport_type) - except exceptions.EntryPointError as exception: - raise exceptions.ConfigurationError(f'transport type `{transport_type}` could not be loaded: {exception}') - - return transport_class(machine=computer.hostname, **self.get_auth_params()) diff --git a/aiida/orm/comments.py b/aiida/orm/comments.py deleted file mode 100644 index 6ab090ada0..0000000000 --- a/aiida/orm/comments.py +++ /dev/null @@ -1,128 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Comment objects and functions""" -from datetime import datetime -from typing import TYPE_CHECKING, List, Optional, Type - -from aiida.common.lang import classproperty -from aiida.manage import get_manager - -from . import entities, users - -if TYPE_CHECKING: - from aiida.orm import Node, User - from aiida.orm.implementation import BackendComment, StorageBackend - -__all__ = ('Comment',) - - -class CommentCollection(entities.Collection['Comment']): - """The collection of Comment entries.""" - - @staticmethod - def _entity_base_cls() -> Type['Comment']: - return Comment - - def delete(self, pk: int) -> None: - """ - Remove a Comment from the collection with the given id - - :param pk: the id of the comment to delete - - :raises TypeError: if ``comment_id`` is not an `int` - :raises `~aiida.common.exceptions.NotExistent`: if Comment with ID ``comment_id`` is not found - """ - self._backend.comments.delete(pk) - - def delete_all(self) -> None: - """ - Delete all Comments from the Collection - - :raises `~aiida.common.exceptions.IntegrityError`: if all Comments could not be deleted - """ - self._backend.comments.delete_all() - - def delete_many(self, filters: dict) -> List[int]: - """ - Delete Comments from the Collection based on ``filters`` - - :param filters: similar to QueryBuilder filter - - :return: (former) ``PK`` s of deleted Comments - - :raises TypeError: if ``filters`` is not a `dict` - :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty - """ - return self._backend.comments.delete_many(filters) - - -class Comment(entities.Entity['BackendComment', CommentCollection]): - """Base class to map a DbComment that represents a comment attached to a certain Node.""" - - _CLS_COLLECTION = CommentCollection - - def __init__( - self, node: 'Node', user: 'User', content: Optional[str] = None, backend: Optional['StorageBackend'] = None - ): - """Create a Comment for a given node and user - - :param node: a Node instance - :param user: a User instance - :param content: the comment content - :param backend: the backend to use for the instance, or use the default backend if None - - :return: a Comment object associated to the given node and user - """ - backend = backend or get_manager().get_profile_storage() - model = backend.comments.create(node=node.backend_entity, user=user.backend_entity, content=content) - super().__init__(model) - - def __str__(self) -> str: - arguments = [self.uuid, self.node.pk, self.user.email, self.content] - return 'Comment<{}> for node<{}> and user<{}>: {}'.format(*arguments) - - @property - def uuid(self) -> str: - """Return the UUID for this comment. - - This identifier is unique across all entities types and backend instances. - - :return: the entity uuid - """ - return self._backend_entity.uuid - - @property - def ctime(self) -> datetime: - return self._backend_entity.ctime - - @property - def mtime(self) -> datetime: - return self._backend_entity.mtime - - def set_mtime(self, value: datetime) -> None: - return self._backend_entity.set_mtime(value) - - @property - def node(self) -> 'Node': - return self._backend_entity.node - - @property - def user(self) -> 'User': - return entities.from_backend_entity(users.User, self._backend_entity.user) - - def set_user(self, value: 'User') -> None: - self._backend_entity.user = value.backend_entity - - @property - def content(self) -> str: - return self._backend_entity.content - - def set_content(self, value: str) -> None: - return self._backend_entity.set_content(value) diff --git a/aiida/orm/computers.py b/aiida/orm/computers.py deleted file mode 100644 index 178650fde7..0000000000 --- a/aiida/orm/computers.py +++ /dev/null @@ -1,706 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module for Computer entities""" -import logging -import os -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union - -from aiida.common import exceptions -from aiida.common.lang import classproperty -from aiida.manage import get_manager -from aiida.plugins import SchedulerFactory, TransportFactory - -from . import entities, users - -if TYPE_CHECKING: - from aiida.orm import AuthInfo, User - from aiida.orm.implementation import BackendComputer, StorageBackend - from aiida.schedulers import Scheduler - from aiida.transports import Transport - -__all__ = ('Computer',) - - -class ComputerCollection(entities.Collection['Computer']): - """The collection of Computer entries.""" - - @staticmethod - def _entity_base_cls() -> Type['Computer']: - return Computer - - def get_or_create(self, label: Optional[str] = None, **kwargs) -> Tuple[bool, 'Computer']: - """ - Try to retrieve a Computer from the DB with the given arguments; - create (and store) a new Computer if such a Computer was not present yet. - - :param label: computer label - - :return: (computer, created) where computer is the computer (new or existing, - in any case already stored) and created is a boolean saying - """ - if not label: - raise ValueError('Computer label must be provided') - - try: - return False, self.get(label=label) - except exceptions.NotExistent: - return True, Computer(backend=self.backend, label=label, **kwargs) - - def list_labels(self) -> List[str]: - """Return a list with all the labels of the computers in the DB.""" - return self._backend.computers.list_names() - - def delete(self, pk: int) -> None: - """Delete the computer with the given id""" - return self._backend.computers.delete(pk) - - -class Computer(entities.Entity['BackendComputer', ComputerCollection]): - """ - Computer entity. - """ - # pylint: disable=too-many-public-methods - - _logger = logging.getLogger(__name__) - - PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL = 'minimum_scheduler_poll_interval' # pylint: disable=invalid-name - PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL__DEFAULT = 10. # pylint: disable=invalid-name - PROPERTY_WORKDIR = 'workdir' - PROPERTY_SHEBANG = 'shebang' - - _CLS_COLLECTION = ComputerCollection - - def __init__( # pylint: disable=too-many-arguments - self, - label: Optional[str] = None, - hostname: str = '', - description: str = '', - transport_type: str = '', - scheduler_type: str = '', - workdir: Optional[str] = None, - backend: Optional['StorageBackend'] = None, - ) -> None: - """Construct a new computer.""" - backend = backend or get_manager().get_profile_storage() - model = backend.computers.create( - label=label, - hostname=hostname, - description=description, - transport_type=transport_type, - scheduler_type=scheduler_type - ) - super().__init__(model) - if workdir is not None: - self.set_workdir(workdir) - - def __repr__(self): - return f'<{self.__class__.__name__}: {str(self)}>' - - def __str__(self): - return f'{self.label} ({self.hostname}), pk: {self.pk}' - - @property - def uuid(self) -> str: - """Return the UUID for this computer. - - This identifier is unique across all entities types and backend instances. - - :return: the entity uuid - """ - return self._backend_entity.uuid - - @property - def logger(self) -> logging.Logger: - return self._logger - - @classmethod - def _label_validator(cls, label: str) -> None: - """ - Validates the label. - """ - if not label.strip(): - raise exceptions.ValidationError('No label specified') - - @classmethod - def _hostname_validator(cls, hostname: str) -> None: - """ - Validates the hostname. - """ - if not (hostname or hostname.strip()): - raise exceptions.ValidationError('No hostname specified') - - @classmethod - def _description_validator(cls, description: str) -> None: - """ - Validates the description. - """ - # The description is always valid - - @classmethod - def _transport_type_validator(cls, transport_type: str) -> None: - """ - Validates the transport string. - """ - from aiida.plugins.entry_point import get_entry_point_names - if transport_type not in get_entry_point_names('aiida.transports'): - raise exceptions.ValidationError('The specified transport is not a valid one') - - @classmethod - def _scheduler_type_validator(cls, scheduler_type: str) -> None: - """ - Validates the transport string. - """ - from aiida.plugins.entry_point import get_entry_point_names - if scheduler_type not in get_entry_point_names('aiida.schedulers'): - raise exceptions.ValidationError(f'The specified scheduler `{scheduler_type}` is not a valid one') - - @classmethod - def _prepend_text_validator(cls, prepend_text: str) -> None: - """ - Validates the prepend text string. - """ - # no validation done - - @classmethod - def _append_text_validator(cls, append_text: str) -> None: - """ - Validates the append text string. - """ - # no validation done - - @classmethod - def _workdir_validator(cls, workdir: str) -> None: - """ - Validates the transport string. - """ - if not workdir.strip(): - raise exceptions.ValidationError('No workdir specified') - - try: - convertedwd = workdir.format(username='test') - except KeyError as exc: - raise exceptions.ValidationError(f'In workdir there is an unknown replacement field {exc.args[0]}') - except ValueError as exc: - raise exceptions.ValidationError(f"Error in the string: '{exc}'") - - if not os.path.isabs(convertedwd): - raise exceptions.ValidationError('The workdir must be an absolute path') - - def _mpirun_command_validator(self, mpirun_cmd: Union[List[str], Tuple[str, ...]]) -> None: - """ - Validates the mpirun_command variable. MUST be called after properly - checking for a valid scheduler. - """ - if not isinstance(mpirun_cmd, (tuple, list)) or not all(isinstance(i, str) for i in mpirun_cmd): - raise exceptions.ValidationError('the mpirun_command must be a list of strings') - - try: - job_resource_keys = self.get_scheduler().job_resource_class.get_valid_keys() - except exceptions.EntryPointError: - raise exceptions.ValidationError('Unable to load the scheduler for this computer') - - subst = {i: 'value' for i in job_resource_keys} - subst['tot_num_mpiprocs'] = 'value' - - try: - for arg in mpirun_cmd: - arg.format(**subst) - except KeyError as exc: - raise exceptions.ValidationError(f'In workdir there is an unknown replacement field {exc.args[0]}') - except ValueError as exc: - raise exceptions.ValidationError(f"Error in the string: '{exc}'") - - def validate(self) -> None: - """ - Check if the attributes and files retrieved from the DB are valid. - Raise a ValidationError if something is wrong. - - Must be able to work even before storing: therefore, use the get_attr and similar methods - that automatically read either from the DB or from the internal attribute cache. - - For the base class, this is always valid. Subclasses will reimplement this. - In the subclass, always call the super().validate() method first! - """ - if not self.label.strip(): - raise exceptions.ValidationError('No name specified') - - self._label_validator(self.label) - self._hostname_validator(self.hostname) - self._description_validator(self.description) - self._transport_type_validator(self.transport_type) - self._scheduler_type_validator(self.scheduler_type) - self._workdir_validator(self.get_workdir()) - self.default_memory_per_machine_validator(self.get_default_memory_per_machine()) - - try: - mpirun_cmd = self.get_mpirun_command() - except exceptions.DbContentError: - raise exceptions.ValidationError('Error in the DB content of the metadata') - - # To be called AFTER the validation of the scheduler - self._mpirun_command_validator(mpirun_cmd) - - @classmethod - def _default_mpiprocs_per_machine_validator(cls, def_cpus_per_machine: Optional[int]) -> None: - """ - Validates the default number of CPUs per machine (node) - """ - if def_cpus_per_machine is None: - return - - if not isinstance(def_cpus_per_machine, int) or def_cpus_per_machine <= 0: - raise exceptions.ValidationError( - 'Invalid value for default_mpiprocs_per_machine, must be a positive integer, or an empty string if you ' - 'do not want to provide a default value.' - ) - - @classmethod - def default_memory_per_machine_validator(cls, def_memory_per_machine: Optional[int]) -> None: - """Validates the default amount of memory (kB) per machine (node)""" - if def_memory_per_machine is None: - return - - if not isinstance(def_memory_per_machine, int) or def_memory_per_machine <= 0: - raise exceptions.ValidationError( - f'Invalid value for def_memory_per_machine, must be a positive int, got: {def_memory_per_machine}' - ) - - def copy(self) -> 'Computer': - """ - Return a copy of the current object to work with, not stored yet. - """ - return entities.from_backend_entity(Computer, self._backend_entity.copy()) - - def store(self) -> 'Computer': - """ - Store the computer in the DB. - - Differently from Nodes, a computer can be re-stored if its properties - are to be changed (e.g. a new mpirun command, etc.) - """ - self.validate() - return super().store() - - @property - def label(self) -> str: - """Return the computer label. - - :return: the label. - """ - return self._backend_entity.label - - @label.setter - def label(self, value: str) -> None: - """Set the computer label. - - :param value: the label to set. - """ - self._backend_entity.set_label(value) - - @property - def description(self) -> str: - """Return the computer computer. - - :return: the description. - """ - return self._backend_entity.description - - @description.setter - def description(self, value: str) -> None: - """Set the computer description. - - :param value: the description to set. - """ - self._backend_entity.set_description(value) - - @property - def hostname(self) -> str: - """Return the computer hostname. - - :return: the hostname. - """ - return self._backend_entity.hostname - - @hostname.setter - def hostname(self, value: str) -> None: - """Set the computer hostname. - - :param value: the hostname to set. - """ - self._backend_entity.set_hostname(value) - - @property - def scheduler_type(self) -> str: - """Return the computer scheduler type. - - :return: the scheduler type. - """ - return self._backend_entity.get_scheduler_type() - - @scheduler_type.setter - def scheduler_type(self, value: str) -> None: - """Set the computer scheduler type. - - :param value: the scheduler type to set. - """ - self._backend_entity.set_scheduler_type(value) - - @property - def transport_type(self) -> str: - """Return the computer transport type. - - :return: the transport_type. - """ - return self._backend_entity.get_transport_type() - - @transport_type.setter - def transport_type(self, value: str) -> None: - """Set the computer transport type. - - :param value: the transport_type to set. - """ - self._backend_entity.set_transport_type(value) - - @property - def metadata(self) -> Dict[str, Any]: - """Return the computer metadata. - - :return: the metadata. - """ - return self._backend_entity.get_metadata() - - @metadata.setter - def metadata(self, value: Dict[str, Any]) -> None: - """Set the computer metadata. - - :param value: the metadata to set. - """ - self._backend_entity.set_metadata(value) - - def delete_property(self, name: str, raise_exception: bool = True) -> None: - """ - Delete a property from this computer - - :param name: the name of the property - :param raise_exception: if True raise if the property does not exist, otherwise return None - """ - olddata = self.metadata - try: - del olddata[name] - self.metadata = olddata - except KeyError: - if raise_exception: - raise AttributeError(f"'{name}' property not found") - - def set_property(self, name: str, value: Any) -> None: - """Set a property on this computer - - :param name: the property name - :param value: the new value - """ - metadata = self.metadata or {} - metadata[name] = value - self.metadata = metadata - - def get_property(self, name: str, *args: Any) -> Any: - """Get a property of this computer - - :param name: the property name - :param args: additional arguments - - :return: the property value - """ - if len(args) > 1: - raise TypeError('get_property expected at most 2 arguments') - olddata = self.metadata - try: - return olddata[name] - except KeyError: - if not args: - raise AttributeError(f"'{name}' property not found") - return args[0] - - def get_prepend_text(self) -> str: - return self.get_property('prepend_text', '') - - def set_prepend_text(self, val: str) -> None: - self.set_property('prepend_text', str(val)) - - def get_append_text(self) -> str: - return self.get_property('append_text', '') - - def set_append_text(self, val: str) -> None: - self.set_property('append_text', str(val)) - - def get_use_double_quotes(self) -> bool: - """Return whether the command line parameters of this computer should be escaped with double quotes. - - :returns: True if to escape with double quotes, False otherwise which is also the default. - """ - return self.get_property('use_double_quotes', False) - - def set_use_double_quotes(self, val: bool) -> None: - """Set whether the command line parameters of this computer should be escaped with double quotes. - - :param use_double_quotes: True if to escape with double quotes, False otherwise. - """ - from aiida.common.lang import type_check - type_check(val, bool) - self.set_property('use_double_quotes', val) - - def get_mpirun_command(self) -> List[str]: - """ - Return the mpirun command. Must be a list of strings, that will be - then joined with spaces when submitting. - - I also provide a sensible default that may be ok in many cases. - """ - return self.get_property('mpirun_command', ['mpirun', '-np', '{tot_num_mpiprocs}']) - - def set_mpirun_command(self, val: Union[List[str], Tuple[str, ...]]) -> None: - """ - Set the mpirun command. It must be a list of strings (you can use - string.split() if you have a single, space-separated string). - """ - if not isinstance(val, (tuple, list)) or not all(isinstance(i, str) for i in val): - raise TypeError('the mpirun_command must be a list of strings') - self.set_property('mpirun_command', val) - - def get_default_mpiprocs_per_machine(self) -> Optional[int]: - """ - Return the default number of CPUs per machine (node) for this computer, - or None if it was not set. - """ - return self.get_property('default_mpiprocs_per_machine', None) - - def set_default_mpiprocs_per_machine(self, def_cpus_per_machine: Optional[int]) -> None: - """ - Set the default number of CPUs per machine (node) for this computer. - Accepts None if you do not want to set this value. - """ - if def_cpus_per_machine is None: - self.delete_property('default_mpiprocs_per_machine', raise_exception=False) - elif not isinstance(def_cpus_per_machine, int): - raise TypeError('def_cpus_per_machine must be an integer (or None)') - self.set_property('default_mpiprocs_per_machine', def_cpus_per_machine) - - def get_default_memory_per_machine(self) -> Optional[int]: - """ - Return the default amount of memory (kB) per machine (node) for this computer, - or None if it was not set. - """ - return self.get_property('default_memory_per_machine', None) - - def set_default_memory_per_machine(self, def_memory_per_machine: Optional[int]) -> None: - """ - Set the default amount of memory (kB) per machine (node) for this computer. - Accepts None if you do not want to set this value. - """ - self.default_memory_per_machine_validator(def_memory_per_machine) - self.set_property('default_memory_per_machine', def_memory_per_machine) - - def get_minimum_job_poll_interval(self) -> float: - """Get the minimum interval between subsequent requests to poll the scheduler for job status. - - .. note:: If no value was ever set for this computer it will fall back on the default provided by the associated - transport class in the ``DEFAULT_MINIMUM_JOB_POLL_INTERVAL`` attribute. If the computer doesn't have a - transport class, or it cannot be loaded, or it doesn't provide a job poll interval default, then this will - fall back on the ``PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL__DEFAULT`` attribute of this class. - - :return: The minimum interval (in seconds). - """ - try: - default = self.get_transport_class().DEFAULT_MINIMUM_JOB_POLL_INTERVAL - except (exceptions.ConfigurationError, AttributeError): - default = self.PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL__DEFAULT - - return self.get_property(self.PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL, default) - - def set_minimum_job_poll_interval(self, interval: float) -> None: - """ - Set the minimum interval between subsequent requests to update the list - of jobs currently running on this computer. - - :param interval: The minimum interval in seconds - """ - self.set_property(self.PROPERTY_MINIMUM_SCHEDULER_POLL_INTERVAL, interval) - - def get_workdir(self) -> str: - """ - Get the working directory for this computer - :return: The currently configured working directory - """ - return self.get_property(self.PROPERTY_WORKDIR, '/scratch/{username}/aiida_run/') - - def set_workdir(self, val: str) -> None: - self.set_property(self.PROPERTY_WORKDIR, val) - - def get_shebang(self) -> str: - return self.get_property(self.PROPERTY_SHEBANG, '#!/bin/bash') - - def set_shebang(self, val: str) -> None: - """ - :param str val: A valid shebang line - """ - if not isinstance(val, str): - raise ValueError(f'{val} is invalid. Input has to be a string') - if not val.startswith('#!'): - raise ValueError(f'{val} is invalid. A shebang line has to start with #!') - metadata = self.metadata - metadata['shebang'] = val - self.metadata = metadata - - def get_authinfo(self, user: 'User') -> 'AuthInfo': - """ - Return the aiida.orm.authinfo.AuthInfo instance for the - given user on this computer, if the computer - is configured for the given user. - - :param user: a User instance. - :return: a AuthInfo instance - :raise aiida.common.NotExistent: if the computer is not configured for the given - user. - """ - from . import authinfos - - try: - authinfo = authinfos.AuthInfo.collection(self.backend).get(dbcomputer_id=self.pk, aiidauser_id=user.pk) - except exceptions.NotExistent as exc: - raise exceptions.NotExistent( - f'Computer `{self.label}` (ID={self.pk}) not configured for user `{user.get_short_name()}` ' - f'(ID={user.pk}) - use `verdi computer configure` first' - ) from exc - - return authinfo - - @property - def is_configured(self) -> bool: - """Return whether the computer is configured for the current default user. - - :return: Boolean, ``True`` if the computer is configured for the current default user, ``False`` otherwise. - """ - return self.is_user_configured(users.User.collection(self.backend).get_default()) - - def is_user_configured(self, user: 'User') -> bool: - """ - Is the user configured on this computer? - - :param user: the user to check - :return: True if configured, False otherwise - """ - try: - self.get_authinfo(user) - return True - except exceptions.NotExistent: - return False - - def is_user_enabled(self, user: 'User') -> bool: - """ - Is the given user enabled to run on this computer? - - :param user: the user to check - :return: True if enabled, False otherwise - """ - try: - authinfo = self.get_authinfo(user) - return authinfo.enabled - except exceptions.NotExistent: - # Return False if the user is not configured (in a sense, it is disabled for that user) - return False - - def get_transport(self, user: Optional['User'] = None) -> 'Transport': - """ - Return a Transport class, configured with all correct parameters. - The Transport is closed (meaning that if you want to run any operation with - it, you have to open it first (i.e., e.g. for a SSH transport, you have - to open a connection). To do this you can call ``transports.open()``, or simply - run within a ``with`` statement:: - - transport = Computer.get_transport() - with transport: - print(transports.whoami()) - - :param user: if None, try to obtain a transport for the default user. - Otherwise, pass a valid User. - - :return: a (closed) Transport, already configured with the connection - parameters to the supercomputer, as configured with ``verdi computer configure`` - for the user specified as a parameter ``user``. - """ - from . import authinfos # pylint: disable=cyclic-import - - user = user or users.User.collection(self.backend).get_default() - authinfo = authinfos.AuthInfo.collection(self.backend).get(dbcomputer=self, aiidauser=user) - return authinfo.get_transport() - - def get_transport_class(self) -> Type['Transport']: - """Get the transport class for this computer. Can be used to instantiate a transport instance.""" - try: - return TransportFactory(self.transport_type) - except exceptions.EntryPointError as exception: - raise exceptions.ConfigurationError( - f'No transport found for {self.label} [type {self.transport_type}], message: {exception}' - ) - - def get_scheduler(self) -> 'Scheduler': - """Get a scheduler instance for this computer""" - try: - scheduler_class = SchedulerFactory(self.scheduler_type) - # I call the init without any parameter - return scheduler_class() - except exceptions.EntryPointError as exception: - raise exceptions.ConfigurationError( - f'No scheduler found for {self.label} [type {self.scheduler_type}], message: {exception}' - ) - - def configure(self, user: Optional['User'] = None, **kwargs: Any) -> 'AuthInfo': - """Configure a computer for a user with valid auth params passed via kwargs - - :param user: the user to configure the computer for - :kwargs: the configuration keywords with corresponding values - :return: the authinfo object for the configured user - """ - from . import authinfos - - transport_cls = self.get_transport_class() - user = user or users.User.collection(self.backend).get_default() - valid_keys = set(transport_cls.get_valid_auth_params()) - - if not set(kwargs.keys()).issubset(valid_keys): - invalid_keys = [key for key in kwargs if key not in valid_keys] - raise ValueError(f'{transport_cls}: received invalid authentication parameter(s) "{invalid_keys}"') - - try: - authinfo = self.get_authinfo(user) - except exceptions.NotExistent: - authinfo = authinfos.AuthInfo(self, user) - - auth_params = authinfo.get_auth_params() - - if valid_keys: - auth_params.update(kwargs) - authinfo.set_auth_params(auth_params) - authinfo.store() - - return authinfo - - def get_configuration(self, user: Optional['User'] = None) -> Dict[str, Any]: - """Get the configuration of computer for the given user as a dictionary - - :param user: the user to to get the configuration for, otherwise default user - """ - user = user or users.User.collection(self.backend).get_default() - - try: - authinfo = self.get_authinfo(user) - except exceptions.NotExistent: - return {} - - return authinfo.get_auth_params() diff --git a/aiida/orm/convert.py b/aiida/orm/convert.py deleted file mode 100644 index e959fe3a08..0000000000 --- a/aiida/orm/convert.py +++ /dev/null @@ -1,145 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=cyclic-import -"""Module for converting backend entities into frontend, ORM, entities""" -from collections.abc import Iterator, Mapping, Sized -from functools import singledispatch - -from aiida.orm.entities import from_backend_entity -from aiida.orm.implementation import ( - BackendAuthInfo, - BackendComment, - BackendComputer, - BackendGroup, - BackendLog, - BackendNode, - BackendUser, -) - - -@singledispatch -def get_orm_entity(backend_entity): - raise TypeError(f'No corresponding AiiDA ORM class exists for backend instance {backend_entity.__class__.__name__}') - - -@get_orm_entity.register(Mapping) -def _(backend_entity): - """Convert all values of the given mapping to ORM entities if they are backend ORM instances.""" - converted = {} - - # Note that we cannot use a simple comprehension because raised `TypeError` should be caught here otherwise only - # parts of the mapping will be converted. - for key, value in backend_entity.items(): - try: - converted[key] = get_orm_entity(value) - except TypeError: - converted[key] = value - - return converted - - -@get_orm_entity.register(list) -@get_orm_entity.register(tuple) -def _(backend_entity): - """Convert all values of the given list or tuple to ORM entities if they are backend ORM instances. - - Note that we do not register on `collections.abc.Sequence` because that will also match strings. - """ - if hasattr(backend_entity, '_asdict'): - # it is a NamedTuple, so return as is - return backend_entity - - converted = [] - - # Note that we cannot use a simple comprehension because raised `TypeError` should be caught here otherwise only - # parts of the mapping will be converted. - for value in backend_entity: - try: - converted.append(get_orm_entity(value)) - except TypeError: - converted.append(value) - - return converted - - -@get_orm_entity.register(BackendGroup) -def _(backend_entity): - from .groups import load_group_class - group_class = load_group_class(backend_entity.type_string) - return from_backend_entity(group_class, backend_entity) - - -@get_orm_entity.register(BackendComputer) -def _(backend_entity): - from . import computers - return from_backend_entity(computers.Computer, backend_entity) - - -@get_orm_entity.register(BackendUser) -def _(backend_entity): - from . import users - return from_backend_entity(users.User, backend_entity) - - -@get_orm_entity.register(BackendAuthInfo) -def _(backend_entity): - from . import authinfos - return from_backend_entity(authinfos.AuthInfo, backend_entity) - - -@get_orm_entity.register(BackendLog) -def _(backend_entity): - from . import logs - return from_backend_entity(logs.Log, backend_entity) - - -@get_orm_entity.register(BackendComment) -def _(backend_entity): - from . import comments - return from_backend_entity(comments.Comment, backend_entity) - - -@get_orm_entity.register(BackendNode) -def _(backend_entity): - from .utils.node import load_node_class # pylint: disable=import-error,no-name-in-module - node_class = load_node_class(backend_entity.node_type) - return from_backend_entity(node_class, backend_entity) - - -class ConvertIterator(Iterator, Sized): - """ - Iterator that converts backend entities into frontend ORM entities as needed - - See :func:`aiida.orm.Group.nodes` for an example. - """ - - def __init__(self, backend_iterator): - super().__init__() - self._backend_iterator = backend_iterator - self.generator = self._genfunction() - - def _genfunction(self): - for backend_node in self._backend_iterator: - yield get_orm_entity(backend_node) - - def __iter__(self): - return self - - def __len__(self): - return len(self._backend_iterator) - - def __getitem__(self, value): - if isinstance(value, slice): - return [get_orm_entity(backend_node) for backend_node in self._backend_iterator[value]] - - return get_orm_entity(self._backend_iterator[value]) - - def __next__(self): - return next(self.generator) diff --git a/aiida/orm/entities.py b/aiida/orm/entities.py deleted file mode 100644 index b15feaba3e..0000000000 --- a/aiida/orm/entities.py +++ /dev/null @@ -1,275 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module for all common top level AiiDA entity classes and methods""" -import abc -from enum import Enum -from functools import lru_cache -from typing import TYPE_CHECKING, Any, Generic, List, Optional, Type, TypeVar, cast - -from plumpy.base.utils import call_with_super_check, super_check - -from aiida.common.exceptions import InvalidOperation -from aiida.common.lang import classproperty, type_check -from aiida.common.warnings import warn_deprecation -from aiida.manage import get_manager - -if TYPE_CHECKING: - from aiida.orm.implementation import BackendEntity, StorageBackend - from aiida.orm.querybuilder import FilterType, OrderByType, QueryBuilder - -__all__ = ('Entity', 'Collection', 'EntityTypes') - -CollectionType = TypeVar('CollectionType', bound='Collection') -EntityType = TypeVar('EntityType', bound='Entity') -BackendEntityType = TypeVar('BackendEntityType', bound='BackendEntity') - - -class EntityTypes(Enum): - """Enum for referring to ORM entities in a backend-agnostic manner.""" - AUTHINFO = 'authinfo' - COMMENT = 'comment' - COMPUTER = 'computer' - GROUP = 'group' - LOG = 'log' - NODE = 'node' - USER = 'user' - LINK = 'link' - GROUP_NODE = 'group_node' - - -class Collection(abc.ABC, Generic[EntityType]): - """Container class that represents the collection of objects of a particular entity type.""" - - @staticmethod - @abc.abstractmethod - def _entity_base_cls() -> Type[EntityType]: - """The allowed entity class or subclasses thereof.""" - - @classmethod - @lru_cache(maxsize=100) - def get_cached(cls, entity_class: Type[EntityType], backend: 'StorageBackend'): - """Get the cached collection instance for the given entity class and backend. - - :param backend: the backend instance to get the collection for - """ - from aiida.orm.implementation import StorageBackend - type_check(backend, StorageBackend) - return cls(entity_class, backend=backend) - - def __init__(self, entity_class: Type[EntityType], backend: Optional['StorageBackend'] = None) -> None: - """ Construct a new entity collection. - - :param entity_class: the entity type e.g. User, Computer, etc - :param backend: the backend instance to get the collection for, or use the default - """ - from aiida.orm.implementation import StorageBackend - type_check(backend, StorageBackend, allow_none=True) - assert issubclass(entity_class, self._entity_base_cls()) - self._backend = backend or get_manager().get_profile_storage() - self._entity_type = entity_class - - def __call__(self: CollectionType, backend: 'StorageBackend') -> CollectionType: - """Get or create a cached collection using a new backend.""" - if backend is self._backend: - return self - return self.get_cached(self.entity_type, backend=backend) # type: ignore - - @property - def entity_type(self) -> Type[EntityType]: - """The entity type for this instance.""" - return self._entity_type - - @property - def backend(self) -> 'StorageBackend': - """Return the backend.""" - return self._backend - - def query( - self, - filters: Optional['FilterType'] = None, - order_by: Optional['OrderByType'] = None, - limit: Optional[int] = None, - offset: Optional[int] = None - ) -> 'QueryBuilder': - """Get a query builder for the objects of this collection. - - :param filters: the keyword value pair filters to match - :param order_by: a list of (key, direction) pairs specifying the sort order - :param limit: the maximum number of results to return - :param offset: number of initial results to be skipped - """ - from . import querybuilder - - filters = filters or {} - order_by = {self.entity_type: order_by} if order_by else {} - - query = querybuilder.QueryBuilder(backend=self._backend, limit=limit, offset=offset) - query.append(self.entity_type, project='*', filters=filters) - query.order_by([order_by]) - return query - - def get(self, **filters: Any) -> EntityType: - """Get a single collection entry that matches the filter criteria. - - :param filters: the filters identifying the object to get - - :return: the entry - """ - res = self.query(filters=filters) - return res.one()[0] - - def find( - self, - filters: Optional['FilterType'] = None, - order_by: Optional['OrderByType'] = None, - limit: Optional[int] = None - ) -> List[EntityType]: - """Find collection entries matching the filter criteria. - - :param filters: the keyword value pair filters to match - :param order_by: a list of (key, direction) pairs specifying the sort order - :param limit: the maximum number of results to return - - :return: a list of resulting matches - """ - query = self.query(filters=filters, order_by=order_by, limit=limit) - return cast(List[EntityType], query.all(flat=True)) - - def all(self) -> List[EntityType]: - """Get all entities in this collection. - - :return: A list of all entities - """ - return cast(List[EntityType], self.query().all(flat=True)) # pylint: disable=no-member - - def count(self, filters: Optional['FilterType'] = None) -> int: - """Count entities in this collection according to criteria. - - :param filters: the keyword value pair filters to match - - :return: The number of entities found using the supplied criteria - """ - return self.query(filters=filters).count() - - -class Entity(abc.ABC, Generic[BackendEntityType, CollectionType]): - """An AiiDA entity""" - - _CLS_COLLECTION: Type[CollectionType] = Collection # type: ignore - - @classproperty - def objects(cls: EntityType) -> CollectionType: # pylint: disable=no-self-argument - """Get a collection for objects of this type, with the default backend. - - .. deprecated:: This will be removed in v3, use ``collection`` instead. - - :return: an object that can be used to access entities of this type - """ - warn_deprecation('`objects` property is deprecated, use `collection` instead.', version=3, stacklevel=4) - return cls.collection - - @classproperty - def collection(cls) -> CollectionType: # pylint: disable=no-self-argument - """Get a collection for objects of this type, with the default backend. - - :return: an object that can be used to access entities of this type - """ - return cls._CLS_COLLECTION.get_cached(cls, get_manager().get_profile_storage()) - - @classmethod - def get(cls, **kwargs): - """Get an entity of the collection matching the given filters. - - .. deprecated: Will be removed in v3, use `Entity.collection.get` instead. - - """ - warn_deprecation( - f'`{cls.__name__}.get` method is deprecated, use `{cls.__name__}.collection.get` instead.', - version=3, - stacklevel=2 - ) - return cls.collection.get(**kwargs) # pylint: disable=no-member - - def __init__(self, backend_entity: BackendEntityType) -> None: - """ - :param backend_entity: the backend model supporting this entity - """ - self._backend_entity = backend_entity - call_with_super_check(self.initialize) - - def __getstate__(self): - """Prevent an ORM entity instance from being pickled.""" - raise InvalidOperation('pickling of AiiDA ORM instances is not supported.') - - @super_check - def initialize(self) -> None: - """Initialize instance attributes. - - This will be called after the constructor is called or an entity is created from an existing backend entity. - """ - - @property - def id(self) -> int: # pylint: disable=invalid-name - """Return the id for this entity. - - This identifier is guaranteed to be unique amongst entities of the same type for a single backend instance. - - .. deprecated: Will be removed in v3, use `pk` instead. - - :return: the entity's id - """ - warn_deprecation('`id` property is deprecated, use `pk` instead.', version=3, stacklevel=2) - return self._backend_entity.id - - @property - def pk(self) -> int: - """Return the primary key for this entity. - - This identifier is guaranteed to be unique amongst entities of the same type for a single backend instance. - - :return: the entity's principal key - """ - return self._backend_entity.id - - def store(self: EntityType) -> EntityType: - """Store the entity.""" - self._backend_entity.store() - return self - - @property - def is_stored(self) -> bool: - """Return whether the entity is stored.""" - return self._backend_entity.is_stored - - @property - def backend(self) -> 'StorageBackend': - """Get the backend for this entity""" - return self._backend_entity.backend - - @property - def backend_entity(self) -> BackendEntityType: - """Get the implementing class for this object""" - return self._backend_entity - - -def from_backend_entity(cls: Type[EntityType], backend_entity: BackendEntityType) -> EntityType: - """Construct an entity from a backend entity instance - - :param backend_entity: the backend entity - - :return: an AiiDA entity instance - """ - from .implementation.entities import BackendEntity - - type_check(backend_entity, BackendEntity) - entity = cls.__new__(cls) - entity._backend_entity = backend_entity # pylint: disable=protected-access - call_with_super_check(entity.initialize) - return entity diff --git a/aiida/orm/groups.py b/aiida/orm/groups.py deleted file mode 100644 index 7d1c58e1be..0000000000 --- a/aiida/orm/groups.py +++ /dev/null @@ -1,386 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""AiiDA Group entites""" -from abc import ABCMeta -from functools import cached_property -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Sequence, Tuple, Type, TypeVar, Union, cast -import warnings - -from aiida.common import exceptions -from aiida.common.lang import classproperty, type_check -from aiida.common.warnings import warn_deprecation -from aiida.manage import get_manager - -from . import convert, entities, extras, users - -if TYPE_CHECKING: - from aiida.orm import Node, User - from aiida.orm.implementation import BackendGroup, StorageBackend - from aiida.plugins.entry_point import EntryPoint # type: ignore - -__all__ = ('Group', 'AutoGroup', 'ImportGroup', 'UpfFamily') - -SelfType = TypeVar('SelfType', bound='Group') - - -def load_group_class(type_string: str) -> Type['Group']: - """Load the sub class of `Group` that corresponds to the given `type_string`. - - .. note:: will fall back on `aiida.orm.groups.Group` if `type_string` cannot be resolved to loadable entry point. - - :param type_string: the entry point name of the `Group` sub class - :return: sub class of `Group` registered through an entry point - """ - from aiida.common.exceptions import EntryPointError - from aiida.plugins.entry_point import load_entry_point - - try: - group_class = load_entry_point('aiida.groups', type_string) - except EntryPointError: - message = f'could not load entry point `{type_string}`, falling back onto `Group` base class.' - warnings.warn(message) # pylint: disable=no-member - group_class = Group - - return group_class - - -class GroupMeta(ABCMeta): - """Meta class for `aiida.orm.groups.Group` to automatically set the `type_string` attribute.""" - - def __new__(mcs, name, bases, namespace, **kwargs): - from aiida.plugins.entry_point import get_entry_point_from_class - - newcls = ABCMeta.__new__(mcs, name, bases, namespace, **kwargs) # pylint: disable=too-many-function-args - - mod = namespace['__module__'] - entry_point_group, entry_point = get_entry_point_from_class(mod, name) - - if entry_point_group is None or entry_point_group != 'aiida.groups': - newcls._type_string = None # type: ignore[attr-defined] - message = f'no registered entry point for `{mod}:{name}` so its instances will not be storable.' - warnings.warn(message) # pylint: disable=no-member - else: - assert entry_point is not None - newcls._type_string = entry_point.name # type: ignore[attr-defined] # pylint: disable=protected-access - - return newcls - - -class GroupCollection(entities.Collection['Group']): - """Collection of Groups""" - - @staticmethod - def _entity_base_cls() -> Type['Group']: - return Group - - def get_or_create(self, label: Optional[str] = None, **kwargs) -> Tuple['Group', bool]: - """ - Try to retrieve a group from the DB with the given arguments; - create (and store) a new group if such a group was not present yet. - - :param label: group label - - :return: (group, created) where group is the group (new or existing, - in any case already stored) and created is a boolean saying - """ - if not label: - raise ValueError('Group label must be provided') - - res = self.find(filters={'label': label}) - - if not res: - return self.entity_type(label, backend=self.backend, **kwargs).store(), True - - if len(res) > 1: - raise exceptions.MultipleObjectsError('More than one groups found in the database') - - return res[0], False - - def delete(self, pk: int) -> None: - """ - Delete a group - - :param pk: the id of the group to delete - """ - self._backend.groups.delete(pk) - - -class GroupBase: - """A namespace for group related functionality, that is not directly related to its user-facing properties.""" - - def __init__(self, group: 'Group') -> None: - """Construct a new instance of the base namespace.""" - self._group: 'Group' = group - - @cached_property - def extras(self) -> extras.EntityExtras: - """Return the extras of this group.""" - return extras.EntityExtras(self._group) - - -class Group(entities.Entity['BackendGroup', GroupCollection], metaclass=GroupMeta): - """An AiiDA ORM implementation of group of nodes.""" - - # added by metaclass - _type_string: ClassVar[Optional[str]] - - _CLS_COLLECTION = GroupCollection - - def __init__( - self, - label: Optional[str] = None, - user: Optional['User'] = None, - description: str = '', - type_string: Optional[str] = None, - backend: Optional['StorageBackend'] = None - ): - """ - Create a new group. Either pass a dbgroup parameter, to reload - a group from the DB (and then, no further parameters are allowed), - or pass the parameters for the Group creation. - - :param label: The group label, required on creation - :param description: The group description (by default, an empty string) - :param user: The owner of the group (by default, the automatic user) - :param type_string: a string identifying the type of group (by default, - an empty string, indicating an user-defined group. - """ - if not label: - raise ValueError('Group label must be provided') - - backend = backend or get_manager().get_profile_storage() - user = cast(users.User, user or backend.default_user) - type_check(user, users.User) - type_string = self._type_string - - model = backend.groups.create( - label=label, user=user.backend_entity, description=description, type_string=type_string - ) - super().__init__(model) - - @cached_property - def base(self) -> GroupBase: - """Return the group base namespace.""" - return GroupBase(self) - - def __repr__(self) -> str: - return ( - f'<{self.__class__.__name__}: {self.label!r} ' - f'[{"type " + self.type_string if self.type_string else "user-defined"}], of user {self.user.email}>' - ) - - def __str__(self) -> str: - return f'{self.__class__.__name__}<{self.label}>' - - def store(self: SelfType) -> SelfType: - """Verify that the group is allowed to be stored, which is the case along as `type_string` is set.""" - if self._type_string is None: - raise exceptions.StoringNotAllowed('`type_string` is `None` so the group cannot be stored.') - - return super().store() - - @classproperty - def entry_point(cls) -> Optional['EntryPoint']: - """Return the entry point associated this group type. - - :return: the associated entry point or ``None`` if it isn't known. - """ - from aiida.plugins.entry_point import get_entry_point_from_class - return get_entry_point_from_class(cls.__module__, cls.__name__)[1] - - @property - def uuid(self) -> str: - """Return the UUID for this group. - - This identifier is unique across all entities types and backend instances. - - :return: the entity uuid - """ - return self._backend_entity.uuid - - @property - def label(self) -> str: - """ - :return: the label of the group as a string - """ - return self._backend_entity.label - - @label.setter - def label(self, label: str) -> None: - """ - Attempt to change the label of the group instance. If the group is already stored - and the another group of the same type already exists with the desired label, a - UniquenessError will be raised - - :param label: the new group label - :type label: str - - :raises aiida.common.UniquenessError: if another group of same type and label already exists - """ - self._backend_entity.label = label - - @property - def description(self) -> str: - """ - :return: the description of the group as a string - """ - return self._backend_entity.description or '' - - @description.setter - def description(self, description: str) -> None: - """ - :param description: the description of the group as a string - """ - self._backend_entity.description = description - - @property - def type_string(self) -> str: - """ - :return: the string defining the type of the group - """ - return self._backend_entity.type_string - - @property - def user(self) -> 'User': - """ - :return: the user associated with this group - """ - return entities.from_backend_entity(users.User, self._backend_entity.user) - - @user.setter - def user(self, user: 'User') -> None: - """Set the user. - - :param user: the user - """ - type_check(user, users.User) - self._backend_entity.user = user.backend_entity - - def count(self) -> int: - """Return the number of entities in this group. - - :return: integer number of entities contained within the group - """ - return self._backend_entity.count() - - @property - def nodes(self) -> convert.ConvertIterator: - """ - Return a generator/iterator that iterates over all nodes and returns - the respective AiiDA subclasses of Node, and also allows to ask for - the number of nodes in the group using len(). - """ - return convert.ConvertIterator(self._backend_entity.nodes) - - @property - def is_empty(self) -> bool: - """Return whether the group is empty, i.e. it does not contain any nodes. - - :return: True if it contains no nodes, False otherwise - """ - try: - self.nodes[0] - except IndexError: - return True - return False - - def clear(self) -> None: - """Remove all the nodes from this group.""" - return self._backend_entity.clear() - - def add_nodes(self, nodes: Union['Node', Sequence['Node']]) -> None: - """Add a node or a set of nodes to the group. - - :note: all the nodes *and* the group itself have to be stored. - - :param nodes: a single `Node` or a list of `Nodes` - """ - from .nodes import Node - - if not self.is_stored: - raise exceptions.ModificationNotAllowed('cannot add nodes to an unstored group') - - # Cannot use `collections.Iterable` here, because that would also match iterable `Node` sub classes like `List` - if not isinstance(nodes, (list, tuple)): - nodes = [nodes] # type: ignore - - for node in nodes: - type_check(node, Node) - - self._backend_entity.add_nodes([node.backend_entity for node in nodes]) - - def remove_nodes(self, nodes: Union['Node', Sequence['Node']]) -> None: - """Remove a node or a set of nodes to the group. - - :note: all the nodes *and* the group itself have to be stored. - - :param nodes: a single `Node` or a list of `Nodes` - """ - from .nodes import Node - - if not self.is_stored: - raise exceptions.ModificationNotAllowed('cannot add nodes to an unstored group') - - # Cannot use `collections.Iterable` here, because that would also match iterable `Node` sub classes like `List` - if not isinstance(nodes, (list, tuple)): - nodes = [nodes] # type: ignore - - for node in nodes: - type_check(node, Node) - - self._backend_entity.remove_nodes([node.backend_entity for node in nodes]) - - def is_user_defined(self) -> bool: - """ - :return: True if the group is user defined, False otherwise - """ - return not self.type_string - - _deprecated_extra_methods = { - 'extras': 'all', - 'get_extra': 'get', - 'get_extra_many': 'get_many', - 'set_extra': 'set', - 'set_extra_many': 'set_many', - 'reset_extras': 'reset', - 'delete_extra': 'delete', - 'delete_extra_many': 'delete_many', - 'clear_extras': 'clear', - 'extras_items': 'items', - 'extras_keys': 'keys', - } - - def __getattr__(self, name: str) -> Any: - """ - This method is called when an extras is not found in the instance. - - It allows for the handling of deprecated mixin methods. - """ - if name in self._deprecated_extra_methods: - new_name = self._deprecated_extra_methods[name] - kls = self.__class__.__name__ - warn_deprecation( - f'`{kls}.{name}` is deprecated, use `{kls}.base.extras.{new_name}` instead.', version=3, stacklevel=3 - ) - return getattr(self.base.extras, new_name) - - raise AttributeError(name) - - -class AutoGroup(Group): - """Group to be used to contain selected nodes generated, whilst autogrouping is enabled.""" - - -class ImportGroup(Group): - """Group to be used to contain all nodes from an export archive that has been imported.""" - - -class UpfFamily(Group): - """Group that represents a pseudo potential family containing `UpfData` nodes.""" diff --git a/aiida/orm/implementation/__init__.py b/aiida/orm/implementation/__init__.py deleted file mode 100644 index 0f02fcbf65..0000000000 --- a/aiida/orm/implementation/__init__.py +++ /dev/null @@ -1,54 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module containing the backend entity abstracts for storage backends.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .authinfos import * -from .comments import * -from .computers import * -from .entities import * -from .groups import * -from .logs import * -from .nodes import * -from .querybuilder import * -from .storage_backend import * -from .users import * -from .utils import * - -__all__ = ( - 'BackendAuthInfo', - 'BackendAuthInfoCollection', - 'BackendCollection', - 'BackendComment', - 'BackendCommentCollection', - 'BackendComputer', - 'BackendComputerCollection', - 'BackendEntity', - 'BackendEntityExtrasMixin', - 'BackendGroup', - 'BackendGroupCollection', - 'BackendLog', - 'BackendLogCollection', - 'BackendNode', - 'BackendNodeCollection', - 'BackendQueryBuilder', - 'BackendUser', - 'BackendUserCollection', - 'EntityType', - 'StorageBackend', - 'clean_value', - 'validate_attribute_extra_key', -) - -# yapf: enable diff --git a/aiida/orm/implementation/authinfos.py b/aiida/orm/implementation/authinfos.py deleted file mode 100644 index 0294de6203..0000000000 --- a/aiida/orm/implementation/authinfos.py +++ /dev/null @@ -1,96 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module for the backend implementation of the `AuthInfo` ORM class.""" -import abc -from typing import TYPE_CHECKING, Any, Dict - -from .entities import BackendCollection, BackendEntity - -if TYPE_CHECKING: - from .computers import BackendComputer - from .users import BackendUser - -__all__ = ('BackendAuthInfo', 'BackendAuthInfoCollection') - - -class BackendAuthInfo(BackendEntity): - """Backend implementation for the `AuthInfo` ORM class. - - An authinfo is a set of credentials that can be used to authenticate to a remote computer. - """ - - METADATA_WORKDIR = 'workdir' - - @property - @abc.abstractmethod - def enabled(self) -> bool: - """Return whether this instance is enabled. - - :return: boolean, True if enabled, False otherwise - """ - - @enabled.setter - @abc.abstractmethod - def enabled(self, value: bool) -> None: - """Set the enabled state - - :param enabled: boolean, True to enable the instance, False to disable it - """ - - @property - @abc.abstractmethod - def computer(self) -> 'BackendComputer': - """Return the computer associated with this instance.""" - - @property - @abc.abstractmethod - def user(self) -> 'BackendUser': - """Return the user associated with this instance.""" - - @abc.abstractmethod - def get_auth_params(self) -> Dict[str, Any]: - """Return the dictionary of authentication parameters - - :return: a dictionary with authentication parameters - """ - - @abc.abstractmethod - def set_auth_params(self, auth_params: Dict[str, Any]) -> None: - """Set the dictionary of authentication parameters - - :param auth_params: a dictionary with authentication parameters - """ - - @abc.abstractmethod - def get_metadata(self) -> Dict[str, Any]: - """Return the dictionary of metadata - - :return: a dictionary with metadata - """ - - @abc.abstractmethod - def set_metadata(self, metadata: Dict[str, Any]) -> None: - """Set the dictionary of metadata - - :param metadata: a dictionary with metadata - """ - - -class BackendAuthInfoCollection(BackendCollection[BackendAuthInfo]): - """The collection of backend `AuthInfo` entries.""" - - ENTITY_CLASS = BackendAuthInfo - - @abc.abstractmethod - def delete(self, pk: int) -> None: - """Delete an entry from the collection. - - :param pk: the pk of the entry to delete - """ diff --git a/aiida/orm/implementation/comments.py b/aiida/orm/implementation/comments.py deleted file mode 100644 index b44d1932d1..0000000000 --- a/aiida/orm/implementation/comments.py +++ /dev/null @@ -1,120 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module for comment backend classes.""" -import abc -from datetime import datetime -from typing import TYPE_CHECKING, List, Optional - -from .entities import BackendCollection, BackendEntity - -if TYPE_CHECKING: - from .nodes import BackendNode - from .users import BackendUser - -__all__ = ('BackendComment', 'BackendCommentCollection') - - -class BackendComment(BackendEntity): - """Backend implementation for the `Comment` ORM class. - - A comment is a text that can be attached to a node. - """ - - @property - @abc.abstractmethod - def uuid(self) -> str: - """Return the UUID of the comment.""" - - @property - @abc.abstractmethod - def ctime(self) -> datetime: - """Return the creation time of the comment.""" - - @property - @abc.abstractmethod - def mtime(self) -> datetime: - """Return the modified time of the comment.""" - - @abc.abstractmethod - def set_mtime(self, value: datetime) -> None: - """Set the modified time of the comment.""" - - @property - @abc.abstractmethod - def node(self) -> 'BackendNode': - """Return the comment's node.""" - - @property - @abc.abstractmethod - def user(self) -> 'BackendUser': - """Return the comment owner.""" - - @abc.abstractmethod - def set_user(self, value: 'BackendUser') -> None: - """Set the comment owner.""" - - @property - @abc.abstractmethod - def content(self) -> str: - """Return the comment content.""" - - @abc.abstractmethod - def set_content(self, value: str): - """Set the comment content.""" - - -class BackendCommentCollection(BackendCollection[BackendComment]): - """The collection of Comment entries.""" - - ENTITY_CLASS = BackendComment - - @abc.abstractmethod - def create( # type: ignore[override] # pylint: disable=arguments-differ - self, node: 'BackendNode', user: 'BackendUser', content: Optional[str] = None, **kwargs): - """ - Create a Comment for a given node and user - - :param node: a Node instance - :param user: a User instance - :param content: the comment content - :return: a Comment object associated to the given node and user - """ - - @abc.abstractmethod - def delete(self, comment_id: int) -> None: - """ - Remove a Comment from the collection with the given id - - :param comment_id: the id of the comment to delete - - :raises TypeError: if ``comment_id`` is not an `int` - :raises `~aiida.common.exceptions.NotExistent`: if Comment with ID ``comment_id`` is not found - """ - - @abc.abstractmethod - def delete_all(self) -> None: - """ - Delete all Comment entries. - - :raises `~aiida.common.exceptions.IntegrityError`: if all Comments could not be deleted - """ - - @abc.abstractmethod - def delete_many(self, filters: dict) -> List[int]: - """ - Delete Comments based on ``filters`` - - :param filters: similar to QueryBuilder filter - - :return: (former) ``PK`` s of deleted Comments - - :raises TypeError: if ``filters`` is not a `dict` - :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty - """ diff --git a/aiida/orm/implementation/computers.py b/aiida/orm/implementation/computers.py deleted file mode 100644 index 804ce24011..0000000000 --- a/aiida/orm/implementation/computers.py +++ /dev/null @@ -1,109 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Backend specific computer objects and methods""" -import abc -import logging -from typing import Any, Dict - -from .entities import BackendCollection, BackendEntity - -__all__ = ('BackendComputer', 'BackendComputerCollection') - - -class BackendComputer(BackendEntity): - """Backend implementation for the `Computer` ORM class. - - A computer is a resource that can be used to run calculations: - It has an associated transport_type, which points to a plugin for connecting to the resource and passing data, - and a scheduler_type, which points to a plugin for scheduling calculations. - """ - # pylint: disable=too-many-public-methods - - _logger = logging.getLogger(__name__) - - @property - @abc.abstractmethod - def uuid(self) -> str: - """Return the UUID of the computer.""" - - @property - @abc.abstractmethod - def label(self) -> str: - """Return the (unique) label of the computer.""" - - @abc.abstractmethod - def set_label(self, val: str): - """Set the (unique) label of the computer.""" - - @property - @abc.abstractmethod - def description(self) -> str: - """Return the description of the computer.""" - - @abc.abstractmethod - def set_description(self, val: str): - """Set the description of the computer.""" - - @property - @abc.abstractmethod - def hostname(self) -> str: - """Return the hostname of the computer (used to associate the connected device).""" - - @abc.abstractmethod - def set_hostname(self, val: str) -> None: - """ - Set the hostname of this computer - :param val: The new hostname - """ - - @abc.abstractmethod - def get_metadata(self) -> Dict[str, Any]: - """Return the metadata for the computer.""" - - @abc.abstractmethod - def set_metadata(self, metadata: Dict[str, Any]) -> None: - """Set the metadata for the computer.""" - - @abc.abstractmethod - def get_scheduler_type(self) -> str: - """Return the scheduler plugin type.""" - - @abc.abstractmethod - def set_scheduler_type(self, scheduler_type: str) -> None: - """Set the scheduler plugin type.""" - - @abc.abstractmethod - def get_transport_type(self) -> str: - """Return the transport plugin type.""" - - @abc.abstractmethod - def set_transport_type(self, transport_type: str) -> None: - """Set the transport plugin type.""" - - @abc.abstractmethod - def copy(self) -> 'BackendComputer': - """Create an un-stored clone of an already stored `Computer`. - - :raises: ``InvalidOperation`` if the computer is not stored. - """ - - -class BackendComputerCollection(BackendCollection[BackendComputer]): - """The collection of Computer entries.""" - - ENTITY_CLASS = BackendComputer - - @abc.abstractmethod - def delete(self, pk: int) -> None: - """ - Delete an entry with the given pk - - :param pk: the pk of the entry to delete - """ diff --git a/aiida/orm/implementation/entities.py b/aiida/orm/implementation/entities.py deleted file mode 100644 index 41f8e8b988..0000000000 --- a/aiida/orm/implementation/entities.py +++ /dev/null @@ -1,203 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Classes and methods for backend non-specific entities""" -import abc -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Generic, Iterable, List, Tuple, Type, TypeVar - -if TYPE_CHECKING: - from aiida.orm.implementation import StorageBackend - -__all__ = ('BackendEntity', 'BackendCollection', 'EntityType', 'BackendEntityExtrasMixin') - -EntityType = TypeVar('EntityType', bound='BackendEntity') # pylint: disable=invalid-name - - -class BackendEntity(abc.ABC): - """An first-class entity in the backend""" - - def __init__(self, backend: 'StorageBackend', **kwargs: Any): # pylint: disable=unused-argument - self._backend = backend - - @property - def backend(self) -> 'StorageBackend': - """Return the backend this entity belongs to - - :return: the backend instance - """ - return self._backend - - @property - @abc.abstractmethod - def id(self) -> int: # pylint: disable=invalid-name - """Return the id for this entity. - - This is unique only amongst entities of this type for a particular backend. - - :return: the entity id - """ - - @property - def pk(self) -> int: - """Return the id for this entity. - - This is unique only amongst entities of this type for a particular backend. - - :return: the entity id - """ - return self.id - - @abc.abstractmethod - def store(self: EntityType) -> EntityType: - """Store this entity in the backend. - - Whether it is possible to call store more than once is delegated to the object itself - """ - - @property - @abc.abstractmethod - def is_stored(self) -> bool: - """Return whether the entity is stored. - - :return: True if stored, False otherwise - """ - - -class BackendCollection(Generic[EntityType]): - """Container class that represents a collection of entries of a particular backend entity.""" - - ENTITY_CLASS: ClassVar[Type[EntityType]] # type: ignore[misc] - - def __init__(self, backend: 'StorageBackend'): - """ - :param backend: the backend this collection belongs to - """ - assert issubclass(self.ENTITY_CLASS, BackendEntity), 'Must set the ENTRY_CLASS class variable to an entity type' - self._backend = backend - - @property - def backend(self) -> 'StorageBackend': - """Return the backend.""" - return self._backend - - def create(self, **kwargs: Any) -> EntityType: - """ - Create new a entry and set the attributes to those specified in the keyword arguments - - :return: the newly created entry of type ENTITY_CLASS - """ - return self.ENTITY_CLASS(backend=self._backend, **kwargs) - - -class BackendEntityExtrasMixin(abc.ABC): - """Mixin class that adds all abstract methods for the extras column to a backend entity""" - - @property - @abc.abstractmethod - def extras(self) -> Dict[str, Any]: - """Return the complete extras dictionary. - - .. warning:: While the entity is unstored, this will return references of the extras on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the entity is stored, the returned - extras will be a deep copy and mutations of the database extras will have to go through the appropriate set - methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you only need the keys - or some values, use the iterators `extras_keys` and `extras_items`, or the getters `get_extra` and - `get_extra_many` instead. - - :return: the extras as a dictionary - """ - - @abc.abstractmethod - def get_extra(self, key: str) -> Any: - """Return the value of an extra. - - .. warning:: While the entity is unstored, this will return a reference of the extra on the database model, - meaning that changes on the returned value (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the entity is stored, the returned - extra will be a deep copy and mutations of the database extras will have to go through the appropriate set - methods. - - :param key: name of the extra - :return: the value of the extra - :raises AttributeError: if the extra does not exist - """ - - def get_extra_many(self, keys: Iterable[str]) -> List[Any]: - """Return the values of multiple extras. - - .. warning:: While the entity is unstored, this will return references of the extras on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the entity is stored, the returned - extras will be a deep copy and mutations of the database extras will have to go through the appropriate set - methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you only need the keys - or some values, use the iterators `extras_keys` and `extras_items`, or the getters `get_extra` and - `get_extra_many` instead. - - :param keys: a list of extra names - :return: a list of extra values - :raises AttributeError: if at least one extra does not exist - """ - return [self.get_extra(key) for key in keys] - - @abc.abstractmethod - def set_extra(self, key: str, value: Any) -> None: - """Set an extra to the given value. - - :param key: name of the extra - :param value: value of the extra - """ - - def set_extra_many(self, extras: Dict[str, Any]) -> None: - """Set multiple extras. - - .. note:: This will override any existing extras that are present in the new dictionary. - - :param extras: a dictionary with the extras to set - """ - for key, value in extras.items(): - self.set_extra(key, value) - - @abc.abstractmethod - def reset_extras(self, extras: Dict[str, Any]) -> None: - """Reset the extras. - - .. note:: This will completely clear any existing extras and replace them with the new dictionary. - - :param extras: a dictionary with the extras to set - """ - - @abc.abstractmethod - def delete_extra(self, key: str) -> None: - """Delete an extra. - - :param key: name of the extra - :raises AttributeError: if the extra does not exist - """ - - def delete_extra_many(self, keys: Iterable[str]) -> None: - """Delete multiple extras. - - :param keys: names of the extras to delete - :raises AttributeError: if at least one of the extra does not exist - """ - for key in keys: - self.delete_extra(key) - - @abc.abstractmethod - def clear_extras(self) -> None: - """Delete all extras.""" - - @abc.abstractmethod - def extras_items(self) -> Iterable[Tuple[str, Any]]: - """Return an iterator over the extras key/value pairs.""" - - @abc.abstractmethod - def extras_keys(self) -> Iterable[str]: - """Return an iterator over the extra keys.""" diff --git a/aiida/orm/implementation/groups.py b/aiida/orm/implementation/groups.py deleted file mode 100644 index 87b33a8679..0000000000 --- a/aiida/orm/implementation/groups.py +++ /dev/null @@ -1,165 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Backend group module""" -import abc -from typing import TYPE_CHECKING, List, Optional, Protocol, Sequence, Union - -from .entities import BackendCollection, BackendEntity, BackendEntityExtrasMixin -from .nodes import BackendNode - -if TYPE_CHECKING: - from .users import BackendUser - -__all__ = ('BackendGroup', 'BackendGroupCollection') - - -class NodeIterator(Protocol): - """Protocol for iterating over nodes in a group""" - - def __iter__(self) -> 'NodeIterator': # pylint: disable=non-iterator-returned - """Return an iterator over the nodes in the group.""" - - def __next__(self) -> BackendNode: - """Return the next node in the group.""" - - def __getitem__(self, value: Union[int, slice]) -> Union[BackendNode, List[BackendNode]]: - """Index node(s) from the group.""" - - def __len__(self) -> int: # pylint: disable=invalid-length-returned - """Return the number of nodes in the group.""" - - -class BackendGroup(BackendEntity, BackendEntityExtrasMixin): - """Backend implementation for the `Group` ORM class. - - A group is a collection of nodes. - """ - - @property - @abc.abstractmethod - def label(self) -> str: - """Return the name of the group as a string.""" - - @label.setter - @abc.abstractmethod - def label(self, name: str) -> None: - """ - Attempt to change the name of the group instance. If the group is already stored - and the another group of the same type already exists with the desired name, a - UniquenessError will be raised - - :param name: the new group name - :raises aiida.common.UniquenessError: if another group of same type and name already exists - """ - - @property - @abc.abstractmethod - def description(self) -> Optional[str]: - """Return the description of the group as a string.""" - - @description.setter - @abc.abstractmethod - def description(self, value: Optional[str]): - """Return the description of the group as a string.""" - - @property - @abc.abstractmethod - def type_string(self) -> str: - """Return the string defining the type of the group.""" - - @property - @abc.abstractmethod - def user(self) -> 'BackendUser': - """Return a backend user object, representing the user associated to this group.""" - - @user.setter - @abc.abstractmethod - def user(self, user: 'BackendUser') -> None: - """Set the user of this group.""" - - @property - @abc.abstractmethod - def uuid(self) -> str: - """Return the UUID of the group.""" - - @property - @abc.abstractmethod - def nodes(self) -> NodeIterator: - """ - Return a generator/iterator that iterates over all nodes and returns - the respective AiiDA subclasses of Node, and also allows to ask for - the number of nodes in the group using len(). - """ - - @abc.abstractmethod - def count(self) -> int: - """Return the number of entities in this group. - - :return: integer number of entities contained within the group - """ - - @abc.abstractmethod - def clear(self) -> None: - """Remove all the nodes from this group.""" - - def add_nodes(self, nodes: Sequence[BackendNode], **kwargs): # pylint: disable=unused-argument - """Add a set of nodes to the group. - - :note: all the nodes *and* the group itself have to be stored. - - :param nodes: a list of `BackendNode` instances to be added to this group - """ - if not self.is_stored: - raise ValueError('group has to be stored before nodes can be added') - - if not isinstance(nodes, (list, tuple)): - raise TypeError('nodes has to be a list or tuple') - - if any(not isinstance(node, BackendNode) for node in nodes): - raise TypeError(f'nodes have to be of type {BackendNode}') - - def remove_nodes(self, nodes: Sequence[BackendNode]) -> None: - """Remove a set of nodes from the group. - - :note: all the nodes *and* the group itself have to be stored. - - :param nodes: a list of `BackendNode` instances to be removed from this group - """ - if not self.is_stored: - raise ValueError('group has to be stored before nodes can be removed') - - if not isinstance(nodes, (list, tuple)): - raise TypeError('nodes has to be a list or tuple') - - if any(not isinstance(node, BackendNode) for node in nodes): - raise TypeError(f'nodes have to be of type {BackendNode}') - - def __repr__(self) -> str: - return f'<{self.__class__.__name__}: {str(self)}>' - - def __str__(self) -> str: - if self.type_string: - return f'"{self.label}" [type {self.type_string}], of user {self.user.email}' - - return f'"{self.label}" [user-defined], of user {self.user.email}' - - -class BackendGroupCollection(BackendCollection[BackendGroup]): - """The collection of Group entries.""" - - ENTITY_CLASS = BackendGroup - - @abc.abstractmethod - def delete(self, id: int) -> None: # pylint: disable=redefined-builtin, invalid-name - """ - Delete a group with the given id - - :param id: the id of the group to delete - """ diff --git a/aiida/orm/implementation/logs.py b/aiida/orm/implementation/logs.py deleted file mode 100644 index 1cb3fec884..0000000000 --- a/aiida/orm/implementation/logs.py +++ /dev/null @@ -1,97 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Backend group module""" -import abc -from datetime import datetime -from typing import Any, Dict, List - -from .entities import BackendCollection, BackendEntity - -__all__ = ('BackendLog', 'BackendLogCollection') - - -class BackendLog(BackendEntity): - """Backend implementation for the `Log` ORM class. - - A log is a record of logging call for a particular node. - """ - - @property - @abc.abstractmethod - def uuid(self) -> str: - """Return the UUID of the log entry.""" - - @property - @abc.abstractmethod - def time(self) -> datetime: - """Return the time corresponding to the log entry.""" - - @property - @abc.abstractmethod - def loggername(self) -> str: - """Return the name of the logger that created this entry.""" - - @property - @abc.abstractmethod - def levelname(self) -> str: - """Return the name of the log level.""" - - @property - @abc.abstractmethod - def dbnode_id(self) -> int: - """Return the id of the object that created the log entry.""" - - @property - @abc.abstractmethod - def message(self) -> str: - """Return the message corresponding to the log entry.""" - - @property - @abc.abstractmethod - def metadata(self) -> Dict[str, Any]: - """Return the metadata corresponding to the log entry.""" - - -class BackendLogCollection(BackendCollection[BackendLog]): - """The collection of Log entries.""" - - ENTITY_CLASS = BackendLog - - @abc.abstractmethod - def delete(self, log_id: int) -> None: - """ - Remove a Log entry from the collection with the given id - - :param log_id: id of the Log to delete - - :raises TypeError: if ``log_id`` is not an `int` - :raises `~aiida.common.exceptions.NotExistent`: if Log with ID ``log_id`` is not found - """ - - @abc.abstractmethod - def delete_all(self) -> None: - """ - Delete all Log entries. - - :raises `~aiida.common.exceptions.IntegrityError`: if all Logs could not be deleted - """ - - @abc.abstractmethod - def delete_many(self, filters: dict) -> List[int]: - """ - Delete Logs based on ``filters`` - - :param filters: similar to QueryBuilder filter - - :return: (former) ``PK`` s of deleted Logs - - :raises TypeError: if ``filters`` is not a `dict` - :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty - """ diff --git a/aiida/orm/implementation/nodes.py b/aiida/orm/implementation/nodes.py deleted file mode 100644 index dee1d94f20..0000000000 --- a/aiida/orm/implementation/nodes.py +++ /dev/null @@ -1,339 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Abstract BackendNode and BackendNodeCollection implementation.""" -import abc -from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar - -from .entities import BackendCollection, BackendEntity, BackendEntityExtrasMixin - -if TYPE_CHECKING: - from ..utils import LinkTriple - from .computers import BackendComputer - from .users import BackendUser - -__all__ = ('BackendNode', 'BackendNodeCollection') - -BackendNodeType = TypeVar('BackendNodeType', bound='BackendNode') - - -class BackendNode(BackendEntity, BackendEntityExtrasMixin, metaclass=abc.ABCMeta): - """Backend implementation for the `Node` ORM class. - - A node stores data input or output from a computation. - """ - - # pylint: disable=too-many-public-methods - - @abc.abstractmethod - def clone(self: BackendNodeType) -> BackendNodeType: - """Return an unstored clone of ourselves. - - :return: an unstored `BackendNode` with the exact same attributes and extras as self - """ - - @property - @abc.abstractmethod - def uuid(self) -> str: - """Return the node UUID. - - :return: the string representation of the UUID - """ - - @property - @abc.abstractmethod - def node_type(self) -> str: - """Return the node type. - - :return: the node type - """ - - @property - @abc.abstractmethod - def process_type(self) -> Optional[str]: - """Return the node process type. - - :return: the process type - """ - - @process_type.setter - @abc.abstractmethod - def process_type(self, value: Optional[str]) -> None: - """Set the process type. - - :param value: the new value to set - """ - - @property - @abc.abstractmethod - def label(self) -> str: - """Return the node label. - - :return: the label - """ - - @label.setter - @abc.abstractmethod - def label(self, value: str) -> None: - """Set the label. - - :param value: the new value to set - """ - - @property - @abc.abstractmethod - def description(self) -> str: - """Return the node description. - - :return: the description - """ - - @description.setter - @abc.abstractmethod - def description(self, value: str) -> None: - """Set the description. - - :param value: the new value to set - """ - - @property - @abc.abstractmethod - def repository_metadata(self) -> Dict[str, Any]: - """Return the node repository metadata. - - :return: the repository metadata - """ - - @repository_metadata.setter - @abc.abstractmethod - def repository_metadata(self, value: Dict[str, Any]) -> None: - """Set the repository metadata. - - :param value: the new value to set - """ - - @property - @abc.abstractmethod - def computer(self) -> Optional['BackendComputer']: - """Return the computer of this node. - - :return: the computer or None - """ - - @computer.setter - @abc.abstractmethod - def computer(self, computer: Optional['BackendComputer']) -> None: - """Set the computer of this node. - - :param computer: a `BackendComputer` - """ - - @property - @abc.abstractmethod - def user(self) -> 'BackendUser': - """Return the user of this node. - - :return: the user - """ - - @user.setter - @abc.abstractmethod - def user(self, user: 'BackendUser') -> None: - """Set the user of this node. - - :param user: a `BackendUser` - """ - - @property - @abc.abstractmethod - def ctime(self) -> datetime: - """Return the node ctime. - - :return: the ctime - """ - - @property - @abc.abstractmethod - def mtime(self) -> datetime: - """Return the node mtime. - - :return: the mtime - """ - - @abc.abstractmethod - def add_incoming(self, source: 'BackendNode', link_type, link_label): - """Add a link of the given type from a given node to ourself. - - :param source: the node from which the link is coming - :param link_type: the link type - :param link_label: the link label - :return: True if the proposed link is allowed, False otherwise - :raise TypeError: if `source` is not a Node instance or `link_type` is not a `LinkType` enum - :raise ValueError: if the proposed link is invalid - :raise aiida.common.ModificationNotAllowed: if either source or target node is not stored - """ - - @abc.abstractmethod - def store( # pylint: disable=arguments-differ - self: BackendNodeType, - links: Optional[Sequence['LinkTriple']] = None, - with_transaction: bool = True, - clean: bool = True - ) -> BackendNodeType: - """Store the node in the database. - - :param links: optional links to add before storing - :param with_transaction: if False, do not use a transaction because the caller will already have opened one. - :param clean: boolean, if True, will clean the attributes and extras before attempting to store - """ - - @abc.abstractmethod - def clean_values(self): - """Clean the values of the node fields. - - This method is called before storing the node. - The purpose of this method is to convert data to a type which can be serialized and deserialized - for storage in the DB without its value changing. - """ - - # attributes methods - - @property - @abc.abstractmethod - def attributes(self) -> Dict[str, Any]: - """Return the complete attributes dictionary. - - .. warning:: While the entity is unstored, this will return references of the attributes on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the entity is stored, the returned - attributes will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you - only need the keys or some values, use the iterators `attributes_keys` and `attributes_items`, or the - getters `get_attribute` and `get_attribute_many` instead. - - :return: the attributes as a dictionary - """ - - @abc.abstractmethod - def get_attribute(self, key: str) -> Any: - """Return the value of an attribute. - - .. warning:: While the entity is unstored, this will return a reference of the attribute on the database model, - meaning that changes on the returned value (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the entity is stored, the returned - attribute will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. - - :param key: name of the attribute - :return: the value of the attribute - :raises AttributeError: if the attribute does not exist - """ - - def get_attribute_many(self, keys: Iterable[str]) -> List[Any]: - """Return the values of multiple attributes. - - .. warning:: While the entity is unstored, this will return references of the attributes on the database model, - meaning that changes on the returned values (if they are mutable themselves, e.g. a list or dictionary) will - automatically be reflected on the database model as well. As soon as the entity is stored, the returned - attributes will be a deep copy and mutations of the database attributes will have to go through the - appropriate set methods. Therefore, once stored, retrieving a deep copy can be a heavy operation. If you - only need the keys or some values, use the iterators `attributes_keys` and `attributes_items`, or the - getters `get_attribute` and `get_attribute_many` instead. - - :param keys: a list of attribute names - :return: a list of attribute values - :raises AttributeError: if at least one attribute does not exist - """ - try: - return [self.get_attribute(key) for key in keys] - except KeyError as exception: - raise AttributeError(f'attribute `{exception}` does not exist') from exception - - @abc.abstractmethod - def set_attribute(self, key: str, value: Any) -> None: - """Set an attribute to the given value. - - :param key: name of the attribute - :param value: value of the attribute - """ - - def set_attribute_many(self, attributes: Dict[str, Any]) -> None: - """Set multiple attributes. - - .. note:: This will override any existing attributes that are present in the new dictionary. - - :param attributes: a dictionary with the attributes to set - """ - for key, value in attributes.items(): - self.set_attribute(key, value) - - @abc.abstractmethod - def reset_attributes(self, attributes: Dict[str, Any]) -> None: - """Reset the attributes. - - .. note:: This will completely clear any existing attributes and replace them with the new dictionary. - - :param attributes: a dictionary with the attributes to set - """ - - @abc.abstractmethod - def delete_attribute(self, key: str) -> None: - """Delete an attribute. - - :param key: name of the attribute - :raises AttributeError: if the attribute does not exist - """ - - def delete_attribute_many(self, keys: Iterable[str]) -> None: - """Delete multiple attributes. - - :param keys: names of the attributes to delete - :raises AttributeError: if at least one of the attribute does not exist - """ - for key in keys: - self.delete_attribute(key) - - @abc.abstractmethod - def clear_attributes(self): - """Delete all attributes.""" - - @abc.abstractmethod - def attributes_items(self) -> Iterable[Tuple[str, Any]]: - """Return an iterator over the attributes. - - :return: an iterator with attribute key value pairs - """ - - @abc.abstractmethod - def attributes_keys(self) -> Iterable[str]: - """Return an iterator over the attribute keys. - - :return: an iterator with attribute keys - """ - - -class BackendNodeCollection(BackendCollection[BackendNode]): - """The collection of `BackendNode` entries.""" - - ENTITY_CLASS = BackendNode - - @abc.abstractmethod - def get(self, pk: int): - """Return a Node entry from the collection with the given id - - :param pk: id of the node - """ - - @abc.abstractmethod - def delete(self, pk: int) -> None: - """Remove a Node entry from the collection with the given id - - :param pk: id of the node to delete - """ diff --git a/aiida/orm/implementation/querybuilder.py b/aiida/orm/implementation/querybuilder.py deleted file mode 100644 index 55e649aac3..0000000000 --- a/aiida/orm/implementation/querybuilder.py +++ /dev/null @@ -1,151 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Abstract `QueryBuilder` definition.""" -import abc -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional, Set, TypedDict, Union - -from aiida.common.lang import type_check -from aiida.common.log import AIIDA_LOGGER -from aiida.orm.entities import EntityTypes - -if TYPE_CHECKING: - from aiida.orm.implementation import StorageBackend - -__all__ = ('BackendQueryBuilder',) - -QUERYBUILD_LOGGER = AIIDA_LOGGER.getChild('orm.querybuilder') - -EntityRelationships: Dict[str, Set[str]] = { - EntityTypes.AUTHINFO.value: {'with_computer', 'with_user'}, - EntityTypes.COMMENT.value: {'with_node', 'with_user'}, - EntityTypes.COMPUTER.value: {'with_node'}, - EntityTypes.GROUP.value: {'with_node', 'with_user'}, - EntityTypes.LOG.value: {'with_node'}, - EntityTypes.NODE.value: { - 'with_comment', 'with_log', 'with_incoming', 'with_outgoing', 'with_descendants', 'with_ancestors', - 'with_computer', 'with_user', 'with_group' - }, - EntityTypes.USER.value: {'with_authinfo', 'with_comment', 'with_group', 'with_node'}, - EntityTypes.LINK.value: set(), -} - - -class PathItemType(TypedDict): - """An item on the query path""" - - entity_type: Union[str, List[str]] - # this can be derived from the entity_type, but it is more efficient to store - orm_base: Literal['node', 'group', 'authinfo', 'comment', 'computer', 'log', 'user'] - tag: str - joining_keyword: str - joining_value: str - outerjoin: bool - edge_tag: str - - -class QueryDictType(TypedDict): - """A JSON serialisable representation of a ``QueryBuilder`` instance""" - - path: List[PathItemType] - # mapping: tag -> 'and' | 'or' | '~or' | '~and' | '!and' | '!or' -> [] -> operator -> value - # -> operator -> value - filters: Dict[str, Dict[str, Union[Dict[str, List[Dict[str, Any]]], Dict[str, Any]]]] - # mapping: tag -> [] -> field -> 'func' -> 'max' | 'min' | 'count' - # 'cast' -> 'b' | 'd' | 'f' | 'i' | 'j' | 't' - project: Dict[str, List[Dict[str, Dict[str, Any]]]] - # list of mappings: tag -> [] -> field -> 'order' -> 'asc' | 'desc' - # 'cast' -> 'b' | 'd' | 'f' | 'i' | 'j' | 't' - order_by: List[Dict[str, List[Dict[str, Dict[str, str]]]]] - offset: Optional[int] - limit: Optional[int] - distinct: bool - - -# This global variable is necessary to enable the subclassing functionality for the `Group` entity. The current -# implementation of the `QueryBuilder` was written with the assumption that only `Node` was subclassable. Support for -# subclassing was added later for `Group` and is based on its `type_string`, but the current implementation does not -# allow to extend this support to the `QueryBuilder` in an elegant way. The prefix `group.` needs to be used in various -# places to make it work, but really the internals of the `QueryBuilder` should be rewritten to in principle support -# subclassing for any entity type. This workaround should then be able to be removed. -GROUP_ENTITY_TYPE_PREFIX = 'group.' - - -class BackendQueryBuilder(abc.ABC): - """Backend query builder interface""" - - def __init__(self, backend: 'StorageBackend'): - """ - :param backend: the backend - """ - from .storage_backend import StorageBackend - type_check(backend, StorageBackend) - self._backend = backend - - @abc.abstractmethod - def count(self, data: QueryDictType) -> int: - """Return the number of results of the query""" - - @abc.abstractmethod - def first(self, data: QueryDictType) -> Optional[List[Any]]: - """Executes query, asking for one instance. - - :returns: One row of aiida results - """ - - @abc.abstractmethod - def iterall(self, data: QueryDictType, batch_size: Optional[int]) -> Iterable[List[Any]]: - """Return an iterator over all the results of a list of lists.""" - - @abc.abstractmethod - def iterdict(self, data: QueryDictType, batch_size: Optional[int]) -> Iterable[Dict[str, Dict[str, Any]]]: - """Return an iterator over all the results of a list of dictionaries.""" - - def as_sql(self, data: QueryDictType, inline: bool = False) -> str: - """Convert the query to an SQL string representation. - - .. warning:: - - This method should be used for debugging purposes only, - since normally sqlalchemy will handle this process internally. - - :params inline: Inline bound parameters (this is normally handled by the Python DBAPI). - """ - raise NotImplementedError - - def analyze_query(self, data: QueryDictType, execute: bool = True, verbose: bool = False) -> str: - """Return the query plan, i.e. a list of SQL statements that will be executed. - - See: https://www.postgresql.org/docs/11/sql-explain.html - - :params execute: Carry out the command and show actual run times and other statistics. - :params verbose: Display additional information regarding the plan. - """ - raise NotImplementedError - - @abc.abstractmethod - def get_creation_statistics(self, user_pk: Optional[int] = None) -> Dict[str, Any]: - """Return a dictionary with the statistics of node creation, summarized by day. - - :note: Days when no nodes were created are not present in the returned `ctime_by_day` dictionary. - - :param user_pk: If None (default), return statistics for all users. - If user pk is specified, return only the statistics for the given user. - - :return: a dictionary as follows:: - - { - "total": TOTAL_NUM_OF_NODES, - "types": {TYPESTRING1: count, TYPESTRING2: count, ...}, - "ctime_by_day": {'YYYY-MMM-DD': count, ...} - } - - where in `ctime_by_day` the key is a string in the format 'YYYY-MM-DD' and the value is - an integer with the number of nodes created that day. - """ diff --git a/aiida/orm/implementation/storage_backend.py b/aiida/orm/implementation/storage_backend.py deleted file mode 100644 index bc99001ffb..0000000000 --- a/aiida/orm/implementation/storage_backend.py +++ /dev/null @@ -1,345 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Generic backend related objects""" -import abc -from typing import TYPE_CHECKING, Any, ContextManager, List, Optional, Sequence, TypeVar, Union - -if TYPE_CHECKING: - from aiida.manage.configuration.profile import Profile - from aiida.orm.autogroup import AutogroupManager - from aiida.orm.entities import EntityTypes - from aiida.orm.implementation import ( - BackendAuthInfoCollection, - BackendCommentCollection, - BackendComputerCollection, - BackendGroupCollection, - BackendLogCollection, - BackendNodeCollection, - BackendQueryBuilder, - BackendUserCollection, - ) - from aiida.orm.users import User - from aiida.repository.backend.abstract import AbstractRepositoryBackend - -__all__ = ('StorageBackend',) - -TransactionType = TypeVar('TransactionType') # pylint: disable=invalid-name - - -class StorageBackend(abc.ABC): # pylint: disable=too-many-public-methods - """Abstraction for a backend to read/write persistent data for a profile's provenance graph. - - AiiDA splits data storage into two sources: - - - Searchable data, which is stored in the database and can be queried using the QueryBuilder - - Non-searchable (binary) data, which is stored in the repository and can be loaded using the RepositoryBackend - - The two sources are inter-linked by the ``Node.base.repository.metadata``. - Once stored, the leaf values of this dictionary must be valid pointers to object keys in the repository. - - For a completely new storage, the ``initialise`` method should be called first. This will automatically initialise - the repository and the database with the current schema. The class methods,`version_profile` and `migrate` should be - able to be called for existing storage, at any supported schema version. But an instance of this class should be - created only for the latest schema version. - """ - - @classmethod - @abc.abstractmethod - def version_head(cls) -> str: - """Return the head schema version of this storage backend type.""" - - @classmethod - @abc.abstractmethod - def version_profile(cls, profile: 'Profile') -> Optional[str]: - """Return the schema version of the given profile's storage, or None for empty/uninitialised storage. - - :raises: `~aiida.common.exceptions.UnreachableStorage` if the storage cannot be accessed - """ - - @classmethod - @abc.abstractmethod - def initialise(cls, profile: 'Profile', reset: bool = False) -> bool: - """Initialise the storage backend. - - This is typically used once when a new storage backed is created. If this method returns without exceptions the - storage backend is ready for use. If the backend already seems initialised, this method is a no-op. - - :param reset: If ``true``, destroy the backend if it already exists including all of its data before recreating - and initialising it. This is useful for example for test profiles that need to be reset before or after - tests having run. - :returns: ``True`` if the storage was initialised by the function call, ``False`` if it was already initialised. - """ - - @classmethod - @abc.abstractmethod - def migrate(cls, profile: 'Profile') -> None: - """Migrate the storage of a profile to the latest schema version. - - If the schema version is already the latest version, this method does nothing. If the storage is uninitialised, - this method will raise an exception. - - :raises: :class`~aiida.common.exceptions.UnreachableStorage` if the storage cannot be accessed. - :raises: :class:`~aiida.common.exceptions.StorageMigrationError` if the storage is not initialised. - """ - - @abc.abstractmethod - def __init__(self, profile: 'Profile') -> None: - """Initialize the backend, for this profile. - - :raises: `~aiida.common.exceptions.UnreachableStorage` if the storage cannot be accessed - :raises: `~aiida.common.exceptions.IncompatibleStorageSchema` - if the profile's storage schema is not at the latest version (and thus should be migrated) - :raises: :raises: :class:`aiida.common.exceptions.CorruptStorage` if the storage is internally inconsistent - """ - from aiida.orm.autogroup import AutogroupManager - self._profile = profile - self._default_user: Optional['User'] = None - self._autogroup = AutogroupManager(self) - - @abc.abstractmethod - def __str__(self) -> str: - """Return a string showing connection details for this instance.""" - - @property - def profile(self) -> 'Profile': - """Return the profile for this backend.""" - return self._profile - - @property - def autogroup(self) -> 'AutogroupManager': - """Return the autogroup manager for this backend.""" - return self._autogroup - - def version(self) -> str: - """Return the schema version of the profile's storage.""" - version = self.version_profile(self.profile) - assert version is not None - return version - - @abc.abstractmethod - def close(self): - """Close the storage access.""" - - @property - @abc.abstractmethod - def is_closed(self) -> bool: - """Return whether the storage is closed.""" - - @abc.abstractmethod - def _clear(self) -> None: - """Clear the storage, removing all data. - - .. warning:: This is a destructive operation, and should only be used for testing purposes. - """ - from aiida.orm.autogroup import AutogroupManager - self._autogroup = AutogroupManager(self) - self._default_user = None - - @property - @abc.abstractmethod - def authinfos(self) -> 'BackendAuthInfoCollection': - """Return the collection of authorisation information objects""" - - @property - @abc.abstractmethod - def comments(self) -> 'BackendCommentCollection': - """Return the collection of comments""" - - @property - @abc.abstractmethod - def computers(self) -> 'BackendComputerCollection': - """Return the collection of computers""" - - @property - @abc.abstractmethod - def groups(self) -> 'BackendGroupCollection': - """Return the collection of groups""" - - @property - @abc.abstractmethod - def logs(self) -> 'BackendLogCollection': - """Return the collection of logs""" - - @property - @abc.abstractmethod - def nodes(self) -> 'BackendNodeCollection': - """Return the collection of nodes""" - - @property - @abc.abstractmethod - def users(self) -> 'BackendUserCollection': - """Return the collection of users""" - - @property - def default_user(self) -> Optional['User']: - """Return the default user for the profile, if it has been created. - - This is cached, since it is a frequently used operation, for creating other entities. - """ - from aiida.orm import QueryBuilder, User - - if self._default_user is None and self.profile.default_user_email: - query = QueryBuilder(self).append(User, filters={'email': self.profile.default_user_email}) - self._default_user = query.first(flat=True) - return self._default_user - - @abc.abstractmethod - def query(self) -> 'BackendQueryBuilder': - """Return an instance of a query builder implementation for this backend""" - - @abc.abstractmethod - def transaction(self) -> ContextManager[Any]: - """ - Get a context manager that can be used as a transaction context for a series of backend operations. - If there is an exception within the context then the changes will be rolled back and the state will - be as before entering. Transactions can be nested. - - :return: a context manager to group database operations - """ - - @property - @abc.abstractmethod - def in_transaction(self) -> bool: - """Return whether a transaction is currently active.""" - - @abc.abstractmethod - def bulk_insert(self, entity_type: 'EntityTypes', rows: List[dict], allow_defaults: bool = False) -> List[int]: - """Insert a list of entities into the database, directly into a backend transaction. - - :param entity_type: The type of the entity - :param data: A list of dictionaries, containing all fields of the backend model, - except the `id` field (a.k.a primary key), which will be generated dynamically - :param allow_defaults: If ``False``, assert that each row contains all fields (except primary key(s)), - otherwise, allow default values for missing fields. - - :raises: ``IntegrityError`` if the keys in a row are not a subset of the columns in the table - - :returns: The list of generated primary keys for the entities - """ - - @abc.abstractmethod - def bulk_update(self, entity_type: 'EntityTypes', rows: List[dict]) -> None: - """Update a list of entities in the database, directly with a backend transaction. - - :param entity_type: The type of the entity - :param data: A list of dictionaries, containing fields of the backend model to update, - and the `id` field (a.k.a primary key) - - :raises: ``IntegrityError`` if the keys in a row are not a subset of the columns in the table - """ - - @abc.abstractmethod - def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]): - """Delete all nodes corresponding to pks in the input and any links to/from them. - - This method is intended to be used within a transaction context. - - :param pks_to_delete: a sequence of node pks to delete - - :raises: ``AssertionError`` if a transaction is not active - """ - - @abc.abstractmethod - def get_repository(self) -> 'AbstractRepositoryBackend': - """Return the object repository configured for this backend.""" - - @abc.abstractmethod - def set_global_variable( - self, key: str, value: Union[None, str, int, float], description: Optional[str] = None, overwrite=True - ) -> None: - """Set a global variable in the storage. - - :param key: the key of the setting - :param value: the value of the setting - :param description: the description of the setting (optional) - :param overwrite: if True, overwrite the setting if it already exists - - :raises: `ValueError` if the key already exists and `overwrite` is False - """ - - @abc.abstractmethod - def get_global_variable(self, key: str) -> Union[None, str, int, float]: - """Return a global variable from the storage. - - :param key: the key of the setting - - :raises: `KeyError` if the setting does not exist - """ - - @abc.abstractmethod - def maintain(self, full: bool = False, dry_run: bool = False, **kwargs) -> None: - """Perform maintenance tasks on the storage. - - If `full == True`, then this method may attempt to block the profile associated with the - storage to guarantee the safety of its procedures. This will not only prevent any other - subsequent process from accessing that profile, but will also first check if there is - already any process using it and raise if that is the case. The user will have to manually - stop any processes that is currently accessing the profile themselves or wait for it to - finish on its own. - - :param full: flag to perform operations that require to stop using the profile to be maintained. - :param dry_run: flag to only print the actions that would be taken without actually executing them. - """ - - def get_info(self, detailed: bool = False) -> dict: - """Return general information on the storage. - - :param detailed: flag to request more detailed information about the content of the storage. - :returns: a nested dict with the relevant information. - """ - return {'entities': self.get_orm_entities(detailed=detailed)} - - def get_orm_entities(self, detailed: bool = False) -> dict: - """Return a mapping with an overview of the storage contents regarding ORM entities. - - :param detailed: flag to request more detailed information about the content of the storage. - :returns: a nested dict with the relevant information. - """ - from aiida.orm import Comment, Computer, Group, Log, Node, QueryBuilder, User - - data = {} - - query_user = QueryBuilder(self).append(User, project=['email']) - data['Users'] = {'count': query_user.count()} - if detailed: - data['Users']['emails'] = sorted({email for email, in query_user.iterall() if email is not None}) - - query_comp = QueryBuilder(self).append(Computer, project=['label']) - data['Computers'] = {'count': query_comp.count()} - if detailed: - data['Computers']['labels'] = sorted({comp for comp, in query_comp.iterall() if comp is not None}) - - count = QueryBuilder(self).append(Node).count() - data['Nodes'] = {'count': count} - if detailed: - node_types = sorted({ - typ for typ, in QueryBuilder(self).append(Node, project=['node_type']).iterall() if typ is not None - }) - data['Nodes']['node_types'] = node_types - process_types = sorted({ - typ for typ, in QueryBuilder(self).append(Node, project=['process_type']).iterall() if typ is not None - }) - data['Nodes']['process_types'] = [p for p in process_types if p] - - query_group = QueryBuilder(self).append(Group, project=['type_string']) - data['Groups'] = {'count': query_group.count()} - if detailed: - data['Groups']['type_strings'] = sorted({typ for typ, in query_group.iterall() if typ is not None}) - - count = QueryBuilder(self).append(Comment).count() - data['Comments'] = {'count': count} - - count = QueryBuilder(self).append(Log).count() - data['Logs'] = {'count': count} - - count = QueryBuilder(self).append(entity_type='link').count() - data['Links'] = {'count': count} - - return data diff --git a/aiida/orm/implementation/users.py b/aiida/orm/implementation/users.py deleted file mode 100644 index c67d13e805..0000000000 --- a/aiida/orm/implementation/users.py +++ /dev/null @@ -1,99 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Backend user""" -import abc - -from .entities import BackendCollection, BackendEntity - -__all__ = ('BackendUser', 'BackendUserCollection') - - -class BackendUser(BackendEntity): - """Backend implementation for the `User` ORM class. - - A user can be assigned as the creator of a variety of other entities. - """ - - @property - @abc.abstractmethod - def email(self) -> str: - """ - Get the email address of the user - - :return: the email address - """ - - @email.setter - @abc.abstractmethod - def email(self, val: str) -> None: - """ - Set the email address of the user - - :param val: the new email address - """ - - @property - @abc.abstractmethod - def first_name(self) -> str: - """ - Get the user's first name - - :return: the first name - """ - - @first_name.setter - @abc.abstractmethod - def first_name(self, val: str) -> None: - """ - Set the user's first name - - :param val: the new first name - """ - - @property - @abc.abstractmethod - def last_name(self) -> str: - """ - Get the user's last name - - :return: the last name - """ - - @last_name.setter - @abc.abstractmethod - def last_name(self, val: str) -> None: - """ - Set the user's last name - - :param val: the new last name - """ - - @property - @abc.abstractmethod - def institution(self) -> str: - """ - Get the user's institution - - :return: the institution - """ - - @institution.setter - @abc.abstractmethod - def institution(self, val: str) -> None: - """ - Set the user's institution - - :param val: the new institution - """ - - -class BackendUserCollection(BackendCollection[BackendUser]): - - ENTITY_CLASS = BackendUser diff --git a/aiida/orm/implementation/utils.py b/aiida/orm/implementation/utils.py deleted file mode 100644 index 76791336c2..0000000000 --- a/aiida/orm/implementation/utils.py +++ /dev/null @@ -1,117 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Utility methods for backend non-specific implementations.""" -from collections.abc import Iterable, Mapping -from decimal import Decimal -import math -import numbers - -from aiida.common import exceptions -from aiida.common.constants import AIIDA_FLOAT_PRECISION - -# This separator character is reserved to indicate nested fields in node attribute and extras dictionaries and -# therefore is not allowed in individual attribute or extra keys. -FIELD_SEPARATOR = '.' - -__all__ = ('validate_attribute_extra_key', 'clean_value') - - -def validate_attribute_extra_key(key): - """Validate the key for an entity attribute or extra. - - :raise aiida.common.ValidationError: if the key is not a string or contains reserved separator character - """ - if not key or not isinstance(key, str): - raise exceptions.ValidationError('key for attributes or extras should be a string') - - if FIELD_SEPARATOR in key: - raise exceptions.ValidationError( - f'key for attributes or extras cannot contain the character `{FIELD_SEPARATOR}`' - ) - - -def clean_value(value): - """ - Get value from input and (recursively) replace, if needed, all occurrences - of BaseType AiiDA data nodes with their value, and List with a standard list. - It also makes a deep copy of everything - The purpose of this function is to convert data to a type which can be serialized and deserialized - for storage in the DB without its value changing. - - Note however that there is no logic to avoid infinite loops when the - user passes some perverse recursive dictionary or list. - In any case, however, this would not be storable by AiiDA... - - :param value: A value to be set as an attribute or an extra - :return: a "cleaned" value, potentially identical to value, but with - values replaced where needed. - """ - # Must be imported in here to avoid recursive imports - from aiida.orm import BaseType - - def clean_builtin(val): - """ - A function to clean build-in python values (`BaseType`). - - It mainly checks that we don't store NaN or Inf. - """ - # This is a whitelist of all the things we understand currently - if val is None or isinstance(val, (bool, str)): - return val - - # This fixes #2773 - in python3, ``numpy.int64(-1)`` cannot be json-serialized - # Note that `numbers.Integral` also match booleans but they are already returned above - if isinstance(val, numbers.Integral): - return int(val) - - if isinstance(val, numbers.Real) and (math.isnan(val) or math.isinf(val)): - # see https://www.postgresql.org/docs/current/static/datatype-json.html#JSON-TYPE-MAPPING-TABLE - raise exceptions.ValidationError('nan and inf/-inf can not be serialized to the database') - - # This is for float-like types, like ``numpy.float128`` that are not json-serializable - # Note that `numbers.Real` also match booleans but they are already returned above - if isinstance(val, (numbers.Real, Decimal)): - string_representation = f'{{:.{AIIDA_FLOAT_PRECISION}g}}'.format(val) - new_val = float(string_representation) - if 'e' in string_representation and new_val.is_integer(): - # This is indeed often quite unexpected, because it is going to change the type of the data - # from float to int. But anyway clean_value is changing some types, and we are also bound to what - # our current backends do. - # Currently, in both Django and SQLA (with JSONB attributes), if we store 1.e1, ..., 1.e14, 1.e15, - # they will be stored as floats; instead 1.e16, 1.e17, ... will all be stored as integer anyway, - # even if we don't run this clean_value step. - # So, for consistency, it's better if we do the conversion ourselves here, and we do it for a bit - # smaller numbers than python+[SQL+JSONB] would do (the AiiDA float precision is here 14), so the - # results are consistent, and the hashing will work also after a round trip as expected. - return int(new_val) - return new_val - - # Anything else we do not understand and we refuse - raise exceptions.ValidationError(f'type `{type(val)}` is not supported as it is not json-serializable') - - if isinstance(value, BaseType): - return clean_builtin(value.value) - - if isinstance(value, Mapping): - # Check dictionary before iterables - return {k: clean_value(v) for k, v in value.items()} - - if (isinstance(value, Iterable) and not isinstance(value, str)): - # list, tuple, ... but not a string - # This should also properly take care of dealing with the - # basedatatypes.List object - return [clean_value(v) for v in value] - - # If I don't know what to do I just return the value - # itself - it's not super robust, but relies on duck typing - # (e.g. if there is something that behaves like an integer - # but is not an integer, I still accept it) - - return clean_builtin(value) diff --git a/aiida/orm/logs.py b/aiida/orm/logs.py deleted file mode 100644 index 88cc6612cc..0000000000 --- a/aiida/orm/logs.py +++ /dev/null @@ -1,235 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module for orm logging abstract classes""" -from datetime import datetime -import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type - -from aiida.common import timezone -from aiida.common.lang import classproperty -from aiida.manage import get_manager - -from . import entities - -if TYPE_CHECKING: - from aiida.orm import Node - from aiida.orm.implementation import BackendLog, StorageBackend - from aiida.orm.querybuilder import FilterType, OrderByType - -__all__ = ('Log', 'OrderSpecifier', 'ASCENDING', 'DESCENDING') - -ASCENDING = 'asc' -DESCENDING = 'desc' - - -def OrderSpecifier(field, direction): # pylint: disable=invalid-name - return {field: direction} - - -class LogCollection(entities.Collection['Log']): - """ - This class represents the collection of logs and can be used to create - and retrieve logs. - """ - - @staticmethod - def _entity_base_cls() -> Type['Log']: - return Log - - def create_entry_from_record(self, record: logging.LogRecord) -> Optional['Log']: - """Helper function to create a log entry from a record created as by the python logging library - - :param record: The record created by the logging module - :return: A stored log instance - """ - dbnode_id = record.__dict__.get('dbnode_id', None) - - # Do not store if dbnode_id is not set - if dbnode_id is None: - return None - - metadata = dict(record.__dict__) - - # If an `exc_info` is present, the log message was an exception, so format the full traceback - try: - import traceback - exc_info = metadata.pop('exc_info') - message = ''.join(traceback.format_exception(*exc_info)) - except (TypeError, KeyError): - message = record.getMessage() - - # Stringify the content of `args` if they exist in the metadata to ensure serializability - for key in ['args']: - if key in metadata: - metadata[key] = str(metadata[key]) - - return Log( - time=timezone.make_aware(datetime.fromtimestamp(record.created)), - loggername=record.name, - levelname=record.levelname, - dbnode_id=dbnode_id, - message=message, - metadata=metadata, - backend=self.backend - ) - - def get_logs_for(self, entity: 'Node', order_by: Optional['OrderByType'] = None) -> List['Log']: - """Get all the log messages for a given node and optionally sort - - :param entity: the entity to get logs for - :param order_by: a list of (key, direction) pairs specifying the sort order - - :return: the list of log entries - """ - from . import nodes - - if not isinstance(entity, nodes.Node): - raise Exception('Only node logs are stored') # pylint: disable=broad-exception-raised - - return self.find({'dbnode_id': entity.pk}, order_by=order_by) - - def delete(self, pk: int) -> None: - """Remove a Log entry from the collection with the given id - - :param pk: id of the Log to delete - - :raises `~aiida.common.exceptions.NotExistent`: if Log with ID ``pk`` is not found - """ - return self._backend.logs.delete(pk) - - def delete_all(self) -> None: - """Delete all Logs in the collection - - :raises `~aiida.common.exceptions.IntegrityError`: if all Logs could not be deleted - """ - return self._backend.logs.delete_all() - - def delete_many(self, filters: 'FilterType') -> List[int]: - """Delete Logs based on ``filters`` - - :param filters: filters to pass to the QueryBuilder - :return: (former) ``PK`` s of deleted Logs - - :raises TypeError: if ``filters`` is not a `dict` - :raises `~aiida.common.exceptions.ValidationError`: if ``filters`` is empty - """ - return self._backend.logs.delete_many(filters) - - -class Log(entities.Entity['BackendLog', LogCollection]): - """ - An AiiDA Log entity. Corresponds to a logged message against a particular AiiDA node. - """ - - _CLS_COLLECTION = LogCollection - - def __init__( - self, - time: datetime, - loggername: str, - levelname: str, - dbnode_id: int, - message: str = '', - metadata: Optional[Dict[str, Any]] = None, - backend: Optional['StorageBackend'] = None - ): # pylint: disable=too-many-arguments - """Construct a new log - - :param time: time - :param loggername: name of logger - :param levelname: name of log level - :param dbnode_id: id of database node - :param message: log message - :param metadata: metadata - :param backend: database backend - """ - from aiida.common import exceptions - - if metadata is not None and not isinstance(metadata, dict): - raise TypeError('metadata must be a dict') - - if not loggername or not levelname: - raise exceptions.ValidationError('The loggername and levelname cannot be empty') - - backend = backend or get_manager().get_profile_storage() - model = backend.logs.create( - time=time, - loggername=loggername, - levelname=levelname, - dbnode_id=dbnode_id, - message=message, - metadata=metadata - ) - super().__init__(model) - self.store() # Logs are immutable and automatically stored - - @property - def uuid(self) -> str: - """Return the UUID for this log. - - This identifier is unique across all entities types and backend instances. - - :return: the entity uuid - """ - return self._backend_entity.uuid - - @property - def time(self) -> datetime: - """ - Get the time corresponding to the entry - - :return: The entry timestamp - """ - return self._backend_entity.time - - @property - def loggername(self) -> str: - """ - The name of the logger that created this entry - - :return: The entry loggername - """ - return self._backend_entity.loggername - - @property - def levelname(self) -> str: - """ - The name of the log level - - :return: The entry log level name - """ - return self._backend_entity.levelname - - @property - def dbnode_id(self) -> int: - """ - Get the id of the object that created the log entry - - :return: The id of the object that created the log entry - """ - return self._backend_entity.dbnode_id - - @property - def message(self) -> str: - """ - Get the message corresponding to the entry - - :return: The entry message - """ - return self._backend_entity.message - - @property - def metadata(self) -> Dict[str, Any]: - """ - Get the metadata corresponding to the entry - - :return: The entry metadata - """ - return self._backend_entity.metadata diff --git a/aiida/orm/nodes/__init__.py b/aiida/orm/nodes/__init__.py deleted file mode 100644 index 3af33b89cc..0000000000 --- a/aiida/orm/nodes/__init__.py +++ /dev/null @@ -1,74 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with `Node` sub classes for data and processes.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .attributes import * -from .data import * -from .node import * -from .process import * -from .repository import * - -__all__ = ( - 'AbstractCode', - 'ArrayData', - 'BandsData', - 'BaseType', - 'Bool', - 'CalcFunctionNode', - 'CalcJobNode', - 'CalculationNode', - 'CifData', - 'Code', - 'ContainerizedCode', - 'Data', - 'Dict', - 'EnumData', - 'Float', - 'FolderData', - 'InstalledCode', - 'Int', - 'JsonableData', - 'Kind', - 'KpointsData', - 'List', - 'Node', - 'NodeAttributes', - 'NodeRepository', - 'NumericType', - 'OrbitalData', - 'PortableCode', - 'ProcessNode', - 'ProjectionData', - 'RemoteData', - 'RemoteStashData', - 'RemoteStashFolderData', - 'SinglefileData', - 'Site', - 'Str', - 'StructureData', - 'TrajectoryData', - 'UpfData', - 'WorkChainNode', - 'WorkFunctionNode', - 'WorkflowNode', - 'XyData', - 'cif_from_ase', - 'find_bandgap', - 'has_pycifrw', - 'pycifrw_from_cif', - 'to_aiida_type', -) - -# yapf: enable diff --git a/aiida/orm/nodes/caching.py b/aiida/orm/nodes/caching.py deleted file mode 100644 index 0d647c2eed..0000000000 --- a/aiida/orm/nodes/caching.py +++ /dev/null @@ -1,164 +0,0 @@ -# -*- coding: utf-8 -*- -"""Interface to control caching of a node instance.""" -from __future__ import annotations - -import importlib -import typing as t - -from aiida.common import exceptions -from aiida.common.hashing import make_hash -from aiida.common.lang import type_check - -from ..querybuilder import QueryBuilder - - -class NodeCaching: - """Interface to control caching of a node instance.""" - - # The keys in the extras that are used to store the hash of the node and whether it should be used in caching. - _HASH_EXTRA_KEY: str = '_aiida_hash' - _VALID_CACHE_KEY: str = '_aiida_valid_cache' - - def __init__(self, node: 'Node') -> None: - """Initialize the caching interface.""" - self._node = node - - def get_hash(self, ignore_errors: bool = True, **kwargs: t.Any) -> str | None: - """Return the hash for this node based on its attributes. - - :param ignore_errors: return ``None`` on ``aiida.common.exceptions.HashingError`` (logging the exception) - """ - if not self._node.is_stored: - raise exceptions.InvalidOperation('You can get the hash only after having stored the node') - - return self._get_hash(ignore_errors=ignore_errors, **kwargs) - - def _get_hash(self, ignore_errors: bool = True, **kwargs: t.Any) -> str | None: - """ - Return the hash for this node based on its attributes. - - This will always work, even before storing. - - :param ignore_errors: return ``None`` on ``aiida.common.exceptions.HashingError`` (logging the exception) - """ - try: - return make_hash(self._get_objects_to_hash(), **kwargs) - except exceptions.HashingError: - if not ignore_errors: - raise - if self._node.logger: - self._node.logger.exception('Node hashing failed') - return None - - def _get_objects_to_hash(self) -> list[t.Any]: - """Return a list of objects which should be included in the hash.""" - top_level_module = self._node.__module__.split('.', 1)[0] - try: - version = importlib.import_module(top_level_module).__version__ - except (ImportError, AttributeError) as exc: - raise exceptions.HashingError("The node's package version could not be determined") from exc - objects = [ - version, - { - key: val - for key, val in self._node.base.attributes.items() - if key not in self._node._hash_ignored_attributes and key not in self._node._updatable_attributes # pylint: disable=unsupported-membership-test,protected-access - }, - self._node.base.repository.hash(), - self._node.computer.uuid if self._node.computer is not None else None - ] - return objects - - def rehash(self) -> None: - """Regenerate the stored hash of the Node.""" - self._node.base.extras.set(self._HASH_EXTRA_KEY, self.get_hash()) - - def clear_hash(self) -> None: - """Sets the stored hash of the Node to None.""" - self._node.base.extras.set(self._HASH_EXTRA_KEY, None) - - def get_cache_source(self) -> str | None: - """Return the UUID of the node that was used in creating this node from the cache, or None if it was not cached. - - :return: source node UUID or None - """ - return self._node.base.extras.get('_aiida_cached_from', None) - - @property - def is_created_from_cache(self) -> bool: - """Return whether this node was created from a cached node. - - :return: boolean, True if the node was created by cloning a cached node, False otherwise - """ - return self.get_cache_source() is not None - - def _get_same_node(self) -> 'Node' | None: - """Returns a stored node from which the current Node can be cached or None if it does not exist - - If a node is returned it is a valid cache, meaning its `_aiida_hash` extra matches `self.get_hash()`. - If there are multiple valid matches, the first one is returned. - If no matches are found, `None` is returned. - - :return: a stored `Node` instance with the same hash as this code or None - - Note: this should be only called on stored nodes, or internally from .store() since it first calls - clean_value() on the attributes to normalise them. - """ - try: - return next(self._iter_all_same_nodes(allow_before_store=True)) - except StopIteration: - return None - - def get_all_same_nodes(self) -> list['Node']: - """Return a list of stored nodes which match the type and hash of the current node. - - All returned nodes are valid caches, meaning their `_aiida_hash` extra matches `self.get_hash()`. - - Note: this can be called only after storing a Node (since at store time attributes will be cleaned with - `clean_value` and the hash should become idempotent to the action of serialization/deserialization) - """ - return list(self._iter_all_same_nodes()) - - def _iter_all_same_nodes(self, allow_before_store=False) -> t.Iterator['Node']: - """ - Returns an iterator of all same nodes. - - Note: this should be only called on stored nodes, or internally from .store() since it first calls - clean_value() on the attributes to normalise them. - """ - if not allow_before_store and not self._node.is_stored: - raise exceptions.InvalidOperation('You can get the hash only after having stored the node') - - node_hash = self._get_hash() - - if not node_hash or not self._node._cachable: # pylint: disable=protected-access - return iter(()) - - builder = QueryBuilder(backend=self._node.backend) - builder.append(self._node.__class__, filters={f'extras.{self._HASH_EXTRA_KEY}': node_hash}, subclassing=False) - - return ( - node for node in builder.all(flat=True) if node.base.caching.is_valid_cache - ) # type: ignore[misc,union-attr] - - @property - def is_valid_cache(self) -> bool: - """Hook to exclude certain ``Node`` classes from being considered a valid cache. - - The base class assumes that all node instances are valid to cache from, unless the ``_VALID_CACHE_KEY`` extra - has been set to ``False`` explicitly. Subclasses can override this property with more specific logic, but should - probably also consider the value returned by this base class. - """ - return self._node.base.extras.get(self._VALID_CACHE_KEY, True) - - @is_valid_cache.setter - def is_valid_cache(self, valid: bool) -> None: - """Set whether this node instance is considered valid for caching or not. - - If a node instance has this property set to ``False``, it will never be used in the caching mechanism, unless - the subclass overrides the ``is_valid_cache`` property and ignores it implementation completely. - - :param valid: whether the node is valid or invalid for use in caching. - """ - type_check(valid, bool) - self._node.base.extras.set(self._VALID_CACHE_KEY, valid) diff --git a/aiida/orm/nodes/comments.py b/aiida/orm/nodes/comments.py deleted file mode 100644 index d926edff23..0000000000 --- a/aiida/orm/nodes/comments.py +++ /dev/null @@ -1,64 +0,0 @@ -# -*- coding: utf-8 -*- -"""Interface for comments of a node instance.""" -from __future__ import annotations - -import typing as t - -from ..comments import Comment -from ..users import User - - -class NodeComments: - """Interface for comments of a node instance.""" - - def __init__(self, node: 'Node') -> None: - """Initialize the comments interface.""" - self._node = node - - def add(self, content: str, user: t.Optional[User] = None) -> Comment: - """Add a new comment. - - :param content: string with comment - :param user: the user to associate with the comment, will use default if not supplied - :return: the newly created comment - """ - user = user or User.collection(self._node.backend).get_default() - return Comment(node=self._node, user=user, content=content).store() - - def get(self, identifier: int) -> Comment: - """Return a comment corresponding to the given identifier. - - :param identifier: the comment pk - :raise aiida.common.NotExistent: if the comment with the given id does not exist - :raise aiida.common.MultipleObjectsError: if the id cannot be uniquely resolved to a comment - :return: the comment - """ - return Comment.collection(self._node.backend).get(dbnode_id=self._node.pk, id=identifier) - - def all(self) -> list[Comment]: - """Return a sorted list of comments for this node. - - :return: the list of comments, sorted by pk - """ - return Comment.collection(self._node.backend - ).find(filters={'dbnode_id': self._node.pk}, order_by=[{ - 'id': 'asc' - }]) - - def update(self, identifier: int, content: str) -> None: - """Update the content of an existing comment. - - :param identifier: the comment pk - :param content: the new comment content - :raise aiida.common.NotExistent: if the comment with the given id does not exist - :raise aiida.common.MultipleObjectsError: if the id cannot be uniquely resolved to a comment - """ - comment = Comment.collection(self._node.backend).get(dbnode_id=self._node.pk, id=identifier) - comment.set_content(content) - - def remove(self, identifier: int) -> None: - """Delete an existing comment. - - :param identifier: the comment pk - """ - Comment.collection(self._node.backend).delete(identifier) diff --git a/aiida/orm/nodes/data/__init__.py b/aiida/orm/nodes/data/__init__.py deleted file mode 100644 index 395de5f979..0000000000 --- a/aiida/orm/nodes/data/__init__.py +++ /dev/null @@ -1,79 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with `Node` sub classes for data structures.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .array import * -from .base import * -from .bool import * -from .cif import * -from .code import * -from .data import * -from .dict import * -from .enum import * -from .float import * -from .folder import * -from .int import * -from .jsonable import * -from .list import * -from .numeric import * -from .orbital import * -from .remote import * -from .singlefile import * -from .str import * -from .structure import * -from .upf import * - -__all__ = ( - 'AbstractCode', - 'ArrayData', - 'BandsData', - 'BaseType', - 'Bool', - 'CifData', - 'Code', - 'ContainerizedCode', - 'Data', - 'Dict', - 'EnumData', - 'Float', - 'FolderData', - 'InstalledCode', - 'Int', - 'JsonableData', - 'Kind', - 'KpointsData', - 'List', - 'NumericType', - 'OrbitalData', - 'PortableCode', - 'ProjectionData', - 'RemoteData', - 'RemoteStashData', - 'RemoteStashFolderData', - 'SinglefileData', - 'Site', - 'Str', - 'StructureData', - 'TrajectoryData', - 'UpfData', - 'XyData', - 'cif_from_ase', - 'find_bandgap', - 'has_pycifrw', - 'pycifrw_from_cif', - 'to_aiida_type', -) - -# yapf: enable diff --git a/aiida/orm/nodes/data/array/__init__.py b/aiida/orm/nodes/data/array/__init__.py deleted file mode 100644 index f12feedfbe..0000000000 --- a/aiida/orm/nodes/data/array/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with `Node` sub classes for array based data structures.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .array import * -from .bands import * -from .kpoints import * -from .projection import * -from .trajectory import * -from .xy import * - -__all__ = ( - 'ArrayData', - 'BandsData', - 'KpointsData', - 'ProjectionData', - 'TrajectoryData', - 'XyData', - 'find_bandgap', -) - -# yapf: enable diff --git a/aiida/orm/nodes/data/array/array.py b/aiida/orm/nodes/data/array/array.py deleted file mode 100644 index 89cb8c5357..0000000000 --- a/aiida/orm/nodes/data/array/array.py +++ /dev/null @@ -1,251 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -AiiDA ORM data class storing (numpy) arrays -""" -from ..data import Data - -__all__ = ('ArrayData',) - - -class ArrayData(Data): - """ - Store a set of arrays on disk (rather than on the database) in an efficient - way using numpy.save() (therefore, this class requires numpy to be - installed). - - Each array is stored within the Node folder as a different .npy file. - - :note: Before storing, no caching is done: if you perform a - :py:meth:`.get_array` call, the array will be re-read from disk. - If instead the ArrayData node has already been stored, - the array is cached in memory after the first read, and the cached array - is used thereafter. - If too much RAM memory is used, you can clear the - cache with the :py:meth:`.clear_internal_cache` method. - """ - array_prefix = 'array|' - _cached_arrays = None - - def initialize(self): - super().initialize() - self._cached_arrays = {} - - def delete_array(self, name): - """ - Delete an array from the node. Can only be called before storing. - - :param name: The name of the array to delete from the node. - """ - fname = f'{name}.npy' - if fname not in self.base.repository.list_object_names(): - raise KeyError(f"Array with name '{name}' not found in node pk= {self.pk}") - - # remove both file and attribute - self.base.repository.delete_object(fname) - try: - self.base.attributes.delete(f'{self.array_prefix}{name}') - except (KeyError, AttributeError): - # Should not happen, but do not crash if for some reason the property was not set. - pass - - def get_arraynames(self): - """ - Return a list of all arrays stored in the node, listing the files (and - not relying on the properties). - - .. versionadded:: 0.7 - Renamed from arraynames - """ - return self._arraynames_from_properties() - - def _arraynames_from_files(self): - """ - Return a list of all arrays stored in the node, listing the files (and - not relying on the properties). - """ - return [i[:-4] for i in self.base.repository.list_object_names() if i.endswith('.npy')] - - def _arraynames_from_properties(self): - """ - Return a list of all arrays stored in the node, listing the attributes - starting with the correct prefix. - """ - return [i[len(self.array_prefix):] for i in self.base.attributes.keys() if i.startswith(self.array_prefix)] - - def get_shape(self, name): - """ - Return the shape of an array (read from the value cached in the - properties for efficiency reasons). - - :param name: The name of the array. - """ - return tuple(self.base.attributes.get(f'{self.array_prefix}{name}')) - - def get_iterarrays(self): - """ - Iterator that returns tuples (name, array) for each array stored in the node. - - .. versionadded:: 1.0 - Renamed from iterarrays - """ - for name in self.get_arraynames(): - yield (name, self.get_array(name)) - - def get_array(self, name): - """ - Return an array stored in the node - - :param name: The name of the array to return. - """ - import numpy - - def get_array_from_file(self, name): - """Return the array stored in a .npy file""" - filename = f'{name}.npy' - - if filename not in self.base.repository.list_object_names(): - raise KeyError(f'Array with name `{name}` not found in ArrayData<{self.pk}>') - - # Open a handle in binary read mode as the arrays are written as binary files as well - with self.base.repository.open(filename, mode='rb') as handle: - return numpy.load(handle, allow_pickle=False) # pylint: disable=unexpected-keyword-arg - - # Return with proper caching if the node is stored, otherwise always re-read from disk - if not self.is_stored: - return get_array_from_file(self, name) - - if name not in self._cached_arrays: - self._cached_arrays[name] = get_array_from_file(self, name) - - return self._cached_arrays[name] - - def clear_internal_cache(self): - """ - Clear the internal memory cache where the arrays are stored after being - read from disk (used in order to reduce at minimum the readings from - disk). - This function is useful if you want to keep the node in memory, but you - do not want to waste memory to cache the arrays in RAM. - """ - self._cached_arrays = {} - - def set_array(self, name, array): - """ - Store a new numpy array inside the node. Possibly overwrite the array - if it already existed. - - Internally, it stores a name.npy file in numpy format. - - :param name: The name of the array. - :param array: The numpy array to store. - """ - import re - import tempfile - - import numpy - - if not isinstance(array, numpy.ndarray): - raise TypeError('ArrayData can only store numpy arrays. Convert the object to an array first') - - # Check if the name is valid - if not name or re.sub('[0-9a-zA-Z_]', '', name): - raise ValueError( - 'The name assigned to the array ({}) is not valid,' - 'it can only contain digits, letters and underscores' - ) - - # Write the array to a temporary file, and then add it to the repository of the node - with tempfile.NamedTemporaryFile() as handle: - numpy.save(handle, array, allow_pickle=False) - - # Flush and rewind the handle, otherwise the command to store it in the repo will write an empty file - handle.flush() - handle.seek(0) - - # Write the numpy array to the repository, keeping the byte representation - self.base.repository.put_object_from_filelike(handle, f'{name}.npy') - - # Store the array name and shape for querying purposes - self.base.attributes.set(f'{self.array_prefix}{name}', list(array.shape)) - - def _validate(self): - """ - Check if the list of .npy files stored inside the node and the - list of properties match. Just a name check, no check on the size - since this would require to reload all arrays and this may take time - and memory. - """ - from aiida.common.exceptions import ValidationError - - files = self._arraynames_from_files() - properties = self._arraynames_from_properties() - - if set(files) != set(properties): - raise ValidationError( - f'Mismatch of files and properties for ArrayData node (pk= {self.pk}): {files} vs. {properties}' - ) - super()._validate() - - def _get_array_entries(self): - """Return a dictionary with the different array entries. - - The idea is that this dictionary contains the array name as a key and - the value is the numpy array transformed into a list. This is so that - it can be transformed into a json object. - """ - - array_dict = {} - for key, val in self.get_iterarrays(): - - array_dict[key] = clean_array(val) - return array_dict - - def _prepare_json(self, main_file_name='', comments=True): # pylint: disable=unused-argument - """Dump the content of the arrays stored in this node into JSON format. - - :param comments: if True, includes comments (if it makes sense for the given format) - """ - import json - - from aiida import get_file_header - - json_dict = self._get_array_entries() - json_dict['original_uuid'] = self.uuid - - if comments: - json_dict['comments'] = get_file_header(comment_char='') - - return json.dumps(json_dict).encode('utf-8'), {} - - -def clean_array(array): - """ - Replacing np.nan and np.inf/-np.inf for Nones. - - The function will also sanitize the array removing ``np.nan`` and ``np.inf`` - for ``None`` of this way the resulting JSON is always valid. - Both ``np.nan`` and ``np.inf``/``-np.inf`` are set to None to be in - accordance with the - `ECMA-262 standard `_. - - :param array: input array to be cleaned - :return: cleaned list to be serialized - :rtype: list - """ - import numpy as np - - output = np.reshape( - np.asarray([ - entry if not np.isnan(entry) and not np.isinf(entry) else None for entry in array.flatten().tolist() - ]), array.shape - ) - - return output.tolist() diff --git a/aiida/orm/nodes/data/array/bands.py b/aiida/orm/nodes/data/array/bands.py deleted file mode 100644 index 590b33a57d..0000000000 --- a/aiida/orm/nodes/data/array/bands.py +++ /dev/null @@ -1,1925 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=too-many-lines -""" -This module defines the classes related to band structures or dispersions -in a Brillouin zone, and how to operate on them. -""" -import json -from string import Template - -import numpy - -from aiida.common.exceptions import ValidationError -from aiida.common.utils import join_labels, prettify_labels - -from .kpoints import KpointsData - -__all__ = ('BandsData', 'find_bandgap') - - -def prepare_header_comment(uuid, plot_info, comment_char='#'): - """Prepare the header.""" - from aiida import get_file_header - - filetext = [] - filetext += get_file_header(comment_char='').splitlines() - filetext.append('') - filetext.append(f'Dumped from BandsData UUID={uuid}') - filetext.append('\tpoints\tbands') - filetext.append('\t{}\t{}'.format(*plot_info['y'].shape)) - filetext.append('') - filetext.append('\tlabel\tpoint') - for label in plot_info['raw_labels']: - filetext.append(f'\t{label[1]}\t{label[0]:.8f}') - - return '\n'.join(f'{comment_char} {line}' for line in filetext) - - -def find_bandgap(bandsdata, number_electrons=None, fermi_energy=None): - """ - Tries to guess whether the bandsdata represent an insulator. - This method is meant to be used only for electronic bands (not phonons) - By default, it will try to use the occupations to guess the number of - electrons and find the Fermi Energy, otherwise, it can be provided - explicitely. - Also, there is an implicit assumption that the kpoints grid is - "sufficiently" dense, so that the bandsdata are not missing the - intersection between valence and conduction band if present. - Use this function with care! - - :param number_electrons: (optional, float) number of electrons in the unit cell - :param fermi_energy: (optional, float) value of the fermi energy. - - :note: By default, the algorithm uses the occupations array - to guess the number of electrons and the occupied bands. This is to be - used with care, because the occupations could be smeared so at a - non-zero temperature, with the unwanted effect that the conduction bands - might be occupied in an insulator. - Prefer to pass the number_of_electrons explicitly - - :note: Only one between number_electrons and fermi_energy can be specified at the - same time. - - :return: (is_insulator, gap), where is_insulator is a boolean, and gap a - float. The gap is None in case of a metal, zero when the homo is - equal to the lumo (e.g. in semi-metals). - """ - - # pylint: disable=too-many-return-statements,too-many-branches,too-many-statements,no-else-return - - def nint(num): - """ - Stable rounding function - """ - if num > 0: - return int(num + .5) - return int(num - .5) - - if fermi_energy and number_electrons: - raise ValueError('Specify either the number of electrons or the Fermi energy, but not both') - - try: - stored_bands = bandsdata.get_bands() - except KeyError: - raise KeyError('Cannot do much of a band analysis without bands') - - if len(stored_bands.shape) == 3: - # I write the algorithm for the generic case of having both the spin up and spin down array - # put all spins on one band per kpoint - bands = numpy.concatenate(stored_bands, axis=1) - else: - bands = stored_bands - - # analysis on occupations: - if fermi_energy is None: - - num_kpoints = len(bands) - - if number_electrons is None: - try: - _, stored_occupations = bandsdata.get_bands(also_occupations=True) - except KeyError: - raise KeyError("Cannot determine metallicity if I don't have either fermi energy, or occupations") - - # put the occupations in the same order of bands, also in case of multiple bands - if len(stored_occupations.shape) == 3: - # I write the algorithm for the generic case of having both the - # spin up and spin down array - - # put all spins on one band per kpoint - occupations = numpy.concatenate(stored_occupations, axis=1) - else: - occupations = stored_occupations - - # now sort the bands by energy - # Note: I am sort of assuming that I have an electronic ground state - - # sort the bands by energy, and reorder the occupations accordingly - # since after joining the two spins, I might have unsorted stuff - bands, occupations = [ - numpy.array(y) for y in zip( - *[ - list(zip(*j)) for j in [ - sorted(zip(i[0].tolist(), i[1].tolist()), key=lambda x: x[0]) - for i in zip(bands, occupations) - ] - ] - ) - ] - number_electrons = int(round(sum(sum(i) for i in occupations) / num_kpoints)) - - homo_indexes = [numpy.where(numpy.array([nint(_) for _ in x]) > 0)[0][-1] for x in occupations] - if len(set(homo_indexes)) > 1: # there must be intersections of valence and conduction bands - return False, None - - homo = [_[0][_[1]] for _ in zip(bands, homo_indexes)] - try: - lumo = [_[0][_[1] + 1] for _ in zip(bands, homo_indexes)] - except IndexError: - raise ValueError( - 'To understand if it is a metal or insulator, ' - 'need more bands than n_band=number_electrons' - ) - - else: - bands = numpy.sort(bands) - number_electrons = int(number_electrons) - - # find the zero-temperature occupation per band (1 for spin-polarized - # calculation, 2 otherwise) - number_electrons_per_band = 4 - len(stored_bands.shape) # 1 or 2 - # gather the energies of the homo band, for every kpoint - homo = [i[number_electrons // number_electrons_per_band - 1] for i in bands] # take the nth level - try: - # gather the energies of the lumo band, for every kpoint - lumo = [i[number_electrons // number_electrons_per_band] for i in bands] # take the n+1th level - except IndexError: - raise ValueError( - 'To understand if it is a metal or insulator, ' - 'need more bands than n_band=number_electrons' - ) - - if number_electrons % 2 == 1 and len(stored_bands.shape) == 2: - # if #electrons is odd and we have a non spin polarized calculation - # it must be a metal and I don't need further checks - return False, None - - # if the nth band crosses the (n+1)th, it is an insulator - gap = min(lumo) - max(homo) - if gap == 0.: - return False, 0. - - if gap < 0.: - return False, None - - return True, gap - - # analysis on the fermi energy - else: - # reorganize the bands, rather than per kpoint, per energy level - - # I need the bands sorted by energy - bands.sort() - - levels = bands.transpose() - max_mins = [(max(i), min(i)) for i in levels] - - if fermi_energy > bands.max(): - raise ValueError("The Fermi energy is above all band energies, don't know what to do") - if fermi_energy < bands.min(): - raise ValueError("The Fermi energy is below all band energies, don't know what to do.") - - # one band is crossed by the fermi energy - if any(i[1] < fermi_energy and fermi_energy < i[0] for i in max_mins): # pylint: disable=chained-comparison - return False, None - - # case of semimetals, fermi energy at the crossing of two bands - # this will only work if the dirac point is computed! - if (any(i[0] == fermi_energy for i in max_mins) and any(i[1] == fermi_energy for i in max_mins)): - return False, 0. - - # insulating case, take the max of the band maxima below the fermi energy - homo = max(i[0] for i in max_mins if i[0] < fermi_energy) - # take the min of the band minima above the fermi energy - lumo = min(i[1] for i in max_mins if i[1] > fermi_energy) - gap = lumo - homo - if gap <= 0.: - raise RuntimeError('Something wrong has been implemented. Revise the code!') - return True, gap - - -class BandsData(KpointsData): - """ - Class to handle bands data - """ - - def set_kpointsdata(self, kpointsdata): - """ - Load the kpoints from a kpoint object. - :param kpointsdata: an instance of KpointsData class - """ - if not isinstance(kpointsdata, KpointsData): - raise ValueError('kpointsdata must be of the KpointsData class') - try: - self.cell = kpointsdata.cell - except AttributeError: - pass - try: - self.pbc = kpointsdata.pbc - except AttributeError: - pass - try: - the_kpoints = kpointsdata.get_kpoints() - except AttributeError: - the_kpoints = None - try: - the_weights = kpointsdata.get_kpoints(also_weights=True)[1] - except AttributeError: - the_weights = None - self.set_kpoints(the_kpoints, weights=the_weights) - try: - self.labels = kpointsdata.labels - except (AttributeError, TypeError): - self.labels = [] - - def _validate_bands_occupations(self, bands, occupations=None, labels=None): - """ - Validate the list of bands and of occupations before storage. - Kpoints must be set in advance. - Bands and occupations must be convertible into arrays of - Nkpoints x Nbands floats or Nspins x Nkpoints x Nbands; Nkpoints must - correspond to the number of kpoints. - """ - # pylint: disable=too-many-branches - try: - kpoints = self.get_kpoints() - except AttributeError: - raise AttributeError('Must first set the kpoints, then the bands') - - the_bands = numpy.array(bands) - - if len(the_bands.shape) not in [2, 3]: - raise ValueError( - 'Bands must be an array of dimension 2' - '([N_kpoints, N_bands]) or of dimension 3 ' - ' ([N_arrays, N_kpoints, N_bands]), found instead {}'.format(len(the_bands.shape)) - ) - - list_of_arrays_to_be_checked = [] - - # check that the shape of everything is consistent with the kpoints - num_kpoints_from_bands = the_bands.shape[0] if len(the_bands.shape) == 2 else the_bands.shape[1] - if num_kpoints_from_bands != len(kpoints): - raise ValueError('There must be energy values for every kpoint') - - if occupations is not None: - the_occupations = numpy.array(occupations) - if the_occupations.shape != the_bands.shape: - raise ValueError( - f'Shape of occupations {the_occupations.shape} different from shapeshape of bands {the_bands.shape}' - ) - - if not the_bands.dtype.type == numpy.float64: - list_of_arrays_to_be_checked.append([the_occupations, 'occupations']) - else: - the_occupations = None - # list_of_arrays_to_be_checked = [ [the_bands,'bands'] ] - - # check that there every element is a float - if not the_bands.dtype.type == numpy.float64: - list_of_arrays_to_be_checked.append([the_bands, 'bands']) - - for x, msg in list_of_arrays_to_be_checked: - try: - [float(_) for _ in x.flatten() if _ is not None] - except (TypeError, ValueError): - raise ValueError(f'The {msg} array can only contain float or None values') - - # check the labels - if labels is not None: - if isinstance(labels, str): - the_labels = [str(labels)] - elif isinstance(labels, (tuple, list)) and all(isinstance(_, str) for _ in labels): - the_labels = [str(_) for _ in labels] - else: - raise ValidationError( - 'Band labels have an unrecognized type ({})' - 'but should be a string or a list of strings'.format(labels.__class__) - ) - - if len(the_bands.shape) == 2 and len(the_labels) != 1: - raise ValidationError('More array labels than the number of arrays') - elif len(the_bands.shape) == 3 and len(the_labels) != the_bands.shape[0]: - raise ValidationError('More array labels than the number of arrays') - else: - the_labels = None - - return the_bands, the_occupations, the_labels - - def set_bands(self, bands, units=None, occupations=None, labels=None): - """ - Set an array of band energies of dimension (nkpoints x nbands). - Kpoints must be set in advance. Can contain floats or None. - :param bands: a list of nkpoints lists of nbands bands, or a 2D array - of shape (nkpoints x nbands), with band energies for each kpoint - :param units: optional, energy units - :param occupations: optional, a 2D list or array of floats of same - shape as bands, with the occupation associated to each band - """ - # checks bands and occupations - the_bands, the_occupations, the_labels = self._validate_bands_occupations(bands, occupations, labels) - # set bands and their units - self.set_array('bands', the_bands) - self.units = units - - if the_labels is not None: - self.base.attributes.set('array_labels', the_labels) - - if the_occupations is not None: - # set occupations - self.set_array('occupations', the_occupations) - - @property - def array_labels(self): - """ - Get the labels associated with the band arrays - """ - return self.base.attributes.get('array_labels', None) - - @property - def units(self): - """ - Units in which the data in bands were stored. A string - """ - # return copy.deepcopy(self._pbc) - return self.base.attributes.get('units') - - @units.setter - def units(self, value): - """ - Set the value of pbc, i.e. a tuple of three booleans, indicating if the - cell is periodic in the 1,2,3 crystal direction - """ - the_str = str(value) - self.base.attributes.set('units', the_str) - - def _set_pbc(self, value): - """ - validate the pbc, then store them - """ - from aiida.common.exceptions import ModificationNotAllowed - from aiida.orm.nodes.data.structure import get_valid_pbc - - if self.is_stored: - raise ModificationNotAllowed('The KpointsData object cannot be modified, it has already been stored') - the_pbc = get_valid_pbc(value) - self.base.attributes.set('pbc1', the_pbc[0]) - self.base.attributes.set('pbc2', the_pbc[1]) - self.base.attributes.set('pbc3', the_pbc[2]) - - def get_bands(self, also_occupations=False, also_labels=False): - """ - Returns an array (nkpoints x num_bands or nspins x nkpoints x num_bands) - of energies. - :param also_occupations: if True, returns also the occupations array. - Default = False - """ - try: - bands = numpy.array(self.get_array('bands')) - except KeyError: - raise AttributeError('No stored bands has been found') - - to_return = [bands] - - if also_occupations: - try: - occupations = numpy.array(self.get_array('occupations')) - except KeyError: - raise AttributeError('No occupations were set') - to_return.append(occupations) - - if also_labels: - to_return.append(self.array_labels) - - if len(to_return) == 1: - return bands - - return to_return - - def _get_bandplot_data(self, cartesian, prettify_format=None, join_symbol=None, get_segments=False, y_origin=0.): - """ - Get data to plot a band structure - - :param cartesian: if True, distances (for the x-axis) are computed in - cartesian coordinates, otherwise they are computed in reciprocal - coordinates. cartesian=True will fail if no cell has been set. - :param prettify_format: by default, strings are not prettified. If you want - to prettify them, pass a valid prettify_format string (see valid options - in the docstring of :py:func:prettify_labels). - :param join_symbols: by default, strings are not joined. If you pass a string, - this is used to join strings that are much closer than a given threshold. - The most typical string is the pipe symbol: ``|``. - :param get_segments: if True, also computes the band split into segments - :param y_origin: if present, shift bands so to set the value specified at ``y=0`` - :return: a plot_info dictiorary, whose keys are ``x`` (array of distances - for the x axis of the plot); ``y`` (array of bands), ``labels`` (list - of tuples in the format (float x value of the label, label string), - ``band_type_idx`` (array containing an index for each band: if there is only - one spin, then it's an array of zeros, of length equal to the number of bands - at each point; if there are two spins, then it's an array of zeros or ones - depending on the type of spin; the length is always equalt to the total - number of bands per kpoint). - """ - # pylint: disable=too-many-locals,too-many-branches,too-many-statements - # load the x and y's of the graph - stored_bands = self.get_bands() - if len(stored_bands.shape) == 2: - bands = stored_bands - band_type_idx = numpy.array([0] * stored_bands.shape[1]) - two_band_types = False - elif len(stored_bands.shape) == 3: - bands = numpy.concatenate(stored_bands, axis=1) - band_type_idx = numpy.array([0] * stored_bands.shape[2] + [1] * stored_bands.shape[2]) - two_band_types = True - else: - raise ValueError('Unexpected shape of bands') - - bands -= y_origin - - # here I build the x distances on the graph (in cartesian coordinates - # if cartesian==True AND if the cell was set, otherwise in reciprocal - # coordinates) - try: - kpoints = self.get_kpoints(cartesian=cartesian) - except AttributeError: - # this error is happening if cartesian==True and if no cell has been - # set -> we switch to reciprocal coordinates to compute distances - kpoints = self.get_kpoints() - # I take advantage of the path to recognize discontinuities - try: - labels = self.labels - labels_indices = [i[0] for i in labels] - except (AttributeError, TypeError): - labels = [] - labels_indices = [] - - # since I can have discontinuous paths, I set on those points the distance to zero - # as a result, where there are discontinuities in the path, - # I have two consecutive points with the same x coordinate - distances = [ - numpy.linalg.norm(kpoints[i] - - kpoints[i - 1]) if not (i in labels_indices and i - 1 in labels_indices) else 0. - for i in range(1, len(kpoints)) - ] - x = [float(sum(distances[:i])) for i in range(len(distances) + 1)] - - # transform the index of the labels in the coordinates of x - raw_labels = [(x[i[0]], i[1]) for i in labels] - - the_labels = raw_labels - - if prettify_format: - the_labels = prettify_labels(the_labels, format=prettify_format) - if join_symbol: - the_labels = join_labels(the_labels, join_symbol=join_symbol) - - plot_info = {} - plot_info['x'] = x - plot_info['y'] = bands - plot_info['band_type_idx'] = band_type_idx - plot_info['raw_labels'] = raw_labels - plot_info['labels'] = the_labels - - if get_segments: - plot_info['path'] = [] - plot_info['paths'] = [] - - if len(labels) > 1: - # I add an empty label that points to the first band if the first label does not do it - if labels[0][0] != 0: - labels.insert(0, (0, '')) - # I add an empty label that points to the last band if the last label does not do it - if labels[-1][0] != len(bands) - 1: - labels.append((len(bands) - 1, '')) - for (position_from, label_from), (position_to, label_to) in zip(labels[:-1], labels[1:]): - if position_to - position_from > 1: - # Create a new path line only if there are at least two points, - # otherwise it is probably just a discontinuity point in the band - # structure (e.g. Gamma-X|Y-Gamma), where X and Y would be two - # consecutive points, but there is no path between them - plot_info['path'].append([label_from, label_to]) - path_dict = { - 'length': position_to - position_from, - 'from': label_from, - 'to': label_to, - 'values': bands[position_from:position_to + 1, :].transpose().tolist(), - 'x': x[position_from:position_to + 1], - 'two_band_types': two_band_types, - } - plot_info['paths'].append(path_dict) - else: - label_from = '0' - label_to = '1' - path_dict = { - 'length': bands.shape[0] - 1, - 'from': label_from, - 'to': label_to, - 'values': bands.transpose().tolist(), - 'x': x, - 'two_band_types': two_band_types, - } - plot_info['paths'].append(path_dict) - plot_info['path'].append([label_from, label_to]) - - return plot_info - - def _prepare_agr_batch(self, main_file_name='', comments=True, prettify_format=None): - """ - Prepare two files, data and batch, to be plot with xmgrace as: - xmgrace -batch file.dat - - :param main_file_name: if the user asks to write the main content on a - file, this contains the filename. This should be used to infer a - good filename for the additional files. - In this case, we remove the extension, and add '_data.dat' - :param comments: if True, print comments (if it makes sense for the given - format) - :param prettify_format: if None, use the default prettify format. Otherwise - specify a string with the prettifier to use. - """ - # pylint: disable=too-many-locals - import os - - dat_filename = os.path.splitext(main_file_name)[0] + '_data.dat' - - if prettify_format is None: - # Default. Specified like this to allow caller functions to pass 'None' - prettify_format = 'agr_seekpath' - - plot_info = self._get_bandplot_data(cartesian=True, prettify_format=prettify_format, join_symbol='|') - - bands = plot_info['y'] - x = plot_info['x'] - labels = plot_info['labels'] - - num_bands = bands.shape[1] - - # axis limits - y_max_lim = bands.max() - y_min_lim = bands.min() - x_min_lim = min(x) # this isn't a numpy array, but a list - x_max_lim = max(x) - - # first prepare the xy coordinates of the sets - raw_data, _ = self._prepare_dat_blocks(plot_info) - - batch = [] - if comments: - batch.append(prepare_header_comment(self.uuid, plot_info, comment_char='#')) - - batch.append(f'READ XY "{dat_filename}"') - - # axis limits - batch.append(f'world {x_min_lim}, {y_min_lim}, {x_max_lim}, {y_max_lim}') - - # axis label - batch.append('yaxis label "Dispersion"') - - # axis ticks - batch.append('xaxis tick place both') - batch.append('xaxis tick spec type both') - batch.append(f'xaxis tick spec {len(labels)}') - # set the name of the special points - for index, label in enumerate(labels): - batch.append(f'xaxis tick major {index}, {label[0]}') - batch.append(f'xaxis ticklabel {index}, "{label[1]}"') - batch.append('xaxis tick major color 7') - batch.append('xaxis tick major grid on') - - # minor graphical tweak - batch.append('yaxis tick minor ticks 3') - batch.append('frame linewidth 1.0') - - # use helvetica fonts - batch.append('map font 4 to "Helvetica", "Helvetica"') - batch.append('yaxis label font 4') - batch.append('xaxis label font 4') - - # set color and linewidths of bands - for index in range(num_bands): - batch.append(f's{index} line color 1') - batch.append(f's{index} linewidth 1') - - batch_data = '\n'.join(batch) + '\n' - extra_files = {dat_filename: raw_data} - - return batch_data.encode('utf-8'), extra_files - - def _prepare_dat_multicolumn(self, main_file_name='', comments=True): # pylint: disable=unused-argument - """ - Write an N x M matrix. First column is the distance between kpoints, - The other columns are the bands. Header contains number of kpoints and - the number of bands (commented). - - :param comments: if True, print comments (if it makes sense for the given - format) - """ - plot_info = self._get_bandplot_data(cartesian=True, prettify_format=None, join_symbol='|') - - bands = plot_info['y'] - x = plot_info['x'] - - return_text = [] - if comments: - return_text.append(prepare_header_comment(self.uuid, plot_info, comment_char='#')) - - for i in zip(x, bands): - line = [f'{i[0]:.8f}'] + [f'{j:.8f}' for j in i[1]] - return_text.append('\t'.join(line)) - - return ('\n'.join(return_text) + '\n').encode('utf-8'), {} - - def _prepare_dat_blocks(self, main_file_name='', comments=True): # pylint: disable=unused-argument - """ - Format suitable for gnuplot using blocks. - Columns with x and y (path and band energy). Several blocks, separated - by two empty lines, one per energy band. - - :param comments: if True, print comments (if it makes sense for the given - format) - """ - plot_info = self._get_bandplot_data(cartesian=True, prettify_format=None, join_symbol='|') - - bands = plot_info['y'] - x = plot_info['x'] - - return_text = [] - if comments: - return_text.append(prepare_header_comment(self.uuid, plot_info, comment_char='#')) - - for band in numpy.transpose(bands): - for i in zip(x, band): - line = [f'{i[0]:.8f}', f'{i[1]:.8f}'] - return_text.append('\t'.join(line)) - return_text.append('') - return_text.append('') - - return '\n'.join(return_text).encode('utf-8'), {} - - def _matplotlib_get_dict( - self, - main_file_name='', - comments=True, - title='', - legend=None, - legend2=None, - y_max_lim=None, - y_min_lim=None, - y_origin=0., - prettify_format=None, - **kwargs - ): # pylint: disable=unused-argument - """ - Prepare the data to send to the python-matplotlib plotting script. - - :param comments: if True, print comments (if it makes sense for the given - format) - :param plot_info: a dictionary - :param setnumber_offset: an offset to be applied to all set numbers - (i.e. s0 is replaced by s[offset], s1 by s[offset+1], etc.) - :param color_number: the color number for lines, symbols, error bars - and filling (should be less than the parameter MAX_NUM_AGR_COLORS - defined below) - :param title: the title - :param legend: the legend (applied only to the first of the set) - :param legend2: the legend for second-type spins - (applied only to the first of the set) - :param y_max_lim: the maximum on the y axis (if None, put the - maximum of the bands) - :param y_min_lim: the minimum on the y axis (if None, put the - minimum of the bands) - :param y_origin: the new origin of the y axis -> all bands are replaced - by bands-y_origin - :param prettify_format: if None, use the default prettify format. Otherwise - specify a string with the prettifier to use. - :param kwargs: additional customization variables; only a subset is - accepted, see internal variable 'valid_additional_keywords - """ - # pylint: disable=too-many-arguments,too-many-locals - - # Only these keywords are accepted in kwargs, and then set into the json - valid_additional_keywords = [ - 'bands_color', # Color of band lines - 'bands_linewidth', # linewidth of bands - 'bands_linestyle', # linestyle of bands - 'bands_marker', # marker for bands - 'bands_markersize', # size of the marker of bands - 'bands_markeredgecolor', # marker edge color for bands - 'bands_markeredgewidth', # marker edge width for bands - 'bands_markerfacecolor', # marker face color for bands - 'bands_color2', # Color of band lines (for other spin, if present) - 'bands_linewidth2', # linewidth of bands (for other spin, if present) - 'bands_linestyle2', # linestyle of bands (for other spin, if present) - 'bands_marker2', # marker for bands (for other spin, if present) - 'bands_markersize2', # size of the marker of bands (for other spin, if present) - 'bands_markeredgecolor2', # marker edge color for bands (for other spin, if present) - 'bands_markeredgewidth2', # marker edge width for bands (for other spin, if present) - 'bands_markerfacecolor2', # marker face color for bands (for other spin, if present) - 'plot_zero_axis', # If true, plot an axis at y=0 - 'zero_axis_color', # Color of the axis at y=0 - 'zero_axis_linestyle', # linestyle of the axis at y=0 - 'zero_axis_linewidth', # linewidth of the axis at y=0 - 'use_latex', # If true, use latex to render captions - ] - - # Note: I do not want to import matplotlib here, for two reasons: - # 1. I would like to be able to print the script for the user - # 2. I don't want to mess up with the user matplotlib backend - # (that I should do if the user does not have a X server, but that - # I do not want to do if he's e.g. in jupyter) - # Therefore I just create a string that can be executed as needed, e.g. with eval. - # I take care of sanitizing the output. - if prettify_format is None: - # Default. Specified like this to allow caller functions to pass 'None' - prettify_format = 'latex_seekpath' - - # The default for use_latex is False - join_symbol = r'\textbar{}' if kwargs.get('use_latex', False) else '|' - - plot_info = self._get_bandplot_data( - cartesian=True, - prettify_format=prettify_format, - join_symbol=join_symbol, - get_segments=True, - y_origin=y_origin - ) - - all_data = {} - - bands = plot_info['y'] - x = plot_info['x'] - labels = plot_info['labels'] - # prepare xticks labels - if labels: - tick_pos, tick_labels = zip(*labels) - else: - tick_pos = [] - tick_labels = [] - - all_data['paths'] = plot_info['paths'] - all_data['band_type_idx'] = plot_info['band_type_idx'].tolist() - - all_data['tick_pos'] = tick_pos - all_data['tick_labels'] = tick_labels - all_data['legend_text'] = legend - all_data['legend_text2'] = legend2 - all_data['yaxis_label'] = f'Dispersion ({self.units})' - all_data['title'] = title - if comments: - all_data['comment'] = prepare_header_comment(self.uuid, plot_info, comment_char='#') - - # axis limits - if y_max_lim is None: - y_max_lim = numpy.nanmax(bands) - if y_min_lim is None: - y_min_lim = numpy.nanmin(bands) - x_min_lim = min(x) # this isn't a numpy array, but a list - x_max_lim = max(x) - all_data['x_min_lim'] = x_min_lim - all_data['x_max_lim'] = x_max_lim - all_data['y_min_lim'] = y_min_lim - all_data['y_max_lim'] = y_max_lim - - for key, value in kwargs.items(): - if key not in valid_additional_keywords: - raise TypeError(f"_matplotlib_get_dict() got an unexpected keyword argument '{key}'") - all_data[key] = value - - return all_data - - def _prepare_mpl_singlefile(self, *args, **kwargs): - """ - Prepare a python script using matplotlib to plot the bands - - For the possible parameters, see documentation of - :py:meth:`~aiida.orm.nodes.data.array.bands.BandsData._matplotlib_get_dict` - """ - all_data = self._matplotlib_get_dict(*args, **kwargs) - - s_header = MATPLOTLIB_HEADER_TEMPLATE.substitute() - s_import = MATPLOTLIB_IMPORT_DATA_INLINE_TEMPLATE.substitute(all_data_json=json.dumps(all_data, indent=2)) - s_body = self._get_mpl_body_template(all_data['paths']) - s_footer = MATPLOTLIB_FOOTER_TEMPLATE_SHOW.substitute() - - string = s_header + s_import + s_body + s_footer - - return string.encode('utf-8'), {} - - def _prepare_mpl_withjson(self, main_file_name='', *args, **kwargs): # pylint: disable=keyword-arg-before-vararg - """ - Prepare a python script using matplotlib to plot the bands, with the JSON - returned as an independent file. - - For the possible parameters, see documentation of - :py:meth:`~aiida.orm.nodes.data.array.bands.BandsData._matplotlib_get_dict` - """ - import os - - all_data = self._matplotlib_get_dict(*args, main_file_name=main_file_name, **kwargs) - - json_fname = os.path.splitext(main_file_name)[0] + '_data.json' - # Escape double_quotes - json_fname = json_fname.replace('"', '\"') - - ext_files = {json_fname: json.dumps(all_data, indent=2).encode('utf-8')} - - s_header = MATPLOTLIB_HEADER_TEMPLATE.substitute() - s_import = MATPLOTLIB_IMPORT_DATA_FROMFILE_TEMPLATE.substitute(json_fname=json_fname) - s_body = self._get_mpl_body_template(all_data['paths']) - s_footer = MATPLOTLIB_FOOTER_TEMPLATE_SHOW.substitute() - - string = s_header + s_import + s_body + s_footer - - return string.encode('utf-8'), ext_files - - def _prepare_mpl_pdf(self, main_file_name='', *args, **kwargs): # pylint: disable=keyword-arg-before-vararg,unused-argument - """ - Prepare a python script using matplotlib to plot the bands, with the JSON - returned as an independent file. - - For the possible parameters, see documentation of - :py:meth:`~aiida.orm.nodes.data.array.bands.BandsData._matplotlib_get_dict` - """ - import os - import subprocess - import sys - import tempfile - - all_data = self._matplotlib_get_dict(*args, **kwargs) - - # Use the Agg backend - s_header = MATPLOTLIB_HEADER_AGG_TEMPLATE.substitute() - s_import = MATPLOTLIB_IMPORT_DATA_INLINE_TEMPLATE.substitute(all_data_json=json.dumps(all_data, indent=2)) - s_body = self._get_mpl_body_template(all_data['paths']) - - # I get a temporary file name - handle, filename = tempfile.mkstemp() - os.close(handle) - os.remove(filename) - - escaped_fname = filename.replace('"', '\"') - - s_footer = MATPLOTLIB_FOOTER_TEMPLATE_EXPORTFILE.substitute(fname=escaped_fname, format='pdf') - - string = s_header + s_import + s_body + s_footer - - # I don't exec it because I might mess up with the matplotlib backend etc. - # I run instead in a different process, with the same executable - # (so it should work properly with virtualenvs) - with tempfile.NamedTemporaryFile(mode='w+') as handle: - handle.write(string) - handle.flush() - subprocess.check_output([sys.executable, handle.name]) - - if not os.path.exists(filename): - raise RuntimeError('Unable to generate the PDF...') - - with open(filename, 'rb', encoding=None) as handle: - imgdata = handle.read() - os.remove(filename) - - return imgdata, {} - - def _prepare_mpl_png(self, main_file_name='', *args, **kwargs): # pylint: disable=keyword-arg-before-vararg,unused-argument - """ - Prepare a python script using matplotlib to plot the bands, with the JSON - returned as an independent file. - - For the possible parameters, see documentation of - :py:meth:`~aiida.orm.nodes.data.array.bands.BandsData._matplotlib_get_dict` - """ - import os - import subprocess - import sys - import tempfile - - all_data = self._matplotlib_get_dict(*args, **kwargs) - - # Use the Agg backend - s_header = MATPLOTLIB_HEADER_AGG_TEMPLATE.substitute() - s_import = MATPLOTLIB_IMPORT_DATA_INLINE_TEMPLATE.substitute(all_data_json=json.dumps(all_data, indent=2)) - s_body = self._get_mpl_body_template(all_data['paths']) - - # I get a temporary file name - handle, filename = tempfile.mkstemp() - os.close(handle) - os.remove(filename) - - escaped_fname = filename.replace('"', '\"') - - s_footer = MATPLOTLIB_FOOTER_TEMPLATE_EXPORTFILE_WITH_DPI.substitute(fname=escaped_fname, format='png', dpi=300) - - string = s_header + s_import + s_body + s_footer - - # I don't exec it because I might mess up with the matplotlib backend etc. - # I run instead in a different process, with the same executable - # (so it should work properly with virtualenvs) - with tempfile.NamedTemporaryFile(mode='w+') as handle: - handle.write(string) - handle.flush() - subprocess.check_output([sys.executable, handle.name]) - - if not os.path.exists(filename): - raise RuntimeError('Unable to generate the PNG...') - - with open(filename, 'rb', encoding=None) as handle: - imgdata = handle.read() - os.remove(filename) - - return imgdata, {} - - @staticmethod - def _get_mpl_body_template(paths): - """ - :param paths: paths of k-points - """ - if len(paths) == 1: - s_body = MATPLOTLIB_BODY_TEMPLATE.substitute(plot_code=SINGLE_KP) - else: - s_body = MATPLOTLIB_BODY_TEMPLATE.substitute(plot_code=MULTI_KP) - return s_body - - def show_mpl(self, **kwargs): - """ - Call a show() command for the band structure using matplotlib. - This uses internally the 'mpl_singlefile' format, with empty - main_file_name. - - Other kwargs are passed to self._exportcontent. - """ - exec(*self._exportcontent(fileformat='mpl_singlefile', main_file_name='', **kwargs)) # pylint: disable=exec-used - - def _prepare_gnuplot( - self, - main_file_name=None, - title='', - comments=True, - prettify_format=None, - y_max_lim=None, - y_min_lim=None, - y_origin=0. - ): - """ - Prepare an gnuplot script to plot the bands, with the .dat file - returned as an independent file. - - :param main_file_name: if the user asks to write the main content on a - file, this contains the filename. This should be used to infer a - good filename for the additional files. - In this case, we remove the extension, and add '_data.dat' - :param title: if specified, add a title to the plot - :param comments: if True, print comments (if it makes sense for the given - format) - :param prettify_format: if None, use the default prettify format. Otherwise - specify a string with the prettifier to use. - """ - # pylint: disable=too-many-arguments,too-many-locals - import os - - main_file_name = main_file_name or 'band.dat' - dat_filename = os.path.splitext(main_file_name)[0] + '_data.dat' - - if prettify_format is None: - # Default. Specified like this to allow caller functions to pass 'None' - prettify_format = 'gnuplot_seekpath' - - plot_info = self._get_bandplot_data( - cartesian=True, prettify_format=prettify_format, join_symbol='|', y_origin=y_origin - ) - - bands = plot_info['y'] - x = plot_info['x'] - - # axis limits - if y_max_lim is None: - y_max_lim = bands.max() - if y_min_lim is None: - y_min_lim = bands.min() - x_min_lim = min(x) # this isn't a numpy array, but a list - x_max_lim = max(x) - - # first prepare the xy coordinates of the sets - raw_data, _ = self._prepare_dat_blocks(plot_info, comments=comments) - - xtics_string = ', '.join(f'"{label}" {pos}' for pos, label in plot_info['labels']) - - script = [] - # Start with some useful comments - - if comments: - script.append(prepare_header_comment(self.uuid, plot_info=plot_info, comment_char='# ')) - script.append('') - - script.append( - """## Uncomment the next two lines to write directly to PDF -## Note: You need to have gnuplot installed with pdfcairo support! -#set term pdfcairo -#set output 'out.pdf' - -### Uncomment one of the options below to change font -### For the LaTeX fonts, you can download them from here: -### https://sourceforge.net/projects/cm-unicode/ -### And then install them in your system -## LaTeX Serif font, if installed -#set termopt font "CMU Serif, 12" -## LaTeX Sans Serif font, if installed -#set termopt font "CMU Sans Serif, 12" -## Classical Times New Roman -#set termopt font "Times New Roman, 12" -""" - ) - - # Actual logic - script.append('set termopt enhanced') # Properly deals with e.g. subscripts - script.append('set encoding utf8') # To deal with Greek letters - script.append(f'set xtics ({xtics_string})') - script.append('unset key') - script.append(f'set yrange [{y_min_lim}:{y_max_lim}]') - script.append(f"set ylabel \"Dispersion ({self.units})\"") - - if title: - script.append('set title "{}"'.format(title.replace('"', '\"'))) - - # Plot, escaping filename - if len(x) > 1: - script.append(f'set xrange [{x_min_lim}:{x_max_lim}]') - script.append('set grid xtics lt 1 lc rgb "#888888"') - script.append('plot "{}" with l lc rgb "#000000"'.format(os.path.basename(dat_filename).replace('"', '\"'))) - else: - script.append('set xrange [-1.0:1.0]') - script.append( - 'plot "{}" using ($1-0.25):($2):(0.5):(0) with vectors nohead lc rgb "#000000"'.format( - os.path.basename(dat_filename).replace('"', '\"') - ) - ) - - script_data = '\n'.join(script) + '\n' - extra_files = {dat_filename: raw_data} - - return script_data.encode('utf-8'), extra_files - - def _prepare_agr( - self, - main_file_name='', - comments=True, - setnumber_offset=0, - color_number=1, - color_number2=2, - legend='', - title='', - y_max_lim=None, - y_min_lim=None, - y_origin=0., - prettify_format=None - ): - """ - Prepare an xmgrace agr file. - - :param comments: if True, print comments - (if it makes sense for the given format) - :param plot_info: a dictionary - :param setnumber_offset: an offset to be applied to all set numbers - (i.e. s0 is replaced by s[offset], s1 by s[offset+1], etc.) - :param color_number: the color number for lines, symbols, error bars - and filling (should be less than the parameter MAX_NUM_AGR_COLORS - defined below) - :param color_number2: the color number for lines, symbols, error bars - and filling for the second-type spins (should be less than the - parameter MAX_NUM_AGR_COLORS defined below) - :param legend: the legend (applied only to the first set) - :param title: the title - :param y_max_lim: the maximum on the y axis (if None, put the - maximum of the bands); applied *after* shifting the origin - by ``y_origin`` - :param y_min_lim: the minimum on the y axis (if None, put the - minimum of the bands); applied *after* shifting the origin - by ``y_origin`` - :param y_origin: the new origin of the y axis -> all bands are replaced - by bands-y_origin - :param prettify_format: if None, use the default prettify format. Otherwise - specify a string with the prettifier to use. - """ - # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,unused-argument - if prettify_format is None: - # Default. Specified like this to allow caller functions to pass 'None' - prettify_format = 'agr_seekpath' - - plot_info = self._get_bandplot_data( - cartesian=True, prettify_format=prettify_format, join_symbol='|', y_origin=y_origin - ) - - import math - - # load the x and y of every set - if color_number > MAX_NUM_AGR_COLORS: - raise ValueError(f'Color number is too high (should be less than {MAX_NUM_AGR_COLORS})') - if color_number2 > MAX_NUM_AGR_COLORS: - raise ValueError(f'Color number 2 is too high (should be less than {MAX_NUM_AGR_COLORS})') - - bands = plot_info['y'] - x = plot_info['x'] - the_bands = numpy.transpose(bands) - labels = plot_info['labels'] - num_labels = len(labels) - - # axis limits - if y_max_lim is None: - y_max_lim = the_bands.max() - if y_min_lim is None: - y_min_lim = the_bands.min() - x_min_lim = min(x) # this isn't a numpy array, but a list - x_max_lim = max(x) - ytick_spacing = 10**int(math.log10((y_max_lim - y_min_lim))) - - # prepare xticks labels - sx1 = '' - for i, label in enumerate(labels): - sx1 += AGR_SINGLE_XTICK_TEMPLATE.substitute( - index=i, - coord=label[0], - name=label[1], - ) - xticks = AGR_XTICKS_TEMPLATE.substitute( - num_labels=num_labels, - single_xtick_templates=sx1, - ) - - # build the arrays with the xy coordinates - all_sets = [] - for band in the_bands: - this_set = '' - for i in zip(x, band): - line = f'{i[0]:.8f}' + '\t' + f'{i[1]:.8f}' + '\n' - this_set += line - all_sets.append(this_set) - - set_descriptions = '' - for i, (this_set, band_type) in enumerate(zip(all_sets, plot_info['band_type_idx'])): - if band_type % 2 == 0: - linecolor = color_number - else: - linecolor = color_number2 - width = str(2.0) - set_descriptions += AGR_SET_DESCRIPTION_TEMPLATE.substitute( - set_number=i + setnumber_offset, - linewidth=width, - color_number=linecolor, - legend=legend if i == 0 else '' - ) - - units = self.units - - graphs = AGR_GRAPH_TEMPLATE.substitute( - x_min_lim=x_min_lim, - y_min_lim=y_min_lim, - x_max_lim=x_max_lim, - y_max_lim=y_max_lim, - yaxislabel=f'Dispersion ({units})', - xticks_template=xticks, - set_descriptions=set_descriptions, - ytick_spacing=ytick_spacing, - title=title, - ) - sets = [] - for i, this_set in enumerate(all_sets): - sets.append(AGR_SINGLESET_TEMPLATE.substitute(set_number=i + setnumber_offset, xydata=this_set)) - the_sets = '&\n'.join(sets) - - string = AGR_TEMPLATE.substitute(graphs=graphs, sets=the_sets) - - if comments: - string = prepare_header_comment(self.uuid, plot_info, comment_char='#') + '\n' + string - - return string.encode('utf-8'), {} - - def _get_band_segments(self, cartesian): - """Return the band segments.""" - plot_info = self._get_bandplot_data( - cartesian=cartesian, prettify_format=None, join_symbol=None, get_segments=True - ) - - out_dict = {'label': self.label} - - out_dict['path'] = plot_info['path'] - out_dict['paths'] = plot_info['paths'] - - return out_dict - - def _prepare_json(self, main_file_name='', comments=True): # pylint: disable=unused-argument - """ - Prepare a json file in a format compatible with the AiiDA band visualizer - - :param comments: if True, print comments (if it makes sense for the given - format) - """ - from aiida import get_file_header - - json_dict = self._get_band_segments(cartesian=True) - json_dict['original_uuid'] = self.uuid - - if comments: - json_dict['comments'] = get_file_header(comment_char='') - - return json.dumps(json_dict).encode('utf-8'), {} - - -MAX_NUM_AGR_COLORS = 15 - -AGR_TEMPLATE = Template( - """ - # Grace project file - # - @version 50122 - @page size 792, 612 - @page scroll 5% - @page inout 5% - @link page off - @map font 8 to "Courier", "Courier" - @map font 10 to "Courier-Bold", "Courier-Bold" - @map font 11 to "Courier-BoldOblique", "Courier-BoldOblique" - @map font 9 to "Courier-Oblique", "Courier-Oblique" - @map font 4 to "Helvetica", "Helvetica" - @map font 6 to "Helvetica-Bold", "Helvetica-Bold" - @map font 7 to "Helvetica-BoldOblique", "Helvetica-BoldOblique" - @map font 5 to "Helvetica-Oblique", "Helvetica-Oblique" - @map font 14 to "NimbusMonoL-BoldOblique", "NimbusMonoL-BoldOblique" - @map font 15 to "NimbusMonoL-Regular", "NimbusMonoL-Regular" - @map font 16 to "NimbusMonoL-RegularOblique", "NimbusMonoL-RegularOblique" - @map font 17 to "NimbusRomanNo9L-Medium", "NimbusRomanNo9L-Medium" - @map font 18 to "NimbusRomanNo9L-MediumItalic", "NimbusRomanNo9L-MediumItalic" - @map font 19 to "NimbusRomanNo9L-Regular", "NimbusRomanNo9L-Regular" - @map font 20 to "NimbusRomanNo9L-RegularItalic", "NimbusRomanNo9L-RegularItalic" - @map font 21 to "NimbusSansL-Bold", "NimbusSansL-Bold" - @map font 22 to "NimbusSansL-BoldCondensed", "NimbusSansL-BoldCondensed" - @map font 23 to "NimbusSansL-BoldCondensedItalic", "NimbusSansL-BoldCondensedItalic" - @map font 24 to "NimbusSansL-BoldItalic", "NimbusSansL-BoldItalic" - @map font 25 to "NimbusSansL-Regular", "NimbusSansL-Regular" - @map font 26 to "NimbusSansL-RegularCondensed", "NimbusSansL-RegularCondensed" - @map font 27 to "NimbusSansL-RegularCondensedItalic", "NimbusSansL-RegularCondensedItalic" - @map font 28 to "NimbusSansL-RegularItalic", "NimbusSansL-RegularItalic" - @map font 29 to "StandardSymbolsL-Regular", "StandardSymbolsL-Regular" - @map font 12 to "Symbol", "Symbol" - @map font 31 to "Symbol-Regular", "Symbol-Regular" - @map font 2 to "Times-Bold", "Times-Bold" - @map font 3 to "Times-BoldItalic", "Times-BoldItalic" - @map font 1 to "Times-Italic", "Times-Italic" - @map font 0 to "Times-Roman", "Times-Roman" - @map font 36 to "URWBookmanL-DemiBold", "URWBookmanL-DemiBold" - @map font 37 to "URWBookmanL-DemiBoldItalic", "URWBookmanL-DemiBoldItalic" - @map font 38 to "URWBookmanL-Light", "URWBookmanL-Light" - @map font 39 to "URWBookmanL-LightItalic", "URWBookmanL-LightItalic" - @map font 40 to "URWChanceryL-MediumItalic", "URWChanceryL-MediumItalic" - @map font 41 to "URWGothicL-Book", "URWGothicL-Book" - @map font 42 to "URWGothicL-BookOblique", "URWGothicL-BookOblique" - @map font 43 to "URWGothicL-Demi", "URWGothicL-Demi" - @map font 44 to "URWGothicL-DemiOblique", "URWGothicL-DemiOblique" - @map font 45 to "URWPalladioL-Bold", "URWPalladioL-Bold" - @map font 46 to "URWPalladioL-BoldItalic", "URWPalladioL-BoldItalic" - @map font 47 to "URWPalladioL-Italic", "URWPalladioL-Italic" - @map font 48 to "URWPalladioL-Roman", "URWPalladioL-Roman" - @map font 13 to "ZapfDingbats", "ZapfDingbats" - @map color 0 to (255, 255, 255), "white" - @map color 1 to (0, 0, 0), "black" - @map color 2 to (255, 0, 0), "red" - @map color 3 to (0, 255, 0), "green" - @map color 4 to (0, 0, 255), "blue" - @map color 5 to (255, 215, 0), "yellow" - @map color 6 to (188, 143, 143), "brown" - @map color 7 to (220, 220, 220), "grey" - @map color 8 to (148, 0, 211), "violet" - @map color 9 to (0, 255, 255), "cyan" - @map color 10 to (255, 0, 255), "magenta" - @map color 11 to (255, 165, 0), "orange" - @map color 12 to (114, 33, 188), "indigo" - @map color 13 to (103, 7, 72), "maroon" - @map color 14 to (64, 224, 208), "turquoise" - @map color 15 to (0, 139, 0), "green4" - @reference date 0 - @date wrap off - @date wrap year 1950 - @default linewidth 1.0 - @default linestyle 1 - @default color 1 - @default pattern 1 - @default font 0 - @default char size 1.000000 - @default symbol size 1.000000 - @default sformat "%.8g" - @background color 0 - @page background fill on - @timestamp off - @timestamp 0.03, 0.03 - @timestamp color 1 - @timestamp rot 0 - @timestamp font 0 - @timestamp char size 1.000000 - @timestamp def "Wed Jul 30 16:44:34 2014" - @r0 off - @link r0 to g0 - @r0 type above - @r0 linestyle 1 - @r0 linewidth 1.0 - @r0 color 1 - @r0 line 0, 0, 0, 0 - @r1 off - @link r1 to g0 - @r1 type above - @r1 linestyle 1 - @r1 linewidth 1.0 - @r1 color 1 - @r1 line 0, 0, 0, 0 - @r2 off - @link r2 to g0 - @r2 type above - @r2 linestyle 1 - @r2 linewidth 1.0 - @r2 color 1 - @r2 line 0, 0, 0, 0 - @r3 off - @link r3 to g0 - @r3 type above - @r3 linestyle 1 - @r3 linewidth 1.0 - @r3 color 1 - @r3 line 0, 0, 0, 0 - @r4 off - @link r4 to g0 - @r4 type above - @r4 linestyle 1 - @r4 linewidth 1.0 - @r4 color 1 - @r4 line 0, 0, 0, 0 - $graphs - $sets - """ -) - -AGR_XTICKS_TEMPLATE = Template(""" - @ xaxis tick spec $num_labels - $single_xtick_templates - """) - -AGR_SINGLE_XTICK_TEMPLATE = Template( - """ - @ xaxis tick major $index, $coord - @ xaxis ticklabel $index, "$name" - """ -) - -AGR_GRAPH_TEMPLATE = Template( - """ - @g0 on - @g0 hidden false - @g0 type XY - @g0 stacked false - @g0 bar hgap 0.000000 - @g0 fixedpoint off - @g0 fixedpoint type 0 - @g0 fixedpoint xy 0.000000, 0.000000 - @g0 fixedpoint format general general - @g0 fixedpoint prec 6, 6 - @with g0 - @ world $x_min_lim, $y_min_lim, $x_max_lim, $y_max_lim - @ stack world 0, 0, 0, 0 - @ znorm 1 - @ view 0.150000, 0.150000, 1.150000, 0.850000 - @ title "$title" - @ title font 0 - @ title size 1.500000 - @ title color 1 - @ subtitle "" - @ subtitle font 0 - @ subtitle size 1.000000 - @ subtitle color 1 - @ xaxes scale Normal - @ yaxes scale Normal - @ xaxes invert off - @ yaxes invert off - @ xaxis on - @ xaxis type zero false - @ xaxis offset 0.000000 , 0.000000 - @ xaxis bar on - @ xaxis bar color 1 - @ xaxis bar linestyle 1 - @ xaxis bar linewidth 1.0 - @ xaxis label "" - @ xaxis label layout para - @ xaxis label place auto - @ xaxis label char size 1.000000 - @ xaxis label font 4 - @ xaxis label color 1 - @ xaxis label place normal - @ xaxis tick on - @ xaxis tick major 5 - @ xaxis tick minor ticks 0 - @ xaxis tick default 6 - @ xaxis tick place rounded true - @ xaxis tick in - @ xaxis tick major size 1.000000 - @ xaxis tick major color 1 - @ xaxis tick major linewidth 1.0 - @ xaxis tick major linestyle 1 - @ xaxis tick major grid on - @ xaxis tick minor color 1 - @ xaxis tick minor linewidth 1.0 - @ xaxis tick minor linestyle 1 - @ xaxis tick minor grid off - @ xaxis tick minor size 0.500000 - @ xaxis ticklabel on - @ xaxis ticklabel format general - @ xaxis ticklabel prec 5 - @ xaxis ticklabel formula "" - @ xaxis ticklabel append "" - @ xaxis ticklabel prepend "" - @ xaxis ticklabel angle 0 - @ xaxis ticklabel skip 0 - @ xaxis ticklabel stagger 0 - @ xaxis ticklabel place normal - @ xaxis ticklabel offset auto - @ xaxis ticklabel offset 0.000000 , 0.010000 - @ xaxis ticklabel start type auto - @ xaxis ticklabel start 0.000000 - @ xaxis ticklabel stop type auto - @ xaxis ticklabel stop 0.000000 - @ xaxis ticklabel char size 1.500000 - @ xaxis ticklabel font 4 - @ xaxis ticklabel color 1 - @ xaxis tick place both - @ xaxis tick spec type both - $xticks_template - @ yaxis on - @ yaxis type zero false - @ yaxis offset 0.000000 , 0.000000 - @ yaxis bar on - @ yaxis bar color 1 - @ yaxis bar linestyle 1 - @ yaxis bar linewidth 1.0 - @ yaxis label "$yaxislabel" - @ yaxis label layout para - @ yaxis label place auto - @ yaxis label char size 1.500000 - @ yaxis label font 4 - @ yaxis label color 1 - @ yaxis label place normal - @ yaxis tick on - @ yaxis tick major $ytick_spacing - @ yaxis tick minor ticks 1 - @ yaxis tick default 6 - @ yaxis tick place rounded true - @ yaxis tick in - @ yaxis tick major size 1.000000 - @ yaxis tick major color 1 - @ yaxis tick major linewidth 1.0 - @ yaxis tick major linestyle 1 - @ yaxis tick major grid off - @ yaxis tick minor color 1 - @ yaxis tick minor linewidth 1.0 - @ yaxis tick minor linestyle 1 - @ yaxis tick minor grid off - @ yaxis tick minor size 0.500000 - @ yaxis ticklabel on - @ yaxis ticklabel format general - @ yaxis ticklabel prec 5 - @ yaxis ticklabel formula "" - @ yaxis ticklabel append "" - @ yaxis ticklabel prepend "" - @ yaxis ticklabel angle 0 - @ yaxis ticklabel skip 0 - @ yaxis ticklabel stagger 0 - @ yaxis ticklabel place normal - @ yaxis ticklabel offset auto - @ yaxis ticklabel offset 0.000000 , 0.010000 - @ yaxis ticklabel start type auto - @ yaxis ticklabel start 0.000000 - @ yaxis ticklabel stop type auto - @ yaxis ticklabel stop 0.000000 - @ yaxis ticklabel char size 1.250000 - @ yaxis ticklabel font 4 - @ yaxis ticklabel color 1 - @ yaxis tick place both - @ yaxis tick spec type none - @ altxaxis off - @ altyaxis off - @ legend on - @ legend loctype view - @ legend 0.85, 0.8 - @ legend box color 1 - @ legend box pattern 1 - @ legend box linewidth 1.0 - @ legend box linestyle 1 - @ legend box fill color 0 - @ legend box fill pattern 1 - @ legend font 0 - @ legend char size 1.000000 - @ legend color 1 - @ legend length 4 - @ legend vgap 1 - @ legend hgap 1 - @ legend invert false - @ frame type 0 - @ frame linestyle 1 - @ frame linewidth 1.0 - @ frame color 1 - @ frame pattern 1 - @ frame background color 0 - @ frame background pattern 0 - $set_descriptions - """ -) - -AGR_SET_DESCRIPTION_TEMPLATE = Template( - """ - @ s$set_number hidden false - @ s$set_number type xy - @ s$set_number symbol 0 - @ s$set_number symbol size 1.000000 - @ s$set_number symbol color $color_number - @ s$set_number symbol pattern 1 - @ s$set_number symbol fill color $color_number - @ s$set_number symbol fill pattern 0 - @ s$set_number symbol linewidth 1.0 - @ s$set_number symbol linestyle 1 - @ s$set_number symbol char 65 - @ s$set_number symbol char font 0 - @ s$set_number symbol skip 0 - @ s$set_number line type 1 - @ s$set_number line linestyle 1 - @ s$set_number line linewidth $linewidth - @ s$set_number line color $color_number - @ s$set_number line pattern 1 - @ s$set_number baseline type 0 - @ s$set_number baseline off - @ s$set_number dropline off - @ s$set_number fill type 0 - @ s$set_number fill rule 0 - @ s$set_number fill color $color_number - @ s$set_number fill pattern 1 - @ s$set_number avalue off - @ s$set_number avalue type 2 - @ s$set_number avalue char size 1.000000 - @ s$set_number avalue font 0 - @ s$set_number avalue color 1 - @ s$set_number avalue rot 0 - @ s$set_number avalue format general - @ s$set_number avalue prec 3 - @ s$set_number avalue prepend "" - @ s$set_number avalue append "" - @ s$set_number avalue offset 0.000000 , 0.000000 - @ s$set_number errorbar on - @ s$set_number errorbar place both - @ s$set_number errorbar color $color_number - @ s$set_number errorbar pattern 1 - @ s$set_number errorbar size 1.000000 - @ s$set_number errorbar linewidth 1.0 - @ s$set_number errorbar linestyle 1 - @ s$set_number errorbar riser linewidth 1.0 - @ s$set_number errorbar riser linestyle 1 - @ s$set_number errorbar riser clip off - @ s$set_number errorbar riser clip length 0.100000 - @ s$set_number comment "Cols 1:2" - @ s$set_number legend "$legend" - """ -) - -AGR_SINGLESET_TEMPLATE = Template(""" - @target G0.S$set_number - @type xy - $xydata - """) - -MATPLOTLIB_HEADER_AGG_TEMPLATE = Template( - """# -*- coding: utf-8 -*- - -import matplotlib -matplotlib.use('Agg') - -from matplotlib import rc -# Uncomment to change default font -#rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']}) -rc('font', **{'family': 'serif', 'serif': ['Computer Modern', 'CMU Serif', 'Times New Roman', 'DejaVu Serif']}) -# To use proper font for, e.g., Gamma if usetex is set to False -rc('mathtext', fontset='cm') - -rc('text', usetex=True) - -import pylab as pl - -# I use json to make sure the input is sanitized -import json - -print_comment = False -""" -) - -MATPLOTLIB_HEADER_TEMPLATE = Template( - """# -*- coding: utf-8 -*- - -from matplotlib import rc -# Uncomment to change default font -#rc('font',**{'family':'sans-serif','sans-serif':['Helvetica']}) -rc('font', **{'family': 'serif', 'serif': ['Computer Modern', 'CMU Serif', 'Times New Roman', 'DejaVu Serif']}) -# To use proper font for, e.g., Gamma if usetex is set to False -rc('mathtext', fontset='cm') - -rc('text', usetex=True) - -import pylab as pl - -# I use json to make sure the input is sanitized -import json - -print_comment = False -""" -) - -MATPLOTLIB_IMPORT_DATA_INLINE_TEMPLATE = Template('''all_data_str = r"""$all_data_json""" -''') - -MATPLOTLIB_IMPORT_DATA_FROMFILE_TEMPLATE = Template( - """with open("$json_fname", encoding='utf8') as f: - all_data_str = f.read() -""" -) - -MULTI_KP = """ -for path in paths: - if path['length'] <= 1: - # Avoid printing empty lines - continue - x = path['x'] - #for band in bands: - for band, band_type in zip(path['values'], all_data['band_type_idx']): - - # For now we support only two colors - if band_type % 2 == 0: - further_plot_options = further_plot_options1 - else: - further_plot_options = further_plot_options2 - - # Put the legend text only once - label = None - if first_band_1 and band_type % 2 == 0: - first_band_1 = False - label = all_data.get('legend_text', None) - elif first_band_2 and band_type % 2 == 1: - first_band_2 = False - label = all_data.get('legend_text2', None) - - p.plot(x, band, label=label, - **further_plot_options - ) -""" - -SINGLE_KP = """ -path = paths[0] -values = path['values'] -x = [path['x'] for _ in values] -p.scatter(x, values, marker="_") -""" - -MATPLOTLIB_BODY_TEMPLATE = Template( - """all_data = json.loads(all_data_str) - -if not all_data.get('use_latex', False): - rc('text', usetex=False) - -#x = all_data['x'] -#bands = all_data['bands'] -paths = all_data['paths'] -tick_pos = all_data['tick_pos'] -tick_labels = all_data['tick_labels'] - -# Option for bands (all, or those of type 1 if there are two spins) -further_plot_options1 = {} -further_plot_options1['color'] = all_data.get('bands_color', 'k') -further_plot_options1['linewidth'] = all_data.get('bands_linewidth', 0.5) -further_plot_options1['linestyle'] = all_data.get('bands_linestyle', None) -further_plot_options1['marker'] = all_data.get('bands_marker', None) -further_plot_options1['markersize'] = all_data.get('bands_markersize', None) -further_plot_options1['markeredgecolor'] = all_data.get('bands_markeredgecolor', None) -further_plot_options1['markeredgewidth'] = all_data.get('bands_markeredgewidth', None) -further_plot_options1['markerfacecolor'] = all_data.get('bands_markerfacecolor', None) - -# Options for second-type of bands if present (e.g. spin up vs. spin down) -further_plot_options2 = {} -further_plot_options2['color'] = all_data.get('bands_color2', 'r') -# Use the values of further_plot_options1 by default -further_plot_options2['linewidth'] = all_data.get('bands_linewidth2', - further_plot_options1['linewidth'] - ) -further_plot_options2['linestyle'] = all_data.get('bands_linestyle2', - further_plot_options1['linestyle'] - ) -further_plot_options2['marker'] = all_data.get('bands_marker2', - further_plot_options1['marker'] - ) -further_plot_options2['markersize'] = all_data.get('bands_markersize2', - further_plot_options1['markersize'] - ) -further_plot_options2['markeredgecolor'] = all_data.get('bands_markeredgecolor2', - further_plot_options1['markeredgecolor'] - ) -further_plot_options2['markeredgewidth'] = all_data.get('bands_markeredgewidth2', - further_plot_options1['markeredgewidth'] - ) -further_plot_options2['markerfacecolor'] = all_data.get('bands_markerfacecolor2', - further_plot_options1['markerfacecolor'] - ) - -fig = pl.figure() -p = fig.add_subplot(1,1,1) - -first_band_1 = True -first_band_2 = True - -${plot_code} - -p.set_xticks(tick_pos) -p.set_xticklabels(tick_labels) -p.set_xlim([all_data['x_min_lim'], all_data['x_max_lim']]) -p.set_ylim([all_data['y_min_lim'], all_data['y_max_lim']]) -p.xaxis.grid(True, which='major', color='#888888', linestyle='-', linewidth=0.5) - -if all_data.get('plot_zero_axis', False): - p.axhline( - 0., - color=all_data.get('zero_axis_color', '#888888'), - linestyle=all_data.get('zero_axis_linestyle', '--'), - linewidth=all_data.get('zero_axis_linewidth', 0.5), - ) -if all_data['title']: - p.set_title(all_data['title']) -if all_data['legend_text']: - p.legend(loc='best') -p.set_ylabel(all_data['yaxis_label']) - -try: - if print_comment: - print(all_data['comment']) -except KeyError: - pass -""" -) - -MATPLOTLIB_FOOTER_TEMPLATE_SHOW = Template("""pl.show()""") - -MATPLOTLIB_FOOTER_TEMPLATE_EXPORTFILE = Template("""pl.savefig("$fname", format="$format")""") - -MATPLOTLIB_FOOTER_TEMPLATE_EXPORTFILE_WITH_DPI = Template("""pl.savefig("$fname", format="$format", dpi=$dpi)""") - - -def get_bands_and_parents_structure(args, backend=None): - """Search for bands and return bands and the closest structure that is a parent of the instance. - - :returns: - A list of sublists, each latter containing (in order): - pk as string, formula as string, creation date, bandsdata-label - """ - # pylint: disable=too-many-locals,too-many-branches - - import datetime - - from aiida import orm - from aiida.common import timezone - - q_build = orm.QueryBuilder(backend=backend) - if args.all_users is False: - q_build.append(orm.User, tag='creator', filters={'email': orm.User.collection.get_default().email}) - else: - q_build.append(orm.User, tag='creator') - - group_filters = {} - with_args = {} - - if args.group_name is not None: - group_filters.update({'label': {'in': args.group_name}}) - if args.group_pk is not None: - group_filters.update({'id': {'in': args.group_pk}}) - - if group_filters: - q_build.append(orm.Group, tag='group', filters=group_filters, with_user='creator') - with_args = {'with_group': 'group'} - else: - # Note: This is a workaround for the QB constraint of not allowing multiple ``with_*`` criteria. Correctly we - # would like to specify with_user always on the ``BandsData`` directly and optionally add with_group. Until this - # is resolved, add the ``with_user`` on the group if specified and on the ``BandsData`` if not. - with_args = {'with_user': 'creator'} - - bdata_filters = {} - if args.past_days is not None: - bdata_filters.update({'ctime': {'>=': timezone.now() - datetime.timedelta(days=args.past_days)}}) - - q_build.append(orm.BandsData, tag='bdata', filters=bdata_filters, project=['id', 'label', 'ctime'], **with_args) - bands_list_data = q_build.all() - - q_build.append( - orm.StructureData, - tag='sdata', - with_descendants='bdata', - # We don't care about the creator of StructureData - project=['id', 'attributes.kinds', 'attributes.sites'] - ) - - q_build.order_by({orm.StructureData: {'ctime': 'desc'}}) - - structure_dict = {} - list_data = q_build.distinct().all() - for bid, _, _, _, akinds, asites in list_data: - structure_dict[bid] = (akinds, asites) - - entry_list = [] - already_visited_bdata = set() - - for [bid, blabel, bdate] in bands_list_data: - - # We process only one StructureData per BandsData. - # We want to process the closest StructureData to - # every BandsData. - # We hope that the StructureData with the latest - # creation time is the closest one. - # This will be updated when the QueryBuilder supports - # order_by by the distance of two nodes. - if already_visited_bdata.__contains__(bid): - continue - already_visited_bdata.add(bid) - strct = structure_dict.get(bid, None) - - if strct is not None: - akinds, asites = strct - formula = _extract_formula(akinds, asites, args) - else: - if args.element is not None or args.element_only is not None: - formula = None - else: - formula = '<>' - - if formula is None: - continue - entry_list.append([str(bid), str(formula), bdate.strftime('%d %b %Y'), blabel]) - - return entry_list - - -def _extract_formula(akinds, asites, args): - """ - Extract formula from the structure object. - - :param akinds: list of kinds, e.g. [{'mass': 55.845, 'name': 'Fe', 'symbols': ['Fe'], 'weights': [1.0]}, - {'mass': 15.9994, 'name': 'O', 'symbols': ['O'], 'weights': [1.0]}] - :param asites: list of structure sites e.g. [{'position': [0.0, 0.0, 0.0], 'kind_name': 'Fe'}, - {'position': [2.0, 2.0, 2.0], 'kind_name': 'O'}] - :param args: a namespace with parsed command line parameters, here only 'element' and 'element_only' are used - :type args: dict - - :return: a string with formula if the formula is found - """ - from aiida.orm.nodes.data.structure import get_formula, get_symbols_string - - if args.element is not None: - all_symbols = [_['symbols'][0] for _ in akinds] - if not any(s in args.element for s in all_symbols): - return None - - if args.element_only is not None: - all_symbols = [_['symbols'][0] for _ in akinds] - if not all(s in all_symbols for s in args.element_only): - return None - - # We want only the StructureData that have attributes - if akinds is None or asites is None: - return '<>' - - symbol_dict = {} - for k in akinds: - symbols = k['symbols'] - weights = k['weights'] - symbol_dict[k['name']] = get_symbols_string(symbols, weights) - - try: - symbol_list = [] - for site in asites: - symbol_list.append(symbol_dict[site['kind_name']]) - formula = get_formula(symbol_list, mode=args.formulamode) - # If for some reason there is no kind with the name - # referenced by the site - except KeyError: - formula = '<>' - return formula diff --git a/aiida/orm/nodes/data/array/kpoints.py b/aiida/orm/nodes/data/array/kpoints.py deleted file mode 100644 index 1f7432d8f2..0000000000 --- a/aiida/orm/nodes/data/array/kpoints.py +++ /dev/null @@ -1,506 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -Module of the KpointsData class, defining the AiiDA data type for storing -lists and meshes of k-points (i.e., points in the reciprocal space of a -periodic crystal structure). -""" -import numpy - -from .array import ArrayData - -__all__ = ('KpointsData',) - -_DEFAULT_EPSILON_LENGTH = 1e-5 -_DEFAULT_EPSILON_ANGLE = 1e-5 - - -class KpointsData(ArrayData): - """ - Class to handle array of kpoints in the Brillouin zone. Provide methods to - generate either user-defined k-points or path of k-points along symmetry - lines. - Internally, all k-points are defined in terms of crystal (fractional) - coordinates. - Cell and lattice vector coordinates are in Angstroms, reciprocal lattice - vectors in Angstrom^-1 . - :note: The methods setting and using the Bravais lattice info assume the - PRIMITIVE unit cell is provided in input to the set_cell or - set_cell_from_structure methods. - """ - - def get_description(self): - """ - Returns a string with infos retrieved from kpoints node's properties. - :param node: - :return: retstr - """ - try: - mesh = self.get_kpoints_mesh() - return 'Kpoints mesh: {}x{}x{} (+{:.1f},{:.1f},{:.1f})'.format( - mesh[0][0], mesh[0][1], mesh[0][2], mesh[1][0], mesh[1][1], mesh[1][2] - ) - except AttributeError: - try: - return f'(Path of {len(self.get_kpoints())} kpts)' - except OSError: - return self.node_type - - @property - def cell(self): - """ - The crystal unit cell. Rows are the crystal vectors in Angstroms. - :return: a 3x3 numpy.array - """ - return numpy.array(self.base.attributes.get('cell')) - - @cell.setter - def cell(self, value): - """ - Set the crystal unit cell - :param value: a 3x3 list/tuple/array of numbers (units = Angstroms). - """ - self._set_cell(value) - - def _set_cell(self, value): - """ - Validate if 'value' is a allowed crystal unit cell - :param value: something compatible with a 3x3 tuple of floats - """ - from aiida.common.exceptions import ModificationNotAllowed - from aiida.orm.nodes.data.structure import _get_valid_cell - - if self.is_stored: - raise ModificationNotAllowed('KpointsData cannot be modified, it has already been stored') - - the_cell = _get_valid_cell(value) - - self.base.attributes.set('cell', the_cell) - - @property - def pbc(self): - """ - The periodic boundary conditions along the vectors a1,a2,a3. - - :return: a tuple of three booleans, each one tells if there are periodic - boundary conditions for the i-th real-space direction (i=1,2,3) - """ - # return copy.deepcopy(self._pbc) - return (self.base.attributes.get('pbc1'), self.base.attributes.get('pbc2'), self.base.attributes.get('pbc3')) - - @pbc.setter - def pbc(self, value): - """ - Set the value of pbc, i.e. a tuple of three booleans, indicating if the - cell is periodic in the 1,2,3 crystal direction - """ - self._set_pbc(value) - - def _set_pbc(self, value): - """ - validate the pbc, then store them - """ - from aiida.common.exceptions import ModificationNotAllowed - from aiida.orm.nodes.data.structure import get_valid_pbc - - if self.is_stored: - raise ModificationNotAllowed('The KpointsData object cannot be modified, it has already been stored') - the_pbc = get_valid_pbc(value) - self.base.attributes.set('pbc1', the_pbc[0]) - self.base.attributes.set('pbc2', the_pbc[1]) - self.base.attributes.set('pbc3', the_pbc[2]) - - @property - def labels(self): - """ - Labels associated with the list of kpoints. - List of tuples with kpoint index and kpoint name: ``[(0,'G'),(13,'M'),...]`` - """ - label_numbers = self.base.attributes.get('label_numbers', None) - labels = self.base.attributes.get('labels', None) - if labels is None or label_numbers is None: - return None - return list(zip(label_numbers, labels)) - - @labels.setter - def labels(self, value): - self._set_labels(value) - - def _set_labels(self, value): - """ - set label names. Must pass in input a list like: ``[[0,'X'],[34,'L'],... ]`` - """ - # check if kpoints were set - try: - self.get_kpoints() - except AttributeError: - raise AttributeError('Kpoints must be set before the labels') - - if value is None: - value = [] - - try: - label_numbers = [int(i[0]) for i in value] - except ValueError: - raise ValueError('The input must contain an integer index, to map the labels into the kpoint list') - labels = [str(i[1]) for i in value] - - if any(i > len(self.get_kpoints()) - 1 for i in label_numbers): - raise ValueError('Index of label exceeding the list of kpoints') - - self.base.attributes.set('label_numbers', label_numbers) - self.base.attributes.set('labels', labels) - - def _change_reference(self, kpoints, to_cartesian=True): - """ - Change reference system, from cartesian to crystal coordinates (units of b1,b2,b3) or viceversa. - :param kpoints: a list of (3) point coordinates - :return kpoints: a list of (3) point coordinates in the new reference - """ - if not isinstance(kpoints, numpy.ndarray): - raise ValueError('kpoints must be a numpy.array for method change_reference()') - - try: - rec_cell = self.reciprocal_cell - except AttributeError: - # rec_cell = numpy.eye(3) - raise AttributeError('Cannot use cartesian coordinates without having defined a cell') - - trec_cell = numpy.transpose(numpy.array(rec_cell)) - if to_cartesian: - matrix = trec_cell - else: - matrix = numpy.linalg.inv(trec_cell) - - # note: kpoints is a list Nx3, matrix is 3x3. - # hence, first transpose kpoints, then multiply, finally transpose it back - return numpy.transpose(numpy.dot(matrix, numpy.transpose(kpoints))) - - def set_cell_from_structure(self, structuredata): - """ - Set a cell to be used for symmetry analysis from an AiiDA structure. - Inherits both the cell and the pbc's. - To set manually a cell, use "set_cell" - - :param structuredata: an instance of StructureData - """ - from aiida.orm import StructureData - - if not isinstance(structuredata, StructureData): - raise ValueError( - 'An instance of StructureData should be passed to ' - 'the KpointsData, found instead {}'.format(structuredata.__class__) - ) - cell = structuredata.cell - self.set_cell(cell, structuredata.pbc) - - def set_cell(self, cell, pbc=None): - """ - Set a cell to be used for symmetry analysis. - To set a cell from an AiiDA structure, use "set_cell_from_structure". - - :param cell: 3x3 matrix of cell vectors. Orientation: each row - represent a lattice vector. Units are Angstroms. - :param pbc: list of 3 booleans, True if in the nth crystal direction the - structure is periodic. Default = [True,True,True] - """ - self.cell = cell - if pbc is None: - pbc = [True, True, True] - self.pbc = pbc - - @property - def reciprocal_cell(self): - """ - Compute reciprocal cell from the internally set cell. - - :returns: reciprocal cell in units of 1/Angstrom with cell vectors stored as rows. - Use e.g. reciprocal_cell[0] to access the first reciprocal cell vector. - """ - the_cell = numpy.array(self.cell) - reciprocal_cell = 2. * numpy.pi * numpy.linalg.inv(the_cell).transpose() - return reciprocal_cell - - def set_kpoints_mesh(self, mesh, offset=None): - """ - Set KpointsData to represent a uniformily spaced mesh of kpoints in the - Brillouin zone. This excludes the possibility of set/get kpoints - - :param mesh: a list of three integers, representing the size of the - kpoint mesh along b1,b2,b3. - :param offset: (optional) a list of three floats between 0 and 1. - [0.,0.,0.] is Gamma centered mesh - [0.5,0.5,0.5] is half shifted - [1.,1.,1.] by periodicity should be equivalent to [0.,0.,0.] - Default = [0.,0.,0.]. - """ - from aiida.common.exceptions import ModificationNotAllowed - - # validate - try: - the_mesh = [int(i) for i in mesh] - if len(the_mesh) != 3: - raise ValueError - except (IndexError, ValueError, TypeError): - raise ValueError('The kpoint mesh must be a list of three integers') - if offset is None: - offset = [0., 0., 0.] - try: - the_offset = [float(i) for i in offset] - if len(the_offset) != 3: - raise ValueError - except (IndexError, ValueError, TypeError): - raise ValueError('The offset must be a list of three floats') - # check that there is no list of kpoints saved already - # I cannot have both of them at the same time - try: - _ = self.get_array('kpoints') - raise ModificationNotAllowed('KpointsData has already a kpoint-list stored') - except KeyError: - pass - - # store - self.base.attributes.set('mesh', the_mesh) - self.base.attributes.set('offset', the_offset) - - def get_kpoints_mesh(self, print_list=False): - """ - Get the mesh of kpoints. - - :param print_list: default=False. If True, prints the mesh of kpoints as a list - - :raise AttributeError: if no mesh has been set - :return mesh,offset: (if print_list=False) a list of 3 integers and a list of three - floats 0= self.numsteps: - raise IndexError(f'You have only {self.numsteps} steps, but you are looking beyond (index={index})') - - vel = self.get_velocities() - if vel is not None: - vel = vel[index, :, :] - time = self.get_times() - if time is not None: - time = time[index] - cells = self.get_cells() - if cells is not None: - cell = cells[index, :, :] - else: - cell = None - return (self.get_stepids()[index], time, cell, self.symbols, self.get_positions()[index, :, :], vel) - - def get_step_structure(self, index, custom_kinds=None): - """ - Return an AiiDA :py:class:`aiida.orm.nodes.data.structure.StructureData` node - (not stored yet!) with the coordinates of the given step, identified by - its index. If you know only the step value, use the - :py:meth:`.get_index_from_stepid` method to get the corresponding index. - - .. note:: The periodic boundary conditions are always set to True. - - .. versionadded:: 0.7 - Renamed from step_to_structure - - :param index: The index of the step that you want to retrieve, from - 0 to ``self.numsteps- 1``. - :param custom_kinds: (Optional) If passed must be a list of - :py:class:`aiida.orm.nodes.data.structure.Kind` objects. There must be one - kind object for each different string in the ``symbols`` array, with - ``kind.name`` set to this string. - If this parameter is omitted, the automatic kind generation of AiiDA - :py:class:`aiida.orm.nodes.data.structure.StructureData` nodes is used, - meaning that the strings in the ``symbols`` array must be valid - chemical symbols. - - :return: :py:class:`aiida.orm.nodes.data.structure.StructureData` node. - """ - from aiida.orm.nodes.data.structure import Kind, Site, StructureData - - # ignore step, time, and velocities - _, _, cell, symbols, positions, _ = self.get_step_data(index) - - if custom_kinds is not None: - kind_names = [] - for k in custom_kinds: - if not isinstance(k, Kind): - raise TypeError( - 'Each element of the custom_kinds list must ' - 'be a aiida.orm.nodes.data.structure.Kind object' - ) - kind_names.append(k.name) - if len(kind_names) != len(set(kind_names)): - raise ValueError('Multiple kinds with the same name passed as custom_kinds') - if set(kind_names) != set(symbols): - raise ValueError( - 'If you pass custom_kinds, you have to ' - 'pass one Kind object for each symbol ' - 'that is present in the trajectory. You ' - 'passed {}, but the symbols are {}'.format(sorted(kind_names), sorted(symbols)) - ) - - struc = StructureData(cell=cell) - if custom_kinds is not None: - for _k in custom_kinds: - struc.append_kind(_k) - for _s, _p in zip(symbols, positions): - struc.append_site(Site(kind_name=_s, position=_p)) - else: - for _s, _p in zip(symbols, positions): - # Automatic species generation - struc.append_atom(symbols=_s, position=_p) - - return struc - - def _prepare_xsf(self, index=None, main_file_name=''): # pylint: disable=unused-argument - """ - Write the given trajectory to a string of format XSF (for XCrySDen). - """ - from aiida.common.constants import elements - _atomic_numbers = {data['symbol']: num for num, data in elements.items()} - - indices = list(range(self.numsteps)) - if index is not None: - indices = [index] - return_string = f'ANIMSTEPS {len(indices)}\nCRYSTAL\n' - # Do the checks once and for all here: - structure = self.get_step_structure(index=0) - if structure.is_alloy or structure.has_vacancies: - raise NotImplementedError('XSF for alloys or systems with vacancies not implemented.') - cells = self.get_cells() - if cells is None: - raise ValueError('No cell parameters have been supplied for TrajectoryData') - positions = self.get_positions() - symbols = self.symbols - atomic_numbers_list = [_atomic_numbers[s] for s in symbols] - nat = len(symbols) - - for idx in indices: - return_string += f'PRIMVEC {idx + 1}\n' - for cell_vector in cells[idx]: - return_string += ' '.join([f'{i:18.5f}' for i in cell_vector]) - return_string += '\n' - return_string += f'PRIMCOORD {idx + 1}\n' - return_string += f'{nat} 1\n' - for atn, pos in zip(atomic_numbers_list, positions[idx]): - try: - return_string += f'{atn} {pos[0]:18.10f} {pos[1]:18.10f} {pos[2]:18.10f}\n' - except: - print(atn, pos) - raise - return return_string.encode('utf-8'), {} - - def _prepare_cif(self, trajectory_index=None, main_file_name=''): # pylint: disable=unused-argument - """ - Write the given trajectory to a string of format CIF. - """ - from aiida.common.utils import Capturing - from aiida.orm.nodes.data.cif import ase_loops, cif_from_ase, pycifrw_from_cif - - cif = '' - indices = list(range(self.numsteps)) - if trajectory_index is not None: - indices = [trajectory_index] - for idx in indices: - structure = self.get_step_structure(idx) - ciffile = pycifrw_from_cif(cif_from_ase(structure.get_ase()), ase_loops) - with Capturing(): - cif = cif + ciffile.WriteOut() - return cif.encode('utf-8'), {} - - def get_structure(self, store=False, **kwargs): - """ - Creates :py:class:`aiida.orm.nodes.data.structure.StructureData`. - - .. versionadded:: 1.0 - Renamed from _get_aiida_structure - - :param store: If True, intermediate calculation gets stored in the - AiiDA database for record. Default False. - :param index: The index of the step that you want to retrieve, from - 0 to ``self.numsteps- 1``. - :param custom_kinds: (Optional) If passed must be a list of - :py:class:`aiida.orm.nodes.data.structure.Kind` objects. There must be one - kind object for each different string in the ``symbols`` array, with - ``kind.name`` set to this string. - If this parameter is omitted, the automatic kind generation of AiiDA - :py:class:`aiida.orm.nodes.data.structure.StructureData` nodes is used, - meaning that the strings in the ``symbols`` array must be valid - chemical symbols. - :param custom_cell: (Optional) The cell matrix of the structure. - If omitted, the cell will be read from the trajectory, if present, - otherwise the default cell of - :py:class:`aiida.orm.nodes.data.structure.StructureData` will be used. - - :return: :py:class:`aiida.orm.nodes.data.structure.StructureData` node. - """ - from aiida.orm.nodes.data.dict import Dict - from aiida.tools.data.array.trajectory import _get_aiida_structure_inline - - param = Dict(kwargs) - - ret_dict = _get_aiida_structure_inline(trajectory=self, parameters=param, metadata={'store_provenance': store}) # pylint: disable=unexpected-keyword-arg - return ret_dict['structure'] - - def get_cif(self, index=None, **kwargs): - """ - Creates :py:class:`aiida.orm.nodes.data.cif.CifData` - - .. versionadded:: 1.0 - Renamed from _get_cif - """ - struct = self.get_structure(index=index, **kwargs) - cif = struct.get_cif(**kwargs) - return cif - - def _parse_xyz_pos(self, inputstring): - """ - Load positions from a XYZ file. - - .. note:: The steps and symbols must be set manually before calling this - import function as a consistency measure. Even though the symbols - and steps could be extracted from the XYZ file, the data present in - the XYZ file may or may not be correct and the same logic would have - to be present in the XYZ-velocities function. It was therefore - decided not to implement it at all but require it to be set - explicitly. - - Usage:: - - from aiida.orm.nodes.data.array.trajectory import TrajectoryData - - t = TrajectoryData() - # get sites and number of timesteps - t.set_array('steps', arange(ntimesteps)) - t.set_array('symbols', array([site.kind for site in s.sites])) - t.importfile('some-calc/AIIDA-PROJECT-pos-1.xyz', 'xyz_pos') - """ - - from numpy import array - - from aiida.common.exceptions import ValidationError - from aiida.tools.data.structure import xyz_parser_iterator - - numsteps = self.numsteps - if numsteps == 0: - raise ValidationError('steps must be set before importing positional data') - - numsites = self.numsites - if numsites == 0: - raise ValidationError('symbols must be set before importing positional data') - - positions = array( - [[list(position) for _, position in atoms] for _, _, atoms in xyz_parser_iterator(inputstring)] - ) - - if positions.shape != (numsteps, numsites, 3): - raise ValueError( - 'TrajectoryData.positions must have shape (s,n,3), ' - 'with s=number of steps={} and ' - 'n=number of symbols={}'.format(numsteps, numsites) - ) - - self.set_array('positions', positions) - - def _parse_xyz_vel(self, inputstring): - """ - Load velocities from a XYZ file. - - .. note:: The steps and symbols must be set manually before calling this - import function as a consistency measure. See also comment for - :py:meth:`._parse_xyz_pos` - """ - - from numpy import array - - from aiida.common.exceptions import ValidationError - from aiida.tools.data.structure import xyz_parser_iterator - - numsteps = self.numsteps - if numsteps == 0: - raise ValidationError('steps must be set before importing positional data') - - numsites = self.numsites - if numsites == 0: - raise ValidationError('symbols must be set before importing positional data') - - velocities = array( - [[list(velocity) for _, velocity in atoms] for _, _, atoms in xyz_parser_iterator(inputstring)] - ) - - if velocities.shape != (numsteps, numsites, 3): - raise ValueError( - 'TrajectoryData.positions must have shape (s,n,3), ' - 'with s=number of steps={} and ' - 'n=number of symbols={}'.format(numsteps, numsites) - ) - - self.set_array('velocities', velocities) - - def show_mpl_pos(self, **kwargs): # pylint: disable=too-many-locals - """ - Shows the positions as a function of time, separate for XYZ coordinates - - :param int stepsize: The stepsize for the trajectory, set higher than 1 to - reduce number of points - :param int mintime: Time to start from - :param int maxtime: Maximum time - :param list elements: - A list of atomic symbols that should be displayed. - If not specified, all atoms are displayed. - :param list indices: - A list of indices of that atoms that can be displayed. - If not specified, all atoms of the correct species are displayed. - :param bool dont_block: If True, interpreter is not blocked when figure is displayed. - """ - from ase.data import atomic_numbers - - # Reading the arrays I need: - positions = self.get_positions() - times = self.get_times() - symbols = self.symbols - - # Try to get the units. - try: - positions_unit = self.base.attributes.get('units|positions') - except AttributeError: - positions_unit = 'A' - try: - times_unit = self.base.attributes.get('units|times') - except AttributeError: - times_unit = 'ps' - - # Getting the keyword input - stepsize = kwargs.pop('stepsize', 1) - maxtime = kwargs.pop('maxtime', times[-1]) - mintime = kwargs.pop('mintime', times[0]) - element_list = kwargs.pop('elements', None) - index_list = kwargs.pop('indices', None) - dont_block = kwargs.pop('dont_block', False) - label = kwargs.pop('label', None) or self.label or self.__repr__() - # Choosing the color scheme - - colors = kwargs.pop('colors', 'jmol') - if colors == 'jmol': - from ase.data.colors import jmol_colors as colors - elif colors == 'cpk': - from ase.data.colors import cpk_colors as colors - else: - raise ValueError(f'Unknown color spec {colors}') - if kwargs: - raise ValueError(f'Unrecognized keyword {kwargs.keys()}') - - if element_list is None: - # If not all elements are allowed - allowed_elements = set(symbols) - else: - # A subset of elements are allowed - allowed_elements = set(element_list) - color_dict = {s: colors[atomic_numbers[s]] for s in set(symbols)} - # Here I am trying to find out the atoms to show - if index_list is None: - # If not index_list was provided, I will see if an element_list - # was given to me - indices_to_show = [i for i, sym in enumerate(symbols) if sym in allowed_elements] - else: - indices_to_show = index_list - # I refrain from checking if indices are ok, will crash if not... - - # The color_list is a list of colors (RGB) that I will - # pass, so the different species give different colors in the plot - color_list = [color_dict[s] for s in symbols] - - # Reducing array size based on stepsize variable - _times = times[::stepsize] - _positions = positions[::stepsize] - - # Calling - plot_positions_XYZ( - _times, - _positions, - indices_to_show, - color_list, - label, - positions_unit, - times_unit, - dont_block, - mintime, - maxtime, - ) - - def show_mpl_heatmap(self, **kwargs): # pylint: disable=invalid-name,too-many-arguments,too-many-locals,too-many-statements,too-many-branches - """ - Show a heatmap of the trajectory with matplotlib. - """ - import numpy as np - from scipy import stats - try: - from mayavi import mlab - except ImportError: - raise ImportError( - 'Unable to import the mayavi package, that is required to' - 'use the plotting feature you requested. ' - 'Please install it first and then call this command again ' - '(note that the installation of mayavi is quite complicated ' - 'and requires that you already installed the python numpy ' - 'package, as well as the vtk package' - ) - from ase.data import atomic_numbers - from ase.data.colors import jmol_colors - - # pylint: disable=invalid-name - - def collapse_into_unit_cell(point, cell): - """ - Applies linear transformation to coordinate system based on crystal - lattice, vectors. The inverse of that inverse transformation matrix with the - point given results in the point being given as a multiples of lattice vectors - Than take the integer of the rows to find how many times you have to shift - the point back""" - invcell = np.matrix(cell).T.I # pylint: disable=no-member - # point in crystal coordinates - points_in_crystal = np.dot(invcell, point).tolist()[0] - #point collapsed into unit cell - points_in_unit_cell = [i % 1 for i in points_in_crystal] - return np.dot(cell.T, points_in_unit_cell).tolist() - - elements = kwargs.pop('elements', None) - mintime = kwargs.pop('mintime', None) - maxtime = kwargs.pop('maxtime', None) - stepsize = kwargs.pop('stepsize', None) or 1 - contours = np.array(kwargs.pop('contours', None) or (0.1, 0.5)) - sampling_stepsize = int(kwargs.pop('sampling_stepsize', None) or 0) - - times = self.get_times() - if mintime is None: - minindex = 0 - else: - minindex = np.argmax(times > mintime) - if maxtime is None: - maxindex = len(times) - else: - maxindex = np.argmin(times < maxtime) - positions = self.get_positions()[minindex:maxindex:stepsize] - - try: - if self.base.attributes.get('units|positions') in ('bohr', 'atomic'): - bohr_to_ang = 0.52917720859 - positions *= bohr_to_ang - except AttributeError: - pass - - symbols = self.symbols - if elements is None: - elements = set(symbols) - - cells = self.get_cells() - if cells is None: - raise ValueError('No cell parameters have been supplied for TrajectoryData') - else: - cell = np.array(cells[0]) - storage_dict = {s: {} for s in elements} - for ele in elements: - storage_dict[ele] = [np.array([]), np.array([]), np.array([])] - for iat, ele in enumerate(symbols): - if ele in elements: - for idim in range(3): - storage_dict[ele][idim] = np.concatenate( - (storage_dict[ele][idim], positions[:, iat, idim].flatten()) - ) - - for ele in elements: - storage_dict[ele] = np.array(storage_dict[ele]).T - storage_dict[ele] = np.array([collapse_into_unit_cell(pos, cell) for pos in storage_dict[ele]]).T - - white = (1, 1, 1) - mlab.figure(bgcolor=white, size=(1080, 720)) - - for i1, a in enumerate(cell): - i2 = (i1 + 1) % 3 - i3 = (i1 + 2) % 3 - for b in [np.zeros(3), cell[i2]]: - for c in [np.zeros(3), cell[i3]]: - p1 = b + c - p2 = p1 + a - mlab.plot3d([p1[0], p2[0]], [p1[1], p2[1]], [p1[2], p2[2]], tube_radius=0.1) - - for ele, data in storage_dict.items(): - kde = stats.gaussian_kde(data, bw_method=0.15) - - _x = data[0, :] - _y = data[1, :] - _z = data[2, :] - xmin, ymin, zmin = _x.min(), _y.min(), _z.min() - xmax, ymax, zmax = _x.max(), _y.max(), _z.max() - - _xi, _yi, _zi = np.mgrid[xmin:xmax:60j, ymin:ymax:30j, zmin:zmax:30j] # pylint: disable=invalid-slice-index - coords = np.vstack([item.ravel() for item in [_xi, _yi, _zi]]) - density = kde(coords).reshape(_xi.shape) - - # Plot scatter with mayavi - #~ figure = mlab.figure('DensityPlot') - grid = mlab.pipeline.scalar_field(_xi, _yi, _zi, density) - #~ min = density.min() - maxdens = density.max() - #~ mlab.pipeline.volume(grid, vmin=min, vmax=min + .5*(max-min)) - surf = mlab.pipeline.iso_surface(grid, opacity=0.5, colormap='cool', contours=(maxdens * contours).tolist()) - lut = surf.module_manager.scalar_lut_manager.lut.table.to_array() - - # The lut is a 255x4 array, with the columns representing RGBA - # (red, green, blue, alpha) coded with integers going from 0 to 255. - - # We modify the alpha channel to add a transparency gradient - lut[:, -1] = np.linspace(100, 255, 256) - lut[:, 0:3] = 255 * jmol_colors[atomic_numbers[ele]] - # and finally we put this LUT back in the surface object. We could have - # added any 255*4 array rather than modifying an existing LUT. - surf.module_manager.scalar_lut_manager.lut.table = lut - - if sampling_stepsize > 0: - mlab.points3d( - _x[::sampling_stepsize], - _y[::sampling_stepsize], - _z[::sampling_stepsize], - color=tuple(jmol_colors[atomic_numbers[ele]].tolist()), - scale_mode='none', - scale_factor=0.3, - opacity=0.3 - ) - - mlab.view(azimuth=155, elevation=70, distance='auto') - mlab.show() - - -def plot_positions_XYZ( # pylint: disable=too-many-arguments,too-many-locals,invalid-name - times, - positions, - indices_to_show, - color_list, - label, - positions_unit='A', - times_unit='ps', - dont_block=False, - mintime=None, - maxtime=None, - label_sparsity=10): - """ - Plot with matplotlib the positions of the coordinates of the atoms - over time for a trajectory - - :param times: array of times - :param positions: array of positions - :param indices_to_show: list of indices of to show (0, 1, 2 for X, Y, Z) - :param color_list: list of valid color specifications for matplotlib - :param label: label for this plot to put in the title - :param positions_unit: label for the units of positions (for the x label) - :param times_unit: label for the units of times (for the y label) - :param dont_block: passed to plt.show() as ``block=not dont_block`` - :param mintime: if specified, cut the time axis at the specified min value - :param maxtime: if specified, cut the time axis at the specified max value - :param label_sparsity: how often to put a label with the pair (t, coord) - """ - from matplotlib import pyplot as plt - from matplotlib.gridspec import GridSpec - import numpy as np - - tlim = [times[0], times[-1]] - index_range = [0, len(times)] - if mintime is not None: - tlim[0] = mintime - index_range[0] = np.argmax(times > mintime) - if maxtime is not None: - tlim[1] = maxtime - index_range[1] = np.argmin(times < maxtime) - - trajectories = zip(*positions.tolist()) # only used in enumerate() below - fig = plt.figure(figsize=(12, 7)) - - plt.suptitle(r'Trajectory of {}'.format(label), fontsize=16) - nr_of_axes = 3 - gridspec = GridSpec(nr_of_axes, 1, hspace=0.0) - - ax1 = fig.add_subplot(gridspec[0]) - plt.ylabel(r'X Position $\left[{}\right]$'.format(positions_unit)) - plt.xticks([]) - plt.xlim(*tlim) - ax2 = fig.add_subplot(gridspec[1]) - plt.ylabel(r'Y Position $\left[{}\right]$'.format(positions_unit)) - plt.xticks([]) - plt.xlim(*tlim) - ax3 = fig.add_subplot(gridspec[2]) - plt.ylabel(r'Z Position $\left[{}\right]$'.format(positions_unit)) - plt.xlabel(f'Time [{times_unit}]') - plt.xlim(*tlim) - sparse_indices = np.linspace(*index_range, num=label_sparsity, dtype=int) - - for index, traj in enumerate(trajectories): - if index not in indices_to_show: - continue - color = color_list[index] - _x, _y, _z = list(zip(*traj)) - ax1.plot(times, _x, color=color) - ax2.plot(times, _y, color=color) - ax3.plot(times, _z, color=color) - for i in sparse_indices: - ax1.text(times[i], _x[i], str(index), color=color, fontsize=5) - ax2.text(times[i], _x[i], str(index), color=color, fontsize=5) - ax3.text(times[i], _x[i], str(index), color=color, fontsize=5) - for axes in ax1, ax2, ax3: - yticks = axes.yaxis.get_major_ticks() - yticks[0].label1.set_visible(False) - - plt.show(block=not dont_block) diff --git a/aiida/orm/nodes/data/array/xy.py b/aiida/orm/nodes/data/array/xy.py deleted file mode 100644 index 7827fc050f..0000000000 --- a/aiida/orm/nodes/data/array/xy.py +++ /dev/null @@ -1,152 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -This module defines the classes related to Xy data. That is data that contains -collections of y-arrays bound to a single x-array, and the methods to operate -on them. -""" -import numpy as np - -from aiida.common.exceptions import NotExistent - -from .array import ArrayData - -__all__ = ('XyData',) - - -def check_convert_single_to_tuple(item): - """ - Checks if the item is a list or tuple, and converts it to a list if it is - not already a list or tuple - - :param item: an object which may or may not be a list or tuple - :return: item_list: the input item unchanged if list or tuple and [item] - otherwise - """ - if isinstance(item, (list, tuple)): - return item - - return [item] - - -class XyData(ArrayData): - """ - A subclass designed to handle arrays that have an "XY" relationship to - each other. That is there is one array, the X array, and there are several - Y arrays, which can be considered functions of X. - """ - - @staticmethod - def _arrayandname_validator(array, name, units): - """ - Validates that the array is an numpy.ndarray and that the name is - of type str. Raises TypeError or ValueError if this not the case. - """ - if not isinstance(name, str): - raise TypeError('The name must always be a str.') - - if not isinstance(array, np.ndarray): - raise TypeError('The input array must always be a numpy array') - try: - array.astype(float) - except ValueError as exc: - raise TypeError('The input array must only contain floats') from exc - if not isinstance(units, str): - raise TypeError('The units must always be a str.') - - def set_x(self, x_array, x_name, x_units): - """ - Sets the array and the name for the x values. - - :param x_array: A numpy.ndarray, containing only floats - :param x_name: a string for the x array name - :param x_units: the units of x - """ - self._arrayandname_validator(x_array, x_name, x_units) - self.base.attributes.set('x_name', x_name) - self.base.attributes.set('x_units', x_units) - self.set_array('x_array', x_array) - - def set_y(self, y_arrays, y_names, y_units): - """ - Set array(s) for the y part of the dataset. Also checks if the - x_array has already been set, and that, the shape of the y_arrays - agree with the x_array. - :param y_arrays: A list of y_arrays, numpy.ndarray - :param y_names: A list of strings giving the names of the y_arrays - :param y_units: A list of strings giving the units of the y_arrays - """ - # for the case of single name, array, tag input converts to a list - y_arrays = check_convert_single_to_tuple(y_arrays) - y_names = check_convert_single_to_tuple(y_names) - y_units = check_convert_single_to_tuple(y_units) - - # checks that the input lengths match - if len(y_arrays) != len(y_names): - raise ValueError('Length of arrays and names do not match!') - if len(y_units) != len(y_names): - raise ValueError('Length of units does not match!') - - # Try to get the x_array - try: - x_array = self.get_x()[1] - except NotExistent as exc: - raise ValueError('X array has not been set yet') from exc - # validate each of the y_arrays - for num, (y_array, y_name, y_unit) in enumerate(zip(y_arrays, y_names, y_units)): - self._arrayandname_validator(y_array, y_name, y_unit) - if np.shape(y_array) != np.shape(x_array): - raise ValueError(f'y_array {y_name} did not have the same shape has the x_array!') - self.set_array(f'y_array_{num}', y_array) - - # if the y_arrays pass the initial validation, sets each - self.base.attributes.set('y_names', y_names) - self.base.attributes.set('y_units', y_units) - - def get_x(self): - """ - Tries to retrieve the x array and x name raises a NotExistent - exception if no x array has been set yet. - :return x_name: the name set for the x_array - :return x_array: the x array set earlier - :return x_units: the x units set earlier - """ - try: - x_name = self.base.attributes.get('x_name') - x_array = self.get_array('x_array') - x_units = self.base.attributes.get('x_units') - except (KeyError, AttributeError): - raise NotExistent('No x array has been set yet!') - return x_name, x_array, x_units - - def get_y(self): - """ - Tries to retrieve the y arrays and the y names, raises a - NotExistent exception if they have not been set yet, or cannot be - retrieved - :return y_names: list of strings naming the y_arrays - :return y_arrays: list of y_arrays - :return y_units: list of strings giving the units for the y_arrays - """ - try: - y_names = self.base.attributes.get('y_names') - except (KeyError, AttributeError): - raise NotExistent('No y names has been set yet!') - try: - y_units = self.base.attributes.get('y_units') - except (KeyError, AttributeError): - raise NotExistent('No y units has been set yet!') - y_arrays = [] - try: - for i in range(len(y_names)): - y_arrays += [self.get_array(f'y_array_{i}')] - except (KeyError, AttributeError): - raise NotExistent(f'Could not retrieve array associated with y array {y_names[i]}') - return list(zip(y_names, y_arrays, y_units)) diff --git a/aiida/orm/nodes/data/base.py b/aiida/orm/nodes/data/base.py deleted file mode 100644 index f95cacaa2e..0000000000 --- a/aiida/orm/nodes/data/base.py +++ /dev/null @@ -1,54 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""`Data` sub class to be used as a base for data containers that represent base python data types.""" -from functools import singledispatch - -from .data import Data - -__all__ = ('BaseType', 'to_aiida_type') - - -@singledispatch -def to_aiida_type(value): - """Turns basic Python types (str, int, float, bool) into the corresponding AiiDA types.""" - raise TypeError(f'Cannot convert value of type {type(value)} to AiiDA type.') - - -class BaseType(Data): - """`Data` sub class to be used as a base for data containers that represent base python data types.""" - - def __init__(self, value=None, **kwargs): - try: - getattr(self, '_type') - except AttributeError: - raise RuntimeError('Derived class must define the `_type` class member') - - super().__init__(**kwargs) - - self.value = value or self._type() # pylint: disable=no-member - - @property - def value(self): - return self.base.attributes.get('value', None) - - @value.setter - def value(self, value): - self.base.attributes.set('value', self._type(value)) # pylint: disable=no-member - - def __str__(self): - return f'{super().__str__()} value: {self.value}' - - def __eq__(self, other): - if isinstance(other, BaseType): - return self.value == other.value - return self.value == other - - def new(self, value=None): - return self.__class__(value) diff --git a/aiida/orm/nodes/data/cif.py b/aiida/orm/nodes/data/cif.py deleted file mode 100644 index 460bc2d9cb..0000000000 --- a/aiida/orm/nodes/data/cif.py +++ /dev/null @@ -1,807 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name,too-many-locals,too-many-statements -"""Tools for handling Crystallographic Information Files (CIF)""" - -import re - -from aiida.common.utils import Capturing - -from .singlefile import SinglefileData - -__all__ = ('CifData', 'cif_from_ase', 'has_pycifrw', 'pycifrw_from_cif') - -ase_loops = { - '_atom_site': [ - '_atom_site_label', - '_atom_site_occupancy', - '_atom_site_fract_x', - '_atom_site_fract_y', - '_atom_site_fract_z', - '_atom_site_adp_type', - '_atom_site_thermal_displace_type', - '_atom_site_B_iso_or_equiv', - '_atom_site_U_iso_or_equiv', - '_atom_site_B_equiv_geom_mean', - '_atom_site_U_equiv_geom_mean', - '_atom_site_type_symbol', - ] -} - - -def has_pycifrw(): - """ - :return: True if the PyCifRW module can be imported, False otherwise. - """ - # pylint: disable=unused-variable,unused-import - try: - import CifFile - from CifFile import CifBlock - except ImportError: - return False - return True - - -def cif_from_ase(ase, full_occupancies=False, add_fake_biso=False): - """ - Construct a CIF datablock from the ASE structure. The code is taken - from - https://wiki.fysik.dtu.dk/ase/ase/io/formatoptions.html#ase.io.cif.write_cif, - as the original ASE code contains a bug in printing the - Hermann-Mauguin symmetry space group symbol. - - :param ase: ASE "images" - :return: array of CIF datablocks - """ - from numpy import arccos, dot, pi - from numpy.linalg import norm - - if not isinstance(ase, (list, tuple)): - ase = [ase] - - datablocks = [] - for _, atoms in enumerate(ase): - datablock = {} - - cell = atoms.cell - a = norm(cell[0]) - b = norm(cell[1]) - c = norm(cell[2]) - alpha = arccos(dot(cell[1], cell[2]) / (b * c)) * 180. / pi - beta = arccos(dot(cell[0], cell[2]) / (a * c)) * 180. / pi - gamma = arccos(dot(cell[0], cell[1]) / (a * b)) * 180. / pi - - datablock['_cell_length_a'] = str(a) - datablock['_cell_length_b'] = str(b) - datablock['_cell_length_c'] = str(c) - datablock['_cell_angle_alpha'] = str(alpha) - datablock['_cell_angle_beta'] = str(beta) - datablock['_cell_angle_gamma'] = str(gamma) - - if atoms.pbc.all(): - datablock['_symmetry_space_group_name_H-M'] = 'P 1' - datablock['_symmetry_int_tables_number'] = str(1) - datablock['_symmetry_equiv_pos_as_xyz'] = ['x, y, z'] - - datablock['_atom_site_label'] = [] - datablock['_atom_site_fract_x'] = [] - datablock['_atom_site_fract_y'] = [] - datablock['_atom_site_fract_z'] = [] - datablock['_atom_site_type_symbol'] = [] - - if full_occupancies: - datablock['_atom_site_occupancy'] = [] - if add_fake_biso: - datablock['_atom_site_thermal_displace_type'] = [] - datablock['_atom_site_B_iso_or_equiv'] = [] - - scaled = atoms.get_scaled_positions() - no = {} - for i, atom in enumerate(atoms): - symbol = atom.symbol - if symbol in no: - no[symbol] += 1 - else: - no[symbol] = 1 - datablock['_atom_site_label'].append(symbol + str(no[symbol])) - datablock['_atom_site_fract_x'].append(str(scaled[i][0])) - datablock['_atom_site_fract_y'].append(str(scaled[i][1])) - datablock['_atom_site_fract_z'].append(str(scaled[i][2])) - datablock['_atom_site_type_symbol'].append(symbol) - - if full_occupancies: - datablock['_atom_site_occupancy'].append(str(1.0)) - if add_fake_biso: - datablock['_atom_site_thermal_displace_type'].append('Biso') - datablock['_atom_site_B_iso_or_equiv'].append(str(1.0)) - - datablocks.append(datablock) - return datablocks - - -# pylint: disable=too-many-branches -def pycifrw_from_cif(datablocks, loops=None, names=None): - """ - Constructs PyCifRW's CifFile from an array of CIF datablocks. - - :param datablocks: an array of CIF datablocks - :param loops: optional dict of lists of CIF tag loops. - :param names: optional list of datablock names - :return: CifFile - """ - try: - import CifFile - from CifFile import CifBlock - except ImportError as exc: - raise ImportError(f'{str(exc)}. You need to install the PyCifRW package.') - - if loops is None: - loops = {} - - cif = CifFile.CifFile() # pylint: disable=no-member - try: - cif.set_grammar('1.1') - except AttributeError: - # if no grammar can be set, we assume it's 1.1 (widespread standard) - pass - - if names and len(names) < len(datablocks): - raise ValueError( - f'Not enough names supplied for datablocks: {len(names)} (names) < {len(datablocks)} (datablocks)' - ) - for i, values in enumerate(datablocks): - name = str(i) - if names: - name = names[i] - datablock = CifBlock() - cif[name] = datablock - tags_in_loops = [] - for loopname in loops.keys(): - row_size = None - tags_seen = [] - for tag in loops[loopname]: - if tag in values: - tag_values = values.pop(tag) - if not isinstance(tag_values, list): - tag_values = [tag_values] - if row_size is None: - row_size = len(tag_values) - elif row_size != len(tag_values): - raise ValueError( - f'Number of values for tag `{tag}` is different from the others in the same loop' - ) - if row_size == 0: - continue - datablock.AddItem(tag, tag_values) - tags_seen.append(tag) - tags_in_loops.append(tag) - if row_size is not None and row_size > 0: - datablock.CreateLoop(datanames=tags_seen) - for tag in sorted(values.keys()): - if not tag in tags_in_loops: - datablock.AddItem(tag, values[tag]) - # create automatically a loop for non-scalar values - if isinstance(values[tag], (tuple, list)) and tag not in loops.keys(): - datablock.CreateLoop([tag]) - return cif - - -def parse_formula(formula): - """ - Parses the Hill formulae. Does not need spaces as separators. - Works also for partial occupancies and for chemical groups enclosed in round/square/curly brackets. - Elements are counted and a dictionary is returned. - e.g. 'C[NH2]3NO3' --> {'C': 1, 'N': 4, 'H': 6, 'O': 3} - """ - - def chemcount_str_to_number(string): - if not string: - quantity = 1 - else: - quantity = float(string) - if quantity.is_integer(): - quantity = int(quantity) - return quantity - - contents = {} - - # split blocks with parentheses - for block in re.split(r'(\([^\)]*\)[^A-Z\(\[\{]*|\[[^\]]*\][^A-Z\(\[\{]*|\{[^\}]*\}[^A-Z\(\[\{]*)', formula): - if not block: # block is void - continue - - # get molecular formula (within parentheses) & count - group = re.search(r'[\{\[\(](.+)[\}\]\)]([\.\d]*)', block) - if group is None: # block does not contain parentheses - molformula = block - molcount = 1 - else: - molformula = group.group(1) - molcount = chemcount_str_to_number(group.group(2)) - - for part in re.findall(r'[A-Z][^A-Z\s]*', molformula.replace(' ', '')): # split at uppercase letters - match = re.match(r'(\D+)([\.\d]+)?', part) # separates element and count - - if match is None: - continue - - species = match.group(1) - quantity = chemcount_str_to_number(match.group(2)) * molcount - contents[species] = contents.get(species, 0) + quantity - return contents - - -# pylint: disable=abstract-method,too-many-public-methods -# Note: Method 'query' is abstract in class 'Node' but is not overridden -class CifData(SinglefileData): - """ - Wrapper for Crystallographic Interchange File (CIF) - - .. note:: the file (physical) is held as the authoritative source of - information, so all conversions are done through the physical file: - when setting ``ase`` or ``values``, a physical CIF file is generated - first, the values are updated from the physical CIF file. - """ - # pylint: disable=abstract-method, too-many-public-methods - _SET_INCOMPATIBILITIES = [('ase', 'file'), ('ase', 'values'), ('file', 'values')] - _SCAN_TYPES = ('standard', 'flex') - _SCAN_TYPE_DEFAULT = 'standard' - _PARSE_POLICIES = ('eager', 'lazy') - _PARSE_POLICY_DEFAULT = 'eager' - - _values = None - _ase = None - - def __init__(self, ase=None, file=None, filename=None, values=None, scan_type=None, parse_policy=None, **kwargs): - """Construct a new instance and set the contents to that of the file. - - :param file: an absolute filepath or filelike object for CIF. - Hint: Pass io.BytesIO(b"my string") to construct the SinglefileData directly from a string. - :param filename: specify filename to use (defaults to name of provided file). - :param ase: ASE Atoms object to construct the CifData instance from. - :param values: PyCifRW CifFile object to construct the CifData instance from. - :param scan_type: scan type string for parsing with PyCIFRW ('standard' or 'flex'). See CifFile.ReadCif - :param parse_policy: 'eager' (parse CIF file on set_file) or 'lazy' (defer parsing until needed) - """ - - # pylint: disable=too-many-arguments, redefined-builtin - - args = { - 'ase': ase, - 'file': file, - 'values': values, - } - - for left, right in CifData._SET_INCOMPATIBILITIES: - if args[left] is not None and args[right] is not None: - raise ValueError(f'cannot pass {left} and {right} at the same time') - - super().__init__(file, filename=filename, **kwargs) - self.set_scan_type(scan_type or CifData._SCAN_TYPE_DEFAULT) - self.set_parse_policy(parse_policy or CifData._PARSE_POLICY_DEFAULT) - - if ase is not None: - self.set_ase(ase) - - if values is not None: - self.set_values(values) - - if not self.is_stored and file is not None and self.base.attributes.get('parse_policy') == 'eager': - self.parse() - - @staticmethod - def read_cif(fileobj, index=-1, **kwargs): - """ - A wrapper method that simulates the behavior of the old - function ase.io.cif.read_cif by using the new generic ase.io.read - function. - - Somewhere from 3.12 to 3.17 the tag concept was bundled with each Atom object. When - reading a CIF file, this is incremented and signifies the atomic species, even though - the CIF file do not have specific tags embedded. On reading CIF files we thus force the - ASE tag to zero for all Atom elements. - - """ - from ase.io import read - - # The read function returns a list as a cif file might contain multiple - # structures. - struct_list = read(fileobj, index=':', format='cif', **kwargs) - - if index is None: - # If index is explicitely set to None, the list is returned as such. - for atoms_entry in struct_list: - atoms_entry.set_tags(0) - return struct_list - # Otherwise return the desired structure specified by index, if no index is specified, - # the last structure is assumed by default. - struct_list[index].set_tags(0) - return struct_list[index] - - @classmethod - def from_md5(cls, md5, backend=None): - """ - Return a list of all CIF files that match a given MD5 hash. - - .. note:: the hash has to be stored in a ``_md5`` attribute, - otherwise the CIF file will not be found. - """ - from aiida.orm.querybuilder import QueryBuilder - builder = QueryBuilder(backend=backend) - builder.append(cls, filters={'attributes.md5': {'==': md5}}) - return builder.all(flat=True) - - @classmethod - def get_or_create(cls, filename, use_first=False, store_cif=True): - """ - Pass the same parameter of the init; if a file with the same md5 - is found, that CifData is returned. - - :param filename: an absolute filename on disk - :param use_first: if False (default), raise an exception if more than \ - one CIF file is found.\ - If it is True, instead, use the first available CIF file. - :param bool store_cif: If false, the CifData objects are not stored in - the database. default=True. - :return (cif, created): where cif is the CifData object, and create is either\ - True if the object was created, or False if the object was retrieved\ - from the DB. - """ - import os - - from aiida.common.files import md5_file - - if not os.path.abspath(filename): - raise ValueError('filename must be an absolute path') - md5 = md5_file(filename) - - cifs = cls.from_md5(md5) - if not cifs: - if store_cif: - instance = cls(file=filename).store() - return (instance, True) - instance = cls(file=filename) - return (instance, True) - - if len(cifs) > 1: - if use_first: - return (cifs[0], False) - - raise ValueError( - 'More than one copy of a CIF file ' - 'with the same MD5 has been found in ' - 'the DB. pks={}'.format(','.join([str(i.pk) for i in cifs])) - ) - - return cifs[0], False - - # pylint: disable=attribute-defined-outside-init - @property - def ase(self): - """ - ASE object, representing the CIF. - - .. note:: requires ASE module. - """ - if self._ase is None: - self._ase = self.get_ase() - return self._ase - - def get_ase(self, **kwargs): - """ - Returns ASE object, representing the CIF. This function differs - from the property ``ase`` by the possibility to pass the keyworded - arguments (kwargs) to ase.io.cif.read_cif(). - - .. note:: requires ASE module. - """ - if not kwargs and self._ase: - return self.ase - with self.open() as handle: - return CifData.read_cif(handle, **kwargs) - - def set_ase(self, aseatoms): - """ - Set the contents of the CifData starting from an ASE atoms object - - :param aseatoms: the ASE atoms object - """ - import tempfile - cif = cif_from_ase(aseatoms) - with tempfile.NamedTemporaryFile(mode='w+') as tmpf: - with Capturing(): - tmpf.write(pycifrw_from_cif(cif, loops=ase_loops).WriteOut()) - tmpf.flush() - self.set_file(tmpf.name) - - @ase.setter - def ase(self, aseatoms): - self.set_ase(aseatoms) - - @property - def values(self): - """ - PyCifRW structure, representing the CIF datablocks. - - .. note:: requires PyCifRW module. - """ - if self._values is None: - import CifFile - from CifFile import CifBlock # pylint: disable=no-name-in-module - - with self.open() as handle: - c = CifFile.ReadCif(handle, scantype=self.base.attributes.get('scan_type', CifData._SCAN_TYPE_DEFAULT)) # pylint: disable=no-member - for k, v in c.items(): - c.dictionary[k] = CifBlock(v) - self._values = c - return self._values - - def set_values(self, values): - """ - Set internal representation to `values`. - - Warning: This also writes a new CIF file. - - :param values: PyCifRW CifFile object - - .. note:: requires PyCifRW module. - """ - import tempfile - with tempfile.NamedTemporaryFile(mode='w+') as tmpf: - with Capturing(): - tmpf.write(values.WriteOut()) - tmpf.flush() - tmpf.seek(0) - self.set_file(tmpf) - - self._values = values - - @values.setter - def values(self, values): - self.set_values(values) - - def parse(self, scan_type=None): - """ - Parses CIF file and sets attributes. - - :param scan_type: See set_scan_type - """ - if scan_type is not None: - self.set_scan_type(scan_type) - - # Note: this causes parsing, if not already parsed - self.base.attributes.set('formulae', self.get_formulae()) - self.base.attributes.set('spacegroup_numbers', self.get_spacegroup_numbers()) - - def store(self, *args, **kwargs): # pylint: disable=signature-differs - """ - Store the node. - """ - if not self.is_stored: - self.base.attributes.set('md5', self.generate_md5()) - - return super().store(*args, **kwargs) - - def set_file(self, file, filename=None): - """ - Set the file. - - If the source is set and the MD5 checksum of new file - is different from the source, the source has to be deleted. - - :param file: filepath or filelike object of the CIF file to store. - Hint: Pass io.BytesIO(b"my string") to construct the file directly from a string. - :param filename: specify filename to use (defaults to name of provided file). - """ - # pylint: disable=redefined-builtin - super().set_file(file, filename=filename) - md5sum = self.generate_md5() - if isinstance(self.source, dict) and \ - self.source.get('source_md5', None) is not None and \ - self.source['source_md5'] != md5sum: - self.source = {} - self.base.attributes.set('md5', md5sum) - - self._values = None - self._ase = None - self.base.attributes.set('formulae', None) - self.base.attributes.set('spacegroup_numbers', None) - - def set_scan_type(self, scan_type): - """ - Set the scan_type for PyCifRW. - - The 'flex' scan_type of PyCifRW is faster for large CIF files but - does not yet support the CIF2 format as of 02/2018. - See the CifFile.ReadCif function - - :param scan_type: Either 'standard' or 'flex' (see _scan_types) - """ - if scan_type in CifData._SCAN_TYPES: - self.base.attributes.set('scan_type', scan_type) - else: - raise ValueError(f'Got unknown scan_type {scan_type}') - - def set_parse_policy(self, parse_policy): - """ - Set the parse policy. - - :param parse_policy: Either 'eager' (parse CIF file on set_file) - or 'lazy' (defer parsing until needed) - """ - if parse_policy in CifData._PARSE_POLICIES: - self.base.attributes.set('parse_policy', parse_policy) - else: - raise ValueError(f'Got unknown parse_policy {parse_policy}') - - def get_formulae(self, mode='sum', custom_tags=None): - """ - Return chemical formulae specified in CIF file. - - Note: This does not compute the formula, it only reads it from the - appropriate tag. Use refine_inline to compute formulae. - """ - # note: If formulae are not None, they could be returned - # directly (but the function is very cheap anyhow). - formula_tags = [f'_chemical_formula_{mode}'] - if custom_tags: - if not isinstance(custom_tags, (list, tuple)): - custom_tags = [custom_tags] - formula_tags.extend(custom_tags) - - formulae = [] - for datablock in self.values.keys(): - formula = None - for formula_tag in formula_tags: - if formula_tag in self.values[datablock].keys(): - formula = self.values[datablock][formula_tag] - break - formulae.append(formula) - - return formulae - - def get_spacegroup_numbers(self): - """ - Get the spacegroup international number. - """ - # note: If spacegroup_numbers are not None, they could be returned - # directly (but the function is very cheap anyhow). - spg_tags = ['_space_group.it_number', '_space_group_it_number', '_symmetry_int_tables_number'] - spacegroup_numbers = [] - for datablock in self.values.keys(): - spacegroup_number = None - correct_tags = [tag for tag in spg_tags if tag in self.values[datablock].keys()] - if correct_tags: - try: - spacegroup_number = int(self.values[datablock][correct_tags[0]]) - except ValueError: - pass - spacegroup_numbers.append(spacegroup_number) - - return spacegroup_numbers - - @property - def has_partial_occupancies(self): - """ - Return if the cif data contains partial occupancies - - A partial occupancy is defined as site with an occupancy that differs from unity, within a precision of 1E-6 - - .. note: occupancies that cannot be parsed into a float are ignored - - :return: True if there are partial occupancies, False otherwise - """ - tag = '_atom_site_occupancy' - - epsilon = 1e-6 - partial_occupancies = False - - for datablock in self.values.keys(): - if tag in self.values[datablock].keys(): - for position in self.values[datablock][tag]: - try: - # First remove any parentheses to support value like 1.134(56) and then cast to float - occupancy = float(re.sub(r'[\(\)]', '', position)) - except ValueError: - pass - else: - if abs(occupancy - 1) > epsilon: - return True - - return partial_occupancies - - @property - def has_attached_hydrogens(self): - """ - Check if there are hydrogens without coordinates, specified as attached - to the atoms of the structure. - - :returns: True if there are attached hydrogens, False otherwise. - """ - tag = '_atom_site_attached_hydrogens' - for datablock in self.values.keys(): - if tag in self.values[datablock].keys(): - for value in self.values[datablock][tag]: - if value not in ['.', '?', '0']: - return True - - return False - - @property - def has_undefined_atomic_sites(self): - """ - Return whether the cif data contains any undefined atomic sites. - - An undefined atomic site is defined as a site where at least one of the fractional coordinates specified in the - `_atom_site_fract_*` tags, cannot be successfully interpreted as a float. If the cif data contains any site that - matches this description, or it does not contain any atomic site tags at all, the cif data is said to have - undefined atomic sites. - - :return: boolean, True if no atomic sites are defined or if any of the defined sites contain undefined positions - and False otherwise - """ - tag_x = '_atom_site_fract_x' - tag_y = '_atom_site_fract_y' - tag_z = '_atom_site_fract_z' - - # Some CifData files do not even contain a single `_atom_site_fract_*` tag - has_tags = False - - for datablock in self.values.keys(): - for tag in [tag_x, tag_y, tag_z]: - if tag in self.values[datablock].keys(): - for position in self.values[datablock][tag]: - - # The CifData contains at least one `_atom_site_fract_*` tag - has_tags = True - - try: - # First remove any parentheses to support value like 1.134(56) and then cast to float - float(re.sub(r'[\(\)]', '', position)) - except ValueError: - # Position cannot be converted to a float value, so we have undefined atomic sites - return True - - # At this point the file either has no tags at all, or it does and all coordinates were valid floats - return not has_tags - - @property - def has_atomic_sites(self): - """ - Returns whether there are any atomic sites defined in the cif data. That - is to say, it will check all the values for the `_atom_site_fract_*` tags - and if they are all equal to `?` that means there are no relevant atomic - sites defined and the function will return False. In all other cases the - function will return True - - :returns: False when at least one atomic site fractional coordinate is not - equal to `?` and True otherwise - """ - tag_x = '_atom_site_fract_x' - tag_y = '_atom_site_fract_y' - tag_z = '_atom_site_fract_z' - coords = [] - for datablock in self.values.keys(): - for tag in [tag_x, tag_y, tag_z]: - if tag in self.values[datablock].keys(): - coords.extend(self.values[datablock][tag]) - - return not all(coord == '?' for coord in coords) - - @property - def has_unknown_species(self): - """ - Returns whether the cif contains atomic species that are not recognized by AiiDA. - - The known species are taken from the elements dictionary in `aiida.common.constants`, with the exception of - the "unknown" placeholder element with symbol 'X', as this could not be used to construct a real structure. - If any of the formula of the cif data contain species that are not in that elements dictionary, the function - will return True and False in all other cases. If there is no formulae to be found, it will return None - - :returns: True when there are unknown species in any of the formulae, False if not, None if no formula found - """ - from aiida.common.constants import elements - - # Get all the elements known by AiiDA, excluding the "unknown" element with symbol 'X' - known_species = [element['symbol'] for element in elements.values() if element['symbol'] != 'X'] - - for formula in self.get_formulae(): - - if formula is None: - return None - - species = parse_formula(formula).keys() - if any(specie not in known_species for specie in species): - return True - - return False - - def generate_md5(self): - """ - Computes and returns MD5 hash of the CIF file. - """ - from aiida.common.files import md5_from_filelike - - # Open in binary mode which is required for generating the md5 checksum - with self.open(mode='rb') as handle: - return md5_from_filelike(handle) - - def get_structure(self, converter='pymatgen', store=False, **kwargs): - """ - Creates :py:class:`aiida.orm.nodes.data.structure.StructureData`. - - .. versionadded:: 1.0 - Renamed from _get_aiida_structure - - :param converter: specify the converter. Default 'pymatgen'. - :param store: if True, intermediate calculation gets stored in the - AiiDA database for record. Default False. - :param primitive_cell: if True, primitive cell is returned, - conventional cell if False. Default False. - :param occupancy_tolerance: If total occupancy of a site is between 1 and occupancy_tolerance, - the occupancies will be scaled down to 1. (pymatgen only) - :param site_tolerance: This tolerance is used to determine if two sites are sitting in the same position, - in which case they will be combined to a single disordered site. Defaults to 1e-4. (pymatgen only) - :return: :py:class:`aiida.orm.nodes.data.structure.StructureData` node. - """ - from aiida.orm import Dict - from aiida.tools.data import cif as cif_tools - - parameters = Dict(kwargs) - - try: - convert_function = getattr(cif_tools, f'_get_aiida_structure_{converter}_inline') - except AttributeError: - raise ValueError(f"No such converter '{converter}' available") - - result = convert_function(cif=self, parameters=parameters, metadata={'store_provenance': store}) - - return result['structure'] - - def _prepare_cif(self, **kwargs): # pylint: disable=unused-argument - """Return CIF string of CifData object. - - If parsed values are present, a CIF string is created and written to file. If no parsed values are present, the - CIF string is read from file. - """ - with self.open(mode='rb') as handle: - return handle.read(), {} - - def _get_object_ase(self): - """ - Converts CifData to ase.Atoms - - :return: an ase.Atoms object - """ - return self.ase - - def _get_object_pycifrw(self): - """ - Converts CifData to PyCIFRW.CifFile - - :return: a PyCIFRW.CifFile object - """ - return self.values - - def _validate(self): - """ - Validates MD5 hash of CIF file. - """ - from aiida.common.exceptions import ValidationError - - super()._validate() - - try: - attr_md5 = self.base.attributes.get('md5') - except AttributeError: - raise ValidationError("attribute 'md5' not set.") - md5 = self.generate_md5() - if attr_md5 != md5: - raise ValidationError(f"Attribute 'md5' says '{attr_md5}' but '{md5}' was parsed instead.") diff --git a/aiida/orm/nodes/data/code/__init__.py b/aiida/orm/nodes/data/code/__init__.py deleted file mode 100644 index 9e85d34d06..0000000000 --- a/aiida/orm/nodes/data/code/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# -*- coding: utf-8 -*- -"""Data plugins that represent an executable code.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .abstract import * -from .containerized import * -from .installed import * -from .legacy import * -from .portable import * - -__all__ = ( - 'AbstractCode', - 'Code', - 'ContainerizedCode', - 'InstalledCode', - 'PortableCode', -) - -# yapf: enable diff --git a/aiida/orm/nodes/data/code/abstract.py b/aiida/orm/nodes/data/code/abstract.py deleted file mode 100644 index fcd54f0dd1..0000000000 --- a/aiida/orm/nodes/data/code/abstract.py +++ /dev/null @@ -1,394 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Abstract data plugin representing an executable code.""" -from __future__ import annotations - -import abc -import collections -import pathlib -from typing import TYPE_CHECKING - -import click - -from aiida.cmdline.params.options.interactive import TemplateInteractiveOption -from aiida.common import exceptions -from aiida.common.folders import Folder -from aiida.common.lang import type_check -from aiida.orm import Computer -from aiida.plugins import CalculationFactory - -from ..data import Data - -if TYPE_CHECKING: - from aiida.engine import ProcessBuilder - -__all__ = ('AbstractCode',) - - -class AbstractCode(Data, metaclass=abc.ABCMeta): - """Abstract data plugin representing an executable code.""" - - # Should become ``default_calc_job_plugin`` once ``Code`` is dropped in ``aiida-core==3.0`` - _KEY_ATTRIBUTE_DEFAULT_CALC_JOB_PLUGIN: str = 'input_plugin' - _KEY_ATTRIBUTE_APPEND_TEXT: str = 'append_text' - _KEY_ATTRIBUTE_PREPEND_TEXT: str = 'prepend_text' - _KEY_ATTRIBUTE_USE_DOUBLE_QUOTES: str = 'use_double_quotes' - _KEY_ATTRIBUTE_WITH_MPI: str = 'with_mpi' - _KEY_ATTRIBUTE_WRAP_CMDLINE_PARAMS: str = 'wrap_cmdline_params' - _KEY_EXTRA_IS_HIDDEN: str = 'hidden' # Should become ``is_hidden`` once ``Code`` is dropped - - def __init__( - self, - default_calc_job_plugin: str | None = None, - append_text: str = '', - prepend_text: str = '', - use_double_quotes: bool = False, - with_mpi: bool | None = None, - is_hidden: bool = False, - wrap_cmdline_params: bool = False, - **kwargs - ): - """Construct a new instance. - - :param default_calc_job_plugin: The entry point name of the default ``CalcJob`` plugin to use. - :param append_text: The text that should be appended to the run line in the job script. - :param prepend_text: The text that should be prepended to the run line in the job script. - :param use_double_quotes: Whether the command line invocation of this code should be escaped with double quotes. - :param with_mpi: Whether the command should be run as an MPI program. - :param wrap_cmdline_params: Whether to wrap the executable and all its command line parameters into quotes to - form a single string. This is required to enable support for Docker with the ``ContainerizedCode``. - :param is_hidden: Whether the code is hidden. - """ - super().__init__(**kwargs) - self.default_calc_job_plugin = default_calc_job_plugin - self.append_text = append_text - self.prepend_text = prepend_text - self.use_double_quotes = use_double_quotes - self.with_mpi = with_mpi - self.wrap_cmdline_params = wrap_cmdline_params - self.is_hidden = is_hidden - - @abc.abstractmethod - def can_run_on_computer(self, computer: Computer) -> bool: - """Return whether the code can run on a given computer. - - :param computer: The computer. - :return: ``True`` if the code can run on ``computer``, ``False`` otherwise. - """ - - @abc.abstractmethod - def get_executable(self) -> pathlib.PurePosixPath: - """Return the executable that the submission script should execute to run the code. - - :return: The executable to be called in the submission script. - """ - - def get_executable_cmdline_params(self, cmdline_params: list[str] | None = None) -> list: - """Return the list of executable with its command line parameters. - - :param cmdline_params: List of command line parameters provided by the ``CalcJob`` plugin. - :return: List of the executable followed by its command line parameters. - """ - return [str(self.get_executable())] + (cmdline_params or []) - - def get_prepend_cmdline_params( - self, mpi_args: list[str] | None = None, extra_mpirun_params: list[str] | None = None - ) -> list[str]: - """Return List of command line parameters to be prepended to the executable in submission line. - These command line parameters are typically parameters related to MPI invocations. - - :param mpi_args: List of MPI parameters provided by the ``Computer.get_mpirun_command`` method. - :param extra_mpiruns_params: List of MPI parameters provided by the ``metadata.options.extra_mpirun_params`` - input of the ``CalcJob``. - :return: List of command line parameters to be prepended to the executable in submission line. - """ - return (mpi_args or []) + (extra_mpirun_params or []) - - def validate_working_directory(self, folder: Folder): - """Validate content of the working directory created by the :class:`~aiida.engine.CalcJob` plugin. - - This method will be called by :meth:`~aiida.engine.processes.calcjobs.calcjob.CalcJob.presubmit` when a new - calculation job is launched, passing the :class:`~aiida.common.folders.Folder` that was used by the plugin used - for the calculation to create the input files for the working directory. This method can be overridden by - implementations of the ``AbstractCode`` class that need to validate the contents of that folder. - - :param folder: A sandbox folder that the ``CalcJob`` plugin wrote input files to that will be copied to the - working directory for the corresponding calculation job instance. - :raises PluginInternalError: If the content of the sandbox folder is not valid. - """ - - @property - @abc.abstractmethod - def full_label(self) -> str: - """Return the full label of this code. - - The full label can be just the label itself but it can be something else. However, it at the very least has to - include the label of the code. - - :return: The full label of the code. - """ - - @property - def label(self) -> str: - """Return the label. - - :return: The label. - """ - return self.backend_entity.label - - @label.setter - def label(self, value: str) -> None: - """Set the label. - - The label cannot contain any ``@`` symbols. - - :param value: The new label. - :raises ValueError: If the label contains invalid characters. - """ - type_check(value, str) - - if '@' in value: - raise ValueError('The label contains a `@` symbol, which is not allowed.') - - self.backend_entity.label = value - - @property - def default_calc_job_plugin(self) -> str | None: - """Return the optional default ``CalcJob`` plugin. - - :return: The entry point name of the default ``CalcJob`` plugin to use. - """ - return self.base.attributes.get(self._KEY_ATTRIBUTE_DEFAULT_CALC_JOB_PLUGIN, None) - - @default_calc_job_plugin.setter - def default_calc_job_plugin(self, value: str | None) -> None: - """Set the default ``CalcJob`` plugin. - - :param value: The entry point name of the default ``CalcJob`` plugin to use. - """ - type_check(value, str, allow_none=True) - self.base.attributes.set(self._KEY_ATTRIBUTE_DEFAULT_CALC_JOB_PLUGIN, value) - - @property - def append_text(self) -> str: - """Return the text that should be appended to the run line in the job script. - - :return: The text that should be appended to the run line in the job script. - """ - return self.base.attributes.get(self._KEY_ATTRIBUTE_APPEND_TEXT, '') - - @append_text.setter - def append_text(self, value: str) -> None: - """Set the text that should be appended to the run line in the job script. - - :param value: The text that should be appended to the run line in the job script. - """ - type_check(value, str, allow_none=True) - self.base.attributes.set(self._KEY_ATTRIBUTE_APPEND_TEXT, value) - - @property - def prepend_text(self) -> str: - """Return the text that should be prepended to the run line in the job script. - - :return: The text that should be prepended to the run line in the job script. - """ - return self.base.attributes.get(self._KEY_ATTRIBUTE_PREPEND_TEXT, '') - - @prepend_text.setter - def prepend_text(self, value: str) -> None: - """Set the text that should be prepended to the run line in the job script. - - :param value: The text that should be prepended to the run line in the job script. - """ - type_check(value, str, allow_none=True) - self.base.attributes.set(self._KEY_ATTRIBUTE_PREPEND_TEXT, value) - - @property - def use_double_quotes(self) -> bool: - """Return whether the command line invocation of this code should be escaped with double quotes. - - :return: ``True`` if to escape with double quotes, ``False`` otherwise. - """ - return self.base.attributes.get(self._KEY_ATTRIBUTE_USE_DOUBLE_QUOTES, False) - - @use_double_quotes.setter - def use_double_quotes(self, value: bool) -> None: - """Set whether the command line invocation of this code should be escaped with double quotes. - - :param value: ``True`` if to escape with double quotes, ``False`` otherwise. - """ - type_check(value, bool) - self.base.attributes.set(self._KEY_ATTRIBUTE_USE_DOUBLE_QUOTES, value) - - @property - def with_mpi(self) -> bool | None: - """Return whether the command should be run as an MPI program. - - :return: ``True`` if the code should be run as an MPI program, ``False`` if it shouldn't, ``None`` if unknown. - """ - return self.base.attributes.get(self._KEY_ATTRIBUTE_WITH_MPI, None) - - @with_mpi.setter - def with_mpi(self, value: bool | None) -> None: - """Set whether the command should be run as an MPI program. - - :param value: ``True`` if the code should be run as an MPI program, ``False`` if it shouldn't, ``None`` if - unknown. - """ - type_check(value, bool, allow_none=True) - self.base.attributes.set(self._KEY_ATTRIBUTE_WITH_MPI, value) - - @property - def wrap_cmdline_params(self) -> bool: - """Return whether all command line parameters should be wrapped with double quotes to form a single argument. - - ..note:: This is required to support certain containerization technologies, such as Docker. - - :return: ``True`` if command line parameters should be wrapped, ``False`` otherwise. - """ - return self.base.attributes.get(self._KEY_ATTRIBUTE_WRAP_CMDLINE_PARAMS, False) - - @wrap_cmdline_params.setter - def wrap_cmdline_params(self, value: bool) -> None: - """Set whether all command line parameters should be wrapped with double quotes to form a single argument. - - :param value: ``True`` if command line parameters should be wrapped, ``False`` otherwise. - """ - type_check(value, bool) - self.base.attributes.set(self._KEY_ATTRIBUTE_WRAP_CMDLINE_PARAMS, value) - - @property - def is_hidden(self) -> bool: - """Return whether the code is hidden. - - :return: ``True`` if the code is hidden, ``False`` otherwise, which is also the default. - """ - return self.base.extras.get(self._KEY_EXTRA_IS_HIDDEN, False) - - @is_hidden.setter - def is_hidden(self, value: bool) -> None: - """Define whether the code is hidden or not. - - :param value: ``True`` if the code should be hidden, ``False`` otherwise. - """ - type_check(value, bool) - self.base.extras.set(self._KEY_EXTRA_IS_HIDDEN, value) - - def get_builder(self) -> 'ProcessBuilder': - """Create and return a new ``ProcessBuilder`` for the ``CalcJob`` class of the plugin configured for this code. - - The configured calculation plugin class is defined by the ``default_calc_job_plugin`` property. - - .. note:: it also sets the ``builder.code`` value. - - :return: a ``ProcessBuilder`` instance with the ``code`` input already populated with ourselves - :raise aiida.common.EntryPointError: if the specified plugin does not exist. - :raise ValueError: if no default plugin was specified. - """ - entry_point = self.default_calc_job_plugin - - if entry_point is None: - raise ValueError('No default calculation input plugin specified for this code') - - try: - process_class = CalculationFactory(entry_point) - except exceptions.EntryPointError: - raise exceptions.EntryPointError(f'The calculation entry point `{entry_point}` could not be loaded') - - builder = process_class.get_builder() # type: ignore - builder.code = self - - return builder - - @staticmethod - def cli_validate_label_uniqueness(_, __, value): - """Validate the uniqueness of the label of the code.""" - from aiida.orm import load_code - - try: - load_code(value) - except exceptions.NotExistent: - pass - except exceptions.MultipleObjectsError: - raise click.BadParameter(f'Multiple codes with the label `{value}` already exist.') - else: - raise click.BadParameter(f'A code with the label `{value}` already exists.') - - return value - - @classmethod - def get_cli_options(cls) -> collections.OrderedDict: - """Return the CLI options that would allow to create an instance of this class.""" - return collections.OrderedDict(cls._get_cli_options()) - - @classmethod - def _get_cli_options(cls) -> dict: - """Return the CLI options that would allow to create an instance of this class.""" - return { - 'label': { - 'short_name': '-L', - 'required': True, - 'type': click.STRING, - 'prompt': 'Label', - 'help': 'A unique label to identify the code by.', - 'callback': cls.cli_validate_label_uniqueness, - }, - 'description': { - 'short_name': '-D', - 'type': click.STRING, - 'prompt': 'Description', - 'help': 'Human-readable description of this code ideally including version and compilation environment.' - }, - 'default_calc_job_plugin': { - 'short_name': '-P', - 'type': click.STRING, - 'prompt': 'Default `CalcJob` plugin', - 'help': 'Entry point name of the default plugin (as listed in `verdi plugin list aiida.calculations`).' - }, - 'use_double_quotes': { - 'is_flag': True, - 'default': False, - 'help': 'Whether the executable and arguments of the code in the submission script should be escaped ' - 'with single or double quotes.', - 'prompt': 'Escape using double quotes', - }, - 'with_mpi': { - 'is_flag': True, - 'default': None, - 'help': ( - 'Whether the executable should be run as an MPI program. This option can be left unspecified ' - 'in which case `None` will be set and it is left up to the calculation job plugin or inputs ' - 'whether to run with MPI.' - ), - 'prompt': 'Run with MPI', - }, - 'prepend_text': { - 'cls': TemplateInteractiveOption, - 'type': click.STRING, - 'default': '', - 'prompt': 'Prepend script', - 'help': 'Bash commands that should be prepended to the run line in all submit scripts for this code.', - 'extension': '.bash', - 'header': 'PREPEND_TEXT: if there is any bash commands that should be prepended to the executable call ' - 'in all submit scripts for this code, type that between the equal signs below and save the file.', - 'footer': 'All lines that start with `#=`: will be ignored.' - }, - 'append_text': { - 'cls': TemplateInteractiveOption, - 'type': click.STRING, - 'default': '', - 'prompt': 'Append script', - 'help': 'Bash commands that should be appended to the run line in all submit scripts for this code.', - 'extension': '.bash', - 'header': 'APPEND_TEXT: if there is any bash commands that should be appended to the executable call ' - 'in all submit scripts for this code, type that between the equal signs below and save the file.', - 'footer': 'All lines that start with `#=`: will be ignored.' - }, - } diff --git a/aiida/orm/nodes/data/code/installed.py b/aiida/orm/nodes/data/code/installed.py deleted file mode 100644 index b57d16f838..0000000000 --- a/aiida/orm/nodes/data/code/installed.py +++ /dev/null @@ -1,200 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Data plugin representing an executable code on a remote computer. - -This plugin should be used if an executable is pre-installed on a computer. The ``InstalledCode`` represents the code by -storing the absolute filepath of the relevant executable and the computer on which it is installed. The computer is -represented by an instance of :class:`aiida.orm.computers.Computer`. Each time a :class:`aiida.engine.CalcJob` is run -using an ``InstalledCode``, it will run its executable on the associated computer. -""" -from __future__ import annotations - -import pathlib - -import click - -from aiida.cmdline.params.types import ComputerParamType -from aiida.common import exceptions -from aiida.common.lang import type_check -from aiida.common.log import override_log_level -from aiida.orm import Computer -from aiida.orm.entities import from_backend_entity - -from .legacy import Code - -__all__ = ('InstalledCode',) - - -class InstalledCode(Code): - """Data plugin representing an executable code on a remote computer.""" - - _KEY_ATTRIBUTE_FILEPATH_EXECUTABLE: str = 'filepath_executable' - - def __init__(self, computer: Computer, filepath_executable: str, **kwargs): - """Construct a new instance. - - :param computer: The remote computer on which the executable is located. - :param filepath_executable: The absolute filepath of the executable on the remote computer. - """ - super().__init__(**kwargs) - self.computer = computer - self.filepath_executable = filepath_executable # type: ignore[assignment] - - def _validate(self): - """Validate the instance by checking that a computer has been defined. - - :raises :class:`aiida.common.exceptions.ValidationError`: If the state of the node is invalid. - """ - super(Code, self)._validate() # Change to ``super()._validate()`` once deprecated ``Code`` class is removed. # pylint: disable=bad-super-call - - if not self.computer: - raise exceptions.ValidationError('The `computer` is undefined.') - - try: - self.filepath_executable - except TypeError as exception: - raise exceptions.ValidationError('The `filepath_executable` is not set.') from exception - - def validate_filepath_executable(self): - """Validate the ``filepath_executable`` attribute. - - Checks whether the executable exists on the remote computer if a transport can be opened to it. This method - is intentionally not called in ``_validate`` as to allow the creation of ``Code`` instances whose computers can - not yet be connected to and as to not require the overhead of opening transports in storing a new code. - - .. note:: If the ``filepath_executable`` is not an absolute path, the check is skipped. - - :raises `~aiida.common.exceptions.ValidationError`: if no transport could be opened or if the defined executable - does not exist on the remote computer. - """ - if not self.filepath_executable.is_absolute(): - return - - try: - with override_log_level(): # Temporarily suppress noisy logging - with self.computer.get_transport() as transport: - file_exists = transport.isfile(str(self.filepath_executable)) - except Exception as exception: # pylint: disable=broad-except - raise exceptions.ValidationError( - 'Could not connect to the configured computer to determine whether the specified executable exists.' - ) from exception - - if not file_exists: - raise exceptions.ValidationError( - f'The provided remote absolute path `{self.filepath_executable}` does not exist on the computer.' - ) - - def can_run_on_computer(self, computer: Computer) -> bool: - """Return whether the code can run on a given computer. - - :param computer: The computer. - :return: ``True`` if the provided computer is the same as the one configured for this code. - """ - type_check(computer, Computer) - return computer.pk == self.computer.pk - - def get_executable(self) -> pathlib.PurePosixPath: - """Return the executable that the submission script should execute to run the code. - - :return: The executable to be called in the submission script. - """ - return self.filepath_executable - - @property # type: ignore[override] - def computer(self) -> Computer: - """Return the computer of this code.""" - assert self.backend_entity.computer is not None - return from_backend_entity(Computer, self.backend_entity.computer) - - @computer.setter - def computer(self, computer: Computer) -> None: - """Set the computer of this code. - - :param computer: A `Computer`. - """ - if self.is_stored: - raise exceptions.ModificationNotAllowed('cannot set the computer on a stored node') - - type_check(computer, Computer, allow_none=False) - self.backend_entity.computer = computer.backend_entity - - @property - def full_label(self) -> str: - """Return the full label of this code. - - The full label can be just the label itself but it can be something else. However, it at the very least has to - include the label of the code. - - :return: The full label of the code. - """ - return f'{self.label}@{self.computer.label}' - - @property - def filepath_executable(self) -> pathlib.PurePosixPath: - """Return the absolute filepath of the executable that this code represents. - - :return: The absolute filepath of the executable. - """ - return pathlib.PurePosixPath(self.base.attributes.get(self._KEY_ATTRIBUTE_FILEPATH_EXECUTABLE)) - - @filepath_executable.setter - def filepath_executable(self, value: str) -> None: - """Set the absolute filepath of the executable that this code represents. - - :param value: The absolute filepath of the executable. - """ - type_check(value, str) - self.base.attributes.set(self._KEY_ATTRIBUTE_FILEPATH_EXECUTABLE, value) - - @staticmethod - def cli_validate_label_uniqueness(ctx, _, value): - """Validate the uniqueness of the label of the code.""" - from aiida.orm import load_code - - computer = ctx.params.get('computer', None) - - if computer is None: - return value - - full_label = f'{value}@{computer.label}' - - try: - load_code(full_label) - except exceptions.NotExistent: - pass - except exceptions.MultipleObjectsError: - raise click.BadParameter(f'Multiple codes with the label `{full_label}` already exist.') - else: - raise click.BadParameter(f'A code with the label `{full_label}` already exists.') - - return value - - @classmethod - def _get_cli_options(cls) -> dict: - """Return the CLI options that would allow to create an instance of this class.""" - options = { - 'computer': { - 'short_name': '-Y', - 'required': True, - 'prompt': 'Computer', - 'help': 'The remote computer on which the executable resides.', - 'type': ComputerParamType(), - }, - 'filepath_executable': { - 'short_name': '-X', - 'required': True, - 'type': click.Path(exists=False), - 'prompt': 'Absolute filepath executable', - 'help': 'Absolute filepath of the executable on the remote computer.', - } - } - options.update(**super()._get_cli_options()) - - return options diff --git a/aiida/orm/nodes/data/code/legacy.py b/aiida/orm/nodes/data/code/legacy.py deleted file mode 100644 index e07164cdd8..0000000000 --- a/aiida/orm/nodes/data/code/legacy.py +++ /dev/null @@ -1,596 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Data plugin represeting an executable code to be wrapped and called through a `CalcJob` plugin.""" -import os -import pathlib - -from aiida.common import exceptions -from aiida.common.log import override_log_level -from aiida.common.warnings import warn_deprecation -from aiida.orm import Computer - -from .abstract import AbstractCode - -__all__ = ('Code',) - -warn_deprecation( - 'The `Code` class is deprecated. To create an instance, use the `aiida.orm.nodes.data.code.installed.InstalledCode`' - ' or `aiida.orm.nodes.data.code.portable.PortableCode` for a "remote" or "local" code, respectively. If you are ' - 'using this class to compare type, e.g. in `isinstance`, use `aiida.orm.nodes.data.code.abstract.AbstractCode`.', - version=3 -) - - -class Code(AbstractCode): - """ - A code entity. - It can either be 'local', or 'remote'. - - * Local code: it is a collection of files/dirs (added using the add_path() method), where one \ - file is flagged as executable (using the set_local_executable() method). - - * Remote code: it is a pair (remotecomputer, remotepath_of_executable) set using the \ - set_remote_computer_exec() method. - - For both codes, one can set some code to be executed right before or right after - the execution of the code, using the set_preexec_code() and set_postexec_code() - methods (e.g., the set_preexec_code() can be used to load specific modules required - for the code to be run). - """ - - # pylint: disable=too-many-public-methods - - def __init__(self, remote_computer_exec=None, local_executable=None, input_plugin_name=None, files=None, **kwargs): - super().__init__(**kwargs) - - if remote_computer_exec and local_executable: - raise ValueError('cannot set `remote_computer_exec` and `local_executable` at the same time') - - if remote_computer_exec: - warn_deprecation( - 'The `Code` plugin is deprecated, use the `InstalledCode` (`core.code.remote`) instead.', 3 - ) - self.set_remote_computer_exec(remote_computer_exec) - - if local_executable: - warn_deprecation('The `Code` plugin is deprecated, use the `PortableCode` (`core.code.local`) instead.', 3) - self.set_local_executable(local_executable) - - if input_plugin_name: - self.set_input_plugin_name(input_plugin_name) - - if files: - self.set_files(files) - - HIDDEN_KEY = 'hidden' - - def can_run_on_computer(self, computer: Computer) -> bool: - """Return whether the code can run on a given computer. - - :param computer: The computer. - :return: ``True`` if the code can run on ``computer``, ``False`` otherwise. - """ - from aiida import orm - from aiida.common.lang import type_check - - if self.is_local(): - return True - - type_check(computer, orm.Computer) - return computer.pk == self.get_remote_computer().pk - - def get_executable(self) -> pathlib.PurePosixPath: - """Return the executable that the submission script should execute to run the code. - - :return: The executable to be called in the submission script. - """ - if self.is_local(): - exec_path = f'./{self.get_local_executable()}' - else: - exec_path = self.get_remote_exec_path() - - return pathlib.PurePosixPath(exec_path) - - def hide(self): - """ - Hide the code (prevents from showing it in the verdi code list) - """ - warn_deprecation('`Code.hide` property is deprecated, use the `Code.is_hidden` property instead.', version=3) - self.is_hidden = True - - def reveal(self): - """ - Reveal the code (allows to show it in the verdi code list) - By default, it is revealed - """ - warn_deprecation('`Code.reveal` property is deprecated, use the `Code.is_hidden` property instead.', version=3) - self.is_hidden = False - - @property - def hidden(self): - """ - Determines whether the Code is hidden or not - """ - warn_deprecation('`Code.hidden` property is deprecated, use the `Code.is_hidden` property instead.', version=3) - return self.is_hidden - - def set_files(self, files): - """ - Given a list of filenames (or a single filename string), - add it to the path (all at level zero, i.e. without folders). - Therefore, be careful for files with the same name! - - :todo: decide whether to check if the Code must be a local executable - to be able to call this function. - """ - - if isinstance(files, str): - files = [files] - - for filename in files: - if os.path.isfile(filename): - with open(filename, 'rb') as handle: - self.base.repository.put_object_from_filelike(handle, os.path.split(filename)[1]) - - def __str__(self): - if self.computer is None: - return f"Local code '{self.label}' pk: {self.pk}, uuid: {self.uuid}" - - return f"Remote code '{self.label}' on {self.computer.label} pk: {self.pk}, uuid: {self.uuid}" - - def get_computer_label(self): - """Get label of this code's computer.""" - warn_deprecation( - '`Code.get_computer_label` method is deprecated, use the `InstalledCode.computer.label` property instead.', - version=3 - ) - return 'repository' if self.computer is None else self.computer.label - - @property - def full_label(self): - """Get full label of this code. - - Returns label of the form @. - """ - return f'{self.label}@{"repository" if self.computer is None else self.computer.label}' - - def relabel(self, new_label): - """Relabel this code. - - :param new_label: new code label - """ - warn_deprecation('`Code.relabel` method is deprecated, use the `label` property instead.', version=3) - if self.computer is not None: - suffix = f'@{self.computer.label}' - if new_label.endswith(suffix): - new_label = new_label[:-len(suffix)] - - self.label = new_label - - def get_description(self): - """Return a string description of this Code instance. - - :return: string description of this Code instance - """ - warn_deprecation( - '`Code.get_description` method is deprecated, use the `description` property instead.', version=3 - ) - return f'{self.description}' - - @classmethod - def get_code_helper(cls, label, machinename=None, backend=None): - """ - :param label: the code label identifying the code to load - :param machinename: the machine name where code is setup - - :raise aiida.common.NotExistent: if no code identified by the given string is found - :raise aiida.common.MultipleObjectsError: if the string cannot identify uniquely - a code - """ - from aiida.common.exceptions import MultipleObjectsError, NotExistent - from aiida.orm.querybuilder import QueryBuilder - - warn_deprecation( - '`Code.get_code_helper` classmethod is deprecated, use `aiida.orm.load_code` instead.', version=3 - ) - - query = QueryBuilder(backend=backend) - query.append(cls, filters={'label': label}, project='*', tag='code') - if machinename: - query.append(Computer, filters={'label': machinename}, with_node='code') - - if query.count() == 0: - raise NotExistent(f"'{label}' is not a valid code label.") - elif query.count() > 1: - codes = query.all(flat=True) - retstr = f"There are multiple codes with label '{label}', having IDs: " - retstr += f"{', '.join(sorted([str(c.pk) for c in codes]))}.\n" # type: ignore[union-attr] - retstr += ('Relabel them (using their ID), or refer to them with their ID.') - raise MultipleObjectsError(retstr) - else: - result = query.first() - if not result: - raise NotExistent(f"code '{label}' does not exist.") - - return result[0] - - @classmethod - def get(cls, pk=None, label=None, machinename=None): - """ - Get a Computer object with given identifier string, that can either be - the numeric ID (pk), or the label (and computername) (if unique). - - :param pk: the numeric ID (pk) for code - :param label: the code label identifying the code to load - :param machinename: the machine name where code is setup - - :raise aiida.common.NotExistent: if no code identified by the given string is found - :raise aiida.common.MultipleObjectsError: if the string cannot identify uniquely a code - :raise ValueError: if neither a pk nor a label was passed in - """ - # pylint: disable=arguments-differ - from aiida.orm.utils import load_code - - warn_deprecation('`Code.get` classmethod is deprecated, use `aiida.orm.load_code` instead.', version=3) - - # first check if code pk is provided - if pk: - code_int = int(pk) - try: - return load_code(pk=code_int) - except exceptions.NotExistent: - raise ValueError(f'{pk} is not valid code pk') - except exceptions.MultipleObjectsError: - raise exceptions.MultipleObjectsError(f"More than one code in the DB with pk='{pk}'!") - - # check if label (and machinename) is provided - elif label is not None: - return cls.get_code_helper(label, machinename) - - else: - raise ValueError('Pass either pk or code label (and machinename)') - - @classmethod - def get_from_string(cls, code_string): - """ - Get a Computer object with given identifier string in the format - label@machinename. See the note below for details on the string - detection algorithm. - - .. note:: the (leftmost) '@' symbol is always used to split code - and computername. Therefore do not use - '@' in the code name if you want to use this function - ('@' in the computer name are instead valid). - - :param code_string: the code string identifying the code to load - - :raise aiida.common.NotExistent: if no code identified by the given string is found - :raise aiida.common.MultipleObjectsError: if the string cannot identify uniquely - a code - :raise TypeError: if code_string is not of string type - - """ - from aiida.common.exceptions import MultipleObjectsError, NotExistent - - warn_deprecation( - '`Code.get_from_string` classmethod is deprecated, use `aiida.orm.load_code` instead.', version=3 - ) - - try: - label, _, machinename = code_string.partition('@') - except AttributeError: - raise TypeError('the provided code_string is not of valid string type') - - try: - return cls.get_code_helper(label, machinename) - except NotExistent: - raise NotExistent(f'{code_string} could not be resolved to a valid code label') - except MultipleObjectsError: - raise MultipleObjectsError(f'{code_string} could not be uniquely resolved') - - @classmethod - def list_for_plugin(cls, plugin, labels=True, backend=None): - """ - Return a list of valid code strings for a given plugin. - - :param plugin: The string of the plugin. - :param labels: if True, return a list of code names, otherwise - return the code PKs (integers). - :return: a list of string, with the code names if labels is True, - otherwise a list of integers with the code PKs. - """ - from aiida.orm.querybuilder import QueryBuilder - - warn_deprecation('`Code.list_for_plugin` classmethod has been deprecated, there is no replacement.', version=3) - - query = QueryBuilder(backend=backend) - query.append(cls, filters={'attributes.input_plugin': {'==': plugin}}) - valid_codes = query.all(flat=True) - - if labels: - return [c.label for c in valid_codes] # type: ignore[union-attr] - - return [c.pk for c in valid_codes] # type: ignore[union-attr] - - def _validate(self): - super()._validate() - - if self.is_local() is None: - raise exceptions.ValidationError('You did not set whether the code is local or remote') - - if self.is_local(): - if not self.get_local_executable(): - raise exceptions.ValidationError( - 'You have to set which file is the local executable ' - 'using the set_exec_filename() method' - ) - if self.get_local_executable() not in self.base.repository.list_object_names(): - raise exceptions.ValidationError( - f"The local executable '{self.get_local_executable()}' is not in the list of files of this code" - ) - else: - if self.base.repository.list_object_names(): - raise exceptions.ValidationError('The code is remote but it has files inside') - if not self.get_remote_computer(): - raise exceptions.ValidationError('You did not specify a remote computer') - if not self.get_remote_exec_path(): - raise exceptions.ValidationError('You did not specify a remote executable') - - def validate_remote_exec_path(self): - """Validate the ``remote_exec_path`` attribute. - - Checks whether the executable exists on the remote computer if a transport can be opened to it. This method - is intentionally not called in ``_validate`` as to allow the creation of ``Code`` instances whose computers can - not yet be connected to and as to not require the overhead of opening transports in storing a new code. - - :raises `~aiida.common.exceptions.ValidationError`: if no transport could be opened or if the defined executable - does not exist on the remote computer. - """ - warn_deprecation( - '`Code.validate_remote_exec_path` method is deprecated, use the ' - '`InstalledCode.validate_filepath_executable` property instead.', - version=3 - ) - filepath = self.get_remote_exec_path() - - if self.computer is None: - raise exceptions.ValidationError('checking the remote exec path is not available for a local code.') - - try: - with override_log_level(): # Temporarily suppress noisy logging - with self.computer.get_transport() as transport: - file_exists = transport.isfile(filepath) - except Exception: # pylint: disable=broad-except - raise exceptions.ValidationError( - 'Could not connect to the configured computer to determine whether the specified executable exists.' - ) - - if not file_exists: - raise exceptions.ValidationError( - f'the provided remote absolute path `{filepath}` does not exist on the computer.' - ) - - def set_prepend_text(self, code): - """ - Pass a string of code that will be put in the scheduler script before the - execution of the code. - """ - warn_deprecation( - '`Code.set_prepend_text` method is deprecated, use the `prepend_text` property instead.', version=3 - ) - self.prepend_text = code - - def get_prepend_text(self): - """ - Return the code that will be put in the scheduler script before the - execution, or an empty string if no pre-exec code was defined. - """ - warn_deprecation( - '`Code.get_prepend_text` method is deprecated, use the `prepend_text` property instead.', version=3 - ) - return self.prepend_text - - def set_input_plugin_name(self, input_plugin): - """ - Set the name of the default input plugin, to be used for the automatic - generation of a new calculation. - """ - warn_deprecation( - '`Code.set_input_plugin_name` method is deprecated, use the `default_calc_job_plugin` property instead.', - version=3 - ) - self.default_calc_job_plugin = input_plugin - - def get_input_plugin_name(self): - """ - Return the name of the default input plugin (or None if no input plugin - was set. - """ - warn_deprecation( - '`Code.get_input_plugin_name` method is deprecated, use the `default_calc_job_plugin` property instead.', - version=3 - ) - return self.default_calc_job_plugin - - def set_use_double_quotes(self, use_double_quotes: bool): - """Set whether the command line invocation of this code should be escaped with double quotes. - - :param use_double_quotes: True if to escape with double quotes, False otherwise. - """ - warn_deprecation( - '`Code.set_use_double_quotes` method is deprecated, use the `use_double_quotes` property instead.', - version=3 - ) - self.use_double_quotes = use_double_quotes - - def get_use_double_quotes(self) -> bool: - """Return whether the command line invocation of this code should be escaped with double quotes. - - :returns: True if to escape with double quotes, False otherwise which is also the default. - """ - warn_deprecation( - '`Code.get_use_double_quotes` method is deprecated, use the `use_double_quotes` property instead.', - version=3 - ) - return self.use_double_quotes - - def set_append_text(self, code): - """ - Pass a string of code that will be put in the scheduler script after the - execution of the code. - """ - warn_deprecation( - '`Code.set_append_text` method is deprecated, use the `append_text` property instead.', version=3 - ) - self.append_text = code - - def get_append_text(self): - """ - Return the postexec_code, or an empty string if no post-exec code was defined. - """ - warn_deprecation( - '`Code.get_append_text` method is deprecated, use the `append_text` property instead.', version=3 - ) - return self.append_text - - def set_local_executable(self, exec_name): - """ - Set the filename of the local executable. - Implicitly set the code as local. - """ - warn_deprecation('`Code.set_local_executable` method is deprecated, use `PortableCode`.', version=3) - - self._set_local() - self.filepath_executable = exec_name - - def get_local_executable(self): - """Return the local executable.""" - warn_deprecation( - '`Code.get_local_executable` method is deprecated, use `PortableCode.filepath_executable` instead.', - version=3 - ) - return self.filepath_executable - - def set_remote_computer_exec(self, remote_computer_exec): - """ - Set the code as remote, and pass the computer on which it resides - and the absolute path on that computer. - - :param remote_computer_exec: a tuple (computer, remote_exec_path), where computer is a aiida.orm.Computer and - remote_exec_path is the absolute path of the main executable on remote computer. - """ - from aiida import orm - from aiida.common.lang import type_check - - warn_deprecation('`Code.set_remote_computer_exec` method is deprecated, use `InstalledCode`.', version=3) - - if (not isinstance(remote_computer_exec, (list, tuple)) or len(remote_computer_exec) != 2): - raise ValueError( - 'remote_computer_exec must be a list or tuple of length 2, with machine and executable name' - ) - - computer, remote_exec_path = tuple(remote_computer_exec) - - if not os.path.isabs(remote_exec_path): - raise ValueError('exec_path must be an absolute path (on the remote machine)') - - type_check(computer, orm.Computer) - - self._set_remote() - self.computer = computer - self.base.attributes.set('remote_exec_path', remote_exec_path) - - def get_remote_exec_path(self): - """Return the ``remote_exec_path`` attribute.""" - warn_deprecation( - '`Code.get_remote_exec_path` method is deprecated, use `InstalledCode.filepath_executable` instead.', - version=3 - ) - if self.is_local(): - raise ValueError('The code is local') - return self.base.attributes.get('remote_exec_path', '') - - def get_remote_computer(self): - """Return the remote computer associated with this code.""" - warn_deprecation( - '`Code.get_remote_computer` method is deprecated, use the `computer` attribute instead.', version=3 - ) - if self.is_local(): - raise ValueError('The code is local') - - return self.computer - - def _set_local(self): - """ - Set the code as a 'local' code, meaning that all the files belonging to the code - will be copied to the cluster, and the file set with set_exec_filename will be - run. - - It also deletes the flags related to the local case (if any) - """ - self.base.attributes.set('is_local', True) - self.computer = None - try: - self.base.attributes.delete('remote_exec_path') - except AttributeError: - pass - - def _set_remote(self): - """ - Set the code as a 'remote' code, meaning that the code itself has no files attached, - but only a location on a remote computer (with an absolute path of the executable on - the remote computer). - - It also deletes the flags related to the local case (if any) - """ - self.base.attributes.set('is_local', False) - try: - self.base.attributes.delete('local_executable') - except AttributeError: - pass - - def is_local(self): - """ - Return True if the code is 'local', False if it is 'remote' (see also documentation - of the set_local and set_remote functions). - """ - warn_deprecation( - '`Code.is_local` method is deprecated, use a `PortableCode` instance and check the type.', version=3 - ) - return self.base.attributes.get('is_local', None) - - def can_run_on(self, computer): - """ - Return True if this code can run on the given computer, False otherwise. - - Local codes can run on any machine; remote codes can run only on the machine - on which they reside. - - TODO: add filters to mask the remote machines on which a local code can run. - """ - from aiida import orm - from aiida.common.lang import type_check - - warn_deprecation('`Code.can_run_on` method is deprecated, use `can_run_on_computer` instead.', version=3) - - if self.is_local(): - return True - - type_check(computer, orm.Computer) - return computer.pk == self.get_remote_computer().pk - - def get_execname(self): - """ - Return the executable string to be put in the script. - For local codes, it is ./LOCAL_EXECUTABLE_NAME - For remote codes, it is the absolute path to the executable. - """ - warn_deprecation('`Code.get_execname` method is deprecated, use `get_executable` instead.', version=3) - return str(self.get_executable()) diff --git a/aiida/orm/nodes/data/code/portable.py b/aiida/orm/nodes/data/code/portable.py deleted file mode 100644 index 8fb2bd8364..0000000000 --- a/aiida/orm/nodes/data/code/portable.py +++ /dev/null @@ -1,169 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Data plugin representing an executable code stored in AiiDA's storage. - -This plugin should be used for executables that are not already installed on the target computer, but instead are -available on the machine where AiiDA is running. The plugin assumes that the code is self-contained by a single -directory containing all the necessary files, including a main executable. When constructing a ``PortableCode``, passing -the absolute filepath as ``filepath_files`` will make sure that all the files contained within are uploaded to AiiDA's -storage. The ``filepath_executable`` should indicate the filename of the executable within that directory. Each time a -:class:`aiida.engine.CalcJob` is run using a ``PortableCode``, the uploaded files will be automatically copied to the -working directory on the selected computer and the executable will be run there. -""" -from __future__ import annotations - -import pathlib - -import click - -from aiida.common import exceptions -from aiida.common.folders import Folder -from aiida.common.lang import type_check -from aiida.orm import Computer - -from .legacy import Code - -__all__ = ('PortableCode',) - - -class PortableCode(Code): - """Data plugin representing an executable code stored in AiiDA's storage.""" - - _KEY_ATTRIBUTE_FILEPATH_EXECUTABLE: str = 'filepath_executable' - - def __init__(self, filepath_executable: str, filepath_files: pathlib.Path, **kwargs): - """Construct a new instance. - - .. note:: If the files necessary for this code are not all located in a single directory or the directory - contains files that should not be uploaded, and so the ``filepath_files`` cannot be used. One can use the - methods of the :class:`aiida.orm.nodes.repository.NodeRepository` class. This can be accessed through the - ``base.repository`` attribute of the instance after it has been constructed. For example:: - - code = PortableCode(filepath_executable='some_name.exe') - code.put_object_from_file() - code.put_object_from_filelike() - code.put_object_from_tree() - - :param filepath_executable: The relative filepath of the executable within the directory of uploaded files. - :param filepath_files: The filepath to the directory containing all the files of the code. - """ - super().__init__(**kwargs) - type_check(filepath_files, pathlib.Path) - self.filepath_executable = filepath_executable # type: ignore[assignment] - self.base.repository.put_object_from_tree(str(filepath_files)) - - def _validate(self): - """Validate the instance by checking that an executable is defined and it is part of the repository files. - - :raises :class:`aiida.common.exceptions.ValidationError`: If the state of the node is invalid. - """ - super(Code, self)._validate() # Change to ``super()._validate()`` once deprecated ``Code`` class is removed. # pylint: disable=bad-super-call - - try: - filepath_executable = self.filepath_executable - except TypeError as exception: - raise exceptions.ValidationError('The `filepath_executable` is not set.') from exception - - objects = self.base.repository.list_object_names() - - if str(filepath_executable) not in objects: - raise exceptions.ValidationError( - f'The executable `{filepath_executable}` is not one of the uploaded files: {objects}' - ) - - def can_run_on_computer(self, computer: Computer) -> bool: - """Return whether the code can run on a given computer. - - A ``PortableCode`` should be able to be run on any computer in principle. - - :param computer: The computer. - :return: ``True`` if the provided computer is the same as the one configured for this code. - """ - return True - - def get_executable(self) -> pathlib.PurePosixPath: - """Return the executable that the submission script should execute to run the code. - - :return: The executable to be called in the submission script. - """ - return self.filepath_executable - - def validate_working_directory(self, folder: Folder): - """Validate content of the working directory created by the :class:`~aiida.engine.CalcJob` plugin. - - This method will be called by :meth:`~aiida.engine.processes.calcjobs.calcjob.CalcJob.presubmit` when a new - calculation job is launched, passing the :class:`~aiida.common.folders.Folder` that was used by the plugin used - for the calculation to create the input files for the working directory. This method can be overridden by - implementations of the ``AbstractCode`` class that need to validate the contents of that folder. - - :param folder: A sandbox folder that the ``CalcJob`` plugin wrote input files to that will be copied to the - working directory for the corresponding calculation job instance. - :raises PluginInternalError: The ``CalcJob`` plugin created a file that has the same relative filepath as the - executable for this portable code. - """ - if str(self.filepath_executable) in folder.get_content_list(): - raise exceptions.PluginInternalError( - f'The plugin created a file {self.filepath_executable} that is also the executable name!' - ) - - @property - def full_label(self) -> str: - """Return the full label of this code. - - The full label can be just the label itself but it can be something else. However, it at the very least has to - include the label of the code. - - :return: The full label of the code. - """ - return self.label - - @property - def filepath_executable(self) -> pathlib.PurePosixPath: - """Return the relative filepath of the executable that this code represents. - - :return: The relative filepath of the executable. - """ - return pathlib.PurePosixPath(self.base.attributes.get(self._KEY_ATTRIBUTE_FILEPATH_EXECUTABLE)) - - @filepath_executable.setter - def filepath_executable(self, value: str) -> None: - """Set the relative filepath of the executable that this code represents. - - :param value: The relative filepath of the executable within the directory of uploaded files. - """ - type_check(value, str) - - if pathlib.PurePosixPath(value).is_absolute(): - raise ValueError('The `filepath_executable` should not be absolute.') - - self.base.attributes.set(self._KEY_ATTRIBUTE_FILEPATH_EXECUTABLE, value) - - @classmethod - def _get_cli_options(cls) -> dict: - """Return the CLI options that would allow to create an instance of this class.""" - options = { - 'filepath_executable': { - 'short_name': '-X', - 'required': True, - 'type': click.STRING, - 'prompt': 'Relative filepath executable', - 'help': 'Relative filepath of executable with directory of code files.', - }, - 'filepath_files': { - 'short_name': '-F', - 'required': True, - 'type': click.Path(exists=True, file_okay=False, dir_okay=True, path_type=pathlib.Path), - 'prompt': 'Code directory', - 'help': 'Filepath to directory containing code files.', - } - } - options.update(**super()._get_cli_options()) - - return options diff --git a/aiida/orm/nodes/data/data.py b/aiida/orm/nodes/data/data.py deleted file mode 100644 index 564139691f..0000000000 --- a/aiida/orm/nodes/data/data.py +++ /dev/null @@ -1,363 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with `Node` sub class `Data` to be used as a base class for data structures.""" -from typing import Dict - -from aiida.common import exceptions -from aiida.common.lang import override -from aiida.common.links import LinkType -from aiida.orm.entities import from_backend_entity - -from ..node import Node - -__all__ = ('Data',) - - -class Data(Node): - """ - The base class for all Data nodes. - - AiiDA Data classes are subclasses of Node and must support multiple inheritance. - - Architecture note: - Calculation plugins are responsible for converting raw output data from simulation codes to Data nodes. - Nodes are responsible for validating their content (see _validate method). - """ - _source_attributes = ['db_name', 'db_uri', 'uri', 'id', 'version', 'extras', 'source_md5', 'description', 'license'] - - # Replace this with a dictionary in each subclass that, given a file - # extension, returns the corresponding fileformat string. - # - # This is used in the self.export() method. - # By default, if not found here, - # The fileformat string is assumed to match the extension. - # Example: {'dat': 'dat_multicolumn'} - _export_format_replacements: Dict[str, str] = {} - - # Data nodes are storable - _storable = True - _unstorable_message = 'storing for this node has been disabled' - - def __init__(self, *args, source=None, **kwargs): - """Construct a new instance, setting the ``source`` attribute if provided as a keyword argument.""" - super().__init__(*args, **kwargs) - if source is not None: - self.source = source - - def __copy__(self): - """Copying a Data node is not supported, use copy.deepcopy or call Data.clone().""" - raise exceptions.InvalidOperation('copying a Data node is not supported, use copy.deepcopy') - - def __deepcopy__(self, memo): - """ - Create a clone of the Data node by piping through to the clone method and return the result. - - :returns: an unstored clone of this Data node - """ - return self.clone() - - def clone(self): - """Create a clone of the Data node. - - :returns: an unstored clone of this Data node - """ - import copy - - backend_clone = self.backend_entity.clone() - clone = from_backend_entity(self.__class__, backend_clone) - clone.base.attributes.reset(copy.deepcopy(self.base.attributes.all)) - clone.base.repository._clone(self.base.repository) # pylint: disable=protected-access - - return clone - - @property - def source(self): - """ - Gets the dictionary describing the source of Data object. Possible fields: - - * **db_name**: name of the source database. - * **db_uri**: URI of the source database. - * **uri**: URI of the object's source. Should be a permanent link. - * **id**: object's source identifier in the source database. - * **version**: version of the object's source. - * **extras**: a dictionary with other fields for source description. - * **source_md5**: MD5 checksum of object's source. - * **description**: human-readable free form description of the object's source. - * **license**: a string with a type of license. - - .. note:: some limitations for setting the data source exist, see ``_validate`` method. - - :return: dictionary describing the source of Data object. - """ - return self.base.attributes.get('source', None) - - @source.setter - def source(self, source): - """ - Sets the dictionary describing the source of Data object. - - :raise KeyError: if dictionary contains unknown field. - :raise ValueError: if supplied source description is not a dictionary. - """ - if not isinstance(source, dict): - raise ValueError('Source must be supplied as a dictionary') - unknown_attrs = tuple(set(source.keys()) - set(self._source_attributes)) - if unknown_attrs: - raise KeyError(f"Unknown source parameters: {', '.join(unknown_attrs)}") - - self.base.attributes.set('source', source) - - def set_source(self, source): - """ - Sets the dictionary describing the source of Data object. - """ - self.source = source - - @property - def creator(self): - """Return the creator of this node or None if it does not exist. - - :return: the creating node or None - """ - inputs = self.base.links.get_incoming(link_type=LinkType.CREATE) - link = inputs.first() - if link: - return link.node - - return None - - @override - def _exportcontent(self, fileformat, main_file_name='', **kwargs): - """ - Converts a Data node to one (or multiple) files. - - Note: Export plugins should return utf8-encoded **bytes**, which can be - directly dumped to file. - - :param fileformat: the extension, uniquely specifying the file format. - :type fileformat: str - :param main_file_name: (empty by default) Can be used by plugin to - infer sensible names for additional files, if necessary. E.g. if the - main file is '../myplot.gnu', the plugin may decide to store the dat - file under '../myplot_data.dat'. - :type main_file_name: str - :param kwargs: other parameters are passed down to the plugin - :returns: a tuple of length 2. The first element is the content of the - otuput file. The second is a dictionary (possibly empty) in the format - {filename: filecontent} for any additional file that should be produced. - :rtype: (bytes, dict) - """ - exporters = self._get_exporters() - - try: - func = exporters[fileformat] - except KeyError: - if exporters.keys(): - raise ValueError( - 'The format {} is not implemented for {}. ' - 'Currently implemented are: {}.'.format( - fileformat, self.__class__.__name__, ','.join(exporters.keys()) - ) - ) - else: - raise ValueError( - 'The format {} is not implemented for {}. ' - 'No formats are implemented yet.'.format(fileformat, self.__class__.__name__) - ) - - string, dictionary = func(main_file_name=main_file_name, **kwargs) - assert isinstance(string, bytes), 'export function `{}` did not return the content as a byte string.' - - return string, dictionary - - @override - def export(self, path, fileformat=None, overwrite=False, **kwargs): - """ - Save a Data object to a file. - - :param fname: string with file name. Can be an absolute or relative path. - :param fileformat: kind of format to use for the export. If not present, - it will try to use the extension of the file name. - :param overwrite: if set to True, overwrites file found at path. Default=False - :param kwargs: additional parameters to be passed to the - _exportcontent method - :return: the list of files created - """ - import os - - if not path: - raise ValueError('Path not recognized') - - if os.path.exists(path) and not overwrite: - raise OSError(f'A file was already found at {path}') - - if fileformat is None: - extension = os.path.splitext(path)[1] - if extension.startswith(os.path.extsep): - extension = extension[len(os.path.extsep):] - if not extension: - raise ValueError('Cannot recognized the fileformat from the extension') - - # Replace the fileformat using the replacements specified in the - # _export_format_replacements dictionary. If not found there, - # by default assume the fileformat string is identical to the extension - fileformat = self._export_format_replacements.get(extension, extension) - - retlist = [] - - filetext, extra_files = self._exportcontent(fileformat, main_file_name=path, **kwargs) - - if not overwrite: - for fname in extra_files: - if os.path.exists(fname): - raise OSError(f'The file {fname} already exists, stopping.') - - if os.path.exists(path): - raise OSError(f'The file {path} already exists, stopping.') - - for additional_fname, additional_fcontent in extra_files.items(): - retlist.append(additional_fname) - with open(additional_fname, 'wb', encoding=None) as fhandle: - fhandle.write(additional_fcontent) # This is up to each specific plugin - retlist.append(path) - with open(path, 'wb', encoding=None) as fhandle: - fhandle.write(filetext) - - return retlist - - def _get_exporters(self): - """ - Get all implemented export formats. - The convention is to find all _prepare_... methods. - Returns a dictionary of method_name: method_function - """ - # NOTE: To add support for a new format, write a new function called as - # _prepare_"" with the name of the new format - exporter_prefix = '_prepare_' - valid_format_names = self.get_export_formats() - valid_formats = {k: getattr(self, exporter_prefix + k) for k in valid_format_names} - return valid_formats - - @classmethod - def get_export_formats(cls): - """ - Get the list of valid export format strings - - :return: a list of valid formats - """ - exporter_prefix = '_prepare_' - method_names = dir(cls) # get list of class methods names - valid_format_names = [ - i[len(exporter_prefix):] for i in method_names if i.startswith(exporter_prefix) - ] # filter them - return sorted(valid_format_names) - - def importstring(self, inputstring, fileformat, **kwargs): - """ - Converts a Data object to other text format. - - :param fileformat: a string (the extension) to describe the file format. - :returns: a string with the structure description. - """ - importers = self._get_importers() - - try: - func = importers[fileformat] - except KeyError: - if importers.keys(): - raise ValueError( - 'The format {} is not implemented for {}. ' - 'Currently implemented are: {}.'.format( - fileformat, self.__class__.__name__, ','.join(importers.keys()) - ) - ) - else: - raise ValueError( - 'The format {} is not implemented for {}. ' - 'No formats are implemented yet.'.format(fileformat, self.__class__.__name__) - ) - - # func is bound to self by getattr in _get_importers() - func(inputstring, **kwargs) - - def importfile(self, fname, fileformat=None): - """ - Populate a Data object from a file. - - :param fname: string with file name. Can be an absolute or relative path. - :param fileformat: kind of format to use for the export. If not present, - it will try to use the extension of the file name. - """ - if fileformat is None: - fileformat = fname.split('.')[-1] - with open(fname, 'r', encoding='utf8') as fhandle: # reads in cwd, if fname is not absolute - self.importstring(fhandle.read(), fileformat) - - def _get_importers(self): - """ - Get all implemented import formats. - The convention is to find all _parse_... methods. - Returns a list of strings. - """ - # NOTE: To add support for a new format, write a new function called as - # _parse_"" with the name of the new format - importer_prefix = '_parse_' - method_names = dir(self) # get list of class methods names - valid_format_names = [i[len(importer_prefix):] for i in method_names if i.startswith(importer_prefix)] - valid_formats = {k: getattr(self, importer_prefix + k) for k in valid_format_names} - return valid_formats - - def convert(self, object_format=None, *args): - """ - Convert the AiiDA StructureData into another python object - - :param object_format: Specify the output format - """ - # pylint: disable=keyword-arg-before-vararg - - if object_format is None: - raise ValueError('object_format must be provided') - - if not isinstance(object_format, str): - raise ValueError('object_format should be a string') - - converters = self._get_converters() - - try: - func = converters[object_format] - except KeyError: - if converters.keys(): - raise ValueError( - 'The format {} is not implemented for {}. ' - 'Currently implemented are: {}.'.format( - object_format, self.__class__.__name__, ','.join(converters.keys()) - ) - ) - else: - raise ValueError( - 'The format {} is not implemented for {}. ' - 'No formats are implemented yet.'.format(object_format, self.__class__.__name__) - ) - - return func(*args) - - def _get_converters(self): - """ - Get all implemented converter formats. - The convention is to find all _get_object_... methods. - Returns a list of strings. - """ - # NOTE: To add support for a new format, write a new function called as - # _prepare_"" with the name of the new format - exporter_prefix = '_get_object_' - method_names = dir(self) # get list of class methods names - valid_format_names = [i[len(exporter_prefix):] for i in method_names if i.startswith(exporter_prefix)] - valid_formats = {k: getattr(self, exporter_prefix + k) for k in valid_format_names} - return valid_formats diff --git a/aiida/orm/nodes/data/folder.py b/aiida/orm/nodes/data/folder.py deleted file mode 100644 index 38c679d9ef..0000000000 --- a/aiida/orm/nodes/data/folder.py +++ /dev/null @@ -1,40 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""`Data` sub class to represent a folder on a file system.""" - -from .data import Data - -__all__ = ('FolderData',) - - -class FolderData(Data): - """`Data` sub class to represent a folder on a file system.""" - - def __init__(self, **kwargs): - """Construct a new `FolderData` to which any files and folders can be added. - - Use the `tree` keyword to simply wrap a directory: - - folder = FolderData(tree='/absolute/path/to/directory') - - Alternatively, one can construct the node first and then use the various repository methods to add objects: - - folder = FolderData() - folder.put_object_from_tree('/absolute/path/to/directory') - folder.put_object_from_filepath('/absolute/path/to/file.txt') - folder.put_object_from_filelike(filelike_object) - - :param tree: absolute path to a folder to wrap - :type tree: str - """ - tree = kwargs.pop('tree', None) - super().__init__(**kwargs) - if tree: - self.base.repository.put_object_from_tree(tree) diff --git a/aiida/orm/nodes/data/orbital.py b/aiida/orm/nodes/data/orbital.py deleted file mode 100644 index 32f1640cac..0000000000 --- a/aiida/orm/nodes/data/orbital.py +++ /dev/null @@ -1,118 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Data plugin to model an atomic orbital.""" -import copy - -from aiida.common.exceptions import ValidationError -from aiida.plugins import OrbitalFactory - -from .data import Data - -__all__ = ('OrbitalData',) - - -class OrbitalData(Data): - """ - Used for storing collections of orbitals, as well as - providing methods for accessing them internally. - """ - - def clear_orbitals(self): - """ - Remove all orbitals that were added to the class - Cannot work if OrbitalData has been already stored - """ - self.base.attributes.set('orbital_dicts', []) - - def get_orbitals(self, **kwargs): - """ - Returns all orbitals by default. If a site is provided, returns - all orbitals cooresponding to the location of that site, additional - arguments may be provided, which act as filters on the retrieved - orbitals. - - :param site: if provided, returns all orbitals with position of site - :kwargs: attributes than can filter the set of returned orbitals - :return list_of_outputs: a list of orbitals - """ - - orbital_dicts = copy.deepcopy(self.base.attributes.get('orbital_dicts', None)) - if orbital_dicts is None: - raise AttributeError('Orbitals must be set before being retrieved') - - filter_dict = {} - filter_dict.update(kwargs) - # prevents KeyError from occuring - orbital_dicts = [x for x in orbital_dicts if all(y in x for y in filter_dict)] - orbital_dicts = [x for x in orbital_dicts if all(x[y] == z for y, z in filter_dict.items())] - - list_of_outputs = [] - for orbital_dict in orbital_dicts: - try: - orbital_type = orbital_dict.pop('_orbital_type') - except KeyError: - raise ValidationError(f'No _orbital_type found in: {orbital_dict}') - - cls = OrbitalFactory(orbital_type) - orbital = cls(**orbital_dict) - list_of_outputs.append(orbital) - return list_of_outputs - - def set_orbitals(self, orbitals): - """ - Sets the orbitals into the database. Uses the orbital's inherent - set_orbital_dict method to generate a orbital dict string. - - :param orbital: an orbital or list of orbitals to be set - """ - if not isinstance(orbitals, list): - orbitals = [orbitals] - orbital_dicts = [] - - for orbital in orbitals: - orbital_dict = copy.deepcopy(orbital.get_orbital_dict()) - try: - _orbital_type = orbital_dict['_orbital_type'] - except KeyError: - raise ValueError(f'No _orbital_type found in: {orbital_dict}') - orbital_dicts.append(orbital_dict) - self.base.attributes.set('orbital_dicts', orbital_dicts) - - -########################################################################## -# Here are some ideas for potential future convenience methods -######################################################################### -# def set_projection_on_site(self, orbital, site, tag=None): -# """ -# Sets a orbital on a site -# We prepare the description dictionary, using information `parsed` -# from the site. -# """ -# diffusivity = from_site_guess_diffusivity(site) # or 1. -# position = site.position -# description = {'somedictionary of the above':''} -# self.set_projection(orbital=orbital, description=description) - -# def delete_projections_by_attribute(self, selection_attributes): -# """ -# Deletes all projections whose internal attributes correspond to the -# selection_attributes -# """ -# raise NotImplementedError -# -# def modify_projections(self, key_attributes_to_select_projections, attributes_to_be_modified): -# """ -# Modifies the projections, as selected by the key_attributes. -# Overwrites attributes inside these projections, to values stored -# in attributes_to_be_modified -# """ -# -# def set_realhydrogenorbitals_from_structure(self, structure, pseudo_family=None): -# raise NotImplementedError diff --git a/aiida/orm/nodes/data/remote/__init__.py b/aiida/orm/nodes/data/remote/__init__.py deleted file mode 100644 index ae1b5dbc4f..0000000000 --- a/aiida/orm/nodes/data/remote/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# -*- coding: utf-8 -*- -"""Module with data plugins that represent remote resources and so effectively are symbolic links.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .base import * -from .stash import * - -__all__ = ( - 'RemoteData', - 'RemoteStashData', - 'RemoteStashFolderData', -) - -# yapf: enable diff --git a/aiida/orm/nodes/data/remote/base.py b/aiida/orm/nodes/data/remote/base.py deleted file mode 100644 index d4b485d926..0000000000 --- a/aiida/orm/nodes/data/remote/base.py +++ /dev/null @@ -1,196 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Data plugin that models a folder on a remote computer.""" -import os - -from aiida.orm import AuthInfo - -from ..data import Data - -__all__ = ('RemoteData',) - - -class RemoteData(Data): - """ - Store a link to a file or folder on a remote machine. - - Remember to pass a computer! - """ - - KEY_EXTRA_CLEANED = 'cleaned' - - def __init__(self, remote_path=None, **kwargs): - super().__init__(**kwargs) - if remote_path is not None: - self.set_remote_path(remote_path) - - def get_remote_path(self): - return self.base.attributes.get('remote_path') - - def set_remote_path(self, val): - self.base.attributes.set('remote_path', val) - - @property - def is_empty(self): - """ - Check if remote folder is empty - """ - authinfo = self.get_authinfo() - transport = authinfo.get_transport() - - with transport: - try: - transport.chdir(self.get_remote_path()) - except IOError: - # If the transport IOError the directory no longer exists and was deleted - return True - - return not transport.listdir() - - def getfile(self, relpath, destpath): - """ - Connects to the remote folder and retrieves the content of a file. - - :param relpath: The relative path of the file on the remote to retrieve. - :param destpath: The absolute path of where to store the file on the local machine. - """ - authinfo = self.get_authinfo() - - with authinfo.get_transport() as transport: - try: - full_path = os.path.join(self.get_remote_path(), relpath) - transport.getfile(full_path, destpath) - except IOError as exception: - if exception.errno == 2: # file does not exist - raise IOError( - 'The required remote file {} on {} does not exist or has been deleted.'.format( - full_path, - self.computer.label # pylint: disable=no-member - ) - ) from exception - raise - - def listdir(self, relpath='.'): - """ - Connects to the remote folder and lists the directory content. - - :param relpath: If 'relpath' is specified, lists the content of the given subfolder. - :return: a flat list of file/directory names (as strings). - """ - authinfo = self.get_authinfo() - - with authinfo.get_transport() as transport: - try: - full_path = os.path.join(self.get_remote_path(), relpath) - transport.chdir(full_path) - except IOError as exception: - if exception.errno in (2, 20): # directory not existing or not a directory - exc = IOError( - 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'. - format(full_path, self.computer.label) # pylint: disable=no-member - ) - exc.errno = exception.errno - raise exc from exception - else: - raise - - try: - return transport.listdir() - except IOError as exception: - if exception.errno in (2, 20): # directory not existing or not a directory - exc = IOError( - 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'. - format(full_path, self.computer.label) # pylint: disable=no-member - ) - exc.errno = exception.errno - raise exc from exception - else: - raise - - def listdir_withattributes(self, path='.'): - """ - Connects to the remote folder and lists the directory content. - - :param relpath: If 'relpath' is specified, lists the content of the given subfolder. - :return: a list of dictionaries, where the documentation is in :py:class:Transport.listdir_withattributes. - """ - authinfo = self.get_authinfo() - - with authinfo.get_transport() as transport: - try: - full_path = os.path.join(self.get_remote_path(), path) - transport.chdir(full_path) - except IOError as exception: - if exception.errno in (2, 20): # directory not existing or not a directory - exc = IOError( - 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'. - format(full_path, self.computer.label) # pylint: disable=no-member - ) - exc.errno = exception.errno - raise exc from exception - else: - raise - - try: - return transport.listdir_withattributes() - except IOError as exception: - if exception.errno in (2, 20): # directory not existing or not a directory - exc = IOError( - 'The required remote folder {} on {} does not exist, is not a directory or has been deleted.'. - format(full_path, self.computer.label) # pylint: disable=no-member - ) - exc.errno = exception.errno - raise exc from exception - else: - raise - - def _clean(self, transport=None): - """Remove all content of the remote folder on the remote computer. - - When the cleaning operation is successful, the extra with the key ``RemoteData.KEY_EXTRA_CLEANED`` is set. - - :param transport: Provide an optional transport that is already open. If not provided, a transport will be - automatically opened, based on the current default user and the computer of this data node. Passing in the - transport can be used for efficiency if a great number of nodes need to be cleaned for the same computer. - Note that the user should take care that the correct transport is passed. - :raises ValueError: If the hostname of the provided transport does not match that of the node's computer. - """ - from aiida.orm.utils.remote import clean_remote - - remote_dir = self.get_remote_path() - - if transport is None: - with self.get_authinfo().get_transport() as transport: # pylint: disable=redefined-argument-from-local - clean_remote(transport, remote_dir) - else: - if transport.hostname != self.computer.hostname: - raise ValueError( - f'Transport hostname `{transport.hostname}` does not equal `{self.computer.hostname}` of {self}.' - ) - clean_remote(transport, remote_dir) - - self.base.extras.set(self.KEY_EXTRA_CLEANED, True) - - def _validate(self): - from aiida.common.exceptions import ValidationError - - super()._validate() - - try: - self.get_remote_path() - except AttributeError as exception: - raise ValidationError("attribute 'remote_path' not set.") from exception - - computer = self.computer - if computer is None: - raise ValidationError('Remote computer not set.') - - def get_authinfo(self): - return AuthInfo.collection(self.backend).get(dbcomputer=self.computer, aiidauser=self.user) diff --git a/aiida/orm/nodes/data/remote/stash/__init__.py b/aiida/orm/nodes/data/remote/stash/__init__.py deleted file mode 100644 index e06481e842..0000000000 --- a/aiida/orm/nodes/data/remote/stash/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# -*- coding: utf-8 -*- -"""Module with data plugins that represent files of completed calculations jobs that have been stashed.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .base import * -from .folder import * - -__all__ = ( - 'RemoteStashData', - 'RemoteStashFolderData', -) - -# yapf: enable diff --git a/aiida/orm/nodes/data/remote/stash/base.py b/aiida/orm/nodes/data/remote/stash/base.py deleted file mode 100644 index c768505249..0000000000 --- a/aiida/orm/nodes/data/remote/stash/base.py +++ /dev/null @@ -1,53 +0,0 @@ -# -*- coding: utf-8 -*- -"""Data plugin that models an archived folder on a remote computer.""" -from aiida.common.datastructures import StashMode -from aiida.common.lang import type_check - -from ...data import Data - -__all__ = ('RemoteStashData',) - - -class RemoteStashData(Data): - """Data plugin that models an archived folder on a remote computer. - - A stashed folder is essentially an instance of ``RemoteData`` that has been archived. Archiving in this context can - simply mean copying the content of the folder to another location on the same or another filesystem as long as it is - on the same machine. In addition, the folder may have been compressed into a single file for efficiency or even - written to tape. The ``stash_mode`` attribute will distinguish how the folder was stashed which will allow the - implementation to also `unstash` it and transform it back into a ``RemoteData`` such that it can be used as an input - for new ``CalcJobs``. - - This class is a non-storable base class that merely registers the ``stash_mode`` attribute. Only its subclasses, - that actually implement a certain stash mode, can be instantiated and therefore stored. The reason for this design - is that because the behavior of the class can change significantly based on the mode employed to stash the files and - implementing all these variants in the same class will lead to an unintuitive interface where certain properties or - methods of the class will only be available or function properly based on the ``stash_mode``. - """ - - _storable = False - - def __init__(self, stash_mode: StashMode, **kwargs): - """Construct a new instance - - :param stash_mode: the stashing mode with which the data was stashed on the remote. - """ - super().__init__(**kwargs) - self.stash_mode = stash_mode - - @property - def stash_mode(self) -> StashMode: - """Return the mode with which the data was stashed on the remote. - - :return: the stash mode. - """ - return StashMode(self.base.attributes.get('stash_mode')) - - @stash_mode.setter - def stash_mode(self, value: StashMode): - """Set the mode with which the data was stashed on the remote. - - :param value: the stash mode. - """ - type_check(value, StashMode) - self.base.attributes.set('stash_mode', value.value) diff --git a/aiida/orm/nodes/data/remote/stash/folder.py b/aiida/orm/nodes/data/remote/stash/folder.py deleted file mode 100644 index bf182b7a5c..0000000000 --- a/aiida/orm/nodes/data/remote/stash/folder.py +++ /dev/null @@ -1,67 +0,0 @@ -# -*- coding: utf-8 -*- -"""Data plugin that models a stashed folder on a remote computer.""" -import typing - -from aiida.common.datastructures import StashMode -from aiida.common.lang import type_check - -from .base import RemoteStashData - -__all__ = ('RemoteStashFolderData',) - - -class RemoteStashFolderData(RemoteStashData): - """Data plugin that models a folder with files of a completed calculation job that has been stashed through a copy. - - This data plugin can and should be used to stash files if and only if the stash mode is `StashMode.COPY`. - """ - - _storable = True - - def __init__(self, stash_mode: StashMode, target_basepath: str, source_list: typing.List, **kwargs): - """Construct a new instance - - :param stash_mode: the stashing mode with which the data was stashed on the remote. - :param target_basepath: the target basepath. - :param source_list: the list of source files. - """ - super().__init__(stash_mode, **kwargs) - self.target_basepath = target_basepath - self.source_list = source_list - - if stash_mode != StashMode.COPY: - raise ValueError('`RemoteStashFolderData` can only be used with `stash_mode == StashMode.COPY`.') - - @property - def target_basepath(self) -> str: - """Return the target basepath. - - :return: the target basepath. - """ - return self.base.attributes.get('target_basepath') - - @target_basepath.setter - def target_basepath(self, value: str): - """Set the target basepath. - - :param value: the target basepath. - """ - type_check(value, str) - self.base.attributes.set('target_basepath', value) - - @property - def source_list(self) -> typing.Union[typing.List, typing.Tuple]: - """Return the list of source files that were stashed. - - :return: the list of source files. - """ - return self.base.attributes.get('source_list') - - @source_list.setter - def source_list(self, value: typing.Union[typing.List, typing.Tuple]): - """Set the list of source files that were stashed. - - :param value: the list of source files. - """ - type_check(value, (list, tuple)) - self.base.attributes.set('source_list', value) diff --git a/aiida/orm/nodes/data/singlefile.py b/aiida/orm/nodes/data/singlefile.py deleted file mode 100644 index 4db02f1abe..0000000000 --- a/aiida/orm/nodes/data/singlefile.py +++ /dev/null @@ -1,141 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Data class that can be used to store a single file in its repository.""" -from __future__ import annotations - -import contextlib -import io -import os -import pathlib - -from aiida.common import exceptions - -from .data import Data - -__all__ = ('SinglefileData',) - - -class SinglefileData(Data): - """Data class that can be used to store a single file in its repository.""" - - DEFAULT_FILENAME = 'file.txt' - - @classmethod - def from_string(cls, content: str, filename: str | pathlib.Path | None = None, **kwargs): - """Construct a new instance and set ``content`` as its contents. - - :param content: The content as a string. - :param filename: Specify filename to use (defaults to ``file.txt``). - """ - return cls(io.StringIO(content), filename, **kwargs) - - def __init__(self, file, filename: str | pathlib.Path | None = None, **kwargs): - """Construct a new instance and set the contents to that of the file. - - :param file: an absolute filepath or filelike object whose contents to copy. - Hint: Pass io.BytesIO(b"my string") to construct the SinglefileData directly from a string. - :param filename: specify filename to use (defaults to name of provided file). - """ - # pylint: disable=redefined-builtin - super().__init__(**kwargs) - - if file is not None: - self.set_file(file, filename=filename) - - @property - def filename(self): - """Return the name of the file stored. - - :return: the filename under which the file is stored in the repository - """ - return self.base.attributes.get('filename') - - @contextlib.contextmanager - def open(self, path=None, mode='r'): - """Return an open file handle to the content of this data node. - - :param path: the relative path of the object within the repository. - :param mode: the mode with which to open the file handle (default: read mode) - :return: a file handle - """ - if path is None: - path = self.filename - - with self.base.repository.open(path, mode=mode) as handle: - yield handle - - def get_content(self): - """Return the content of the single file stored for this data node. - - :return: the content of the file as a string - """ - with self.open() as handle: - return handle.read() - - def set_file(self, file, filename: str | pathlib.Path | None = None): - """Store the content of the file in the node's repository, deleting any other existing objects. - - :param file: an absolute filepath or filelike object whose contents to copy - Hint: Pass io.BytesIO(b"my string") to construct the file directly from a string. - :param filename: specify filename to use (defaults to name of provided file). - """ - # pylint: disable=redefined-builtin - - if isinstance(file, (str, pathlib.Path)): - is_filelike = False - - key = os.path.basename(file) - if not os.path.isabs(file): - raise ValueError(f'path `{file}` is not absolute') - - if not os.path.isfile(file): - raise ValueError(f'path `{file}` does not correspond to an existing file') - else: - is_filelike = True - try: - key = os.path.basename(file.name) - except AttributeError: - key = self.DEFAULT_FILENAME - - key = str(filename) if filename is not None else key - existing_object_names = self.base.repository.list_object_names() - - try: - # Remove the 'key' from the list of currently existing objects such that it is not deleted after storing - existing_object_names.remove(key) - except ValueError: - pass - - if is_filelike: - self.base.repository.put_object_from_filelike(file, key) - else: - self.base.repository.put_object_from_file(file, key) - - # Delete any other existing objects (minus the current `key` which was already removed from the list) - for existing_key in existing_object_names: - self.base.repository.delete_object(existing_key) - - self.base.attributes.set('filename', key) - - def _validate(self): - """Ensure that there is one object stored in the repository, whose key matches value set for `filename` attr.""" - super()._validate() - - try: - filename = self.filename - except AttributeError: - raise exceptions.ValidationError('the `filename` attribute is not set.') - - objects = self.base.repository.list_object_names() - - if [filename] != objects: - raise exceptions.ValidationError( - f'respository files {objects} do not match the `filename` attribute `{filename}`.' - ) diff --git a/aiida/orm/nodes/data/structure.py b/aiida/orm/nodes/data/structure.py deleted file mode 100644 index 8f8dc95e47..0000000000 --- a/aiida/orm/nodes/data/structure.py +++ /dev/null @@ -1,2520 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=too-many-lines -""" -This module defines the classes for structures and all related -functions to operate on them. -""" -import copy -import functools -import itertools -import json - -from aiida.common.constants import elements -from aiida.common.exceptions import UnsupportedSpeciesError - -from .data import Data - -__all__ = ('StructureData', 'Kind', 'Site') - -# Threshold used to check if the mass of two different Site objects is the same. - -_MASS_THRESHOLD = 1.e-3 -# Threshold to check if the sum is one or not -_SUM_THRESHOLD = 1.e-6 -# Default cell -_DEFAULT_CELL = ((0, 0, 0), (0, 0, 0), (0, 0, 0)) - -_valid_symbols = tuple(i['symbol'] for i in elements.values()) -_atomic_masses = {el['symbol']: el['mass'] for el in elements.values()} -_atomic_numbers = {data['symbol']: num for num, data in elements.items()} - - -def _get_valid_cell(inputcell): - """ - Return the cell in a valid format from a generic input. - - :raise ValueError: whenever the format is not valid. - """ - try: - the_cell = list(list(float(c) for c in i) for i in inputcell) - if len(the_cell) != 3: - raise ValueError - if any(len(i) != 3 for i in the_cell): - raise ValueError - except (IndexError, ValueError, TypeError): - raise ValueError('Cell must be a list of three vectors, each defined as a list of three coordinates.') - - return the_cell - - -def get_valid_pbc(inputpbc): - """ - Return a list of three booleans for the periodic boundary conditions, - in a valid format from a generic input. - - :raise ValueError: if the format is not valid. - """ - if isinstance(inputpbc, bool): - the_pbc = (inputpbc, inputpbc, inputpbc) - elif hasattr(inputpbc, '__iter__'): - # To manage numpy lists of bools, whose elements are of type numpy.bool_ - # and for which isinstance(i,bool) return False... - if hasattr(inputpbc, 'tolist'): - the_value = inputpbc.tolist() - else: - the_value = inputpbc - if all(isinstance(i, bool) for i in the_value): - if len(the_value) == 3: - the_pbc = tuple(i for i in the_value) - elif len(the_value) == 1: - the_pbc = (the_value[0], the_value[0], the_value[0]) - else: - raise ValueError('pbc length must be either one or three.') - else: - raise ValueError('pbc elements are not booleans.') - else: - raise ValueError('pbc must be a boolean or a list of three booleans.', inputpbc) - - return the_pbc - - -def has_ase(): - """ - :return: True if the ase module can be imported, False otherwise. - """ - try: - import ase # pylint: disable=unused-import - except ImportError: - return False - return True - - -def has_pymatgen(): - """ - :return: True if the pymatgen module can be imported, False otherwise. - """ - try: - import pymatgen # pylint: disable=unused-import - except ImportError: - return False - return True - - -def get_pymatgen_version(): - """ - :return: string with pymatgen version, None if can not import. - """ - if not has_pymatgen(): - return None - try: - from pymatgen import __version__ - except ImportError: - # this was changed in version 2022.0.3 - from pymatgen.core import __version__ - return __version__ - - -def has_spglib(): - """ - :return: True if the spglib module can be imported, False otherwise. - """ - try: - import spglib # pylint: disable=unused-import - except ImportError: - return False - return True - - -def calc_cell_volume(cell): - """ - Compute the three-dimensional cell volume in Angstrom^3. - - :param cell: the cell vectors; the must be a 3x3 list of lists of floats - :returns: the cell volume. - """ - import numpy as np - return np.abs(np.dot(cell[0], np.cross(cell[1], cell[2]))) - - -def _create_symbols_tuple(symbols): - """ - Returns a tuple with the symbols provided. If a string is provided, - this is converted to a tuple with one single element. - """ - if isinstance(symbols, str): - symbols_list = (symbols,) - else: - symbols_list = tuple(symbols) - return symbols_list - - -def _create_weights_tuple(weights): - """ - Returns a tuple with the weights provided. If a number is provided, - this is converted to a tuple with one single element. - If None is provided, this is converted to the tuple (1.,) - """ - import numbers - - if weights is None: - weights_tuple = (1.,) - elif isinstance(weights, numbers.Number): - weights_tuple = (weights,) - else: - weights_tuple = tuple(float(i) for i in weights) - return weights_tuple - - -def create_automatic_kind_name(symbols, weights): - """ - Create a string obtained with the symbols appended one - after the other, without spaces, in alphabetical order; - if the site has a vacancy, a X is appended at the end too. - """ - sorted_symbol_list = list(set(symbols)) - sorted_symbol_list.sort() # In-place sort - name_string = ''.join(sorted_symbol_list) - if has_vacancies(weights): - name_string += 'X' - return name_string - - -def validate_weights_tuple(weights_tuple, threshold): - """ - Validates the weight of the atomic kinds. - - :raise: ValueError if the weights_tuple is not valid. - - :param weights_tuple: the tuple to validate. It must be a - a tuple of floats (as created by :func:_create_weights_tuple). - :param threshold: a float number used as a threshold to check that the sum - of the weights is <= 1. - - If the sum is less than one, it means that there are vacancies. - Each element of the list must be >= 0, and the sum must be <= 1. - """ - w_sum = sum(weights_tuple) - if (any(i < 0. for i in weights_tuple) or (w_sum - 1. > threshold)): - raise ValueError('The weight list is not valid (each element must be positive, and the sum must be <= 1).') - - -def is_valid_symbol(symbol): - """ - Validates the chemical symbol name. - - :return: True if the symbol is a valid chemical symbol (with correct - capitalization), or the dummy X, False otherwise. - - Recognized symbols are for elements from hydrogen (Z=1) to lawrencium - (Z=103). In addition, a dummy element unknown name (Z=0) is supported. - """ - return symbol in _valid_symbols - - -def validate_symbols_tuple(symbols_tuple): - """ - Used to validate whether the chemical species are valid. - - :param symbols_tuple: a tuple (or list) with the chemical symbols name. - :raises: UnsupportedSpeciesError if any symbol in the tuple is not a valid chemical - symbol (with correct capitalization). - - Refer also to the documentation of :func:is_valid_symbol - """ - if len(symbols_tuple) == 0: - valid = False - else: - valid = all(is_valid_symbol(sym) for sym in symbols_tuple) - if not valid: - raise UnsupportedSpeciesError( - f'At least one element of the symbol list {symbols_tuple} has not been recognized.' - ) - - -def is_ase_atoms(ase_atoms): - """ - Check if the ase_atoms parameter is actually a ase.Atoms object. - - :param ase_atoms: an object, expected to be an ase.Atoms. - :return: a boolean. - - Requires the ability to import ase, by doing 'import ase'. - """ - import ase - return isinstance(ase_atoms, ase.Atoms) - - -def group_symbols(_list): - """ - Group a list of symbols to a list containing the number of consecutive - identical symbols, and the symbol itself. - - Examples: - - * ``['Ba','Ti','O','O','O','Ba']`` will return - ``[[1,'Ba'],[1,'Ti'],[3,'O'],[1,'Ba']]`` - - * ``[ [ [1,'Ba'],[1,'Ti'] ],[ [1,'Ba'],[1,'Ti'] ] ]`` will return - ``[[2, [ [1, 'Ba'], [1, 'Ti'] ] ]]`` - - :param _list: a list of elements representing a chemical formula - :return: a list of length-2 lists of the form [ multiplicity , element ] - """ - - the_list = copy.deepcopy(_list) - the_list.reverse() - grouped_list = [[1, the_list.pop()]] - while the_list: - elem = the_list.pop() - if elem == grouped_list[-1][1]: - # same symbol is repeated - grouped_list[-1][0] += 1 - else: - grouped_list.append([1, elem]) - - return grouped_list - - -def get_formula_from_symbol_list(_list, separator=''): - """ - Return a string with the formula obtained from the list of symbols. - Examples: - * ``[[1,'Ba'],[1,'Ti'],[3,'O']]`` will return ``'BaTiO3'`` - * ``[[2, [ [1, 'Ba'], [1, 'Ti'] ] ]]`` will return ``'(BaTi)2'`` - - :param _list: a list of symbols and multiplicities as obtained from - the function group_symbols - :param separator: a string used to concatenate symbols. Default empty. - - :return: a string - """ - - list_str = [] - for elem in _list: - if elem[0] == 1: - multiplicity_str = '' - else: - multiplicity_str = str(elem[0]) - - if isinstance(elem[1], str): - list_str.append(f'{elem[1]}{multiplicity_str}') - elif elem[0] > 1: - list_str.append(f'({get_formula_from_symbol_list(elem[1], separator=separator)}){multiplicity_str}') - else: - list_str.append(f'{get_formula_from_symbol_list(elem[1], separator=separator)}{multiplicity_str}') - - return separator.join(list_str) - - -def get_formula_group(symbol_list, separator=''): - """ - Return a string with the chemical formula from a list of chemical symbols. - The formula is written in a compact" way, i.e. trying to group as much as - possible parts of the formula. - - .. note:: it works for instance very well if structure was obtained - from an ASE supercell. - - Example of result: - ``['Ba', 'Ti', 'O', 'O', 'O', 'Ba', 'Ti', 'O', 'O', 'O', - 'Ba', 'Ti', 'Ti', 'O', 'O', 'O']`` will return ``'(BaTiO3)2BaTi2O3'``. - - :param symbol_list: list of symbols - (e.g. ['Ba','Ti','O','O','O']) - :param separator: a string used to concatenate symbols. Default empty. - :returns: a string with the chemical formula for the given structure. - """ - - def group_together(_list, group_size, offset): - """ - :param _list: a list - :param group_size: size of the groups - :param offset: beginning grouping after offset elements - :return : a list of lists made of groups of size group_size - obtained by grouping list elements together - The first elements (up to _list[offset-1]) are not grouped - example: - ``group_together(['O','Ba','Ti','Ba','Ti'],2,1) = - ['O',['Ba','Ti'],['Ba','Ti']]`` - """ - - the_list = copy.deepcopy(_list) - the_list.reverse() - grouped_list = [] - for _ in range(offset): - grouped_list.append([the_list.pop()]) - - while the_list: - sub_list = [] - for _ in range(group_size): - if the_list: - sub_list.append(the_list.pop()) - grouped_list.append(sub_list) - - return grouped_list - - def cleanout_symbol_list(_list): - """ - :param _list: a list of groups of symbols and multiplicities - :return : a list where all groups with multiplicity 1 have - been reduced to minimum - example: ``[[1,[[1,'Ba']]]]`` will return ``[[1,'Ba']]`` - """ - the_list = [] - for elem in _list: - if elem[0] == 1 and isinstance(elem[1], list): - the_list.extend(elem[1]) - else: - the_list.append(elem) - - return the_list - - def group_together_symbols(_list, group_size): - """ - Successive application of group_together, group_symbols and - cleanout_symbol_list, in order to group a symbol list, scanning all - possible offsets, for a given group size - :param _list: the symbol list (see function group_symbols) - :param group_size: the size of the groups - :return the_symbol_list: the new grouped symbol list - :return has_grouped: True if we grouped something - """ - the_symbol_list = copy.deepcopy(_list) - has_grouped = False - offset = 0 - while not has_grouped and offset < group_size: - grouped_list = group_together(the_symbol_list, group_size, offset) - new_symbol_list = group_symbols(grouped_list) - if len(new_symbol_list) < len(grouped_list): - the_symbol_list = copy.deepcopy(new_symbol_list) - the_symbol_list = cleanout_symbol_list(the_symbol_list) - has_grouped = True - # print get_formula_from_symbol_list(the_symbol_list) - offset += 1 - - return the_symbol_list, has_grouped - - def group_all_together_symbols(_list): - """ - Successive application of the function group_together_symbols, to group - a symbol list, scanning all possible offsets and group sizes - :param _list: the symbol list (see function group_symbols) - :return: the new grouped symbol list - """ - has_finished = False - group_size = 2 - the_symbol_list = copy.deepcopy(_list) - - while not has_finished and group_size <= len(_list) // 2: - # try to group as much as possible by groups of size group_size - the_symbol_list, has_grouped = group_together_symbols(the_symbol_list, group_size) - has_finished = has_grouped - group_size += 1 - # stop as soon as we managed to group something - # or when the group_size is too big to get anything - - return the_symbol_list - - # initial grouping of the chemical symbols - old_symbol_list = [-1] - new_symbol_list = group_symbols(symbol_list) - - # successively apply the grouping procedure until the symbol list does not - # change anymore - while new_symbol_list != old_symbol_list: - old_symbol_list = copy.deepcopy(new_symbol_list) - new_symbol_list = group_all_together_symbols(old_symbol_list) - - return get_formula_from_symbol_list(new_symbol_list, separator=separator) - - -def get_formula(symbol_list, mode='hill', separator=''): - """ - Return a string with the chemical formula. - - :param symbol_list: a list of symbols, e.g. ``['H','H','O']`` - :param mode: a string to specify how to generate the formula, can - assume one of the following values: - - * 'hill' (default): count the number of atoms of each species, - then use Hill notation, i.e. alphabetical order with C and H - first if one or several C atom(s) is (are) present, e.g. - ``['C','H','H','H','O','C','H','H','H']`` will return ``'C2H6O'`` - ``['S','O','O','H','O','H','O']`` will return ``'H2O4S'`` - From E. A. Hill, J. Am. Chem. Soc., 22 (8), pp 478–494 (1900) - - * 'hill_compact': same as hill but the number of atoms for each - species is divided by the greatest common divisor of all of them, e.g. - ``['C','H','H','H','O','C','H','H','H','O','O','O']`` - will return ``'CH3O2'`` - - * 'reduce': group repeated symbols e.g. - ``['Ba', 'Ti', 'O', 'O', 'O', 'Ba', 'Ti', 'O', 'O', 'O', - 'Ba', 'Ti', 'Ti', 'O', 'O', 'O']`` will return ``'BaTiO3BaTiO3BaTi2O3'`` - - * 'group': will try to group as much as possible parts of the formula - e.g. - ``['Ba', 'Ti', 'O', 'O', 'O', 'Ba', 'Ti', 'O', 'O', 'O', - 'Ba', 'Ti', 'Ti', 'O', 'O', 'O']`` will return ``'(BaTiO3)2BaTi2O3'`` - - * 'count': same as hill (i.e. one just counts the number - of atoms of each species) without the re-ordering (take the - order of the atomic sites), e.g. - ``['Ba', 'Ti', 'O', 'O', 'O','Ba', 'Ti', 'O', 'O', 'O']`` - will return ``'Ba2Ti2O6'`` - - * 'count_compact': same as count but the number of atoms - for each species is divided by the greatest common divisor of - all of them, e.g. - ``['Ba', 'Ti', 'O', 'O', 'O','Ba', 'Ti', 'O', 'O', 'O']`` - will return ``'BaTiO3'`` - - :param separator: a string used to concatenate symbols. Default empty. - - :return: a string with the formula - - .. note:: in modes reduce, group, count and count_compact, the - initial order in which the atoms were appended by the user is - used to group and/or order the symbols in the formula - """ - - if mode == 'group': - return get_formula_group(symbol_list, separator=separator) - - # for hill and count cases, simply count the occurences of each - # chemical symbol (with some re-ordering in hill) - if mode in ['hill', 'hill_compact']: - if 'C' in symbol_list: - ordered_symbol_set = sorted(set(symbol_list), key=lambda elem: {'C': '0', 'H': '1'}.get(elem, elem)) - else: - ordered_symbol_set = sorted(set(symbol_list)) - the_symbol_list = [[symbol_list.count(elem), elem] for elem in ordered_symbol_set] - - elif mode in ['count', 'count_compact']: - ordered_symbol_indexes = sorted([symbol_list.index(elem) for elem in set(symbol_list)]) - ordered_symbol_set = [symbol_list[i] for i in ordered_symbol_indexes] - the_symbol_list = [[symbol_list.count(elem), elem] for elem in ordered_symbol_set] - - elif mode == 'reduce': - the_symbol_list = group_symbols(symbol_list) - - else: - raise ValueError('Mode should be hill, hill_compact, group, reduce, count or count_compact') - - if mode in ['hill_compact', 'count_compact']: - from math import gcd - the_gcd = functools.reduce(gcd, [e[0] for e in the_symbol_list]) - the_symbol_list = [[e[0] // the_gcd, e[1]] for e in the_symbol_list] - - return get_formula_from_symbol_list(the_symbol_list, separator=separator) - - -def get_symbols_string(symbols, weights): - """ - Return a string that tries to match as good as possible the symbols - and weights. If there is only one symbol (no alloy) with 100% - occupancy, just returns the symbol name. Otherwise, groups the full - string in curly brackets, and try to write also the composition - (with 2 precision only). - If (sum of weights<1), we indicate it with the X symbol followed - by 1-sum(weights) (still with 2 digits precision, so it can be 0.00) - - :param symbols: the symbols as obtained from ._symbols - :param weights: the weights as obtained from ._weights - - .. note:: Note the difference with respect to the symbols and the - symbol properties! - """ - if len(symbols) == 1 and weights[0] == 1.: - return symbols[0] - - pieces = [] - for symbol, weight in zip(symbols, weights): - pieces.append(f'{symbol}{weight:4.2f}') - if has_vacancies(weights): - pieces.append(f'X{1.0 - sum(weights):4.2f}') - return f"{{{''.join(sorted(pieces))}}}" - - -def has_vacancies(weights): - """ - Returns True if the sum of the weights is less than one. - It uses the internal variable _SUM_THRESHOLD as a threshold. - :param weights: the weights - :return: a boolean - """ - w_sum = sum(weights) - return not 1. - w_sum < _SUM_THRESHOLD - - -def symop_ortho_from_fract(cell): - """ - Creates a matrix for conversion from orthogonal to fractional - coordinates. - - Taken from - svn://www.crystallography.net/cod-tools/trunk/lib/perl5/Fractional.pm, - revision 850. - - :param cell: array of cell parameters (three lengths and three angles) - """ - # pylint: disable=invalid-name - import math - - import numpy - - a, b, c, alpha, beta, gamma = cell - alpha, beta, gamma = [math.pi * x / 180 for x in [alpha, beta, gamma]] - ca, cb, cg = [math.cos(x) for x in [alpha, beta, gamma]] - sg = math.sin(gamma) - - return numpy.array([[a, b * cg, c * cb], [0, b * sg, c * (ca - cb * cg) / sg], - [0, 0, c * math.sqrt(sg * sg - ca * ca - cb * cb + 2 * ca * cb * cg) / sg]]) - - -def symop_fract_from_ortho(cell): - """ - Creates a matrix for conversion from fractional to orthogonal - coordinates. - - Taken from - svn://www.crystallography.net/cod-tools/trunk/lib/perl5/Fractional.pm, - revision 850. - - :param cell: array of cell parameters (three lengths and three angles) - """ - # pylint: disable=invalid-name - import math - - import numpy - - a, b, c, alpha, beta, gamma = cell - alpha, beta, gamma = [math.pi * x / 180 for x in [alpha, beta, gamma]] - ca, cb, cg = [math.cos(x) for x in [alpha, beta, gamma]] - sg = math.sin(gamma) - ctg = cg / sg - D = math.sqrt(sg * sg - cb * cb - ca * ca + 2 * ca * cb * cg) - - return numpy.array([ - [1.0 / a, -(1.0 / a) * ctg, (ca * cg - cb) / (a * D)], - [0, 1.0 / (b * sg), -(ca - cb * cg) / (b * D * sg)], - [0, 0, sg / (c * D)], - ]) - - -def ase_refine_cell(aseatoms, **kwargs): - """ - Detect the symmetry of the structure, remove symmetric atoms and - refine unit cell. - - :param aseatoms: an ase.atoms.Atoms instance - :param symprec: symmetry precision, used by spglib - :return newase: refined cell with reduced set of atoms - :return symmetry: a dictionary describing the symmetry space group - """ - from ase.atoms import Atoms - from spglib import get_symmetry_dataset, refine_cell - cell, positions, numbers = refine_cell(aseatoms, **kwargs) - - refined_atoms = Atoms(numbers, scaled_positions=positions, cell=cell, pbc=True) - - sym_dataset = get_symmetry_dataset(refined_atoms, **kwargs) - - unique_numbers = [] - unique_positions = [] - - for i in set(sym_dataset['equivalent_atoms']): - unique_numbers.append(refined_atoms.numbers[i]) - unique_positions.append(refined_atoms.get_scaled_positions()[i]) - - unique_atoms = Atoms(unique_numbers, scaled_positions=unique_positions, cell=cell, pbc=True) - - return unique_atoms, { - 'hm': sym_dataset['international'], - 'hall': sym_dataset['hall'], - 'tables': sym_dataset['number'], - 'rotations': sym_dataset['rotations'], - 'translations': sym_dataset['translations'] - } - - -def atom_kinds_to_html(atom_kind): - """ - - Construct in html format - - an alloy with 0.5 Ge, 0.4 Si and 0.1 vacancy is represented as - Ge0.5 + Si0.4 + vacancy0.1 - - Args: - atom_kind: a string with the name of the atomic kind, as printed by - kind.get_symbols_string(), e.g. Ba0.80Ca0.10X0.10 - - Returns: - html code for rendered formula - """ - - # Parse the formula (TODO can be made more robust though never fails if - # it takes strings generated with kind.get_symbols_string()) - import re - matched_elements = re.findall(r'([A-Z][a-z]*)([0-1][.[0-9]*]?)?', atom_kind) - - # Compose the html string - html_formula_pieces = [] - - for element in matched_elements: - - # replace element X by 'vacancy' - species = element[0] if element[0] != 'X' else 'vacancy' - weight = element[1] if element[1] != '' else None - - if weight is not None: - html_formula_pieces.append(f'{species}{weight}') - else: - html_formula_pieces.append(species) - - html_formula = ' + '.join(html_formula_pieces) - - return html_formula - - -class StructureData(Data): - """ - This class contains the information about a given structure, i.e. a - collection of sites together with a cell, the - boundary conditions (whether they are periodic or not) and other - related useful information. - """ - - # pylint: disable=too-many-public-methods - - _set_incompatibilities = [('ase', 'cell'), ('ase', 'pbc'), ('ase', 'pymatgen'), ('ase', 'pymatgen_molecule'), - ('ase', 'pymatgen_structure'), ('cell', 'pymatgen'), ('cell', 'pymatgen_molecule'), - ('cell', 'pymatgen_structure'), ('pbc', 'pymatgen'), ('pbc', 'pymatgen_molecule'), - ('pbc', 'pymatgen_structure'), ('pymatgen', 'pymatgen_molecule'), - ('pymatgen', 'pymatgen_structure'), ('pymatgen_molecule', 'pymatgen_structure')] - - _dimensionality_label = {0: '', 1: 'length', 2: 'surface', 3: 'volume'} - _internal_kind_tags = None - - def __init__( - self, - cell=None, - pbc=None, - ase=None, - pymatgen=None, - pymatgen_structure=None, - pymatgen_molecule=None, - **kwargs - ): # pylint: disable=too-many-arguments - args = { - 'cell': cell, - 'pbc': pbc, - 'ase': ase, - 'pymatgen': pymatgen, - 'pymatgen_structure': pymatgen_structure, - 'pymatgen_molecule': pymatgen_molecule, - } - - for left, right in self._set_incompatibilities: - if args[left] is not None and args[right] is not None: - raise ValueError(f'cannot pass {left} and {right} at the same time') - - super().__init__(**kwargs) - - if any(ext is not None for ext in [ase, pymatgen, pymatgen_structure, pymatgen_molecule]): - - if ase is not None: - self.set_ase(ase) - - if pymatgen is not None: - self.set_pymatgen(pymatgen) - - if pymatgen_structure is not None: - self.set_pymatgen_structure(pymatgen_structure) - - if pymatgen_molecule is not None: - self.set_pymatgen_molecule(pymatgen_molecule) - - else: - if cell is None: - cell = _DEFAULT_CELL - self.set_cell(cell) - - if pbc is None: - pbc = [True, True, True] - self.set_pbc(pbc) - - def get_dimensionality(self): - """ - Return the dimensionality of the structure and its length/surface/volume. - - Zero-dimensional structures are assigned "volume" 0. - - :return: returns a dictionary with keys "dim" (dimensionality integer), "label" (dimensionality label) - and "value" (numerical length/surface/volume). - """ - return _get_dimensionality(self.pbc, self.cell) - - def set_ase(self, aseatoms): - """ - Load the structure from a ASE object - """ - if is_ase_atoms(aseatoms): - # Read the ase structure - self.cell = aseatoms.cell - self.pbc = aseatoms.pbc - self.clear_kinds() # This also calls clear_sites - for atom in aseatoms: - self.append_atom(ase=atom) - else: - raise TypeError('The value is not an ase.Atoms object') - - def set_pymatgen(self, obj, **kwargs): - """ - Load the structure from a pymatgen object. - - .. note:: Requires the pymatgen module (version >= 3.0.13, usage - of earlier versions may cause errors). - """ - typestr = type(obj).__name__ - try: - func = getattr(self, f'set_pymatgen_{typestr.lower()}') - except AttributeError: - raise AttributeError(f"Converter for '{typestr}' to AiiDA structure does not exist") - func(obj, **kwargs) - - def set_pymatgen_molecule(self, mol, margin=5): - """ - Load the structure from a pymatgen Molecule object. - - :param margin: the margin to be added in all directions of the - bounding box of the molecule. - - .. note:: Requires the pymatgen module (version >= 3.0.13, usage - of earlier versions may cause errors). - """ - box = [ - max(x.coords.tolist()[0] for x in mol.sites) - min(x.coords.tolist()[0] for x in mol.sites) + 2 * margin, - max(x.coords.tolist()[1] for x in mol.sites) - min(x.coords.tolist()[1] for x in mol.sites) + 2 * margin, - max(x.coords.tolist()[2] for x in mol.sites) - min(x.coords.tolist()[2] for x in mol.sites) + 2 * margin - ] - self.set_pymatgen_structure(mol.get_boxed_structure(*box)) - self.pbc = [False, False, False] - - def set_pymatgen_structure(self, struct): - """ - Load the structure from a pymatgen Structure object. - - .. note:: periodic boundary conditions are set to True in all - three directions. - .. note:: Requires the pymatgen module (version >= 3.3.5, usage - of earlier versions may cause errors). - - :raise ValueError: if there are partial occupancies together with spins. - """ - - def build_kind_name(species_and_occu): - """ - Build a kind name from a pymatgen Composition, including an additional ordinal if spin is included, - e.g. it returns '1' for an atom with spin < 0 and '2' for an atom with spin > 0, - otherwise (no spin) it returns None - - :param species_and_occu: a pymatgen species and occupations dictionary - :return: a string representing the kind name or None - """ - species = list(species_and_occu.keys()) - occupations = list(species_and_occu.values()) - - has_spin = any(specie.as_dict().get('properties', {}).get('spin', 0) != 0 for specie in species) - has_partial_occupancies = (len(occupations) != 1 or occupations[0] != 1.0) - - if has_partial_occupancies and has_spin: - raise ValueError('Cannot set partial occupancies and spins at the same time') - - if has_spin: - - symbols = [specie.symbol for specie in species] - kind_name = create_automatic_kind_name(symbols, occupations) - - # If there is spin, we can only have a single specie, otherwise we would have raised above - specie = species[0] - spin = specie.as_dict().get('properties', {}).get('spin', 0) - - if spin < 0: - kind_name += '1' - else: - kind_name += '2' - - return kind_name - - return None - - self.cell = struct.lattice.matrix.tolist() - self.pbc = [True, True, True] - self.clear_kinds() - - for site in struct.sites: - - species_and_occu = site.species - - if 'kind_name' in site.properties: - kind_name = site.properties['kind_name'] - else: - kind_name = build_kind_name(species_and_occu) - - inputs = { - 'symbols': [x.symbol for x in species_and_occu.keys()], - 'weights': list(species_and_occu.values()), - 'position': site.coords.tolist() - } - - if kind_name is not None: - inputs['name'] = kind_name - - self.append_atom(**inputs) - - def _validate(self): - """ - Performs some standard validation tests. - """ - - from aiida.common.exceptions import ValidationError - - super()._validate() - - try: - _get_valid_cell(self.cell) - except ValueError as exc: - raise ValidationError(f'Invalid cell: {exc}') - - try: - get_valid_pbc(self.pbc) - except ValueError as exc: - raise ValidationError(f'Invalid periodic boundary conditions: {exc}') - - _validate_dimensionality(self.pbc, self.cell) - - try: - # This will try to create the kinds objects - kinds = self.kinds - except ValueError as exc: - raise ValidationError(f'Unable to validate the kinds: {exc}') - - from collections import Counter - - counts = Counter([k.name for k in kinds]) - for count in counts: - if counts[count] != 1: - raise ValidationError(f"Kind with name '{count}' appears {counts[count]} times instead of only one") - - try: - # This will try to create the sites objects - sites = self.sites - except ValueError as exc: - raise ValidationError(f'Unable to validate the sites: {exc}') - - for site in sites: - if site.kind_name not in [k.name for k in kinds]: - raise ValidationError(f'A site has kind {site.kind_name}, but no specie with that name exists') - - kinds_without_sites = (set(k.name for k in kinds) - set(s.kind_name for s in sites)) - if kinds_without_sites: - raise ValidationError( - f'The following kinds are defined, but there are no sites with that kind: {list(kinds_without_sites)}' - ) - - def _prepare_xsf(self, main_file_name=''): # pylint: disable=unused-argument - """ - Write the given structure to a string of format XSF (for XCrySDen). - """ - if self.is_alloy or self.has_vacancies: - raise NotImplementedError('XSF for alloys or systems with vacancies not implemented.') - - sites = self.sites - - return_string = 'CRYSTAL\nPRIMVEC 1\n' - for cell_vector in self.cell: - return_string += ' '.join([f'{i:18.10f}' for i in cell_vector]) - return_string += '\n' - return_string += 'PRIMCOORD 1\n' - return_string += f'{int(len(sites))} 1\n' - for site in sites: - # I checked above that it is not an alloy, therefore I take the - # first symbol - return_string += f'{_atomic_numbers[self.get_kind(site.kind_name).symbols[0]]} ' - return_string += '%18.10f %18.10f %18.10f\n' % tuple(site.position) - return return_string.encode('utf-8'), {} - - def _prepare_cif(self, main_file_name=''): # pylint: disable=unused-argument - """ - Write the given structure to a string of format CIF. - """ - from aiida.orm import CifData - - cif = CifData(ase=self.get_ase()) - return cif._prepare_cif() # pylint: disable=protected-access - - def _prepare_chemdoodle(self, main_file_name=''): # pylint: disable=unused-argument - """ - Write the given structure to a string of format required by ChemDoodle. - """ - # pylint: disable=too-many-locals,invalid-name - from itertools import product - - import numpy as np - - supercell_factors = [1, 1, 1] - - # Get cell vectors and atomic position - lattice_vectors = np.array(self.base.attributes.get('cell')) - base_sites = self.base.attributes.get('sites') - - start1 = -int(supercell_factors[0] / 2) - start2 = -int(supercell_factors[1] / 2) - start3 = -int(supercell_factors[2] / 2) - - stop1 = start1 + supercell_factors[0] - stop2 = start2 + supercell_factors[1] - stop3 = start3 + supercell_factors[2] - - grid1 = range(start1, stop1) - grid2 = range(start2, stop2) - grid3 = range(start3, stop3) - - atoms_json = [] - - # Manual recenter of the structure - center = (lattice_vectors[0] + lattice_vectors[1] + lattice_vectors[2]) / 2. - - for ix, iy, iz in product(grid1, grid2, grid3): - for base_site in base_sites: - shift = (ix * lattice_vectors[0] + iy * lattice_vectors[1] + \ - iz * lattice_vectors[2] - center).tolist() - - kind_name = base_site['kind_name'] - kind_string = self.get_kind(kind_name).get_symbols_string() - - atoms_json.append({ - 'l': kind_string, - 'x': base_site['position'][0] + shift[0], - 'y': base_site['position'][1] + shift[1], - 'z': base_site['position'][2] + shift[2], - 'atomic_elements_html': atom_kinds_to_html(kind_string) - }) - - cell_json = { - 't': 'UnitCell', - 'i': 's0', - 'o': (-center).tolist(), - 'x': (lattice_vectors[0] - center).tolist(), - 'y': (lattice_vectors[1] - center).tolist(), - 'z': (lattice_vectors[2] - center).tolist(), - 'xy': (lattice_vectors[0] + lattice_vectors[1] - center).tolist(), - 'xz': (lattice_vectors[0] + lattice_vectors[2] - center).tolist(), - 'yz': (lattice_vectors[1] + lattice_vectors[2] - center).tolist(), - 'xyz': (lattice_vectors[0] + lattice_vectors[1] + lattice_vectors[2] - center).tolist(), - } - - return_dict = {'s': [cell_json], 'm': [{'a': atoms_json}], 'units': 'Å'} - - return json.dumps(return_dict).encode('utf-8'), {} - - def _prepare_xyz(self, main_file_name=''): # pylint: disable=unused-argument - """ - Write the given structure to a string of format XYZ. - """ - if self.is_alloy or self.has_vacancies: - raise NotImplementedError('XYZ for alloys or systems with vacancies not implemented.') - - sites = self.sites - cell = self.cell - - return_list = [f'{len(sites)}'] - return_list.append( - 'Lattice="{} {} {} {} {} {} {} {} {}" pbc="{} {} {}"'.format( - cell[0][0], cell[0][1], cell[0][2], cell[1][0], cell[1][1], cell[1][2], cell[2][0], cell[2][1], - cell[2][2], self.pbc[0], self.pbc[1], self.pbc[2] - ) - ) - for site in sites: - # I checked above that it is not an alloy, therefore I take the - # first symbol - return_list.append( - '{:6s} {:18.10f} {:18.10f} {:18.10f}'.format( - self.get_kind(site.kind_name).symbols[0], site.position[0], site.position[1], site.position[2] - ) - ) - - return_string = '\n'.join(return_list) - return return_string.encode('utf-8'), {} - - def _parse_xyz(self, inputstring): - """ - Read the structure from a string of format XYZ. - """ - from aiida.tools.data.structure import xyz_parser_iterator - - # idiom to get to the last block - atoms = None - for _, _, atoms in xyz_parser_iterator(inputstring): - pass - - if atoms is None: - raise TypeError('The data does not contain any XYZ data') - - self.clear_kinds() - self.pbc = (False, False, False) - - for sym, position in atoms: - self.append_atom(symbols=sym, position=position) - - def _adjust_default_cell(self, vacuum_factor=1.0, vacuum_addition=10.0, pbc=(False, False, False)): - """ - If the structure was imported from an xyz file, it lacks a cell. - This method will adjust the cell - """ - # pylint: disable=invalid-name - import numpy as np - - def get_extremas_from_positions(positions): - """ - returns the minimum and maximum value for each dimension in the positions given - """ - return list(zip(*[(min(values), max(values)) for values in zip(*positions)])) - - # Calculating the minimal cell: - positions = np.array([site.position for site in self.sites]) - position_min, _ = get_extremas_from_positions(positions) - - # Translate the structure to the origin, such that the minimal values in each dimension - # amount to (0,0,0) - positions -= position_min - for index, site in enumerate(self.base.attributes.get('sites')): - site['position'] = list(positions[index]) - - # The orthorhombic cell that (just) accomodates the whole structure is now given by the - # extremas of position in each dimension: - minimal_orthorhombic_cell_dimensions = np.array(get_extremas_from_positions(positions)[1]) - minimal_orthorhombic_cell_dimensions = np.dot(vacuum_factor, minimal_orthorhombic_cell_dimensions) - minimal_orthorhombic_cell_dimensions += vacuum_addition - - # Transform the vector (a, b, c ) to [[a,0,0], [0,b,0], [0,0,c]] - newcell = np.diag(minimal_orthorhombic_cell_dimensions) - self.set_cell(newcell.tolist()) - - # Now set PBC (checks are done in set_pbc, no need to check anything here) - self.set_pbc(pbc) - - return self - - def get_description(self): - """ - Returns a string with infos retrieved from StructureData node's properties - - :param self: the StructureData node - :return: retsrt: the description string - """ - return self.get_formula(mode='hill_compact') - - def get_symbols_set(self): - """ - Return a set containing the names of all elements involved in - this structure (i.e., for it joins the list of symbols for each - kind k in the structure). - - :returns: a set of strings of element names. - """ - return set(itertools.chain.from_iterable(kind.symbols for kind in self.kinds)) - - def get_formula(self, mode='hill', separator=''): - """ - Return a string with the chemical formula. - - :param mode: a string to specify how to generate the formula, can - assume one of the following values: - - * 'hill' (default): count the number of atoms of each species, - then use Hill notation, i.e. alphabetical order with C and H - first if one or several C atom(s) is (are) present, e.g. - ``['C','H','H','H','O','C','H','H','H']`` will return ``'C2H6O'`` - ``['S','O','O','H','O','H','O']`` will return ``'H2O4S'`` - From E. A. Hill, J. Am. Chem. Soc., 22 (8), pp 478–494 (1900) - - * 'hill_compact': same as hill but the number of atoms for each - species is divided by the greatest common divisor of all of them, e.g. - ``['C','H','H','H','O','C','H','H','H','O','O','O']`` - will return ``'CH3O2'`` - - * 'reduce': group repeated symbols e.g. - ``['Ba', 'Ti', 'O', 'O', 'O', 'Ba', 'Ti', 'O', 'O', 'O', - 'Ba', 'Ti', 'Ti', 'O', 'O', 'O']`` will return ``'BaTiO3BaTiO3BaTi2O3'`` - - * 'group': will try to group as much as possible parts of the formula - e.g. - ``['Ba', 'Ti', 'O', 'O', 'O', 'Ba', 'Ti', 'O', 'O', 'O', - 'Ba', 'Ti', 'Ti', 'O', 'O', 'O']`` will return ``'(BaTiO3)2BaTi2O3'`` - - * 'count': same as hill (i.e. one just counts the number - of atoms of each species) without the re-ordering (take the - order of the atomic sites), e.g. - ``['Ba', 'Ti', 'O', 'O', 'O','Ba', 'Ti', 'O', 'O', 'O']`` - will return ``'Ba2Ti2O6'`` - - * 'count_compact': same as count but the number of atoms - for each species is divided by the greatest common divisor of - all of them, e.g. - ``['Ba', 'Ti', 'O', 'O', 'O','Ba', 'Ti', 'O', 'O', 'O']`` - will return ``'BaTiO3'`` - - :param separator: a string used to concatenate symbols. Default empty. - - :return: a string with the formula - - .. note:: in modes reduce, group, count and count_compact, the - initial order in which the atoms were appended by the user is - used to group and/or order the symbols in the formula - """ - - symbol_list = [self.get_kind(s.kind_name).get_symbols_string() for s in self.sites] - - return get_formula(symbol_list, mode=mode, separator=separator) - - def get_site_kindnames(self): - """ - Return a list with length equal to the number of sites of this structure, - where each element of the list is the kind name of the corresponding site. - - .. note:: This is NOT necessarily a list of chemical symbols! Use - ``[ self.get_kind(s.kind_name).get_symbols_string() for s in self.sites]`` - for chemical symbols - - :return: a list of strings - """ - return [this_site.kind_name for this_site in self.sites] - - def get_composition(self, mode='full'): - """ - Returns the chemical composition of this structure as a dictionary, - where each key is the kind symbol (e.g. H, Li, Ba), - and each value is the number of occurences of that element in this - structure. - - :param mode: Specify the mode of the composition to return. Choose from ``full``, ``reduced`` or ``fractional``. - For example, given the structure with formula Ba2Zr2O6, the various modes operate as follows. - ``full``: The default, the counts are left unnnormalized. - ``reduced``: The counts are renormalized to the greatest common denominator. - ``fractional``: The counts are renormalized such that the sum equals 1. - - :returns: a dictionary with the composition - """ - import numpy as np - symbols_list = [self.get_kind(s.kind_name).get_symbols_string() for s in self.sites] - symbols_set = set(symbols_list) - - if mode == 'full': - return {symbol: symbols_list.count(symbol) for symbol in symbols_set} - - if mode == 'reduced': - gcd = np.gcd.reduce([symbols_list.count(symbol) for symbol in symbols_set]) - return {symbol: (symbols_list.count(symbol) / gcd) for symbol in symbols_set} - - if mode == 'fractional': - sum_comp = sum(symbols_list.count(symbol) for symbol in symbols_set) - return {symbol: symbols_list.count(symbol) / sum_comp for symbol in symbols_set} - - raise ValueError(f'mode `{mode}` is invalid, choose from `full`, `reduced` or `fractional`.') - - def get_ase(self): - """ - Get the ASE object. - Requires to be able to import ase. - - :return: an ASE object corresponding to this - :py:class:`StructureData ` - object. - - .. note:: If any site is an alloy or has vacancies, a ValueError - is raised (from the site.get_ase() routine). - """ - return self._get_object_ase() - - def get_pymatgen(self, **kwargs): - """ - Get pymatgen object. Returns Structure for structures with - periodic boundary conditions (in three dimensions) and Molecule - otherwise. - :param add_spin: True to add the spins to the pymatgen structure. - Default is False (no spin added). - - .. note:: The spins are set according to the following rule: - - * if the kind name ends with 1 -> spin=+1 - - * if the kind name ends with 2 -> spin=-1 - - .. note:: Requires the pymatgen module (version >= 3.0.13, usage - of earlier versions may cause errors). - """ - return self._get_object_pymatgen(**kwargs) - - def get_pymatgen_structure(self, **kwargs): - """ - Get the pymatgen Structure object. - :param add_spin: True to add the spins to the pymatgen structure. - Default is False (no spin added). - - .. note:: The spins are set according to the following rule: - - * if the kind name ends with 1 -> spin=+1 - - * if the kind name ends with 2 -> spin=-1 - - .. note:: Requires the pymatgen module (version >= 3.0.13, usage - of earlier versions may cause errors). - - :return: a pymatgen Structure object corresponding to this - :py:class:`StructureData ` - object. - :raise ValueError: if periodic boundary conditions do not hold - in at least one dimension of real space. - """ - return self._get_object_pymatgen_structure(**kwargs) - - def get_pymatgen_molecule(self): - """ - Get the pymatgen Molecule object. - - .. note:: Requires the pymatgen module (version >= 3.0.13, usage - of earlier versions may cause errors). - - :return: a pymatgen Molecule object corresponding to this - :py:class:`StructureData ` - object. - """ - return self._get_object_pymatgen_molecule() - - def append_kind(self, kind): - """ - Append a kind to the - :py:class:`StructureData `. - It makes a copy of the kind. - - :param kind: the site to append, must be a Kind object. - """ - from aiida.common.exceptions import ModificationNotAllowed - - if self.is_stored: - raise ModificationNotAllowed('The StructureData object cannot be modified, it has already been stored') - - new_kind = Kind(kind=kind) # So we make a copy - - if kind.name in [k.name for k in self.kinds]: - raise ValueError(f'A kind with the same name ({kind.name}) already exists.') - - # If here, no exceptions have been raised, so I add the site. - self.base.attributes.all.setdefault('kinds', []).append(new_kind.get_raw()) - # Note, this is a dict (with integer keys) so it allows for empty spots! - if self._internal_kind_tags is None: - self._internal_kind_tags = {} - - self._internal_kind_tags[len(self.base.attributes.get('kinds')) - 1] = kind._internal_tag # pylint: disable=protected-access - - def append_site(self, site): - """ - Append a site to the - :py:class:`StructureData `. - It makes a copy of the site. - - :param site: the site to append. It must be a Site object. - """ - from aiida.common.exceptions import ModificationNotAllowed - - if self.is_stored: - raise ModificationNotAllowed('The StructureData object cannot be modified, it has already been stored') - - new_site = Site(site=site) # So we make a copy - - if site.kind_name not in [kind.name for kind in self.kinds]: - raise ValueError( - f"No kind with name '{site.kind_name}', available kinds are: {[kind.name for kind in self.kinds]}" - ) - - # If here, no exceptions have been raised, so I add the site. - self.base.attributes.all.setdefault('sites', []).append(new_site.get_raw()) - - def append_atom(self, **kwargs): - """ - Append an atom to the Structure, taking care of creating the - corresponding kind. - - :param ase: the ase Atom object from which we want to create a new atom - (if present, this must be the only parameter) - :param position: the position of the atom (three numbers in angstrom) - :param symbols: passed to the constructor of the Kind object. - :param weights: passed to the constructor of the Kind object. - :param name: passed to the constructor of the Kind object. See also the note below. - - .. note :: Note on the 'name' parameter (that is, the name of the kind): - - * if specified, no checks are done on existing species. Simply, - a new kind with that name is created. If there is a name - clash, a check is done: if the kinds are identical, no error - is issued; otherwise, an error is issued because you are trying - to store two different kinds with the same name. - - * if not specified, the name is automatically generated. Before - adding the kind, a check is done. If other species with the - same properties already exist, no new kinds are created, but - the site is added to the existing (identical) kind. - (Actually, the first kind that is encountered). - Otherwise, the name is made unique first, by adding to the string - containing the list of chemical symbols a number starting from 1, - until an unique name is found - - .. note :: checks of equality of species are done using - the :py:meth:`~aiida.orm.nodes.data.structure.Kind.compare_with` method. - """ - # pylint: disable=too-many-branches - aseatom = kwargs.pop('ase', None) - if aseatom is not None: - if kwargs: - raise ValueError( - "If you pass 'ase' as a parameter to " - 'append_atom, you cannot pass any further' - 'parameter' - ) - position = aseatom.position - kind = Kind(ase=aseatom) - else: - position = kwargs.pop('position', None) - if position is None: - raise ValueError('You have to specify the position of the new atom') - # all remaining parameters - kind = Kind(**kwargs) - - # I look for identical species only if the name is not specified - _kinds = self.kinds - - if 'name' not in kwargs: - # If the kind is identical to an existing one, I use the existing - # one, otherwise I replace it - exists_already = False - for idx, existing_kind in enumerate(_kinds): - try: - existing_kind._internal_tag = self._internal_kind_tags[idx] # pylint: disable=protected-access - except KeyError: - # self._internal_kind_tags does not contain any info for - # the kind in position idx: I don't have to add anything - # then, and I continue - pass - if kind.compare_with(existing_kind)[0]: - kind = existing_kind - exists_already = True - break - if not exists_already: - # There is not an identical kind. - # By default, the name of 'kind' just contains the elements. - # I then check that the name of 'kind' does not already exist, - # and if it exists I add a number (starting from 1) until I - # find a non-used name. - existing_names = [k.name for k in _kinds] - simplename = kind.name - counter = 1 - while kind.name in existing_names: - kind.name = f'{simplename}{counter}' - counter += 1 - self.append_kind(kind) - else: # 'name' was specified - old_kind = None - for existing_kind in _kinds: - if existing_kind.name == kwargs['name']: - old_kind = existing_kind - break - if old_kind is None: - self.append_kind(kind) - else: - is_the_same, firstdiff = kind.compare_with(old_kind) - if is_the_same: - kind = old_kind - else: - raise ValueError( - 'You are explicitly setting the name ' - "of the kind to '{}', that already " - 'exists, but the two kinds are different!' - ' (first difference: {})'.format(kind.name, firstdiff) - ) - - site = Site(kind_name=kind.name, position=position) - self.append_site(site) - - def clear_kinds(self): - """ - Removes all kinds for the StructureData object. - - .. note:: Also clear all sites! - """ - from aiida.common.exceptions import ModificationNotAllowed - - if self.is_stored: - raise ModificationNotAllowed('The StructureData object cannot be modified, it has already been stored') - - self.base.attributes.set('kinds', []) - self._internal_kind_tags = {} - self.clear_sites() - - def clear_sites(self): - """ - Removes all sites for the StructureData object. - """ - from aiida.common.exceptions import ModificationNotAllowed - - if self.is_stored: - raise ModificationNotAllowed('The StructureData object cannot be modified, it has already been stored') - - self.base.attributes.set('sites', []) - - @property - def sites(self): - """ - Returns a list of sites. - """ - try: - raw_sites = self.base.attributes.get('sites') - except AttributeError: - raw_sites = [] - return [Site(raw=i) for i in raw_sites] - - @property - def kinds(self): - """ - Returns a list of kinds. - """ - try: - raw_kinds = self.base.attributes.get('kinds') - except AttributeError: - raw_kinds = [] - return [Kind(raw=i) for i in raw_kinds] - - def get_kind(self, kind_name): - """ - Return the kind object associated with the given kind name. - - :param kind_name: String, the name of the kind you want to get - - :return: The Kind object associated with the given kind_name, if - a Kind with the given name is present in the structure. - - :raise: ValueError if the kind_name is not present. - """ - # Cache the kinds, if stored, for efficiency - if self.is_stored: - try: - kinds_dict = self._kinds_cache - except AttributeError: - self._kinds_cache = {_.name: _ for _ in self.kinds} # pylint: disable=attribute-defined-outside-init - kinds_dict = self._kinds_cache - else: - kinds_dict = {_.name: _ for _ in self.kinds} - - # Will raise ValueError if the kind is not present - try: - return kinds_dict[kind_name] - except KeyError: - raise ValueError(f"Kind name '{kind_name}' unknown") - - def get_kind_names(self): - """ - Return a list of kind names (in the same order of the ``self.kinds`` - property, but return the names rather than Kind objects) - - .. note:: This is NOT necessarily a list of chemical symbols! Use - get_symbols_set for chemical symbols - - :return: a list of strings. - """ - return [k.name for k in self.kinds] - - @property - def cell(self): - """ - Returns the cell shape. - - :return: a 3x3 list of lists. - """ - return copy.deepcopy(self.base.attributes.get('cell')) - - @cell.setter - def cell(self, value): - """Set the cell.""" - self.set_cell(value) - - def set_cell(self, value): - """Set the cell.""" - from aiida.common.exceptions import ModificationNotAllowed - - if self.is_stored: - raise ModificationNotAllowed('The StructureData object cannot be modified, it has already been stored') - - the_cell = _get_valid_cell(value) - self.base.attributes.set('cell', the_cell) - - def reset_cell(self, new_cell): - """ - Reset the cell of a structure not yet stored to a new value. - - :param new_cell: list specifying the cell vectors - - :raises: - ModificationNotAllowed: if object is already stored - """ - from aiida.common.exceptions import ModificationNotAllowed - - if self.is_stored: - raise ModificationNotAllowed() - - self.base.attributes.set('cell', new_cell) - - def reset_sites_positions(self, new_positions, conserve_particle=True): - """ - Replace all the Site positions attached to the Structure - - :param new_positions: list of (3D) positions for every sites. - - :param conserve_particle: if True, allows the possibility of removing a site. - currently not implemented. - - :raises aiida.common.ModificationNotAllowed: if object is stored already - :raises ValueError: if positions are invalid - - .. note:: it is assumed that the order of the new_positions is - given in the same order of the one it's substituting, i.e. the - kind of the site will not be checked. - """ - from aiida.common.exceptions import ModificationNotAllowed - - if self.is_stored: - raise ModificationNotAllowed() - - if not conserve_particle: - raise NotImplementedError - else: - - # test consistency of th enew input - n_sites = len(self.sites) - if n_sites != len(new_positions) and conserve_particle: - raise ValueError('the new positions should be as many as the previous structure.') - - new_sites = [] - for i in range(n_sites): - try: - this_pos = [float(j) for j in new_positions[i]] - except ValueError: - raise ValueError(f'Expecting a list of floats. Found instead {new_positions[i]}') - - if len(this_pos) != 3: - raise ValueError(f'Expecting a list of lists of length 3. found instead {len(this_pos)}') - - # now append this Site to the new_site list. - new_site = Site(site=self.sites[i]) # So we make a copy - new_site.position = copy.deepcopy(this_pos) - new_sites.append(new_site) - - # now clear the old sites, and substitute with the new ones - self.clear_sites() - for this_new_site in new_sites: - self.append_site(this_new_site) - - @property - def pbc(self): - """ - Get the periodic boundary conditions. - - :return: a tuple of three booleans, each one tells if there are periodic - boundary conditions for the i-th real-space direction (i=1,2,3) - """ - # return copy.deepcopy(self._pbc) - return (self.base.attributes.get('pbc1'), self.base.attributes.get('pbc2'), self.base.attributes.get('pbc3')) - - @pbc.setter - def pbc(self, value): - """Set the periodic boundary conditions.""" - self.set_pbc(value) - - def set_pbc(self, value): - """Set the periodic boundary conditions.""" - from aiida.common.exceptions import ModificationNotAllowed - - if self.is_stored: - raise ModificationNotAllowed('The StructureData object cannot be modified, it has already been stored') - the_pbc = get_valid_pbc(value) - - # self._pbc = the_pbc - self.base.attributes.set('pbc1', the_pbc[0]) - self.base.attributes.set('pbc2', the_pbc[1]) - self.base.attributes.set('pbc3', the_pbc[2]) - - @property - def cell_lengths(self): - """ - Get the lengths of cell lattice vectors in angstroms. - """ - import numpy - - cell = self.cell - return [ - numpy.linalg.norm(cell[0]), - numpy.linalg.norm(cell[1]), - numpy.linalg.norm(cell[2]), - ] - - @cell_lengths.setter - def cell_lengths(self, value): - self.set_cell_lengths(value) - - def set_cell_lengths(self, value): - raise NotImplementedError('Modification is not implemented yet') - - @property - def cell_angles(self): - """ - Get the angles between the cell lattice vectors in degrees. - """ - import numpy - - cell = self.cell - lengths = self.cell_lengths - return [ - float(numpy.arccos(x) / numpy.pi * 180) for x in [ - numpy.vdot(cell[1], cell[2]) / lengths[1] / lengths[2], - numpy.vdot(cell[0], cell[2]) / lengths[0] / lengths[2], - numpy.vdot(cell[0], cell[1]) / lengths[0] / lengths[1], - ] - ] - - @cell_angles.setter - def cell_angles(self, value): - self.set_cell_angles(value) - - def set_cell_angles(self, value): - raise NotImplementedError('Modification is not implemented yet') - - @property - def is_alloy(self): - """Return whether the structure contains any alloy kinds. - - :return: a boolean, True if at least one kind is an alloy - """ - return any(kind.is_alloy for kind in self.kinds) - - @property - def has_vacancies(self): - """Return whether the structure has vacancies in the structure. - - :return: a boolean, True if at least one kind has a vacancy - """ - return any(kind.has_vacancies for kind in self.kinds) - - def get_cell_volume(self): - """ - Returns the three-dimensional cell volume in Angstrom^3. - - Use the `get_dimensionality` method in order to get the area/length of lower-dimensional cells. - - :return: a float. - """ - return calc_cell_volume(self.cell) - - def get_cif(self, converter='ase', store=False, **kwargs): - """ - Creates :py:class:`aiida.orm.nodes.data.cif.CifData`. - - .. versionadded:: 1.0 - Renamed from _get_cif - - :param converter: specify the converter. Default 'ase'. - :param store: If True, intermediate calculation gets stored in the - AiiDA database for record. Default False. - :return: :py:class:`aiida.orm.nodes.data.cif.CifData` node. - """ - from aiida.tools.data import structure as structure_tools - - from .dict import Dict - - param = Dict(kwargs) - try: - conv_f = getattr(structure_tools, f'_get_cif_{converter}_inline') - except AttributeError: - raise ValueError(f"No such converter '{converter}' available") - ret_dict = conv_f(struct=self, parameters=param, metadata={'store_provenance': store}) - return ret_dict['cif'] - - def _get_object_phonopyatoms(self): - """ - Converts StructureData to PhonopyAtoms - - :return: a PhonopyAtoms object - """ - from phonopy.structure.atoms import PhonopyAtoms # pylint: disable=import-error,no-name-in-module - - atoms = PhonopyAtoms(symbols=[_.kind_name for _ in self.sites]) - # Phonopy internally uses scaled positions, so you must store cell first! - atoms.set_cell(self.cell) - atoms.set_positions([_.position for _ in self.sites]) - - return atoms - - def _get_object_ase(self): - """ - Converts - :py:class:`StructureData ` - to ase.Atoms - - :return: an ase.Atoms object - """ - import ase - - asecell = ase.Atoms(cell=self.cell, pbc=self.pbc) - _kinds = self.kinds - - for site in self.sites: - asecell.append(site.get_ase(kinds=_kinds)) - return asecell - - def _get_object_pymatgen(self, **kwargs): - """ - Converts - :py:class:`StructureData ` - to pymatgen object - - :return: a pymatgen Structure for structures with periodic boundary - conditions (in three dimensions) and Molecule otherwise - - .. note:: Requires the pymatgen module (version >= 3.0.13, usage - of earlier versions may cause errors). - """ - if self.pbc == (True, True, True): - return self._get_object_pymatgen_structure(**kwargs) - - return self._get_object_pymatgen_molecule(**kwargs) - - def _get_object_pymatgen_structure(self, **kwargs): - """ - Converts - :py:class:`StructureData ` - to pymatgen Structure object - :param add_spin: True to add the spins to the pymatgen structure. - Default is False (no spin added). - - .. note:: The spins are set according to the following rule: - - * if the kind name ends with 1 -> spin=+1 - - * if the kind name ends with 2 -> spin=-1 - - :return: a pymatgen Structure object corresponding to this - :py:class:`StructureData ` - object - :raise ValueError: if periodic boundary conditions does not hold - in at least one dimension of real space; if there are partial occupancies - together with spins (defined by kind names ending with '1' or '2'). - - .. note:: Requires the pymatgen module (version >= 3.0.13, usage - of earlier versions may cause errors) - """ - from pymatgen.core.structure import Structure - - if self.pbc != (True, True, True): - raise ValueError('Periodic boundary conditions must apply in all three dimensions of real space') - - species = [] - additional_kwargs = {} - - if (kwargs.pop('add_spin', False) and any(n.endswith('1') or n.endswith('2') for n in self.get_kind_names())): - # case when spins are defined -> no partial occupancy allowed - from pymatgen.core.periodic_table import Specie - oxidation_state = 0 # now I always set the oxidation_state to zero - for site in self.sites: - kind = self.get_kind(site.kind_name) - if len(kind.symbols) != 1 or (len(kind.weights) != 1 or sum(kind.weights) < 1.): - raise ValueError('Cannot set partial occupancies and spins at the same time') - species.append( - Specie( - kind.symbols[0], - oxidation_state, - properties={'spin': -1 if kind.name.endswith('1') else 1 if kind.name.endswith('2') else 0} - ) - ) - else: - # case when no spin are defined - for site in self.sites: - kind = self.get_kind(site.kind_name) - species.append(dict(zip(kind.symbols, kind.weights))) - if any( - create_automatic_kind_name(self.get_kind(name).symbols, - self.get_kind(name).weights) != name for name in self.get_site_kindnames() - ): - # add "kind_name" as a properties to each site, whenever - # the kind_name cannot be automatically obtained from the symbols - additional_kwargs['site_properties'] = {'kind_name': self.get_site_kindnames()} - - if kwargs: - raise ValueError(f'Unrecognized parameters passed to pymatgen converter: {kwargs.keys()}') - - positions = [list(x.position) for x in self.sites] - return Structure(self.cell, species, positions, coords_are_cartesian=True, **additional_kwargs) - - def _get_object_pymatgen_molecule(self, **kwargs): - """ - Converts - :py:class:`StructureData ` - to pymatgen Molecule object - - :return: a pymatgen Molecule object corresponding to this - :py:class:`StructureData ` - object. - - .. note:: Requires the pymatgen module (version >= 3.0.13, usage - of earlier versions may cause errors) - """ - from pymatgen.core.structure import Molecule - - if kwargs: - raise ValueError(f'Unrecognized parameters passed to pymatgen converter: {kwargs.keys()}') - - species = [] - for site in self.sites: - kind = self.get_kind(site.kind_name) - species.append(dict(zip(kind.symbols, kind.weights))) - - positions = [list(site.position) for site in self.sites] - return Molecule(species, positions) - - -class Kind: - """ - This class contains the information about the species (kinds) of the system. - - It can be a single atom, or an alloy, or even contain vacancies. - """ - - def __init__(self, **kwargs): - """ - Create a site. - One can either pass: - - :param raw: the raw python dictionary that will be converted to a - Kind object. - :param ase: an ase Atom object - :param kind: a Kind object (to get a copy) - - Or alternatively the following parameters: - - :param symbols: a single string for the symbol of this site, or a list - of symbol strings - :param weights: (optional) the weights for each atomic species of - this site. - If only a single symbol is provided, then this value is - optional and the weight is set to 1. - :param mass: (optional) the mass for this site in atomic mass units. - If not provided, the mass is set by the - self.reset_mass() function. - :param name: a string that uniquely identifies the kind, and that - is used to identify the sites. - """ - # pylint: disable=too-many-branches,too-many-statements - # Internal variables - self._mass = None - self._symbols = None - self._weights = None - self._name = None - - # It will be remain to None in general; it is used to further - # identify this species. At the moment, it is used only when importing - # from ASE, if the species had a tag (different from zero). - ## NOTE! This is not persisted on DB but only used while the class - # is loaded in memory (i.e., it is not output with the get_raw() method) - self._internal_tag = None - - # Logic to create the site from the raw format - if 'raw' in kwargs: - if len(kwargs) != 1: - raise ValueError("If you pass 'raw', then you cannot pass any other parameter.") - - raw = kwargs['raw'] - - try: - self.set_symbols_and_weights(raw['symbols'], raw['weights']) - except KeyError: - raise ValueError("You didn't specify either 'symbols' or 'weights' in the raw site data.") - try: - self.mass = raw['mass'] - except KeyError: - raise ValueError("You didn't specify the site mass in the raw site data.") - - try: - self.name = raw['name'] - except KeyError: - raise ValueError("You didn't specify the name in the raw site data.") - - elif 'kind' in kwargs: - if len(kwargs) != 1: - raise ValueError("If you pass 'kind', then you cannot pass any other parameter.") - oldkind = kwargs['kind'] - - try: - self.set_symbols_and_weights(oldkind.symbols, oldkind.weights) - self.mass = oldkind.mass - self.name = oldkind.name - self._internal_tag = oldkind._internal_tag - except AttributeError: - raise ValueError( - 'Error using the Kind object. Are you sure ' - 'it is a Kind object? [Introspection says it is ' - '{}]'.format(str(type(oldkind))) - ) - - elif 'ase' in kwargs: - aseatom = kwargs['ase'] - if len(kwargs) != 1: - raise ValueError("If you pass 'ase', then you cannot pass any other parameter.") - - try: - import numpy - self.set_symbols_and_weights([aseatom.symbol], [1.]) - # ASE sets mass to numpy.nan for unstable species - if not numpy.isnan(aseatom.mass): - self.mass = aseatom.mass - else: - self.reset_mass() - except AttributeError: - raise ValueError( - 'Error using the aseatom object. Are you sure ' - 'it is a ase.atom.Atom object? [Introspection says it is ' - '{}]'.format(str(type(aseatom))) - ) - if aseatom.tag != 0: - self.set_automatic_kind_name(tag=aseatom.tag) - self._internal_tag = aseatom.tag - else: - self.set_automatic_kind_name() - else: - if 'symbols' not in kwargs: - raise ValueError( - "'symbols' need to be " - 'specified (at least) to create a Site object. Otherwise, ' - "pass a raw site using the 'raw' parameter." - ) - weights = kwargs.pop('weights', None) - self.set_symbols_and_weights(kwargs.pop('symbols'), weights) - try: - self.mass = kwargs.pop('mass') - except KeyError: - self.reset_mass() - try: - self.name = kwargs.pop('name') - except KeyError: - self.set_automatic_kind_name() - if kwargs: - raise ValueError(f'Unrecognized parameters passed to Kind constructor: {kwargs.keys()}') - - def get_raw(self): - """ - Return the raw version of the site, mapped to a suitable dictionary. - This is the format that is actually used to store each kind of the - structure in the DB. - - :return: a python dictionary with the kind. - """ - return { - 'symbols': self.symbols, - 'weights': self.weights, - 'mass': self.mass, - 'name': self.name, - } - - def reset_mass(self): - """ - Reset the mass to the automatic calculated value. - - The mass can be set manually; by default, if not provided, - it is the mass of the constituent atoms, weighted with their - weight (after the weight has been normalized to one to take - correctly into account vacancies). - - This function uses the internal _symbols and _weights values and - thus assumes that the values are validated. - - It sets the mass to None if the sum of weights is zero. - """ - w_sum = sum(self._weights) - - if abs(w_sum) < _SUM_THRESHOLD: - self._mass = None - return - - normalized_weights = (i / w_sum for i in self._weights) - element_masses = (_atomic_masses[sym] for sym in self._symbols) - # Weighted mass - self._mass = sum(i * j for i, j in zip(normalized_weights, element_masses)) - - @property - def name(self): - """ - Return the name of this kind. - The name of a kind is used to identify the species of a site. - - :return: a string - """ - return self._name - - @name.setter - def name(self, value): - """ - Set the name of this site (a string). - """ - self._name = str(value) - - def set_automatic_kind_name(self, tag=None): - """ - Set the type to a string obtained with the symbols appended one - after the other, without spaces, in alphabetical order; - if the site has a vacancy, a X is appended at the end too. - """ - name_string = create_automatic_kind_name(self.symbols, self.weights) - if tag is None: - self.name = name_string - else: - self.name = f'{name_string}{tag}' - - def compare_with(self, other_kind): - """ - Compare with another Kind object to check if they are different. - - .. note:: This does NOT check the 'type' attribute. Instead, it compares - (with reasonable thresholds, where applicable): the mass, and the list - of symbols and of weights. Moreover, it compares the - ``_internal_tag``, if defined (at the moment, defined automatically - only when importing the Kind from ASE, if the atom has a non-zero tag). - Note that the _internal_tag is only used while the class is loaded, - but is not persisted on the database. - - :return: A tuple with two elements. The first one is True if the two sites - are 'equivalent' (same mass, symbols and weights), False otherwise. - The second element of the tuple is a string, - which is either None (if the first element was True), or contains - a 'human-readable' description of the first difference encountered - between the two sites. - """ - # Check length of symbols - if len(self.symbols) != len(other_kind.symbols): - return (False, 'Different length of symbols list') - - # Check list of symbols - for i, symbol in enumerate(self.symbols): - if symbol != other_kind.symbols[i]: - return (False, f'Symbol at position {i + 1:d} are different ({symbol} vs. {other_kind.symbols[i]})') - # Check weights (assuming length of weights and of symbols have same - # length, which should be always true - for i, weight in enumerate(self.weights): - if weight != other_kind.weights[i]: - return (False, f'Weight at position {i + 1:d} are different ({weight} vs. {other_kind.weights[i]})') - # Check masses - if abs(self.mass - other_kind.mass) > _MASS_THRESHOLD: - return (False, f'Masses are different ({self.mass} vs. {other_kind.mass})') - - if self._internal_tag != other_kind._internal_tag: # pylint: disable=protected-access - return ( - False, - 'Internal tags are different ({} vs. {})' - ''.format(self._internal_tag, other_kind._internal_tag) # pylint: disable=protected-access - ) - - # If we got here, the two Site objects are similar enough - # to be considered of the same kind - return (True, '') - - @property - def mass(self): - """ - The mass of this species kind. - - :return: a float - """ - return self._mass - - @mass.setter - def mass(self, value): - the_mass = float(value) - if the_mass <= 0: - raise ValueError('The mass must be positive.') - self._mass = the_mass - - @property - def weights(self): - """ - Weights for this species kind. Refer also to - :func:validate_symbols_tuple for the validation rules on the weights. - """ - return copy.deepcopy(self._weights) - - @weights.setter - def weights(self, value): - """ - If value is a number, a single weight is used. Otherwise, a list or - tuple of numbers is expected. - None is also accepted, corresponding to the list [1.]. - """ - weights_tuple = _create_weights_tuple(value) - - if len(weights_tuple) != len(self._symbols): - raise ValueError( - 'Cannot change the number of weights. Use the ' - 'set_symbols_and_weights function instead.' - ) - validate_weights_tuple(weights_tuple, _SUM_THRESHOLD) - - self._weights = weights_tuple - - def get_symbols_string(self): - """ - Return a string that tries to match as good as possible the symbols - of this kind. If there is only one symbol (no alloy) with 100% - occupancy, just returns the symbol name. Otherwise, groups the full - string in curly brackets, and try to write also the composition - (with 2 precision only). - - .. note:: If there is a vacancy (sum of weights<1), we indicate it - with the X symbol followed by 1-sum(weights) (still with 2 - digits precision, so it can be 0.00) - - .. note:: Note the difference with respect to the symbols and the - symbol properties! - """ - return get_symbols_string(self._symbols, self._weights) - - @property - def symbol(self): - """ - If the kind has only one symbol, return it; otherwise, raise a - ValueError. - """ - if len(self._symbols) == 1: - return self._symbols[0] - - raise ValueError(f'This kind has more than one symbol (it is an alloy): {self._symbols}') - - @property - def symbols(self): - """ - List of symbols for this site. If the site is a single atom, - pass a list of one element only, or simply the string for that atom. - For alloys, a list of elements. - - .. note:: Note that if you change the list of symbols, the kind - name remains unchanged. - """ - return copy.deepcopy(self._symbols) - - @symbols.setter - def symbols(self, value): - """ - If value is a string, a single symbol is used. Otherwise, a list or - tuple of strings is expected. - - I set a copy of the list, so to avoid that the content changes - after the value is set. - """ - symbols_tuple = _create_symbols_tuple(value) - - if len(symbols_tuple) != len(self._weights): - raise ValueError( - 'Cannot change the number of symbols. Use the ' - 'set_symbols_and_weights function instead.' - ) - validate_symbols_tuple(symbols_tuple) - - self._symbols = symbols_tuple - - def set_symbols_and_weights(self, symbols, weights): - """ - Set the chemical symbols and the weights for the site. - - .. note:: Note that the kind name remains unchanged. - """ - symbols_tuple = _create_symbols_tuple(symbols) - weights_tuple = _create_weights_tuple(weights) - if len(symbols_tuple) != len(weights_tuple): - raise ValueError('The number of symbols and weights must coincide.') - validate_symbols_tuple(symbols_tuple) - validate_weights_tuple(weights_tuple, _SUM_THRESHOLD) - self._symbols = symbols_tuple - self._weights = weights_tuple - - @property - def is_alloy(self): - """Return whether the Kind is an alloy, i.e. contains more than one element - - :return: boolean, True if the kind has more than one element, False otherwise. - """ - return len(self._symbols) != 1 - - @property - def has_vacancies(self): - """Return whether the Kind contains vacancies, i.e. when the sum of the weights is less than one. - - .. note:: the property uses the internal variable `_SUM_THRESHOLD` as a threshold. - - :return: boolean, True if the sum of the weights is less than one, False otherwise - """ - return has_vacancies(self._weights) - - def __repr__(self): - return f'<{self.__class__.__name__}: {str(self)}>' - - def __str__(self): - symbol = self.get_symbols_string() - return f"name '{self.name}', symbol '{symbol}'" - - -class Site: - """ - This class contains the information about a given site of the system. - - It can be a single atom, or an alloy, or even contain vacancies. - """ - - def __init__(self, **kwargs): - """ - Create a site. - - :param kind_name: a string that identifies the kind (species) of this site. - This has to be found in the list of kinds of the StructureData - object. - Validation will be done at the StructureData level. - :param position: the absolute position (three floats) in angstrom - """ - self._kind_name = None - self._position = None - - if 'site' in kwargs: - site = kwargs.pop('site') - if kwargs: - raise ValueError("If you pass 'site', you cannot pass any further parameter to the Site constructor") - if not isinstance(site, Site): - raise ValueError("'site' must be of type Site") - self.kind_name = site.kind_name - self.position = site.position - elif 'raw' in kwargs: - raw = kwargs.pop('raw') - if kwargs: - raise ValueError("If you pass 'raw', you cannot pass any further parameter to the Site constructor") - try: - self.kind_name = raw['kind_name'] - self.position = raw['position'] - except KeyError as exc: - raise ValueError(f'Invalid raw object, it does not contain any key {exc.args[0]}') - except TypeError: - raise ValueError('Invalid raw object, it is not a dictionary') - - else: - try: - self.kind_name = kwargs.pop('kind_name') - self.position = kwargs.pop('position') - except KeyError as exc: - raise ValueError(f'You need to specify {exc.args[0]}') - if kwargs: - raise ValueError(f'Unrecognized parameters: {kwargs.keys}') - - def get_raw(self): - """ - Return the raw version of the site, mapped to a suitable dictionary. - This is the format that is actually used to store each site of the - structure in the DB. - - :return: a python dictionary with the site. - """ - return { - 'position': self.position, - 'kind_name': self.kind_name, - } - - def get_ase(self, kinds): - """ - Return a ase.Atom object for this site. - - :param kinds: the list of kinds from the StructureData object. - - .. note:: If any site is an alloy or has vacancies, a ValueError - is raised (from the site.get_ase() routine). - """ - # pylint: disable=too-many-branches - from collections import defaultdict - - import ase - - # I create the list of tags - tag_list = [] - used_tags = defaultdict(list) - for k in kinds: - # Skip alloys and vacancies - if k.is_alloy or k.has_vacancies: - tag_list.append(None) - # If the kind name is equal to the specie name, - # then no tag should be set - elif str(k.name) == str(k.symbols[0]): - tag_list.append(None) - else: - # Name is not the specie name - if k.name.startswith(k.symbols[0]): - try: - new_tag = int(k.name[len(k.symbols[0])]) - tag_list.append(new_tag) - used_tags[k.symbols[0]].append(new_tag) - continue - except ValueError: - pass - tag_list.append(k.symbols[0]) # I use a string as a placeholder - - for i, _ in enumerate(tag_list): - # If it is a string, it is the name of the element, - # and I have to generate a new integer for this element - # and replace tag_list[i] with this new integer - if isinstance(tag_list[i], str): - # I get a list of used tags for this element - existing_tags = used_tags[tag_list[i]] - if existing_tags: - new_tag = max(existing_tags) + 1 - else: # empty list - new_tag = 1 - # I store it also as a used tag! - used_tags[tag_list[i]].append(new_tag) - # I update the tag - tag_list[i] = new_tag - - found = False - for kind_candidate, tag_candidate in zip(kinds, tag_list): - if kind_candidate.name == self.kind_name: - kind = kind_candidate - tag = tag_candidate - found = True - break - if not found: - raise ValueError(f"No kind '{self.kind_name}' has been found in the list of kinds") - - if kind.is_alloy or kind.has_vacancies: - raise ValueError('Cannot convert to ASE if the kind represents an alloy or it has vacancies.') - aseatom = ase.Atom(position=self.position, symbol=str(kind.symbols[0]), mass=kind.mass) - if tag is not None: - aseatom.tag = tag # pylint: disable=assigning-non-slot - return aseatom - - @property - def kind_name(self): - """ - Return the kind name of this site (a string). - - The type of a site is used to decide whether two sites are identical - (same mass, symbols, weights, ...) or not. - """ - return self._kind_name - - @kind_name.setter - def kind_name(self, value): - """ - Set the type of this site (a string). - """ - self._kind_name = str(value) - - @property - def position(self): - """ - Return the position of this site in absolute coordinates, - in angstrom. - """ - return copy.deepcopy(self._position) - - @position.setter - def position(self, value): - """ - Set the position of this site in absolute coordinates, - in angstrom. - """ - try: - internal_pos = tuple(float(i) for i in value) - if len(internal_pos) != 3: - raise ValueError - # value is not iterable or elements are not floats or len != 3 - except (ValueError, TypeError): - raise ValueError('Wrong format for position, must be a list of three float numbers.') - self._position = internal_pos - - def __repr__(self): - return f'<{self.__class__.__name__}: {str(self)}>' - - def __str__(self): - return f"kind name '{self.kind_name}' @ {self.position[0]},{self.position[1]},{self.position[2]}" - - -def _get_dimensionality(pbc, cell): - """ - Return the dimensionality of the structure and its length/surface/volume. - - Zero-dimensional structures are assigned "volume" 0. - - :return: returns a dictionary with keys "dim" (dimensionality integer), "label" (dimensionality label) - and "value" (numerical length/surface/volume). - """ - - import numpy as np - - retdict = {} - - pbc = np.array(pbc) - cell = np.array(cell) - - dim = len(pbc[pbc]) - - retdict['dim'] = dim - retdict['label'] = StructureData._dimensionality_label[dim] # pylint: disable=protected-access - - if dim not in (0, 1, 2, 3): - raise ValueError(f'Dimensionality {dim} must be one of 0, 1, 2, 3') - - if dim == 0: - # We have no concept of 0d volume. Let's return a value of 0 for a consistent output dictionary - retdict['value'] = 0 - elif dim == 1: - retdict['value'] = np.linalg.norm(cell[pbc]) - elif dim == 2: - vectors = cell[pbc] - retdict['value'] = np.linalg.norm(np.cross(vectors[0], vectors[1])) - elif dim == 3: - retdict['value'] = calc_cell_volume(cell) - - return retdict - - -def _validate_dimensionality(pbc, cell): - """ - Check whether the given pbc and cell vectors are consistent. - """ - dim = _get_dimensionality(pbc, cell) - - # 0-d structures put no constraints on the cell - if dim['dim'] == 0: - return - - # finite-d structures should have a cell with finite volume - if dim['value'] == 0: - raise ValueError(f'Structure has periodicity {pbc} but {dim["dim"]}-d volume 0.') - - return diff --git a/aiida/orm/nodes/data/upf.py b/aiida/orm/nodes/data/upf.py deleted file mode 100644 index 4d59dd11f8..0000000000 --- a/aiida/orm/nodes/data/upf.py +++ /dev/null @@ -1,485 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module of `Data` sub class to represent a pseudopotential single file in UPF format and related utilities.""" -import json -import re - -from upf_to_json import upf_to_json - -from .singlefile import SinglefileData - -__all__ = ('UpfData',) - -REGEX_UPF_VERSION = re.compile(r""" - \s*.*)"> - """, re.VERBOSE) - -REGEX_ELEMENT_V1 = re.compile(r""" - (?P[a-zA-Z]{1,2}) - \s+ - Element - """, re.VERBOSE) - -REGEX_ELEMENT_V2 = re.compile( - r""" - \s* - element\s*=\s*(?P['"])\s* - (?P[a-zA-Z]{1,2})\s* - (?P=quote_symbol) - """, re.VERBOSE -) - - -def get_pseudos_from_structure(structure, family_name): - """Return a dictionary mapping each kind name of the structure to corresponding `UpfData` from given family. - - :param structure: a `StructureData` - :param family_name: the name of a UPF family group - :return: dictionary mapping each structure kind name onto `UpfData` of corresponding element - :raise aiida.common.MultipleObjectsError: if more than one UPF for the same element is found in the group. - :raise aiida.common.NotExistent: if no UPF for an element in the group is found in the group. - """ - from aiida.common.exceptions import MultipleObjectsError, NotExistent - - pseudo_list = {} - family_pseudos = {} - family = UpfData.get_upf_group(family_name) - - for node in family.nodes: - if isinstance(node, UpfData): - if node.element in family_pseudos: - raise MultipleObjectsError( - f'More than one UPF for element {node.element} found in family {family_name}' - ) - family_pseudos[node.element] = node - - for kind in structure.kinds: - try: - pseudo_list[kind.name] = family_pseudos[kind.symbol] - except KeyError: - raise NotExistent(f'No UPF for element {kind.symbol} found in family {family_name}') - - return pseudo_list - - -def upload_upf_family(folder, group_label, group_description, stop_if_existing=True, backend=None): - """Upload a set of UPF files in a given group. - - :param folder: a path containing all UPF files to be added. - Only files ending in .UPF (case-insensitive) are considered. - :param group_label: the name of the group to create. If it exists and is non-empty, a UniquenessError is raised. - :param group_description: string to be set as the group description. Overwrites previous descriptions. - :param stop_if_existing: if True, check for the md5 of the files and, if the file already exists in the DB, raises a - MultipleObjectsError. If False, simply adds the existing UPFData node to the group. - """ - # pylint: disable=too-many-locals,too-many-branches - import os - - from aiida import orm - from aiida.common import AIIDA_LOGGER - from aiida.common.exceptions import UniquenessError - from aiida.common.files import md5_file - - if not os.path.isdir(folder): - raise ValueError('folder must be a directory') - - # only files, and only those ending with .upf or .UPF; - # go to the real file if it is a symlink - filenames = [ - os.path.realpath(os.path.join(folder, i)) - for i in os.listdir(folder) - if os.path.isfile(os.path.join(folder, i)) and i.lower().endswith('.upf') - ] - - nfiles = len(filenames) - - automatic_user = orm.User.collection.get_default() - group, group_created = orm.UpfFamily.collection.get_or_create(label=group_label, user=automatic_user) - - if group.user.email != automatic_user.email: - raise UniquenessError( - 'There is already a UpfFamily group with label {}' - ', but it belongs to user {}, therefore you ' - 'cannot modify it'.format(group_label, group.user.email) - ) - - # Always update description, even if the group already existed - group.description = group_description - - # NOTE: GROUP SAVED ONLY AFTER CHECKS OF UNICITY - - pseudo_and_created = [] - - for filename in filenames: - md5sum = md5_file(filename) - builder = orm.QueryBuilder(backend=backend) - builder.append(UpfData, filters={'attributes.md5': {'==': md5sum}}) - existing_upf = builder.first(flat=True) - - if existing_upf is None: - # return the upfdata instances, not stored - pseudo, created = UpfData.get_or_create(filename, use_first=True, store_upf=False) - # to check whether only one upf per element exists - # NOTE: actually, created has the meaning of "to_be_created" - pseudo_and_created.append((pseudo, created)) - else: - if stop_if_existing: - raise ValueError(f'A UPF with identical MD5 to {filename} cannot be added with stop_if_existing') - pseudo_and_created.append((existing_upf, False)) - - # check whether pseudo are unique per element - elements = [(i[0].element, i[0].md5sum) for i in pseudo_and_created] - # If group already exists, check also that I am not inserting more than - # once the same element - if not group_created: - for aiida_n in group.nodes: - # Skip non-pseudos - if not isinstance(aiida_n, UpfData): - continue - elements.append((aiida_n.element, aiida_n.md5sum)) - - elements = set(elements) # Discard elements with the same MD5, that would - # not be stored twice - elements_names = [e[0] for e in elements] - - if not len(elements_names) == len(set(elements_names)): - duplicates = {x for x in elements_names if elements_names.count(x) > 1} - duplicates_string = ', '.join(i for i in duplicates) - raise UniquenessError(f'More than one UPF found for the elements: {duplicates_string}.') - - # At this point, save the group, if still unstored - if group_created: - group.store() - - # save the upf in the database, and add them to group - for pseudo, created in pseudo_and_created: - if created: - pseudo.store() - - AIIDA_LOGGER.debug(f'New node {pseudo.uuid} created for file {pseudo.filename}') - else: - AIIDA_LOGGER.debug(f'Reusing node {pseudo.uuid} for file {pseudo.filename}') - - # Add elements to the group all togetehr - group.add_nodes([pseudo for pseudo, created in pseudo_and_created]) - - nuploaded = len([_ for _, created in pseudo_and_created if created]) - - return nfiles, nuploaded - - -def parse_upf(fname, check_filename=True, encoding='utf-8'): - """ - Try to get relevant information from the UPF. For the moment, only the - element name. Note that even UPF v.2 cannot be parsed with the XML minidom! - (e.g. due to the & characters in the human-readable section). - - If check_filename is True, raise a ParsingError exception if the filename - does not start with the element name. - """ - # pylint: disable=too-many-branches - import os - - from aiida.common import AIIDA_LOGGER - from aiida.common.exceptions import ParsingError - from aiida.orm.nodes.data.structure import _valid_symbols - - parsed_data = {} - - try: - upf_contents = fname.read() - except AttributeError: - with open(fname, encoding=encoding) as handle: - upf_contents = handle.read() - else: - if check_filename: - raise ValueError('cannot use filelike objects when `check_filename=True`, use a filepath instead.') - fname = 'file.txt' - - match = REGEX_UPF_VERSION.search(upf_contents) - if match: - version = match.group('version') - AIIDA_LOGGER.debug(f'Version found: {version} for file {fname}') - else: - AIIDA_LOGGER.debug(f'Assuming version 1 for file {fname}') - version = '1' - - parsed_data['version'] = version - try: - version_major = int(version.partition('.')[0]) - except ValueError: - # If the version string does not contain a dot, fallback - # to version 1 - AIIDA_LOGGER.debug(f'Falling back to version 1 for file {fname} version string {version} unrecognized') - version_major = 1 - - element = None - if version_major == 1: - match = REGEX_ELEMENT_V1.search(upf_contents) - if match: - element = match.group('element_name') - else: # all versions > 1 - match = REGEX_ELEMENT_V2.search(upf_contents) - if match: - element = match.group('element_name') - - if element is None: - raise ParsingError(f'Unable to find the element of UPF {fname}') - element = element.capitalize() - if element not in _valid_symbols: - raise ParsingError(f'Unknown element symbol {element} for file {fname}') - if check_filename: - if not os.path.basename(fname).lower().startswith(element.lower()): - raise ParsingError( - 'Filename {0} was recognized for element ' - '{1}, but the filename does not start ' - 'with {1}'.format(fname, element) - ) - - parsed_data['element'] = element - - return parsed_data - - -class UpfData(SinglefileData): - """`Data` sub class to represent a pseudopotential single file in UPF format.""" - - @classmethod - def get_or_create(cls, filepath, use_first=False, store_upf=True): - """Get the `UpfData` with the same md5 of the given file, or create it if it does not yet exist. - - :param filepath: an absolute filepath on disk - :param use_first: if False (default), raise an exception if more than one potential is found. - If it is True, instead, use the first available pseudopotential. - :param store_upf: boolean, if false, the `UpfData` if created will not be stored. - :return: tuple of `UpfData` and boolean indicating whether it was created. - """ - import os - - from aiida.common.files import md5_file - - if not os.path.isabs(filepath): - raise ValueError('filepath must be an absolute path') - - pseudos = cls.from_md5(md5_file(filepath)) - - if not pseudos: - instance = cls(file=filepath) - if store_upf: - instance.store() - return (instance, True) - - if len(pseudos) > 1: - if use_first: - return (pseudos[0], False) - - raise ValueError( - 'More than one copy of a pseudopotential with the same MD5 has been found in the DB. pks={}'.format( - ','.join([str(i.pk) for i in pseudos]) - ) - ) - - return (pseudos[0], False) - - def store(self, *args, **kwargs): # pylint: disable=signature-differs - """Store the node, reparsing the file so that the md5 and the element are correctly reset.""" - from aiida.common.exceptions import ParsingError - from aiida.common.files import md5_from_filelike - - if self.is_stored: - return self - - # Do not check the filename because it will fail since we are passing in a handle, which doesn't have a filename - # and so `parse_upf` will raise. The reason we have to pass in a handle is because this is the repository does - # not allow to get an absolute filepath. Anyway, the filename was already checked in `set_file` when the file - # was set for the first time. All the logic in this method is duplicated in `store` and `_validate` and badly - # needs to be refactored, but that is for another time. - with self.open(mode='r') as handle: - parsed_data = parse_upf(handle, check_filename=False) - - # Open in binary mode which is required for generating the md5 checksum - with self.open(mode='rb') as handle: - md5 = md5_from_filelike(handle) - - try: - element = parsed_data['element'] - except KeyError: - raise ParsingError(f'Could not parse the element from the UPF file {self.filename}') - - self.base.attributes.set('element', str(element)) - self.base.attributes.set('md5', md5) - - return super().store(*args, **kwargs) - - @classmethod - def from_md5(cls, md5, backend=None): - """Return a list of all `UpfData` that match the given md5 hash. - - .. note:: assumes hash of stored `UpfData` nodes is stored in the `md5` attribute - - :param md5: the file hash - :return: list of existing `UpfData` nodes that have the same md5 hash - """ - from aiida.orm.querybuilder import QueryBuilder - builder = QueryBuilder(backend=backend) - builder.append(cls, filters={'attributes.md5': {'==': md5}}) - return builder.all(flat=True) - - def set_file(self, file, filename=None): - """Store the file in the repository and parse it to set the `element` and `md5` attributes. - - :param file: filepath or filelike object of the UPF potential file to store. - Hint: Pass io.BytesIO(b"my string") to construct the file directly from a string. - :param filename: specify filename to use (defaults to name of provided file). - """ - # pylint: disable=redefined-builtin - from aiida.common.exceptions import ParsingError - from aiida.common.files import md5_file, md5_from_filelike - - parsed_data = parse_upf(file) - - try: - md5sum = md5_file(file) - except TypeError: - md5sum = md5_from_filelike(file) - - try: - element = parsed_data['element'] - except KeyError: - raise ParsingError(f"No 'element' parsed in the UPF file {self.filename}; unable to store") - - super().set_file(file, filename=filename) - - self.base.attributes.set('element', str(element)) - self.base.attributes.set('md5', md5sum) - - def get_upf_family_names(self): - """Get the list of all upf family names to which the pseudo belongs.""" - from aiida.orm import QueryBuilder, UpfFamily - - query = QueryBuilder(backend=self.backend) - query.append(UpfFamily, tag='group', project='label') - query.append(UpfData, filters={'id': {'==': self.pk}}, with_group='group') - return query.all(flat=True) - - @property - def element(self): - """Return the element of the UPF pseudopotential. - - :return: the element - """ - return self.base.attributes.get('element', None) - - @property - def md5sum(self): - """Return the md5 checksum of the UPF pseudopotential file. - - :return: the md5 checksum - """ - return self.base.attributes.get('md5', None) - - def _validate(self): - """Validate the UPF potential file stored for this node.""" - from aiida.common.exceptions import ValidationError - from aiida.common.files import md5_from_filelike - - super()._validate() - - # Do not check the filename because it will fail since we are passing in a handle, which doesn't have a filename - # and so `parse_upf` will raise. The reason we have to pass in a handle is because this is the repository does - # not allow to get an absolute filepath. Anyway, the filename was already checked in `set_file` when the file - # was set for the first time. All the logic in this method is duplicated in `store` and `_validate` and badly - # needs to be refactored, but that is for another time. - with self.open(mode='r') as handle: - parsed_data = parse_upf(handle, check_filename=False) - - # Open in binary mode which is required for generating the md5 checksum - with self.open(mode='rb') as handle: - md5 = md5_from_filelike(handle) - - try: - element = parsed_data['element'] - except KeyError: - raise ValidationError(f"No 'element' could be parsed in the UPF {self.filename}") - - try: - attr_element = self.base.attributes.get('element') - except AttributeError: - raise ValidationError("attribute 'element' not set.") - - try: - attr_md5 = self.base.attributes.get('md5') - except AttributeError: - raise ValidationError("attribute 'md5' not set.") - - if attr_element != element: - raise ValidationError(f"Attribute 'element' says '{attr_element}' but '{element}' was parsed instead.") - - if attr_md5 != md5: - raise ValidationError(f"Attribute 'md5' says '{attr_md5}' but '{md5}' was parsed instead.") - - def _prepare_upf(self, main_file_name=''): - """ - Return UPF content. - """ - # pylint: disable=unused-argument - return_string = self.get_content() - - return return_string.encode('utf-8'), {} - - @classmethod - def get_upf_group(cls, group_label): - """Return the UPF family group with the given label. - - :param group_label: the family group label - :return: the `Group` with the given label, if it exists - """ - from aiida.orm import UpfFamily - - return UpfFamily.get(label=group_label) - - @classmethod - def get_upf_groups(cls, filter_elements=None, user=None, backend=None): - """Return all names of groups of type UpfFamily, possibly with some filters. - - :param filter_elements: A string or a list of strings. - If present, returns only the groups that contains one UPF for every element present in the list. The default - is `None`, meaning that all families are returned. - :param user: if None (default), return the groups for all users. - If defined, it should be either a `User` instance or the user email. - :return: list of `Group` entities of type UPF. - """ - from aiida.orm import QueryBuilder, UpfFamily, User - - builder = QueryBuilder(backend=backend) - builder.append(UpfFamily, tag='group', project='*') - - if user: - builder.append(User, filters={'email': {'==': user}}, with_group='group') - - if isinstance(filter_elements, str): - filter_elements = [filter_elements] - - if filter_elements is not None: - builder.append(UpfData, filters={'attributes.element': {'in': filter_elements}}, with_group='group') - - builder.order_by({UpfFamily: {'id': 'asc'}}) - - return builder.all(flat=True) - - # pylint: disable=unused-argument - def _prepare_json(self, main_file_name=''): - """ - Returns UPF PP in json format. - """ - with self.open() as file_handle: - upf_json = upf_to_json(file_handle.read(), fname=self.filename) - return json.dumps(upf_json).encode('utf-8'), {} diff --git a/aiida/orm/nodes/links.py b/aiida/orm/nodes/links.py deleted file mode 100644 index 6b0a83842e..0000000000 --- a/aiida/orm/nodes/links.py +++ /dev/null @@ -1,240 +0,0 @@ -# -*- coding: utf-8 -*- -"""Interface for links of a node instance.""" -from __future__ import annotations - -import typing as t -from typing import Optional, cast - -from aiida.common import exceptions -from aiida.common.escaping import sql_string_match -from aiida.common.lang import type_check -from aiida.common.links import LinkType - -from ..querybuilder import QueryBuilder -from ..utils.links import LinkManager, LinkTriple - -if t.TYPE_CHECKING: - from .node import Node # pylint: disable=unused-import - - -class NodeLinks: - """Interface for links of a node instance.""" - - def __init__(self, node: 'Node') -> None: - """Initialize the links interface.""" - self._node = node - self.incoming_cache: list[LinkTriple] = [] - - def _add_incoming_cache(self, source: 'Node', link_type: LinkType, link_label: str) -> None: - """Add an incoming link to the cache. - - .. note: the proposed link is not validated in this function, so this should not be called directly - but it should only be called by `Node.add_incoming`. - - :param source: the node from which the link is coming - :param link_type: the link type - :param link_label: the link label - :raise aiida.common.UniquenessError: if the given link triple already exists in the cache - """ - assert self.incoming_cache is not None, 'incoming_cache not initialised' - - link_triple = LinkTriple(source, link_type, link_label) - - if link_triple in self.incoming_cache: - raise exceptions.UniquenessError(f'the link triple {link_triple} is already present in the cache') - - self.incoming_cache.append(link_triple) - - def add_incoming(self, source: 'Node', link_type: LinkType, link_label: str) -> None: - """Add a link of the given type from a given node to ourself. - - :param source: the node from which the link is coming - :param link_type: the link type - :param link_label: the link label - :raise TypeError: if `source` is not a Node instance or `link_type` is not a `LinkType` enum - :raise ValueError: if the proposed link is invalid - """ - self.validate_incoming(source, link_type, link_label) - source.base.links.validate_outgoing(self._node, link_type, link_label) - - if self._node.is_stored and source.is_stored: - self._node.backend_entity.add_incoming(source.backend_entity, link_type, link_label) - else: - self._add_incoming_cache(source, link_type, link_label) - - def validate_incoming(self, source: 'Node', link_type: LinkType, link_label: str) -> None: - """Validate adding a link of the given type from a given node to ourself. - - This function will first validate the types of the inputs, followed by the node and link types and validate - whether in principle a link of that type between the nodes of these types is allowed. - - Subsequently, the validity of the "degree" of the proposed link is validated, which means validating the - number of links of the given type from the given node type is allowed. - - :param source: the node from which the link is coming - :param link_type: the link type - :param link_label: the link label - :raise TypeError: if `source` is not a Node instance or `link_type` is not a `LinkType` enum - :raise ValueError: if the proposed link is invalid - """ - from aiida.orm.utils.links import validate_link - - from .node import Node # pylint: disable=redefined-outer-name - - validate_link(source, self._node, link_type, link_label, backend=self._node.backend) - - # Check if the proposed link would introduce a cycle in the graph following ancestor/descendant rules - if link_type in [LinkType.CREATE, LinkType.INPUT_CALC, LinkType.INPUT_WORK]: - builder = QueryBuilder(backend=self._node.backend).append( - Node, filters={'id': self._node.pk}, tag='parent').append( - Node, filters={'id': source.pk}, tag='child', with_ancestors='parent') # yapf:disable - if builder.count() > 0: - raise ValueError('the link you are attempting to create would generate a cycle in the graph') - - def validate_outgoing(self, target: 'Node', link_type: LinkType, link_label: str) -> None: # pylint: disable=unused-argument - """Validate adding a link of the given type from ourself to a given node. - - The validity of the triple (source, link, target) should be validated in the `validate_incoming` call. - This method will be called afterwards and can be overriden by subclasses to add additional checks that are - specific to that subclass. - - :param target: the node to which the link is going - :param link_type: the link type - :param link_label: the link label - :raise TypeError: if `target` is not a Node instance or `link_type` is not a `LinkType` enum - :raise ValueError: if the proposed link is invalid - """ - from .node import Node # pylint: disable=redefined-outer-name - type_check(link_type, LinkType, f'link_type should be a LinkType enum but got: {type(link_type)}') - type_check(target, Node, f'target should be a `Node` instance but got: {type(target)}') - - def get_stored_link_triples( - self, - node_class: Optional[t.Type['Node']] = None, - link_type: t.Union[LinkType, t.Sequence[LinkType]] = (), - link_label_filter: t.Optional[str] = None, - link_direction: str = 'incoming', - only_uuid: bool = False - ) -> list[LinkTriple]: - """Return the list of stored link triples directly incoming to or outgoing of this node. - - Note this will only return link triples that are stored in the database. Anything in the cache is ignored. - - :param node_class: If specified, should be a class, and it filters only elements of that (subclass of) type - :param link_type: Only get inputs of this link type, if empty tuple then returns all inputs of all link types. - :param link_label_filter: filters the incoming nodes by its link label. This should be a regex statement as - one would pass directly to a QueryBuilder filter statement with the 'like' operation. - :param link_direction: `incoming` or `outgoing` to get the incoming or outgoing links, respectively. - :param only_uuid: project only the node UUID instead of the instance onto the `NodeTriple.node` entries - """ - from .node import Node # pylint: disable=redefined-outer-name - - if not isinstance(link_type, (tuple, list)): - link_type = cast(t.Sequence[LinkType], (link_type,)) - - if link_type and not all(isinstance(t, LinkType) for t in link_type): - raise TypeError(f'link_type should be a LinkType or tuple of LinkType: got {link_type}') - - node_class = node_class or Node - node_filters: dict[str, t.Any] = {'id': {'==': self._node.pk}} - edge_filters: dict[str, t.Any] = {} - - if link_type: - edge_filters['type'] = {'in': [t.value for t in link_type]} - - if link_label_filter: - edge_filters['label'] = {'like': link_label_filter} - - builder = QueryBuilder(backend=self._node.backend) - builder.append(Node, filters=node_filters, tag='main') - - node_project = ['uuid'] if only_uuid else ['*'] - if link_direction == 'outgoing': - builder.append( - node_class, - with_incoming='main', - project=node_project, - edge_project=['type', 'label'], - edge_filters=edge_filters - ) - else: - builder.append( - node_class, - with_outgoing='main', - project=node_project, - edge_project=['type', 'label'], - edge_filters=edge_filters - ) - - return [LinkTriple(entry[0], LinkType(entry[1]), entry[2]) for entry in builder.all()] - - def get_incoming( - self, - node_class: Optional[t.Type['Node']] = None, - link_type: t.Union[LinkType, t.Sequence[LinkType]] = (), - link_label_filter: t.Optional[str] = None, - only_uuid: bool = False - ) -> LinkManager: - """Return a list of link triples that are (directly) incoming into this node. - - :param node_class: If specified, should be a class or tuple of classes, and it filters only - elements of that specific type (or a subclass of 'type') - :param link_type: If specified should be a string or tuple to get the inputs of this - link type, if None then returns all inputs of all link types. - :param link_label_filter: filters the incoming nodes by its link label. - Here wildcards (% and _) can be passed in link label filter as we are using "like" in QB. - :param only_uuid: project only the node UUID instead of the instance onto the `NodeTriple.node` entries - """ - if not isinstance(link_type, (tuple, list)): - link_type = cast(t.Sequence[LinkType], (link_type,)) - - if self._node.is_stored: - link_triples = self.get_stored_link_triples( - node_class, link_type, link_label_filter, 'incoming', only_uuid=only_uuid - ) - else: - link_triples = [] - - # Get all cached link triples - for link_triple in self.incoming_cache: - - if only_uuid: - link_triple = LinkTriple( - link_triple.node.uuid, # type: ignore - link_triple.link_type, - link_triple.link_label, - ) - - if link_triple in link_triples: - raise exceptions.InternalError( - f'Node<{self._node.pk}> has both a stored and cached link triple {link_triple}' - ) - - if not link_type or link_triple.link_type in link_type: - if link_label_filter is not None: - if sql_string_match(string=link_triple.link_label, pattern=link_label_filter): - link_triples.append(link_triple) - else: - link_triples.append(link_triple) - - return LinkManager(link_triples) - - def get_outgoing( - self, - node_class: Optional[t.Type['Node']] = None, - link_type: t.Union[LinkType, t.Sequence[LinkType]] = (), - link_label_filter: t.Optional[str] = None, - only_uuid: bool = False - ) -> LinkManager: - """Return a list of link triples that are (directly) outgoing of this node. - - :param node_class: If specified, should be a class or tuple of classes, and it filters only - elements of that specific type (or a subclass of 'type') - :param link_type: If specified should be a string or tuple to get the inputs of this - link type, if None then returns all outputs of all link types. - :param link_label_filter: filters the outgoing nodes by its link label. - Here wildcards (% and _) can be passed in link label filter as we are using "like" in QB. - :param only_uuid: project only the node UUID instead of the instance onto the `NodeTriple.node` entries - """ - link_triples = self.get_stored_link_triples(node_class, link_type, link_label_filter, 'outgoing', only_uuid) - return LinkManager(link_triples) diff --git a/aiida/orm/nodes/node.py b/aiida/orm/nodes/node.py deleted file mode 100644 index b1ac17c631..0000000000 --- a/aiida/orm/nodes/node.py +++ /dev/null @@ -1,722 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=too-many-arguments -"""Package for node ORM classes.""" -import datetime -from functools import cached_property -from logging import Logger -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Iterator, List, Optional, Tuple, Type, TypeVar -from uuid import UUID - -from aiida.common import exceptions -from aiida.common.lang import classproperty, type_check -from aiida.common.links import LinkType -from aiida.common.warnings import warn_deprecation -from aiida.manage import get_manager -from aiida.orm.utils.node import AbstractNodeMeta - -from ..computers import Computer -from ..entities import Collection as EntityCollection -from ..entities import Entity, from_backend_entity -from ..extras import EntityExtras -from ..querybuilder import QueryBuilder -from ..users import User -from .attributes import NodeAttributes -from .caching import NodeCaching -from .comments import NodeComments -from .links import NodeLinks -from .repository import NodeRepository - -if TYPE_CHECKING: - from aiida.plugins.entry_point import EntryPoint # type: ignore - - from ..implementation import BackendNode, StorageBackend - -__all__ = ('Node',) - -NodeType = TypeVar('NodeType', bound='Node') # pylint: disable=invalid-name - - -class NodeCollection(EntityCollection[NodeType], Generic[NodeType]): - """The collection of nodes.""" - - @staticmethod - def _entity_base_cls() -> Type['Node']: # type: ignore - return Node - - def delete(self, pk: int) -> None: - """Delete a `Node` from the collection with the given id - - :param pk: the node id - """ - node = self.get(id=pk) - - if not node.is_stored: - return - - if node.base.links.get_incoming().all(): - raise exceptions.InvalidOperation(f'cannot delete Node<{node.pk}> because it has incoming links') - - if node.base.links.get_outgoing().all(): - raise exceptions.InvalidOperation(f'cannot delete Node<{node.pk}> because it has outgoing links') - - self._backend.nodes.delete(pk) - - def iter_repo_keys(self, - filters: Optional[dict] = None, - subclassing: bool = True, - batch_size: int = 100) -> Iterator[str]: - """Iterate over all repository object keys for this ``Node`` class - - .. note:: keys will not be deduplicated, wrap in a ``set`` to achieve this - - :param filters: Filters for the node query - :param subclassing: Whether to include subclasses of the given class - :param batch_size: The number of nodes to fetch data for at once - """ - from aiida.repository import Repository - query = QueryBuilder(backend=self.backend) - query.append(self.entity_type, subclassing=subclassing, filters=filters, project=['repository_metadata']) - for metadata, in query.iterall(batch_size=batch_size): - for key in Repository.flatten(metadata).values(): - if key is not None: - yield key - - -class NodeBase: - """A namespace for node related functionality, that is not directly related to its user-facing properties.""" - - def __init__(self, node: 'Node') -> None: - """Construct a new instance of the base namespace.""" - self._node: 'Node' = node - - @cached_property - def repository(self) -> 'NodeRepository': - """Return the repository for this node.""" - return NodeRepository(self._node) - - @cached_property - def caching(self) -> 'NodeCaching': - """Return an interface to interact with the caching of this node.""" - return self._node._CLS_NODE_CACHING(self._node) # pylint: disable=protected-access - - @cached_property - def comments(self) -> 'NodeComments': - """Return an interface to interact with the comments of this node.""" - return NodeComments(self._node) - - @cached_property - def attributes(self) -> 'NodeAttributes': - """Return an interface to interact with the attributes of this node.""" - return NodeAttributes(self._node) - - @cached_property - def extras(self) -> 'EntityExtras': - """Return an interface to interact with the extras of this node.""" - return EntityExtras(self._node) - - @cached_property - def links(self) -> 'NodeLinks': - """Return an interface to interact with the links of this node.""" - return self._node._CLS_NODE_LINKS(self._node) # pylint: disable=protected-access - - -class Node(Entity['BackendNode', NodeCollection], metaclass=AbstractNodeMeta): - """ - Base class for all nodes in AiiDA. - - Stores attributes starting with an underscore. - - Caches files and attributes before the first save, and saves everything - only on store(). After the call to store(), attributes cannot be changed. - - Only after storing (or upon loading from uuid) extras can be modified - and in this case they are directly set on the db. - - In the plugin, also set the _plugin_type_string, to be set in the DB in - the 'type' field. - """ - # pylint: disable=too-many-public-methods - - _CLS_COLLECTION = NodeCollection - _CLS_NODE_LINKS = NodeLinks - _CLS_NODE_CACHING = NodeCaching - - # added by metaclass - _plugin_type_string: ClassVar[str] - _query_type_string: ClassVar[str] - - # This will be set by the metaclass call - _logger: Optional[Logger] = None - - # A tuple of attribute names that can be updated even after node is stored - # Requires Sealable mixin, but needs empty tuple for base class - _updatable_attributes: Tuple[str, ...] = tuple() - - # A tuple of attribute names that will be ignored when creating the hash. - _hash_ignored_attributes: Tuple[str, ...] = tuple() - - # Flag that determines whether the class can be cached. - _cachable = False - - # Flag that determines whether the class can be stored. - _storable = False - _unstorable_message = 'only Data, WorkflowNode, CalculationNode or their subclasses can be stored' - - def __init__( - self, - backend: Optional['StorageBackend'] = None, - user: Optional[User] = None, - computer: Optional[Computer] = None, - **kwargs: Any - ) -> None: - backend = backend or get_manager().get_profile_storage() - - if computer and not computer.is_stored: - raise ValueError('the computer is not stored') - - backend_computer = computer.backend_entity if computer else None - user = user if user else backend.default_user - - if user is None: - raise ValueError('the user cannot be None') - - backend_entity = backend.nodes.create( - node_type=self.class_node_type, user=user.backend_entity, computer=backend_computer, **kwargs - ) - super().__init__(backend_entity) - - @cached_property - def base(self) -> NodeBase: - """Return the node base namespace.""" - return NodeBase(self) - - def _check_mutability_attributes(self, keys: Optional[List[str]] = None) -> None: # pylint: disable=unused-argument - """Check if the entity is mutable and raise an exception if not. - - This is called from `NodeAttributes` methods that modify the attributes. - - :param keys: the keys that will be mutated, or all if None - """ - if self.is_stored: - raise exceptions.ModificationNotAllowed('the attributes of a stored entity are immutable') - - def __eq__(self, other: Any) -> bool: - """Fallback equality comparison by uuid (can be overwritten by specific types)""" - if isinstance(other, Node) and self.uuid == other.uuid: - return True - return super().__eq__(other) - - def __hash__(self) -> int: - """Python-Hash: Implementation that is compatible with __eq__""" - return UUID(self.uuid).int - - def __repr__(self) -> str: - return f'<{self.__class__.__name__}: {str(self)}>' - - def __str__(self) -> str: - if not self.is_stored: - return f'uuid: {self.uuid} (unstored)' - - return f'uuid: {self.uuid} (pk: {self.pk})' - - def __copy__(self): - """Copying a Node is not supported in general, but only for the Data sub class.""" - raise exceptions.InvalidOperation('copying a base Node is not supported') - - def __deepcopy__(self, memo): - """Deep copying a Node is not supported in general, but only for the Data sub class.""" - raise exceptions.InvalidOperation('deep copying a base Node is not supported') - - def _validate(self) -> bool: - """Validate information stored in Node object. - - For the :py:class:`~aiida.orm.Node` base class, this check is always valid. - Subclasses can override this method to perform additional checks - and should usually call ``super()._validate()`` first! - - This method is called automatically before storing the node in the DB. - Therefore, use :py:meth:`~aiida.orm.nodes.attributes.NodeAttributes.get()` and similar methods that - automatically read either from the DB or from the internal attribute cache. - """ - return True - - def _validate_storability(self) -> None: - """Verify that the current node is allowed to be stored. - - :raises `aiida.common.exceptions.StoringNotAllowed`: if the node does not match all requirements for storing - """ - from aiida.plugins.entry_point import is_registered_entry_point - - if not self._storable: - raise exceptions.StoringNotAllowed(self._unstorable_message) - - if not is_registered_entry_point(self.__module__, self.__class__.__name__, groups=('aiida.node', 'aiida.data')): - raise exceptions.StoringNotAllowed( - f'class `{self.__module__}:{self.__class__.__name__}` does not have a registered entry point. ' - 'Check that the corresponding plugin is installed ' - 'and that the entry point shows up in `verdi plugin list`.' - ) - - @classproperty - def class_node_type(cls) -> str: - """Returns the node type of this node (sub) class.""" - # pylint: disable=no-self-argument,no-member - return cls._plugin_type_string - - @classproperty - def entry_point(cls) -> Optional['EntryPoint']: - """Return the entry point associated this node class. - - :return: the associated entry point or ``None`` if it isn't known. - """ - # pylint: disable=no-self-argument - from aiida.plugins.entry_point import get_entry_point_from_class - return get_entry_point_from_class(cls.__module__, cls.__name__)[1] - - @property - def logger(self) -> Optional[Logger]: - """Return the logger configured for this Node. - - :return: Logger object - """ - return self._logger - - @property - def uuid(self) -> str: - """Return the node UUID. - - :return: the string representation of the UUID - """ - return self.backend_entity.uuid - - @property - def node_type(self) -> str: - """Return the node type. - - :return: the node type - """ - return self.backend_entity.node_type - - @property - def process_type(self) -> Optional[str]: - """Return the node process type. - - :return: the process type - """ - return self.backend_entity.process_type - - @process_type.setter - def process_type(self, value: str) -> None: - """Set the node process type. - - :param value: the new value to set - """ - self.backend_entity.process_type = value - - @property - def label(self) -> str: - """Return the node label. - - :return: the label - """ - return self.backend_entity.label - - @label.setter - def label(self, value: str) -> None: - """Set the label. - - :param value: the new value to set - """ - self.backend_entity.label = value - - @property - def description(self) -> str: - """Return the node description. - - :return: the description - """ - return self.backend_entity.description - - @description.setter - def description(self, value: str) -> None: - """Set the description. - - :param value: the new value to set - """ - self.backend_entity.description = value - - @property - def computer(self) -> Optional[Computer]: - """Return the computer of this node.""" - if self.backend_entity.computer: - return from_backend_entity(Computer, self.backend_entity.computer) - - return None - - @computer.setter - def computer(self, computer: Optional[Computer]) -> None: - """Set the computer of this node. - - :param computer: a `Computer` - """ - if self.is_stored: - raise exceptions.ModificationNotAllowed('cannot set the computer on a stored node') - - type_check(computer, Computer, allow_none=True) - - self.backend_entity.computer = None if computer is None else computer.backend_entity - - @property - def user(self) -> User: - """Return the user of this node.""" - return from_backend_entity(User, self._backend_entity.user) - - @user.setter - def user(self, user: User) -> None: - """Set the user of this node. - - :param user: a `User` - """ - if self.is_stored: - raise exceptions.ModificationNotAllowed('cannot set the user on a stored node') - - type_check(user, User) - self.backend_entity.user = user.backend_entity - - @property - def ctime(self) -> datetime.datetime: - """Return the node ctime. - - :return: the ctime - """ - return self.backend_entity.ctime - - @property - def mtime(self) -> datetime.datetime: - """Return the node mtime. - - :return: the mtime - """ - return self.backend_entity.mtime - - def store_all(self, with_transaction: bool = True) -> 'Node': - """Store the node, together with all input links. - - Unstored nodes from cached incoming linkswill also be stored. - - :parameter with_transaction: if False, do not use a transaction because the caller will already have opened one. - """ - if self.is_stored: - raise exceptions.ModificationNotAllowed(f'Node<{self.pk}> is already stored') - - # For each node of a cached incoming link, check that all its incoming links are stored - for link_triple in self.base.links.incoming_cache: - link_triple.node._verify_are_parents_stored() # pylint: disable=protected-access - - for link_triple in self.base.links.incoming_cache: - if not link_triple.node.is_stored: - link_triple.node.store(with_transaction=with_transaction) - - return self.store(with_transaction) - - def store(self, with_transaction: bool = True) -> 'Node': # pylint: disable=arguments-differ - """Store the node in the database while saving its attributes and repository directory. - - After being called attributes cannot be changed anymore! Instead, extras can be changed only AFTER calling - this store() function. - - :note: After successful storage, those links that are in the cache, and for which also the parent node is - already stored, will be automatically stored. The others will remain unstored. - - :parameter with_transaction: if False, do not use a transaction because the caller will already have opened one. - """ - from aiida.manage.caching import get_use_cache - - if not self.is_stored: - - # Call `_validate_storability` directly and not in `_validate` in case sub class forgets to call the super. - self._validate_storability() - self._validate() - - # Verify that parents are already stored. Raises if this is not the case. - self._verify_are_parents_stored() - - # Determine whether the cache should be used for the process type of this node. - use_cache = get_use_cache(identifier=self.process_type) - - # Clean the values on the backend node *before* computing the hash in `_get_same_node`. This will allow - # us to set `clean=False` if we are storing normally, since the values will already have been cleaned - self._backend_entity.clean_values() - - # Retrieve the cached node. - same_node = self.base.caching._get_same_node() if use_cache else None # pylint: disable=protected-access - - if same_node is not None: - self._store_from_cache(same_node, with_transaction=with_transaction) - else: - self._store(with_transaction=with_transaction, clean=True) - - if self.backend.autogroup.is_to_be_grouped(self): - group = self.backend.autogroup.get_or_create_group() - group.add_nodes(self) - - return self - - def _store(self, with_transaction: bool = True, clean: bool = True) -> 'Node': - """Store the node in the database while saving its attributes and repository directory. - - :param with_transaction: if False, do not use a transaction because the caller will already have opened one. - :param clean: boolean, if True, will clean the attributes and extras before attempting to store - """ - self.base.repository._store() # pylint: disable=protected-access - - links = self.base.links.incoming_cache - self._backend_entity.store(links, with_transaction=with_transaction, clean=clean) - - self.base.links.incoming_cache = [] - self.base.caching.rehash() - - return self - - def _verify_are_parents_stored(self) -> None: - """Verify that all `parent` nodes are already stored. - - :raise aiida.common.ModificationNotAllowed: if one of the source nodes of incoming links is not stored. - """ - for link_triple in self.base.links.incoming_cache: - if not link_triple.node.is_stored: - raise exceptions.ModificationNotAllowed( - f'Cannot store because source node of link triple {link_triple} is not stored' - ) - - def _store_from_cache(self, cache_node: 'Node', with_transaction: bool) -> None: - """Store this node from an existing cache node. - - .. note:: - - With the current implementation of the backend repository, which automatically deduplicates the content that - it contains, we do not have to copy the contents of the source node. Since the content should be exactly - equal, the repository will already contain it and there is nothing to copy. We simply replace the current - ``repository`` instance with a clone of that of the source node, which does not actually copy any files. - - """ - from aiida.orm.utils.mixins import Sealable - assert self.node_type == cache_node.node_type - - # Make sure the node doesn't have any RETURN links - if cache_node.base.links.get_outgoing(link_type=LinkType.RETURN).all(): - raise ValueError('Cannot use cache from nodes with RETURN links.') - - self.label = cache_node.label - self.description = cache_node.description - - # Make sure to reinitialize the repository instance of the clone to that of the source node. - self.base.repository._copy(cache_node.base.repository) # pylint: disable=protected-access - - for key, value in cache_node.base.attributes.all.items(): - if key != Sealable.SEALED_KEY: - self.base.attributes.set(key, value) - - self._store(with_transaction=with_transaction, clean=False) - self._add_outputs_from_cache(cache_node) - self.base.extras.set('_aiida_cached_from', cache_node.uuid) - - def _add_outputs_from_cache(self, cache_node: 'Node') -> None: - """Replicate the output links and nodes from the cached node onto this node.""" - for entry in cache_node.base.links.get_outgoing(link_type=LinkType.CREATE): - new_node = entry.node.clone() - new_node.base.links.add_incoming(self, link_type=LinkType.CREATE, link_label=entry.link_label) - new_node.store() - - def get_description(self) -> str: - """Return a string with a description of the node. - - :return: a description string - """ - return '' - - @property - def is_valid_cache(self) -> bool: - """Hook to exclude certain ``Node`` classes from being considered a valid cache. - - The base class assumes that all node instances are valid to cache from, unless the ``_VALID_CACHE_KEY`` extra - has been set to ``False`` explicitly. Subclasses can override this property with more specific logic, but should - probably also consider the value returned by this base class. - """ - kls = self.__class__.__name__ - warn_deprecation( - f'`{kls}.is_valid_cache` is deprecated, use `{kls}.base.caching.is_valid_cache` instead.', - version=3, - stacklevel=2 - ) - return self.base.caching.is_valid_cache - - @is_valid_cache.setter - def is_valid_cache(self, valid: bool) -> None: - """Set whether this node instance is considered valid for caching or not. - - If a node instance has this property set to ``False``, it will never be used in the caching mechanism, unless - the subclass overrides the ``is_valid_cache`` property and ignores it implementation completely. - - :param valid: whether the node is valid or invalid for use in caching. - """ - kls = self.__class__.__name__ - warn_deprecation( - f'`{kls}.is_valid_cache` is deprecated, use `{kls}.base.caching.is_valid_cache` instead.', - version=3, - stacklevel=2 - ) - self.base.caching.is_valid_cache = valid - - _deprecated_repo_methods = { - 'copy_tree': 'copy_tree', - 'delete_object': 'delete_object', - 'get_object': 'get_object', - 'get_object_content': 'get_object_content', - 'glob': 'glob', - 'list_objects': 'list_objects', - 'list_object_names': 'list_object_names', - 'open': 'open', - 'put_object_from_filelike': 'put_object_from_filelike', - 'put_object_from_file': 'put_object_from_file', - 'put_object_from_tree': 'put_object_from_tree', - 'walk': 'walk', - 'repository_metadata': 'metadata', - } - - _deprecated_attr_methods = { - 'attributes': 'all', - 'get_attribute': 'get', - 'get_attribute_many': 'get_many', - 'set_attribute': 'set', - 'set_attribute_many': 'set_many', - 'reset_attributes': 'reset', - 'delete_attribute': 'delete', - 'delete_attribute_many': 'delete_many', - 'clear_attributes': 'clear', - 'attributes_items': 'items', - 'attributes_keys': 'keys', - } - - _deprecated_extra_methods = { - 'extras': 'all', - 'get_extra': 'get', - 'get_extra_many': 'get_many', - 'set_extra': 'set', - 'set_extra_many': 'set_many', - 'reset_extras': 'reset', - 'delete_extra': 'delete', - 'delete_extra_many': 'delete_many', - 'clear_extras': 'clear', - 'extras_items': 'items', - 'extras_keys': 'keys', - } - - _deprecated_comment_methods = { - 'add_comment': 'add', - 'get_comment': 'get', - 'get_comments': 'all', - 'remove_comment': 'remove', - 'update_comment': 'update', - } - - _deprecated_caching_methods = { - 'get_hash': 'get_hash', - '_get_hash': '_get_hash', - '_get_objects_to_hash': '_get_objects_to_hash', - 'rehash': 'rehash', - 'clear_hash': 'clear_hash', - 'get_cache_source': 'get_cache_source', - 'is_created_from_cache': 'is_created_from_cache', - '_get_same_node': '_get_same_node', - 'get_all_same_nodes': 'get_all_same_nodes', - '_iter_all_same_nodes': '_iter_all_same_nodes', - } - - _deprecated_links_methods = { - 'add_incoming': 'add_incoming', - 'validate_incoming': 'validate_incoming', - 'validate_outgoing': 'validate_outgoing', - 'get_stored_link_triples': 'get_stored_link_triples', - 'get_incoming': 'get_incoming', - 'get_outgoing': 'get_outgoing', - } - - @classproperty - def Collection(cls): # pylint: disable=invalid-name - """Return the collection type for this class. - - This used to be a class argument with the value ``NodeCollection``. The argument is deprecated and this property - is here for backwards compatibility to print the deprecation warning. - """ - warn_deprecation( - 'This attribute is deprecated, use `aiida.orm.nodes.node.NodeCollection` instead.', version=3, stacklevel=2 - ) - return NodeCollection - - def __getattr__(self, name: str) -> Any: - """This method is called when an attribute is not found in the instance. - - It allows for the handling of deprecated mixin methods. - """ - if name in self._deprecated_extra_methods: - new_name = self._deprecated_extra_methods[name] - kls = self.__class__.__name__ - warn_deprecation( - f'`{kls}.{name}` is deprecated, use `{kls}.base.extras.{new_name}` instead.', version=3, stacklevel=3 - ) - return getattr(self.base.extras, new_name) - - if name in self._deprecated_attr_methods: - new_name = self._deprecated_attr_methods[name] - kls = self.__class__.__name__ - warn_deprecation( - f'`{kls}.{name}` is deprecated, use `{kls}.base.attributes.{new_name}` instead.', - version=3, - stacklevel=3 - ) - return getattr(self.base.attributes, new_name) - - if name in self._deprecated_repo_methods: - new_name = self._deprecated_repo_methods[name] - kls = self.__class__.__name__ - warn_deprecation( - f'`{kls}.{name}` is deprecated, use `{kls}.base.repository.{new_name}` instead.', - version=3, - stacklevel=3 - ) - return getattr(self.base.repository, new_name) - - if name in self._deprecated_comment_methods: - new_name = self._deprecated_comment_methods[name] - kls = self.__class__.__name__ - warn_deprecation( - f'`{kls}.{name}` is deprecated, use `{kls}.base.comments.{new_name}` instead.', version=3, stacklevel=3 - ) - return getattr(self.base.comments, new_name) - - if name in self._deprecated_caching_methods: - new_name = self._deprecated_caching_methods[name] - kls = self.__class__.__name__ - warn_deprecation( - f'`{kls}.{name}` is deprecated, use `{kls}.base.caching.{new_name}` instead.', version=3, stacklevel=3 - ) - return getattr(self.base.caching, new_name) - - if name in self._deprecated_links_methods: - new_name = self._deprecated_links_methods[name] - kls = self.__class__.__name__ - warn_deprecation( - f'`{kls}.{name}` is deprecated, use `{kls}.base.links.{new_name}` instead.', version=3, stacklevel=3 - ) - return getattr(self.base.links, new_name) - - raise AttributeError(name) diff --git a/aiida/orm/nodes/process/__init__.py b/aiida/orm/nodes/process/__init__.py deleted file mode 100644 index 283b14e9b0..0000000000 --- a/aiida/orm/nodes/process/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with `Node` sub classes for processes.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .calculation import * -from .process import * -from .workflow import * - -__all__ = ( - 'CalcFunctionNode', - 'CalcJobNode', - 'CalculationNode', - 'ProcessNode', - 'WorkChainNode', - 'WorkFunctionNode', - 'WorkflowNode', -) - -# yapf: enable diff --git a/aiida/orm/nodes/process/calculation/__init__.py b/aiida/orm/nodes/process/calculation/__init__.py deleted file mode 100644 index 21af4e576e..0000000000 --- a/aiida/orm/nodes/process/calculation/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with `Node` sub classes for calculation processes.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .calcfunction import * -from .calcjob import * -from .calculation import * - -__all__ = ( - 'CalcFunctionNode', - 'CalcJobNode', - 'CalculationNode', -) - -# yapf: enable diff --git a/aiida/orm/nodes/process/calculation/calcfunction.py b/aiida/orm/nodes/process/calculation/calcfunction.py deleted file mode 100644 index 818fec3d06..0000000000 --- a/aiida/orm/nodes/process/calculation/calcfunction.py +++ /dev/null @@ -1,53 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with `Node` sub class for calculation function processes.""" -from typing import TYPE_CHECKING - -from aiida.common.links import LinkType -from aiida.orm.utils.mixins import FunctionCalculationMixin - -from ..process import ProcessNodeLinks -from .calculation import CalculationNode - -if TYPE_CHECKING: - from aiida.orm import Node - -__all__ = ('CalcFunctionNode',) - - -class CalcFunctionNodeLinks(ProcessNodeLinks): - """Interface for links of a node instance.""" - - def validate_outgoing(self, target: 'Node', link_type: LinkType, link_label: str) -> None: - """Validate adding a link of the given type from ourself to a given node. - - A calcfunction cannot return Data, so if we receive an outgoing link to a stored Data node, that means - the user created a Data node within our function body and stored it themselves or they are returning an input - node. The latter use case is reserved for @workfunctions, as they can have RETURN links. - - :param target: the node to which the link is going - :param link_type: the link type - :param link_label: the link label - :raise TypeError: if `target` is not a Node instance or `link_type` is not a `LinkType` enum - :raise ValueError: if the proposed link is invalid - """ - super().validate_outgoing(target, link_type, link_label) - if link_type is LinkType.CREATE and target.is_stored: - raise ValueError( - 'trying to return an already stored Data node from a @calcfunction, however, @calcfunctions cannot ' - 'return data. If you stored the node yourself, simply do not call `store()` yourself. If you want to ' - 'return an input node, use a @workfunction instead.' - ) - - -class CalcFunctionNode(FunctionCalculationMixin, CalculationNode): # type: ignore - """ORM class for all nodes representing the execution of a calcfunction.""" - - _CLS_NODE_LINKS = CalcFunctionNodeLinks diff --git a/aiida/orm/nodes/process/calculation/calcjob.py b/aiida/orm/nodes/process/calculation/calcjob.py deleted file mode 100644 index 8463c871f9..0000000000 --- a/aiida/orm/nodes/process/calculation/calcjob.py +++ /dev/null @@ -1,513 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with `Node` sub class for calculation job processes.""" -import datetime -from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Sequence, Tuple, Type, Union - -from aiida.common import exceptions -from aiida.common.datastructures import CalcJobState -from aiida.common.lang import classproperty -from aiida.common.links import LinkType - -from ..process import ProcessNodeCaching -from .calculation import CalculationNode - -if TYPE_CHECKING: - from aiida.engine.processes.builder import ProcessBuilder - from aiida.orm import FolderData - from aiida.orm.authinfos import AuthInfo - from aiida.orm.utils.calcjob import CalcJobResultManager - from aiida.parsers import Parser - from aiida.schedulers.datastructures import JobInfo, JobState - from aiida.tools.calculations import CalculationTools - from aiida.transports import Transport - -__all__ = ('CalcJobNode',) - - -class CalcJobNodeCaching(ProcessNodeCaching): - """Interface to control caching of a node instance.""" - - def _get_objects_to_hash(self) -> List[Any]: - """Return a list of objects which should be included in the hash. - - This method is purposefully overridden from the base `Node` class, because we do not want to include the - repository folder in the hash. The reason is that the hash of this node is computed in the `store` method, at - which point the input files that will be stored in the repository have not yet been generated. Including these - anyway in the computation of the hash would mean that the hash of the node would change as soon as the process - has started and the input files have been written to the repository. - """ - from importlib import import_module - objects = [ - import_module(self._node.__module__.split('.', 1)[0]).__version__, - { - key: val - for key, val in self._node.base.attributes.items() - if key not in self._node._hash_ignored_attributes and key not in self._node._updatable_attributes # pylint: disable=unsupported-membership-test,protected-access - }, - self._node.computer.uuid if self._node.computer is not None else None, # pylint: disable=no-member - { - entry.link_label: entry.node.base.caching.get_hash() - for entry in self._node.base.links.get_incoming(link_type=(LinkType.INPUT_CALC, LinkType.INPUT_WORK)) - if entry.link_label not in self._hash_ignored_inputs - } - ] - return objects - - -class CalcJobNode(CalculationNode): - """ORM class for all nodes representing the execution of a CalcJob.""" - - # pylint: disable=too-many-public-methods - _CLS_NODE_CACHING = CalcJobNodeCaching - - CALC_JOB_STATE_KEY = 'state' - IMMIGRATED_KEY = 'imported' - REMOTE_WORKDIR_KEY = 'remote_workdir' - RETRIEVE_LIST_KEY = 'retrieve_list' - RETRIEVE_TEMPORARY_LIST_KEY = 'retrieve_temporary_list' - SCHEDULER_JOB_ID_KEY = 'job_id' - SCHEDULER_STATE_KEY = 'scheduler_state' - SCHEDULER_LAST_CHECK_TIME_KEY = 'scheduler_lastchecktime' - SCHEDULER_LAST_JOB_INFO_KEY = 'last_job_info' - SCHEDULER_DETAILED_JOB_INFO_KEY = 'detailed_job_info' - - # An optional entry point for a CalculationTools instance - _tools = None - - @property - def tools(self) -> 'CalculationTools': - """Return the calculation tools that are registered for the process type associated with this calculation. - - If the entry point name stored in the `process_type` of the CalcJobNode has an accompanying entry point in the - `aiida.tools.calculations` entry point category, it will attempt to load the entry point and instantiate it - passing the node to the constructor. If the entry point does not exist, cannot be resolved or loaded, a warning - will be logged and the base CalculationTools class will be instantiated and returned. - - :return: CalculationTools instance - """ - from aiida.plugins.entry_point import get_entry_point_from_string, is_valid_entry_point_string, load_entry_point - from aiida.tools.calculations import CalculationTools - - if self._tools is None: - entry_point_string = self.process_type - - if entry_point_string and is_valid_entry_point_string(entry_point_string): - entry_point = get_entry_point_from_string(entry_point_string) - - try: - tools_class = load_entry_point('aiida.tools.calculations', entry_point.name) - self._tools = tools_class(self) - except exceptions.EntryPointError as exception: - self._tools = CalculationTools(self) - self.logger.warning( - f'could not load the calculation tools entry point {entry_point.name}: {exception}' - ) - - return self._tools - - @classproperty - def _updatable_attributes(cls) -> Tuple[str, ...]: # pylint: disable=no-self-argument - return super()._updatable_attributes + ( - cls.CALC_JOB_STATE_KEY, - cls.IMMIGRATED_KEY, - cls.REMOTE_WORKDIR_KEY, - cls.RETRIEVE_LIST_KEY, - cls.RETRIEVE_TEMPORARY_LIST_KEY, - cls.SCHEDULER_JOB_ID_KEY, - cls.SCHEDULER_STATE_KEY, - cls.SCHEDULER_LAST_CHECK_TIME_KEY, - cls.SCHEDULER_LAST_JOB_INFO_KEY, - cls.SCHEDULER_DETAILED_JOB_INFO_KEY, - ) - - @classproperty - def _hash_ignored_attributes(cls) -> Tuple[str, ...]: # pylint: disable=no-self-argument - return super()._hash_ignored_attributes + ( - 'queue_name', - 'account', - 'qos', - 'priority', - 'max_wallclock_seconds', - 'max_memory_kb', - ) - - @property - def is_imported(self) -> bool: - """Return whether the calculation job was imported instead of being an actual run.""" - return self.base.attributes.get(self.IMMIGRATED_KEY, None) is True - - def get_option(self, name: str) -> Optional[Any]: - """ - Retun the value of an option that was set for this CalcJobNode - - :param name: the option name - :return: the option value or None - :raises: ValueError for unknown option - """ - return self.base.attributes.get(name, None) - - def set_option(self, name: str, value: Any) -> None: - """ - Set an option to the given value - - :param name: the option name - :param value: the value to set - :raises: ValueError for unknown option - :raises: TypeError for values with invalid type - """ - self.base.attributes.set(name, value) - - def get_options(self) -> Dict[str, Any]: - """ - Return the dictionary of options set for this CalcJobNode - - :return: dictionary of the options and their values - """ - options = {} - for name in self.process_class.spec_options.keys(): # type: ignore[attr-defined] - value = self.get_option(name) - if value is not None: - options[name] = value - - return options - - def set_options(self, options: Dict[str, Any]) -> None: - """ - Set the options for this CalcJobNode - - :param options: dictionary of option and their values to set - """ - for name, value in options.items(): - self.set_option(name, value) - - def get_state(self) -> Optional[CalcJobState]: - """Return the calculation job active sub state. - - The calculation job state serves to give more granular state information to `CalcJobs`, in addition to the - generic process state, while the calculation job is active. The state can take values from the enumeration - defined in `aiida.common.datastructures.CalcJobState` and can be used to query for calculation jobs in specific - active states. - - :return: instance of `aiida.common.datastructures.CalcJobState` or `None` if invalid value, or not set - """ - state = self.base.attributes.get(self.CALC_JOB_STATE_KEY, None) - - try: - state = CalcJobState(state) - except ValueError: - state = None - - return state - - def set_state(self, state: CalcJobState) -> None: - """Set the calculation active job state. - - :raise: ValueError if state is invalid - """ - if not isinstance(state, CalcJobState): - raise ValueError(f'{state} is not a valid CalcJobState') - - self.base.attributes.set(self.CALC_JOB_STATE_KEY, state.value) - - def delete_state(self) -> None: - """Delete the calculation job state attribute if it exists.""" - try: - self.base.attributes.delete(self.CALC_JOB_STATE_KEY) - except AttributeError: - pass - - def set_remote_workdir(self, remote_workdir: str) -> None: - """Set the absolute path to the working directory on the remote computer where the calculation is run. - - :param remote_workdir: absolute filepath to the remote working directory - """ - self.base.attributes.set(self.REMOTE_WORKDIR_KEY, remote_workdir) - - def get_remote_workdir(self) -> Optional[str]: - """Return the path to the remote (on cluster) scratch folder of the calculation. - - :return: a string with the remote path - """ - return self.base.attributes.get(self.REMOTE_WORKDIR_KEY, None) - - @staticmethod - def _validate_retrieval_directive(directives: Sequence[Union[str, Tuple[str, str, str]]]) -> None: - """Validate a list or tuple of file retrieval directives. - - :param directives: a list or tuple of file retrieval directives - :raise ValueError: if the format of the directives is invalid - """ - if not isinstance(directives, (tuple, list)): - raise TypeError('file retrieval directives has to be a list or tuple') - - for directive in directives: - - # A string as a directive is valid, so we continue - if isinstance(directive, str): - continue - - # Otherwise, it has to be a tuple of length three with specific requirements - if not isinstance(directive, (tuple, list)) or len(directive) != 3: - raise ValueError(f'invalid directive, not a list or tuple of length three: {directive}') - - if not isinstance(directive[0], str): - raise ValueError('invalid directive, first element has to be a string representing remote path') - - if not isinstance(directive[1], str): - raise ValueError('invalid directive, second element has to be a string representing local path') - - if not isinstance(directive[2], int): - raise ValueError('invalid directive, three element has to be an integer representing the depth') - - def set_retrieve_list(self, retrieve_list: Sequence[Union[str, Tuple[str, str, str]]]) -> None: - """Set the retrieve list. - - This list of directives will instruct the daemon what files to retrieve after the calculation has completed. - list or tuple of files or paths that should be retrieved by the daemon. - - :param retrieve_list: list or tuple of with filepath directives - """ - self._validate_retrieval_directive(retrieve_list) - self.base.attributes.set(self.RETRIEVE_LIST_KEY, retrieve_list) - - def get_retrieve_list(self) -> Optional[Sequence[Union[str, Tuple[str, str, str]]]]: - """Return the list of files/directories to be retrieved on the cluster after the calculation has completed. - - :return: a list of file directives - """ - return self.base.attributes.get(self.RETRIEVE_LIST_KEY, None) - - def set_retrieve_temporary_list(self, retrieve_temporary_list: Sequence[Union[str, Tuple[str, str, str]]]) -> None: - """Set the retrieve temporary list. - - The retrieve temporary list stores files that are retrieved after completion and made available during parsing - and are deleted as soon as the parsing has been completed. - - :param retrieve_temporary_list: list or tuple of with filepath directives - """ - self._validate_retrieval_directive(retrieve_temporary_list) - self.base.attributes.set(self.RETRIEVE_TEMPORARY_LIST_KEY, retrieve_temporary_list) - - def get_retrieve_temporary_list(self) -> Optional[Sequence[Union[str, Tuple[str, str, str]]]]: - """Return list of files to be retrieved from the cluster which will be available during parsing. - - :return: a list of file directives - """ - return self.base.attributes.get(self.RETRIEVE_TEMPORARY_LIST_KEY, None) - - def set_job_id(self, job_id: Union[int, str]) -> None: - """Set the job id that was assigned to the calculation by the scheduler. - - .. note:: the id will always be stored as a string - - :param job_id: the id assigned by the scheduler after submission - """ - return self.base.attributes.set(self.SCHEDULER_JOB_ID_KEY, str(job_id)) - - def get_job_id(self) -> Optional[str]: - """Return job id that was assigned to the calculation by the scheduler. - - :return: the string representation of the scheduler job id - """ - return self.base.attributes.get(self.SCHEDULER_JOB_ID_KEY, None) - - def set_scheduler_state(self, state: 'JobState') -> None: - """Set the scheduler state. - - :param state: an instance of `JobState` - """ - from aiida.common import timezone - from aiida.schedulers.datastructures import JobState - - if not isinstance(state, JobState): - raise ValueError(f'scheduler state should be an instance of JobState, got: {state}') - - self.base.attributes.set(self.SCHEDULER_STATE_KEY, state.value) - self.base.attributes.set(self.SCHEDULER_LAST_CHECK_TIME_KEY, timezone.now().isoformat()) - - def get_scheduler_state(self) -> Optional['JobState']: - """Return the status of the calculation according to the cluster scheduler. - - :return: a JobState enum instance. - """ - from aiida.schedulers.datastructures import JobState - - state = self.base.attributes.get(self.SCHEDULER_STATE_KEY, None) - - if state is None: - return state - - return JobState(state) - - def get_scheduler_lastchecktime(self) -> Optional[datetime.datetime]: - """Return the time of the last update of the scheduler state by the daemon or None if it was never set. - - :return: a datetime object or None - """ - value = self.base.attributes.get(self.SCHEDULER_LAST_CHECK_TIME_KEY, None) - - if value is not None: - value = datetime.datetime.fromisoformat(value) - - return value - - def set_detailed_job_info(self, detailed_job_info: Optional[dict]) -> None: - """Set the detailed job info dictionary. - - :param detailed_job_info: a dictionary with metadata with the accounting of a completed job - """ - self.base.attributes.set(self.SCHEDULER_DETAILED_JOB_INFO_KEY, detailed_job_info) - - def get_detailed_job_info(self) -> Optional[dict]: - """Return the detailed job info dictionary. - - The scheduler is polled for the detailed job info after the job is completed and ready to be retrieved. - - :return: the dictionary with detailed job info if defined or None - """ - return self.base.attributes.get(self.SCHEDULER_DETAILED_JOB_INFO_KEY, None) - - def set_last_job_info(self, last_job_info: 'JobInfo') -> None: - """Set the last job info. - - :param last_job_info: a `JobInfo` object - """ - self.base.attributes.set(self.SCHEDULER_LAST_JOB_INFO_KEY, last_job_info.get_dict()) - - def get_last_job_info(self) -> Optional['JobInfo']: - """Return the last information asked to the scheduler about the status of the job. - - The last job info is updated on every poll of the scheduler, except for the final poll when the job drops from - the scheduler's job queue. - For completed jobs, the last job info therefore contains the "second-to-last" job info that still shows the job - as running. Please use :meth:`~aiida.orm.nodes.process.calculation.calcjob.CalcJobNode.get_detailed_job_info` - instead. - - :return: a `JobInfo` object (that closely resembles a dictionary) or None. - """ - from aiida.schedulers.datastructures import JobInfo - - last_job_info_dictserialized = self.base.attributes.get(self.SCHEDULER_LAST_JOB_INFO_KEY, None) - - if last_job_info_dictserialized is not None: - job_info = JobInfo.load_from_dict(last_job_info_dictserialized) - else: - job_info = None - - return job_info - - def get_authinfo(self) -> 'AuthInfo': - """Return the `AuthInfo` that is configured for the `Computer` set for this node. - - :return: `AuthInfo` - """ - computer = self.computer - - if computer is None: - raise exceptions.NotExistent('No computer has been set for this calculation') - - return computer.get_authinfo(self.user) # pylint: disable=no-member - - def get_transport(self) -> 'Transport': - """Return the transport for this calculation. - - :return: `Transport` configured with the `AuthInfo` associated to the computer of this node - """ - return self.get_authinfo().get_transport() - - def get_parser_class(self) -> Optional[Type['Parser']]: - """Return the output parser object for this calculation or None if no parser is set. - - :return: a `Parser` class. - :raises `aiida.common.exceptions.EntryPointError`: if the parser entry point can not be resolved. - """ - from aiida.plugins import ParserFactory - - parser_name = self.get_option('parser_name') - - if parser_name is not None: - return ParserFactory(parser_name) - - return None - - @property - def link_label_retrieved(self) -> str: - """Return the link label used for the retrieved FolderData node.""" - return 'retrieved' - - def get_retrieved_node(self) -> Optional['FolderData']: - """Return the retrieved data folder. - - :return: the retrieved FolderData node or None if not found - """ - from aiida.orm import FolderData - try: - return self.base.links.get_outgoing(node_class=FolderData, - link_label_filter=self.link_label_retrieved).one().node - except ValueError: - return None - - @property - def res(self) -> 'CalcJobResultManager': - """ - To be used to get direct access to the parsed parameters. - - :return: an instance of the CalcJobResultManager. - - :note: a practical example on how it is meant to be used: let's say that there is a key 'energy' - in the dictionary of the parsed results which contains a list of floats. - The command `calc.res.energy` will return such a list. - """ - from aiida.orm.utils.calcjob import CalcJobResultManager - return CalcJobResultManager(self) - - def get_scheduler_stdout(self) -> Optional[AnyStr]: - """Return the scheduler stderr output if the calculation has finished and been retrieved, None otherwise. - - :return: scheduler stderr output or None - """ - filename = self.get_option('scheduler_stdout') - retrieved_node = self.get_retrieved_node() - - if filename is None or retrieved_node is None: - return None - - try: - stdout = retrieved_node.base.repository.get_object_content(filename) - except IOError: - stdout = None - - return stdout - - def get_scheduler_stderr(self) -> Optional[AnyStr]: - """Return the scheduler stdout output if the calculation has finished and been retrieved, None otherwise. - - :return: scheduler stdout output or None - """ - filename = self.get_option('scheduler_stderr') - retrieved_node = self.get_retrieved_node() - - if filename is None or retrieved_node is None: - return None - - try: - stderr = retrieved_node.base.repository.get_object_content(filename) - except IOError: - stderr = None - - return stderr - - def get_description(self) -> str: - """Return a description of the node based on its properties.""" - state = self.get_state() - if not state: - return '' - return state.value diff --git a/aiida/orm/nodes/process/calculation/calculation.py b/aiida/orm/nodes/process/calculation/calculation.py deleted file mode 100644 index 4dd8b9bf23..0000000000 --- a/aiida/orm/nodes/process/calculation/calculation.py +++ /dev/null @@ -1,47 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with `Node` sub class for calculation processes.""" - -from aiida.common.links import LinkType -from aiida.orm.utils.managers import NodeLinksManager - -from ..process import ProcessNode - -__all__ = ('CalculationNode',) - - -class CalculationNode(ProcessNode): - """Base class for all nodes representing the execution of a calculation process.""" - - _storable = True # Calculation nodes are storable - _cachable = True # Calculation nodes can be cached from - _unstorable_message = 'storing for this node has been disabled' - - @property - def inputs(self) -> NodeLinksManager: - """Return an instance of `NodeLinksManager` to manage incoming INPUT_CALC links - - The returned Manager allows you to easily explore the nodes connected to this node - via an incoming INPUT_CALC link. - The incoming nodes are reachable by their link labels which are attributes of the manager. - - """ - return NodeLinksManager(node=self, link_type=LinkType.INPUT_CALC, incoming=True) - - @property - def outputs(self) -> NodeLinksManager: - """Return an instance of `NodeLinksManager` to manage outgoing CREATE links - - The returned Manager allows you to easily explore the nodes connected to this node - via an outgoing CREATE link. - The outgoing nodes are reachable by their link labels which are attributes of the manager. - - """ - return NodeLinksManager(node=self, link_type=LinkType.CREATE, incoming=False) diff --git a/aiida/orm/nodes/process/process.py b/aiida/orm/nodes/process/process.py deleted file mode 100644 index 10d55cdfa6..0000000000 --- a/aiida/orm/nodes/process/process.py +++ /dev/null @@ -1,584 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with `Node` sub class for processes.""" -import enum -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union - -from plumpy.process_states import ProcessState - -from aiida.common import exceptions -from aiida.common.lang import classproperty -from aiida.common.links import LinkType -from aiida.orm.utils.mixins import Sealable - -from ..caching import NodeCaching -from ..node import Node, NodeLinks - -if TYPE_CHECKING: - from aiida.engine.processes import ExitCode, Process - from aiida.engine.processes.builder import ProcessBuilder - -__all__ = ('ProcessNode',) - - -class ProcessNodeCaching(NodeCaching): - """Interface to control caching of a node instance.""" - - # The link_type might not be correct while the object is being created. - _hash_ignored_inputs = ['CALL_CALC', 'CALL_WORK'] - - @property - def is_valid_cache(self) -> bool: - """Return whether the node is valid for caching - - :returns: True if this process node is valid to be used for caching, False otherwise - """ - if not (super().is_valid_cache and self._node.is_finished): - return False - - try: - process_class = self._node.process_class - except ValueError as exc: - self._node.logger.warning( - f"Not considering {self} for caching, '{exc!r}' when accessing its process class." - ) - return False - - # For process functions, the `process_class` does not have an is_valid_cache attribute - try: - is_valid_cache_func = process_class.is_valid_cache - except AttributeError: - return True - - return is_valid_cache_func(self._node) - - @is_valid_cache.setter - def is_valid_cache(self, valid: bool) -> None: - """Set whether this node instance is considered valid for caching or not. - - :param valid: whether the node is valid or invalid for use in caching. - """ - super(ProcessNodeCaching, self.__class__).is_valid_cache.fset(self, valid) - - def _get_objects_to_hash(self) -> List[Any]: - """ - Return a list of objects which should be included in the hash. - """ - res = super()._get_objects_to_hash() # pylint: disable=protected-access - res.append({ - entry.link_label: entry.node.base.caching.get_hash() - for entry in self._node.base.links.get_incoming(link_type=(LinkType.INPUT_CALC, LinkType.INPUT_WORK)) - if entry.link_label not in self._hash_ignored_inputs - }) - return res - - -class ProcessNodeLinks(NodeLinks): - """Interface for links of a node instance.""" - - def validate_incoming(self, source: Node, link_type: LinkType, link_label: str) -> None: - """Validate adding a link of the given type from a given node to ourself. - - Adding an input link to a `ProcessNode` once it is stored is illegal because this should be taken care of - by the engine in one go. If a link is being added after the node is stored, it is most likely not by the engine - and it should not be allowed. - - :param source: the node from which the link is coming - :param link_type: the link type - :param link_label: the link label - :raise TypeError: if `source` is not a Node instance or `link_type` is not a `LinkType` enum - :raise ValueError: if the proposed link is invalid - """ - if self._node.is_sealed: - raise exceptions.ModificationNotAllowed('Cannot add a link to a sealed node') - - if self._node.is_stored: - raise ValueError('attempted to add an input link after the process node was already stored.') - - super().validate_incoming(source, link_type, link_label) - - def validate_outgoing(self, target, link_type, link_label): - """Validate adding a link of the given type from ourself to a given node. - - Adding an outgoing link from a sealed node is forbidden. - - :param target: the node to which the link is going - :param link_type: the link type - :param link_label: the link label - :raise aiida.common.ModificationNotAllowed: if the source node (self) is sealed - """ - if self._node.is_sealed: - raise exceptions.ModificationNotAllowed('Cannot add a link from a sealed node') - - super().validate_outgoing(target, link_type=link_type, link_label=link_label) - - -class ProcessNode(Sealable, Node): - """ - Base class for all nodes representing the execution of a process - - This class and its subclasses serve as proxies in the database, for actual `Process` instances being run. The - `Process` instance in memory will leverage an instance of this class (the exact sub class depends on the sub class - of `Process`) to persist important information of its state to the database. This serves as a way for the user to - inspect the state of the `Process` during its execution as well as a permanent record of its execution in the - provenance graph, after the execution has terminated. - """ - # pylint: disable=too-many-public-methods,abstract-method - - _CLS_NODE_LINKS = ProcessNodeLinks - _CLS_NODE_CACHING = ProcessNodeCaching - - CHECKPOINT_KEY = 'checkpoints' - EXCEPTION_KEY = 'exception' - EXIT_MESSAGE_KEY = 'exit_message' - EXIT_STATUS_KEY = 'exit_status' - PROCESS_PAUSED_KEY = 'paused' - PROCESS_LABEL_KEY = 'process_label' - PROCESS_STATE_KEY = 'process_state' - PROCESS_STATUS_KEY = 'process_status' - METADATA_INPUTS_KEY: str = 'metadata_inputs' - - _unstorable_message = 'only Data, WorkflowNode, CalculationNode or their subclasses can be stored' - - def __str__(self) -> str: - base = super().__str__() - if self.process_type: - return f'{base} ({self.process_type})' - - return f'{base}' - - @classproperty - def _updatable_attributes(cls) -> Tuple[str, ...]: - # pylint: disable=no-self-argument - return super()._updatable_attributes + ( - cls.PROCESS_PAUSED_KEY, - cls.CHECKPOINT_KEY, - cls.EXCEPTION_KEY, - cls.EXIT_MESSAGE_KEY, - cls.EXIT_STATUS_KEY, - cls.PROCESS_LABEL_KEY, - cls.PROCESS_STATE_KEY, - cls.PROCESS_STATUS_KEY, - ) - - def set_metadata_inputs(self, value: Dict[str, Any]) -> None: - """Set the mapping of inputs corresponding to ``metadata`` ports that were passed to the process.""" - return self.base.attributes.set(self.METADATA_INPUTS_KEY, value) - - def get_metadata_inputs(self) -> Optional[Dict[str, Any]]: - """Return the mapping of inputs corresponding to ``metadata`` ports that were passed to the process.""" - return self.base.attributes.get(self.METADATA_INPUTS_KEY, None) - - @property - def logger(self): - """ - Get the logger of the Calculation object, so that it also logs to the DB. - - :return: LoggerAdapter object, that works like a logger, but also has the 'extra' embedded - """ - from aiida.orm.utils.log import create_logger_adapter - return create_logger_adapter(self._logger, self) - - def get_builder_restart(self) -> 'ProcessBuilder': - """Return a `ProcessBuilder` that is ready to relaunch the process that created this node. - - The process class will be set based on the `process_type` of this node and the inputs of the builder will be - prepopulated with the inputs registered for this node. This functionality is very useful if a process has - completed and you want to relaunch it with slightly different inputs. - - :return: `~aiida.engine.processes.builder.ProcessBuilder` instance - """ - builder = self.process_class.get_builder() - builder._update(self.base.links.get_incoming(link_type=(LinkType.INPUT_CALC, LinkType.INPUT_WORK)).nested()) # pylint: disable=protected-access - builder._merge(self.get_metadata_inputs() or {}) # pylint: disable=protected-access - - return builder - - @property - def process_class(self) -> Type['Process']: - """Return the process class that was used to create this node. - - :return: `Process` class - :raises ValueError: if no process type is defined, it is an invalid process type string or cannot be resolved - to load the corresponding class - """ - from aiida.plugins.entry_point import load_entry_point_from_string - - if not self.process_type: - raise ValueError(f'no process type for Node<{self.pk}>: cannot recreate process class') - - try: - process_class = load_entry_point_from_string(self.process_type) - except exceptions.EntryPointError as exception: - raise ValueError( - f'could not load process class for entry point `{self.process_type}` for Node<{self.pk}>: {exception}' - ) from exception - except ValueError as exception: - import importlib - - def str_rsplit_iter(string, sep='.'): - components = string.split(sep) - for idx in range(1, len(components)): - yield sep.join(components[:-idx]), components[-idx:] - - for module_name, class_names in str_rsplit_iter(self.process_type): - try: - module = importlib.import_module(module_name) - process_class = module - for objname in class_names: - process_class = getattr(process_class, objname) - break - except (AttributeError, ValueError, ImportError): - pass - else: - raise ValueError( - f'could not load process class from `{self.process_type}` for Node<{self.pk}>' - ) from exception - - return process_class - - def set_process_type(self, process_type_string: str) -> None: - """ - Set the process type string. - - :param process_type: the process type string identifying the class using this process node as storage. - """ - self.process_type = process_type_string - - @property - def process_label(self) -> Optional[str]: - """ - Return the process label - - :returns: the process label - """ - return self.base.attributes.get(self.PROCESS_LABEL_KEY, None) - - def set_process_label(self, label: str) -> None: - """ - Set the process label - - :param label: process label string - """ - self.base.attributes.set(self.PROCESS_LABEL_KEY, label) - - @property - def process_state(self) -> Optional[ProcessState]: - """ - Return the process state - - :returns: the process state instance of ProcessState enum - """ - state = self.base.attributes.get(self.PROCESS_STATE_KEY, None) - - if state is None: - return state - - return ProcessState(state) - - def set_process_state(self, state: Union[str, ProcessState, None]): - """ - Set the process state - - :param state: value or instance of ProcessState enum - """ - if isinstance(state, ProcessState): - state = state.value - return self.base.attributes.set(self.PROCESS_STATE_KEY, state) - - @property - def process_status(self) -> Optional[str]: - """ - Return the process status - - The process status is a generic status message e.g. the reason it might be paused or when it is being killed - - :returns: the process status - """ - return self.base.attributes.get(self.PROCESS_STATUS_KEY, None) - - def set_process_status(self, status: Optional[str]) -> None: - """ - Set the process status - - The process status is a generic status message e.g. the reason it might be paused or when it is being killed. - If status is None, the corresponding attribute will be deleted. - - :param status: string process status - """ - if status is None: - try: - self.base.attributes.delete(self.PROCESS_STATUS_KEY) - except AttributeError: - pass - return None - - if not isinstance(status, str): - raise TypeError('process status should be a string') - - return self.base.attributes.set(self.PROCESS_STATUS_KEY, status) - - @property - def is_terminated(self) -> bool: - """ - Return whether the process has terminated - - Terminated means that the process has reached any terminal state. - - :return: True if the process has terminated, False otherwise - :rtype: bool - """ - return self.is_excepted or self.is_finished or self.is_killed - - @property - def is_excepted(self) -> bool: - """ - Return whether the process has excepted - - Excepted means that during execution of the process, an exception was raised that was not caught. - - :return: True if during execution of the process an exception occurred, False otherwise - :rtype: bool - """ - return self.process_state == ProcessState.EXCEPTED - - @property - def is_killed(self) -> bool: - """ - Return whether the process was killed - - Killed means the process was killed directly by the user or by the calling process being killed. - - :return: True if the process was killed, False otherwise - :rtype: bool - """ - return self.process_state == ProcessState.KILLED - - @property - def is_finished(self) -> bool: - """ - Return whether the process has finished - - Finished means that the process reached a terminal state nominally. - Note that this does not necessarily mean successfully, but there were no exceptions and it was not killed. - - :return: True if the process has finished, False otherwise - :rtype: bool - """ - return self.process_state == ProcessState.FINISHED - - @property - def is_finished_ok(self) -> bool: - """ - Return whether the process has finished successfully - - Finished successfully means that it terminated nominally and had a zero exit status. - - :return: True if the process has finished successfully, False otherwise - :rtype: bool - """ - return self.is_finished and self.exit_status == 0 - - @property - def is_failed(self) -> bool: - """ - Return whether the process has failed - - Failed means that the process terminated nominally but it had a non-zero exit status. - - :return: True if the process has failed, False otherwise - :rtype: bool - """ - return self.is_finished and self.exit_status != 0 - - @property - def exit_code(self) -> Optional['ExitCode']: - """Return the exit code of the process. - - It is reconstituted from the ``exit_status`` and ``exit_message`` attributes if both of those are defined. - - :returns: The exit code if defined, or ``None``. - """ - from aiida.engine.processes.exit_code import ExitCode - - exit_status = self.exit_status - exit_message = self.exit_message - - if exit_status is None or exit_message is None: - return None - - return ExitCode(exit_status, exit_message) - - @property - def exit_status(self) -> Optional[int]: - """ - Return the exit status of the process - - :returns: the exit status, an integer exit code or None - """ - return self.base.attributes.get(self.EXIT_STATUS_KEY, None) - - def set_exit_status(self, status: Union[None, enum.Enum, int]) -> None: - """ - Set the exit status of the process - - :param state: an integer exit code or None, which will be interpreted as zero - """ - if status is None: - status = 0 - - if isinstance(status, enum.Enum): - status = status.value - - if not isinstance(status, int): - raise ValueError(f'exit status has to be an integer, got {status}') - - return self.base.attributes.set(self.EXIT_STATUS_KEY, status) - - @property - def exit_message(self) -> Optional[str]: - """ - Return the exit message of the process - - :returns: the exit message - """ - return self.base.attributes.get(self.EXIT_MESSAGE_KEY, None) - - def set_exit_message(self, message: Optional[str]) -> None: - """ - Set the exit message of the process, if None nothing will be done - - :param message: a string message - """ - if message is None: - return None - - if not isinstance(message, str): - raise ValueError(f'exit message has to be a string type, got {type(message)}') - - return self.base.attributes.set(self.EXIT_MESSAGE_KEY, message) - - @property - def exception(self) -> Optional[str]: - """ - Return the exception of the process or None if the process is not excepted. - - If the process is marked as excepted yet there is no exception attribute, an empty string will be returned. - - :returns: the exception message or None - """ - if self.is_excepted: - return self.base.attributes.get(self.EXCEPTION_KEY, '') - - return None - - def set_exception(self, exception: str) -> None: - """ - Set the exception of the process - - :param exception: the exception message - """ - if not isinstance(exception, str): - raise ValueError(f'exception message has to be a string type, got {type(exception)}') - - return self.base.attributes.set(self.EXCEPTION_KEY, exception) - - @property - def checkpoint(self) -> Optional[str]: - """ - Return the checkpoint bundle set for the process - - :returns: checkpoint bundle if it exists, None otherwise - """ - return self.base.attributes.get(self.CHECKPOINT_KEY, None) - - def set_checkpoint(self, checkpoint: str) -> None: - """ - Set the checkpoint bundle set for the process - - :param state: string representation of the stepper state info - """ - return self.base.attributes.set(self.CHECKPOINT_KEY, checkpoint) - - def delete_checkpoint(self) -> None: - """ - Delete the checkpoint bundle set for the process - """ - try: - self.base.attributes.delete(self.CHECKPOINT_KEY) - except AttributeError: - pass - - @property - def paused(self) -> bool: - """ - Return whether the process is paused - - :returns: True if the Calculation is marked as paused, False otherwise - """ - return self.base.attributes.get(self.PROCESS_PAUSED_KEY, False) - - def pause(self) -> None: - """ - Mark the process as paused by setting the corresponding attribute. - - This serves only to reflect that the corresponding Process is paused and so this method should not be called - by anyone but the Process instance itself. - """ - return self.base.attributes.set(self.PROCESS_PAUSED_KEY, True) - - def unpause(self) -> None: - """ - Mark the process as unpaused by removing the corresponding attribute. - - This serves only to reflect that the corresponding Process is unpaused and so this method should not be called - by anyone but the Process instance itself. - """ - try: - self.base.attributes.delete(self.PROCESS_PAUSED_KEY) - except AttributeError: - pass - - @property - def called(self) -> List['ProcessNode']: - """ - Return a list of nodes that the process called - - :returns: list of process nodes called by this process - """ - return self.base.links.get_outgoing(link_type=(LinkType.CALL_CALC, LinkType.CALL_WORK)).all_nodes() - - @property - def called_descendants(self) -> List['ProcessNode']: - """ - Return a list of all nodes that have been called downstream of this process - - This will recursively find all the called processes for this process and its children. - """ - descendants = [] - - for descendant in self.called: - descendants.append(descendant) - descendants.extend(descendant.called_descendants) - - return descendants - - @property - def caller(self) -> Optional['ProcessNode']: - """ - Return the process node that called this process node, or None if it does not have a caller - - :returns: process node that called this process node instance or None - """ - try: - caller = self.base.links.get_incoming(link_type=(LinkType.CALL_CALC, LinkType.CALL_WORK)).one().node - except ValueError: - return None - return caller diff --git a/aiida/orm/nodes/process/workflow/__init__.py b/aiida/orm/nodes/process/workflow/__init__.py deleted file mode 100644 index f4125a4f8f..0000000000 --- a/aiida/orm/nodes/process/workflow/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with `Node` sub classes for workflow processes.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .workchain import * -from .workflow import * -from .workfunction import * - -__all__ = ( - 'WorkChainNode', - 'WorkFunctionNode', - 'WorkflowNode', -) - -# yapf: enable diff --git a/aiida/orm/nodes/process/workflow/workchain.py b/aiida/orm/nodes/process/workflow/workchain.py deleted file mode 100644 index 0a673431c1..0000000000 --- a/aiida/orm/nodes/process/workflow/workchain.py +++ /dev/null @@ -1,45 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with `Node` sub class for workchain processes.""" -from typing import Optional, Tuple - -from aiida.common.lang import classproperty - -from .workflow import WorkflowNode - -__all__ = ('WorkChainNode',) - - -class WorkChainNode(WorkflowNode): - """ORM class for all nodes representing the execution of a WorkChain.""" - - STEPPER_STATE_INFO_KEY = 'stepper_state_info' - - @classproperty - def _updatable_attributes(cls) -> Tuple[str, ...]: # type: ignore - # pylint: disable=no-self-argument - return super()._updatable_attributes + (cls.STEPPER_STATE_INFO_KEY,) - - @property - def stepper_state_info(self) -> Optional[str]: - """ - Return the stepper state info - - :returns: string representation of the stepper state info - """ - return self.base.attributes.get(self.STEPPER_STATE_INFO_KEY, None) - - def set_stepper_state_info(self, stepper_state_info: str) -> None: - """ - Set the stepper state info - - :param state: string representation of the stepper state info - """ - return self.base.attributes.set(self.STEPPER_STATE_INFO_KEY, stepper_state_info) diff --git a/aiida/orm/nodes/process/workflow/workflow.py b/aiida/orm/nodes/process/workflow/workflow.py deleted file mode 100644 index bef3bee3d3..0000000000 --- a/aiida/orm/nodes/process/workflow/workflow.py +++ /dev/null @@ -1,80 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with `Node` sub class for workflow processes.""" -from typing import TYPE_CHECKING - -from aiida.common.links import LinkType -from aiida.orm.utils.managers import NodeLinksManager - -from ..process import ProcessNode, ProcessNodeLinks - -if TYPE_CHECKING: - from aiida.orm import Node - -__all__ = ('WorkflowNode',) - - -class WorkflowNodeLinks(ProcessNodeLinks): - """Interface for links of a node instance.""" - - def validate_outgoing(self, target: 'Node', link_type: LinkType, link_label: str) -> None: - """Validate adding a link of the given type from ourself to a given node. - - A workflow cannot 'create' Data, so if we receive an outgoing link to an unstored Data node, that means - the user created a Data node within our function body and tries to attach it as an output. This is strictly - forbidden and can cause provenance to be lost. - - :param target: the node to which the link is going - :param link_type: the link type - :param link_label: the link label - :raise TypeError: if `target` is not a Node instance or `link_type` is not a `LinkType` enum - :raise ValueError: if the proposed link is invalid - """ - super().validate_outgoing(target, link_type, link_label) - if link_type is LinkType.RETURN and not target.is_stored: - raise ValueError( - 'Workflow<{}> tried returning an unstored `Data` node. This likely means new `Data` is being created ' - 'inside the workflow. In order to preserve data provenance, use a `calcfunction` to create this node ' - 'and return its output from the workflow'.format(self._node.process_label) - ) - - -class WorkflowNode(ProcessNode): - """Base class for all nodes representing the execution of a workflow process.""" - - _CLS_NODE_LINKS = WorkflowNodeLinks - - # Workflow nodes are storable - _storable = True - _unstorable_message = 'storing for this node has been disabled' - - @property - def inputs(self) -> NodeLinksManager: - """Return an instance of `NodeLinksManager` to manage incoming INPUT_WORK links - - The returned Manager allows you to easily explore the nodes connected to this node - via an incoming INPUT_WORK link. - The incoming nodes are reachable by their link labels which are attributes of the manager. - - :return: `NodeLinksManager` - """ - return NodeLinksManager(node=self, link_type=LinkType.INPUT_WORK, incoming=True) - - @property - def outputs(self) -> NodeLinksManager: - """Return an instance of `NodeLinksManager` to manage outgoing RETURN links - - The returned Manager allows you to easily explore the nodes connected to this node - via an outgoing RETURN link. - The outgoing nodes are reachable by their link labels which are attributes of the manager. - - :return: `NodeLinksManager` - """ - return NodeLinksManager(node=self, link_type=LinkType.RETURN, incoming=False) diff --git a/aiida/orm/nodes/process/workflow/workfunction.py b/aiida/orm/nodes/process/workflow/workfunction.py deleted file mode 100644 index 73d37c0ab2..0000000000 --- a/aiida/orm/nodes/process/workflow/workfunction.py +++ /dev/null @@ -1,51 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with `Node` sub class for workflow function processes.""" -from typing import TYPE_CHECKING - -from aiida.common.links import LinkType -from aiida.orm.utils.mixins import FunctionCalculationMixin - -from .workflow import WorkflowNode, WorkflowNodeLinks - -if TYPE_CHECKING: - from aiida.orm import Node - -__all__ = ('WorkFunctionNode',) - - -class WorkFunctionNodeLinks(WorkflowNodeLinks): - """Interface for links of a node instance.""" - - def validate_outgoing(self, target: 'Node', link_type: LinkType, link_label: str) -> None: - """Validate adding a link of the given type from ourself to a given node. - - A workfunction cannot create Data, so if we receive an outgoing RETURN link to an unstored Data node, that means - the user created a Data node within our function body and is trying to return it. This use case should be - reserved for @calcfunctions, as they can have CREATE links. - - :param target: the node to which the link is going - :param link_type: the link type - :param link_label: the link label - :raise TypeError: if `target` is not a Node instance or `link_type` is not a `LinkType` enum - :raise ValueError: if the proposed link is invalid - """ - super().validate_outgoing(target, link_type, link_label) - if link_type is LinkType.RETURN and not target.is_stored: - raise ValueError( - 'trying to return an unstored Data node from a @workfunction, however, @workfunctions cannot create ' - 'data. You probably want to use a @calcfunction instead.' - ) - - -class WorkFunctionNode(FunctionCalculationMixin, WorkflowNode): # type: ignore - """ORM class for all nodes representing the execution of a workfunction.""" - - _CLS_NODE_LINKS = WorkFunctionNodeLinks diff --git a/aiida/orm/nodes/repository.py b/aiida/orm/nodes/repository.py deleted file mode 100644 index ccc814e20b..0000000000 --- a/aiida/orm/nodes/repository.py +++ /dev/null @@ -1,324 +0,0 @@ -# -*- coding: utf-8 -*- -"""Interface to the file repository of a node instance.""" -import contextlib -import copy -import io -import pathlib -import tempfile -from typing import TYPE_CHECKING, Any, BinaryIO, Dict, Iterable, Iterator, List, Optional, TextIO, Tuple, Union - -from aiida.common import exceptions -from aiida.manage import get_config_option -from aiida.repository import File, Repository -from aiida.repository.backend import SandboxRepositoryBackend - -if TYPE_CHECKING: - from .node import Node - -__all__ = ('NodeRepository',) - -FilePath = Union[str, pathlib.PurePosixPath] - - -class NodeRepository: - """Interface to the file repository of a node instance. - - This is the compatibility layer between the `Node` class and the `Repository` class. The repository in principle has - no concept of immutability, so it is implemented here. Any mutating operations will raise a `ModificationNotAllowed` - exception if the node is stored. Otherwise the operation is just forwarded to the repository instance. - - The repository instance keeps an internal mapping of the file hierarchy that it maintains, starting from an empty - hierarchy if the instance was constructed normally, or from a specific hierarchy if reconstructred through the - ``Repository.from_serialized`` classmethod. This is only the case for stored nodes, because unstored nodes do not - have any files yet when they are constructed. Once the node get's stored, the repository is asked to serialize its - metadata contents which is then stored in the ``repository_metadata`` field of the backend node. - This layer explicitly does not update the metadata of the node on a mutation action. - The reason is that for stored nodes these actions are anyway forbidden and for unstored nodes, - the final metadata will be stored in one go, once the node is stored, - so there is no need to keep updating the node metadata intermediately. - Note that this does mean that ``repository_metadata`` does not give accurate information, - as long as the node is not yet stored. - """ - - def __init__(self, node: 'Node') -> None: - """Construct a new instance of the repository interface.""" - self._node: 'Node' = node - self._repository_instance: Optional[Repository] = None - - @property - def metadata(self) -> Dict[str, Any]: - """Return the repository metadata, representing the virtual file hierarchy. - - Note, this is only accurate if the node is stored. - - :return: the repository metadata - """ - return self._node.backend_entity.repository_metadata - - def _update_repository_metadata(self): - """Refresh the repository metadata of the node if it is stored.""" - if self._node.is_stored: - self._node.backend_entity.repository_metadata = self.serialize() - - def _check_mutability(self): - """Check if the node is mutable. - - :raises `~aiida.common.exceptions.ModificationNotAllowed`: when the node is stored and therefore immutable. - """ - if self._node.is_stored: - raise exceptions.ModificationNotAllowed('the node is stored and therefore the repository is immutable.') - - @property - def _repository(self) -> Repository: - """Return the repository instance, lazily constructing it if necessary. - - .. note:: this property is protected because a node's repository should not be accessed outside of its scope. - - :return: the file repository instance. - """ - if self._repository_instance is None: - if self._node.is_stored: - backend = self._node.backend.get_repository() - self._repository_instance = Repository.from_serialized(backend=backend, serialized=self.metadata) - else: - filepath = get_config_option('storage.sandbox') or None - self._repository_instance = Repository(backend=SandboxRepositoryBackend(filepath)) - - return self._repository_instance - - @_repository.setter - def _repository(self, repository: Repository) -> None: - """Set a new repository instance, deleting the current reference if it has been initialized. - - :param repository: the new repository instance to set. - """ - if self._repository_instance is not None: - del self._repository_instance - - self._repository_instance = repository - - def _store(self) -> None: - """Store the repository in the backend.""" - if isinstance(self._repository.backend, SandboxRepositoryBackend): - # Only if the backend repository is a sandbox do we have to clone its contents to the permanent repository. - repository_backend = self._node.backend.get_repository() - repository = Repository(backend=repository_backend) - repository.clone(self._repository) - # Swap the sandbox repository for the new permanent repository instance which should delete the sandbox - self._repository_instance = repository - # update the metadata on the node backend - self._node.backend_entity.repository_metadata = self.serialize() - - def _copy(self, repo: 'NodeRepository') -> None: - """Copy a repository from another instance. - - This is used when storing cached nodes. - - :param repo: the repository to clone. - """ - self._repository = copy.copy(repo._repository) # pylint: disable=protected-access - - def _clone(self, repo: 'NodeRepository') -> None: - """Clone the repository from another instance. - - This is used when cloning a node. - - :param repo: the repository to clone. - """ - self._repository.clone(repo._repository) # pylint: disable=protected-access - - def serialize(self) -> Dict: - """Serialize the metadata of the repository content into a JSON-serializable format. - - :return: dictionary with the content metadata. - """ - return self._repository.serialize() - - def hash(self) -> str: - """Generate a hash of the repository's contents. - - :return: the hash representing the contents of the repository. - """ - return self._repository.hash() - - def list_objects(self, path: Optional[str] = None) -> List[File]: - """Return a list of the objects contained in this repository sorted by name, optionally in given sub directory. - - :param path: the relative path where to store the object in the repository. - :return: a list of `File` named tuples representing the objects present in directory with the given key. - :raises TypeError: if the path is not a string and relative path. - :raises FileNotFoundError: if no object exists for the given path. - :raises NotADirectoryError: if the object at the given path is not a directory. - """ - return self._repository.list_objects(path) - - def list_object_names(self, path: Optional[str] = None) -> List[str]: - """Return a sorted list of the object names contained in this repository, optionally in the given sub directory. - - :param path: the relative path where to store the object in the repository. - :return: a list of `File` named tuples representing the objects present in directory with the given key. - :raises TypeError: if the path is not a string and relative path. - :raises FileNotFoundError: if no object exists for the given path. - :raises NotADirectoryError: if the object at the given path is not a directory. - """ - return self._repository.list_object_names(path) - - @contextlib.contextmanager - def open(self, path: str, mode='r') -> Iterator[Union[BinaryIO, TextIO]]: - """Open a file handle to an object stored under the given key. - - .. note:: this should only be used to open a handle to read an existing file. To write a new file use the method - ``put_object_from_filelike`` instead. - - :param path: the relative path of the object within the repository. - :return: yield a byte stream object. - :raises TypeError: if the path is not a string and relative path. - :raises FileNotFoundError: if the file does not exist. - :raises IsADirectoryError: if the object is a directory and not a file. - :raises OSError: if the file could not be opened. - """ - if mode not in ['r', 'rb']: - raise ValueError(f'the mode {mode} is not supported.') - - with self._repository.open(path) as handle: - if 'b' not in mode: - yield io.StringIO(handle.read().decode('utf-8')) - else: - yield handle - - def get_object(self, path: Optional[FilePath] = None) -> File: - """Return the object at the given path. - - :param path: the relative path where to store the object in the repository. - :return: the `File` representing the object located at the given relative path. - :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. - :raises FileNotFoundError: if no object exists for the given path. - """ - return self._repository.get_object(path) - - def get_object_content(self, path: str, mode='r') -> Union[str, bytes]: - """Return the content of a object identified by key. - - :param key: fully qualified identifier for the object within the repository. - :raises TypeError: if the path is not a string and relative path. - :raises FileNotFoundError: if the file does not exist. - :raises IsADirectoryError: if the object is a directory and not a file. - :raises OSError: if the file could not be opened. - """ - if mode not in ['r', 'rb']: - raise ValueError(f'the mode {mode} is not supported.') - - if 'b' not in mode: - return self._repository.get_object_content(path).decode('utf-8') - - return self._repository.get_object_content(path) - - def put_object_from_bytes(self, content: bytes, path: str) -> None: - """Store the given content in the repository at the given path. - - :param path: the relative path where to store the object in the repository. - :param content: the content to store. - :raises TypeError: if the path is not a string and relative path. - :raises FileExistsError: if an object already exists at the given path. - """ - self._check_mutability() - self._repository.put_object_from_filelike(io.BytesIO(content), path) - self._update_repository_metadata() - - def put_object_from_filelike(self, handle: io.BufferedReader, path: str): - """Store the byte contents of a file in the repository. - - :param handle: filelike object with the byte content to be stored. - :param path: the relative path where to store the object in the repository. - :raises TypeError: if the path is not a string and relative path. - :raises `~aiida.common.exceptions.ModificationNotAllowed`: when the node is stored and therefore immutable. - """ - self._check_mutability() - - if isinstance(handle, io.StringIO): - handle = io.BytesIO(handle.read().encode('utf-8')) - - if isinstance(handle, tempfile._TemporaryFileWrapper): # pylint: disable=protected-access - if 'b' in handle.file.mode: - handle = io.BytesIO(handle.read()) - else: - handle = io.BytesIO(handle.read().encode('utf-8')) - - self._repository.put_object_from_filelike(handle, path) - self._update_repository_metadata() - - def put_object_from_file(self, filepath: str, path: str): - """Store a new object under `path` with contents of the file located at `filepath` on the local file system. - - :param filepath: absolute path of file whose contents to copy to the repository - :param path: the relative path where to store the object in the repository. - :raises TypeError: if the path is not a string and relative path, or the handle is not a byte stream. - :raises `~aiida.common.exceptions.ModificationNotAllowed`: when the node is stored and therefore immutable. - """ - self._check_mutability() - self._repository.put_object_from_file(filepath, path) - self._update_repository_metadata() - - def put_object_from_tree(self, filepath: str, path: Optional[str] = None): - """Store the entire contents of `filepath` on the local file system in the repository with under given `path`. - - :param filepath: absolute path of the directory whose contents to copy to the repository. - :param path: the relative path where to store the objects in the repository. - :raises TypeError: if the path is not a string and relative path. - :raises `~aiida.common.exceptions.ModificationNotAllowed`: when the node is stored and therefore immutable. - """ - self._check_mutability() - self._repository.put_object_from_tree(filepath, path) - self._update_repository_metadata() - - def walk(self, path: Optional[FilePath] = None) -> Iterable[Tuple[pathlib.PurePosixPath, List[str], List[str]]]: - """Walk over the directories and files contained within this repository. - - .. note:: the order of the dirname and filename lists that are returned is not necessarily sorted. This is in - line with the ``os.walk`` implementation where the order depends on the underlying file system used. - - :param path: the relative path of the directory within the repository whose contents to walk. - :return: tuples of root, dirnames and filenames just like ``os.walk``, with the exception that the root path is - always relative with respect to the repository root, instead of an absolute path and it is an instance of - ``pathlib.PurePosixPath`` instead of a normal string - """ - yield from self._repository.walk(path) - - def glob(self) -> Iterable[pathlib.PurePosixPath]: - """Yield a recursive list of all paths (files and directories).""" - for dirpath, dirnames, filenames in self.walk(): - for dirname in dirnames: - yield dirpath / dirname - for filename in filenames: - yield dirpath / filename - - def copy_tree(self, target: Union[str, pathlib.Path], path: Optional[FilePath] = None) -> None: - """Copy the contents of the entire node repository to another location on the local file system. - - :param target: absolute path of the directory where to copy the contents to. - :param path: optional relative path whose contents to copy. - """ - self._repository.copy_tree(target, path) - - def delete_object(self, path: str): - """Delete the object from the repository. - - :param key: fully qualified identifier for the object within the repository. - :raises TypeError: if the path is not a string and relative path. - :raises FileNotFoundError: if the file does not exist. - :raises IsADirectoryError: if the object is a directory and not a file. - :raises OSError: if the file could not be deleted. - :raises `~aiida.common.exceptions.ModificationNotAllowed`: when the node is stored and therefore immutable. - """ - self._check_mutability() - self._repository.delete_object(path) - self._update_repository_metadata() - - def erase(self): - """Delete all objects from the repository. - - :raises `~aiida.common.exceptions.ModificationNotAllowed`: when the node is stored and therefore immutable. - """ - self._check_mutability() - self._repository.erase() - self._update_repository_metadata() diff --git a/aiida/orm/querybuilder.py b/aiida/orm/querybuilder.py deleted file mode 100644 index 9e5fa36ff7..0000000000 --- a/aiida/orm/querybuilder.py +++ /dev/null @@ -1,1495 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=too-many-lines -""" -The QueryBuilder: A class that allows you to query the AiiDA database, independent from backend. -Note that the backend implementation is enforced and handled with a composition model! -:func:`QueryBuilder` is the frontend class that the user can use. It inherits from *object* and contains -backend-specific functionality. Backend specific functionality is provided by the implementation classes. - -These inherit from :func:`aiida.orm.implementation.querybuilder.BackendQueryBuilder`, -an interface classes which enforces the implementation of its defined methods. -An instance of one of the implementation classes becomes a member of the :func:`QueryBuilder` instance -when instantiated by the user. -""" -from __future__ import annotations - -from copy import deepcopy -from inspect import isclass as inspect_isclass -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Iterable, - List, - Literal, - NamedTuple, - Optional, - Sequence, - Set, - Tuple, - Type, - Union, - cast, - overload, -) -import warnings - -from aiida.common.log import AIIDA_LOGGER -from aiida.common.warnings import warn_deprecation -from aiida.manage import get_manager -from aiida.orm.entities import EntityTypes -from aiida.orm.implementation.querybuilder import ( - GROUP_ENTITY_TYPE_PREFIX, - BackendQueryBuilder, - EntityRelationships, - PathItemType, - QueryDictType, -) - -from . import authinfos, comments, computers, convert, entities, groups, logs, nodes, users - -if TYPE_CHECKING: - # pylint: disable=ungrouped-imports - from aiida.engine import Process - from aiida.orm.implementation import StorageBackend - -__all__ = ('QueryBuilder',) - -# re-usable type annotations -EntityClsType = Type[Union[entities.Entity, 'Process']] # pylint: disable=invalid-name -ProjectType = Union[str, dict, Sequence[Union[str, dict]]] # pylint: disable=invalid-name -FilterType = Dict[str, Any] # pylint: disable=invalid-name -OrderByType = Union[dict, List[dict], Tuple[dict, ...]] - -LOGGER = AIIDA_LOGGER.getChild('querybuilder') - - -class Classifier(NamedTuple): - """A classifier for an entity.""" - ormclass_type_string: str - process_type_string: Optional[str] = None - - -class QueryBuilder: - """ - The class to query the AiiDA database. - - Usage:: - - from aiida.orm.querybuilder import QueryBuilder - qb = QueryBuilder() - # Querying nodes: - qb.append(Node) - # retrieving the results: - results = qb.all() - - """ - - # pylint: disable=too-many-instance-attributes,too-many-public-methods - - # This tag defines how edges are tagged (labeled) by the QueryBuilder default - # namely tag of first entity + _EDGE_TAG_DELIM + tag of second entity - _EDGE_TAG_DELIM = '--' - _VALID_PROJECTION_KEYS = ('func', 'cast') - - def __init__( - self, - backend: Optional['StorageBackend'] = None, - *, - debug: bool | None = None, - path: Optional[Sequence[Union[str, Dict[str, Any], EntityClsType]]] = (), - filters: Optional[Dict[str, FilterType]] = None, - project: Optional[Dict[str, ProjectType]] = None, - limit: Optional[int] = None, - offset: Optional[int] = None, - order_by: Optional[OrderByType] = None, - distinct: bool = False, - ) -> None: - """ - Instantiates a QueryBuilder instance. - - Which backend is used decided here based on backend-settings (taken from the user profile). - This cannot be overridden so far by the user. - - :param debug: - Turn on debug mode. This feature prints information on the screen about the stages - of the QueryBuilder. Does not affect results. - :param path: - A list of the vertices to traverse. Leave empty if you plan on using the method - :func:`QueryBuilder.append`. - :param filters: - The filters to apply. You can specify the filters here, when appending to the query - using :func:`QueryBuilder.append` or even later using :func:`QueryBuilder.add_filter`. - Check latter gives API-details. - :param project: - The projections to apply. You can specify the projections here, when appending to the query - using :func:`QueryBuilder.append` or even later using :func:`QueryBuilder.add_projection`. - Latter gives you API-details. - :param limit: - Limit the number of rows to this number. Check :func:`QueryBuilder.limit` - for more information. - :param offset: - Set an offset for the results returned. Details in :func:`QueryBuilder.offset`. - :param order_by: - How to order the results. As the 2 above, can be set also at later stage, - check :func:`QueryBuilder.order_by` for more information. - :param distinct: Whether to return de-duplicated rows - - """ - self._backend = backend or get_manager().get_profile_storage() - self._impl: BackendQueryBuilder = self._backend.query() - - # SERIALISABLE ATTRIBUTES - # A list storing the path being traversed by the query - self._path: List[PathItemType] = [] - # map tags to filters - self._filters: Dict[str, FilterType] = {} - # map tags to projections: tag -> list(fields) -> func | cast -> value - self._projections: Dict[str, List[Dict[str, Dict[str, Any]]]] = {} - # list of mappings: tag -> list(fields) -> 'order' | 'cast' -> value (str('asc' | 'desc'), str(cast_key)) - self._order_by: List[Dict[str, List[Dict[str, Dict[str, str]]]]] = [] - self._limit: Optional[int] = None - self._offset: Optional[int] = None - self._distinct: bool = distinct - - # cache of tag mappings, populated during appends - self._tags = _QueryTagMap() - - # Set the debug level - if debug is not None: - warn_deprecation( - 'The `debug` argument is deprecated. Configure the log level of the AiiDA logger instead.', version=3 - ) - else: - debug = False - - with warnings.catch_warnings(): - warnings.simplefilter('ignore') - self.set_debug(debug) - - # Validate & add the query path - if not isinstance(path, (list, tuple)): - raise TypeError('Path needs to be a tuple or a list') - for path_spec in path: - if isinstance(path_spec, dict): - self.append(**path_spec) - elif isinstance(path_spec, str): - # Assume user means the entity_type - self.append(entity_type=path_spec) - else: - self.append(cls=path_spec) - # Validate & add projections - projection_dict = project or {} - if not isinstance(projection_dict, dict): - raise TypeError('You need to provide the projections as dictionary') - for key, val in projection_dict.items(): - self.add_projection(key, val) - # Validate & add filters - filter_dict = filters or {} - if not isinstance(filter_dict, dict): - raise TypeError('You need to provide the filters as dictionary') - for key, val in filter_dict.items(): - self.add_filter(key, val) - # Validate & add limit - self.limit(limit) - # Validate & add offset - self.offset(offset) - # Validate & add order_by - if order_by: - self.order_by(order_by) - - @property - def backend(self) -> 'StorageBackend': - """Return the backend used by the QueryBuilder.""" - return self._backend - - def as_dict(self, copy: bool = True) -> QueryDictType: - """Convert to a JSON serialisable dictionary representation of the query.""" - data: QueryDictType = { - 'path': self._path, - 'filters': self._filters, - 'project': self._projections, - 'order_by': self._order_by, - 'limit': self._limit, - 'offset': self._offset, - 'distinct': self._distinct, - } - if copy: - return deepcopy(data) - return data - - @property - def queryhelp(self) -> 'QueryDictType': - """"Legacy name for ``as_dict`` method.""" - warn_deprecation('`QueryBuilder.queryhelp` is deprecated, use `QueryBuilder.as_dict()` instead', version=3) - return self.as_dict() - - @classmethod - def from_dict(cls, dct: Dict[str, Any]) -> 'QueryBuilder': - """Create an instance from a dictionary representation of the query.""" - return cls(**dct) - - def __repr__(self) -> str: - """Return an unambiguous string representation of the instance.""" - params = ', '.join(f'{key}={value!r}' for key, value in self.as_dict(copy=False).items()) - return f'QueryBuilder({params})' - - def __str__(self) -> str: - """Return a readable string representation of the instance.""" - return repr(self) - - def __deepcopy__(self, memo) -> 'QueryBuilder': - """Create deep copy of the instance.""" - return type(self)(backend=self.backend, **self.as_dict()) # type: ignore - - def get_used_tags(self, vertices: bool = True, edges: bool = True) -> List[str]: - """Returns a list of all the vertices that are being used. - - :param vertices: If True, adds the tags of vertices to the returned list - :param edges: If True, adds the tags of edges to the returnend list. - - :returns: A list of tags - """ - given_tags = [] - for idx, path in enumerate(self._path): - if vertices: - given_tags.append(path['tag']) - if edges and idx > 0: - given_tags.append(path['edge_tag']) - return given_tags - - def _get_unique_tag(self, classifiers: List[Classifier]) -> str: - """ - Using the function get_tag_from_type, I get a tag. - I increment an index that is appended to that tag until I have an unused tag. - This function is called in :func:`QueryBuilder.append` when no tag is given. - - :param dict classifiers: - Classifiers, containing the string that defines the type of the AiiDA ORM class. - For subclasses of Node, this is the Node._plugin_type_string, for other they are - as defined as returned by :func:`QueryBuilder._get_ormclass`. - - Can also be a list of dictionaries, when multiple classes are passed to QueryBuilder.append - - :returns: A tag as a string (it is a single string also when passing multiple classes). - """ - basetag = '-'.join([t.ormclass_type_string.rstrip('.').split('.')[-1] or 'node' for t in classifiers]) - for i in range(1, 100): - tag = f'{basetag}_{i}' - if tag not in self._tags: - return tag - - raise RuntimeError('Cannot find a tag after 100 tries') - - def append( - self, - cls: Optional[Union[EntityClsType, Sequence[EntityClsType]]] = None, - entity_type: Optional[Union[str, Sequence[str]]] = None, - tag: Optional[str] = None, - filters: Optional[FilterType] = None, - project: Optional[ProjectType] = None, - subclassing: bool = True, - edge_tag: Optional[str] = None, - edge_filters: Optional[FilterType] = None, - edge_project: Optional[ProjectType] = None, - outerjoin: bool = False, - joining_keyword: Optional[str] = None, - joining_value: Optional[Any] = None, - orm_base: Optional[str] = None, # pylint: disable=unused-argument - **kwargs: Any - ) -> 'QueryBuilder': - """ - Any iterative procedure to build the path for a graph query - needs to invoke this method to append to the path. - - :param cls: - The Aiida-class (or backend-class) defining the appended vertice. - Also supports a tuple/list of classes. This results in an all instances of - this class being accepted in a query. However the classes have to have the same orm-class - for the joining to work. I.e. both have to subclasses of Node. Valid is:: - - cls=(StructureData, Dict) - - This is invalid: - - cls=(Group, Node) - - :param entity_type: The node type of the class, if cls is not given. Also here, a tuple or list is accepted. - :param tag: - A unique tag. If none is given, I will create a unique tag myself. - :param filters: - Filters to apply for this vertex. - See :meth:`.add_filter`, the method invoked in the background, or usage examples for details. - :param project: - Projections to apply. See usage examples for details. - More information also in :meth:`.add_projection`. - :param subclassing: - Whether to include subclasses of the given class (default **True**). - E.g. Specifying a ProcessNode as cls will include CalcJobNode, WorkChainNode, CalcFunctionNode, etc.. - :param edge_tag: - The tag that the edge will get. If nothing is specified - (and there is a meaningful edge) the default is tag1--tag2 with tag1 being the entity joining - from and tag2 being the entity joining to (this entity). - :param edge_filters: - The filters to apply on the edge. Also here, details in :meth:`.add_filter`. - :param edge_project: - The project from the edges. API-details in :meth:`.add_projection`. - :param outerjoin: - If True, (default is False), will do a left outerjoin - instead of an inner join - - Joining can be specified in two ways: - - - Specifying the 'joining_keyword' and 'joining_value' arguments - - Specify a single keyword argument - - The joining keyword wil be ``with_*`` or ``direction``, depending on the joining entity type. - The joining value is the tag name or class of the entity to join to. - - A small usage example how this can be invoked:: - - qb = QueryBuilder() # Instantiating empty querybuilder instance - qb.append(cls=StructureData) # First item is StructureData node - # The - # next node in the path is a PwCalculation, with - # the structure joined as an input - qb.append( - cls=PwCalculation, - with_incoming=StructureData - ) - - :return: self - """ - # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements - # INPUT CHECKS ########################## - # This function can be called by users, so I am checking the input now. - # First of all, let's make sure the specified the class or the type (not both) - - if cls is not None and entity_type is not None: - raise ValueError(f'You cannot specify both a class ({cls}) and a entity_type ({entity_type})') - - if cls is None and entity_type is None: - raise ValueError('You need to specify at least a class or a entity_type') - - # Let's check if it is a valid class or type - if cls: - if isinstance(cls, (list, tuple)): - for sub_cls in cls: - if not inspect_isclass(sub_cls): - raise TypeError(f"{sub_cls} was passed with kw 'cls', but is not a class") - elif not inspect_isclass(cls): - raise TypeError(f"{cls} was passed with kw 'cls', but is not a class") - elif entity_type is not None: - if isinstance(entity_type, (list, tuple)): - for sub_type in entity_type: - if not isinstance(sub_type, str): - raise TypeError(f'{sub_type} was passed as entity_type, but is not a string') - elif not isinstance(entity_type, str): - raise TypeError(f'{entity_type} was passed as entity_type, but is not a string') - - ormclass, classifiers = _get_ormclass(cls, entity_type) - - # TAG ################################# - # Let's get a tag - if tag: - if self._EDGE_TAG_DELIM in tag: - raise ValueError( - f'tag cannot contain {self._EDGE_TAG_DELIM}\nsince this is used as a delimiter for links' - ) - if tag in self._tags: - raise ValueError(f'This tag ({tag}) is already in use') - else: - tag = self._get_unique_tag(classifiers) - - # Checks complete - # This is where I start doing changes to self! - # Now, several things can go wrong along the way, so I need to split into - # atomic blocks that I can reverse if something goes wrong. - - # TAG ALIASING ############################## - try: - self._tags.add(tag, ormclass, cls) - except Exception as exception: - LOGGER.debug('Exception caught in append, cleaning up: %s', exception) - self._tags.remove(tag) - raise - - # FILTERS ###################################### - try: - self._filters[tag] = {} - # Subclassing is currently only implemented for the `Node` and `Group` classes. - # So for those cases we need to construct the correct filters, - # corresponding to the provided classes and value of `subclassing`. - if ormclass == EntityTypes.NODE: - self._add_node_type_filter(tag, classifiers, subclassing) - self._add_process_type_filter(tag, classifiers, subclassing) - - elif ormclass == EntityTypes.GROUP: - self._add_group_type_filter(tag, classifiers, subclassing) - - # The order has to be first _add_node_type_filter and then add_filter. - # If the user adds a query on the type column, it overwrites what I did - # if the user specified a filter, add it: - if filters is not None: - self.add_filter(tag, filters) - except Exception as exception: - LOGGER.debug('Exception caught in append, cleaning up: %s', exception) - self._tags.remove(tag) - self._filters.pop(tag) - raise - - # PROJECTIONS ############################## - try: - self._projections[tag] = [] - if project is not None: - self.add_projection(tag, project) - except Exception as exception: - LOGGER.debug('Exception caught in append, cleaning up: %s', exception) - self._tags.remove(tag) - self._filters.pop(tag) - self._projections.pop(tag) - raise exception - - # JOINING ##################################### - # pylint: disable=too-many-nested-blocks - try: - # Get the functions that are implemented: - spec_to_function_map = set(EntityRelationships[ormclass.value]) - if ormclass == EntityTypes.NODE: - # 'direction 'was an old implementation, which is now converted below to with_outgoing or with_incoming - spec_to_function_map.add('direction') - for key, val in kwargs.items(): - if key not in spec_to_function_map: - raise ValueError( - f"'{key}' is not a valid keyword for {ormclass.value!r} joining specification\n" - f'Valid keywords are: {spec_to_function_map or []!r}' - ) - if joining_keyword: - raise ValueError( - 'You already specified joining specification {}\n' - 'But you now also want to specify {}' - ''.format(joining_keyword, key) - ) - - joining_keyword = key - if joining_keyword == 'direction': - if not isinstance(val, int): - raise TypeError('direction=n expects n to be an integer') - try: - if val < 0: - joining_keyword = 'with_outgoing' - elif val > 0: - joining_keyword = 'with_incoming' - else: - raise ValueError('direction=0 is not valid') - joining_value = self._path[-abs(val)]['tag'] - except IndexError as exc: - raise ValueError( - f'You have specified a non-existent entity with\ndirection={joining_value}\n{exc}\n' - ) - else: - joining_value = self._tags.get(val) - - if joining_keyword is None and len(self._path) > 0: - # the default is that this vertice is 'with_incoming' as the previous one - if ormclass == EntityTypes.NODE: - joining_keyword = 'with_incoming' - else: - joining_keyword = 'with_node' - joining_value = self._path[-1]['tag'] - - except Exception as exception: - LOGGER.debug('Exception caught in append (part filters), cleaning up: %s', exception) - self._tags.remove(tag) - self._filters.pop(tag) - self._projections.pop(tag) - # There's not more to clean up here! - raise exception - - # EDGES ################################# - if len(self._path) > 0: - joining_value = cast(str, joining_value) - try: - if edge_tag is None: - edge_destination_tag = self._tags.get(joining_value) - edge_tag = edge_destination_tag + self._EDGE_TAG_DELIM + tag - else: - if edge_tag in self._tags: - raise ValueError(f'The tag {edge_tag} is already in use') - LOGGER.debug('edge_tag chosen: %s', edge_tag) - - # edge tags do not have an ormclass - self._tags.add(edge_tag) - - # Filters on links: - # Beware, I alway add this entry now, but filtering here might be - # non-sensical, since this ONLY works for m2m relationship where - # I go through a different table - self._filters[edge_tag] = {} - if edge_filters is not None: - self.add_filter(edge_tag, edge_filters) - # Projections on links - self._projections[edge_tag] = [] - if edge_project is not None: - self.add_projection(edge_tag, edge_project) - except Exception as exception: - LOGGER.debug('Exception caught in append (part joining), cleaning up %s', exception) - self._tags.remove(tag) - self._filters.pop(tag) - self._projections.pop(tag) - if edge_tag is not None: - self._tags.remove(edge_tag) - self._filters.pop(edge_tag, None) - self._projections.pop(edge_tag, None) - # There's not more to clean up here! - raise exception - - # EXTENDING THE PATH ################################# - # Note: 'type' being a list is a relict of an earlier implementation - # Could simply pass all classifiers here. - path_type: Union[List[str], str] - if len(classifiers) > 1: - path_type = [c.ormclass_type_string for c in classifiers] - else: - path_type = classifiers[0].ormclass_type_string - - self._path.append( - dict( - entity_type=path_type, - orm_base=ormclass.value, # type: ignore[typeddict-item] - tag=tag, - # for the first item joining_keyword/joining_value can be None, - # but after they always default to 'with_incoming' of the previous item - joining_keyword=joining_keyword, # type: ignore - joining_value=joining_value, # type: ignore - # same for edge_tag for which a default is applied - edge_tag=edge_tag, # type: ignore - outerjoin=outerjoin, - ) - ) - - return self - - def order_by(self, order_by: OrderByType) -> 'QueryBuilder': - """ - Set the entity to order by - - :param order_by: - This is a list of items, where each item is a dictionary specifies - what to sort for an entity - - In each dictionary in that list, keys represent valid tags of - entities (tables), and values are list of columns. - - Usage:: - - #Sorting by id (ascending): - qb = QueryBuilder() - qb.append(Node, tag='node') - qb.order_by({'node':['id']}) - - # or - #Sorting by id (ascending): - qb = QueryBuilder() - qb.append(Node, tag='node') - qb.order_by({'node':[{'id':{'order':'asc'}}]}) - - # for descending order: - qb = QueryBuilder() - qb.append(Node, tag='node') - qb.order_by({'node':[{'id':{'order':'desc'}}]}) - - # or (shorter) - qb = QueryBuilder() - qb.append(Node, tag='node') - qb.order_by({'node':[{'id':'desc'}]}) - """ - # pylint: disable=too-many-nested-blocks,too-many-branches - self._order_by = [] - allowed_keys = ('cast', 'order') - possible_orders = ('asc', 'desc') - - if not isinstance(order_by, (list, tuple)): - order_by = [order_by] - - for order_spec in order_by: - if not isinstance(order_spec, dict): - raise TypeError( - f'Invalid input for order_by statement: {order_spec!r}\n' - 'Expecting a dictionary like: {tag: field} or {tag: [field1, field2, ...]}' - ) - _order_spec: dict = {} - for tagspec, items_to_order_by in order_spec.items(): - if not isinstance(items_to_order_by, (tuple, list)): - items_to_order_by = [items_to_order_by] - tag = self._tags.get(tagspec) - _order_spec[tag] = [] - for item_to_order_by in items_to_order_by: - if isinstance(item_to_order_by, str): - item_to_order_by = {item_to_order_by: {}} - elif isinstance(item_to_order_by, dict): - pass - else: - raise ValueError( - f'Cannot deal with input to order_by {item_to_order_by}\nof type{type(item_to_order_by)}\n' - ) - for entityname, orderspec in item_to_order_by.items(): - # if somebody specifies eg {'node':{'id':'asc'}} - # tranform to {'node':{'id':{'order':'asc'}}} - - if isinstance(orderspec, str): - this_order_spec = {'order': orderspec} - elif isinstance(orderspec, dict): - this_order_spec = orderspec - else: - raise TypeError( - 'I was expecting a string or a dictionary\n' - 'You provided {} {}\n' - ''.format(type(orderspec), orderspec) - ) - for key in this_order_spec: - if key not in allowed_keys: - raise ValueError( - 'The allowed keys for an order specification\n' - 'are {}\n' - '{} is not valid\n' - ''.format(', '.join(allowed_keys), key) - ) - this_order_spec['order'] = this_order_spec.get('order', 'asc') - if this_order_spec['order'] not in possible_orders: - raise ValueError( - 'You gave {} as an order parameters,\n' - 'but it is not a valid order parameter\n' - 'Valid orders are: {}\n' - ''.format(this_order_spec['order'], possible_orders) - ) - item_to_order_by[entityname] = this_order_spec - - _order_spec[tag].append(item_to_order_by) - - self._order_by.append(_order_spec) - return self - - def add_filter(self, tagspec: Union[str, EntityClsType], filter_spec: FilterType) -> 'QueryBuilder': - """ - Adding a filter to my filters. - - :param tagspec: A tag string or an ORM class which maps to an existing tag - :param filter_spec: The specifications for the filter, has to be a dictionary - - Usage:: - - qb = QueryBuilder() # Instantiating the QueryBuilder instance - qb.append(Node, tag='node') # Appending a Node - #let's put some filters: - qb.add_filter('node',{'id':{'>':12}}) - # 2 filters together: - qb.add_filter('node',{'label':'foo', 'uuid':{'like':'ab%'}}) - # Now I am overriding the first filter I set: - qb.add_filter('node',{'id':13}) - """ - filters = self._process_filters(filter_spec) - tag = self._tags.get(tagspec) - self._filters[tag].update(filters) - return self - - @staticmethod - def _process_filters(filters: FilterType) -> Dict[str, Any]: - """Process filters.""" - if not isinstance(filters, dict): - raise TypeError('Filters have to be passed as dictionaries') - - processed_filters = {} - - for key, value in filters.items(): - if isinstance(value, entities.Entity): - # Convert to be the id of the joined entity because we can't query - # for the object instance directly - processed_filters[f'{key}_id'] = value.pk - else: - processed_filters[key] = value - - return processed_filters - - def _add_node_type_filter(self, tagspec: str, classifiers: List[Classifier], subclassing: bool): - """ - Add a filter based on node type. - - :param tagspec: The tag, which has to exist already as a key in self._filters - :param classifiers: a dictionary with classifiers - :param subclassing: if True, allow for subclasses of the ormclass - """ - if len(classifiers) > 1: - # If a list was passed to QueryBuilder.append, this propagates to a list in the classifiers - entity_type_filter: dict = {'or': []} - for classifier in classifiers: - entity_type_filter['or'].append(_get_node_type_filter(classifier, subclassing)) - else: - entity_type_filter = _get_node_type_filter(classifiers[0], subclassing) - - self.add_filter(tagspec, {'node_type': entity_type_filter}) - - def _add_process_type_filter(self, tagspec: str, classifiers: List[Classifier], subclassing: bool) -> None: - """ - Add a filter based on process type. - - :param tagspec: The tag, which has to exist already as a key in self._filters - :param classifiers: a dictionary with classifiers - :param subclassing: if True, allow for subclasses of the process type - - Note: This function handles the case when process_type_string is None. - """ - if len(classifiers) > 1: - # If a list was passed to QueryBuilder.append, this propagates to a list in the classifiers - process_type_filter: dict = {'or': []} - for classifier in classifiers: - if classifier.process_type_string is not None: - process_type_filter['or'].append(_get_process_type_filter(classifier, subclassing)) - - if len(process_type_filter['or']) > 0: - self.add_filter(tagspec, {'process_type': process_type_filter}) - - else: - if classifiers[0].process_type_string is not None: - process_type_filter = _get_process_type_filter(classifiers[0], subclassing) - self.add_filter(tagspec, {'process_type': process_type_filter}) - - def _add_group_type_filter(self, tagspec: str, classifiers: List[Classifier], subclassing: bool) -> None: - """ - Add a filter based on group type. - - :param tagspec: The tag, which has to exist already as a key in self._filters - :param classifiers: a dictionary with classifiers - :param subclassing: if True, allow for subclasses of the ormclass - """ - if len(classifiers) > 1: - # If a list was passed to QueryBuilder.append, this propagates to a list in the classifiers - type_string_filter: dict = {'or': []} - for classifier in classifiers: - type_string_filter['or'].append(_get_group_type_filter(classifier, subclassing)) - else: - type_string_filter = _get_group_type_filter(classifiers[0], subclassing) - - self.add_filter(tagspec, {'type_string': type_string_filter}) - - def add_projection(self, tag_spec: Union[str, EntityClsType], projection_spec: ProjectType) -> None: - r"""Adds a projection - - :param tag_spec: A tag string or an ORM class which maps to an existing tag - :param projection_spec: - The specification for the projection. - A projection is a list of dictionaries, with each dictionary - containing key-value pairs where the key is database entity - (e.g. a column / an attribute) and the value is (optional) - additional information on how to process this database entity. - - If the given *projection_spec* is not a list, it will be expanded to - a list. - If the listitems are not dictionaries, but strings (No additional - processing of the projected results desired), they will be expanded to - dictionaries. - - Usage:: - - qb = QueryBuilder() - qb.append(StructureData, tag='struc') - - # Will project the uuid and the kinds - qb.add_projection('struc', ['uuid', 'attributes.kinds']) - - The above example will project the uuid and the kinds-attribute of all matching structures. - There are 2 (so far) special keys. - - The single star *\** will project the *ORM-instance*:: - - qb = QueryBuilder() - qb.append(StructureData, tag='struc') - # Will project the ORM instance - qb.add_projection('struc', '*') - print type(qb.first()[0]) - # >>> aiida.orm.nodes.data.structure.StructureData - - The double star ``**`` projects all possible projections of this entity: - - QueryBuilder().append(StructureData,tag='s', project='**').limit(1).dict()[0]['s'].keys() - - # >>> 'user_id, description, ctime, label, extras, mtime, id, attributes, dbcomputer_id, type, uuid' - - Be aware that the result of ``**`` depends on the backend implementation. - - """ - tag = self._tags.get(tag_spec) - _projections = [] - LOGGER.debug('Adding projection of %s: %s', tag_spec, projection_spec) - if not isinstance(projection_spec, (list, tuple)): - projection_spec = [projection_spec] # type: ignore - for projection in projection_spec: - if isinstance(projection, dict): - _thisprojection = projection - elif isinstance(projection, str): - _thisprojection = {projection: {}} - else: - raise ValueError(f'Cannot deal with projection specification {projection}\n') - for spec in _thisprojection.values(): - if not isinstance(spec, dict): - raise TypeError( - f'\nThe value of a key-value pair in a projection\nhas to be a dictionary\nYou gave: {spec}\n' - ) - - for key, val in spec.items(): - if key not in self._VALID_PROJECTION_KEYS: - raise ValueError(f'{key} is not a valid key {self._VALID_PROJECTION_KEYS}') - if not isinstance(val, str): - raise TypeError(f'{val} has to be a string') - _projections.append(_thisprojection) - LOGGER.debug('projections have become: %s', _projections) - self._projections[tag] = _projections - - def set_debug(self, debug: bool) -> 'QueryBuilder': - """ - Run in debug mode. This does not affect functionality, but prints intermediate stages - when creating a query on screen. - - :param debug: Turn debug on or off - """ - warn_deprecation( - '`QueryBuilder.set_debug` is deprecated. Configure the log level of the AiiDA logger instead.', version=3 - ) - if not isinstance(debug, bool): - return TypeError('I expect a boolean') - self._debug = debug - - return self - - def debug(self, msg: str, *objects: Any) -> None: - """Log debug message. - - objects will passed to the format string, e.g. ``msg % objects`` - """ - warn_deprecation('`QueryBuilder.debug` is deprecated.', version=3) - if self._debug: - print(f'DEBUG: {msg}' % objects) - - def limit(self, limit: Optional[int]) -> 'QueryBuilder': - """ - Set the limit (nr of rows to return) - - :param limit: integers of number of rows of rows to return - """ - if (limit is not None) and (not isinstance(limit, int)): - raise TypeError('The limit has to be an integer, or None') - self._limit = limit - return self - - def offset(self, offset: Optional[int]) -> 'QueryBuilder': - """ - Set the offset. If offset is set, that many rows are skipped before returning. - *offset* = 0 is the same as omitting setting the offset. - If both offset and limit appear, - then *offset* rows are skipped before starting to count the *limit* rows - that are returned. - - :param offset: integers of nr of rows to skip - """ - if (offset is not None) and (not isinstance(offset, int)): - raise TypeError('offset has to be an integer, or None') - self._offset = offset - return self - - def distinct(self, value: bool = True) -> 'QueryBuilder': - """ - Asks for distinct rows, which is the same as asking the backend to remove - duplicates. - Does not execute the query! - - If you want a distinct query:: - - qb = QueryBuilder() - # append stuff! - qb.append(...) - qb.append(...) - ... - qb.distinct().all() #or - qb.distinct().dict() - - :returns: self - """ - if not isinstance(value, bool): - raise TypeError(f'distinct() takes a boolean as parameter, not {value!r}') - self._distinct = value - return self - - def inputs(self, **kwargs: Any) -> 'QueryBuilder': - """ - Join to inputs of previous vertice in path. - - :returns: self - """ - from aiida.orm import Node - join_to = self._path[-1]['tag'] - cls = kwargs.pop('cls', Node) - self.append(cls=cls, with_outgoing=join_to, **kwargs) - return self - - def outputs(self, **kwargs: Any) -> 'QueryBuilder': - """ - Join to outputs of previous vertice in path. - - :returns: self - """ - from aiida.orm import Node - join_to = self._path[-1]['tag'] - cls = kwargs.pop('cls', Node) - self.append(cls=cls, with_incoming=join_to, **kwargs) - return self - - def children(self, **kwargs: Any) -> 'QueryBuilder': - """ - Join to children/descendants of previous vertice in path. - - :returns: self - """ - from aiida.orm import Node - join_to = self._path[-1]['tag'] - cls = kwargs.pop('cls', Node) - self.append(cls=cls, with_ancestors=join_to, **kwargs) - return self - - def parents(self, **kwargs: Any) -> 'QueryBuilder': - """ - Join to parents/ancestors of previous vertice in path. - - :returns: self - """ - from aiida.orm import Node - join_to = self._path[-1]['tag'] - cls = kwargs.pop('cls', Node) - self.append(cls=cls, with_descendants=join_to, **kwargs) - return self - - def as_sql(self, inline: bool = False) -> str: - """Convert the query to an SQL string representation. - - .. warning:: - - This method should be used for debugging purposes only, - since normally sqlalchemy will handle this process internally. - - :params inline: Inline bound parameters (this is normally handled by the Python DB-API). - """ - return self._impl.as_sql(data=self.as_dict(), inline=inline) - - def analyze_query(self, execute: bool = True, verbose: bool = False) -> str: - """Return the query plan, i.e. a list of SQL statements that will be executed. - - See: https://www.postgresql.org/docs/11/sql-explain.html - - :params execute: Carry out the command and show actual run times and other statistics. - :params verbose: Display additional information regarding the plan. - """ - return self._impl.analyze_query(data=self.as_dict(), execute=execute, verbose=verbose) - - @staticmethod - def _get_aiida_entity_res(value) -> Any: - """Convert a projected query result to front end class if it is an instance of a `BackendEntity`. - - Values that are not an `BackendEntity` instance will be returned unaltered - - :param value: a projected query result to convert - :return: the converted value - """ - try: - return convert.get_orm_entity(value) - except TypeError: - return value - - @overload - def first(self, flat: Literal[False] = False) -> Optional[list[Any]]: - ... - - @overload - def first(self, flat: Literal[True]) -> Optional[Any]: - ... - - def first(self, flat: bool = False) -> Optional[list[Any] | Any]: - """Return the first result of the query. - - Calling ``first`` results in an execution of the underlying query. - - Note, this may change if several rows are valid for the query, as persistent ordering is not guaranteed unless - explicitly specified. - - :param flat: if True, return just the projected quantity if there is just a single projection. - :returns: One row of results as a list, or None if no result returned. - """ - result = self._impl.first(self.as_dict()) - - if result is None: - return None - - result = [self._get_aiida_entity_res(rowitem) for rowitem in result] - - if flat and len(result) == 1: - return result[0] - - return result - - def count(self) -> int: - """ - Counts the number of rows returned by the backend. - - :returns: the number of rows as an integer - """ - return self._impl.count(self.as_dict()) - - def iterall(self, batch_size: Optional[int] = 100) -> Iterable[List[Any]]: - """ - Same as :meth:`.all`, but returns a generator. - Be aware that this is only safe if no commit will take place during this - transaction. You might also want to read the SQLAlchemy documentation on - https://docs.sqlalchemy.org/en/14/orm/query.html#sqlalchemy.orm.Query.yield_per - - :param batch_size: - The size of the batches to ask the backend to batch results in subcollections. - You can optimize the speed of the query by tuning this parameter. - - :returns: a generator of lists - """ - for item in self._impl.iterall(self.as_dict(), batch_size): - # Convert to AiiDA frontend entities (if they are such) - for i, item_entry in enumerate(item): - item[i] = self._get_aiida_entity_res(item_entry) - - yield item - - def iterdict(self, batch_size: Optional[int] = 100) -> Iterable[Dict[str, Dict[str, Any]]]: - """ - Same as :meth:`.dict`, but returns a generator. - Be aware that this is only safe if no commit will take place during this - transaction. You might also want to read the SQLAlchemy documentation on - https://docs.sqlalchemy.org/en/14/orm/query.html#sqlalchemy.orm.Query.yield_per - - :param batch_size: - The size of the batches to ask the backend to batch results in subcollections. - You can optimize the speed of the query by tuning this parameter. - - :returns: a generator of dictionaries - """ - for item in self._impl.iterdict(self.as_dict(), batch_size): - for key, value in item.items(): - item[key] = self._get_aiida_entity_res(value) - - yield item - - def all(self, batch_size: Optional[int] = None, flat: bool = False) -> Union[List[List[Any]], List[Any]]: - """Executes the full query with the order of the rows as returned by the backend. - - The order inside each row is given by the order of the vertices in the path and the order of the projections for - each vertex in the path. - - :param batch_size: the size of the batches to ask the backend to batch results in subcollections. You can - optimize the speed of the query by tuning this parameter. Leave the default `None` if speed is not critical - or if you don't know what you're doing. - :param flat: return the result as a flat list of projected entities without sub lists. - :returns: a list of lists of all projected entities. - """ - matches = list(self.iterall(batch_size=batch_size)) - - if not flat: - return matches - - return [projection for entry in matches for projection in entry] - - def one(self) -> List[Any]: - """Executes the query asking for exactly one results. - - Will raise an exception if this is not the case: - - :raises: MultipleObjectsError if more then one row can be returned - :raises: NotExistent if no result was found - """ - from aiida.common.exceptions import MultipleObjectsError, NotExistent - limit = self._limit - self.limit(2) - try: - res = self.all() - finally: - self.limit(limit) - if len(res) > 1: - raise MultipleObjectsError('More than one result was found') - elif len(res) == 0: - raise NotExistent('No result was found') - return res[0] - - def dict(self, batch_size: Optional[int] = None) -> List[Dict[str, Dict[str, Any]]]: - """ - Executes the full query with the order of the rows as returned by the backend. - the order inside each row is given by the order of the vertices in the path - and the order of the projections for each vertice in the path. - - :param batch_size: - The size of the batches to ask the backend to batch results in subcollections. - You can optimize the speed of the query by tuning this parameter. - Leave the default (*None*) if speed is not critical or if you don't know what you're doing! - - :returns: A list of dictionaries of all projected entities: tag -> field -> value - - Usage:: - - qb = QueryBuilder() - qb.append( - StructureData, - tag='structure', - filters={'uuid':{'==':myuuid}}, - ) - qb.append( - Node, - with_ancestors='structure', - project=['entity_type', 'id'], # returns entity_type (string) and id (string) - tag='descendant' - ) - - # Return the dictionaries: - print "qb.iterdict()" - for d in qb.iterdict(): - print '>>>', d - - results in the following output:: - - qb.iterdict() - >>> {'descendant': { - 'entity_type': 'calculation.job.quantumespresso.pw.PwCalculation.', - 'id': 7716} - } - >>> {'descendant': { - 'entity_type': 'data.remote.RemoteData.', - 'id': 8510} - } - - """ - return list(self.iterdict(batch_size=batch_size)) - - -def _get_ormclass( - cls: Union[None, EntityClsType, Sequence[EntityClsType]], entity_type: Union[None, str, Sequence[str]] -) -> Tuple[EntityTypes, List[Classifier]]: - """Get ORM classifiers from either class(es) or ormclass_type_string(s). - - :param cls: a class or tuple/set/list of classes that are either AiiDA ORM classes or backend ORM classes. - :param ormclass_type_string: type string for ORM class - - :returns: the ORM class as well as a dictionary with additional classifier strings - - Handles the case of lists as well. - """ - if cls is not None: - func = _get_ormclass_from_cls - input_info = cls - elif entity_type is not None: - func = _get_ormclass_from_str # type: ignore - input_info = entity_type # type: ignore - else: - raise ValueError('Neither cls nor entity_type specified') - - if isinstance(input_info, str) or not isinstance(input_info, Sequence): - input_info = (input_info,) - - ormclass = EntityTypes.NODE - classifiers = [] - - for index, classifier in enumerate(input_info): - new_ormclass, new_classifiers = func(classifier) - if index: - # check consistency with previous item - if new_ormclass != ormclass: - raise ValueError('Non-matching types have been passed as list/tuple/set.') - else: - ormclass = new_ormclass - - classifiers.append(new_classifiers) - - return ormclass, classifiers - - -def _get_ormclass_from_cls(cls: EntityClsType) -> Tuple[EntityTypes, Classifier]: - """ - Return the correct classifiers for the QueryBuilder from an ORM class. - - :param cls: an AiiDA ORM class or backend ORM class. - :param query: an instance of the appropriate QueryBuilder backend. - :returns: the ORM class as well as a dictionary with additional classifier strings - - Note: the ormclass_type_string is currently hardcoded for group, computer etc. One could instead use something like - aiida.orm.utils.node.get_type_string_from_class(cls.__module__, cls.__name__) - """ - # pylint: disable=protected-access,too-many-branches,too-many-statements - # Note: Unable to move this import to the top of the module for some reason - from aiida.engine import Process - from aiida.orm.utils.node import is_valid_node_type_string - - classifiers: Classifier - - if issubclass(cls, nodes.Node): - classifiers = Classifier(cls.class_node_type) - ormclass = EntityTypes.NODE - elif issubclass(cls, groups.Group): - type_string = cls._type_string - assert type_string is not None, 'Group not registered as entry point' - classifiers = Classifier(GROUP_ENTITY_TYPE_PREFIX + type_string) - ormclass = EntityTypes.GROUP - elif issubclass(cls, computers.Computer): - classifiers = Classifier('computer') - ormclass = EntityTypes.COMPUTER - elif issubclass(cls, users.User): - classifiers = Classifier('user') - ormclass = EntityTypes.USER - elif issubclass(cls, authinfos.AuthInfo): - classifiers = Classifier('authinfo') - ormclass = EntityTypes.AUTHINFO - elif issubclass(cls, comments.Comment): - classifiers = Classifier('comment') - ormclass = EntityTypes.COMMENT - elif issubclass(cls, logs.Log): - classifiers = Classifier('log') - ormclass = EntityTypes.LOG - - # Process - # This is a special case, since Process is not an ORM class. - # We need to deduce the ORM class used by the Process. - elif issubclass(cls, Process): - classifiers = Classifier(cls._node_class._plugin_type_string, cls.build_process_type()) - ormclass = EntityTypes.NODE - - else: - raise ValueError(f'I do not know what to do with {cls}') - - if ormclass == EntityTypes.NODE: - is_valid_node_type_string(classifiers.ormclass_type_string, raise_on_false=True) - - return ormclass, classifiers - - -def _get_ormclass_from_str(type_string: str) -> Tuple[EntityTypes, Classifier]: - """Return the correct classifiers for the QueryBuilder from an ORM type string. - - :param type_string: type string for ORM class - :param query: an instance of the appropriate QueryBuilder backend. - :returns: the ORM class as well as a dictionary with additional classifier strings - """ - from aiida.orm.utils.node import is_valid_node_type_string - - classifiers: Classifier - type_string_lower = type_string.lower() - - if type_string_lower.startswith(GROUP_ENTITY_TYPE_PREFIX): - classifiers = Classifier('group.core') - ormclass = EntityTypes.GROUP - elif type_string_lower == EntityTypes.COMPUTER.value: - classifiers = Classifier('computer') - ormclass = EntityTypes.COMPUTER - elif type_string_lower == EntityTypes.USER.value: - classifiers = Classifier('user') - ormclass = EntityTypes.USER - elif type_string_lower == EntityTypes.LINK.value: - classifiers = Classifier('link') - ormclass = EntityTypes.LINK - else: - # At this point, we assume it is a node. The only valid type string then is a string - # that matches exactly the _plugin_type_string of a node class - is_valid_node_type_string(type_string, raise_on_false=True) - classifiers = Classifier(type_string) - ormclass = EntityTypes.NODE - - return ormclass, classifiers - - -def _get_node_type_filter(classifiers: Classifier, subclassing: bool) -> dict: - """ - Return filter dictionaries given a set of classifiers. - - :param classifiers: a dictionary with classifiers (note: does *not* support lists) - :param subclassing: if True, allow for subclasses of the ormclass - - :returns: dictionary in QueryBuilder filter language to pass into {"type": ... } - """ - from aiida.common.escaping import escape_for_sql_like - from aiida.orm.utils.node import get_query_type_from_type_string - value = classifiers.ormclass_type_string - - if not subclassing: - filters = {'==': value} - else: - # Note: the query_type_string always ends with a dot. This ensures that "like {str}%" matches *only* - # the query type string - filters = {'like': f'{escape_for_sql_like(get_query_type_from_type_string(value))}%'} - - return filters - - -def _get_process_type_filter(classifiers: Classifier, subclassing: bool) -> dict: - """ - Return filter dictionaries given a set of classifiers. - - :param classifiers: a dictionary with classifiers (note: does *not* support lists) - :param subclassing: if True, allow for subclasses of the process type - This is activated only, if an entry point can be found for the process type - (as well as for a selection of built-in process types) - - - :returns: dictionary in QueryBuilder filter language to pass into {"process_type": ... } - """ - from aiida.common.escaping import escape_for_sql_like - from aiida.common.warnings import AiidaEntryPointWarning - from aiida.engine.processes.process import get_query_string_from_process_type_string - - value = classifiers.process_type_string - assert value is not None - filters: Dict[str, Any] - - if not subclassing: - filters = {'==': value} - else: - if ':' in value: - # if value is an entry point, do usual subclassing - - # Note: the process_type_string stored in the database does *not* end in a dot. - # In order to avoid that querying for class 'Begin' will also find class 'BeginEnd', - # we need to search separately for equality and 'like'. - filters = { - 'or': [ - { - '==': value - }, - { - 'like': escape_for_sql_like(get_query_string_from_process_type_string(value)) - }, - ] - } - elif value.startswith('aiida.engine'): - # For core process types, a filter is not is needed since each process type has a corresponding - # ormclass type that already specifies everything. - # Note: This solution is fragile and will break as soon as there is not an exact one-to-one correspondence - # between process classes and node classes - - # Note: Improve this when issue https://github.com/aiidateam/aiida-core/issues/2475 is addressed - filters = {'like': '%'} - else: - warnings.warn( - "Process type '{value}' does not correspond to a registered entry. " - 'This risks queries to fail once the location of the process class changes. ' - "Add an entry point for '{value}' to remove this warning.".format(value=value), AiidaEntryPointWarning - ) - filters = { - 'or': [ - { - '==': value - }, - { - 'like': escape_for_sql_like(get_query_string_from_process_type_string(value)) - }, - ] - } - - return filters - - -class _QueryTagMap: - """Cache of tag mappings for a query.""" - - def __init__(self): - """Construct a new instance.""" - self._tag_to_type: Dict[str, Union[None, EntityTypes]] = {} - # A dictionary for classes passed to the tag given to them - # Everything is specified with unique tags, which are strings. - # But somebody might not care about giving tags, so to do - # everything with classes one needs a map, that also defines classes - # as tags, to allow the following example: - - # qb = QueryBuilder() - # qb.append(PwCalculation, tag='pwcalc') - # qb.append(StructureData, tag='structure', with_outgoing=PwCalculation) - - # The cls_to_tag_map in this case would be: - # {PwCalculation: {'pwcalc'}, StructureData: {'structure'}} - self._cls_to_tag_map: Dict[Any, Set[str]] = {} - - def __repr__(self) -> str: - return repr(list(self._tag_to_type)) - - def __contains__(self, tag: str) -> bool: - return tag in self._tag_to_type - - def __iter__(self): - return iter(self._tag_to_type) - - def add( - self, - tag: str, - etype: Union[None, EntityTypes] = None, - klasses: Union[None, EntityClsType, Sequence[EntityClsType]] = None - ) -> None: - """Add a tag.""" - self._tag_to_type[tag] = etype - # if a class was specified allow to get the tag given a class - if klasses: - tag_key = tuple(klasses) if isinstance(klasses, (list, set)) else klasses - self._cls_to_tag_map.setdefault(tag_key, set()).add(tag) - - def remove(self, tag: str) -> None: - """Remove a tag.""" - self._tag_to_type.pop(tag, None) - for tags in self._cls_to_tag_map.values(): - tags.discard(tag) - - def get(self, tag_or_cls: Union[str, EntityClsType]) -> str: - """Return the tag or, given a class(es), map to a tag. - - :raises ValueError: if the tag is not found, or the class(es) does not map to a single tag - """ - if isinstance(tag_or_cls, str): - if tag_or_cls in self: - return tag_or_cls - raise ValueError(f'Tag {tag_or_cls!r} is not among my known tags: {list(self)}') - if self._cls_to_tag_map.get(tag_or_cls, None): - if len(self._cls_to_tag_map[tag_or_cls]) != 1: - raise ValueError( - f'The object used as a tag ({tag_or_cls}) has multiple values associated with it: ' - f'{self._cls_to_tag_map[tag_or_cls]}' - ) - return list(self._cls_to_tag_map[tag_or_cls])[0] - raise ValueError(f'The given object ({tag_or_cls}) has no tags associated with it.') - - -def _get_group_type_filter(classifiers: Classifier, subclassing: bool) -> dict: - """Return filter dictionaries for `Group.type_string` given a set of classifiers. - - :param classifiers: a dictionary with classifiers (note: does *not* support lists) - :param subclassing: if True, allow for subclasses of the ormclass - - :returns: dictionary in QueryBuilder filter language to pass into {'type_string': ... } - """ - from aiida.common.escaping import escape_for_sql_like - - value = classifiers.ormclass_type_string[len(GROUP_ENTITY_TYPE_PREFIX):] - - if not subclassing: - filters = {'==': value} - else: - # This is a hardcoded solution to the problem that the base class `Group` should match all subclasses, however - # its entry point string is `core` and so will only match those subclasses whose entry point also starts with - # 'core', however, this is only the case for group subclasses shipped with `aiida-core`. Any plugins from - # external packages will never be matched. Making the entry point name of `Group` an empty string is also not - # possible so we perform the switch here in code. - if value == 'core': - value = '' - filters = {'like': f'{escape_for_sql_like(value)}%'} - - return filters diff --git a/aiida/orm/users.py b/aiida/orm/users.py deleted file mode 100644 index 3f2b49c288..0000000000 --- a/aiida/orm/users.py +++ /dev/null @@ -1,149 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module for the ORM user class.""" -from typing import TYPE_CHECKING, Optional, Tuple, Type - -from aiida.common import exceptions -from aiida.common.lang import classproperty -from aiida.manage import get_manager - -from . import entities - -if TYPE_CHECKING: - from aiida.orm.implementation import BackendUser, StorageBackend - -__all__ = ('User',) - - -class UserCollection(entities.Collection['User']): - """The collection of users stored in a backend.""" - - @staticmethod - def _entity_base_cls() -> Type['User']: - return User - - def get_or_create(self, email: str, **kwargs) -> Tuple[bool, 'User']: - """Get the existing user with a given email address or create an unstored one - - :param kwargs: The properties of the user to get or create - :return: The corresponding user object - :raises: :class:`aiida.common.exceptions.MultipleObjectsError`, - :class:`aiida.common.exceptions.NotExistent` - """ - try: - return False, self.get(email=email) - except exceptions.NotExistent: - return True, User(backend=self.backend, email=email, **kwargs) - - def get_default(self) -> Optional['User']: - """Get the current default user""" - return self.backend.default_user - - -class User(entities.Entity['BackendUser', UserCollection]): - """AiiDA User""" - - _CLS_COLLECTION = UserCollection - - def __init__( - self, - email: str, - first_name: str = '', - last_name: str = '', - institution: str = '', - backend: Optional['StorageBackend'] = None - ): - """Create a new `User`.""" - # pylint: disable=too-many-arguments - backend = backend or get_manager().get_profile_storage() - email = self.normalize_email(email) - backend_entity = backend.users.create( - email=email, first_name=first_name, last_name=last_name, institution=institution - ) - super().__init__(backend_entity) - - def __str__(self) -> str: - return self.email - - @staticmethod - def normalize_email(email: str) -> str: - """Normalize the address by lowercasing the domain part of the email address (taken from Django).""" - email = email or '' - try: - email_name, domain_part = email.strip().rsplit('@', 1) - except ValueError: - pass - else: - email = '@'.join([email_name, domain_part.lower()]) - return email - - @property - def email(self) -> str: - return self._backend_entity.email - - @email.setter - def email(self, email: str) -> None: - self._backend_entity.email = email - - @property - def first_name(self) -> str: - return self._backend_entity.first_name - - @first_name.setter - def first_name(self, first_name: str) -> None: - self._backend_entity.first_name = first_name - - @property - def last_name(self) -> str: - return self._backend_entity.last_name - - @last_name.setter - def last_name(self, last_name: str) -> None: - self._backend_entity.last_name = last_name - - @property - def institution(self) -> str: - return self._backend_entity.institution - - @institution.setter - def institution(self, institution: str) -> None: - self._backend_entity.institution = institution - - def get_full_name(self) -> str: - """ - Return the user full name - - :return: the user full name - """ - if self.first_name and self.last_name: - full_name = f'{self.first_name} {self.last_name} ({self.email})' - elif self.first_name: - full_name = f'{self.first_name} ({self.email})' - elif self.last_name: - full_name = f'{self.last_name} ({self.email})' - else: - full_name = f'{self.email}' - - return full_name - - def get_short_name(self) -> str: - """ - Return the user short name (typically, this returns the email) - - :return: The short name - """ - return self.email - - @property - def uuid(self) -> None: - """ - For now users do not have UUIDs so always return None - """ - return None diff --git a/aiida/orm/utils/__init__.py b/aiida/orm/utils/__init__.py deleted file mode 100644 index 16e7b146c1..0000000000 --- a/aiida/orm/utils/__init__.py +++ /dev/null @@ -1,49 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Utilities related to the ORM.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .calcjob import * -from .links import * -from .loaders import * -from .managers import * -from .node import * - -__all__ = ( - 'AbstractNodeMeta', - 'AttributeManager', - 'CalcJobResultManager', - 'CalculationEntityLoader', - 'CodeEntityLoader', - 'ComputerEntityLoader', - 'GroupEntityLoader', - 'LinkManager', - 'LinkPair', - 'LinkTriple', - 'NodeEntityLoader', - 'NodeLinksManager', - 'OrmEntityLoader', - 'get_loader', - 'get_query_type_from_type_string', - 'get_type_string_from_class', - 'load_code', - 'load_computer', - 'load_entity', - 'load_group', - 'load_node', - 'load_node_class', - 'validate_link', -) - -# yapf: enable diff --git a/aiida/orm/utils/builders/__init__.py b/aiida/orm/utils/builders/__init__.py deleted file mode 100644 index 2776a55f97..0000000000 --- a/aiida/orm/utils/builders/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### diff --git a/aiida/orm/utils/builders/code.py b/aiida/orm/utils/builders/code.py deleted file mode 100644 index e94db906ae..0000000000 --- a/aiida/orm/utils/builders/code.py +++ /dev/null @@ -1,222 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Manage code objects with lazy loading of the db env""" -import enum -import pathlib - -from aiida.cmdline.utils.decorators import with_dbenv -from aiida.common.utils import ErrorAccumulator -from aiida.common.warnings import warn_deprecation -from aiida.orm import InstalledCode, PortableCode - -warn_deprecation('This module is deprecated. To create a new code instance, simply use the constructor.', version=3) - - -class CodeBuilder: - """Build a code with validation of attribute combinations""" - - def __init__(self, **kwargs): - """Construct a new instance.""" - self._err_acc = ErrorAccumulator(self.CodeValidationError) - self._code_spec = {} - - # code_type must go first - for key in ['code_type']: - self.__setattr__(key, kwargs.pop(key)) - - # then set the rest - for key, value in kwargs.items(): - self.__setattr__(key, value) - - def validate(self, raise_error=True): - self._err_acc.run(self.validate_code_type) - self._err_acc.run(self.validate_upload) - self._err_acc.run(self.validate_installed) - return self._err_acc.result(raise_error=self.CodeValidationError if raise_error else False) - - @with_dbenv() - def new(self): - """Build and return a new code instance (not stored)""" - self.validate() - - # Will be used at the end to check if all keys are known (those that are not None) - passed_keys = set(k for k in self._code_spec.keys() if self._code_spec[k] is not None) - used = set() - - if self._get_and_count('code_type', used) == self.CodeType.STORE_AND_UPLOAD: - code = PortableCode( - filepath_executable=self._get_and_count('code_rel_path', used), - filepath_files=pathlib.Path(self._get_and_count('code_folder', used)) - ) - else: - code = InstalledCode( - computer=self._get_and_count('computer', used), - filepath_executable=self._get_and_count('remote_abs_path', used) - ) - - code.label = self._get_and_count('label', used) - code.description = self._get_and_count('description', used) - code.default_calc_job_plugin = self._get_and_count('input_plugin', used) - code.use_double_quotes = self._get_and_count('use_double_quotes', used) - code.prepend_text = self._get_and_count('prepend_text', used) - code.append_text = self._get_and_count('append_text', used) - - # Complain if there are keys that are passed but not used - if passed_keys - used: - raise self.CodeValidationError( - f"Unknown parameters passed to the CodeBuilder: {', '.join(sorted(passed_keys - used))}" - ) - - return code - - @staticmethod - def from_code(code): - """Create CodeBuilder from existing code instance. - - See also :py:func:`~CodeBuilder.get_code_spec` - """ - spec = CodeBuilder.get_code_spec(code) - return CodeBuilder(**spec) - - @staticmethod - def get_code_spec(code): - """Get code attributes from existing code instance. - - These attributes can be used to create a new CodeBuilder:: - - spec = CodeBuilder.get_code_spec(old_code) - builder = CodeBuilder(**spec) - new_code = builder.new() - - """ - spec = {} - spec['label'] = code.label - spec['description'] = code.description - spec['input_plugin'] = code.default_calc_job_plugin - spec['use_double_quotes'] = code.use_double_quotes - spec['prepend_text'] = code.prepend_text - spec['append_text'] = code.append_text - - if isinstance(code, PortableCode): - spec['code_type'] = CodeBuilder.CodeType.STORE_AND_UPLOAD - spec['code_folder'] = code.get_code_folder() - spec['code_rel_path'] = code.get_code_rel_path() - else: - spec['code_type'] = CodeBuilder.CodeType.ON_COMPUTER - spec['computer'] = code.computer - spec['remote_abs_path'] = str(code.get_executable()) - - return spec - - def __getattr__(self, key): - """Access code attributes used to build the code""" - if not key.startswith('_'): - try: - return self._code_spec[key] - except KeyError: - raise KeyError(f"Attribute '{key}' not set") - return None - - def _get(self, key): - """ - Return a spec, or None if not defined - - :param key: name of a code spec - """ - return self._code_spec.get(key) - - def _get_and_count(self, key, used): - """ - Return a spec, or raise if not defined. - Moreover, add the key to the 'used' dict. - - :param key: name of a code spec - :param used: should be a set of keys that you want to track. - ``key`` will be added to this set if the value exists in the spec and can be retrieved. - """ - retval = self.__getattr__(key) - # I first get a retval, so if I get an exception, I don't add it to the 'used' set - used.add(key) - return retval - - def __setattr__(self, key, value): - if not key.startswith('_'): - self._set_code_attr(key, value) - super().__setattr__(key, value) - - def _set_code_attr(self, key, value): - """Set a code attribute, if it passes validation. - - Checks compatibility with other code attributes. - """ - if key == 'description' and value is None: - value = '' - - backup = self._code_spec.copy() - self._code_spec[key] = value - success, _ = self.validate(raise_error=False) - if not success: - self._code_spec = backup - self.validate() - - def validate_code_type(self): - """Make sure the code type is set correctly""" - if self._get('code_type') and self.code_type not in self.CodeType: - raise self.CodeValidationError( - f'invalid code type: must be one of {list(self.CodeType)}, not {self.code_type}' - ) - - def validate_upload(self): - """If the code is stored and uploaded, catch invalid on-computer attributes""" - messages = [] - if self.is_local(): - if self._get('computer'): - messages.append('invalid option for store-and-upload code: "computer"') - if self._get('remote_abs_path'): - messages.append('invalid option for store-and-upload code: "remote_abs_path"') - if messages: - raise self.CodeValidationError(f'{messages}') - - def validate_installed(self): - """If the code is on-computer, catch invalid store-and-upload attributes""" - messages = [] - if self._get('code_type') == self.CodeType.ON_COMPUTER: - if self._get('code_folder'): - messages.append('invalid options for on-computer code: "code_folder"') - if self._get('code_rel_path'): - messages.append('invalid options for on-computer code: "code_rel_path"') - if messages: - raise self.CodeValidationError(f'{messages}') - - class CodeValidationError(ValueError): - """ - A CodeBuilder instance may raise this - - * when asked to instanciate a code with missing or invalid code attributes - * when asked for a code attibute that has not been set yet - """ - - def __init__(self, msg): - super().__init__() - self.msg = msg - - def __str__(self): - return self.msg - - def __repr__(self): - return f'' - - def is_local(self): - """Analogous to Code.is_local()""" - return self.__getattr__('code_type') == self.CodeType.STORE_AND_UPLOAD - - class CodeType(enum.Enum): - STORE_AND_UPLOAD = 'store in the db and upload' - ON_COMPUTER = 'on computer' diff --git a/aiida/orm/utils/builders/computer.py b/aiida/orm/utils/builders/computer.py deleted file mode 100644 index 3f9e9da6de..0000000000 --- a/aiida/orm/utils/builders/computer.py +++ /dev/null @@ -1,192 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Manage computer objects with lazy loading of the db env""" -from aiida.cmdline.utils.decorators import with_dbenv -from aiida.common.exceptions import ValidationError -from aiida.common.utils import ErrorAccumulator - - -class ComputerBuilder: # pylint: disable=too-many-instance-attributes - """Build a computer with validation of attribute combinations""" - - @staticmethod - def from_computer(computer): - """Create ComputerBuilder from existing computer instance. - - See also :py:func:`~ComputerBuilder.get_computer_spec`""" - spec = ComputerBuilder.get_computer_spec(computer) - return ComputerBuilder(**spec) - - @staticmethod - def get_computer_spec(computer): - """Get computer attributes from existing computer instance. - - These attributes can be used to create a new ComputerBuilder:: - - spec = ComputerBuilder.get_computer_spec(old_computer) - builder = ComputerBuilder(**spec) - new_computer = builder.new()""" - spec = {} - spec['label'] = computer.label - spec['description'] = computer.description - spec['hostname'] = computer.hostname - spec['scheduler'] = computer.scheduler_type - spec['transport'] = computer.transport_type - spec['use_double_quotes'] = computer.get_use_double_quotes() - spec['prepend_text'] = computer.get_prepend_text() - spec['append_text'] = computer.get_append_text() - spec['work_dir'] = computer.get_workdir() - spec['shebang'] = computer.get_shebang() - spec['mpirun_command'] = ' '.join(computer.get_mpirun_command()) - spec['mpiprocs_per_machine'] = computer.get_default_mpiprocs_per_machine() - spec['default_memory_per_machine'] = computer.get_default_memory_per_machine() - - return spec - - def __init__(self, **kwargs): - """Construct a new instance.""" - self._computer_spec = {} - self._err_acc = ErrorAccumulator(self.ComputerValidationError) - - for key, value in kwargs.items(): - self.__setattr__(key, value) - - def validate(self, raise_error=True): - """Validate the computer options.""" - return self._err_acc.result(raise_error=self.ComputerValidationError if raise_error else False) - - @with_dbenv() - def new(self): - """Build and return a new computer instance (not stored)""" - from aiida.orm import Computer - - self.validate() - - # Will be used at the end to check if all keys are known - passed_keys = set(self._computer_spec.keys()) - used = set() - - computer = Computer(label=self._get_and_count('label', used), hostname=self._get_and_count('hostname', used)) - - computer.description = self._get_and_count('description', used) - computer.scheduler_type = self._get_and_count('scheduler', used) - computer.transport_type = self._get_and_count('transport', used) - computer.set_use_double_quotes(self._get_and_count('use_double_quotes', used)) - computer.set_prepend_text(self._get_and_count('prepend_text', used)) - computer.set_append_text(self._get_and_count('append_text', used)) - computer.set_workdir(self._get_and_count('work_dir', used)) - computer.set_shebang(self._get_and_count('shebang', used)) - - mpiprocs_per_machine = self._get_and_count('mpiprocs_per_machine', used) - # In the command line, 0 means unspecified - if mpiprocs_per_machine == 0: - mpiprocs_per_machine = None - if mpiprocs_per_machine is not None: - try: - mpiprocs_per_machine = int(mpiprocs_per_machine) - except ValueError: - raise self.ComputerValidationError( - 'Invalid value provided for mpiprocs_per_machine, ' - 'must be a valid integer' - ) - if mpiprocs_per_machine <= 0: - raise self.ComputerValidationError( - 'Invalid value provided for mpiprocs_per_machine, ' - 'must be positive' - ) - computer.set_default_mpiprocs_per_machine(mpiprocs_per_machine) - - def_memory_per_machine = self._get_and_count('default_memory_per_machine', used) - if def_memory_per_machine is not None: - try: - def_memory_per_machine = int(def_memory_per_machine) - except ValueError: - raise self.ComputerValidationError( - 'Invalid value provided for memory_per_machine, must be a valid integer' - ) - try: - computer.set_default_memory_per_machine(def_memory_per_machine) - except ValidationError as exception: - raise self.ComputerValidationError(f'Invalid value for `default_memory_per_machine`: {exception}') - - mpirun_command_internal = self._get_and_count('mpirun_command', used).strip().split(' ') - if mpirun_command_internal == ['']: - mpirun_command_internal = [] - computer._mpirun_command_validator(mpirun_command_internal) # pylint: disable=protected-access - computer.set_mpirun_command(mpirun_command_internal) - - # Complain if there are keys that are passed but not used - if passed_keys - used: - raise self.ComputerValidationError( - f"Unknown parameters passed to the ComputerBuilder: {', '.join(sorted(passed_keys - used))}" - ) - - return computer - - def __getattr__(self, key): - """Access computer attributes used to build the computer""" - if not key.startswith('_'): - try: - return self._computer_spec[key] - except KeyError: - raise self.ComputerValidationError(f'{key} not set') - return None - - def _get(self, key): - """ - Return a spec, or None if not defined - - :param key: name of a computer spec""" - return self._computer_spec.get(key) - - def _get_and_count(self, key, used): - """ - Return a spec, or raise if not defined. - Moreover, add the key to the 'used' dict. - - :param key: name of a computer spec - :param used: should be a set of keys that you want to track. - ``key`` will be added to this set if the value exists in the spec and can be retrieved. - """ - retval = self.__getattr__(key) # pylint: disable=unnecessary-dunder-call - # I first get a retval, so if I get an exception, I don't add it to the 'used' set - used.add(key) - return retval - - def __setattr__(self, key, value): - if not key.startswith('_'): - self._set_computer_attr(key, value) - super().__setattr__(key, value) - - def _set_computer_attr(self, key, value): - """Set a computer attribute if it passes validation.""" - backup = self._computer_spec.copy() - self._computer_spec[key] = value - success, _ = self.validate(raise_error=False) - if not success: - self._computer_spec = backup - self.validate() - - class ComputerValidationError(Exception): - """ - A ComputerBuilder instance may raise this - - * when asked to instanciate a code with missing or invalid computer attributes - * when asked for a computer attibute that has not been set yet.""" - - def __init__(self, msg): - super().__init__() - self.msg = msg - - def __str__(self): - return self.msg - - def __repr__(self): - return f'' diff --git a/aiida/orm/utils/calcjob.py b/aiida/orm/utils/calcjob.py deleted file mode 100644 index 5fc58150a6..0000000000 --- a/aiida/orm/utils/calcjob.py +++ /dev/null @@ -1,117 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Utilities to operate on `CalcJobNode` instances.""" - -from aiida.common import exceptions - -__all__ = ('CalcJobResultManager',) - - -class CalcJobResultManager: - """ - Utility class to easily access the contents of the 'default output' node of a `CalcJobNode`. - - A `CalcJob` process can mark one of its outputs as the 'default output'. The default output node will always be - returned by the `CalcJob` and will always be a `Dict` node. - - If a `CalcJob` defines such a default output node, this utility class will simplify retrieving the result of said - node through the `CalcJobNode` instance produced by the execution of the `CalcJob`. - - The default results are only defined if the `CalcJobNode` has a `process_type` that can be successfully used - to load the corresponding `CalcJob` process class *and* if its process spec defines a `default_output_node`. - If both these conditions are met, the results are defined as the dictionary contained within the default - output node. - """ - - def __init__(self, node): - """Construct an instance of the `CalcJobResultManager`. - - :param calc: the `CalcJobNode` instance. - """ - self._node = node - self._result_node = None - self._results = None - - @property - def node(self): - """Return the `CalcJobNode` associated with this result manager instance.""" - return self._node - - def _load_results(self): - """Try to load the results for the `CalcJobNode` of this result manager. - - :raises ValueError: if no default output node could be loaded - """ - try: - process_class = self._node.process_class - except ValueError as exception: - raise ValueError(f'cannot load results because process class cannot be loaded: {exception}') - - process_spec = process_class.spec() - default_output_node_label = process_spec.default_output_node - - if default_output_node_label is None: - raise ValueError(f'cannot load results as {process_class} does not specify a default output node') - - try: - default_output_node = self.node.base.links.get_outgoing().get_node_by_label(default_output_node_label) - except exceptions.NotExistent as exception: - raise ValueError(f'cannot load results as the default node could not be retrieved: {exception}') - - self._result_node = default_output_node - self._results = default_output_node.get_dict() - - def get_results(self): - """Return the results dictionary of the default results node of the calculation node. - - This property will lazily load the dictionary. - - :return: the dictionary of the default result node - """ - if self._results is None: - self._load_results() - return self._results - - def __dir__(self): - """Add the keys of the results dictionary such that they can be autocompleted.""" - return sorted(list(self.get_results().keys())) - - def __iter__(self): - """Return an iterator over the keys of the result dictionary.""" - for key in self.get_results().keys(): - yield key - - def __getattr__(self, name): - """Return an attribute from the results dictionary. - - :param name: name of the result return - :return: value of the attribute - :raises AttributeError: if the results node cannot be retrieved or it does not contain the `name` attribute - """ - try: - return self.get_results()[name] - except ValueError as exception: - raise AttributeError from exception - except KeyError: - raise AttributeError(f"Default result node<{self._result_node.pk}> does not contain key '{name}'") - - def __getitem__(self, name): - """Return an attribute from the results dictionary. - - :param name: name of the result return - :return: value of the attribute - :raises KeyError: if the results node cannot be retrieved or it does not contain the `name` attribute - """ - try: - return self.get_results()[name] - except ValueError as exception: - raise KeyError from exception - except KeyError: - raise KeyError(f"Default result node<{self._result_node.pk}> does not contain key '{name}'") diff --git a/aiida/orm/utils/links.py b/aiida/orm/utils/links.py deleted file mode 100644 index 42566d948e..0000000000 --- a/aiida/orm/utils/links.py +++ /dev/null @@ -1,374 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Utilities for dealing with links between nodes.""" -from collections import OrderedDict -from collections.abc import Mapping -from typing import TYPE_CHECKING, Generator, Iterator, List, NamedTuple, Optional - -from aiida.common import exceptions -from aiida.common.lang import type_check - -if TYPE_CHECKING: - from aiida.common.links import LinkType - from aiida.orm import Node - from aiida.orm.implementation.storage_backend import StorageBackend - -__all__ = ('LinkPair', 'LinkTriple', 'LinkManager', 'validate_link') - - -class LinkPair(NamedTuple): - link_type: 'LinkType' - link_label: str - - -class LinkTriple(NamedTuple): - node: 'Node' - link_type: 'LinkType' - link_label: str - - -class LinkQuadruple(NamedTuple): - source_id: int - target_id: int - link_type: 'LinkType' - link_label: str - - -def link_triple_exists( - source: 'Node', - target: 'Node', - link_type: 'LinkType', - link_label: str, - backend: Optional['StorageBackend'] = None -) -> bool: - """Return whether a link with the given type and label exists between the given source and target node. - - :param source: node from which the link is outgoing - :param target: node to which the link is incoming - :param link_type: the link type - :param link_label: the link label - :return: boolean, True if the link triple exists, False otherwise - """ - from aiida.orm import Node, QueryBuilder - - target_links_cache = target.base.links.incoming_cache - - # First check if the triple exist in the cache, in the case of an unstored target node - if target_links_cache and LinkTriple(source, link_type, link_label) in target_links_cache: - return True - - # If either node is unstored (i.e. does not have a pk), the link cannot exist in the database, so no need to check - if source.pk is None or target.pk is None: - return False - - # Here we have two stored nodes, so we need to check if the same link already exists in the database. - # Finding just a single match is sufficient so we can use the `limit` clause for efficiency - builder = QueryBuilder(backend=backend) - builder.append(Node, filters={'id': source.pk}, project=['id']) - builder.append(Node, filters={'id': target.pk}, edge_filters={'type': link_type.value, 'label': link_label}) - builder.limit(1) - - return builder.count() != 0 - - -def validate_link( - source: 'Node', - target: 'Node', - link_type: 'LinkType', - link_label: str, - backend: Optional['StorageBackend'] = None -) -> None: - """ - Validate adding a link of the given type and label from a given node to ourself. - - This function will first validate the class types of the inputs and will subsequently validate whether a link of - the specified type is allowed at all between the nodes types of the source and target. - - Subsequently, the validity of the "indegree" and "outdegree" of the proposed link is validated, which means - validating that the uniqueness constraints of the incoming links into the target node and the outgoing links from - the source node are not violated. In AiiDA's provenance graph each link type has one of the following three types - of "degree" character:: - - * unique - * unique pair - * unique triple - - Each degree character has a different unique constraint on its links, here defined for the indegree:: - - * unique: any target node, it can only have a single incoming link of this type, regardless of the link label. - * unique pair: a node can have an infinite amount of incoming links of this type, as long as the labels within - that sub set, are unique. In short, it is the link pair, i.e. the tuple of the link type and label, that has - a uniquess constraint for the incoming links to a given node. - * unique triple: a node can have an infinite amount of incoming links of this type, as long as the triple tuple - of source node, link type and link label is unique. In other words, it is the link triple that has a - uniqueness constraint for the incoming links. - - The same holds for outdegree, but then it concerns outgoing links from the source node to the target node. - - For illustration purposes, consider the following example provenance graphs that are considered legal, where - `WN`, `DN` and `CN` represent a `WorkflowNode`, a `DataNode` and a `CalculationNode`, respectively:: - - 1 2 3 - ______ ______ ______ ______ ______ - | | | | | | | | | | - | WN | | DN | | DN | | WN | | WN | - |______| |______| |______| |______| |______| - | / | | | / - a | / a a | | b a | / a - _|___/ |___|_ _|___/ - | | | | | | - | CN | | CN | | DN | - |______| |______| |______| - - In example 1, the link uniqueness constraint is not violated because despite the labels having the same label `a`, - their link types, `CALL_CALC` and `INPUT_CALC`, respectively, are different and their `unique_pair` indegree is - not violated. - - Similarly, in the second example, the constraint is not violated, because despite both links having the same link - type `INPUT_CALC`, the have different labels, so the `unique_pair` indegree of the `INPUT_CALC` is not violated. - - Finally, in the third example, we see two `WorkflowNodes` both returning the same `DataNode` and with the same - label. Despite the two incoming links here having both the same type as well as the same label, the uniqueness - constraint is not violated, because the indegree for `RETURN` links is `unique_triple` which means that the triple - of source node and link type and label should be unique. - - :param source: the node from which the link is coming - :param target: the node to which the link is going - :param link_type: the type of link - :param link_label: link label - :raise TypeError: if `source` or `target` is not a Node instance, or `link_type` is not a `LinkType` enum - :raise ValueError: if the proposed link is invalid - """ - # yapf: disable - from aiida.common.links import LinkType, validate_link_label - from aiida.orm import CalculationNode, Data, Node, WorkflowNode - - type_check(link_type, LinkType, f'link_type should be a LinkType enum but got: {type(link_type)}') - type_check(source, Node, f'source should be a `Node` but got: {type(source)}') - type_check(target, Node, f'target should be a `Node` but got: {type(target)}') - - if source.uuid is None or target.uuid is None: - raise ValueError('source or target node does not have a UUID') - - if source.uuid == target.uuid: - raise ValueError('cannot add a link to oneself') - - try: - validate_link_label(link_label) - except ValueError as exception: - raise ValueError(f'invalid link label `{link_label}`: {exception}') - - # For each link type, define a tuple that defines the valid types for the source and target node, as well as - # the outdegree and indegree character. If the degree is `unique` that means that there can only be a single - # link of this type regardless of the label. If instead it is `unique_label`, an infinite amount of links of that - # type can be defined, as long as the link label is unique for the sub set of links of that type. Finally, for - # `unique_triple` the triple of node, link type and link label has to be unique. - link_mapping = { - LinkType.CALL_CALC: (WorkflowNode, CalculationNode, 'unique_triple', 'unique'), - LinkType.CALL_WORK: (WorkflowNode, WorkflowNode, 'unique_triple', 'unique'), - LinkType.CREATE: (CalculationNode, Data, 'unique_pair', 'unique'), - LinkType.INPUT_CALC: (Data, CalculationNode, 'unique_triple', 'unique_pair'), - LinkType.INPUT_WORK: (Data, WorkflowNode, 'unique_triple', 'unique_pair'), - LinkType.RETURN: (WorkflowNode, Data, 'unique_pair', 'unique_triple'), - } - - type_source, type_target, outdegree, indegree = link_mapping[link_type] - - if not isinstance(source, type_source) or not isinstance(target, type_target): - raise ValueError(f'cannot add a {link_type} link from {type(source)} to {type(target)}') - - if outdegree == 'unique_triple' or indegree == 'unique_triple': - # For a `unique_triple` degree we just have to check if an identical triple already exist, either in the cache - # or stored, in which case, the new proposed link is a duplicate and thus illegal - duplicate_link_triple = link_triple_exists(source, target, link_type, link_label, backend) - - # If the outdegree is `unique` there cannot already be any other outgoing link of that type - if outdegree == 'unique' and source.base.links.get_outgoing(link_type=link_type, only_uuid=True).all(): - raise ValueError(f'node<{source.uuid}> already has an outgoing {link_type} link') - - # If the outdegree is `unique_pair`, then the link labels for outgoing links of this type should be unique - elif outdegree == 'unique_pair' and source.base.links.get_outgoing( - link_type=link_type, only_uuid=True, link_label_filter=link_label).all(): - raise ValueError(f'node<{source.uuid}> already has an outgoing {link_type} link with label "{link_label}"') - - # If the outdegree is `unique_triple`, then the link triples of link type, link label and target should be unique - elif outdegree == 'unique_triple' and duplicate_link_triple: - raise ValueError('node<{}> already has an outgoing {} link with label "{}" from node<{}>'.format( - source.uuid, link_type, link_label, target.uuid)) - - # If the indegree is `unique` there cannot already be any other incoming links of that type - if indegree == 'unique' and target.base.links.get_incoming(link_type=link_type, only_uuid=True).all(): - raise ValueError(f'node<{target.uuid}> already has an incoming {link_type} link') - - # If the indegree is `unique_pair`, then the link labels for incoming links of this type should be unique - elif indegree == 'unique_pair' and target.base.links.get_incoming( - link_type=link_type, link_label_filter=link_label, only_uuid=True).all(): - raise ValueError(f'node<{target.uuid}> already has an incoming {link_type} link with label "{link_label}"') - - # If the indegree is `unique_triple`, then the link triples of link type, link label and source should be unique - elif indegree == 'unique_triple' and duplicate_link_triple: - raise ValueError('node<{}> already has an incoming {} link with label "{}" from node<{}>'.format( - target.uuid, link_type, link_label, source.uuid)) - - -class LinkManager: - """ - Class to convert a list of LinkTriple tuples into an iterator. - - It defines convenience methods to retrieve certain subsets of LinkTriple while checking for consistency. - For example:: - - LinkManager.one(): returns the only entry in the list or it raises an exception - LinkManager.first(): returns the first entry from the list - LinkManager.all(): returns all entries from list - - The methods `all_nodes` and `all_link_labels` are syntactic sugar wrappers around `all` to get a list of only the - incoming nodes or link labels, respectively. - """ - - def __init__(self, link_triples: List[LinkTriple]): - """Initialise the collection.""" - self.link_triples = link_triples - - def __iter__(self) -> Iterator[LinkTriple]: - """Return an iterator of LinkTriple instances. - - :return: iterator of LinkTriple instances - """ - return iter(self.link_triples) - - def __next__(self) -> Generator[LinkTriple, None, None]: - """Return the next element in the iterator. - - :return: LinkTriple - """ - for link_triple in self.link_triples: - yield link_triple - - def __bool__(self): - return bool(len(self.link_triples)) - - def next(self) -> Generator[LinkTriple, None, None]: - """Return the next element in the iterator. - - :return: LinkTriple - """ - return self.__next__() - - def one(self) -> LinkTriple: - """Return a single entry from the iterator. - - If the iterator contains no or more than one entry, an exception will be raised - :return: LinkTriple instance - :raises ValueError: if the iterator contains anything but one entry - """ - if self.link_triples: - if len(self.link_triples) > 1: - raise ValueError('more than one entry found') - return self.link_triples[0] - - raise ValueError('no entries found') - - def first(self) -> Optional[LinkTriple]: - """Return the first entry from the iterator. - - :return: LinkTriple instance or None if no entries were matched - """ - if self.link_triples: - return self.link_triples[0] - - return None - - def all(self) -> List[LinkTriple]: - """Return all entries from the list. - - :return: list of LinkTriple instances - """ - return self.link_triples - - def all_nodes(self) -> List['Node']: - """Return a list of all nodes. - - :return: list of nodes - """ - return [entry.node for entry in self.all()] - - def all_link_pairs(self) -> List[LinkPair]: - """Return a list of all link pairs. - - :return: list of LinkPair instances - """ - return [LinkPair(entry.link_type, entry.link_label) for entry in self.all()] - - def all_link_labels(self) -> List[str]: - """Return a list of all link labels. - - :return: list of link labels - """ - return [entry.link_label for entry in self.all()] - - def get_node_by_label(self, label: str) -> 'Node': - """Return the node from list for given label. - - :return: node that corresponds to the given label - :raises aiida.common.NotExistent: if the label is not present among the link_triples - """ - matching_entry = None - for entry in self.link_triples: - if entry.link_label == label: - if matching_entry is None: - matching_entry = entry.node - else: - raise exceptions.MultipleObjectsError( - f'more than one neighbor with the label {label} found' - ) - - if matching_entry is None: - raise exceptions.NotExistent(f'no neighbor with the label {label} found') - - return matching_entry - - def nested(self, sort=True): - """Construct (nested) dictionary of matched nodes that mirrors the original nesting of link namespaces. - - Process input and output namespaces can be nested, however the link labels that represent them in the database - have a flat hierarchy, and so the link labels are flattened representations of the nested namespaces. - This function reconstructs the original node nesting based on the flattened links. - - :return: dictionary of nested namespaces - :raises KeyError: if there are duplicate link labels in a namespace - """ - from aiida.engine.processes.ports import PORT_NAMESPACE_SEPARATOR - - nested: dict = {} - - for entry in self.link_triples: - - current_namespace = nested - breadcrumbs = entry.link_label.split(PORT_NAMESPACE_SEPARATOR) - - # The last element is the "leaf" port name the preceding elements are nested port namespaces - port_name = breadcrumbs[-1] - port_namespaces = breadcrumbs[:-1] - - # Get the nested namespace - for subspace in port_namespaces: - current_namespace = current_namespace.setdefault(subspace, {}) - - # Insert the node at the given port name - if port_name in current_namespace: - raise KeyError(f"duplicate label '{port_name}' in namespace '{'.'.join(port_namespaces)}'") - - current_namespace[port_name] = entry.node - - if sort: - return OrderedDict(sorted(nested.items(), key=lambda x: (not isinstance(x[1], Mapping), x))) - - return nested diff --git a/aiida/orm/utils/log.py b/aiida/orm/utils/log.py deleted file mode 100644 index 38e379e316..0000000000 --- a/aiida/orm/utils/log.py +++ /dev/null @@ -1,70 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module for logging methods/classes that need the ORM.""" -import logging - - -class DBLogHandler(logging.Handler): - """A custom db log handler for writing logs tot he database""" - - def emit(self, record): - if record.exc_info: - # We do this because if there is exc_info this will put an appropriate string in exc_text. - # See: - # https://github.com/python/cpython/blob/1c2cb516e49ceb56f76e90645e67e8df4e5df01a/Lib/logging/handlers.py#L590 - self.format(record) - - from aiida import orm - - try: - try: - backend = record.__dict__.pop('backend') - orm.Log.collection(backend).create_entry_from_record(record) - except KeyError: - # The backend should be set. We silently absorb this error - pass - - except Exception: # pylint: disable=broad-except - # To avoid loops with the error handler, I just print. - # Hopefully, though, this should not happen! - import traceback - traceback.print_exc() - raise - - -def get_dblogger_extra(node): - """Return the additional information necessary to attach any log records to the given node instance. - - :param node: a Node instance - """ - from aiida.orm import Node - - # If the object is not a Node or it is not stored, then any associated log records should bot be stored. This is - # accomplished by returning an empty dictionary because the `dbnode_id` is required to successfully store it. - if not isinstance(node, Node) or not node.is_stored: - return {} - - return {'dbnode_id': node.pk, 'backend': node.backend} - - -def create_logger_adapter(logger, node): - """Create a logger adapter for the given Node instance. - - :param logger: the logger to adapt - :param node: the node instance to create the adapter for - :return: the logger adapter - :rtype: :class:`logging.LoggerAdapter` - """ - from aiida.orm import Node - - if not isinstance(node, Node): - raise TypeError('node should be an instance of `Node`') - - return logging.LoggerAdapter(logger=logger, extra=get_dblogger_extra(node)) diff --git a/aiida/orm/utils/node.py b/aiida/orm/utils/node.py deleted file mode 100644 index 4734a27b9d..0000000000 --- a/aiida/orm/utils/node.py +++ /dev/null @@ -1,168 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Utilities to operate on `Node` classes.""" -from abc import ABCMeta -import logging -import warnings - -from aiida.common import exceptions -from aiida.common.utils import strip_prefix - -__all__ = ( - 'load_node_class', - 'get_type_string_from_class', - 'get_query_type_from_type_string', - 'AbstractNodeMeta', -) - - -def load_node_class(type_string): - """ - Return the `Node` sub class that corresponds to the given type string. - - :param type_string: the `type` string of the node - :return: a sub class of `Node` - """ - from aiida.orm import Data, Node - from aiida.plugins.entry_point import load_entry_point - - if type_string == '': - return Node - - if type_string == 'data.Data.': - return Data - - if not type_string.endswith('.'): - raise exceptions.DbContentError(f'The type string `{type_string}` is invalid') - - try: - base_path = type_string.rsplit('.', 2)[0] - except ValueError: - raise exceptions.EntryPointError from ValueError - - # This exception needs to be there to make migrations work that rely on the old type string starting with `node.` - # Since now the type strings no longer have that prefix, we simply strip it and continue with the normal logic. - if base_path.startswith('node.'): - base_path = strip_prefix(base_path, 'node.') - - # Data nodes are the only ones with sub classes that are still external, so if the plugin is not available - # we fall back on the base node type - if base_path.startswith('data.'): - entry_point_name = strip_prefix(base_path, 'data.') - try: - return load_entry_point('aiida.data', entry_point_name) - except exceptions.MissingEntryPointError: - return Data - - if base_path.startswith('process'): - entry_point_name = strip_prefix(base_path, 'nodes.') - return load_entry_point('aiida.node', entry_point_name) - - # At this point we really have an anomalous type string. At some point, storing nodes with unresolvable type strings - # was allowed, for example by creating a sub class in a shell and then storing an instance. Attempting to load the - # node then would fail miserably. This is now no longer allowed, but we need a fallback for existing cases, which - # should be rare. We fallback on `Data` and not `Node` because bare node instances are also not storable and so the - # logic of the ORM is not well defined for a loaded instance of the base `Node` class. - warnings.warn(f'unknown type string `{type_string}`, falling back onto `Data` class') # pylint: disable=no-member - - return Data - - -def get_type_string_from_class(class_module, class_name): - """ - Given the module and name of a class, determine the orm_class_type string, which codifies the - orm class that is to be used. The returned string will always have a terminating period, which - is required to query for the string in the database - - :param class_module: module of the class - :param class_name: name of the class - """ - from aiida.plugins.entry_point import ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP, get_entry_point_from_class - - group, entry_point = get_entry_point_from_class(class_module, class_name) - - # If we can reverse engineer an entry point group and name, we're dealing with an external class - if group and entry_point: - module_base_path = ENTRY_POINT_GROUP_TO_MODULE_PATH_MAP[group] - type_string = f'{module_base_path}.{entry_point.name}.{class_name}.' - - # Otherwise we are dealing with an internal class - else: - type_string = f'{class_module}.{class_name}.' - - prefixes = ('aiida.orm.nodes.',) - - # Sequentially and **in order** strip the prefixes if present - for prefix in prefixes: - type_string = strip_prefix(type_string, prefix) - - # This needs to be here as long as `aiida.orm.nodes.data` does not live in `aiida.orm.nodes.data` because all the - # `Data` instances will have a type string that starts with `data.` instead of `nodes.`, so in order to match any - # `Node` we have to look for any type string essentially. - if type_string == 'node.Node.': - type_string = '' - - return type_string - - -def is_valid_node_type_string(type_string, raise_on_false=False): - """ - Checks whether type string of a Node is valid. - - :param type_string: the plugin_type_string attribute of a Node - :return: True if type string is valid, else false - """ - # Currently the type string for the top-level node is empty. - # Change this when a consistent type string hierarchy is introduced. - if type_string == '': - return True - - # Note: this allows for the user-defined type strings like 'group' in the QueryBuilder - # as well as the usual type strings like 'data.parameter.ParameterData.' - if type_string.count('.') == 1 or not type_string.endswith('.'): - if raise_on_false: - raise exceptions.DbContentError(f'The type string {type_string} is invalid') - return False - - return True - - -def get_query_type_from_type_string(type_string): - """ - Take the type string of a Node and create the queryable type string - - :param type_string: the plugin_type_string attribute of a Node - :return: the type string that can be used to query for - """ - is_valid_node_type_string(type_string, raise_on_false=True) - - # Currently the type string for the top-level node is empty. - # Change this when a consistent type string hierarchy is introduced. - if type_string == '': - return '' - - type_path = type_string.rsplit('.', 2)[0] - type_string = f'{type_path}.' - - return type_string - - -class AbstractNodeMeta(ABCMeta): - """Some python black magic to set correctly the logger also in subclasses.""" - - def __new__(mcs, name, bases, namespace, **kwargs): - newcls = ABCMeta.__new__(mcs, name, bases, namespace, **kwargs) # pylint: disable=too-many-function-args - newcls._logger = logging.getLogger(f"{namespace['__module__']}.{name}") - - # Set the plugin type string and query type string based on the plugin type string - newcls._plugin_type_string = get_type_string_from_class(namespace['__module__'], name) # pylint: disable=protected-access - newcls._query_type_string = get_query_type_from_type_string(newcls._plugin_type_string) # pylint: disable=protected-access - - return newcls diff --git a/aiida/orm/utils/remote.py b/aiida/orm/utils/remote.py deleted file mode 100644 index c8c9ef8138..0000000000 --- a/aiida/orm/utils/remote.py +++ /dev/null @@ -1,127 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Utilities for operations on files on remote computers.""" -import os - -from aiida.orm.nodes.data.remote.base import RemoteData - - -def clean_remote(transport, path): - """ - Recursively remove a remote folder, with the given absolute path, and all its contents. The path should be - made accessible through the transport channel, which should already be open - - :param transport: an open Transport channel - :param path: an absolute path on the remote made available through the transport - """ - if not isinstance(path, str): - raise ValueError('the path has to be a string type') - - if not os.path.isabs(path): - raise ValueError('the path should be absolute') - - if not transport.is_open: - raise ValueError('the transport should already be open') - - basedir, relative_path = os.path.split(path) - - try: - transport.chdir(basedir) - transport.rmtree(relative_path) - except IOError: - pass - - -def get_calcjob_remote_paths( # pylint: disable=too-many-locals - pks=None, - past_days=None, - older_than=None, - computers=None, - user=None, - backend=None, - exit_status=None, - only_not_cleaned=False, -): - """ - Return a mapping of computer uuids to a list of remote paths, for a given set of calcjobs. The set of - calcjobs will be determined by a query with filters based on the pks, past_days, older_than, - computers and user arguments. - - :param pks: only include calcjobs with a pk in this list - :param past_days: only include calcjobs created since past_days - :param older_than: only include calcjobs older than - :param computers: only include calcjobs that were ran on these computers - :param user: only include calcjobs of this user - :param exit_status: only select calcjob with this exit_status - :param only_not_cleaned: only include calcjobs whose workdir have not been cleaned - :return: mapping of computer uuid and list of remote folder - """ - from datetime import timedelta - - from aiida import orm - from aiida.common import timezone - from aiida.orm import CalcJobNode - - filters_calc = {} - filters_computer = {} - filters_remote = {} - - if user is None: - user = orm.User.collection.get_default() - - if computers is not None: - filters_computer['id'] = {'in': [computer.pk for computer in computers]} - - if past_days is not None: - filters_calc['mtime'] = {'>': timezone.now() - timedelta(days=past_days)} - - if older_than is not None: - older_filter = {'<': timezone.now() - timedelta(days=older_than)} - # Check if we need to apply the AND condition - if 'mtime' not in filters_calc: - filters_calc['mtime'] = older_filter - else: - past_filter = filters_calc['mtime'] - filters_calc['mtime'] = {'and': [past_filter, older_filter]} - - if exit_status is not None: - filters_calc['attributes.exit_status'] = exit_status - - if pks: - filters_calc['id'] = {'in': pks} - - if only_not_cleaned is True: - filters_remote['or'] = [{ - f'extras.{RemoteData.KEY_EXTRA_CLEANED}': { - '!==': True - } - }, { - 'extras': { - '!has_key': RemoteData.KEY_EXTRA_CLEANED - } - }] - - query = orm.QueryBuilder(backend=backend) - query.append(CalcJobNode, tag='calc', filters=filters_calc) - query.append( - RemoteData, tag='remote', project=['*'], edge_filters={'label': 'remote_folder'}, filters=filters_remote - ) - query.append(orm.Computer, with_node='calc', tag='computer', project=['uuid'], filters=filters_computer) - query.append(orm.User, with_node='calc', filters={'email': user.email}) - - if query.count() == 0: - return None - - path_mapping = {} - - for remote_data, computer_uuid in query.all(): - path_mapping.setdefault(computer_uuid, []).append(remote_data) - - return path_mapping diff --git a/aiida/parsers/__init__.py b/aiida/parsers/__init__.py deleted file mode 100644 index b3789ed596..0000000000 --- a/aiida/parsers/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module for classes and utilities to write parsers for calculation jobs.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .parser import * - -__all__ = ( - 'Parser', -) - -# yapf: enable diff --git a/aiida/parsers/parser.py b/aiida/parsers/parser.py deleted file mode 100644 index c4d4976063..0000000000 --- a/aiida/parsers/parser.py +++ /dev/null @@ -1,181 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -This module implements a generic output plugin, that is general enough -to allow the reading of the outputs of a calculation. -""" -from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple - -from aiida.common import exceptions, extendeddicts -from aiida.engine import ExitCode, ExitCodesNamespace, calcfunction -from aiida.engine.processes.ports import CalcJobOutputPort - -if TYPE_CHECKING: - from aiida import orm - from aiida.orm import CalcJobNode, Data, FolderData - -__all__ = ('Parser',) - - -class Parser(ABC): - """Base class for a Parser that can parse the outputs produced by a CalcJob process.""" - - def __init__(self, node: 'CalcJobNode'): - """Construct the Parser instance. - - :param node: the `CalcJobNode` that contains the results of the executed `CalcJob` process. - """ - from aiida.common.log import AIIDA_LOGGER - from aiida.orm.utils.log import create_logger_adapter - - self._logger = create_logger_adapter(AIIDA_LOGGER.getChild('parser').getChild(self.__class__.__name__), node) - self._node = node - self._outputs = extendeddicts.AttributeDict() - - @property - def logger(self): - """Return the logger preconfigured for the calculation node associated with this parser instance. - - :return: `logging.Logger` - """ - return self._logger - - @property - def node(self) -> 'CalcJobNode': - """Return the node instance - - :return: the `CalcJobNode` instance - """ - return self._node - - @property - def exit_codes(self) -> ExitCodesNamespace: - """Return the exit codes defined for the process class of the node being parsed. - - :returns: ExitCodesNamespace of ExitCode named tuples - """ - return self.node.process_class.exit_codes - - @property - def retrieved(self) -> 'FolderData': - return self.node.base.links.get_outgoing().get_node_by_label( - self.node.process_class.link_label_retrieved # type: ignore - ) - - @property - def outputs(self): - """Return the dictionary of outputs that have been registered. - - :return: an AttributeDict instance with the registered output nodes - """ - return self._outputs - - def out(self, link_label: str, node: 'Data') -> None: - """Register a node as an output with the given link label. - - :param link_label: the name of the link label - :param node: the node to register as an output - :raises aiida.common.ModificationNotAllowed: if an output node was already registered with the same link label - """ - if link_label in self._outputs: - raise exceptions.ModificationNotAllowed(f'the output {link_label} already exists') - self._outputs[link_label] = node - - def get_outputs_for_parsing(self): - """Return the dictionary of nodes that should be passed to the `Parser.parse` call. - - Output nodes can be marked as being required by the `parse` method, by setting the `pass_to_parser` attribute, - in the `spec.output` call in the process spec of the `CalcJob`, to True. - - :return: dictionary of nodes that are required by the `parse` method - """ - link_triples = self.node.base.links.get_outgoing() - result = {} - - for label, port in self.node.process_class.spec().outputs.items(): - if isinstance(port, CalcJobOutputPort) and port.pass_to_parser: - try: - result[label] = link_triples.get_node_by_label(label) - except exceptions.NotExistent: - if port.required: - raise - - return result - - @classmethod - def parse_from_node(cls, - node: 'CalcJobNode', - store_provenance=True, - retrieved_temporary_folder=None) -> Tuple[Optional[Dict[str, Any]], 'orm.CalcFunctionNode']: - """Parse the outputs directly from the `CalcJobNode`. - - If `store_provenance` is set to False, a `CalcFunctionNode` will still be generated, but it will not be stored. - It's storing method will also be disabled, making it impossible to store, because storing it afterwards would - not have the expected effect, as the outputs it produced will not be stored with it. - - This method is useful to test parsing in unit tests where a `CalcJobNode` can be mocked without actually having - to run a `CalcJob`. It can also be useful to actually re-perform the parsing of a completed `CalcJob` with a - different parser. - - :param node: a `CalcJobNode` instance - :param store_provenance: bool, if True will store the parsing as a `CalcFunctionNode` in the provenance - :param retrieved_temporary_folder: absolute path to folder with contents of `retrieved_temporary_list` - :return: a tuple of the parsed results and the `CalcFunctionNode` representing the process of parsing - """ - parser = cls(node=node) - - @calcfunction - def parse_calcfunction(**kwargs): - """A wrapper function that will turn calling the `Parser.parse` method into a `CalcFunctionNode`. - - .. warning:: This implementation of a `calcfunction` uses the `Process.current` to circumvent the limitation - of not being able to return both output nodes as well as an exit code. However, since this calculation - function is supposed to emulate the parsing of a `CalcJob`, which *does* have that capability, I have - to use this method. This method should however not be used in process functions, in other words: - - Do not try this at home! - - :param kwargs: keyword arguments that are passed to `Parser.parse` after it has been constructed - """ - from aiida.engine import Process - - if retrieved_temporary_folder is not None: - kwargs['retrieved_temporary_folder'] = retrieved_temporary_folder - - exit_code = parser.parse(**kwargs) - outputs = parser.outputs - - if exit_code and exit_code.status: - # In the case that an exit code was returned, still attach all the registered outputs on the current - # process as well, which should represent this `CalcFunctionNode`. Otherwise the caller of the - # `parse_from_node` method will get an empty dictionary as a result, despite the `Parser.parse` method - # having registered outputs. - process = Process.current() - process.out_many(outputs) # type: ignore - return exit_code - - return dict(outputs) - - inputs = {'metadata': {'store_provenance': store_provenance}} - inputs.update(parser.get_outputs_for_parsing()) - - return parse_calcfunction.run_get_node(**inputs) # type: ignore - - @abstractmethod - def parse(self, **kwargs) -> Optional[ExitCode]: - """Parse the contents of the output files retrieved in the `FolderData`. - - This method should be implemented in the sub class. Outputs can be registered through the `out` method. - After the `parse` call finishes, the runner will automatically link them up to the underlying `CalcJobNode`. - - :param kwargs: output nodes attached to the `CalcJobNode` of the parser instance. - :return: an instance of ExitCode or None - """ diff --git a/aiida/parsers/plugins/__init__.py b/aiida/parsers/plugins/__init__.py deleted file mode 100644 index 2776a55f97..0000000000 --- a/aiida/parsers/plugins/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### diff --git a/aiida/parsers/plugins/arithmetic/__init__.py b/aiida/parsers/plugins/arithmetic/__init__.py deleted file mode 100644 index 2776a55f97..0000000000 --- a/aiida/parsers/plugins/arithmetic/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### diff --git a/aiida/parsers/plugins/arithmetic/add.py b/aiida/parsers/plugins/arithmetic/add.py deleted file mode 100644 index 8c2e32ee3d..0000000000 --- a/aiida/parsers/plugins/arithmetic/add.py +++ /dev/null @@ -1,50 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# Warning: this implementation is used directly in the documentation as a literal-include, which means that if any part -# of this code is changed, the snippets in the file `docs/source/howto/codes.rst` have to be checked for consistency. -# mypy: disable_error_code=arg-type -"""Parser for an `ArithmeticAddCalculation` job.""" -from aiida.parsers.parser import Parser - - -class ArithmeticAddParser(Parser): - """Parser for an `ArithmeticAddCalculation` job.""" - - def parse(self, **kwargs): - """Parse the contents of the output files stored in the `retrieved` output node.""" - from aiida.orm import Int - - try: - with self.retrieved.base.repository.open(self.node.get_option('output_filename'), 'r') as handle: - result = int(handle.read()) - except OSError: - return self.exit_codes.ERROR_READING_OUTPUT_FILE - except ValueError: - return self.exit_codes.ERROR_INVALID_OUTPUT - - self.out('sum', Int(result)) - - if result < 0: - return self.exit_codes.ERROR_NEGATIVE_NUMBER - - -class SimpleArithmeticAddParser(Parser): - """Simple parser for an `ArithmeticAddCalculation` job (for demonstration purposes only).""" - - def parse(self, **kwargs): - """Parse the contents of the output files stored in the `retrieved` output node.""" - from aiida.orm import Int - - output_folder = self.retrieved - - with output_folder.base.repository.open(self.node.get_option('output_filename'), 'r') as handle: - result = int(handle.read()) - - self.out('sum', Int(result)) diff --git a/aiida/parsers/plugins/templatereplacer/__init__.py b/aiida/parsers/plugins/templatereplacer/__init__.py deleted file mode 100644 index 2776a55f97..0000000000 --- a/aiida/parsers/plugins/templatereplacer/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### diff --git a/aiida/parsers/plugins/templatereplacer/parser.py b/aiida/parsers/plugins/templatereplacer/parser.py deleted file mode 100644 index fa5d0acf0f..0000000000 --- a/aiida/parsers/plugins/templatereplacer/parser.py +++ /dev/null @@ -1,69 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Parser for the `TemplatereplacerCalculation` calculation job.""" -import os - -from aiida.orm import Dict -from aiida.parsers.parser import Parser - - -class TemplatereplacerParser(Parser): - """Parser for the `TemplatereplacerCalculation` calculation job.""" - - def parse(self, **kwargs): - """Parse the contents of the output files retrieved in the `FolderData`.""" - output_folder = self.retrieved - template = self.node.inputs.template.get_dict() - - try: - output_file = template['output_file_name'] - except KeyError: - return self.exit_codes.ERROR_NO_OUTPUT_FILE_NAME_DEFINED - - try: - with output_folder.base.repository.open(output_file, 'r') as handle: - result = handle.read() - except (OSError, IOError): - self.logger.exception(f'unable to parse the output for CalcJobNode<{self.node.pk}>') - return self.exit_codes.ERROR_READING_OUTPUT_FILE - - output_dict: dict = {'value': result, 'retrieved_temporary_files': []} - retrieve_temporary_files = template.get('retrieve_temporary_files', None) - - # If the 'retrieve_temporary_files' key was set in the template input node, we expect a temporary directory - # to have been passed in the keyword arguments under the name `retrieved_temporary_folder`. - if retrieve_temporary_files is not None: - try: - retrieved_temporary_folder = kwargs['retrieved_temporary_folder'] - except KeyError: - return self.exit_codes.ERROR_NO_TEMPORARY_RETRIEVED_FOLDER - - for retrieved_file in retrieve_temporary_files: - - file_path = os.path.join(retrieved_temporary_folder, retrieved_file) - - if not os.path.isfile(file_path): - self.logger.error( - 'the file {} was not found in the temporary retrieved folder {}'.format( - retrieved_file, retrieved_temporary_folder - ) - ) - return self.exit_codes.ERROR_READING_TEMPORARY_RETRIEVED_FILE - - with open(file_path, 'r', encoding='utf8') as handle: - parsed_value = handle.read().strip() - - # We always strip the content of the file from whitespace to simplify testing for expected output - output_dict['retrieved_temporary_files'].append((retrieved_file, parsed_value)) - - label = self.node.process_class.spec().default_output_node # type: ignore - self.out(label, Dict(dict=output_dict)) - - return diff --git a/aiida/plugins/__init__.py b/aiida/plugins/__init__.py deleted file mode 100644 index c09ca21af1..0000000000 --- a/aiida/plugins/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Classes and functions to load and interact with plugin classes accessible through defined entry points.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .entry_point import * -from .factories import * -from .utils import * - -__all__ = ( - 'BaseFactory', - 'CalcJobImporterFactory', - 'CalculationFactory', - 'DataFactory', - 'DbImporterFactory', - 'GroupFactory', - 'OrbitalFactory', - 'ParserFactory', - 'PluginVersionProvider', - 'SchedulerFactory', - 'StorageFactory', - 'TransportFactory', - 'WorkflowFactory', - 'get_entry_points', - 'load_entry_point', - 'load_entry_point_from_string', - 'parse_entry_point', -) - -# yapf: enable diff --git a/aiida/plugins/utils.py b/aiida/plugins/utils.py deleted file mode 100644 index 2d34ffdb6b..0000000000 --- a/aiida/plugins/utils.py +++ /dev/null @@ -1,96 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Utilities dealing with plugins and entry points.""" -from __future__ import annotations - -from importlib import import_module -from inspect import isclass, isfunction -from logging import Logger -from types import FunctionType -import typing as t - -from aiida.common import AIIDA_LOGGER -from aiida.common.exceptions import EntryPointError - -from .entry_point import load_entry_point_from_string - -__all__ = ('PluginVersionProvider',) - -KEY_VERSION_ROOT: str = 'version' -KEY_VERSION_CORE: str = 'core' # The version of `aiida-core` -KEY_VERSION_PLUGIN: str = 'plugin' # The version of the plugin top level module, e.g. `aiida-quantumespresso` - - -class PluginVersionProvider: - """Utility class that determines version information about a given plugin resource.""" - - def __init__(self): - self._cache: dict[type | FunctionType, dict[t.Any, dict[t.Any, t.Any]]] = {} - self._logger: Logger = AIIDA_LOGGER.getChild('plugin_version_provider') - - @property - def logger(self) -> Logger: - return self._logger - - def get_version_info(self, plugin: str | type) -> dict[t.Any, dict[t.Any, t.Any]]: - """Get the version information for a given plugin. - - .. note:: - - This container will keep a cache, so if this method was already called for the given ``plugin`` before for - this instance, the result computed at the last invocation will be returned. - - :param plugin: A class, function, or an entry point string. If the type is string, it will be assumed to be an - entry point string and the class will attempt to load it first. It should be a full entry point string, - including the entry point group. - :return: Dictionary with the `version.core` and optionally `version.plugin` if it could be determined. - :raises EntryPointError: If ``plugin`` is a string but could not be loaded as a valid entry point. - :raises TypeError: If ``plugin`` (or the resource pointed to it in the case of an entry point) is not a class - or a function. - """ - from aiida import __version__ as version_core - - if isinstance(plugin, str): - try: - plugin = load_entry_point_from_string(plugin) - except EntryPointError as exc: - raise EntryPointError(f'got string `{plugin}` but could not load corresponding entry point') from exc - - if not isclass(plugin) and not isfunction(plugin): - raise TypeError(f'`{plugin}` is not a class nor a function.') - - # If the `plugin` already exists in the cache, simply return it. On purpose we do not verify whether the version - # information is completed. If it failed the first time, we don't retry. If the failure was temporarily, whoever - # holds a reference to this instance can simply reconstruct it to start with a clean slate. - if plugin in self._cache: - return self._cache[plugin] - - self._cache[plugin] = { - KEY_VERSION_ROOT: { - KEY_VERSION_CORE: version_core, - } - } - - try: - parent_module_name = plugin.__module__.split('.')[0] - parent_module = import_module(parent_module_name) - except (AttributeError, IndexError, ImportError): - self.logger.debug(f'could not determine the top level module for plugin: {plugin}') - return self._cache[plugin] - - try: - version_plugin = parent_module.__version__ - except AttributeError: - self.logger.debug(f'parent module does not define `__version__` attribute for plugin: {plugin}') - return self._cache[plugin] - - self._cache[plugin][KEY_VERSION_ROOT][KEY_VERSION_PLUGIN] = version_plugin - - return self._cache[plugin] diff --git a/aiida/repository/__init__.py b/aiida/repository/__init__.py deleted file mode 100644 index c828ca07f1..0000000000 --- a/aiida/repository/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with resources dealing with the file repository.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .backend import * -from .common import * -from .repository import * - -__all__ = ( - 'AbstractRepositoryBackend', - 'DiskObjectStoreRepositoryBackend', - 'File', - 'FileType', - 'Repository', - 'SandboxRepositoryBackend', -) - -# yapf: enable diff --git a/aiida/repository/backend/__init__.py b/aiida/repository/backend/__init__.py deleted file mode 100644 index ea4ab3386f..0000000000 --- a/aiida/repository/backend/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# -*- coding: utf-8 -*- -"""Module for file repository backend implementations.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .abstract import * -from .disk_object_store import * -from .sandbox import * - -__all__ = ( - 'AbstractRepositoryBackend', - 'DiskObjectStoreRepositoryBackend', - 'SandboxRepositoryBackend', -) - -# yapf: enable diff --git a/aiida/repository/backend/abstract.py b/aiida/repository/backend/abstract.py deleted file mode 100644 index c19f8629b1..0000000000 --- a/aiida/repository/backend/abstract.py +++ /dev/null @@ -1,219 +0,0 @@ -# -*- coding: utf-8 -*- -"""Class that defines the abstract interface for an object repository. - -The scope of this class is intentionally very narrow. Any backend implementation should merely provide the methods to -store binary blobs, or "objects", and return a string-based key that unique identifies the object that was just created. -This key should then be able to be used to retrieve the bytes of the corresponding object or to delete it. -""" -import abc -import contextlib -import hashlib -import io -import pathlib -from typing import BinaryIO, Iterable, Iterator, List, Optional, Tuple, Union - -from aiida.common.hashing import chunked_file_hash - -__all__ = ('AbstractRepositoryBackend',) - - -class AbstractRepositoryBackend(metaclass=abc.ABCMeta): - """Class that defines the abstract interface for an object repository. - - The repository backend only deals with raw bytes, both when creating new objects as well as when returning a stream - or the content of an existing object. The encoding and decoding of the byte content should be done by the client - upstream. The file repository backend is also not expected to keep any kind of file hierarchy but must be assumed - to be a simple flat data store. When files are created in the file object repository, the implementation will return - a string-based key with which the content of the stored object can be addressed. This key is guaranteed to be unique - and persistent. Persisting the key or mapping it onto a virtual file hierarchy is again up to the client upstream. - """ - - @property - @abc.abstractmethod - def uuid(self) -> Optional[str]: - """Return the unique identifier of the repository.""" - - @property - @abc.abstractmethod - def key_format(self) -> Optional[str]: - """Return the format for the keys of the repository. - - Important for when migrating between backends (e.g. archive -> main), as if they are not equal then it is - necessary to re-compute all the `Node.base.repository.metadata` before importing (otherwise they will not match - with the repository). - """ - - @abc.abstractmethod - def initialise(self, **kwargs) -> None: - """Initialise the repository if it hasn't already been initialised. - - :param kwargs: parameters for the initialisation. - """ - - @property - @abc.abstractmethod - def is_initialised(self) -> bool: - """Return whether the repository has been initialised.""" - - @abc.abstractmethod - def erase(self) -> None: - """Delete the repository itself and all its contents. - - .. note:: This should not merely delete the contents of the repository but any resources it created. For - example, if the repository is essentially a folder on disk, the folder itself should also be deleted, not - just its contents. - """ - - @staticmethod - def is_readable_byte_stream(handle) -> bool: - return hasattr(handle, 'read') and hasattr(handle, 'mode') and 'b' in handle.mode - - def put_object_from_filelike(self, handle: BinaryIO) -> str: - """Store the byte contents of a file in the repository. - - :param handle: filelike object with the byte content to be stored. - :return: the generated fully qualified identifier for the object within the repository. - :raises TypeError: if the handle is not a byte stream. - """ - if not isinstance(handle, io.BufferedIOBase) and not self.is_readable_byte_stream(handle): - raise TypeError(f'handle does not seem to be a byte stream: {type(handle)}.') - return self._put_object_from_filelike(handle) - - @abc.abstractmethod - def _put_object_from_filelike(self, handle: BinaryIO) -> str: - pass - - def put_object_from_file(self, filepath: Union[str, pathlib.Path]) -> str: - """Store a new object with contents of the file located at `filepath` on this file system. - - :param filepath: absolute path of file whose contents to copy to the repository. - :return: the generated fully qualified identifier for the object within the repository. - :raises TypeError: if the handle is not a byte stream. - """ - with open(filepath, mode='rb') as handle: - return self.put_object_from_filelike(handle) - - @abc.abstractmethod - def has_objects(self, keys: List[str]) -> List[bool]: - """Return whether the repository has an object with the given key. - - :param keys: - list of fully qualified identifiers for objects within the repository. - :return: - list of logicals, in the same order as the keys provided, with value True if the respective - object exists and False otherwise. - """ - - def has_object(self, key: str) -> bool: - """Return whether the repository has an object with the given key. - - :param key: fully qualified identifier for the object within the repository. - :return: True if the object exists, False otherwise. - """ - return self.has_objects([key])[0] - - @abc.abstractmethod - def list_objects(self) -> Iterable[str]: - """Return iterable that yields all available objects by key. - - :return: An iterable for all the available object keys. - """ - - @abc.abstractmethod - def get_info(self, detailed: bool = False, **kwargs) -> dict: - """Returns relevant information about the content of the repository. - - :param detailed: - flag to enable extra information (detailed=False by default, only returns basic information). - - :return: a dictionary with the information. - """ - - @abc.abstractmethod - def maintain(self, dry_run: bool = False, live: bool = True, **kwargs) -> None: - """Performs maintenance operations. - - :param dry_run: - flag to only print the actions that would be taken without actually executing them. - - :param live: - flag to indicate to the backend whether AiiDA is live or not (i.e. if the profile of the - backend is currently being used/accessed). The backend is expected then to only allow (and - thus set by default) the operations that are safe to perform in this state. - """ - - @contextlib.contextmanager - def open(self, key: str) -> Iterator[BinaryIO]: # type: ignore - """Open a file handle to an object stored under the given key. - - .. note:: this should only be used to open a handle to read an existing file. To write a new file use the method - ``put_object_from_filelike`` instead. - - :param key: fully qualified identifier for the object within the repository. - :return: yield a byte stream object. - :raise FileNotFoundError: if the file does not exist. - :raise OSError: if the file could not be opened. - """ - if not self.has_object(key): - raise FileNotFoundError(f'object with key `{key}` does not exist.') - - def get_object_content(self, key: str) -> bytes: - """Return the content of a object identified by key. - - :param key: fully qualified identifier for the object within the repository. - :raise FileNotFoundError: if the file does not exist. - :raise OSError: if the file could not be opened. - """ - with self.open(key) as handle: # pylint: disable=not-context-manager - return handle.read() - - @abc.abstractmethod - def iter_object_streams(self, keys: List[str]) -> Iterator[Tuple[str, BinaryIO]]: - """Return an iterator over the (read-only) byte streams of objects identified by key. - - .. note:: handles should only be read within the context of this iterator. - - :param keys: fully qualified identifiers for the objects within the repository. - :return: an iterator over the object byte streams. - :raise FileNotFoundError: if the file does not exist. - :raise OSError: if a file could not be opened. - """ - - def get_object_hash(self, key: str) -> str: - """Return the SHA-256 hash of an object stored under the given key. - - .. important:: - A SHA-256 hash should always be returned, - to ensure consistency across different repository implementations. - - :param key: fully qualified identifier for the object within the repository. - :raise FileNotFoundError: if the file does not exist. - :raise OSError: if the file could not be opened. - """ - with self.open(key) as handle: # pylint: disable=not-context-manager - return chunked_file_hash(handle, hashlib.sha256) - - @abc.abstractmethod - def delete_objects(self, keys: List[str]) -> None: - """Delete the objects from the repository. - - :param keys: list of fully qualified identifiers for the objects within the repository. - :raise FileNotFoundError: if any of the files does not exist. - :raise OSError: if any of the files could not be deleted. - """ - keys_exist = self.has_objects(keys) - if not all(keys_exist): - error_message = 'some of the keys provided do not correspond to any object in the repository:\n' - for indx, key_exists in enumerate(keys_exist): - if not key_exists: - error_message += f' > object with key `{keys[indx]}` does not exist.\n' - raise FileNotFoundError(error_message) - - def delete_object(self, key: str) -> None: - """Delete the object from the repository. - - :param key: fully qualified identifier for the object within the repository. - :raise FileNotFoundError: if the file does not exist. - :raise OSError: if the file could not be deleted. - """ - return self.delete_objects([key]) diff --git a/aiida/repository/common.py b/aiida/repository/common.py deleted file mode 100644 index f99d1a55dd..0000000000 --- a/aiida/repository/common.py +++ /dev/null @@ -1,141 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Module with resources common to the repository.""" -import enum -import typing - -__all__ = ('FileType', 'File') - - -class FileType(enum.Enum): - """Enumeration to represent the type of a file object.""" - - DIRECTORY = 0 - FILE = 1 - - -class File(): - """Data class representing a file object.""" - - def __init__( - self, - name: str = '', - file_type: FileType = FileType.DIRECTORY, - key: typing.Union[str, None] = None, - objects: typing.Optional[typing.Dict[str, 'File']] = None - ) -> None: - """Construct a new instance. - - :param name: The final element of the file path - :param file_type: Identifies whether the File is a file or a directory - :param key: A key to map the file to its contents in the backend repository (file only) - :param objects: Mapping of child names to child Files (directory only) - - :raises ValueError: If a key is defined for a directory, - or objects are defined for a file - """ - if not isinstance(name, str): - raise TypeError('name should be a string.') - - if not isinstance(file_type, FileType): - raise TypeError('file_type should be an instance of `FileType`.') - - if key is not None and not isinstance(key, str): - raise TypeError('key should be `None` or a string.') - - if objects is not None and any(not isinstance(obj, self.__class__) for obj in objects.values()): - raise TypeError('objects should be `None` or a dictionary of `File` instances.') - - if file_type == FileType.DIRECTORY and key is not None: - raise ValueError('an object of type `FileType.DIRECTORY` cannot define a key.') - - if file_type == FileType.FILE and objects is not None: - raise ValueError('an object of type `FileType.FILE` cannot define any objects.') - - self._name = name - self._file_type = file_type - self._key = key - self._objects = objects or {} - - @classmethod - def from_serialized(cls, serialized: dict, name='') -> 'File': - """Construct a new instance from a serialized instance. - - :param serialized: the serialized instance. - :return: the reconstructed file object. - """ - if 'k' in serialized: - file_type = FileType.FILE - key = serialized['k'] - objects = None - else: - file_type = FileType.DIRECTORY - key = None - objects = {name: File.from_serialized(obj, name) for name, obj in serialized.get('o', {}).items()} - - instance = cls.__new__(cls) - instance.__init__(name, file_type, key, objects) # type: ignore[misc] - return instance - - def serialize(self) -> dict: - """Serialize the metadata into a JSON-serializable format. - - .. note:: the serialization format is optimized to reduce the size in bytes. - - :return: dictionary with the content metadata. - """ - if self.file_type == FileType.DIRECTORY: - if self.objects: - return {'o': {key: obj.serialize() for key, obj in self.objects.items()}} - return {} - return {'k': self.key} - - @property - def name(self) -> str: - """Return the name of the file object.""" - return self._name - - @property - def file_type(self) -> FileType: - """Return the file type of the file object.""" - return self._file_type - - def is_file(self) -> bool: - """Return whether this instance is a file object.""" - return self.file_type == FileType.FILE - - def is_dir(self) -> bool: - """Return whether this instance is a directory object.""" - return self.file_type == FileType.DIRECTORY - - @property - def key(self) -> typing.Union[str, None]: - """Return the key of the file object.""" - return self._key - - @property - def objects(self) -> typing.Dict[str, 'File']: - """Return the objects of the file object.""" - return self._objects - - def __eq__(self, other) -> bool: - """Return whether this instance is equal to another file object instance.""" - if not isinstance(other, self.__class__): - return False - - equal_attributes = all(getattr(self, key) == getattr(other, key) for key in ['name', 'file_type', 'key']) - equal_object_keys = sorted(self.objects) == sorted(other.objects) - equal_objects = equal_object_keys and all(obj == other.objects[key] for key, obj in self.objects.items()) - - return equal_attributes and equal_objects - - def __repr__(self): - args = (self.name, self.file_type.value, self.key, self.objects.items()) - return 'File'.format(*args) diff --git a/aiida/repository/repository.py b/aiida/repository/repository.py deleted file mode 100644 index b61575f0b1..0000000000 --- a/aiida/repository/repository.py +++ /dev/null @@ -1,543 +0,0 @@ -# -*- coding: utf-8 -*- -"""Module for the implementation of a file repository.""" -import contextlib -import pathlib -from typing import Any, BinaryIO, Dict, Iterable, Iterator, List, Optional, Tuple, Union - -from aiida.common.hashing import make_hash -from aiida.common.lang import type_check - -from .backend import AbstractRepositoryBackend, SandboxRepositoryBackend -from .common import File, FileType - -__all__ = ('Repository',) - -FilePath = Union[str, pathlib.PurePosixPath] - - -class Repository: - """File repository. - - This class provides an interface to a backend file repository instance, but unlike the backend repository, this - class keeps a reference of the virtual file hierarchy. This means that through this interface, a client can create - files and directories with a file hierarchy, just as they would on a local file system, except it is completely - virtual as the files are stored by the backend which can store them in a completely flat structure. This also means - that the internal virtual hierarchy of a ``Repository`` instance does not necessarily represent all the files that - are stored by repository backend. The repository exposes a mere subset of all the file objects stored in the - backend. This is why object deletion is also implemented as a soft delete, by default, where the files are just - removed from the internal virtual hierarchy, but not in the actual backend. This is because those objects can be - referenced by other instances. - """ - - # pylint: disable=too-many-public-methods - - _file_cls = File - - def __init__(self, backend: Optional[AbstractRepositoryBackend] = None): - """Construct a new instance with empty metadata. - - :param backend: instance of repository backend to use to actually store the file objects. By default, an - instance of the ``SandboxRepositoryBackend`` will be created. - """ - if backend is None: - backend = SandboxRepositoryBackend() - - self.set_backend(backend) - self.reset() - - def __str__(self) -> str: - """Return the string representation of this repository.""" - return f'Repository<{str(self.backend)}>' - - @property - def uuid(self) -> Optional[str]: - """Return the unique identifier of the repository backend or ``None`` if it doesn't have one.""" - return self.backend.uuid - - @property - def is_initialised(self) -> bool: - """Return whether the repository backend has been initialised.""" - return self.backend.is_initialised - - @classmethod - def from_serialized(cls, backend: AbstractRepositoryBackend, serialized: Dict[str, Any]) -> 'Repository': - """Construct an instance where the metadata is initialized from the serialized content. - - :param backend: instance of repository backend to use to actually store the file objects. - """ - instance = cls.__new__(cls) - instance.__init__(backend) # type: ignore[misc] # pylint: disable=unnecessary-dunder-call - - if serialized: - for name, obj in serialized['o'].items(): - instance.get_directory().objects[name] = cls._file_cls.from_serialized(obj, name) - - return instance - - def reset(self) -> None: - self._directory = self._file_cls() - - def serialize(self) -> Dict[str, Any]: - """Serialize the metadata into a JSON-serializable format. - - :return: dictionary with the content metadata. - """ - return self._directory.serialize() - - @classmethod - def flatten(cls, serialized=Optional[Dict[str, Any]], delimiter: str = '/') -> Dict[str, Optional[str]]: - """Flatten the serialized content of a repository into a mapping of path -> key or None (if folder). - - Note, all folders are represented in the flattened output, and their path is suffixed with the delimiter. - - :param serialized: the serialized content of the repository. - :param delimiter: the delimiter to use to separate the path elements. - :return: dictionary with the flattened content. - """ - if serialized is None: - return {} - items: Dict[str, Optional[str]] = {} - stack = [('', serialized)] - while stack: - path, sub_dict = stack.pop() - for name, obj in sub_dict.get('o', {}).items(): - sub_path = f'{path}{delimiter}{name}' if path else name - if not obj: - items[f'{sub_path}{delimiter}'] = None - elif 'k' in obj: - items[sub_path] = obj['k'] - else: - items[f'{sub_path}{delimiter}'] = None - stack.append((sub_path, obj)) - return items - - def hash(self) -> str: - """Generate a hash of the repository's contents. - - .. warning:: this will read the content of all file objects contained within the virtual hierarchy into memory. - - :return: the hash representing the contents of the repository. - """ - objects: Dict[str, Any] = {} - for root, dirnames, filenames in self.walk(): - objects['__dirnames__'] = dirnames - for filename in filenames: - key = self.get_file(root / filename).key - assert key is not None, 'Expected FileType.File to have a key' - objects[str(root / filename)] = self.backend.get_object_hash(key) - - return make_hash(objects) - - @staticmethod - def _pre_process_path(path: Optional[FilePath] = None) -> pathlib.PurePosixPath: - """Validate and convert the path to instance of ``pathlib.PurePosixPath``. - - This should be called by every method of this class before doing anything, such that it can safely assume that - the path is a ``pathlib.PurePosixPath`` object, which makes path manipulation a lot easier. - - :param path: the path as a ``pathlib.PurePosixPath`` object or `None`. - :raises TypeError: if the type of path was not a str nor a ``pathlib.PurePosixPath`` instance. - """ - if path is None: - return pathlib.PurePosixPath() - - if isinstance(path, str): - path = pathlib.PurePosixPath(path) - - if not isinstance(path, pathlib.PurePosixPath): - raise TypeError('path is not of type `str` nor `pathlib.PurePosixPath`.') - - if path.is_absolute(): - raise TypeError(f'path `{path}` is not a relative path.') - - return path - - @property - def backend(self) -> AbstractRepositoryBackend: - """Return the current repository backend. - - :return: the repository backend. - """ - return self._backend - - def set_backend(self, backend: AbstractRepositoryBackend) -> None: - """Set the backend for this repository. - - :param backend: the repository backend. - :raises TypeError: if the type of the backend is invalid. - """ - type_check(backend, AbstractRepositoryBackend) - self._backend = backend - - def _insert_file(self, path: pathlib.PurePosixPath, key: str) -> None: - """Insert a new file object in the object mapping. - - .. note:: this assumes the path is a valid relative path, so should be checked by the caller. - - :param path: the relative path where to store the object in the repository. - :param key: fully qualified identifier for the object within the repository. - """ - directory = self.create_directory(path.parent) - directory.objects[path.name] = self._file_cls(path.name, FileType.FILE, key) - - def create_directory(self, path: FilePath) -> File: - """Create a new directory with the given path. - - :param path: the relative path of the directory. - :return: the created directory. - :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. - """ - if path is None: - raise TypeError('path cannot be `None`.') - - path = self._pre_process_path(path) - directory = self._directory - - for part in path.parts: - if part not in directory.objects: - directory.objects[part] = self._file_cls(part) - - directory = directory.objects[part] - - return directory - - def get_file_keys(self) -> List[str]: - """Return the keys of all file objects contained within this repository. - - :return: list of keys, which map a file to its content in the backend repository. - """ - file_keys: List[str] = [] - - def _add_file_keys(keys, objects): - """Recursively add keys of all file objects to the keys list.""" - for obj in objects.values(): - if obj.file_type == FileType.FILE and obj.key is not None: - keys.append(obj.key) - elif obj.file_type == FileType.DIRECTORY: - _add_file_keys(keys, obj.objects) - - _add_file_keys(file_keys, self._directory.objects) - - return file_keys - - def get_object(self, path: Optional[FilePath] = None) -> File: - """Return the object at the given path. - - :param path: the relative path where to store the object in the repository. - :return: the `File` representing the object located at the given relative path. - :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. - :raises FileNotFoundError: if no object exists for the given path. - """ - path = self._pre_process_path(path) - file_object = self._directory - - if not path.parts: - return file_object - - for part in path.parts: - if part not in file_object.objects: - raise FileNotFoundError(f'object with path `{path}` does not exist.') - - file_object = file_object.objects[part] - - return file_object - - def get_directory(self, path: Optional[FilePath] = None) -> File: - """Return the directory object at the given path. - - :param path: the relative path of the directory. - :return: the `File` representing the object located at the given relative path. - :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. - :raises FileNotFoundError: if no object exists for the given path. - :raises NotADirectoryError: if the object at the given path is not a directory. - """ - file_object = self.get_object(path) - - if file_object.file_type != FileType.DIRECTORY: - raise NotADirectoryError(f'object with path `{path}` is not a directory.') - - return file_object - - def get_file(self, path: FilePath) -> File: - """Return the file object at the given path. - - :param path: the relative path of the file object. - :return: the `File` representing the object located at the given relative path. - :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. - :raises FileNotFoundError: if no object exists for the given path. - :raises IsADirectoryError: if the object at the given path is not a directory. - """ - if path is None: - raise TypeError('path cannot be `None`.') - - path = self._pre_process_path(path) - - file_object = self.get_object(path) - - if file_object.file_type != FileType.FILE: - raise IsADirectoryError(f'object with path `{path}` is not a file.') - - return file_object - - def list_objects(self, path: Optional[FilePath] = None) -> List[File]: - """Return a list of the objects contained in this repository sorted by name, optionally in given sub directory. - - :param path: the relative path of the directory. - :return: a list of `File` named tuples representing the objects present in directory with the given path. - :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. - :raises FileNotFoundError: if no object exists for the given path. - :raises NotADirectoryError: if the object at the given path is not a directory. - """ - directory = self.get_directory(path) - return sorted(directory.objects.values(), key=lambda obj: obj.name) - - def list_object_names(self, path: Optional[FilePath] = None) -> List[str]: - """Return a sorted list of the object names contained in this repository, optionally in the given sub directory. - - :param path: the relative path of the directory. - :return: a list of `File` named tuples representing the objects present in directory with the given path. - :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. - :raises FileNotFoundError: if no object exists for the given path. - :raises NotADirectoryError: if the object at the given path is not a directory. - """ - return [entry.name for entry in self.list_objects(path)] - - def put_object_from_filelike(self, handle: BinaryIO, path: FilePath) -> None: - """Store the byte contents of a file in the repository. - - :param handle: filelike object with the byte content to be stored. - :param path: the relative path where to store the object in the repository. - :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. - """ - path = self._pre_process_path(path) - key = self.backend.put_object_from_filelike(handle) - self._insert_file(path, key) - - def put_object_from_file(self, filepath: FilePath, path: FilePath) -> None: - """Store a new object under `path` with contents of the file located at `filepath` on the local file system. - - :param filepath: absolute path of file whose contents to copy to the repository - :param path: the relative path where to store the object in the repository. - :raises TypeError: if the path is not a string and relative path, or the handle is not a byte stream. - """ - with open(filepath, 'rb') as handle: - self.put_object_from_filelike(handle, path) - - def put_object_from_tree(self, filepath: FilePath, path: Optional[FilePath] = None) -> None: - """Store the entire contents of `filepath` on the local file system in the repository with under given `path`. - - :param filepath: absolute path of the directory whose contents to copy to the repository. - :param path: the relative path where to store the objects in the repository. - :raises TypeError: if the filepath is not a string or ``Path``, or is a relative path. - :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. - """ - import os - - path = self._pre_process_path(path) - - if isinstance(filepath, str): - filepath = pathlib.PurePosixPath(filepath) - - if not isinstance(filepath, pathlib.PurePosixPath): - raise TypeError(f'filepath `{filepath}` is not of type `str` nor `pathlib.PurePosixPath`.') - - if not filepath.is_absolute(): - raise TypeError(f'filepath `{filepath}` is not an absolute path.') - - # Explicitly create the base directory if specified by `path`, just in case `filepath` contains no file objects. - if path.parts: - self.create_directory(path) - - for root_str, dirnames, filenames in os.walk(filepath): - - root = pathlib.PurePosixPath(root_str) - - for dirname in dirnames: - self.create_directory(path / root.relative_to(filepath) / dirname) - - for filename in filenames: - self.put_object_from_file(root / filename, path / root.relative_to(filepath) / filename) - - def is_empty(self) -> bool: - """Return whether the repository is empty. - - :return: True if the repository contains no file objects. - """ - return not self._directory.objects - - def has_object(self, path: FilePath) -> bool: - """Return whether the repository has an object with the given path. - - :param path: the relative path of the object within the repository. - :return: True if the object exists, False otherwise. - :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. - """ - try: - self.get_object(path) - except FileNotFoundError: - return False - return True - - @contextlib.contextmanager - def open(self, path: FilePath) -> Iterator[BinaryIO]: - """Open a file handle to an object stored under the given path. - - .. note:: this should only be used to open a handle to read an existing file. To write a new file use the method - ``put_object_from_filelike`` instead. - - :param path: the relative path of the object within the repository. - :return: yield a byte stream object. - :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. - :raises FileNotFoundError: if the file does not exist. - :raises IsADirectoryError: if the object is a directory and not a file. - :raises OSError: if the file could not be opened. - """ - key = self.get_file(path).key - assert key is not None, 'Expected FileType.File to have a key' - with self.backend.open(key) as handle: - yield handle - - def get_object_content(self, path: FilePath) -> bytes: - """Return the content of a object identified by path. - - :param path: the relative path of the object within the repository. - :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. - :raises FileNotFoundError: if the file does not exist. - :raises IsADirectoryError: if the object is a directory and not a file. - :raises OSError: if the file could not be opened. - """ - key = self.get_file(path).key - assert key is not None, 'Expected FileType.File to have a key' - return self.backend.get_object_content(key) - - def delete_object(self, path: FilePath, hard_delete: bool = False) -> None: - """Soft delete the object from the repository. - - .. note:: can only delete file objects, but not directories. - - :param path: the relative path of the object within the repository. - :param hard_delete: when true, not only remove the file from the internal mapping but also call through to the - ``delete_object`` method of the actual repository backend. - :raises TypeError: if the path is not a string or ``Path``, or is an absolute path. - :raises FileNotFoundError: if the file does not exist. - :raises IsADirectoryError: if the object is a directory and not a file. - :raises OSError: if the file could not be deleted. - """ - path = self._pre_process_path(path) - file_object = self.get_object(path) - - if file_object.file_type == FileType.DIRECTORY: - raise IsADirectoryError(f'object with path `{path}` is a directory.') - - if hard_delete: - assert file_object.key is not None, 'Expected FileType.File to have a key' - self.backend.delete_object(file_object.key) - - directory = self.get_directory(path.parent) - directory.objects.pop(path.name) - - def erase(self) -> None: - """Delete all objects from the repository. - - .. important: this intentionally does not call through to any ``erase`` method of the backend, because unlike - this class, the backend does not just store the objects of a single node, but potentially of a lot of other - nodes. Therefore, we manually delete all file objects and then simply reset the internal file hierarchy. - - """ - for file_key in self.get_file_keys(): - self.backend.delete_object(file_key) - self.reset() - - def clone(self, source: 'Repository') -> None: - """Clone the contents of another repository instance.""" - if not isinstance(source, Repository): - raise TypeError('source is not an instance of `Repository`.') - - for root, dirnames, filenames in source.walk(): - for dirname in dirnames: - self.create_directory(root / dirname) - for filename in filenames: - with source.open(root / filename) as handle: - self.put_object_from_filelike(handle, root / filename) - - def walk(self, path: Optional[FilePath] = None) -> Iterable[Tuple[pathlib.PurePosixPath, List[str], List[str]]]: - """Walk over the directories and files contained within this repository. - - .. note:: the order of the dirname and filename lists that are returned is not necessarily sorted. This is in - line with the ``os.walk`` implementation where the order depends on the underlying file system used. - - :param path: the relative path of the directory within the repository whose contents to walk. - :return: tuples of root, dirnames and filenames just like ``os.walk``, with the exception that the root path is - always relative with respect to the repository root, instead of an absolute path and it is an instance of - ``pathlib.PurePosixPath`` instead of a normal string - """ - path = self._pre_process_path(path) - - directory = self.get_directory(path) - dirnames = [obj.name for obj in directory.objects.values() if obj.file_type == FileType.DIRECTORY] - filenames = [obj.name for obj in directory.objects.values() if obj.file_type == FileType.FILE] - - if dirnames: - for dirname in dirnames: - yield from self.walk(path / dirname) - - yield path, dirnames, filenames - - def copy_tree(self, target: Union[str, pathlib.Path], path: Optional[FilePath] = None) -> None: - """Copy the contents of the entire node repository to another location on the local file system. - - .. note:: If ``path`` is specified, only its contents are copied, and the relative path with respect to the - root is discarded. For example, if ``path`` is ``relative/sub``, the contents of ``sub`` will be copied - directly to the target, without the ``relative/sub`` directory. - - :param target: absolute path of the directory where to copy the contents to. - :param path: optional relative path whose contents to copy. - :raises TypeError: if ``target`` is of incorrect type or not absolute. - :raises NotADirectoryError: if ``path`` does not reference a directory. - """ - path = self._pre_process_path(path) - file_object = self.get_object(path) - - if file_object.file_type != FileType.DIRECTORY: - raise NotADirectoryError(f'object with path `{path}` is not a directory.') - - if isinstance(target, str): - target = pathlib.Path(target) - - if not isinstance(target, pathlib.Path): - raise TypeError(f'path `{path}` is not of type `str` nor `pathlib.Path`.') - - if not target.is_absolute(): - raise TypeError(f'provided target `{target}` is not an absolute path.') - - for root, dirnames, filenames in self.walk(path): - for dirname in dirnames: - dirpath = target / (root / dirname).relative_to(path) - dirpath.mkdir(parents=True, exist_ok=True) - - for filename in filenames: - dirpath = target / root.relative_to(path) - filepath = dirpath / filename - - dirpath.mkdir(parents=True, exist_ok=True) - - with self.open(root / filename) as handle: - filepath.write_bytes(handle.read()) - - # these methods are not actually used in aiida-core, but are here for completeness - - def initialise(self, **kwargs: Any) -> None: - """Initialise the repository if it hasn't already been initialised. - - :param kwargs: keyword argument that will be passed to the ``initialise`` call of the backend. - """ - self.backend.initialise(**kwargs) - - def delete(self) -> None: - """Delete the repository. - - .. important:: This will not just delete the contents of the repository but also the repository itself and all - of its assets. For example, if the repository is stored inside a folder on disk, the folder may be deleted. - """ - self.backend.erase() - self.reset() diff --git a/aiida/restapi/__init__.py b/aiida/restapi/__init__.py deleted file mode 100644 index 5cdd575a4a..0000000000 --- a/aiida/restapi/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -In this module, AiiDA provides REST API to access different -AiiDA nodes stored in database. The REST API is implemented -using Flask RESTFul framework. -""" - -# AUTO-GENERATED - -__all__ = () diff --git a/aiida/restapi/common/__init__.py b/aiida/restapi/common/__init__.py deleted file mode 100644 index 2776a55f97..0000000000 --- a/aiida/restapi/common/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### diff --git a/aiida/restapi/common/config.py b/aiida/restapi/common/config.py deleted file mode 100644 index 0569640824..0000000000 --- a/aiida/restapi/common/config.py +++ /dev/null @@ -1,51 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -Default configuration for the REST API -""" -import os - -API_CONFIG = { - 'LIMIT_DEFAULT': 400, # default records total - 'PERPAGE_DEFAULT': 20, # default records per page - 'PREFIX': '/api/v4', # prefix for all URLs - 'VERSION': '4.1.0', -} - -APP_CONFIG = { - 'DEBUG': False, # use False for production - 'PROPAGATE_EXCEPTIONS': True, # serve REST exceptions to client instead of generic 500 internal server error -} - -SERIALIZER_CONFIG = {'datetime_format': 'default'} # use 'asinput' or 'default' - -CACHE_CONFIG = {'CACHE_TYPE': 'memcached'} -CACHING_TIMEOUTS = { # Caching timeouts in seconds - 'nodes': 10, - 'users': 10, - 'calculations': 10, - 'computers': 10, - 'datas': 10, - 'groups': 10, - 'codes': 10, -} - -# IO tree -MAX_TREE_DEPTH = 5 - -CLI_DEFAULTS = { - 'HOST_NAME': '127.0.0.1', - 'PORT': 5000, - 'CONFIG_DIR': os.path.dirname(os.path.abspath(__file__)), - 'WSGI_PROFILE': False, - 'HOOKUP_APP': True, - 'CATCH_INTERNAL_SERVER': False, - 'POSTING': True, # Include POST endpoints (currently only /querybuilder) -} diff --git a/aiida/restapi/common/exceptions.py b/aiida/restapi/common/exceptions.py deleted file mode 100644 index 083ee24ff7..0000000000 --- a/aiida/restapi/common/exceptions.py +++ /dev/null @@ -1,43 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" -This file contains the exceptions that are raised by the RESTapi at the -highest level, namely that of the interaction with the client. Their -specificity resides into the fact that they return a message that is embedded -into the HTTP response. - -Example: - - .../api/v1/nodes/ ... (TODO compete this with an actual example) - -Other errors arising at deeper level, e.g. those raised by the QueryBuilder -or internal errors, are not embedded into the HTTP response. -""" - -from aiida.common.exceptions import FeatureNotAvailable, InputValidationError, ValidationError - - -class RestValidationError(ValidationError): - """ - If validation error in code - E.g. more than one node available for given uuid - """ - - -class RestInputValidationError(InputValidationError): - """ - If inputs passed in query strings are wrong - """ - - -class RestFeatureNotAvailable(FeatureNotAvailable): - """ - If endpoint is not emplemented for given node type - """ diff --git a/aiida/restapi/common/utils.py b/aiida/restapi/common/utils.py deleted file mode 100644 index 6a26270f52..0000000000 --- a/aiida/restapi/common/utils.py +++ /dev/null @@ -1,846 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -""" Util methods """ -from datetime import datetime, timedelta -import urllib.parse - -from flask import jsonify -from flask.json.provider import DefaultJSONProvider -from wrapt import decorator - -from aiida.common.exceptions import InputValidationError, ValidationError -from aiida.common.utils import DatetimePrecision -from aiida.manage import get_manager -from aiida.restapi.common.exceptions import RestInputValidationError, RestValidationError - -# Important to match querybuilder keys -PK_DBSYNONYM = 'id' -# Example uuid (version 4) -UUID_REF = 'd55082b6-76dc-426b-af89-0e08b59524d2' - - -########################## Classes ##################### -class CustomJSONProvider(DefaultJSONProvider): - """ - Custom json encoder for serialization. - This has to be provided to the Flask app in order to replace the default - encoder. - """ - - def default(self, obj, **kwargs): - """ - Override serialization of ``DefaultJSONProvider`` for ``datetime`` and ``bytes`` objects. - - :param obj: Object e.g. dict, list that will be serialized. - :return: Serialized object as a string. - """ - - from aiida.restapi.common.config import SERIALIZER_CONFIG - - # Treat the datetime objects - if isinstance(obj, datetime): - if 'datetime_format' in SERIALIZER_CONFIG and SERIALIZER_CONFIG['datetime_format'] != 'default': - if SERIALIZER_CONFIG['datetime_format'] == 'asinput': - if obj.utcoffset() is not None: - obj = obj - obj.utcoffset() - return '-'.join([str(obj.year), str(obj.month).zfill(2), - str(obj.day).zfill(2)]) + 'T' + \ - ':'.join([str( - obj.hour).zfill(2), str(obj.minute).zfill(2), - str(obj.second).zfill(2)]) - - # To support bytes objects, try to decode to a string - try: - return obj.decode('utf-8') - except (UnicodeDecodeError, AttributeError): - pass - - # If not returned yet, do it in the default way - return super().default(obj, **kwargs) - - -class Utils: - """ - A class that gathers all the utility functions for parsing URI, - validating request, pass it to the translator, and building HTTP response - - An istance of Utils has to be included in the api class so that the - configuration parameters used to build the api are automatically - accessible by the methods of Utils without the need to import them from - the config.py file. - - """ - - # Conversion map from the query_string operators to the query_builder - # operators - op_conv_map = { - '=': '==', - '!=': '!==', - '=in=': 'in', - '=notin=': '!in', - '>': '>', - '<': '<', - '>=': '>=', - '<=': '<=', - '=like=': 'like', - '=ilike=': 'ilike' - } - - def __init__(self, **kwargs): - """ - Sets internally the configuration parameters - """ - - self.prefix = kwargs['PREFIX'] - self.perpage_default = kwargs['PERPAGE_DEFAULT'] - self.limit_default = kwargs['LIMIT_DEFAULT'] - - def strip_api_prefix(self, path): - """ - Removes the PREFIX from an URL path. PREFIX must be defined in the - config.py file:: - - PREFIX = "/api/v2" - path = "/api/v2/calculations/page/2" - strip_api_prefix(path) ==> "/calculations/page/2" - - :param path: the URL path string - :return: the same URL without the prefix - """ - if path.startswith(self.prefix): - return path[len(self.prefix):] - - raise ValidationError(f'path has to start with {self.prefix}') - - @staticmethod - def split_path(path): - """ - :param path: entire path contained in flask request - :return: list of each element separated by '/' - """ - return [f for f in path.split('/') if f] - - def parse_path(self, path_string, parse_pk_uuid=None): - # pylint: disable=too-many-return-statements,too-many-branches, too-many-statements - """ - Takes the path and parse it checking its validity. Does not parse "io", - "content" fields. I do not check the validity of the path, since I assume - that this is done by the Flask routing methods. - - :param path_string: the path string - :param parse_id_uuid: if 'pk' ('uuid') expects an integer (uuid starting pattern) - :return: resource_type (string) - page (integer) - node_id (string: uuid starting pattern, int: pk) - query_type (string)) - """ - - ## Initialization - page = None - node_id = None - query_type = 'default' - path = self.split_path(self.strip_api_prefix(path_string)) - - ## Pop out iteratively the "words" of the path until it is an empty - # list. - ## This way it should be easier to plug in more endpoint logic - - # Resource type - resource_type = path.pop(0) - if not path: - return (resource_type, page, node_id, query_type) - - # Validate uuid or starting pattern of uuid. - # Technique: - take our UUID_REF and replace the first characters the - # string to be validated as uuid. - # - validate instead the newly built string - if parse_pk_uuid == 'pk': - raw_id = path[0] - try: - # Check whether it can be an integer - node_id = int(raw_id) - except ValueError: - pass - else: - path.pop(0) - elif parse_pk_uuid == 'uuid': - import uuid - raw_id = path[0] - maybe_uuid = raw_id + UUID_REF[len(raw_id):] - try: - _ = uuid.UUID(maybe_uuid, version=4) - except ValueError: - # assume that it cannot be an id and go to the next check - pass - else: - # It is a node_id so pop out the path element - node_id = raw_id - path.pop(0) - - if not path: - return (resource_type, page, node_id, query_type) - - if path[0] in [ - 'projectable_properties', 'statistics', 'full_types', 'full_types_count', 'download', 'download_formats', - 'report', 'status', 'input_files', 'output_files' - ]: - query_type = path.pop(0) - if path: - raise RestInputValidationError('Given url does not accept further fields') - elif path[0] in ['links', 'contents']: - path.pop(0) - query_type = path.pop(0) - elif path[0] in ['repo']: - path.pop(0) - query_type = f'repo_{path.pop(0)}' - - if not path: - return (resource_type, page, node_id, query_type) - - # Page (this has to be in any case the last field) - if path[0] == 'page': - path.pop(0) - if not path: - page = 1 - return (resource_type, page, node_id, query_type) - page = int(path.pop(0)) - else: - raise RestInputValidationError('The requested URL is not found on the server.') - - return (resource_type, page, node_id, query_type) - - def validate_request( - self, limit=None, offset=None, perpage=None, page=None, query_type=None, is_querystring_defined=False - ): - # pylint: disable=fixme,too-many-arguments,too-many-branches - """ - Performs various checks on the consistency of the request. - Add here all the checks that you want to do, except validity of the page - number that is done in paginate(). - Future additional checks must be added here - """ - - # TODO Consider using **kwargs so to make easier to add more validations - # 1. perpage incompatible with offset and limits - if perpage is not None and (limit is not None or offset is not None): - raise RestValidationError('perpage key is incompatible with limit and offset') - # 2. /page/ in path is incompatible with limit and offset - if page is not None and (limit is not None or offset is not None): - raise RestValidationError('requesting a specific page is incompatible with limit and offset') - # 3. perpage requires that the path contains a page request - if perpage is not None and page is None: - raise RestValidationError( - 'perpage key requires that a page is ' - 'requested (i.e. the path must contain ' - '/page/)' - ) - # 4. No querystring if query type = projectable_properties' - if query_type in ('projectable_properties',) and is_querystring_defined: - raise RestInputValidationError('projectable_properties requests do not allow specifying a query string') - - def paginate(self, page, perpage, total_count): - """ - Calculates limit and offset for the reults of a query, - given the page and the number of restuls per page. - Moreover, calculates the last available page and raises an exception - if the - required page exceeds that limit. - If number of rows==0, only page 1 exists - :param page: integer number of the page that has to be viewed - :param perpage: integer defining how many results a page contains - :param total_count: the total number of rows retrieved by the query - :return: integers: limit, offset, rel_pages - """ - from math import ceil - - ## Type checks - # Mandatory params - try: - page = int(page) - except ValueError: - raise InputValidationError('page number must be an integer') - try: - total_count = int(total_count) - except ValueError: - raise InputValidationError('total_count must be an integer') - # Non-mandatory params - if perpage is not None: - try: - perpage = int(perpage) - except ValueError: - raise InputValidationError('perpage must be an integer') - else: - perpage = self.perpage_default - - ## First_page is anyway 1 - first_page = 1 - - ## Calculate last page - if total_count == 0: - last_page = 1 - else: - last_page = int(ceil(total_count / perpage)) - - ## Check validity of required page and calculate limit, offset, - # previous, - # and next page - if page > last_page or page < 1: - raise RestInputValidationError( - f'Non existent page requested. The page range is [{first_page} : {last_page}]' - ) - - limit = perpage - offset = (page - 1) * perpage - prev_page = None - if page > 1: - prev_page = page - 1 - - next_page = None - if page < last_page: - next_page = page + 1 - - rel_pages = {'prev': prev_page, 'next': next_page, 'first': first_page, 'last': last_page} - - return (limit, offset, rel_pages) - - def build_headers(self, rel_pages=None, url=None, total_count=None): - """ - Construct the header dictionary for an HTTP response. It includes related - pages, total count of results (before pagination). - - :param rel_pages: a dictionary defining related pages (first, prev, next, last) - :param url: (string) the full url, i.e. the url that the client uses to get Rest resources - """ - - ## Type validation - # mandatory parameters - try: - total_count = int(total_count) - except ValueError: - raise InputValidationError('total_count must be a long integer') - - # non mandatory parameters - if rel_pages is not None and not isinstance(rel_pages, dict): - raise InputValidationError('rel_pages must be a dictionary') - - if url is not None: - try: - url = str(url) - except ValueError: - raise InputValidationError('url must be a string') - - ## Input consistency - # rel_pages cannot be defined without url - if rel_pages is not None and url is None: - raise InputValidationError("'rel_pages' parameter requires 'url' parameter to be defined") - - headers = {} - - ## Setting mandatory headers - # set X-Total-Count - headers['X-Total-Count'] = total_count - expose_header = ['X-Total-Count'] - - ## Two auxiliary functions - def split_url(url): - """ Split url into path and query string """ - if '?' in url: - [path, query_string] = url.split('?') - question_mark = '?' - else: - path = url - query_string = '' - question_mark = '' - return (path, query_string, question_mark) - - def make_rel_url(rel, page): - new_path_elems = path_elems + ['page', str(page)] - return f"<{'/'.join(new_path_elems)}{question_mark}{query_string}>; rel={rel}, " - - ## Setting non-mandatory parameters - # set links to related pages - if rel_pages is not None: - (path, query_string, question_mark) = split_url(url) - path_elems = self.split_path(path) - - if path_elems.pop(-1) == 'page' or path_elems.pop(-1) == 'page': - links = [] - for (rel, page) in rel_pages.items(): - if page is not None: - links.append(make_rel_url(rel, page)) - headers['Link'] = ''.join(links) - expose_header.append('Link') - else: - pass - - # to expose header access in cross-domain requests - headers['Access-Control-Expose-Headers'] = ','.join(expose_header) - - return headers - - @staticmethod - def build_response(status=200, headers=None, data=None): - """ - Build the response - - :param status: status of the response, e.g. 200=OK, 400=bad request - :param headers: dictionary for additional header k,v pairs, - e.g. X-total-count= - :param data: a dictionary with the data returned by the Resource - - :return: a Flask response object - """ - - ## Type checks - # mandatory parameters - if not isinstance(data, dict): - raise InputValidationError('data must be a dictionary') - - # non-mandatory parameters - if status is not None: - try: - status = int(status) - except ValueError: - raise InputValidationError('status must be an integer') - - if headers is not None and not isinstance(headers, dict): - raise InputValidationError('header must be a dictionary') - - # Build response - response = jsonify(data) - response.status_code = status - - if headers is not None: - for key, val in headers.items(): - response.headers[key] = val - - return response - - @staticmethod - def build_datetime_filter(dtobj): - """ - This function constructs a filter for a datetime object to be in a - certain datetime interval according to the precision. - - The interval is [reference_datetime, reference_datetime + delta_time], - where delta_time is a function fo the required precision. - - This function should be used to replace a datetime filter based on - the equality operator that is inehrently "picky" because it implies - matching two datetime objects down to the microsecond or something, - by a "tolerant" operator which checks whether the datetime is in an - interval. - - :return: a suitable entry of the filter dictionary - """ - - if not isinstance(dtobj, DatetimePrecision): - raise TypeError('dtobj argument has to be a DatetimePrecision object') - - reference_datetime = dtobj.dtobj - precision = dtobj.precision - - ## Define interval according to the precision - if precision == 1: - delta_time = timedelta(days=1) - elif precision == 2: - delta_time = timedelta(hours=1) - elif precision == 3: - delta_time = timedelta(minutes=1) - elif precision == 4: - delta_time = timedelta(seconds=1) - else: - raise RestValidationError('The datetime resolution is not valid.') - - filters = {'and': [{'>=': reference_datetime}, {'<': reference_datetime + delta_time}]} - - return filters - - def build_translator_parameters(self, field_list): - # pylint: disable=too-many-locals,too-many-statements,too-many-branches - """ - Takes a list of elements resulting from the parsing the query_string and - elaborates them in order to provide translator-compliant instructions - - :param field_list: a (nested) list of elements resulting from parsing the query_string - :returns: the filters in the - """ - ## Create void variables - filters = {} - orderby = [] - limit = None - offset = None - perpage = None - filename = None - download_format = None - download = True - attributes = None - attributes_filter = None - extras = None - extras_filter = None - full_type = None - profile = None - - # io tree limit parameters - tree_in_limit = None - tree_out_limit = None - - ## Count how many time a key has been used for the filters - # and check if reserved keyword have been used twice - field_counts = {} - for field in field_list: - field_key = field[0] - if field_key not in field_counts: - field_counts[field_key] = 1 - # Store the information whether membership operator is used - # is_membership = (field[1] is '=in=') - else: - # Check if the key of a filter using membership operator is used - # in multiple filters - # if is_membership is True or field[1] is '=in=': - # raise RestInputValidationError("If a key appears in " - # "multiple filters, " - # "those cannot use " - # "membership opertor '=in='") - field_counts[field_key] = field_counts[field_key] + 1 - - ## Check the reserved keywords - if 'limit' in field_counts and field_counts['limit'] > 1: - raise RestInputValidationError('You cannot specify limit more than once') - if 'offset' in field_counts and field_counts['offset'] > 1: - raise RestInputValidationError('You cannot specify offset more than once') - if 'perpage' in field_counts and field_counts['perpage'] > 1: - raise RestInputValidationError('You cannot specify perpage more than once') - if 'orderby' in field_counts and field_counts['orderby'] > 1: - raise RestInputValidationError('You cannot specify orderby more than once') - if 'download' in field_counts and field_counts['download'] > 1: - raise RestInputValidationError('You cannot specify download more than once') - if 'download_format' in field_counts and field_counts['download_format'] > 1: - raise RestInputValidationError('You cannot specify download_format more than once') - if 'filename' in field_counts and field_counts['filename'] > 1: - raise RestInputValidationError('You cannot specify filename more than once') - if 'in_limit' in field_counts and field_counts['in_limit'] > 1: - raise RestInputValidationError('You cannot specify in_limit more than once') - if 'out_limit' in field_counts and field_counts['out_limit'] > 1: - raise RestInputValidationError('You cannot specify out_limit more than once') - if 'attributes' in field_counts and field_counts['attributes'] > 1: - raise RestInputValidationError('You cannot specify attributes more than once') - if 'attributes_filter' in field_counts and field_counts['attributes_filter'] > 1: - raise RestInputValidationError('You cannot specify attributes_filter more than once') - if 'extras' in field_counts and field_counts['extras'] > 1: - raise RestInputValidationError('You cannot specify extras more than once') - if 'extras_filter' in field_counts and field_counts['extras_filter'] > 1: - raise RestInputValidationError('You cannot specify extras_filter more than once') - if 'full_type' in field_counts and field_counts['full_type'] > 1: - raise RestInputValidationError('You cannot specify full_type more than once') - if 'profile' in field_counts and field_counts['profile'] > 1: - raise RestInputValidationError('You cannot specify profile more than once') - - ## Extract results - for field in field_list: - if field[0] == 'profile': - if field[1] == '=': - profile = field[2] - else: - raise RestInputValidationError("only assignment operator '=' is permitted after 'profile'") - elif field[0] == 'limit': - if field[1] == '=': - limit = field[2] - else: - raise RestInputValidationError("only assignment operator '=' is permitted after 'limit'") - elif field[0] == 'offset': - if field[1] == '=': - offset = field[2] - else: - raise RestInputValidationError("only assignment operator '=' is permitted after 'offset'") - elif field[0] == 'perpage': - if field[1] == '=': - perpage = field[2] - else: - raise RestInputValidationError("only assignment operator '=' is permitted after 'perpage'") - - elif field[0] == 'orderby': - if field[1] == '=': - # Consider value (gives string) and value_list (gives list of - # strings) cases - if isinstance(field[2], list): - orderby.extend(field[2]) - else: - orderby.extend([field[2]]) - else: - raise RestInputValidationError("only assignment operator '=' is permitted after 'orderby'") - - elif field[0] == 'download': - if field[1] == '=': - download = field[2] - else: - raise RestInputValidationError("only assignment operator '=' is permitted after 'download'") - - elif field[0] == 'download_format': - if field[1] == '=': - download_format = field[2] - else: - raise RestInputValidationError("only assignment operator '=' is permitted after 'download_format'") - - elif field[0] == 'filename': - if field[1] == '=': - filename = field[2] - else: - raise RestInputValidationError("only assignment operator '=' is permitted after 'filename'") - - elif field[0] == 'full_type': - if field[1] == '=': - full_type = field[2] - else: - raise RestInputValidationError("only assignment operator '=' is permitted after 'full_type'") - - elif field[0] == 'in_limit': - if field[1] == '=': - tree_in_limit = field[2] - else: - raise RestInputValidationError("only assignment operator '=' is permitted after 'in_limit'") - - elif field[0] == 'out_limit': - if field[1] == '=': - tree_out_limit = field[2] - else: - raise RestInputValidationError("only assignment operator '=' is permitted after 'out_limit'") - - elif field[0] == 'attributes': - if field[1] == '=': - attributes = field[2] - else: - raise RestInputValidationError("only assignment operator '=' is permitted after 'attributes'") - - elif field[0] == 'attributes_filter': - if field[1] == '=': - attributes_filter = field[2] - else: - raise RestInputValidationError( - "only assignment operator '=' is permitted after 'attributes_filter'" - ) - elif field[0] == 'extras': - if field[1] == '=': - extras = field[2] - else: - raise RestInputValidationError("only assignment operator '=' is permitted after 'extras'") - - elif field[0] == 'extras_filter': - if field[1] == '=': - extras_filter = field[2] - else: - raise RestInputValidationError("only assignment operator '=' is permitted after 'extras_filter'") - - else: - - ## Construct the filter entry. - field_key = field[0] - operator = field[1] - field_value = field[2] - - if isinstance(field_value, DatetimePrecision) and operator == '=': - filter_value = self.build_datetime_filter(field_value) - else: - filter_value = {self.op_conv_map[field[1]]: field_value} - - # Here I treat the AND clause - if field_counts[field_key] > 1: - - if field_key not in filters: - filters.update({field_key: {'and': [filter_value]}}) - else: - filters[field_key]['and'].append(filter_value) - else: - filters.update({field_key: filter_value}) - - # #Impose defaults if needed - # if limit is None: - # limit = self.limit_default - - return ( - limit, offset, perpage, orderby, filters, download_format, download, filename, tree_in_limit, - tree_out_limit, attributes, attributes_filter, extras, extras_filter, full_type, profile - ) - - def parse_query_string(self, query_string): - # pylint: disable=too-many-locals - """ - Function that parse the querystring, extracting infos for limit, offset, - ordering, filters, attribute and extra projections. - :param query_string (as obtained from request.query_string) - :return: parsed values for the querykeys - """ - from psycopg2.tz import FixedOffsetTimezone - from pyparsing import Combine, Group, Literal, OneOrMore, Optional, ParseException, QuotedString - from pyparsing import StringEnd as SE - from pyparsing import StringStart as SS - from pyparsing import Suppress, Word - from pyparsing import WordEnd as WE - from pyparsing import ZeroOrMore, alphanums, alphas, nums, printables - from pyparsing import pyparsing_common as ppc - - ## Define grammar - # key types - key = Word(f'{alphas}_', f'{alphanums}_-') - # operators - operator = ( - Literal('=like=') | Literal('=ilike=') | Literal('=in=') | Literal('=notin=') | Literal('=') | - Literal('!=') | Literal('>=') | Literal('>') | Literal('<=') | Literal('<') - ) - # Value types - value_num = ppc.number - value_bool = (Literal('true') | Literal('false')).addParseAction(lambda toks: toks[0] == 'true') - value_string = QuotedString('"', escQuote='""') - value_orderby = Combine(Optional(Word('+-', exact=1)) + key) - - ## DateTimeShift value. First, compose the atomic values and then - # combine - # them and convert them to datetime objects - # Date - value_date = Combine( - Word(nums, exact=4) + Literal('-') + Word(nums, exact=2) + Literal('-') + Word(nums, exact=2) - ) - # Time - value_time = Combine( - Literal('T') + Word(nums, exact=2) + Optional(Literal(':') + Word(nums, exact=2)) + - Optional(Literal(':') + Word(nums, exact=2)) - ) - # Shift - value_shift = Combine(Word('+-', exact=1) + Word(nums, exact=2) + Optional(Literal(':') + Word(nums, exact=2))) - # Combine atomic values - value_datetime = Combine( - value_date + Optional(value_time) + Optional(value_shift) + WE(printables.replace('&', '')) - # To us the - # word must end with '&' or end of the string - # Adding WordEnd only here is very important. This makes atomic - # values for date, time and shift not really - # usable alone individually. - ) - - ######################################################################## - - def validate_time(toks): - """ - Function to convert datetime string into datetime object. The format is - compliant with ParseAction requirements - - :param toks: datetime string passed in tokens - :return: datetime object - """ - datetime_string = toks[0] - - # Check the precision - precision = len(datetime_string.replace('T', ':').split(':')) - - # Parse - try: - dtobj = datetime.fromisoformat(datetime_string) - except ValueError: - raise RestInputValidationError( - 'time value has wrong format. The ' - 'right format is ' - 'T