Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding nodes to group optimisation using SQLA code to generate SQL INSERT #2518

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions aiida/backends/sqlalchemy/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
from __future__ import absolute_import

from aiida.backends.testbase import AiidaTestCase
from aiida.orm import Node, Data
from aiida.common import exceptions
from aiida.orm import Data, Node


class TestComputer(AiidaTestCase):
Expand Down Expand Up @@ -130,6 +129,42 @@ def test_query(self):
self.assertSetEqual(set(_.pk for _ in res), set(_.pk for _ in [g1, g2]))


class TestGroupNoOrmSQLA(AiidaTestCase):
"""
These tests check that the group node addition works ok when the skip_orm=True flag is used
"""

def test_group_general(self):
"""
General tests to verify that the group addition with the skip_orm=True flag
work properly
"""
backend = self.backend

node_01 = Data().store().backend_entity
node_02 = Data().store().backend_entity
node_03 = Data().store().backend_entity
node_04 = Data().store().backend_entity
node_05 = Data().store().backend_entity
nodes = [node_01, node_02, node_03, node_04, node_05]

simple_user = backend.users.create('simple1@ton.com')
group = backend.groups.create(label='test_adding_nodes', user=simple_user).store()
# Single node in a list
group.add_nodes([node_01], skip_orm=True)
# List of nodes
group.add_nodes([node_02, node_03], skip_orm=True)
# Tuple of nodes
group.add_nodes((node_04, node_05), skip_orm=True)

# Check
self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes))

# Try to add a node that is already present: there should be no problem
group.add_nodes([node_01], skip_orm=True)
self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes))


class TestDbExtrasSqla(AiidaTestCase):
"""
Characterized functions
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/implementation/django/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def next(self):

return NodesIterator(self._dbmodel.dbnodes.all(), self._backend)

def add_nodes(self, nodes):
def add_nodes(self, nodes, **kwargs):
from .nodes import DjangoNode

super(DjangoGroup, self).add_nodes(nodes)
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/implementation/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def nodes(self):
the number of nodes in the group using len().
"""

def add_nodes(self, nodes):
def add_nodes(self, nodes, **kwargs): # pylint: disable=unused-argument
"""Add a set of nodes to the group.
:note: all the nodes *and* the group itself have to be stored.
Expand Down
68 changes: 45 additions & 23 deletions aiida/orm/implementation/sqlalchemy/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
# For further information please visit http://www.aiida.net #
###########################################################################
"""SQLA groups"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

import collections
import logging

import six

from aiida.backends import sqlalchemy as sa
Expand All @@ -21,7 +23,6 @@
from aiida.common.exceptions import UniquenessError
from aiida.common.lang import type_check
from aiida.orm.implementation.groups import BackendGroup, BackendGroupCollection

from . import entities
from . import users
from . import utils
Expand Down Expand Up @@ -161,40 +162,61 @@ def next(self):

return Iterator(self._dbmodel.dbnodes, self._backend)

def add_nodes(self, nodes):
def add_nodes(self, nodes, **kwargs):
"""Add a node or a set of nodes to the group.
:note: all the nodes *and* the group itself have to be stored.
:param nodes: a list of `BackendNode` instance to be added to this group
:param kwargs:
skip_orm: When the flag is on, the SQLA ORM is skipped and SQLA is used
to create a direct SQL INSERT statement to the group-node relationship
table (to improve speed).
"""
from sqlalchemy.exc import IntegrityError # pylint: disable=import-error, no-name-in-module
from sqlalchemy.dialects.postgresql import insert # pylint: disable=import-error, no-name-in-module
from aiida.orm.implementation.sqlalchemy.nodes import SqlaNode
from aiida.backends.sqlalchemy import get_scoped_session
from aiida.backends.sqlalchemy.models.base import Base

super(SqlaGroup, self).add_nodes(nodes)
skip_orm = kwargs.get('skip_orm', False)

with utils.disable_expire_on_commit(get_scoped_session()) as session:
def check_node(given_node):
""" Check if given node is of correct type and stored """
if not isinstance(given_node, SqlaNode):
raise TypeError('invalid type {}, has to be {}'.format(type(given_node), SqlaNode))

if not given_node.is_stored:
raise ValueError('At least one of the provided nodes is unstored, stopping...')

# Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database
dbnodes = self._dbmodel.dbnodes

for node in nodes:
if not isinstance(node, SqlaNode):
raise TypeError('invalid type {}, has to be {}'.format(type(node), SqlaNode))

if not node.is_stored:
raise ValueError('At least one of the provided nodes is unstored, stopping...')

# Use pattern as suggested here:
# http://docs.sqlalchemy.org/en/latest/orm/session_transaction.html#using-savepoint
try:
with session.begin_nested():
dbnodes.append(node.dbmodel)
session.flush()
except IntegrityError:
# Duplicate entry, skip
pass
with utils.disable_expire_on_commit(get_scoped_session()) as session:
if not skip_orm:
# Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database
dbnodes = self._dbmodel.dbnodes

for node in nodes:
check_node(node)

# Use pattern as suggested here:
# http://docs.sqlalchemy.org/en/latest/orm/session_transaction.html#using-savepoint
try:
with session.begin_nested():
dbnodes.append(node.dbmodel)
session.flush()
except IntegrityError:
# Duplicate entry, skip
pass
else:
ins_dict = list()
for node in nodes:
check_node(node)
ins_dict.append({'dbnode_id': node.id, 'dbgroup_id': self.id})

my_table = Base.metadata.tables['db_dbgroup_dbnodes']
ins = insert(my_table).values(ins_dict)
session.execute(ins.on_conflict_do_nothing(index_elements=['dbnode_id', 'dbgroup_id']))

# Commit everything as up till now we've just flushed
session.commit()
Expand Down
11 changes: 6 additions & 5 deletions aiida/orm/importexport.py
Original file line number Diff line number Diff line change
Expand Up @@ -1739,8 +1739,9 @@ def import_data_sqla(in_path, user_group=None, ignore_unknown_nodes=False,
for node_uuid in groupnodes]
qb_nodes = QueryBuilder().append(
Node, filters={'id': {'in': nodes_ids_to_add}})
nodes_to_add = [n[0] for n in qb_nodes.all()]
group.add_nodes(nodes_to_add)
# Adding nodes to group avoiding the SQLA ORM to increase speed
nodes_to_add = [n[0].backend_entity for n in qb_nodes.all()]
group.backend_entity.add_nodes(nodes_to_add, skip_orm=True)

######################################################
# Put everything in a specific group
Expand Down Expand Up @@ -1782,9 +1783,9 @@ def import_data_sqla(in_path, user_group=None, ignore_unknown_nodes=False,
counter += 1

# Add all the nodes to the new group
# TODO: decide if we want to return the group label
nodes = [entry[0] for entry in QueryBuilder().append(Node, filters={'id': {'in': pks_for_group}}).all()]
group.add_nodes(nodes)
# Adding nodes to group avoiding the SQLA ORM to increase speed
nodes = [entry[0].backend_entity for entry in QueryBuilder().append(Node, filters={'id': {'in': pks_for_group}}).all()]
group.backend_entity.add_nodes(nodes, skip_orm=True)

if not silent:
print("IMPORTED NODES GROUPED IN IMPORT GROUP NAMED '{}'".format(group.label))
Expand Down