Skip to content

Commit

Permalink
Add a transaction context manager to backend (#2387)
Browse files Browse the repository at this point in the history
This allows to group operations that will be rolled back if the
context is exited with an exception. This is laying the groundwork for
implementing `Node` as part of the new backend system as links, caches,
etc will have to be done in a transaction.
  • Loading branch information
sphuber authored Jan 15, 2019
1 parent 69f3bd5 commit a2456ab
Show file tree
Hide file tree
Showing 7 changed files with 109 additions and 4 deletions.
1 change: 1 addition & 0 deletions aiida/backends/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
'orm.data.upf': ['aiida.backends.tests.orm.data.upf'],
'orm.entities': ['aiida.backends.tests.orm.entities'],
'orm.groups': ['aiida.backends.tests.orm.groups'],
'orm.implementation.backend': ['aiida.backends.tests.orm.implementation.test_backend'],
'orm.logs': ['aiida.backends.tests.orm.logs'],
'orm.mixins': ['aiida.backends.tests.orm.mixins'],
'orm.node': ['aiida.backends.tests.orm.node.test_node'],
Expand Down
Empty file.
69 changes: 69 additions & 0 deletions aiida/backends/tests/orm/implementation/test_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida_core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Unit tests for the ORM Backend class."""
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
from aiida.backends.testbase import AiidaTestCase
from aiida import orm
from aiida.common import exceptions


class TestBackend(AiidaTestCase):
"""Test backend."""

def test_transaction_nesting(self):
"""Test that transaction nesting works."""
user = orm.User('initial@email.com').store()
with self.backend.transaction():
user.email = 'pre-failure@email.com'
try:
with self.backend.transaction():
user.email = 'failure@email.com'
self.assertEqual(user.email, 'failure@email.com')
raise RuntimeError
except RuntimeError:
pass
self.assertEqual(user.email, 'pre-failure@email.com')
self.assertEqual(user.email, 'pre-failure@email.com')

def test_transaction(self):
"""Test that transaction nesting works."""
user1 = orm.User('user1@email.com').store()
user2 = orm.User('user2@email.com').store()

try:
with self.backend.transaction():
user1.email = 'broken1@email.com'
user2.email = 'broken2@email.com'
raise RuntimeError
except RuntimeError:
pass
self.assertEqual(user1.email, 'user1@email.com')
self.assertEqual(user2.email, 'user2@email.com')

def test_store_in_transaction(self):
"""Test that storing inside a transaction is correctly dealt with."""
user1 = orm.User('user_store@email.com')
with self.backend.transaction():
user1.store()
# the following shouldn't raise
orm.User.objects.get(email='user_store@email.com')

user2 = orm.User('user_store_fail@email.com')
try:
with self.backend.transaction():
user2.store()
raise RuntimeError
except RuntimeError:
pass

with self.assertRaises(exceptions.NotExistent):
orm.User.objects.get(email='user_store_fail@email.com')
10 changes: 10 additions & 0 deletions aiida/orm/implementation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ def users(self):
:rtype: :class:`aiida.orm.implementation.BackendUserCollection`
"""

@abc.abstractmethod
def transaction(self):
"""
Get a context manager that can be used as a transaction context for a series of backend operations.
If there is an exception within the context then the changes will be rolled back and the state will
be as before entering. Transactions can be nested.
:return: a context manager to group database operations
"""


@six.add_metaclass(abc.ABCMeta)
class BackendEntity(object):
Expand Down
5 changes: 4 additions & 1 deletion aiida/orm/implementation/django/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from contextlib import contextmanager

from django.db import models
from django.db import models, transaction

from aiida.backends.djsite.queries import DjangoQueryManager
from aiida.orm.implementation.sql import SqlBackend
Expand Down Expand Up @@ -103,3 +103,6 @@ def cursor(self):
yield self.get_connection().cursor()
finally:
pass

def transaction(self):
return transaction.atomic()
17 changes: 17 additions & 0 deletions aiida/orm/implementation/sqlalchemy/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from contextlib import contextmanager

from aiida.orm.implementation.sql import SqlBackend
from aiida.backends.sqlalchemy import get_scoped_session
from aiida.backends.sqlalchemy.models import base
from aiida.backends.sqlalchemy.queries import SqlaQueryManager
from . import authinfo
Expand Down Expand Up @@ -107,3 +108,19 @@ def cursor(self):
yield connection.cursor()
finally:
self.get_connection().close()

@contextmanager
def transaction(self):
session = get_scoped_session()
nested = session.transaction.nested
try:
session.begin_nested()
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
if not nested:
# Make sure to commit the outermost session
session.commit()
11 changes: 8 additions & 3 deletions aiida/orm/implementation/sqlalchemy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import sqlalchemy.exc

from aiida.common import exceptions

from aiida.backends.sqlalchemy import get_scoped_session

__all__ = ['django_filter', 'get_attr']

Expand Down Expand Up @@ -53,7 +53,7 @@ def __getattr__(self, item):
if item == '_model':
raise AttributeError()

if self.is_saved() and self._is_model_field(item):
if self.is_saved() and not self._in_transaction() and self._is_model_field(item):
self._ensure_model_uptodate(fields=(item,))

return getattr(self._model, item)
Expand All @@ -70,7 +70,8 @@ def is_saved(self):
def save(self):
"""Store the model (possibly updating values if changed)."""
try:
self._model.save(commit=True)
commit = not self._in_transaction()
self._model.save(commit=commit)
except sqlalchemy.exc.IntegrityError as e:
self._model.session.rollback()
raise exceptions.IntegrityError(str(e))
Expand All @@ -90,6 +91,10 @@ def _ensure_model_uptodate(self, fields=None):
if self.is_saved():
self._model.session.expire(self._model, attribute_names=fields)

@staticmethod
def _in_transaction():
return get_scoped_session().transaction.nested


@contextlib.contextmanager
def disable_expire_on_commit(session):
Expand Down

0 comments on commit a2456ab

Please sign in to comment.