Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config: Switch from jsonschema to pydantic #6117

Merged
merged 3 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aiida/common/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import collections
import contextlib
import enum
import logging
import types
from typing import cast
Expand Down Expand Up @@ -52,6 +53,8 @@ def report(self, msg: str, *args, **kwargs) -> None:
logging.getLevelName(logging.CRITICAL): logging.CRITICAL,
}

LogLevels = enum.Enum('LogLevels', {key: key for key in LOG_LEVELS}) # type: ignore[misc]
unkcpz marked this conversation as resolved.
Show resolved Hide resolved

AIIDA_LOGGER = cast(AiidaLoggerType, logging.getLogger('aiida'))

CLI_ACTIVE: bool | None = None
Expand Down
177 changes: 169 additions & 8 deletions aiida/manage/configuration/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,19 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Module that defines the configuration file of an AiiDA instance and functions to create and load it."""
from __future__ import annotations

import codecs
from functools import cache
import json
import os
from typing import Any, Dict, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Sequence, Tuple
sphuber marked this conversation as resolved.
Show resolved Hide resolved
import uuid

from pydantic import BaseModel, Field, ValidationError, validator # pylint: disable=no-name-in-module
sphuber marked this conversation as resolved.
Show resolved Hide resolved

from aiida.common.exceptions import ConfigurationError
from aiida.common.log import LogLevels

from . import schema as schema_module
from .options import Option, get_option, get_option_names, parse_option
Expand Down Expand Up @@ -52,6 +58,164 @@ def __str__(self) -> str:
return f'Validation Error: {prefix}{path}{self._message}{schema}'


class ConfigVersionSchema(BaseModel):
"""Schema for the version configuration of an AiiDA instance."""

CURRENT: int
OLDEST_COMPATIBLE: int


class ProfileOptionsSchema(BaseModel):
"""Schema for the options of an AiiDA profile."""

runner__poll__interval: int = Field(60, description='Polling interval in seconds to be used by process runners.')
daemon__default_workers: int = Field(
1, description='Default number of workers to be launched by `verdi daemon start`.'
)
daemon__timeout: int = Field(
2,
description=
'Used to set default timeout in the :class:`aiida.engine.daemon.client.DaemonClient` for calls to the daemon.'
)
daemon__worker_process_slots: int = Field(
200, description='Maximum number of concurrent process tasks that each daemon worker can handle.'
)
daemon__recursion_limit: int = Field(3000, description='Maximum recursion depth for the daemon workers.')
db__batch_size: int = Field(
100000,
description='Batch size for bulk CREATE operations in the database. Avoids hitting MaxAllocSize of PostgreSQL '
'(1GB) when creating large numbers of database records in one go.'
)
verdi__shell__auto_import: str = Field(
':',
description='Additional modules/functions/classes to be automatically loaded in `verdi shell`, split by `:`.'
)
logging__aiida_loglevel: LogLevels = Field(
'REPORT', description='Minimum level to log to daemon log and the `DbLog` table for the `aiida` logger.'
)
logging__verdi_loglevel: LogLevels = Field(
'REPORT', description='Minimum level to log to console when running a `verdi` command.'
)
logging__db_loglevel: LogLevels = Field('REPORT', description='Minimum level to log to the DbLog table.')
logging__plumpy_loglevel: LogLevels = Field(
'WARNING', description='Minimum level to log to daemon log and the `DbLog` table for the `plumpy` logger.'
)
logging__kiwipy_loglevel: LogLevels = Field(
'WARNING', description='Minimum level to log to daemon log and the `DbLog` table for the `kiwipy` logger'
)
logging__paramiko_loglevel: LogLevels = Field(
'WARNING', description='Minimum level to log to daemon log and the `DbLog` table for the `paramiko` logger'
)
logging__alembic_loglevel: LogLevels = Field(
'WARNING', description='Minimum level to log to daemon log and the `DbLog` table for the `alembic` logger'
)
logging__sqlalchemy_loglevel: LogLevels = Field(
'WARNING', description='Minimum level to log to daemon log and the `DbLog` table for the `sqlalchemy` logger'
)
logging__circus_loglevel: LogLevels = Field(
'INFO', description='Minimum level to log to daemon log and the `DbLog` table for the `circus` logger'
)
logging__aiopika_loglevel: LogLevels = Field(
'WARNING', description='Minimum level to log to daemon log and the `DbLog` table for the `aiopika` logger'
)
warnings__showdeprecations: bool = Field(True, description='Whether to print AiiDA deprecation warnings.')
warnings__rabbitmq_version: bool = Field(
True, description='Whether to print a warning when an incompatible version of RabbitMQ is configured.'
)
transport__task_retry_initial_interval: int = Field(
20, description='Initial time interval for the exponential backoff mechanism.'
)
transport__task_maximum_attempts: int = Field(
5, description='Maximum number of transport task attempts before a Process is Paused.'
)
rmq__task_timeout: int = Field(10, description='Timeout in seconds for communications with RabbitMQ.')
storage__sandbox: Optional[str] = Field(description='Absolute path to the directory to store sandbox folders.')
caching__default_enabled: bool = Field(False, description='Enable calculation caching by default.')
caching__enabled_for: List[str] = Field([], description='Calculation entry points to enable caching on.')
caching__disabled_for: List[str] = Field([], description='Calculation entry points to disable caching on.')

class Config:
use_enum_values = True

@validator('caching__enabled_for', 'caching__disabled_for')
@classmethod
def validate_caching_identifier_pattern(cls, value: List[str]) -> List[str]:
"""Validate the caching identifier patterns."""
from aiida.manage.caching import _validate_identifier_pattern
for identifier in value:
try:
_validate_identifier_pattern(identifier=identifier)
except ValueError as exception:
raise ValidationError(str(exception)) from exception

return value


class GlobalOptionsSchema(ProfileOptionsSchema):
"""Schema for the global options of an AiiDA instance."""
autofill__user__email: Optional[str] = Field(description='Default user email to use when creating new profiles.')
autofill__user__first_name: Optional[str] = Field(
description='Default user first name to use when creating new profiles.'
)
autofill__user__last_name: Optional[str] = Field(
description='Default user last name to use when creating new profiles.'
)
autofill__user__institution: Optional[str] = Field(
description='Default user institution to use when creating new profiles.'
)
rest_api__profile_switching: bool = Field(
False, description='Toggle whether the profile can be specified in requests submitted to the REST API.'
)
warnings__development_version: bool = Field(
True,
description='Whether to print a warning when a profile is loaded while a development version is installed.'
)


class ProfileStorageConfig(BaseModel):
"""Schema for the storage backend configuration of an AiiDA profile."""

backend: str
config: Dict[str, Any]


class ProcessControlConfig(BaseModel):
"""Schema for the process control configuration of an AiiDA profile."""

broker_protocol: str = Field('amqp', description='Protocol for connecting to the message broker.')
broker_username: str = Field('guest', description='Username for message broker authentication.')
broker_password: str = Field('guest', description='Password for message broker.')
broker_host: str = Field('127.0.0.1', description='Hostname of the message broker.')
broker_port: int = Field(5432, description='Port of the message broker.')
broker_virtual_host: str = Field('', description='Virtual host to use for the message broker.')
broker_parameters: dict[str, Any] = Field('guest', description='Arguments to be encoded as query parameters.')
sphuber marked this conversation as resolved.
Show resolved Hide resolved


class ProfileSchema(BaseModel):
"""Schema for the configuration of an AiiDA profile."""

uuid: str = Field(description='', default_factory=uuid.uuid4)
sphuber marked this conversation as resolved.
Show resolved Hide resolved
storage: ProfileStorageConfig
process_control: ProcessControlConfig
default_user_email: Optional[str] = None
test_profile: bool = False
options: Optional[ProfileOptionsSchema]

class Config:
json_encoders = {
uuid.UUID: lambda u: str(u), # pylint: disable=unnecessary-lambda
sphuber marked this conversation as resolved.
Show resolved Hide resolved
}


class ConfigSchema(BaseModel):
"""Schema for the configuration of an AiiDA instance."""

CONFIG_VERSION: Optional[ConfigVersionSchema]
profiles: Optional[dict[str, ProfileSchema]]
options: Optional[GlobalOptionsSchema]
default_profile: Optional[str]


class Config: # pylint: disable=too-many-public-methods
"""Object that represents the configuration file of an AiiDA instance."""

Expand Down Expand Up @@ -125,13 +289,10 @@ def _backup(cls, filepath):
@staticmethod
def validate(config: dict, filepath: Optional[str] = None):
"""Validate a configuration dictionary."""
import jsonschema
try:
jsonschema.validate(instance=config, schema=config_schema())
except jsonschema.ValidationError as error:
raise ConfigValidationError(
message=error.message, keypath=error.path, schema=error.schema, filepath=filepath
)
ConfigSchema(**config)
except ValidationError as exception:
raise ConfigurationError(f'invalid config schema: {filepath}: {str(exception)}')

def __init__(self, filepath: str, config: dict, validate: bool = True):
"""Instantiate a configuration object from a configuration dictionary and its filepath.
Expand Down Expand Up @@ -470,7 +631,7 @@ def get_options(self, scope: Optional[str] = None) -> Dict[str, Tuple[Option, st
elif name in self.options:
value = self.options.get(name)
source = 'global'
elif 'default' in option.schema:
elif option.default is not None:
value = option.default
source = 'default'
else:
Expand Down
84 changes: 23 additions & 61 deletions aiida/manage/configuration/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
class Option:
"""Represent a configuration option schema."""

def __init__(self, name: str, schema: Dict[str, Any]):
def __init__(self, name: str, schema: Dict[str, Any], field):
self._name = name
self._schema = schema
self._field = field

def __str__(self) -> str:
return f'Option(name={self._name})'
Expand All @@ -30,26 +31,27 @@ def name(self) -> str:
return self._name

@property
def schema(self) -> Dict[str, Any]:
return self._schema
def valid_type(self) -> Any:
return self._field.type_

@property
def valid_type(self) -> Any:
return self._schema.get('type', None)
def schema(self) -> Dict[str, Any]:
return self._schema

@property
def default(self) -> Any:
return self._schema.get('default', None)
return self._field.default

@property
def description(self) -> str:
return self._schema.get('description', '')
return self._field.field_info.description

@property
def global_only(self) -> bool:
return self._schema.get('global_only', False)
from .config import ProfileOptionsSchema
sphuber marked this conversation as resolved.
Show resolved Hide resolved
return self._name in ProfileOptionsSchema.__fields__

def validate(self, value: Any, cast: bool = True) -> Any:
def validate(self, value: Any) -> Any:
"""Validate a value

:param value: The input value
Expand All @@ -59,68 +61,28 @@ def validate(self, value: Any, cast: bool = True) -> Any:
:raise: ConfigValidationError

"""
# pylint: disable=too-many-branches
import jsonschema

from aiida.manage.caching import _validate_identifier_pattern

from .config import ConfigValidationError

if cast:
try:
if self.valid_type == 'boolean':
if isinstance(value, str):
if value.strip().lower() in ['0', 'false', 'f']:
value = False
elif value.strip().lower() in ['1', 'true', 't']:
value = True
else:
value = bool(value)
elif self.valid_type == 'string':
value = str(value)
elif self.valid_type == 'integer':
value = int(value)
elif self.valid_type == 'number':
value = float(value)
elif self.valid_type == 'array' and isinstance(value, str):
value = value.split()
except ValueError:
pass

try:
jsonschema.validate(instance=value, schema=self.schema)
except jsonschema.ValidationError as exc:
raise ConfigValidationError(message=exc.message, keypath=[self.name, *(exc.path or [])], schema=exc.schema)

# special caching validation
if self.name in ('caching.enabled_for', 'caching.disabled_for'):
for i, identifier in enumerate(value):
try:
_validate_identifier_pattern(identifier=identifier)
except ValueError as exc:
raise ConfigValidationError(message=str(exc), keypath=[self.name, str(i)])

return value
value, validation_error = self._field.validate(value, {}, loc=None)

if validation_error:
raise ConfigurationError(validation_error)

def get_schema_options() -> Dict[str, Dict[str, Any]]:
"""Return schema for options."""
from .config import config_schema
schema = config_schema()
return schema['definitions']['options']['properties']
return value


def get_option_names() -> List[str]:
"""Return a list of available option names."""
return list(get_schema_options())
from .config import GlobalOptionsSchema
return [key.replace('__', '.') for key in GlobalOptionsSchema.__fields__]


def get_option(name: str) -> Option:
"""Return option."""
options = get_schema_options()
if name not in options:
from .config import GlobalOptionsSchema
options = GlobalOptionsSchema.__fields__
option_name = name.replace('.', '__')
if option_name not in options:
raise ConfigurationError(f'the option {name} does not exist')
return Option(name, options[name])
return Option(name, GlobalOptionsSchema.schema()['properties'][option_name], options[option_name])


def parse_option(option_name: str, option_value: Any) -> Tuple[Option, Any]:
Expand All @@ -132,6 +94,6 @@ def parse_option(option_name: str, option_value: Any) -> Tuple[Option, Any]:

"""
option = get_option(option_name)
value = option.validate(option_value, cast=True)
value = option.validate(option_value)

return option, value
1 change: 1 addition & 0 deletions aiida/manage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def get_option(self, option_name: str) -> Any:
else:
if option_name in config.options:
return config.get_option(option_name)

# try the defaults (will raise ConfigurationError if not present)
option = get_option(option_name)
return option.default
Expand Down
4 changes: 2 additions & 2 deletions docs/source/nitpick-exceptions
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ py:class ndarray

py:class paramiko.proxy.ProxyCommand

py:class pydantic.main.BaseModel

# These can be removed once they are properly included in the `__all__` in `plumpy`
py:class plumpy.ports.PortNamespace
py:class plumpy.utils.AttributesDict
Expand Down Expand Up @@ -218,8 +220,6 @@ py:class CircusClient
py:class pgsu.PGSU
py:meth pgsu.PGSU.__init__

py:class jsonschema.exceptions._Error

py:class Session
py:class Query
py:class importlib_metadata.EntryPoint
Expand Down
Loading