Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solution for the fast node addition to group for v0.12 #2471

Merged
merged 2 commits into from
Feb 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
###########################################################################
# Copyright (c), The AiiDA team. All rights reserved. #
# This file is part of the AiiDA code. #
# #
# The code is hosted on GitHub at https://github.com/aiidateam/aiida_core #
# For further information on the license, see the LICENSE.txt file #
# For further information please visit http://www.aiida.net #
###########################################################################
"""Unique constraints for the db_dbgroup_dbnodes table

Revision ID: 7a6587e16f4c
Revises: 35d4ee9a1b0e
Create Date: 2019-02-11 19:25:11.744902

"""
from alembic import op

# revision identifiers, used by Alembic.
revision = '7a6587e16f4c'
down_revision = '35d4ee9a1b0e'
branch_labels = None
depends_on = None


def upgrade():
"""
Add unique constraints to the db_dbgroup_dbnodes table.
"""
op.create_unique_constraint('uix_dbnode_id_dbgroup_id', 'db_dbgroup_dbnodes', ['dbnode_id', 'dbgroup_id'])


def downgrade():
"""
Remove unique constraints from the db_dbgroup_dbnodes table.
"""
op.drop_constraint('uix_dbnode_id_dbgroup_id', 'db_dbgroup_dbnodes')
5 changes: 4 additions & 1 deletion aiida/backends/sqlalchemy/models/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
Base.metadata,
Column('id', Integer, primary_key=True),
Column('dbnode_id', Integer, ForeignKey('db_dbnode.id', deferrable=True, initially="DEFERRED")),
Column('dbgroup_id', Integer, ForeignKey('db_dbgroup.id', deferrable=True, initially="DEFERRED"))
Column('dbgroup_id', Integer, ForeignKey('db_dbgroup.id', deferrable=True, initially="DEFERRED")),

# explicit/composite unique constraint. 'name' is optional.
UniqueConstraint('dbnode_id', 'dbgroup_id', name='uix_dbnode_id_dbgroup_id')
szoupanos marked this conversation as resolved.
Show resolved Hide resolved
)


Expand Down
63 changes: 63 additions & 0 deletions aiida/backends/sqlalchemy/tests/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,69 @@ def test_query(self):
newuser.delete()


class TestGroupNoOrmSQLA(AiidaTestCase):
"""
These tests check that the group node addition works ok when the skip_orm=True flag is used
"""

def test_group_general(self):
"""
General tests to verify that the group addition with the skip_orm=True flag
work properly
"""
from aiida.orm.group import Group
from aiida.orm.data import Data

node_01 = Data().store()
node_02 = Data().store()
node_03 = Data().store()
node_04 = Data().store()
node_05 = Data().store()
node_06 = Data().store()
node_07 = Data().store()
node_08 = Data().store()
nodes = [node_01, node_02, node_03, node_04, node_05, node_06, node_07, node_08]

group = Group(name='test_adding_nodes').store()
# Single node
group.add_nodes(node_01, skip_orm=True)
# List of nodes
group.add_nodes([node_02, node_03], skip_orm=True)
# Single DbNode
group.add_nodes(node_04.dbnode, skip_orm=True)
# List of DbNodes
group.add_nodes([node_05.dbnode, node_06.dbnode], skip_orm=True)
# List of orm.Nodes and DbNodes
group.add_nodes([node_07, node_08.dbnode], skip_orm=True)

# Check
self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes))

# Try to add a node that is already present: there should be no problem
group.add_nodes(node_01, skip_orm=True)
self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes))

def test_group_batch_size(self):
"""
Test that the group addition in batches works as expected.
"""
from aiida.orm.group import Group
from aiida.orm.data import Data

# Create 100 nodes
nodes = []
for _ in range(100):
nodes.append(Data().store())

# Add nodes to groups using different batch size. Check in the end the
# correct addition.
batch_sizes = (1, 3, 10, 1000)
for batch_size in batch_sizes:
group = Group(name='test_batches_'+ str(batch_size)).store()
group.add_nodes(nodes, skip_orm=True, batch_size=batch_size)
self.assertEqual(set(_.pk for _ in nodes), set(_.pk for _ in group.nodes))


class TestDbExtrasSqla(AiidaTestCase):
"""
Characterized functions
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/implementation/django/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def store(self):
# To allow to do directly g = Group(...).store()
return self

def add_nodes(self, nodes):
def add_nodes(self, nodes, **kargs):
from aiida.backends.djsite.db.models import DbNode
if not self.is_stored:
raise ModificationNotAllowed("Cannot add nodes to a group before "
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/implementation/general/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def store(self):
pass

@abstractmethod
def add_nodes(self, nodes):
def add_nodes(self, nodes, **kargs):
"""
Add a node or a set of nodes to the group.

Expand Down
68 changes: 62 additions & 6 deletions aiida/orm/implementation/sqlalchemy/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,20 @@ def store(self):

return self

def add_nodes(self, nodes):
def add_nodes(self, nodes, skip_orm=False, batch_size=5000):
"""
:param nodes: See the description of the abstract method that it extends
:param skip_orm: When the flag is on, the SQLA ORM is skipped and a RAW SQL
statement is issued (to improove speed).
:param batch_size: The maximum number of nodes added per SQL query when
skip_orm=True.
"""
if not self.is_stored:
raise ModificationNotAllowed("Cannot add nodes to a group before "
"storing")
if skip_orm and batch_size <= 0:
raise ValueError("batch_size should be a positive nunber")

from aiida.orm.implementation.sqlalchemy.node import Node
from aiida.backends.sqlalchemy import get_scoped_session
session = get_scoped_session()
Expand All @@ -147,7 +157,11 @@ def add_nodes(self, nodes):
"of such objects, it is instead {}".format(
str(type(nodes))))

list_nodes = []
# In the following list we store the the group,node pairs that will
# be used for the non-ORM insert
ins_list = list()
insert_txt = ""
node_count = 0
for node in nodes:
if not isinstance(node, (Node, DbNode)):
raise TypeError("Invalid type of one of the elements passed "
Expand All @@ -163,11 +177,53 @@ def add_nodes(self, nodes):
else:
to_add = node

if to_add not in self._dbgroup.dbnodes:
# ~ list_nodes.append(to_add)
self._dbgroup.dbnodes.append(to_add)
# If we would like to skip the ORM, we just populate the list containing
# the group/node pairs to be inserted
if skip_orm:
# We keep the nodes in bathes that should be inserted one by one to
# the group to avoid creating very big SQL statements
insert_txt += "({}, {}), ".format(node.id, self.id)
if node_count < batch_size:
node_count += 1
else:
node_count = 0
# Clean the end of the text
insert_txt = insert_txt[:-2]
# Keep it and reset the string for the next batch
ins_list.append(insert_txt)
insert_txt = ""

# Otherwise follow the traditional approach for
else:
if to_add not in self._dbgroup.dbnodes:
self._dbgroup.dbnodes.append(to_add)

# Here we do the final insert for the non-ORM case
if skip_orm:
# Take care of the latest batch if it was not added to the list
if len(insert_txt) > 0:
insert_txt = insert_txt[:-2]
ins_list.append(insert_txt)

for ins_item in ins_list:
statement = """INSERT INTO db_dbgroup_dbnodes(dbnode_id, dbgroup_id) VALUES {}
ON CONFLICT DO NOTHING;""".format(ins_item)
session.execute(statement)

# #### The following code can be used to generate similar to the SQL statement shown above.
# #### The difference is that it will create an SQL statement per added node resulting to a
# #### performance degradation of ~50% comparing the above SQL statement that will
# #### contain all nodes in one statement.
# #### We don't use the following SQLA code for the moment because the "on_conflict_do_nothing"
# #### is not supported by the SQLA version (1.0.x) that we use at AiiDA 0.12.x.

# from sqlalchemy.dialects.postgresql import insert
# from aiida.backends.sqlalchemy.models.group import table_groups_nodes
# # Create the insert statement and update the relationship table
# insert_statement = insert(table_groups_nodes).values(ins_dict)
# session.execute(insert_statement.on_conflict_do_nothing(index_elements=['dbnode_id', 'dbgroup_id']))

session.commit()
# ~ self._dbgroup.dbnodes.extend(list_nodes)

@property
def nodes(self):
Expand Down
10 changes: 4 additions & 6 deletions aiida/orm/importexport.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,7 +1402,8 @@ def import_data_sqla(in_path, ignore_unknown_nodes=False, silent=False):
qb_nodes = QueryBuilder().append(
Node, filters={'id': {'in': nodes_ids_to_add}})
nodes_to_add = [n[0] for n in qb_nodes.all()]
group.add_nodes(nodes_to_add)
# Adding nodes to group avoiding the SQLA ORM to increase speed
group.add_nodes(nodes_to_add, skip_orm=True)

######################################################
# Put everything in a specific group
Expand Down Expand Up @@ -1442,12 +1443,9 @@ def import_data_sqla(in_path, ignore_unknown_nodes=False, silent=False):
# Add all the nodes to the new group
# TODO: decide if we want to return the group name
from aiida.backends.sqlalchemy.models.node import DbNode
# Adding nodes to group avoiding the SQLA ORM to increase speed
group.add_nodes(session.query(DbNode).filter(
DbNode.id.in_(pks_for_group)).distinct().all())

# group.add_nodes(models.DbNode.objects.filter(
# pk__in=pks_for_group))

DbNode.id.in_(pks_for_group)).distinct().all(), skip_orm=True)
if not silent:
print "IMPORTED NODES GROUPED IN IMPORT GROUP NAMED '{}'".format(group.name)
else:
Expand Down