Skip to content

Commit

Permalink
✨ NEW: Add Backend bulk methods (#5171)
Browse files Browse the repository at this point in the history
Adds `Backend.bulk_insert` and `Backend.bulk_update methods`, and moves the `delete_nodes_and_connections` function to a method on the Backend.
This removes the need for using backend specific code outside the backend.
  • Loading branch information
chrisjsewell authored Oct 16, 2021
1 parent 0b7db7b commit 8fb1457
Show file tree
Hide file tree
Showing 14 changed files with 420 additions and 133 deletions.
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ repos:
aiida/manage/manager.py|
aiida/manage/database/delete/nodes.py|
aiida/orm/querybuilder.py|
aiida/orm/implementation/backends.py|
aiida/orm/implementation/sql/backends.py|
aiida/orm/implementation/django/backend.py|
aiida/orm/implementation/sqlalchemy/backend.py|
aiida/orm/implementation/querybuilder.py|
aiida/orm/implementation/sqlalchemy/querybuilder/.*py|
aiida/orm/nodes/data/jsonable.py|
Expand Down
28 changes: 0 additions & 28 deletions aiida/backends/djsite/utils.py

This file was deleted.

5 changes: 5 additions & 0 deletions aiida/backends/sqlalchemy/models/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
)


class DbGroupNode(Base):
"""Class to store group to nodes relation using SQLA backend."""
__table__ = table_groups_nodes


class DbGroup(Base):
"""Class to store groups using SQLA backend."""

Expand Down
28 changes: 0 additions & 28 deletions aiida/backends/sqlalchemy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,6 @@
"""Utility functions specific to the SqlAlchemy backend."""


def delete_nodes_and_connections_sqla(pks_to_delete): # pylint: disable=invalid-name
"""
Delete all nodes corresponding to pks in the input.
:param pks_to_delete: A list, tuple or set of pks that should be deleted.
"""
# pylint: disable=no-value-for-parameter
from aiida.backends.sqlalchemy.models.group import table_groups_nodes
from aiida.backends.sqlalchemy.models.node import DbLink, DbNode
from aiida.manage.manager import get_manager

backend = get_manager().get_backend()

with backend.transaction() as session:
# I am first making a statement to delete the membership of these nodes to groups.
# Since table_groups_nodes is a sqlalchemy.schema.Table, I am using expression language to compile
# a stmt to be executed by the session. It works, but it's not nice that two different ways are used!
# Can this be changed?
stmt = table_groups_nodes.delete().where(table_groups_nodes.c.dbnode_id.in_(list(pks_to_delete)))
session.execute(stmt)
# First delete links, then the Nodes, since we are not cascading deletions.
# Here I delete the links coming out of the nodes marked for deletion.
session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
# Here I delete the links pointing to the nodes marked for deletion.
session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
# Now I am deleting the nodes
session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')


def flag_modified(instance, key):
"""Wrapper around `sqlalchemy.orm.attributes.flag_modified` to correctly dereference utils.ModelWrapper
Expand Down
15 changes: 0 additions & 15 deletions aiida/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""Backend-agnostic utility functions"""
from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA
from aiida.manage import configuration

AIIDA_ATTRIBUTE_SEP = '.'


Expand Down Expand Up @@ -47,15 +44,3 @@ def create_scoped_session_factory(engine, **kwargs):
"""Create scoped SQLAlchemy session factory"""
from sqlalchemy.orm import scoped_session, sessionmaker
return scoped_session(sessionmaker(bind=engine, future=True, **kwargs))


def delete_nodes_and_connections(pks):
"""Backend-agnostic function to delete Nodes and connections"""
if configuration.PROFILE.database_backend == BACKEND_DJANGO:
from aiida.backends.djsite.utils import delete_nodes_and_connections_django as delete_nodes_backend
elif configuration.PROFILE.database_backend == BACKEND_SQLA:
from aiida.backends.sqlalchemy.utils import delete_nodes_and_connections_sqla as delete_nodes_backend
else:
raise Exception(f'unknown backend {configuration.PROFILE.database_backend}')

delete_nodes_backend(pks)
14 changes: 14 additions & 0 deletions aiida/orm/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""Module for all common top level AiiDA entity classes and methods"""
import abc
import copy
from enum import Enum
import typing

from plumpy.base.utils import call_with_super_check, super_check
Expand All @@ -25,6 +26,19 @@
_NO_DEFAULT = tuple()


class EntityTypes(Enum):
"""Enum for referring to ORM entities in a backend-agnostic manner."""
AUTHINFO = 'authinfo'
COMMENT = 'comment'
COMPUTER = 'computer'
GROUP = 'group'
LOG = 'log'
NODE = 'node'
USER = 'user'
LINK = 'link'
GROUP_NODE = 'group_node'


class Collection(typing.Generic[EntityType]):
"""Container class that represents the collection of objects of a particular type."""

Expand Down
63 changes: 54 additions & 9 deletions aiida/orm/implementation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
###########################################################################
"""Generic backend related objects"""
import abc
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, ContextManager, List, Sequence, TypeVar

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from aiida.orm.entities import EntityTypes
from aiida.orm.implementation import (
BackendAuthInfoCollection,
BackendCommentCollection,
Expand All @@ -27,12 +28,14 @@

__all__ = ('Backend',)

TransactionType = TypeVar('TransactionType')


class Backend(abc.ABC):
"""The public interface that defines a backend factory that creates backend specific concrete objects."""

@abc.abstractmethod
def migrate(self):
def migrate(self) -> None:
"""Migrate the database to the latest schema generation or version."""

@property
Expand Down Expand Up @@ -65,17 +68,24 @@ def logs(self) -> 'BackendLogCollection':
def nodes(self) -> 'BackendNodeCollection':
"""Return the collection of nodes"""

@property
@abc.abstractmethod
def users(self) -> 'BackendUserCollection':
"""Return the collection of users"""

@abc.abstractmethod
def query(self) -> 'BackendQueryBuilder':
"""Return an instance of a query builder implementation for this backend"""

@property
@abc.abstractmethod
def users(self) -> 'BackendUserCollection':
"""Return the collection of users"""
def get_session(self) -> 'Session':
"""Return a database session that can be used by the `QueryBuilder` to perform its query.
:return: an instance of :class:`sqlalchemy.orm.session.Session`
"""

@abc.abstractmethod
def transaction(self):
def transaction(self) -> ContextManager[Any]:
"""
Get a context manager that can be used as a transaction context for a series of backend operations.
If there is an exception within the context then the changes will be rolled back and the state will
Expand All @@ -84,9 +94,44 @@ def transaction(self):
:return: a context manager to group database operations
"""

@property
@abc.abstractmethod
def get_session(self) -> 'Session':
"""Return a database session that can be used by the `QueryBuilder` to perform its query.
def in_transaction(self) -> bool:
"""Return whether a transaction is currently active."""

:return: an instance of :class:`sqlalchemy.orm.session.Session`
@abc.abstractmethod
def bulk_insert(self, entity_type: 'EntityTypes', rows: List[dict], allow_defaults: bool = False) -> List[int]:
"""Insert a list of entities into the database, directly into a backend transaction.
:param entity_type: The type of the entity
:param data: A list of dictionaries, containing all fields of the backend model,
except the `id` field (a.k.a primary key), which will be generated dynamically
:param allow_defaults: If ``False``, assert that each row contains all fields (except primary key(s)),
otherwise, allow default values for missing fields.
:raises: ``IntegrityError`` if the keys in a row are not a subset of the columns in the table
:returns: The list of generated primary keys for the entities
"""

@abc.abstractmethod
def bulk_update(self, entity_type: 'EntityTypes', rows: List[dict]) -> None:
"""Update a list of entities in the database, directly with a backend transaction.
:param entity_type: The type of the entity
:param data: A list of dictionaries, containing fields of the backend model to update,
and the `id` field (a.k.a primary key)
:raises: ``IntegrityError`` if the keys in a row are not a subset of the columns in the table
"""

@abc.abstractmethod
def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]):
"""Delete all nodes corresponding to pks in the input and any links to/from them.
This method is intended to be used within a transaction context.
:param pks_to_delete: a sequence of node pks to delete
:raises: ``AssertionError`` if a transaction is not active
"""
104 changes: 96 additions & 8 deletions aiida/orm/implementation/django/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,18 @@
###########################################################################
"""Django implementation of `aiida.orm.implementation.backends.Backend`."""
from contextlib import contextmanager
import functools
from typing import Any, ContextManager, List, Sequence

# pylint: disable=import-error,no-name-in-module
from django.db import models, transaction
from django.apps import apps
from django.db import models
from django.db import transaction as django_transaction

from aiida.backends.djsite.db import models as dbm
from aiida.backends.djsite.manager import DjangoBackendManager
from aiida.common.exceptions import IntegrityError
from aiida.orm.entities import EntityTypes

from . import authinfos, comments, computers, convert, groups, logs, nodes, querybuilder, users
from ..sql.backends import SqlBackend
Expand Down Expand Up @@ -69,11 +76,6 @@ def query(self):
def users(self):
return self._users

@staticmethod
def transaction():
"""Open a transaction to be used as a context manager."""
return transaction.atomic()

@staticmethod
def get_session():
"""Return a database session that can be used by the `QueryBuilder` to perform its query.
Expand All @@ -86,6 +88,92 @@ def get_session():
from aiida.backends.djsite import get_scoped_session
return get_scoped_session()

@staticmethod
def transaction() -> ContextManager[Any]:
"""Open a transaction to be used as a context manager."""
return django_transaction.atomic()

@property
def in_transaction(self) -> bool:
return not django_transaction.get_autocommit()

@staticmethod
@functools.lru_cache(maxsize=18)
def _get_model_from_entity(entity_type: EntityTypes, with_pk: bool):
"""Return the Django model and fields corresponding to the given entity.
:param with_pk: if True, the fields returned will include the primary key
"""
from sqlalchemy import inspect

model = {
EntityTypes.AUTHINFO: dbm.DbAuthInfo,
EntityTypes.COMMENT: dbm.DbComment,
EntityTypes.COMPUTER: dbm.DbComputer,
EntityTypes.GROUP: dbm.DbGroup,
EntityTypes.LOG: dbm.DbLog,
EntityTypes.NODE: dbm.DbNode,
EntityTypes.USER: dbm.DbUser,
EntityTypes.LINK: dbm.DbLink,
EntityTypes.GROUP_NODE:
{model._meta.db_table: model for model in apps.get_models(include_auto_created=True)}['db_dbgroup_dbnodes']
}[entity_type]
mapper = inspect(model.sa).mapper # here aldjemy provides us the SQLAlchemy model
keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key}
return model, keys

def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults: bool = False) -> List[int]:
model, keys = self._get_model_from_entity(entity_type, False)
if allow_defaults:
for row in rows:
if not keys.issuperset(row):
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}')
else:
for row in rows:
if set(row) != keys:
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} != {keys}')
objects = [model(**row) for row in rows]
# if there is an mtime field, disable the automatic update, so as not to change it
if entity_type in (EntityTypes.NODE, EntityTypes.COMMENT):
with dbm.suppress_auto_now([(model, ['mtime'])]):
model.objects.bulk_create(objects)
else:
model.objects.bulk_create(objects)
return [obj.id for obj in objects]

def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None:
model, keys = self._get_model_from_entity(entity_type, True)
id_entries = {}
fields = None
for row in rows:
if not keys.issuperset(row):
raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}')
try:
id_entries[row['id']] = {k: v for k, v in row.items() if k != 'id'}
fields = fields or list(id_entries[row['id']])
assert fields == list(id_entries[row['id']])
except KeyError:
raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}")
except AssertionError:
# this is handled in sqlalchemy, but would require more complex logic here
raise NotImplementedError(f'Cannot bulk update {entity_type} with different fields')
if fields is None:
return
objects = []
for pk, obj in model.objects.in_bulk(list(id_entries), field_name='id').items():
for name, value in id_entries[pk].items():
setattr(obj, name, value)
objects.append(obj)
model.objects.bulk_update(objects, fields)

def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None:
if not self.in_transaction:
raise AssertionError('Cannot delete nodes and links outside a transaction')
# Delete all links pointing to or from a given node
dbm.DbLink.objects.filter(models.Q(input__in=pks_to_delete) | models.Q(output__in=pks_to_delete)).delete()
# now delete nodes
dbm.DbNode.objects.filter(pk__in=pks_to_delete).delete()

# Below are abstract methods inherited from `aiida.orm.implementation.sql.backends.SqlBackend`

def get_backend_entity(self, model):
Expand All @@ -100,7 +188,7 @@ def cursor(self):
:rtype: :class:`psycopg2.extensions.cursor`
"""
try:
yield self.get_connection().cursor()
yield self._get_connection().cursor()
finally:
pass

Expand All @@ -117,7 +205,7 @@ def execute_raw(self, query):
return results

@staticmethod
def get_connection():
def _get_connection():
"""
Get the Django connection
Expand Down
Loading

0 comments on commit 8fb1457

Please sign in to comment.