diff --git a/aiida/manage/configuration/config.py b/aiida/manage/configuration/config.py index d88f81f30c..7b53339b12 100644 --- a/aiida/manage/configuration/config.py +++ b/aiida/manage/configuration/config.py @@ -16,9 +16,11 @@ from __future__ import annotations import codecs +import contextlib +import io import json import os -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import uuid from pydantic import ( # pylint: disable=no-name-in-module @@ -33,11 +35,19 @@ from aiida.common.exceptions import ConfigurationError from aiida.common.log import LogLevels +from aiida.common.exceptions import ConfigurationError, EntryPointError, StorageMigrationError +from aiida.common.log import AIIDA_LOGGER + from .options import Option, get_option, get_option_names, parse_option from .profile import Profile __all__ = ('Config',) +if TYPE_CHECKING: + from aiida.orm.implementation.storage_backend import StorageBackend + +LOGGER = AIIDA_LOGGER.getChild(__file__) + class ConfigVersionSchema(BaseModel, defer_build=True): """Schema for the version configuration of an AiiDA instance.""" @@ -126,7 +136,6 @@ def validate_caching_identifier_pattern(cls, value: List[str]) -> List[str]: from aiida.manage.caching import _validate_identifier_pattern for identifier in value: _validate_identifier_pattern(identifier=identifier) - return value @@ -446,6 +455,70 @@ def get_profile(self, name: Optional[str] = None) -> Profile: return self._profiles[name] + def create_profile(self, name: str, storage_cls: Type['StorageBackend'], storage_config: dict[str, str]) -> Profile: + """Create a new profile and initialise its storage. + + :param name: The profile name. + :param storage_cls: The :class:`aiida.orm.implementation.storage_backend.StorageBackend` implementation to use. + :param storage_config: The configuration necessary to initialise and connect to the storage backend. + :returns: The created profile. + :raises ValueError: If the profile already exists. + :raises TypeError: If the ``storage_cls`` is not a subclass of + :class:`aiida.orm.implementation.storage_backend.StorageBackend`. + :raises EntryPointError: If the ``storage_cls`` does not have an associated entry point. + :raises StorageMigrationError: If the storage cannot be initialised. + """ + from aiida.orm.implementation.storage_backend import StorageBackend + from aiida.plugins.entry_point import get_entry_point_from_class + + if name in self.profile_names: + raise ValueError(f'The profile `{name}` already exists.') + + if not issubclass(storage_cls, StorageBackend): + raise TypeError( + f'The `storage_cls={storage_cls}` is not subclass of `aiida.orm.implementationStorageBackend`.' + ) + + _, storage_entry_point = get_entry_point_from_class(storage_cls.__module__, storage_cls.__name__) + + if storage_entry_point is None: + raise EntryPointError(f'`{storage_cls}` does not have a registered entry point.') + + profile = Profile( + name, { + 'storage': { + 'backend': storage_entry_point.name, + 'config': storage_config, + }, + 'process_control': { + 'backend': 'rabbitmq', + 'config': { + 'broker_protocol': 'amqp', + 'broker_username': 'guest', + 'broker_password': 'guest', + 'broker_host': '127.0.0.1', + 'broker_port': 5672, + 'broker_virtual_host': '' + } + }, + } + ) + + LOGGER.report('Initialising the storage backend.') + try: + with contextlib.redirect_stdout(io.StringIO()): + profile.storage_cls.initialise(profile) + except Exception as exception: # pylint: disable=broad-except + raise StorageMigrationError( + f'Storage backend initialisation failed, probably because the configuration is incorrect:\n{exception}' + ) + LOGGER.report('Storage initialisation completed.') + + self.add_profile(profile) + self.store() + + return profile + def add_profile(self, profile): """Add a profile to the configuration. diff --git a/tests/manage/configuration/test_config.py b/tests/manage/configuration/test_config.py index 613a346617..35786ff683 100644 --- a/tests/manage/configuration/test_config.py +++ b/tests/manage/configuration/test_config.py @@ -11,6 +11,7 @@ import json import os import pathlib +import uuid import pytest @@ -18,6 +19,8 @@ from aiida.manage.configuration import Config, Profile, settings from aiida.manage.configuration.migrations import CURRENT_CONFIG_VERSION, OLDEST_COMPATIBLE_CONFIG_VERSION from aiida.manage.configuration.options import get_option +from aiida.orm.implementation.storage_backend import StorageBackend +from aiida.storage.sqlite_temp import SqliteTempBackend @pytest.fixture @@ -418,3 +421,38 @@ def test_delete_profile(config_with_profile, profile_factory): # Now reload the config from disk to make sure the changes after deletion were persisted to disk config_on_disk = Config.from_file(config.filepath) assert profile_name not in config_on_disk.profile_names + + +def test_create_profile_raises(config_with_profile, monkeypatch): + """Test the ``create_profile`` method when it raises.""" + config = config_with_profile + profile_name = uuid.uuid4().hex + + def raise_storage_migration_error(*args, **kwargs): + raise exceptions.StorageMigrationError() + + monkeypatch.setattr(SqliteTempBackend, 'initialise', raise_storage_migration_error) + + class UnregisteredStorageBackend(StorageBackend): + pass + + with pytest.raises(ValueError, match=r'The profile `.*` already exists.'): + config.create_profile(config_with_profile.default_profile_name, SqliteTempBackend, {}) + + with pytest.raises(TypeError, match=r'The `storage_cls=.*` is not subclass of `.*`.'): + config.create_profile(profile_name, object, {}) + + with pytest.raises(exceptions.EntryPointError, match=r'.*does not have a registered entry point.'): + config.create_profile(profile_name, UnregisteredStorageBackend, {}) + + with pytest.raises(exceptions.StorageMigrationError, match='Storage backend initialisation failed.*'): + config.create_profile(profile_name, SqliteTempBackend, {}) + + +def test_create_profile(config_with_profile): + """Test the ``create_profile`` method.""" + config = config_with_profile + profile_name = uuid.uuid4().hex + + config.create_profile(profile_name, SqliteTempBackend, {}) + assert profile_name in config.profile_names