From a2f89576c9b6a1ffca15b950476a9a569b565d48 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 8 Oct 2021 07:55:03 +0200 Subject: [PATCH 01/14] Add abstract methods --- aiida/orm/implementation/backends.py | 32 ++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/aiida/orm/implementation/backends.py b/aiida/orm/implementation/backends.py index a0d43a7b43..92ec52b40d 100644 --- a/aiida/orm/implementation/backends.py +++ b/aiida/orm/implementation/backends.py @@ -90,3 +90,35 @@ def get_session(self) -> 'Session': :return: an instance of :class:`sqlalchemy.orm.session.Session` """ + + @abc.abstractmethod + def bulk_insert(self, + entity_type: 'EntityTypes', + rows: List[dict], + transaction: Any, + allow_defaults: bool = False) -> List[int]: + """Insert a list of entities into the database, directly into a backend transaction. + + :param entity_type: The type of the entity + :param data: A list of dictionaries, containing all fields of the backend model, + except the `id` field (a.k.a primary key), which will be generated dynamically + :param transaction: the returned object of the ``self.transaction`` context + :param allow_defaults: If ``False``, assert that each row contains all fields (except primary key(s)), + otherwise, allow default values for missing fields. + + :raises: ``IntegrityError`` if the keys in a row are not a subset of the columns in the table + + :returns: The list of generated primary keys for the entities + """ + + @abc.abstractmethod + def bulk_update(self, entity_type: 'EntityTypes', rows: List[dict], transaction: Any) -> None: + """Update a list of entities in the database, directly with a backend transaction. + + :param entity_type: The type of the entity + :param data: A list of dictionaries, containing fields of the backend model to update, + and the `id` field (a.k.a primary key) + :param transaction: the returned object of the ``self.transaction`` context + + :raises: ``IntegrityError`` if the keys in a row are not a subset of the columns in the table + """ From 4bbf562acc5bf378d0738588bcd437831da96fd1 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 8 Oct 2021 07:58:29 +0200 Subject: [PATCH 02/14] add django implementation --- aiida/orm/implementation/backends.py | 2 +- aiida/orm/implementation/django/backend.py | 81 +++++++++++++++++++++- 2 files changed, 80 insertions(+), 3 deletions(-) diff --git a/aiida/orm/implementation/backends.py b/aiida/orm/implementation/backends.py index 92ec52b40d..941219635f 100644 --- a/aiida/orm/implementation/backends.py +++ b/aiida/orm/implementation/backends.py @@ -9,7 +9,7 @@ ########################################################################### """Generic backend related objects""" import abc -from typing import TYPE_CHECKING +from typing import Any, List, TYPE_CHECKING if TYPE_CHECKING: from sqlalchemy.orm.session import Session diff --git a/aiida/orm/implementation/django/backend.py b/aiida/orm/implementation/django/backend.py index b6056bda35..3917731c20 100644 --- a/aiida/orm/implementation/django/backend.py +++ b/aiida/orm/implementation/django/backend.py @@ -9,11 +9,18 @@ ########################################################################### """Django implementation of `aiida.orm.implementation.backends.Backend`.""" from contextlib import contextmanager +import functools +from typing import Any, List # pylint: disable=import-error,no-name-in-module -from django.db import models, transaction +from django.apps import apps +from django.db import models +from django.db import transaction as django_transaction +from aiida.backends.djsite.db import models as dbm from aiida.backends.djsite.manager import DjangoBackendManager +from aiida.common.exceptions import IntegrityError +from aiida.orm.entities import EntityTypes from . import authinfos, comments, computers, convert, groups, logs, nodes, querybuilder, users from ..sql.backends import SqlBackend @@ -72,7 +79,7 @@ def users(self): @staticmethod def transaction(): """Open a transaction to be used as a context manager.""" - return transaction.atomic() + return django_transaction.atomic() @staticmethod def get_session(): @@ -86,6 +93,76 @@ def get_session(): from aiida.backends.djsite import get_scoped_session return get_scoped_session() + @staticmethod + @functools.lru_cache(maxsize=18) + def _get_model_from_entity(entity_type: EntityTypes, with_pk: bool): + """Return the Django model corresponding to the given entity.""" + from sqlalchemy import inspect + + model = { + EntityTypes.AUTHINFO: dbm.DbAuthInfo, + EntityTypes.COMMENT: dbm.DbComment, + EntityTypes.COMPUTER: dbm.DbComputer, + EntityTypes.GROUP: dbm.DbGroup, + EntityTypes.LOG: dbm.DbLog, + EntityTypes.NODE: dbm.DbNode, + EntityTypes.USER: dbm.DbUser, + EntityTypes.LINK: dbm.DbLink, + EntityTypes.GROUP_NODE: + {model._meta.db_table: model for model in apps.get_models(include_auto_created=True)}['db_dbgroup_dbnodes'] + }[entity_type] + mapper = inspect(model.sa).mapper # here aldjemy provides us the SQLAlchemy model + keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key} + return model, keys + + def bulk_insert(self, + entity_type: EntityTypes, + rows: List[dict], + transaction: Any, + allow_defaults: bool = False) -> List[int]: + model, keys = self._get_model_from_entity(entity_type, False) + if allow_defaults: + for row in rows: + if not keys.issuperset(row): + raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') + else: + for row in rows: + if set(row) != keys: + raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} != {keys}') + objects = [model(**row) for row in rows] + # if there is an mtime field, disable the automatic update, so as not to change it + if entity_type in (EntityTypes.NODE, EntityTypes.COMMENT): + with dbm.suppress_auto_now([(model, ['mtime'])]): + model.objects.bulk_create(objects) + else: + model.objects.bulk_create(objects) + return [obj.id for obj in objects] + + def bulk_update(self, entity_type: EntityTypes, rows: List[dict], transaction: Any) -> None: + model, keys = self._get_model_from_entity(entity_type, True) + id_entries = {} + fields = None + for row in rows: + if not keys.issuperset(row): + raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') + try: + id_entries[row['id']] = {k: v for k, v in row.items() if k != 'id'} + fields = fields or list(id_entries[row['id']]) + assert fields == list(id_entries[row['id']]) + except KeyError: + raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}") + except AssertionError: + # this is handled in sqlalchemy, but would require more complex logic here + raise NotImplementedError(f'Cannot bulk update {entity_type} with different fields') + if fields is None: + return + objects = [] + for pk, obj in model.objects.in_bulk(list(id_entries), field_name='id').items(): + for name, value in id_entries[pk].items(): + setattr(obj, name, value) + objects.append(obj) + model.objects.bulk_update(objects, fields) + # Below are abstract methods inherited from `aiida.orm.implementation.sql.backends.SqlBackend` def get_backend_entity(self, model): From f8771ea7939253bbfd122f20eae74dd6bdd085b0 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 8 Oct 2021 08:01:50 +0200 Subject: [PATCH 03/14] add sqla implmentation --- .../orm/implementation/sqlalchemy/backend.py | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/aiida/orm/implementation/sqlalchemy/backend.py b/aiida/orm/implementation/sqlalchemy/backend.py index 64a7109bf9..53594ba9ee 100644 --- a/aiida/orm/implementation/sqlalchemy/backend.py +++ b/aiida/orm/implementation/sqlalchemy/backend.py @@ -8,10 +8,15 @@ # For further information please visit http://www.aiida.net # ########################################################################### """SqlAlchemy implementation of `aiida.orm.implementation.backends.Backend`.""" +# pylint: disable=missing-function-docstring from contextlib import contextmanager +import functools +from typing import Any, List from aiida.backends.sqlalchemy.manager import SqlaBackendManager from aiida.backends.sqlalchemy.models import base +from aiida.common.exceptions import IntegrityError +from aiida.orm.entities import EntityTypes from . import authinfos, comments, computers, convert, groups, logs, nodes, querybuilder, users from ..sql.backends import SqlBackend @@ -78,11 +83,76 @@ def transaction(self): if session.in_transaction(): with session.begin_nested(): yield session + session.commit() else: with session.begin(): with session.begin_nested(): yield session + @staticmethod + @functools.lru_cache(maxsize=18) + def _get_mapper_from_entity(entity_type: EntityTypes, with_pk: bool): + """Return the Sqlalchemy mapper and non-primary keys corresponding to the given entity.""" + from sqlalchemy import inspect + + from aiida.backends.sqlalchemy.models.authinfo import DbAuthInfo + from aiida.backends.sqlalchemy.models.comment import DbComment + from aiida.backends.sqlalchemy.models.computer import DbComputer + from aiida.backends.sqlalchemy.models.group import DbGroup, DbGroupNode + from aiida.backends.sqlalchemy.models.log import DbLog + from aiida.backends.sqlalchemy.models.node import DbLink, DbNode + from aiida.backends.sqlalchemy.models.user import DbUser + model = { + EntityTypes.AUTHINFO: DbAuthInfo, + EntityTypes.COMMENT: DbComment, + EntityTypes.COMPUTER: DbComputer, + EntityTypes.GROUP: DbGroup, + EntityTypes.LOG: DbLog, + EntityTypes.NODE: DbNode, + EntityTypes.USER: DbUser, + EntityTypes.LINK: DbLink, + EntityTypes.GROUP_NODE: DbGroupNode, + }[entity_type] + mapper = inspect(model).mapper + keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key} + return mapper, keys + + def bulk_insert(self, + entity_type: EntityTypes, + rows: List[dict], + transaction: Any, + allow_defaults: bool = False) -> List[int]: + mapper, keys = self._get_mapper_from_entity(entity_type, False) + if not rows: + return [] + if entity_type in (EntityTypes.COMPUTER, EntityTypes.LOG): + for row in rows: + row['_metadata'] = row.pop('metadata') + if allow_defaults: + for row in rows: + if not keys.issuperset(row): + raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') + else: + for row in rows: + if set(row) != keys: + raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} != {keys}') + # note for postgresql+psycopg2 we could also use `save_all` + `flush` with minimal performance degradation, see + # https://docs.sqlalchemy.org/en/14/changelog/migration_14.html#orm-batch-inserts-with-psycopg2-now-batch-statements-with-returning-in-most-cases + # by contrast, in sqlite, bulk_insert is faster: https://docs.sqlalchemy.org/en/14/faq/performance.html + transaction.bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True) + return [row['id'] for row in rows] + + def bulk_update(self, entity_type: EntityTypes, rows: List[dict], transaction: Any) -> None: + mapper, keys = self._get_mapper_from_entity(entity_type, True) + if not rows: + return None + for row in rows: + if 'id' not in row: + raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}") + if not keys.issuperset(row): + raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') + transaction.bulk_update_mappings(mapper, rows) + @staticmethod def get_session(): """Return a database session that can be used by the `QueryBuilder` to perform its query. From a2226c5bc08dfa0e095a964fe1f15054dbdbfa97 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 8 Oct 2021 08:03:12 +0200 Subject: [PATCH 04/14] add EntityTypes --- aiida/orm/entities.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/aiida/orm/entities.py b/aiida/orm/entities.py index f5019e2a99..19477a8671 100644 --- a/aiida/orm/entities.py +++ b/aiida/orm/entities.py @@ -10,6 +10,7 @@ """Module for all common top level AiiDA entity classes and methods""" import abc import copy +from enum import Enum import typing from plumpy.base.utils import call_with_super_check, super_check @@ -25,6 +26,19 @@ _NO_DEFAULT = tuple() +class EntityTypes(Enum): + """Enum for referring to ORM entities in a backend-agnostic manner.""" + AUTHINFO = 'authinfo' + COMMENT = 'comment' + COMPUTER = 'computer' + GROUP = 'group' + LOG = 'log' + NODE = 'node' + USER = 'user' + LINK = 'link' + GROUP_NODE = 'group_node' + + class Collection(typing.Generic[EntityType]): """Container class that represents the collection of objects of a particular type.""" From dcbbf6cd38cc51b3242b5e0fd2c24907c0685cf8 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 8 Oct 2021 08:04:32 +0200 Subject: [PATCH 05/14] add DbGroupNode --- aiida/backends/sqlalchemy/models/group.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/aiida/backends/sqlalchemy/models/group.py b/aiida/backends/sqlalchemy/models/group.py index f943e7a519..1fdb898987 100644 --- a/aiida/backends/sqlalchemy/models/group.py +++ b/aiida/backends/sqlalchemy/models/group.py @@ -31,6 +31,11 @@ ) +class DbGroupNode(Base): + """Class to store group to nodes relation using SQLA backend.""" + __table__ = table_groups_nodes + + class DbGroup(Base): """Class to store groups using SQLA backend.""" From 7c8b600aa0758b9bb2f00f57c4e2180bfc46d3be Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 8 Oct 2021 08:06:04 +0200 Subject: [PATCH 06/14] add tests --- tests/orm/implementation/test_backend.py | 63 +++++++++++++++++++++--- 1 file changed, 55 insertions(+), 8 deletions(-) diff --git a/tests/orm/implementation/test_backend.py b/tests/orm/implementation/test_backend.py index b7b89e7fe5..4b6429babf 100644 --- a/tests/orm/implementation/test_backend.py +++ b/tests/orm/implementation/test_backend.py @@ -8,14 +8,22 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Unit tests for the ORM Backend class.""" +import pytest + from aiida import orm -from aiida.backends.testbase import AiidaTestCase from aiida.common import exceptions +from aiida.orm.entities import EntityTypes -class TestBackend(AiidaTestCase): +@pytest.mark.usefixtures('clear_database_before_test') +class TestBackend: """Test backend.""" + @pytest.fixture(autouse=True) + def init_test(self, backend): + """Set up the backend.""" + self.backend = backend # pylint: disable=attribute-defined-outside-init + def test_transaction_nesting(self): """Test that transaction nesting works.""" user = orm.User('initial@email.com').store() @@ -24,12 +32,12 @@ def test_transaction_nesting(self): try: with self.backend.transaction(): user.email = 'failure@email.com' - self.assertEqual(user.email, 'failure@email.com') + assert user.email == 'failure@email.com' raise RuntimeError except RuntimeError: pass - self.assertEqual(user.email, 'pre-failure@email.com') - self.assertEqual(user.email, 'pre-failure@email.com') + assert user.email == 'pre-failure@email.com' + assert user.email == 'pre-failure@email.com' def test_transaction(self): """Test that transaction nesting works.""" @@ -43,8 +51,8 @@ def test_transaction(self): raise RuntimeError except RuntimeError: pass - self.assertEqual(user1.email, 'user1@email.com') - self.assertEqual(user2.email, 'user2@email.com') + assert user1.email == 'user1@email.com' + assert user2.email == 'user2@email.com' def test_store_in_transaction(self): """Test that storing inside a transaction is correctly dealt with.""" @@ -62,5 +70,44 @@ def test_store_in_transaction(self): except RuntimeError: pass - with self.assertRaises(exceptions.NotExistent): + with pytest.raises(exceptions.NotExistent): orm.User.objects.get(email='user_store_fail@email.com') + + def test_bulk_insert(self): + """Test that bulk insert works.""" + rows = [{'email': 'user1@email.com'}, {'email': 'user2@email.com'}] + with self.backend.transaction() as transaction: + # should fail if all fields are not given and allow_defaults=False + with pytest.raises(exceptions.IntegrityError, match='Incorrect fields'): + self.backend.bulk_insert(EntityTypes.USER, rows, transaction) + pks = self.backend.bulk_insert(EntityTypes.USER, rows, transaction, allow_defaults=True) + assert len(pks) == len(rows) + for pk, row in zip(pks, rows): + assert isinstance(pk, int) + user = orm.User.objects.get(id=pk) + assert user.email == row['email'] + + def test_bulk_update(self): + """Test that bulk update works.""" + user1 = orm.User('user1@email.com').store() + user2 = orm.User('user2@email.com').store() + user3 = orm.User('user3@email.com').store() + with self.backend.transaction() as transaction: + # should raise if the 'id' field is not present + with pytest.raises(exceptions.IntegrityError, match="'id' field not given"): + self.backend.bulk_update(EntityTypes.USER, [{'email': 'other'}], transaction) + # should raise if a non-existent field is present + with pytest.raises(exceptions.IntegrityError, match='Incorrect fields'): + self.backend.bulk_update(EntityTypes.USER, [{'id': user1.pk, 'x': 'other'}], transaction) + self.backend.bulk_update( + EntityTypes.USER, [{ + 'id': user1.pk, + 'email': 'other1' + }, { + 'id': user2.pk, + 'email': 'other2' + }], transaction + ) + assert user1.email == 'other1' + assert user2.email == 'other2' + assert user3.email == 'user3@email.com' From 0817ef6bde99eb01bed6454b97bd5d75316882d7 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 8 Oct 2021 08:13:29 +0200 Subject: [PATCH 07/14] Minor fixes --- aiida/orm/implementation/backends.py | 2 +- aiida/orm/implementation/django/backend.py | 5 ++++- aiida/orm/implementation/sqlalchemy/backend.py | 5 ++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/aiida/orm/implementation/backends.py b/aiida/orm/implementation/backends.py index 941219635f..bf8396569f 100644 --- a/aiida/orm/implementation/backends.py +++ b/aiida/orm/implementation/backends.py @@ -9,7 +9,7 @@ ########################################################################### """Generic backend related objects""" import abc -from typing import Any, List, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, List if TYPE_CHECKING: from sqlalchemy.orm.session import Session diff --git a/aiida/orm/implementation/django/backend.py b/aiida/orm/implementation/django/backend.py index 3917731c20..ba681b4299 100644 --- a/aiida/orm/implementation/django/backend.py +++ b/aiida/orm/implementation/django/backend.py @@ -96,7 +96,10 @@ def get_session(): @staticmethod @functools.lru_cache(maxsize=18) def _get_model_from_entity(entity_type: EntityTypes, with_pk: bool): - """Return the Django model corresponding to the given entity.""" + """Return the Django model and fields corresponding to the given entity. + + :param with_pk: if True, the fields returned will include the primary key + """ from sqlalchemy import inspect model = { diff --git a/aiida/orm/implementation/sqlalchemy/backend.py b/aiida/orm/implementation/sqlalchemy/backend.py index 53594ba9ee..72d46831de 100644 --- a/aiida/orm/implementation/sqlalchemy/backend.py +++ b/aiida/orm/implementation/sqlalchemy/backend.py @@ -92,7 +92,10 @@ def transaction(self): @staticmethod @functools.lru_cache(maxsize=18) def _get_mapper_from_entity(entity_type: EntityTypes, with_pk: bool): - """Return the Sqlalchemy mapper and non-primary keys corresponding to the given entity.""" + """Return the Sqlalchemy mapper and fields corresponding to the given entity. + + :param with_pk: if True, the fields returned will include the primary key + """ from sqlalchemy import inspect from aiida.backends.sqlalchemy.models.authinfo import DbAuthInfo From 95971c2f7bff8428e4afe601c7096a0fe2ef7dc4 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 8 Oct 2021 08:16:09 +0200 Subject: [PATCH 08/14] Update backends.py --- aiida/orm/implementation/backends.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aiida/orm/implementation/backends.py b/aiida/orm/implementation/backends.py index bf8396569f..43d5fdf42d 100644 --- a/aiida/orm/implementation/backends.py +++ b/aiida/orm/implementation/backends.py @@ -14,6 +14,7 @@ if TYPE_CHECKING: from sqlalchemy.orm.session import Session + from aiida.orm.entities import EntityTypes from aiida.orm.implementation import ( BackendAuthInfoCollection, BackendCommentCollection, From 76485abec4998c56cea36020611b64c750601d82 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 8 Oct 2021 08:23:15 +0200 Subject: [PATCH 09/14] fixes --- aiida/backends/sqlalchemy/models/node.py | 5 +++-- docs/source/nitpick-exceptions | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/aiida/backends/sqlalchemy/models/node.py b/aiida/backends/sqlalchemy/models/node.py index ba59190b3e..10571cfcfe 100644 --- a/aiida/backends/sqlalchemy/models/node.py +++ b/aiida/backends/sqlalchemy/models/node.py @@ -156,8 +156,9 @@ class DbLink(Base): Integer, ForeignKey('db_dbnode.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED'), index=True ) - input = relationship('DbNode', primaryjoin='DbLink.input_id == DbNode.id') - output = relationship('DbNode', primaryjoin='DbLink.output_id == DbNode.id') + # https://docs.sqlalchemy.org/en/14/errors.html#relationship-x-will-copy-column-q-to-column-p-which-conflicts-with-relationship-s-y + input = relationship('DbNode', primaryjoin='DbLink.input_id == DbNode.id', overlaps='inputs_q,outputs_q') + output = relationship('DbNode', primaryjoin='DbLink.output_id == DbNode.id', overlaps='inputs_q,outputs_q') label = Column(String(255), index=True, nullable=False) type = Column(String(255), index=True) diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 99af999ad5..8bc6d9ef4f 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -35,6 +35,7 @@ py:class aiida.engine.runners.ResultAndNode py:class aiida.engine.runners.ResultAndPk py:class aiida.engine.processes.workchains.workchain.WorkChainSpec py:class aiida.manage.manager.Manager +py:class aiida.orm.entities.EntityTypes py:class aiida.orm.nodes.node.WarnWhenNotEntered py:class aiida.orm.implementation.querybuilder.QueryDictType py:class aiida.orm.querybuilder.Classifier From 7d776854dab5e891c824453501d7613d18aac418 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Fri, 8 Oct 2021 08:25:47 +0200 Subject: [PATCH 10/14] Update nitpick-exceptions --- docs/source/nitpick-exceptions | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index 8bc6d9ef4f..bdc826251d 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -21,6 +21,7 @@ py:class builtins.dict # typing py:class asyncio.events.AbstractEventLoop py:class EntityType +py:class EntityTypes py:class function py:class IO py:class traceback From 01cb3079066fbf8331093ae27eb9cc1df15ef36d Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Sun, 10 Oct 2021 02:32:08 +0200 Subject: [PATCH 11/14] Move delete_nodes_and_connections to backend --- .pre-commit-config.yaml | 4 ++ aiida/backends/djsite/utils.py | 28 -------- aiida/backends/sqlalchemy/utils.py | 28 -------- aiida/backends/utils.py | 15 ---- aiida/orm/implementation/backends.py | 44 ++++++++---- aiida/orm/implementation/django/backend.py | 26 ++++--- .../orm/implementation/sqlalchemy/backend.py | 70 +++++++++++-------- aiida/tools/graph/deletions.py | 15 ++-- tests/orm/node/test_node.py | 15 ++-- 9 files changed, 110 insertions(+), 135 deletions(-) delete mode 100644 aiida/backends/djsite/utils.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fff8155541..810c81fd80 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -75,6 +75,10 @@ repos: aiida/manage/manager.py| aiida/manage/database/delete/nodes.py| aiida/orm/querybuilder.py| + aiida/orm/implementation/backends.py| + aiida/orm/implementation/sql/backends.py| + aiida/orm/implementation/django/backend.py| + aiida/orm/implementation/sqlalchemy/backend.py| aiida/orm/implementation/querybuilder.py| aiida/orm/implementation/sqlalchemy/querybuilder/.*py| aiida/orm/nodes/data/jsonable.py| diff --git a/aiida/backends/djsite/utils.py b/aiida/backends/djsite/utils.py deleted file mode 100644 index 74bfb56269..0000000000 --- a/aiida/backends/djsite/utils.py +++ /dev/null @@ -1,28 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Utility functions specific to the Django backend.""" - - -def delete_nodes_and_connections_django(pks_to_delete): # pylint: disable=invalid-name - """Delete all nodes corresponding to pks in the input. - - :param pks_to_delete: A list, tuple or set of pks that should be deleted. - """ - # pylint: disable=no-member,import-error,no-name-in-module - from django.db import transaction - from django.db.models import Q - - from aiida.backends.djsite.db import models - with transaction.atomic(): - # This is fixed in pylint-django>=2, but this supports only py3 - # Delete all links pointing to or from a given node - models.DbLink.objects.filter(Q(input__in=pks_to_delete) | Q(output__in=pks_to_delete)).delete() - # now delete nodes - models.DbNode.objects.filter(pk__in=pks_to_delete).delete() diff --git a/aiida/backends/sqlalchemy/utils.py b/aiida/backends/sqlalchemy/utils.py index a8d76265ef..780df99bf3 100644 --- a/aiida/backends/sqlalchemy/utils.py +++ b/aiida/backends/sqlalchemy/utils.py @@ -11,34 +11,6 @@ """Utility functions specific to the SqlAlchemy backend.""" -def delete_nodes_and_connections_sqla(pks_to_delete): # pylint: disable=invalid-name - """ - Delete all nodes corresponding to pks in the input. - :param pks_to_delete: A list, tuple or set of pks that should be deleted. - """ - # pylint: disable=no-value-for-parameter - from aiida.backends.sqlalchemy.models.group import table_groups_nodes - from aiida.backends.sqlalchemy.models.node import DbLink, DbNode - from aiida.manage.manager import get_manager - - backend = get_manager().get_backend() - - with backend.transaction() as session: - # I am first making a statement to delete the membership of these nodes to groups. - # Since table_groups_nodes is a sqlalchemy.schema.Table, I am using expression language to compile - # a stmt to be executed by the session. It works, but it's not nice that two different ways are used! - # Can this be changed? - stmt = table_groups_nodes.delete().where(table_groups_nodes.c.dbnode_id.in_(list(pks_to_delete))) - session.execute(stmt) - # First delete links, then the Nodes, since we are not cascading deletions. - # Here I delete the links coming out of the nodes marked for deletion. - session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') - # Here I delete the links pointing to the nodes marked for deletion. - session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') - # Now I am deleting the nodes - session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') - - def flag_modified(instance, key): """Wrapper around `sqlalchemy.orm.attributes.flag_modified` to correctly dereference utils.ModelWrapper diff --git a/aiida/backends/utils.py b/aiida/backends/utils.py index 234412e1f1..d73be4674d 100644 --- a/aiida/backends/utils.py +++ b/aiida/backends/utils.py @@ -8,9 +8,6 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Backend-agnostic utility functions""" -from aiida.backends import BACKEND_DJANGO, BACKEND_SQLA -from aiida.manage import configuration - AIIDA_ATTRIBUTE_SEP = '.' @@ -47,15 +44,3 @@ def create_scoped_session_factory(engine, **kwargs): """Create scoped SQLAlchemy session factory""" from sqlalchemy.orm import scoped_session, sessionmaker return scoped_session(sessionmaker(bind=engine, future=True, **kwargs)) - - -def delete_nodes_and_connections(pks): - """Backend-agnostic function to delete Nodes and connections""" - if configuration.PROFILE.database_backend == BACKEND_DJANGO: - from aiida.backends.djsite.utils import delete_nodes_and_connections_django as delete_nodes_backend - elif configuration.PROFILE.database_backend == BACKEND_SQLA: - from aiida.backends.sqlalchemy.utils import delete_nodes_and_connections_sqla as delete_nodes_backend - else: - raise Exception(f'unknown backend {configuration.PROFILE.database_backend}') - - delete_nodes_backend(pks) diff --git a/aiida/orm/implementation/backends.py b/aiida/orm/implementation/backends.py index 43d5fdf42d..83ffca68ed 100644 --- a/aiida/orm/implementation/backends.py +++ b/aiida/orm/implementation/backends.py @@ -9,7 +9,7 @@ ########################################################################### """Generic backend related objects""" import abc -from typing import TYPE_CHECKING, Any, List +from typing import TYPE_CHECKING, ContextManager, Generic, List, Sequence, TypeVar if TYPE_CHECKING: from sqlalchemy.orm.session import Session @@ -28,12 +28,14 @@ __all__ = ('Backend',) +TransactionType = TypeVar('TransactionType') -class Backend(abc.ABC): + +class Backend(abc.ABC, Generic[TransactionType]): """The public interface that defines a backend factory that creates backend specific concrete objects.""" @abc.abstractmethod - def migrate(self): + def migrate(self) -> None: """Migrate the database to the latest schema generation or version.""" @property @@ -76,7 +78,14 @@ def users(self) -> 'BackendUserCollection': """Return the collection of users""" @abc.abstractmethod - def transaction(self): + def get_session(self) -> 'Session': + """Return a database session that can be used by the `QueryBuilder` to perform its query. + + :return: an instance of :class:`sqlalchemy.orm.session.Session` + """ + + @abc.abstractmethod + def transaction(self) -> ContextManager[TransactionType]: """ Get a context manager that can be used as a transaction context for a series of backend operations. If there is an exception within the context then the changes will be rolled back and the state will @@ -86,20 +95,27 @@ def transaction(self): """ @abc.abstractmethod - def get_session(self) -> 'Session': - """Return a database session that can be used by the `QueryBuilder` to perform its query. + def delete_nodes_and_connections(self, pks_to_delete: Sequence[int], transaction: TransactionType): + """Delete all nodes corresponding to pks in the input. - :return: an instance of :class:`sqlalchemy.orm.session.Session` + This method is intended to be used within a transaction context. + + :param pks_to_delete: a sequence of node pks to delete + :param transact: the returned instance from entering transaction context """ @abc.abstractmethod - def bulk_insert(self, - entity_type: 'EntityTypes', - rows: List[dict], - transaction: Any, - allow_defaults: bool = False) -> List[int]: + def bulk_insert( + self, + entity_type: 'EntityTypes', + rows: List[dict], + transaction: TransactionType, + allow_defaults: bool = False + ) -> List[int]: """Insert a list of entities into the database, directly into a backend transaction. + This method is intended to be used within a transaction context. + :param entity_type: The type of the entity :param data: A list of dictionaries, containing all fields of the backend model, except the `id` field (a.k.a primary key), which will be generated dynamically @@ -113,9 +129,11 @@ def bulk_insert(self, """ @abc.abstractmethod - def bulk_update(self, entity_type: 'EntityTypes', rows: List[dict], transaction: Any) -> None: + def bulk_update(self, entity_type: 'EntityTypes', rows: List[dict], transaction: TransactionType) -> None: """Update a list of entities in the database, directly with a backend transaction. + This method is intended to be used within a transaction context. + :param entity_type: The type of the entity :param data: A list of dictionaries, containing fields of the backend model to update, and the `id` field (a.k.a primary key) diff --git a/aiida/orm/implementation/django/backend.py b/aiida/orm/implementation/django/backend.py index ba681b4299..de45f423cd 100644 --- a/aiida/orm/implementation/django/backend.py +++ b/aiida/orm/implementation/django/backend.py @@ -10,7 +10,7 @@ """Django implementation of `aiida.orm.implementation.backends.Backend`.""" from contextlib import contextmanager import functools -from typing import Any, List +from typing import ContextManager, List, Sequence # pylint: disable=import-error,no-name-in-module from django.apps import apps @@ -76,11 +76,6 @@ def query(self): def users(self): return self._users - @staticmethod - def transaction(): - """Open a transaction to be used as a context manager.""" - return django_transaction.atomic() - @staticmethod def get_session(): """Return a database session that can be used by the `QueryBuilder` to perform its query. @@ -93,6 +88,11 @@ def get_session(): from aiida.backends.djsite import get_scoped_session return get_scoped_session() + @staticmethod + def transaction() -> ContextManager[None]: + """Open a transaction to be used as a context manager.""" + return django_transaction.atomic() + @staticmethod @functools.lru_cache(maxsize=18) def _get_model_from_entity(entity_type: EntityTypes, with_pk: bool): @@ -121,7 +121,7 @@ def _get_model_from_entity(entity_type: EntityTypes, with_pk: bool): def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], - transaction: Any, + transaction: None, allow_defaults: bool = False) -> List[int]: model, keys = self._get_model_from_entity(entity_type, False) if allow_defaults: @@ -141,7 +141,7 @@ def bulk_insert(self, model.objects.bulk_create(objects) return [obj.id for obj in objects] - def bulk_update(self, entity_type: EntityTypes, rows: List[dict], transaction: Any) -> None: + def bulk_update(self, entity_type: EntityTypes, rows: List[dict], transaction: None) -> None: model, keys = self._get_model_from_entity(entity_type, True) id_entries = {} fields = None @@ -166,6 +166,12 @@ def bulk_update(self, entity_type: EntityTypes, rows: List[dict], transaction: A objects.append(obj) model.objects.bulk_update(objects, fields) + def delete_nodes_and_connections(self, pks_to_delete: Sequence[int], transaction: None) -> None: + # Delete all links pointing to or from a given node + dbm.DbLink.objects.filter(models.Q(input__in=pks_to_delete) | models.Q(output__in=pks_to_delete)).delete() + # now delete nodes + dbm.DbNode.objects.filter(pk__in=pks_to_delete).delete() + # Below are abstract methods inherited from `aiida.orm.implementation.sql.backends.SqlBackend` def get_backend_entity(self, model): @@ -180,7 +186,7 @@ def cursor(self): :rtype: :class:`psycopg2.extensions.cursor` """ try: - yield self.get_connection().cursor() + yield self._get_connection().cursor() finally: pass @@ -197,7 +203,7 @@ def execute_raw(self, query): return results @staticmethod - def get_connection(): + def _get_connection(): """ Get the Django connection diff --git a/aiida/orm/implementation/sqlalchemy/backend.py b/aiida/orm/implementation/sqlalchemy/backend.py index 72d46831de..f55166c252 100644 --- a/aiida/orm/implementation/sqlalchemy/backend.py +++ b/aiida/orm/implementation/sqlalchemy/backend.py @@ -11,7 +11,9 @@ # pylint: disable=missing-function-docstring from contextlib import contextmanager import functools -from typing import Any, List +from typing import Iterator, List, Sequence + +from sqlalchemy.orm import Session from aiida.backends.sqlalchemy.manager import SqlaBackendManager from aiida.backends.sqlalchemy.models import base @@ -72,8 +74,17 @@ def query(self): def users(self): return self._users + @staticmethod + def get_session() -> Session: + """Return a database session that can be used by the `QueryBuilder` to perform its query. + + :return: an instance of :class:`sqlalchemy.orm.session.Session` + """ + from aiida.backends.sqlalchemy import get_scoped_session + return get_scoped_session() + @contextmanager - def transaction(self): + def transaction(self) -> Iterator[Session]: """Open a transaction to be used as a context manager. If there is an exception within the context then the changes will be rolled back and the state will be as before @@ -120,11 +131,13 @@ def _get_mapper_from_entity(entity_type: EntityTypes, with_pk: bool): keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key} return mapper, keys - def bulk_insert(self, - entity_type: EntityTypes, - rows: List[dict], - transaction: Any, - allow_defaults: bool = False) -> List[int]: + def bulk_insert( + self, + entity_type: EntityTypes, + rows: List[dict], + transaction: Session, + allow_defaults: bool = False + ) -> List[int]: mapper, keys = self._get_mapper_from_entity(entity_type, False) if not rows: return [] @@ -145,7 +158,7 @@ def bulk_insert(self, transaction.bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True) return [row['id'] for row in rows] - def bulk_update(self, entity_type: EntityTypes, rows: List[dict], transaction: Any) -> None: + def bulk_update(self, entity_type: EntityTypes, rows: List[dict], transaction: Session) -> None: # pylint: disable=no-self-use mapper, keys = self._get_mapper_from_entity(entity_type, True) if not rows: return None @@ -156,41 +169,42 @@ def bulk_update(self, entity_type: EntityTypes, rows: List[dict], transaction: A raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') transaction.bulk_update_mappings(mapper, rows) - @staticmethod - def get_session(): - """Return a database session that can be used by the `QueryBuilder` to perform its query. + def delete_nodes_and_connections(self, pks_to_delete: Sequence[int], transact: Session) -> None: # pylint: disable=no-self-use + # pylint: disable=no-value-for-parameter + from aiida.backends.sqlalchemy.models.group import table_groups_nodes + from aiida.backends.sqlalchemy.models.node import DbLink, DbNode - :return: an instance of :class:`sqlalchemy.orm.session.Session` - """ - from aiida.backends.sqlalchemy import get_scoped_session - return get_scoped_session() + session = transact + + # I am first making a statement to delete the membership of these nodes to groups. + # Since table_groups_nodes is a sqlalchemy.schema.Table, I am using expression language to compile + # a stmt to be executed by the session. It works, but it's not nice that two different ways are used! + # Can this be changed? + stmt = table_groups_nodes.delete().where(table_groups_nodes.c.dbnode_id.in_(list(pks_to_delete))) + session.execute(stmt) + # First delete links, then the Nodes, since we are not cascading deletions. + # Here I delete the links coming out of the nodes marked for deletion. + session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') + # Here I delete the links pointing to the nodes marked for deletion. + session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') + # Now I am deleting the nodes + session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') # Below are abstract methods inherited from `aiida.orm.implementation.sql.backends.SqlBackend` def get_backend_entity(self, model): - """Return a `BackendEntity` instance from a `DbModel` instance.""" return convert.get_backend_entity(model, self) @contextmanager def cursor(self): - """Return a psycopg cursor to be used in a context manager. - - :return: a psycopg cursor - :rtype: :class:`psycopg2.extensions.cursor` - """ from aiida.backends import sqlalchemy as sa try: connection = sa.ENGINE.raw_connection() yield connection.cursor() finally: - self.get_connection().close() + self._get_connection().close() def execute_raw(self, query): - """Execute a raw SQL statement and return the result. - - :param query: a string containing a raw SQL statement - :return: the result of the query - """ from sqlalchemy import text from sqlalchemy.exc import ResourceClosedError # pylint: disable=import-error,no-name-in-module @@ -205,7 +219,7 @@ def execute_raw(self, query): return results @staticmethod - def get_connection(): + def _get_connection(): """Get the SQLA database connection :return: the SQLA database connection diff --git a/aiida/tools/graph/deletions.py b/aiida/tools/graph/deletions.py index 57f785e9c2..0bb0b02f62 100644 --- a/aiida/tools/graph/deletions.py +++ b/aiida/tools/graph/deletions.py @@ -11,8 +11,8 @@ import logging from typing import Callable, Iterable, Set, Tuple, Union -from aiida.backends.utils import delete_nodes_and_connections from aiida.common.log import AIIDA_LOGGER +from aiida.manage.manager import get_manager from aiida.orm import Group, Node, QueryBuilder from aiida.tools.graph.graph_traversers import get_nodes_delete @@ -21,9 +21,12 @@ DELETE_LOGGER = AIIDA_LOGGER.getChild('delete') -def delete_nodes(pks: Iterable[int], - dry_run: Union[bool, Callable[[Set[int]], bool]] = True, - **traversal_rules: bool) -> Tuple[Set[int], bool]: +def delete_nodes( + pks: Iterable[int], + dry_run: Union[bool, Callable[[Set[int]], bool]] = True, + backend=None, + **traversal_rules: bool +) -> Tuple[Set[int], bool]: """Delete nodes given a list of "starting" PKs. This command will delete not only the specified nodes, but also the ones that are @@ -60,6 +63,7 @@ def delete_nodes(pks: Iterable[int], :returns: (pks to delete, whether they were deleted) """ + backend = backend or get_manager().get_backend() # pylint: disable=too-many-arguments,too-many-branches,too-many-locals,too-many-statements @@ -99,7 +103,8 @@ def _missing_callback(_pks: Iterable[int]): return (pks_set_to_delete, True) DELETE_LOGGER.report('Starting node deletion...') - delete_nodes_and_connections(pks_set_to_delete) + with backend.transaction() as transaction: + backend.delete_nodes_and_connections(pks_set_to_delete, transaction) DELETE_LOGGER.report('Deletion of nodes completed.') return (pks_set_to_delete, True) diff --git a/tests/orm/node/test_node.py b/tests/orm/node/test_node.py index 1533fa7c24..3d5763038d 100644 --- a/tests/orm/node/test_node.py +++ b/tests/orm/node/test_node.py @@ -16,7 +16,8 @@ import pytest -from aiida.common import LinkType, exceptions +from aiida.common import LinkType, exceptions, timezone +from aiida.manage.manager import get_manager from aiida.orm import CalculationNode, Computer, Data, Log, Node, User, WorkflowNode, load_node from aiida.orm.utils.links import LinkTriple @@ -800,10 +801,9 @@ class TestNodeDelete: # pylint: disable=no-member,no-self-use @pytest.mark.usefixtures('clear_database_before_test') - def test_delete_through_utility_method(self): - """Test deletion works correctly through the `aiida.backends.utils.delete_nodes_and_connections`.""" - from aiida.backends.utils import delete_nodes_and_connections - from aiida.common import timezone + def test_delete_through_backend(self): + """Test deletion works correctly through the backend.""" + backend = get_manager().get_backend() data_one = Data().store() data_two = Data().store() @@ -820,7 +820,8 @@ def test_delete_through_utility_method(self): assert len(Log.objects.get_logs_for(data_two)) == 1 assert Log.objects.get_logs_for(data_two)[0].pk == log_two.pk - delete_nodes_and_connections([data_two.pk]) + with backend.transaction() as transaction: + backend.delete_nodes_and_connections([data_two.pk], transaction) assert len(Log.objects.get_logs_for(data_one)) == 1 assert Log.objects.get_logs_for(data_one)[0].pk == log_one.pk @@ -829,8 +830,6 @@ def test_delete_through_utility_method(self): @pytest.mark.usefixtures('clear_database_before_test') def test_delete_collection_logs(self): """Test deletion works correctly through objects collection.""" - from aiida.common import timezone - data_one = Data().store() data_two = Data().store() From 0f805846b9db01ba3ae9f307672ce753be846655 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Sun, 10 Oct 2021 04:39:00 +0200 Subject: [PATCH 12/14] fixes --- aiida/orm/implementation/backends.py | 8 ++++---- aiida/orm/implementation/django/backend.py | 2 +- aiida/orm/implementation/sql/backends.py | 9 ++++----- aiida/orm/implementation/sqlalchemy/backend.py | 14 ++++++-------- docs/source/nitpick-exceptions | 3 +++ 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/aiida/orm/implementation/backends.py b/aiida/orm/implementation/backends.py index 83ffca68ed..90434c32f6 100644 --- a/aiida/orm/implementation/backends.py +++ b/aiida/orm/implementation/backends.py @@ -68,15 +68,15 @@ def logs(self) -> 'BackendLogCollection': def nodes(self) -> 'BackendNodeCollection': """Return the collection of nodes""" - @abc.abstractmethod - def query(self) -> 'BackendQueryBuilder': - """Return an instance of a query builder implementation for this backend""" - @property @abc.abstractmethod def users(self) -> 'BackendUserCollection': """Return the collection of users""" + @abc.abstractmethod + def query(self) -> 'BackendQueryBuilder': + """Return an instance of a query builder implementation for this backend""" + @abc.abstractmethod def get_session(self) -> 'Session': """Return a database session that can be used by the `QueryBuilder` to perform its query. diff --git a/aiida/orm/implementation/django/backend.py b/aiida/orm/implementation/django/backend.py index de45f423cd..4a08a00ccc 100644 --- a/aiida/orm/implementation/django/backend.py +++ b/aiida/orm/implementation/django/backend.py @@ -28,7 +28,7 @@ __all__ = ('DjangoBackend',) -class DjangoBackend(SqlBackend[models.Model]): +class DjangoBackend(SqlBackend[None, models.Model]): """Django implementation of `aiida.orm.implementation.backends.Backend`.""" def __init__(self): diff --git a/aiida/orm/implementation/sql/backends.py b/aiida/orm/implementation/sql/backends.py index 2bb21f22af..e729ed3f6b 100644 --- a/aiida/orm/implementation/sql/backends.py +++ b/aiida/orm/implementation/sql/backends.py @@ -11,15 +11,15 @@ import abc import typing -from .. import backends +from .. import backends, entities __all__ = ('SqlBackend',) -# The template type for the base ORM model type +# The template type for the base sqlalchemy/django ORM model type ModelType = typing.TypeVar('ModelType') # pylint: disable=invalid-name -class SqlBackend(typing.Generic[ModelType], backends.Backend): +class SqlBackend(typing.Generic[backends.TransactionType, ModelType], backends.Backend[backends.TransactionType]): """ A class for SQL based backends. Assumptions are that: * there is an ORM @@ -30,13 +30,12 @@ class SqlBackend(typing.Generic[ModelType], backends.Backend): """ @abc.abstractmethod - def get_backend_entity(self, model): + def get_backend_entity(self, model: ModelType) -> entities.BackendEntity: """ Return the backend entity that corresponds to the given Model instance :param model: the ORM model instance to promote to a backend instance :return: the backend entity corresponding to the given model - :rtype: :class:`aiida.orm.implementation.entities.BackendEntity` """ @abc.abstractmethod diff --git a/aiida/orm/implementation/sqlalchemy/backend.py b/aiida/orm/implementation/sqlalchemy/backend.py index f55166c252..a3fca0e70f 100644 --- a/aiida/orm/implementation/sqlalchemy/backend.py +++ b/aiida/orm/implementation/sqlalchemy/backend.py @@ -26,7 +26,7 @@ __all__ = ('SqlaBackend',) -class SqlaBackend(SqlBackend[base.Base]): +class SqlaBackend(SqlBackend[Session, base.Base]): """SqlAlchemy implementation of `aiida.orm.implementation.backends.Backend`.""" def __init__(self): @@ -169,26 +169,24 @@ def bulk_update(self, entity_type: EntityTypes, rows: List[dict], transaction: S raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') transaction.bulk_update_mappings(mapper, rows) - def delete_nodes_and_connections(self, pks_to_delete: Sequence[int], transact: Session) -> None: # pylint: disable=no-self-use + def delete_nodes_and_connections(self, pks_to_delete: Sequence[int], transaction: Session) -> None: # pylint: disable=no-self-use # pylint: disable=no-value-for-parameter from aiida.backends.sqlalchemy.models.group import table_groups_nodes from aiida.backends.sqlalchemy.models.node import DbLink, DbNode - session = transact - # I am first making a statement to delete the membership of these nodes to groups. # Since table_groups_nodes is a sqlalchemy.schema.Table, I am using expression language to compile # a stmt to be executed by the session. It works, but it's not nice that two different ways are used! # Can this be changed? stmt = table_groups_nodes.delete().where(table_groups_nodes.c.dbnode_id.in_(list(pks_to_delete))) - session.execute(stmt) + transaction.execute(stmt) # First delete links, then the Nodes, since we are not cascading deletions. # Here I delete the links coming out of the nodes marked for deletion. - session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') + transaction.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') # Here I delete the links pointing to the nodes marked for deletion. - session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') + transaction.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') # Now I am deleting the nodes - session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') + transaction.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') # Below are abstract methods inherited from `aiida.orm.implementation.sql.backends.SqlBackend` diff --git a/docs/source/nitpick-exceptions b/docs/source/nitpick-exceptions index bdc826251d..82ee5e3994 100644 --- a/docs/source/nitpick-exceptions +++ b/docs/source/nitpick-exceptions @@ -19,9 +19,12 @@ py:class builtins.str py:class builtins.dict # typing +py:class AbstractContextManager py:class asyncio.events.AbstractEventLoop py:class EntityType py:class EntityTypes +py:class ModelType +py:class TransactionType py:class function py:class IO py:class traceback From 264c61bd3d822ad18fbab99ece565d7f51226063 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Sat, 16 Oct 2021 02:27:39 +0200 Subject: [PATCH 13/14] Improve testing --- aiida/orm/implementation/backends.py | 44 +++---- aiida/orm/implementation/django/backend.py | 22 ++-- aiida/orm/implementation/sql/backends.py | 2 +- .../orm/implementation/sqlalchemy/backend.py | 50 ++++---- aiida/tools/graph/deletions.py | 4 +- tests/orm/implementation/test_backend.py | 114 +++++++++++++----- 6 files changed, 145 insertions(+), 91 deletions(-) diff --git a/aiida/orm/implementation/backends.py b/aiida/orm/implementation/backends.py index 90434c32f6..b1273661d9 100644 --- a/aiida/orm/implementation/backends.py +++ b/aiida/orm/implementation/backends.py @@ -9,7 +9,7 @@ ########################################################################### """Generic backend related objects""" import abc -from typing import TYPE_CHECKING, ContextManager, Generic, List, Sequence, TypeVar +from typing import TYPE_CHECKING, Any, ContextManager, List, Sequence, TypeVar if TYPE_CHECKING: from sqlalchemy.orm.session import Session @@ -31,7 +31,7 @@ TransactionType = TypeVar('TransactionType') -class Backend(abc.ABC, Generic[TransactionType]): +class Backend(abc.ABC): """The public interface that defines a backend factory that creates backend specific concrete objects.""" @abc.abstractmethod @@ -85,7 +85,7 @@ def get_session(self) -> 'Session': """ @abc.abstractmethod - def transaction(self) -> ContextManager[TransactionType]: + def transaction(self) -> ContextManager[Any]: """ Get a context manager that can be used as a transaction context for a series of backend operations. If there is an exception within the context then the changes will be rolled back and the state will @@ -94,32 +94,18 @@ def transaction(self) -> ContextManager[TransactionType]: :return: a context manager to group database operations """ + @property @abc.abstractmethod - def delete_nodes_and_connections(self, pks_to_delete: Sequence[int], transaction: TransactionType): - """Delete all nodes corresponding to pks in the input. - - This method is intended to be used within a transaction context. - - :param pks_to_delete: a sequence of node pks to delete - :param transact: the returned instance from entering transaction context - """ + def in_transaction(self) -> bool: + """Return whether a transaction is currently active.""" @abc.abstractmethod - def bulk_insert( - self, - entity_type: 'EntityTypes', - rows: List[dict], - transaction: TransactionType, - allow_defaults: bool = False - ) -> List[int]: + def bulk_insert(self, entity_type: 'EntityTypes', rows: List[dict], allow_defaults: bool = False) -> List[int]: """Insert a list of entities into the database, directly into a backend transaction. - This method is intended to be used within a transaction context. - :param entity_type: The type of the entity :param data: A list of dictionaries, containing all fields of the backend model, except the `id` field (a.k.a primary key), which will be generated dynamically - :param transaction: the returned object of the ``self.transaction`` context :param allow_defaults: If ``False``, assert that each row contains all fields (except primary key(s)), otherwise, allow default values for missing fields. @@ -129,15 +115,23 @@ def bulk_insert( """ @abc.abstractmethod - def bulk_update(self, entity_type: 'EntityTypes', rows: List[dict], transaction: TransactionType) -> None: + def bulk_update(self, entity_type: 'EntityTypes', rows: List[dict]) -> None: """Update a list of entities in the database, directly with a backend transaction. - This method is intended to be used within a transaction context. - :param entity_type: The type of the entity :param data: A list of dictionaries, containing fields of the backend model to update, and the `id` field (a.k.a primary key) - :param transaction: the returned object of the ``self.transaction`` context :raises: ``IntegrityError`` if the keys in a row are not a subset of the columns in the table """ + + @abc.abstractmethod + def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]): + """Delete all nodes corresponding to pks in the input and any links to/from them. + + This method is intended to be used within a transaction context. + + :param pks_to_delete: a sequence of node pks to delete + + :raises: ``AssertionError`` if a transaction is not active + """ diff --git a/aiida/orm/implementation/django/backend.py b/aiida/orm/implementation/django/backend.py index 4a08a00ccc..2d02ae4f5c 100644 --- a/aiida/orm/implementation/django/backend.py +++ b/aiida/orm/implementation/django/backend.py @@ -10,7 +10,7 @@ """Django implementation of `aiida.orm.implementation.backends.Backend`.""" from contextlib import contextmanager import functools -from typing import ContextManager, List, Sequence +from typing import Any, ContextManager, List, Sequence # pylint: disable=import-error,no-name-in-module from django.apps import apps @@ -28,7 +28,7 @@ __all__ = ('DjangoBackend',) -class DjangoBackend(SqlBackend[None, models.Model]): +class DjangoBackend(SqlBackend[models.Model]): """Django implementation of `aiida.orm.implementation.backends.Backend`.""" def __init__(self): @@ -89,10 +89,14 @@ def get_session(): return get_scoped_session() @staticmethod - def transaction() -> ContextManager[None]: + def transaction() -> ContextManager[Any]: """Open a transaction to be used as a context manager.""" return django_transaction.atomic() + @property + def in_transaction(self) -> bool: + return not django_transaction.get_autocommit() + @staticmethod @functools.lru_cache(maxsize=18) def _get_model_from_entity(entity_type: EntityTypes, with_pk: bool): @@ -118,11 +122,7 @@ def _get_model_from_entity(entity_type: EntityTypes, with_pk: bool): keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key} return model, keys - def bulk_insert(self, - entity_type: EntityTypes, - rows: List[dict], - transaction: None, - allow_defaults: bool = False) -> List[int]: + def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults: bool = False) -> List[int]: model, keys = self._get_model_from_entity(entity_type, False) if allow_defaults: for row in rows: @@ -141,7 +141,7 @@ def bulk_insert(self, model.objects.bulk_create(objects) return [obj.id for obj in objects] - def bulk_update(self, entity_type: EntityTypes, rows: List[dict], transaction: None) -> None: + def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None: model, keys = self._get_model_from_entity(entity_type, True) id_entries = {} fields = None @@ -166,7 +166,9 @@ def bulk_update(self, entity_type: EntityTypes, rows: List[dict], transaction: N objects.append(obj) model.objects.bulk_update(objects, fields) - def delete_nodes_and_connections(self, pks_to_delete: Sequence[int], transaction: None) -> None: + def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None: + if not self.in_transaction: + raise AssertionError('Cannot delete nodes outside a transaction') # Delete all links pointing to or from a given node dbm.DbLink.objects.filter(models.Q(input__in=pks_to_delete) | models.Q(output__in=pks_to_delete)).delete() # now delete nodes diff --git a/aiida/orm/implementation/sql/backends.py b/aiida/orm/implementation/sql/backends.py index e729ed3f6b..1423ce5d22 100644 --- a/aiida/orm/implementation/sql/backends.py +++ b/aiida/orm/implementation/sql/backends.py @@ -19,7 +19,7 @@ ModelType = typing.TypeVar('ModelType') # pylint: disable=invalid-name -class SqlBackend(typing.Generic[backends.TransactionType, ModelType], backends.Backend[backends.TransactionType]): +class SqlBackend(typing.Generic[ModelType], backends.Backend): """ A class for SQL based backends. Assumptions are that: * there is an ORM diff --git a/aiida/orm/implementation/sqlalchemy/backend.py b/aiida/orm/implementation/sqlalchemy/backend.py index a3fca0e70f..f00af12e02 100644 --- a/aiida/orm/implementation/sqlalchemy/backend.py +++ b/aiida/orm/implementation/sqlalchemy/backend.py @@ -26,7 +26,7 @@ __all__ = ('SqlaBackend',) -class SqlaBackend(SqlBackend[Session, base.Base]): +class SqlaBackend(SqlBackend[base.Base]): """SqlAlchemy implementation of `aiida.orm.implementation.backends.Backend`.""" def __init__(self): @@ -100,6 +100,10 @@ def transaction(self) -> Iterator[Session]: with session.begin_nested(): yield session + @property + def in_transaction(self) -> bool: + return self.get_session().in_nested_transaction() + @staticmethod @functools.lru_cache(maxsize=18) def _get_mapper_from_entity(entity_type: EntityTypes, with_pk: bool): @@ -131,13 +135,7 @@ def _get_mapper_from_entity(entity_type: EntityTypes, with_pk: bool): keys = {key for key, col in mapper.c.items() if with_pk or col not in mapper.primary_key} return mapper, keys - def bulk_insert( - self, - entity_type: EntityTypes, - rows: List[dict], - transaction: Session, - allow_defaults: bool = False - ) -> List[int]: + def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults: bool = False) -> List[int]: mapper, keys = self._get_mapper_from_entity(entity_type, False) if not rows: return [] @@ -155,10 +153,10 @@ def bulk_insert( # note for postgresql+psycopg2 we could also use `save_all` + `flush` with minimal performance degradation, see # https://docs.sqlalchemy.org/en/14/changelog/migration_14.html#orm-batch-inserts-with-psycopg2-now-batch-statements-with-returning-in-most-cases # by contrast, in sqlite, bulk_insert is faster: https://docs.sqlalchemy.org/en/14/faq/performance.html - transaction.bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True) + self.get_session().bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True) return [row['id'] for row in rows] - def bulk_update(self, entity_type: EntityTypes, rows: List[dict], transaction: Session) -> None: # pylint: disable=no-self-use + def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None: # pylint: disable=no-self-use mapper, keys = self._get_mapper_from_entity(entity_type, True) if not rows: return None @@ -167,26 +165,26 @@ def bulk_update(self, entity_type: EntityTypes, rows: List[dict], transaction: S raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}") if not keys.issuperset(row): raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') - transaction.bulk_update_mappings(mapper, rows) + self.get_session().bulk_update_mappings(mapper, rows) - def delete_nodes_and_connections(self, pks_to_delete: Sequence[int], transaction: Session) -> None: # pylint: disable=no-self-use + def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None: # pylint: disable=no-self-use # pylint: disable=no-value-for-parameter - from aiida.backends.sqlalchemy.models.group import table_groups_nodes + from aiida.backends.sqlalchemy.models.group import DbGroupNode from aiida.backends.sqlalchemy.models.node import DbLink, DbNode - # I am first making a statement to delete the membership of these nodes to groups. - # Since table_groups_nodes is a sqlalchemy.schema.Table, I am using expression language to compile - # a stmt to be executed by the session. It works, but it's not nice that two different ways are used! - # Can this be changed? - stmt = table_groups_nodes.delete().where(table_groups_nodes.c.dbnode_id.in_(list(pks_to_delete))) - transaction.execute(stmt) - # First delete links, then the Nodes, since we are not cascading deletions. - # Here I delete the links coming out of the nodes marked for deletion. - transaction.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') - # Here I delete the links pointing to the nodes marked for deletion. - transaction.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') - # Now I am deleting the nodes - transaction.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') + if not self.in_transaction: + raise AssertionError('Cannot delete nodes outside a transaction') + + session = self.get_session() + # Delete the membership of these nodes to groups. + session.query(DbGroupNode).filter(DbGroupNode.dbnode_id.in_(list(pks_to_delete)) + ).delete(synchronize_session='fetch') + # Delete the links coming out of the nodes marked for deletion. + session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') + # Delete the links pointing to the nodes marked for deletion. + session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') + # Delete the actual nodes + session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch') # Below are abstract methods inherited from `aiida.orm.implementation.sql.backends.SqlBackend` diff --git a/aiida/tools/graph/deletions.py b/aiida/tools/graph/deletions.py index 0bb0b02f62..d14d9c7dd5 100644 --- a/aiida/tools/graph/deletions.py +++ b/aiida/tools/graph/deletions.py @@ -103,8 +103,8 @@ def _missing_callback(_pks: Iterable[int]): return (pks_set_to_delete, True) DELETE_LOGGER.report('Starting node deletion...') - with backend.transaction() as transaction: - backend.delete_nodes_and_connections(pks_set_to_delete, transaction) + with backend.transaction(): + backend.delete_nodes_and_connections(pks_set_to_delete) DELETE_LOGGER.report('Deletion of nodes completed.') return (pks_set_to_delete, True) diff --git a/tests/orm/implementation/test_backend.py b/tests/orm/implementation/test_backend.py index 4b6429babf..82d3b6f72b 100644 --- a/tests/orm/implementation/test_backend.py +++ b/tests/orm/implementation/test_backend.py @@ -12,6 +12,7 @@ from aiida import orm from aiida.common import exceptions +from aiida.common.links import LinkType from aiida.orm.entities import EntityTypes @@ -46,6 +47,7 @@ def test_transaction(self): try: with self.backend.transaction(): + assert self.backend.in_transaction user1.email = 'broken1@email.com' user2.email = 'broken2@email.com' raise RuntimeError @@ -76,38 +78,96 @@ def test_store_in_transaction(self): def test_bulk_insert(self): """Test that bulk insert works.""" rows = [{'email': 'user1@email.com'}, {'email': 'user2@email.com'}] - with self.backend.transaction() as transaction: - # should fail if all fields are not given and allow_defaults=False - with pytest.raises(exceptions.IntegrityError, match='Incorrect fields'): - self.backend.bulk_insert(EntityTypes.USER, rows, transaction) - pks = self.backend.bulk_insert(EntityTypes.USER, rows, transaction, allow_defaults=True) + # should fail if all fields are not given and allow_defaults=False + with pytest.raises(exceptions.IntegrityError, match='Incorrect fields'): + self.backend.bulk_insert(EntityTypes.USER, rows) + pks = self.backend.bulk_insert(EntityTypes.USER, rows, allow_defaults=True) assert len(pks) == len(rows) for pk, row in zip(pks, rows): assert isinstance(pk, int) user = orm.User.objects.get(id=pk) assert user.email == row['email'] + def test_bulk_insert_in_transaction(self): + """Test that bulk insert in a cancelled transaction is not committed.""" + rows = [{'email': 'user1@email.com'}, {'email': 'user2@email.com'}] + try: + with self.backend.transaction(): + self.backend.bulk_insert(EntityTypes.USER, rows, allow_defaults=True) + raise RuntimeError + except RuntimeError: + pass + for row in rows: + with pytest.raises(exceptions.NotExistent): + orm.User.objects.get(email=row['email']) + def test_bulk_update(self): """Test that bulk update works.""" - user1 = orm.User('user1@email.com').store() - user2 = orm.User('user2@email.com').store() - user3 = orm.User('user3@email.com').store() - with self.backend.transaction() as transaction: - # should raise if the 'id' field is not present - with pytest.raises(exceptions.IntegrityError, match="'id' field not given"): - self.backend.bulk_update(EntityTypes.USER, [{'email': 'other'}], transaction) - # should raise if a non-existent field is present - with pytest.raises(exceptions.IntegrityError, match='Incorrect fields'): - self.backend.bulk_update(EntityTypes.USER, [{'id': user1.pk, 'x': 'other'}], transaction) - self.backend.bulk_update( - EntityTypes.USER, [{ - 'id': user1.pk, - 'email': 'other1' - }, { - 'id': user2.pk, - 'email': 'other2' - }], transaction - ) - assert user1.email == 'other1' - assert user2.email == 'other2' - assert user3.email == 'user3@email.com' + users = [orm.User(f'user{i}@email.com').store() for i in range(3)] + # should raise if the 'id' field is not present + with pytest.raises(exceptions.IntegrityError, match="'id' field not given"): + self.backend.bulk_update(EntityTypes.USER, [{'email': 'other'}]) + # should raise if a non-existent field is present + with pytest.raises(exceptions.IntegrityError, match='Incorrect fields'): + self.backend.bulk_update(EntityTypes.USER, [{'id': users[0].pk, 'x': 'other'}]) + self.backend.bulk_update( + EntityTypes.USER, [{ + 'id': users[0].pk, + 'email': 'other0' + }, { + 'id': users[1].pk, + 'email': 'other1' + }] + ) + assert users[0].email == 'other0' + assert users[1].email == 'other1' + assert users[2].email == 'user2@email.com' + + def test_bulk_update_in_transaction(self): + """Test that bulk update in a cancelled transaction is not committed.""" + users = [orm.User(f'user{i}@email.com').store() for i in range(3)] + try: + with self.backend.transaction(): + self.backend.bulk_update( + EntityTypes.USER, [{ + 'id': users[0].pk, + 'email': 'other0' + }, { + 'id': users[1].pk, + 'email': 'other1' + }] + ) + raise RuntimeError + except RuntimeError: + pass + for i, user in enumerate(users): + assert user.email == f'user{i}@email.com' + + def test_delete_nodes_and_connections(self): + """Delete all nodes and connections.""" + # create node, link and add to group + node = orm.Data() + calc_node = orm.CalcFunctionNode().store() + node.add_incoming(calc_node, link_type=LinkType.CREATE, link_label='link') + node.store() + node_pk = node.pk + group = orm.Group('name').store() + group.add_nodes([node]) + + # checks before deletion + orm.Node.objects.get(id=node_pk) + assert len(calc_node.get_outgoing().all()) == 1 + assert len(group.nodes) == 1 + + # cannot call outside a transaction + with pytest.raises(AssertionError): + self.backend.delete_nodes_and_connections([node_pk]) + + with self.backend.transaction(): + self.backend.delete_nodes_and_connections([node_pk]) + + # checks after deletion + with pytest.raises(exceptions.NotExistent): + orm.Node.objects.get(id=node_pk) + assert len(calc_node.get_outgoing().all()) == 0 + assert len(group.nodes) == 0 From 346a6b8f6482e3fbf77494877d8fe399a7126400 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Sat, 16 Oct 2021 03:39:59 +0200 Subject: [PATCH 14/14] fix test --- aiida/orm/implementation/django/backend.py | 2 +- aiida/orm/implementation/sqlalchemy/backend.py | 12 ++++++++---- tests/orm/node/test_node.py | 4 ++-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/aiida/orm/implementation/django/backend.py b/aiida/orm/implementation/django/backend.py index 2d02ae4f5c..915000c170 100644 --- a/aiida/orm/implementation/django/backend.py +++ b/aiida/orm/implementation/django/backend.py @@ -168,7 +168,7 @@ def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None: def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None: if not self.in_transaction: - raise AssertionError('Cannot delete nodes outside a transaction') + raise AssertionError('Cannot delete nodes and links outside a transaction') # Delete all links pointing to or from a given node dbm.DbLink.objects.filter(models.Q(input__in=pks_to_delete) | models.Q(output__in=pks_to_delete)).delete() # now delete nodes diff --git a/aiida/orm/implementation/sqlalchemy/backend.py b/aiida/orm/implementation/sqlalchemy/backend.py index f00af12e02..01a34c125b 100644 --- a/aiida/orm/implementation/sqlalchemy/backend.py +++ b/aiida/orm/implementation/sqlalchemy/backend.py @@ -9,7 +9,7 @@ ########################################################################### """SqlAlchemy implementation of `aiida.orm.implementation.backends.Backend`.""" # pylint: disable=missing-function-docstring -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext import functools from typing import Iterator, List, Sequence @@ -153,7 +153,9 @@ def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults # note for postgresql+psycopg2 we could also use `save_all` + `flush` with minimal performance degradation, see # https://docs.sqlalchemy.org/en/14/changelog/migration_14.html#orm-batch-inserts-with-psycopg2-now-batch-statements-with-returning-in-most-cases # by contrast, in sqlite, bulk_insert is faster: https://docs.sqlalchemy.org/en/14/faq/performance.html - self.get_session().bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True) + session = self.get_session() + with (nullcontext() if self.in_transaction else self.transaction()): # type: ignore[attr-defined] + session.bulk_insert_mappings(mapper, rows, render_nulls=True, return_defaults=True) return [row['id'] for row in rows] def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None: # pylint: disable=no-self-use @@ -165,7 +167,9 @@ def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None: # py raise IntegrityError(f"'id' field not given for {entity_type}: {set(row)}") if not keys.issuperset(row): raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} not subset of {keys}') - self.get_session().bulk_update_mappings(mapper, rows) + session = self.get_session() + with (nullcontext() if self.in_transaction else self.transaction()): # type: ignore[attr-defined] + session.bulk_update_mappings(mapper, rows) def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None: # pylint: disable=no-self-use # pylint: disable=no-value-for-parameter @@ -173,7 +177,7 @@ def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None: # from aiida.backends.sqlalchemy.models.node import DbLink, DbNode if not self.in_transaction: - raise AssertionError('Cannot delete nodes outside a transaction') + raise AssertionError('Cannot delete nodes and links outside a transaction') session = self.get_session() # Delete the membership of these nodes to groups. diff --git a/tests/orm/node/test_node.py b/tests/orm/node/test_node.py index 3d5763038d..20d2d12c62 100644 --- a/tests/orm/node/test_node.py +++ b/tests/orm/node/test_node.py @@ -820,8 +820,8 @@ def test_delete_through_backend(self): assert len(Log.objects.get_logs_for(data_two)) == 1 assert Log.objects.get_logs_for(data_two)[0].pk == log_two.pk - with backend.transaction() as transaction: - backend.delete_nodes_and_connections([data_two.pk], transaction) + with backend.transaction(): + backend.delete_nodes_and_connections([data_two.pk]) assert len(Log.objects.get_logs_for(data_one)) == 1 assert Log.objects.get_logs_for(data_one)[0].pk == log_one.pk