Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ NEW: Add Backend bulk methods #5171

Merged
merged 17 commits into from
Oct 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ repos:
aiida/manage/manager.py|
aiida/manage/database/delete/nodes.py|
aiida/orm/querybuilder.py|
aiida/orm/implementation/backends.py|
aiida/orm/implementation/sql/backends.py|
aiida/orm/implementation/django/backend.py|
aiida/orm/implementation/sqlalchemy/backend.py|
aiida/orm/implementation/querybuilder.py|
aiida/orm/implementation/sqlalchemy/querybuilder/.*py|
aiida/orm/nodes/data/jsonable.py|
Expand Down
28 changes: 0 additions & 28 deletions aiida/backends/djsite/utils.py

This file was deleted.

5 changes: 5 additions & 0 deletions aiida/backends/sqlalchemy/models/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
)


class DbGroupNode(Base):
"""Class to store group to nodes relation using SQLA backend."""
__table__ = table_groups_nodes


class DbGroup(Base):
"""Class to store groups using SQLA backend."""

Expand Down
28 changes: 0 additions & 28 deletions aiida/backends/sqlalchemy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,6 @@
"""Utility functions specific to the SqlAlchemy backend."""


def delete_nodes_and_connections_sqla(pks_to_delete): # pylint: disable=invalid-name
"""
Delete all nodes corresponding to pks in the input.
:param pks_to_delete: A list, tuple or set of pks that should be deleted.
"""
# pylint: disable=no-value-for-parameter
from aiida.backends.sqlalchemy.models.group import table_groups_nodes
from aiida.backends.sqlalchemy.models.node import DbLink, DbNode
from aiida.manage.manager import get_manager

backend = get_manager().get_backend()

with backend.transaction() as session:
# I am first making a statement to delete the membership of these nodes to groups.
# Since table_groups_nodes is a sqlalchemy.schema.Table, I am using expression language to compile
# a stmt to be executed by the session. It works, but it's not nice that two different ways are used!
# Can this be changed?
stmt = table_groups_nodes.delete().where(table_groups_nodes.c.dbnode_id.in_(list(pks_to_delete)))
session.execute(stmt)
# First delete links, then the Nodes, since we are not cascading deletions.
# Here I delete the links coming out of the nodes marked for deletion.
session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
# Here I delete the links pointing to the nodes marked for deletion.
session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
# Now I am deleting the nodes
session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')


def flag_modified(instance, key):
"""Wrapper around `sqlalchemy.orm.attributes.flag_modified` to correctly dereference utils.ModelWrapper

Expand Down
15 changes: 0 additions & 15 deletions aiida/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Backend-agnostic utility functions"""
from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA
from aiida.manage import configuration

AIIDA_ATTRIBUTE_SEP = '.'


Expand Down Expand Up @@ -47,15 +44,3 @@ def create_scoped_session_factory(engine, **kwargs):
"""Create scoped SQLAlchemy session factory"""
from sqlalchemy.orm import scoped_session, sessionmaker
return scoped_session(sessionmaker(bind=engine, future=True, **kwargs))


def delete_nodes_and_connections(pks):
"""Backend-agnostic function to delete Nodes and connections"""
if configuration.PROFILE.database_backend == BACKEND_DJANGO:
from aiida.backends.djsite.utils import delete_nodes_and_connections_django as delete_nodes_backend
elif configuration.PROFILE.database_backend == BACKEND_SQLA:
from aiida.backends.sqlalchemy.utils import delete_nodes_and_connections_sqla as delete_nodes_backend
else:
raise Exception(f'unknown backend {configuration.PROFILE.database_backend}')

delete_nodes_backend(pks)
14 changes: 14 additions & 0 deletions aiida/orm/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""Module for all common top level AiiDA entity classes and methods"""
import abc
import copy
from enum import Enum
import typing

from plumpy.base.utils import call_with_super_check, super_check
Expand All @@ -25,6 +26,19 @@
_NO_DEFAULT = tuple()


class EntityTypes(Enum):
"""Enum for referring to ORM entities in a backend-agnostic manner."""
AUTHINFO = 'authinfo'
COMMENT = 'comment'
COMPUTER = 'computer'
GROUP = 'group'
LOG = 'log'
NODE = 'node'
USER = 'user'
LINK = 'link'
GROUP_NODE = 'group_node'


class Collection(typing.Generic[EntityType]):
"""Container class that represents the collection of objects of a particular type."""

Expand Down
63 changes: 54 additions & 9 deletions aiida/orm/implementation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
###########################################################################
"""Generic backend related objects"""
import abc
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, ContextManager, List, Sequence, TypeVar

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from aiida.orm.entities import EntityTypes
from aiida.orm.implementation import (
BackendAuthInfoCollection,
BackendCommentCollection,
Expand All @@ -27,12 +28,14 @@

__all__ = ('Backend',)

TransactionType = TypeVar('TransactionType')


class Backend(abc.ABC):
"""The public interface that defines a backend factory that creates backend specific concrete objects."""

@abc.abstractmethod
def migrate(self):
def migrate(self) -> None:
"""Migrate the database to the latest schema generation or version."""

@property
Expand Down Expand Up @@ -65,17 +68,24 @@ def logs(self) -> 'BackendLogCollection':
def nodes(self) -> 'BackendNodeCollection':
"""Return the collection of nodes"""

@property
@abc.abstractmethod
def users(self) -> 'BackendUserCollection':
"""Return the collection of users"""

@abc.abstractmethod
def query(self) -> 'BackendQueryBuilder':
"""Return an instance of a query builder implementation for this backend"""

@property
@abc.abstractmethod
def users(self) -> 'BackendUserCollection':
"""Return the collection of users"""
def get_session(self) -> 'Session':
"""Return a database session that can be used by the `QueryBuilder` to perform its query.

:return: an instance of :class:`sqlalchemy.orm.session.Session`
"""

@abc.abstractmethod
def transaction(self):
def transaction(self) -> ContextManager[Any]:
"""
Get a context manager that can be used as a transaction context for a series of backend operations.
If there is an exception within the context then the changes will be rolled back and the state will
Expand All @@ -84,9 +94,44 @@ def transaction(self):
:return: a context manager to group database operations
"""

@property
@abc.abstractmethod
def get_session(self) -> 'Session':
"""Return a database session that can be used by the `QueryBuilder` to perform its query.
def in_transaction(self) -> bool:
"""Return whether a transaction is currently active."""

:return: an instance of :class:`sqlalchemy.orm.session.Session`
@abc.abstractmethod
def bulk_insert(self, entity_type: 'EntityTypes', rows: List[dict], allow_defaults: bool = False) -> List[int]:
"""Insert a list of entities into the database, directly into a backend transaction.

:param entity_type: The type of the entity
:param data: A list of dictionaries, containing all fields of the backend model,
except the `id` field (a.k.a primary key), which will be generated dynamically
:param allow_defaults: If ``False``, assert that each row contains all fields (except primary key(s)),
otherwise, allow default values for missing fields.

:raises: ``IntegrityError`` if the keys in a row are not a subset of the columns in the table

:returns: The list of generated primary keys for the entities
"""

@abc.abstractmethod
def bulk_update(self, entity_type: 'EntityTypes', rows: List[dict]) -> None:
"""Update a list of entities in the database, directly with a backend transaction.

:param entity_type: The type of the entity
:param data: A list of dictionaries, containing fields of the backend model to update,
and the `id` field (a.k.a primary key)

:raises: ``IntegrityError`` if the keys in a row are not a subset of the columns in the table
"""

@abc.abstractmethod
def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]):
"""Delete all nodes corresponding to pks in the input and any links to/from them.

This method is intended to be used within a transaction context.

:param pks_to_delete: a sequence of node pks to delete

:raises: ``AssertionError`` if a transaction is not active
"""
104 changes: 96 additions & 8 deletions aiida/orm/implementation/django/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,18 @@
###########################################################################
"""Django implementation of `aiida.orm.implementation.backends.Backend`."""
from contextlib import contextmanager
import functools
from typing import Any, ContextManager, List, Sequence

# pylint: disable=import-error,no-name-in-module
from django.db import models, transaction
from django.apps import apps
from django.db import models
from django.db import transaction as django_transaction

from aiida.backends.djsite.db import models as dbm
from aiida.backends.djsite.manager import DjangoBackendManager
from aiida.common.exceptions import IntegrityError
from aiida.orm.entities import EntityTypes

from . import authinfos, comments, computers, convert, groups, logs, nodes, querybuilder, users
from ..sql.backends import SqlBackend
Expand Down Expand Up @@ -69,11 +76,6 @@ def query(self):
def users(self):
return self._users

@staticmethod
def transaction():
"""Open a transaction to be used as a context manager."""
return transaction.atomic()

@staticmethod
def get_session():
"""Return a database session that can be used by the `QueryBuilder` to perform its query.
Expand All @@ -86,6 +88,92 @@ def get_session():
from aiida.backends.djsite import get_scoped_session
return get_scoped_session()

@staticmethod
def transaction() -> ContextManager[Any]:
"""Open a transaction to be used as a context manager."""
return django_transaction.atomic()

@property
def in_transaction(self) -> bool:
return not django_transaction.get_autocommit()

@staticmethod
@functools.lru_cache(maxsize=18)
def _get_model_from_entity(entity_type: EntityTypes, with_pk: bool):
"""Return the Django model and fields corresponding to the given entity.

:param with_pk: if True, the fields returned will include the primary key
"""
from sqlalchemy import inspect

model = {
EntityTypes.AUTHINFO: dbm.DbAuthInfo,
EntityTypes.COMMENT: dbm.DbComment,
EntityTypes.COMPUTER: dbm.DbComputer,
EntityTypes.GROUP: dbm.DbGroup,
EntityTypes.LOG: dbm.DbLog,
EntityTypes.NODE: dbm.DbNode,
EntityTypes.USER: dbm.DbUser,
EntityTypes.LINK: dbm.DbLink,
EntityTypes.GROUP_NODE:
{model._meta.db_table: model for model in apps.get_models(include_auto_created=True)}['db_dbgroup_dbnodes']
}[entity_type]
mapper = inspect(model.sa).mapper # here aldjemy provides us the SQLAlchemy model
keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key}
return model, keys

def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults: bool = False) -> List[int]:
model, keys = self._get_model_from_entity(entity_type, False)
if allow_defaults:
for row in rows:
if not keys.issuperset(row):
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}')
else:
for row in rows:
if set(row) != keys:
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} != {keys}')
objects = [model(**row) for row in rows]
# if there is an mtime field, disable the automatic update, so as not to change it
if entity_type in (EntityTypes.NODE, EntityTypes.COMMENT):
with dbm.suppress_auto_now([(model, ['mtime'])]):
model.objects.bulk_create(objects)
else:
model.objects.bulk_create(objects)
return [obj.id for obj in objects]

def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None:
model, keys = self._get_model_from_entity(entity_type, True)
id_entries = {}
fields = None
for row in rows:
if not keys.issuperset(row):
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}')
try:
id_entries[row['id']] = {k: v for k, v in row.items() if k != 'id'}
fields = fields or list(id_entries[row['id']])
assert fields == list(id_entries[row['id']])
except KeyError:
raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}")
except AssertionError:
# this is handled in sqlalchemy, but would require more complex logic here
raise NotImplementedError(f'Cannot bulk update {entity_type} with different fields')
if fields is None:
return
objects = []
for pk, obj in model.objects.in_bulk(list(id_entries), field_name='id').items():
for name, value in id_entries[pk].items():
setattr(obj, name, value)
objects.append(obj)
model.objects.bulk_update(objects, fields)

def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None:
if not self.in_transaction:
raise AssertionError('Cannot delete nodes and links outside a transaction')
# Delete all links pointing to or from a given node
dbm.DbLink.objects.filter(models.Q(input__in=pks_to_delete) | models.Q(output__in=pks_to_delete)).delete()
# now delete nodes
dbm.DbNode.objects.filter(pk__in=pks_to_delete).delete()

# Below are abstract methods inherited from `aiida.orm.implementation.sql.backends.SqlBackend`

def get_backend_entity(self, model):
Expand All @@ -100,7 +188,7 @@ def cursor(self):
:rtype: :class:`psycopg2.extensions.cursor`
"""
try:
yield self.get_connection().cursor()
yield self._get_connection().cursor()
finally:
pass

Expand All @@ -117,7 +205,7 @@ def execute_raw(self, query):
return results

@staticmethod
def get_connection():
def _get_connection():
"""
Get the Django connection

Expand Down
Loading