diff --git a/aiida/backends/tests/__init__.py b/aiida/backends/tests/__init__.py index 7cebe4f483..bbe7ced09b 100644 --- a/aiida/backends/tests/__init__.py +++ b/aiida/backends/tests/__init__.py @@ -147,7 +147,6 @@ 'plugins.entry_point': ['aiida.backends.tests.plugins.test_entry_point'], 'plugins.factories': ['aiida.backends.tests.plugins.test_factories'], 'plugins.utils': ['aiida.backends.tests.plugins.test_utils'], - 'query': ['aiida.backends.tests.test_query'], 'restapi.identifiers': ['aiida.backends.tests.restapi.test_identifiers'], 'restapi': ['aiida.backends.tests.test_restapi'], 'tools.data.orbital': ['aiida.backends.tests.tools.data.orbital.test_orbitals'], diff --git a/aiida/backends/tests/orm/test_querybuilder.py b/aiida/backends/tests/orm/test_querybuilder.py index c031b0f81d..2b6c7906e9 100644 --- a/aiida/backends/tests/orm/test_querybuilder.py +++ b/aiida/backends/tests/orm/test_querybuilder.py @@ -7,18 +7,612 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Unit tests for the QueryBuilder ORM class.""" +# pylint: disable=invalid-name,missing-docstring,too-many-lines +"""Tests for the QueryBuilder.""" from __future__ import division -from __future__ import print_function from __future__ import absolute_import +from __future__ import print_function -import six +import warnings + +from six.moves import range, zip +from six import string_types from aiida import orm from aiida.backends.testbase import AiidaTestCase +from aiida.common.links import LinkType +from aiida.manage import configuration class TestQueryBuilder(AiidaTestCase): + + def setUp(self): + super(TestQueryBuilder, self).setUp() + self.clean_db() + self.insert_data() + + def test_ormclass_type_classification(self): + """ + This tests the classifications of the QueryBuilder + """ + # pylint: disable=protected-access + from aiida.common.exceptions import DbContentError + + qb = orm.QueryBuilder() + + # Asserting that improper declarations of the class type raise an error + with self.assertRaises(DbContentError): + qb._get_ormclass(None, 'data') + with self.assertRaises(DbContentError): + qb._get_ormclass(None, 'data.Data') + with self.assertRaises(DbContentError): + qb._get_ormclass(None, '.') + + # Asserting that the query type string and plugin type string are returned: + for _cls, classifiers in ( + qb._get_ormclass(orm.StructureData, None), + qb._get_ormclass(None, 'data.structure.StructureData.'), + ): + self.assertEqual(classifiers['ormclass_type_string'], orm.StructureData._plugin_type_string) # pylint: disable=no-member + + for _cls, classifiers in ( + qb._get_ormclass(orm.Group, None), + qb._get_ormclass(None, 'group'), + qb._get_ormclass(None, 'Group'), + ): + self.assertEqual(classifiers['ormclass_type_string'], 'group') + + for _cls, classifiers in ( + qb._get_ormclass(orm.User, None), + qb._get_ormclass(None, 'user'), + qb._get_ormclass(None, 'User'), + ): + self.assertEqual(classifiers['ormclass_type_string'], 'user') + + for _cls, classifiers in ( + qb._get_ormclass(orm.Computer, None), + qb._get_ormclass(None, 'computer'), + qb._get_ormclass(None, 'Computer'), + ): + self.assertEqual(classifiers['ormclass_type_string'], 'computer') + + for _cls, classifiers in ( + qb._get_ormclass(orm.Data, None), + qb._get_ormclass(None, 'data.Data.'), + ): + self.assertEqual(classifiers['ormclass_type_string'], orm.Data._plugin_type_string) # pylint: disable=no-member + + def test_process_type_classification(self): + """ + This tests the classifications of the QueryBuilder + """ + from aiida.engine import WorkChain + from aiida.plugins import CalculationFactory + + ArithmeticAdd = CalculationFactory('arithmetic.add') + + qb = orm.QueryBuilder() + + # pylint: disable=protected-access + + # When passing a WorkChain class, it should return the type of the corresponding Node + # including the appropriate filter on the process_type + _cls, classifiers = qb._get_ormclass(WorkChain, None) + self.assertEqual(classifiers['ormclass_type_string'], 'process.workflow.workchain.WorkChainNode.') + self.assertEqual(classifiers['process_type_string'], 'aiida.engine.processes.workchains.workchain.WorkChain') + + # When passing a WorkChainNode, no process_type filter is applied + _cls, classifiers = qb._get_ormclass(orm.WorkChainNode, None) + self.assertEqual(classifiers['ormclass_type_string'], 'process.workflow.workchain.WorkChainNode.') + self.assertEqual(classifiers['process_type_string'], None) + + # Same tests for a calculation + _cls, classifiers = qb._get_ormclass(ArithmeticAdd, None) + self.assertEqual(classifiers['ormclass_type_string'], 'process.calculation.calcjob.CalcJobNode.') + self.assertEqual(classifiers['process_type_string'], 'aiida.calculations:arithmetic.add') + + def test_process_query(self): + """ + Test querying for a process class. + """ + from aiida.engine import run, WorkChain, if_, return_, ExitCode + from aiida.common.warnings import AiidaEntryPointWarning + + class PotentialFailureWorkChain(WorkChain): + EXIT_STATUS = 1 + EXIT_MESSAGE = 'Well you did ask for it' + OUTPUT_LABEL = 'optional_output' + OUTPUT_VALUE = 144 + + @classmethod + def define(cls, spec): + super(PotentialFailureWorkChain, cls).define(spec) + spec.input('success', valid_type=orm.Bool) + spec.input('through_return', valid_type=orm.Bool, default=orm.Bool(False)) + spec.input('through_exit_code', valid_type=orm.Bool, default=orm.Bool(False)) + spec.exit_code(cls.EXIT_STATUS, 'EXIT_STATUS', cls.EXIT_MESSAGE) + spec.outline(if_(cls.should_return_out_of_outline)(return_(cls.EXIT_STATUS)), cls.failure, cls.success) + spec.output(cls.OUTPUT_LABEL, required=False) + + def should_return_out_of_outline(self): + return self.inputs.through_return.value + + def failure(self): + # pylint: disable=no-else-return + + if self.inputs.success.value is False: + # Returning either 0 or ExitCode with non-zero status should terminate the workchain + if self.inputs.through_exit_code.value is False: + return self.EXIT_STATUS + else: + return self.exit_codes.EXIT_STATUS # pylint: disable=no-member + else: + # Returning 0 or ExitCode with zero status should *not* terminate the workchain + if self.inputs.through_exit_code.value is False: + return 0 + else: + return ExitCode() + + def success(self): + self.out(self.OUTPUT_LABEL, orm.Int(self.OUTPUT_VALUE).store()) + + class DummyWorkChain(WorkChain): + pass + + # Run a simple test WorkChain + _result = run(PotentialFailureWorkChain, success=orm.Bool(True)) + + # Query for nodes associated with this type of WorkChain + qb = orm.QueryBuilder() + + with warnings.catch_warnings(record=True) as w: # pylint: disable=no-member + # Cause all warnings to always be triggered. + warnings.simplefilter('always') # pylint: disable=no-member + + qb.append(PotentialFailureWorkChain) + + # Verify some things + assert len(w) == 1 + assert issubclass(w[-1].category, AiidaEntryPointWarning) + + # There should be one result of type WorkChainNode + self.assertEqual(qb.count(), 1) + self.assertTrue(isinstance(qb.all()[0][0], orm.WorkChainNode)) + + # Query for nodes of a different type of WorkChain + qb = orm.QueryBuilder() + + with warnings.catch_warnings(record=True) as w: # pylint: disable=no-member + # Cause all warnings to always be triggered. + warnings.simplefilter('always') # pylint: disable=no-member + + qb.append(DummyWorkChain) + + # Verify some things + assert len(w) == 1 + assert issubclass(w[-1].category, AiidaEntryPointWarning) + + # There should be no result + self.assertEqual(qb.count(), 0) + + # Query for all WorkChain nodes + qb = orm.QueryBuilder() + qb.append(WorkChain) + + # There should be one result + self.assertEqual(qb.count(), 1) + + def test_simple_query_1(self): + """ + Testing a simple query + """ + # pylint: disable=too-many-statements + + n1 = orm.Data() + n1.label = 'node1' + n1.set_attribute('foo', ['hello', 'goodbye']) + n1.store() + + n2 = orm.CalculationNode() + n2.label = 'node2' + n2.set_attribute('foo', 1) + + n3 = orm.Data() + n3.label = 'node3' + n3.set_attribute('foo', 1.0000) # Stored as fval + n3.store() + + n4 = orm.CalculationNode() + n4.label = 'node4' + n4.set_attribute('foo', 'bar') + + n5 = orm.Data() + n5.label = 'node5' + n5.set_attribute('foo', None) + n5.store() + + n2.add_incoming(n1, link_type=LinkType.INPUT_CALC, link_label='link1') + n2.store() + n3.add_incoming(n2, link_type=LinkType.CREATE, link_label='link2') + + n4.add_incoming(n3, link_type=LinkType.INPUT_CALC, link_label='link3') + n4.store() + n5.add_incoming(n4, link_type=LinkType.CREATE, link_label='link4') + + qb1 = orm.QueryBuilder() + qb1.append(orm.Node, filters={'attributes.foo': 1.000}) + + self.assertEqual(len(qb1.all()), 2) + + qb2 = orm.QueryBuilder() + qb2.append(orm.Data) + self.assertEqual(qb2.count(), 3) + + qb2 = orm.QueryBuilder() + qb2.append(entity_type='data.Data.') + self.assertEqual(qb2.count(), 3) + + qb3 = orm.QueryBuilder() + qb3.append(orm.Node, project='label', tag='node1') + qb3.append(orm.Node, project='label', tag='node2') + self.assertEqual(qb3.count(), 4) + + qb4 = orm.QueryBuilder() + qb4.append(orm.CalculationNode, tag='node1') + qb4.append(orm.Data, tag='node2') + self.assertEqual(qb4.count(), 2) + + qb5 = orm.QueryBuilder() + qb5.append(orm.Data, tag='node1') + qb5.append(orm.CalculationNode, tag='node2') + self.assertEqual(qb5.count(), 2) + + qb6 = orm.QueryBuilder() + qb6.append(orm.Data, tag='node1') + qb6.append(orm.Data, tag='node2') + self.assertEqual(qb6.count(), 0) + + def test_simple_query_2(self): + from datetime import datetime + from aiida.common.exceptions import MultipleObjectsError, NotExistent + n0 = orm.Data() + n0.label = 'hello' + n0.description = '' + n0.set_attribute('foo', 'bar') + + n1 = orm.CalculationNode() + n1.label = 'foo' + n1.description = 'I am FoO' + + n2 = orm.Data() + n2.label = 'bar' + n2.description = 'I am BaR' + + n2.add_incoming(n1, link_type=LinkType.CREATE, link_label='random_2') + n1.add_incoming(n0, link_type=LinkType.INPUT_CALC, link_label='random_1') + + for n in (n0, n1, n2): + n.store() + + qb1 = orm.QueryBuilder() + qb1.append(orm.Node, filters={'label': 'hello'}) + self.assertEqual(len(list(qb1.all())), 1) + + qh = { + 'path': [{ + 'cls': orm.Node, + 'tag': 'n1' + }, { + 'cls': orm.Node, + 'tag': 'n2', + 'with_incoming': 'n1' + }], + 'filters': { + 'n1': { + 'label': { + 'ilike': '%foO%' + }, + }, + 'n2': { + 'label': { + 'ilike': 'bar%' + }, + } + }, + 'project': { + 'n1': ['id', 'uuid', 'ctime', 'label'], + 'n2': ['id', 'description', 'label'], + } + } + + qb2 = orm.QueryBuilder(**qh) + + resdict = qb2.dict() + self.assertEqual(len(resdict), 1) + self.assertTrue(isinstance(resdict[0]['n1']['ctime'], datetime)) + + res_one = qb2.one() + self.assertTrue('bar' in res_one) + + qh = { + 'path': [{ + 'cls': orm.Node, + 'tag': 'n1' + }, { + 'cls': orm.Node, + 'tag': 'n2', + 'with_incoming': 'n1' + }], + 'filters': { + 'n1--n2': { + 'label': { + 'like': '%_2' + } + } + } + } + qb = orm.QueryBuilder(**qh) + self.assertEqual(qb.count(), 1) + + # Test the hashing: + query1 = qb.get_query() + qb.add_filter('n2', {'label': 'nonexistentlabel'}) + self.assertEqual(qb.count(), 0) + + with self.assertRaises(NotExistent): + qb.one() + with self.assertRaises(MultipleObjectsError): + orm.QueryBuilder().append(orm.Node).one() + + query2 = qb.get_query() + query3 = qb.get_query() + + self.assertTrue(id(query1) != id(query2)) + self.assertTrue(id(query2) == id(query3)) + + def test_operators_eq_lt_gt(self): + nodes = [orm.Data() for _ in range(8)] + + nodes[0].set_attribute('fa', 1) + nodes[1].set_attribute('fa', 1.0) + nodes[2].set_attribute('fa', 1.01) + nodes[3].set_attribute('fa', 1.02) + nodes[4].set_attribute('fa', 1.03) + nodes[5].set_attribute('fa', 1.04) + nodes[6].set_attribute('fa', 1.05) + nodes[7].set_attribute('fa', 1.06) + + for n in nodes: + n.store() + + self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'<': 1}}).count(), 0) + self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'==': 1}}).count(), 2) + self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'<': 1.02}}).count(), 3) + self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'<=': 1.02}}).count(), 4) + self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'>': 1.02}}).count(), 4) + self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'>=': 1.02}}).count(), 5) + + def test_subclassing(self): + s = orm.StructureData() + s.set_attribute('cat', 'miau') + s.store() + + d = orm.Data() + d.set_attribute('cat', 'miau') + d.store() + + p = orm.Dict(dict=dict(cat='miau')) + p.store() + + # Now when asking for a node with attr.cat==miau, I want 3 esults: + qb = orm.QueryBuilder().append(orm.Node, filters={'attributes.cat': 'miau'}) + self.assertEqual(qb.count(), 3) + + qb = orm.QueryBuilder().append(orm.Data, filters={'attributes.cat': 'miau'}) + self.assertEqual(qb.count(), 3) + + # If I'm asking for the specific lowest subclass, I want one result + for cls in (orm.StructureData, orm.Dict): + qb = orm.QueryBuilder().append(cls, filters={'attributes.cat': 'miau'}) + self.assertEqual(qb.count(), 1) + + # Now I am not allow the subclassing, which should give 1 result for each + for cls, count in ((orm.StructureData, 1), (orm.Dict, 1), (orm.Data, 1), (orm.Node, 0)): + qb = orm.QueryBuilder().append(cls, filters={'attributes.cat': 'miau'}, subclassing=False) + self.assertEqual(qb.count(), count) + + # Now I am testing the subclassing with tuples: + qb = orm.QueryBuilder().append(cls=(orm.StructureData, orm.Dict), filters={'attributes.cat': 'miau'}) + self.assertEqual(qb.count(), 2) + qb = orm.QueryBuilder().append( + entity_type=('data.structure.StructureData.', 'data.dict.Dict.'), filters={'attributes.cat': 'miau'} + ) + self.assertEqual(qb.count(), 2) + qb = orm.QueryBuilder().append( + cls=(orm.StructureData, orm.Dict), filters={'attributes.cat': 'miau'}, subclassing=False + ) + self.assertEqual(qb.count(), 2) + qb = orm.QueryBuilder().append( + cls=(orm.StructureData, orm.Data), + filters={'attributes.cat': 'miau'}, + ) + self.assertEqual(qb.count(), 3) + qb = orm.QueryBuilder().append( + entity_type=('data.structure.StructureData.', 'data.dict.Dict.'), + filters={'attributes.cat': 'miau'}, + subclassing=False + ) + self.assertEqual(qb.count(), 2) + qb = orm.QueryBuilder().append( + entity_type=('data.structure.StructureData.', 'data.Data.'), + filters={'attributes.cat': 'miau'}, + subclassing=False + ) + self.assertEqual(qb.count(), 2) + + def test_list_behavior(self): + for _i in range(4): + orm.Data().store() + + self.assertEqual(len(orm.QueryBuilder().append(orm.Node).all()), 4) + self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project='*').all()), 4) + self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).all()), 4) + self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project=['id']).all()), 4) + self.assertEqual(len(orm.QueryBuilder().append(orm.Node).dict()), 4) + self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project='*').dict()), 4) + self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).dict()), 4) + self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project=['id']).dict()), 4) + self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node).iterall())), 4) + self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project='*').iterall())), 4) + self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).iterall())), 4) + self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project=['id']).iterall())), 4) + self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node).iterdict())), 4) + self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project='*').iterdict())), 4) + self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).iterdict())), 4) + self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project=['id']).iterdict())), 4) + + def test_append_validation(self): + from aiida.common.exceptions import InputValidationError + + # So here I am giving two times the same tag + with self.assertRaises(InputValidationError): + orm.QueryBuilder().append(orm.StructureData, tag='n').append(orm.StructureData, tag='n') + # here I am giving a wrong filter specifications + with self.assertRaises(InputValidationError): + orm.QueryBuilder().append(orm.StructureData, filters=['jajjsd']) + # here I am giving a nonsensical projection: + with self.assertRaises(InputValidationError): + orm.QueryBuilder().append(orm.StructureData, project=True) + + # here I am giving a nonsensical projection for the edge: + with self.assertRaises(InputValidationError): + orm.QueryBuilder().append(orm.ProcessNode).append(orm.StructureData, edge_tag='t').add_projection('t', True) + # Giving a nonsensical limit + with self.assertRaises(InputValidationError): + orm.QueryBuilder().append(orm.ProcessNode).limit(2.3) + # Giving a nonsensical offset + with self.assertRaises(InputValidationError): + orm.QueryBuilder(offset=2.3) + + # So, I mess up one append, I want the QueryBuilder to clean it! + with self.assertRaises(InputValidationError): + qb = orm.QueryBuilder() + # This also checks if we correctly raise for wrong keywords + qb.append(orm.StructureData, tag='s', randomkeyword={}) + + # Now I'm checking whether this keyword appears anywhere in the internal dictionaries: + # pylint: disable=protected-access + self.assertTrue('s' not in qb._projections) + self.assertTrue('s' not in qb._filters) + self.assertTrue('s' not in qb.tag_to_alias_map) + self.assertTrue(len(qb._path) == 0) + self.assertTrue(orm.StructureData not in qb._cls_to_tag_map) + # So this should work now: + qb.append(orm.StructureData, tag='s').limit(2).dict() + + def test_tags(self): + qb = orm.QueryBuilder() + qb.append(orm.Node, tag='n1') + qb.append(orm.Node, tag='n2', edge_tag='e1', with_incoming='n1') + qb.append(orm.Node, tag='n3', edge_tag='e2', with_incoming='n2') + qb.append(orm.Computer, with_node='n3', tag='c1', edge_tag='nonsense') + self.assertEqual(qb.get_used_tags(), ['n1', 'n2', 'e1', 'n3', 'e2', 'c1', 'nonsense']) + + # Now I am testing the default tags, + qb = orm.QueryBuilder().append(orm.StructureData + ).append(orm.ProcessNode).append(orm.StructureData + ).append(orm.Dict, with_outgoing=orm.ProcessNode) + self.assertEqual( + qb.get_used_tags(), [ + 'StructureData_1', 'ProcessNode_1', 'StructureData_1--ProcessNode_1', 'StructureData_2', + 'ProcessNode_1--StructureData_2', 'Dict_1', 'ProcessNode_1--Dict_1' + ] + ) + self.assertEqual( + qb.get_used_tags(edges=False), [ + 'StructureData_1', + 'ProcessNode_1', + 'StructureData_2', + 'Dict_1', + ] + ) + self.assertEqual( + qb.get_used_tags(vertices=False), + ['StructureData_1--ProcessNode_1', 'ProcessNode_1--StructureData_2', 'ProcessNode_1--Dict_1'] + ) + self.assertEqual( + qb.get_used_tags(edges=False), [ + 'StructureData_1', + 'ProcessNode_1', + 'StructureData_2', + 'Dict_1', + ] + ) + self.assertEqual( + qb.get_used_tags(vertices=False), + ['StructureData_1--ProcessNode_1', 'ProcessNode_1--StructureData_2', 'ProcessNode_1--Dict_1'] + ) + + def test_direction_keyword(self): + """ + The direction keyword is a special case with the QueryBuilder append + method, so some tests are good. + """ + d1, d2, d3, d4 = [orm.Data().store() for _ in range(4)] + c1, c2 = [orm.CalculationNode() for _ in range(2)] + c1.add_incoming(d1, link_type=LinkType.INPUT_CALC, link_label='link_d1c1') + c1.store() + d2.add_incoming(c1, link_type=LinkType.CREATE, link_label='link_c1d2') + d4.add_incoming(c1, link_type=LinkType.CREATE, link_label='link_c1d4') + c2.add_incoming(d2, link_type=LinkType.INPUT_CALC, link_label='link_d2c2') + c2.store() + d3.add_incoming(c2, link_type=LinkType.CREATE, link_label='link_c2d3') + + # testing direction=1 for d1, which should return the outgoing + qb = orm.QueryBuilder() + qb.append(orm.Data, filters={'id': d1.id}) + qb.append(orm.CalculationNode, direction=1, project='id') + res1 = {_ for _, in qb.all()} + + qb = orm.QueryBuilder() + qb.append(orm.Data, filters={'id': d1.id}, tag='data') + qb.append(orm.CalculationNode, with_incoming='data', project='id') + res2 = {_ for _, in qb.all()} + + self.assertEqual(res1, res2) + self.assertEqual(res1, {c1.id}) + + # testing direction=-1, which should return the incoming + qb = orm.QueryBuilder() + qb.append(orm.Data, filters={'id': d2.id}) + qb.append(orm.CalculationNode, direction=-1, project='id') + res1 = {_ for _, in qb.all()} + + qb = orm.QueryBuilder() + qb.append(orm.Data, filters={'id': d2.id}, tag='data') + qb.append(orm.CalculationNode, with_outgoing='data', project='id') + res2 = {_ for _, in qb.all()} + self.assertEqual(res1, res2) + self.assertEqual(res1, {c1.id}) + + # testing direction higher than 1 + qb = orm.QueryBuilder() + qb.append(orm.CalculationNode, tag='c1', filters={'id': c1.id}) + qb.append(orm.Data, with_incoming='c1', tag='d2or4') + qb.append(orm.CalculationNode, tag='c2', with_incoming='d2or4') + qb.append(orm.Data, tag='d3', with_incoming='c2', project='id') + qh = qb.get_json_compatible_queryhelp() # saving query for later + qb.append(orm.Data, direction=-4, project='id') + res1 = {item[1] for item in qb.all()} + self.assertEqual(res1, {d1.id}) + + qb = orm.QueryBuilder(**qh) + qb.append(orm.Data, direction=4, project='id') + res2 = {item[1] for item in qb.all()} + self.assertEqual(res2, {d2.id, d4.id}) + + +class TestMultipleProjections(AiidaTestCase): """Unit tests for the QueryBuilder ORM class.""" def test_first_multiple_projections(self): @@ -31,5 +625,769 @@ def test_first_multiple_projections(self): self.assertEqual(type(result), list) self.assertEqual(len(result), 2) - self.assertIsInstance(result[0], six.string_types) + self.assertIsInstance(result[0], string_types) self.assertIsInstance(result[1], orm.Data) + + +class TestQueryHelp(AiidaTestCase): + + def test_queryhelp(self): + """ + Here I test the queryhelp by seeing whether results are the same as using the append method. + I also check passing of tuples. + """ + g = orm.Group(label='helloworld').store() + for cls in (orm.StructureData, orm.Dict, orm.Data): + obj = cls() + obj.set_attribute('foo-qh2', 'bar') + obj.store() + g.add_nodes(obj) + + for cls, expected_count, subclassing in ( + (orm.StructureData, 1, True), + (orm.Dict, 1, True), + (orm.Data, 3, True), + (orm.Data, 1, False), + ((orm.Dict, orm.StructureData), 2, True), + ((orm.Dict, orm.StructureData), 2, False), + ((orm.Dict, orm.Data), 2, False), + ((orm.Dict, orm.Data), 3, True), + ((orm.Dict, orm.Data, orm.StructureData), 3, False), + ): + qb = orm.QueryBuilder() + qb.append(cls, filters={'attributes.foo-qh2': 'bar'}, subclassing=subclassing, project='uuid') + self.assertEqual(qb.count(), expected_count) + + qh = qb.get_json_compatible_queryhelp() + qb_new = orm.QueryBuilder(**qh) + self.assertEqual(qb_new.count(), expected_count) + self.assertEqual(sorted([uuid for uuid, in qb.all()]), sorted([uuid for uuid, in qb_new.all()])) + + qb = orm.QueryBuilder().append(orm.Group, filters={'label': 'helloworld'}) + self.assertEqual(qb.count(), 1) + + qb = orm.QueryBuilder().append((orm.Group,), filters={'label': 'helloworld'}) + self.assertEqual(qb.count(), 1) + + qb = orm.QueryBuilder().append(orm.Computer,) + self.assertEqual(qb.count(), 1) + + qb = orm.QueryBuilder().append(cls=(orm.Computer,)) + self.assertEqual(qb.count(), 1) + + def test_recreate_from_queryhelp(self): + """Test recreating a QueryBuilder from the Query Help""" + import copy + + qb1 = orm.QueryBuilder() + qb1.append(orm.Data) + qb1.append(orm.CalcJobNode) + + qb2 = orm.QueryBuilder(**qb1.queryhelp) + self.assertDictEqual(qb1.queryhelp, qb2.queryhelp) + + qb3 = copy.deepcopy(qb1) + self.assertDictEqual(qb1.queryhelp, qb3.queryhelp) + + +class TestQueryBuilderCornerCases(AiidaTestCase): + """ + In this class corner cases of QueryBuilder are added. + """ + + def test_computer_json(self): # pylint: disable=no-self-use + """ + In this test we check the correct behavior of QueryBuilder when + retrieving the _metadata with no content. + Note that they are in JSON format in both backends. Forcing the + decoding of a None value leads to an exception (this was the case + under Django). + """ + n1 = orm.CalculationNode() + n1.label = 'node2' + n1.set_attribute('foo', 1) + n1.store() + + # Checking the correct retrieval of _metadata which is + # a JSON field (in both backends). + qb = orm.QueryBuilder() + qb.append(orm.CalculationNode, project=['id'], tag='calc') + qb.append(orm.Computer, project=['id', 'metadata'], outerjoin=True, with_node='calc') + qb.all() + + +class TestAttributes(AiidaTestCase): + + def test_attribute_existence(self): + # I'm storing a value under key whatever: + val = 1. + res_uuids = set() + n1 = orm.Data() + n1.set_attribute('whatever', 3.) + n1.set_attribute('test_case', 'test_attribute_existence') + n1.store() + + # I want all the nodes where whatever is smaller than 1. or there is no such value: + + qb = orm.QueryBuilder() + qb.append( + orm.Data, + filters={ + 'or': [{ + 'attributes': { + '!has_key': 'whatever' + } + }, { + 'attributes.whatever': { + '<': val + } + }], + }, + project='uuid' + ) + res_query = {str(_[0]) for _ in qb.all()} + self.assertEqual(res_query, res_uuids) + + def test_attribute_type(self): + key = 'value_test_attr_type' + n_int, n_float, n_str, n_str2, n_bool, n_arr = [orm.Data() for _ in range(6)] + n_int.set_attribute(key, 1) + n_float.set_attribute(key, 1.0) + n_bool.set_attribute(key, True) + n_str.set_attribute(key, '1') + n_str2.set_attribute(key, 'one') + n_arr.set_attribute(key, [4, 3, 5]) + + for n in (n_str2, n_str, n_int, n_float, n_bool, n_arr): + n.store() + + # Here I am testing which values contain a number 1. + # Both 1 and 1.0 are legitimate values if ask for either 1 or 1.0 + for val in (1.0, 1): + qb = orm.QueryBuilder().append(orm.Node, filters={'attributes.{}'.format(key): val}, project='uuid') + res = [str(_) for _, in qb.all()] + self.assertEqual(set(res), set((n_float.uuid, n_int.uuid))) + qb = orm.QueryBuilder().append(orm.Node, filters={'attributes.{}'.format(key): {'>': 0.5}}, project='uuid') + res = [str(_) for _, in qb.all()] + self.assertEqual(set(res), set((n_float.uuid, n_int.uuid))) + qb = orm.QueryBuilder().append(orm.Node, filters={'attributes.{}'.format(key): {'<': 1.5}}, project='uuid') + res = [str(_) for _, in qb.all()] + self.assertEqual(set(res), set((n_float.uuid, n_int.uuid))) + # Now I am testing the boolean value: + qb = orm.QueryBuilder().append(orm.Node, filters={'attributes.{}'.format(key): True}, project='uuid') + res = [str(_) for _, in qb.all()] + self.assertEqual(set(res), set((n_bool.uuid,))) + + qb = orm.QueryBuilder().append(orm.Node, filters={'attributes.{}'.format(key): {'like': '%n%'}}, project='uuid') + res = [str(_) for _, in qb.all()] + self.assertEqual(set(res), set((n_str2.uuid,))) + qb = orm.QueryBuilder().append( + orm.Node, filters={'attributes.{}'.format(key): { + 'ilike': 'On%' + }}, project='uuid' + ) + res = [str(_) for _, in qb.all()] + self.assertEqual(set(res), set((n_str2.uuid,))) + qb = orm.QueryBuilder().append(orm.Node, filters={'attributes.{}'.format(key): {'like': '1'}}, project='uuid') + res = [str(_) for _, in qb.all()] + self.assertEqual(set(res), set((n_str.uuid,))) + qb = orm.QueryBuilder().append(orm.Node, filters={'attributes.{}'.format(key): {'==': '1'}}, project='uuid') + res = [str(_) for _, in qb.all()] + self.assertEqual(set(res), set((n_str.uuid,))) + if configuration.PROFILE.database_backend == u'sqlalchemy': + # I can't query the length of an array with Django, + # so I exclude. Not the nicest way, But I would like to keep this piece + # of code because of the initialization part, that would need to be + # duplicated or wrapped otherwise. + qb = orm.QueryBuilder().append( + orm.Node, filters={'attributes.{}'.format(key): { + 'of_length': 3 + }}, project='uuid' + ) + res = [str(_) for _, in qb.all()] + self.assertEqual(set(res), set((n_arr.uuid,))) + + +class QueryBuilderLimitOffsetsTest(AiidaTestCase): + + def test_ordering_limits_offsets_of_results_general(self): + # Creating 10 nodes with an attribute that can be ordered + for i in range(10): + n = orm.Data() + n.set_attribute('foo', i) + n.store() + + qb = orm.QueryBuilder().append(orm.Node, project='attributes.foo').order_by({orm.Node: 'ctime'}) + + res = next(zip(*qb.all())) + self.assertEqual(res, tuple(range(10))) + + # Now applying an offset: + qb.offset(5) + res = next(zip(*qb.all())) + self.assertEqual(res, tuple(range(5, 10))) + + # Now also applying a limit: + qb.limit(3) + res = next(zip(*qb.all())) + self.assertEqual(res, tuple(range(5, 8))) + + # Specifying the order explicitly the order: + qb = orm.QueryBuilder().append(orm.Node, + project='attributes.foo').order_by({orm.Node: { + 'ctime': { + 'order': 'asc' + } + }}) + + res = next(zip(*qb.all())) + self.assertEqual(res, tuple(range(10))) + + # Now applying an offset: + qb.offset(5) + res = next(zip(*qb.all())) + self.assertEqual(res, tuple(range(5, 10))) + + # Now also applying a limit: + qb.limit(3) + res = next(zip(*qb.all())) + self.assertEqual(res, tuple(range(5, 8))) + + # Reversing the order: + qb = orm.QueryBuilder().append(orm.Node, + project='attributes.foo').order_by({orm.Node: { + 'ctime': { + 'order': 'desc' + } + }}) + + res = next(zip(*qb.all())) + self.assertEqual(res, tuple(range(9, -1, -1))) + + # Now applying an offset: + qb.offset(5) + res = next(zip(*qb.all())) + self.assertEqual(res, tuple(range(4, -1, -1))) + + # Now also applying a limit: + qb.limit(3) + res = next(zip(*qb.all())) + self.assertEqual(res, tuple(range(4, 1, -1))) + + +class QueryBuilderJoinsTests(AiidaTestCase): + + def test_joins1(self): + # Creating n1, who will be a parent: + parent = orm.Data() + parent.label = 'mother' + parent.store() + + good_child = orm.CalculationNode() + good_child.label = 'good_child' + good_child.set_attribute('is_good', True) + + bad_child = orm.CalculationNode() + bad_child.label = 'bad_child' + bad_child.set_attribute('is_good', False) + + unrelated = orm.CalculationNode() + unrelated.label = 'unrelated' + unrelated.store() + + good_child.add_incoming(parent, link_type=LinkType.INPUT_CALC, link_label='parent') + bad_child.add_incoming(parent, link_type=LinkType.INPUT_CALC, link_label='parent') + good_child.store() + bad_child.store() + + # Using a standard inner join + qb = orm.QueryBuilder() + qb.append(orm.Node, tag='parent') + qb.append(orm.Node, tag='children', project='label', filters={'attributes.is_good': True}) + self.assertEqual(qb.count(), 1) + + qb = orm.QueryBuilder() + qb.append(orm.Node, tag='parent') + qb.append(orm.Node, tag='children', outerjoin=True, project='label', filters={'attributes.is_good': True}) + self.assertEqual(qb.count(), 1) + + def test_joins2(self): + # Creating n1, who will be a parent: + + students = [orm.Data() for i in range(10)] + advisors = [orm.CalculationNode() for i in range(3)] + for i, a in enumerate(advisors): + a.label = 'advisor {}'.format(i) + a.set_attribute('advisor_id', i) + + for n in advisors + students: + n.store() + + # advisor 0 get student 0, 1 + for i in (0, 1): + students[i].add_incoming(advisors[0], link_type=LinkType.CREATE, link_label='is_advisor_{}'.format(i)) + + # advisor 1 get student 3, 4 + for i in (3, 4): + students[i].add_incoming(advisors[1], link_type=LinkType.CREATE, link_label='is_advisor_{}'.format(i)) + + # advisor 2 get student 5, 6, 7 + for i in (5, 6, 7): + students[i].add_incoming(advisors[2], link_type=LinkType.CREATE, link_label='is_advisor_{}'.format(i)) + + # let's add a differnt relationship than advisor: + students[9].add_incoming(advisors[2], link_type=LinkType.CREATE, link_label='lover') + + self.assertEqual( + orm.QueryBuilder().append( + orm.Node + ).append(orm.Node, edge_filters={ + 'label': { + 'like': 'is\\_advisor\\_%' + } + }, tag='student').count(), 7 + ) + + for adv_id, number_students in zip(list(range(3)), (2, 2, 3)): + self.assertEqual( + orm.QueryBuilder().append(orm.Node, filters={ + 'attributes.advisor_id': adv_id + }).append(orm.Node, edge_filters={ + 'label': { + 'like': 'is\\_advisor\\_%' + } + }, tag='student').count(), number_students + ) + + def test_joins3_user_group(self): + # Create another user + new_email = 'newuser@new.n' + user = orm.User(email=new_email).store() + + # Create a group that belongs to that user + group = orm.Group(label='node_group') + group.user = user + group.store() + + # Search for the group of the user + qb = orm.QueryBuilder() + qb.append(orm.User, tag='user', filters={'id': {'==': user.id}}) + qb.append(orm.Group, with_user='user', filters={'id': {'==': group.id}}) + self.assertEqual(qb.count(), 1, 'The expected group that belongs to ' 'the selected user was not found.') + + # Search for the user that owns a group + qb = orm.QueryBuilder() + qb.append(orm.Group, tag='group', filters={'id': {'==': group.id}}) + qb.append(orm.User, with_group='group', filters={'id': {'==': user.id}}) + + self.assertEqual(qb.count(), 1, 'The expected user that owns the ' 'selected group was not found.') + + def test_joins_group_node(self): + """ + This test checks that the querying for the nodes that belong to a group works correctly (using QueryBuilder). + This is important for the Django backend with the use of aldjemy for the Django to SQLA schema translation. + Since this is not backend specific test (even if it is mainly used to test the querying of Django backend + with QueryBuilder), we keep it at the general tests (ran by both backends). + """ + new_email = 'newuser@new.n2' + user = orm.User(email=new_email).store() + + # Create a group that belongs to that user + group = orm.Group(label='node_group_2') + group.user = user + group.store() + + # Create nodes and add them to the created group + n1 = orm.Data() + n1.label = 'node1' + n1.set_attribute('foo', ['hello', 'goodbye']) + n1.store() + + n2 = orm.CalculationNode() + n2.label = 'node2' + n2.set_attribute('foo', 1) + n2.store() + + n3 = orm.Data() + n3.label = 'node3' + n3.set_attribute('foo', 1.0000) # Stored as fval + n3.store() + + n4 = orm.CalculationNode() + n4.label = 'node4' + n4.set_attribute('foo', 'bar') + n4.store() + + group.add_nodes([n1, n2, n3, n4]) + + # Check that the nodes are in the group + qb = orm.QueryBuilder() + qb.append(orm.Node, tag='node', project=['id']) + qb.append(orm.Group, with_node='node', filters={'id': {'==': group.id}}) + self.assertEqual(qb.count(), 4, 'There should be 4 nodes in the group') + id_res = [_ for [_] in qb.all()] + for curr_id in [n1.id, n2.id, n3.id, n4.id]: + self.assertIn(curr_id, id_res) + + +class QueryBuilderPath(AiidaTestCase): + + def test_query_path(self): + # pylint: disable=too-many-statements + + q = self.backend.query_manager + n1 = orm.Data() + n1.label = 'n1' + n2 = orm.CalculationNode() + n2.label = 'n2' + n3 = orm.Data() + n3.label = 'n3' + n4 = orm.Data() + n4.label = 'n4' + n5 = orm.CalculationNode() + n5.label = 'n5' + n6 = orm.Data() + n6.label = 'n6' + n7 = orm.CalculationNode() + n7.label = 'n7' + n8 = orm.Data() + n8.label = 'n8' + n9 = orm.Data() + n9.label = 'n9' + + # I create a strange graph, inserting links in a order + # such that I often have to create the transitive closure + # between two graphs + n3.add_incoming(n2, link_type=LinkType.CREATE, link_label='link1') + n2.add_incoming(n1, link_type=LinkType.INPUT_CALC, link_label='link2') + n5.add_incoming(n3, link_type=LinkType.INPUT_CALC, link_label='link3') + n5.add_incoming(n4, link_type=LinkType.INPUT_CALC, link_label='link4') + n4.add_incoming(n2, link_type=LinkType.CREATE, link_label='link5') + n7.add_incoming(n6, link_type=LinkType.INPUT_CALC, link_label='link6') + n8.add_incoming(n7, link_type=LinkType.CREATE, link_label='link7') + + for node in [n1, n2, n3, n4, n5, n6, n7, n8, n9]: + node.store() + + # There are no parents to n9, checking that + self.assertEqual(set([]), set(q.get_all_parents([n9.pk]))) + # There is one parent to n6 + self.assertEqual({(_,) for _ in (n6.pk,)}, {tuple(_) for _ in q.get_all_parents([n7.pk])}) + # There are several parents to n4 + self.assertEqual({(_.pk,) for _ in (n1, n2)}, {tuple(_) for _ in q.get_all_parents([n4.pk])}) + # There are several parents to n5 + self.assertEqual({(_.pk,) for _ in (n1, n2, n3, n4)}, {tuple(_) for _ in q.get_all_parents([n5.pk])}) + + # Yet, no links from 1 to 8 + self.assertEqual( + orm.QueryBuilder().append(orm.Node, filters={ + 'id': n1.pk + }, tag='anc').append(orm.Node, with_ancestors='anc', filters={ + 'id': n8.pk + }).count(), 0 + ) + + self.assertEqual( + orm.QueryBuilder().append(orm.Node, filters={ + 'id': n8.pk + }, tag='desc').append(orm.Node, with_descendants='desc', filters={ + 'id': n1.pk + }).count(), 0 + ) + + n6.add_incoming(n5, link_type=LinkType.CREATE, link_label='link1') + # Yet, now 2 links from 1 to 8 + self.assertEqual( + orm.QueryBuilder().append(orm.Node, filters={ + 'id': n1.pk + }, tag='anc').append(orm.Node, with_ancestors='anc', filters={ + 'id': n8.pk + }).count(), 2 + ) + + self.assertEqual( + orm.QueryBuilder().append(orm.Node, filters={ + 'id': n8.pk + }, tag='desc').append(orm.Node, with_descendants='desc', filters={ + 'id': n1.pk + }).count(), 2 + ) + + self.assertEqual( + orm.QueryBuilder().append(orm.Node, filters={ + 'id': n8.pk + }, tag='desc').append( + orm.Node, + with_descendants='desc', + filters={ + 'id': n1.pk + }, + edge_filters={ + 'depth': { + '<': 6 + } + }, + ).count(), 2 + ) + self.assertEqual( + orm.QueryBuilder().append(orm.Node, filters={ + 'id': n8.pk + }, tag='desc').append( + orm.Node, + with_descendants='desc', + filters={ + 'id': n1.pk + }, + edge_filters={ + 'depth': 5 + }, + ).count(), 2 + ) + self.assertEqual( + orm.QueryBuilder().append(orm.Node, filters={ + 'id': n8.pk + }, tag='desc').append( + orm.Node, + with_descendants='desc', + filters={ + 'id': n1.pk + }, + edge_filters={ + 'depth': { + '<': 5 + } + }, + ).count(), 0 + ) + + # TODO write a query that can filter certain paths by traversed ID # pylint: disable=fixme + qb = orm.QueryBuilder().append( + orm.Node, + filters={ + 'id': n8.pk + }, + tag='desc', + ).append(orm.Node, with_descendants='desc', edge_project='path', filters={'id': n1.pk}) + queried_path_set = {frozenset(p) for p, in qb.all()} + + paths_there_should_be = { + frozenset([n1.pk, n2.pk, n3.pk, n5.pk, n6.pk, n7.pk, n8.pk]), + frozenset([n1.pk, n2.pk, n4.pk, n5.pk, n6.pk, n7.pk, n8.pk]) + } + + self.assertTrue(queried_path_set == paths_there_should_be) + + qb = orm.QueryBuilder().append(orm.Node, filters={ + 'id': n1.pk + }, tag='anc').append(orm.Node, with_ancestors='anc', filters={'id': n8.pk}, edge_project='path') + + self.assertEqual({frozenset(p) for p, in qb.all()}, { + frozenset([n1.pk, n2.pk, n3.pk, n5.pk, n6.pk, n7.pk, n8.pk]), + frozenset([n1.pk, n2.pk, n4.pk, n5.pk, n6.pk, n7.pk, n8.pk]) + }) + + # This part of the test is no longer possible as the nodes have already been stored and the previous parts of + # the test rely on this, which means however, that here, no more links can be added as that will raise. + + # n7.add_incoming(n9, link_type=LinkType.INPUT_CALC, link_label='link0') + # # Still two links... + + # self.assertEqual( + # orm.QueryBuilder().append(orm.Node, filters={ + # 'id': n1.pk + # }, tag='anc').append(orm.Node, with_ancestors='anc', filters={ + # 'id': n8.pk + # }).count(), 2) + + # self.assertEqual( + # orm.QueryBuilder().append(orm.Node, filters={ + # 'id': n8.pk + # }, tag='desc').append(orm.Node, with_descendants='desc', filters={ + # 'id': n1.pk + # }).count(), 2) + # n9.add_incoming(n5, link_type=LinkType.CREATE, link_label='link6') + # # And now there should be 4 nodes + + # self.assertEqual( + # orm.QueryBuilder().append(orm.Node, filters={ + # 'id': n1.pk + # }, tag='anc').append(orm.Node, with_ancestors='anc', filters={ + # 'id': n8.pk + # }).count(), 4) + + # self.assertEqual( + # orm.QueryBuilder().append(orm.Node, filters={ + # 'id': n8.pk + # }, tag='desc').append(orm.Node, with_descendants='desc', filters={ + # 'id': n1.pk + # }).count(), 4) + + # qb = orm.QueryBuilder().append( + # orm.Node, filters={ + # 'id': n1.pk + # }, tag='anc').append( + # orm.Node, with_ancestors='anc', filters={'id': n8.pk}, edge_tag='edge') + # qb.add_projection('edge', 'depth') + # self.assertTrue(set(next(zip(*qb.all()))), set([5, 6])) + # qb.add_filter('edge', {'depth': 5}) + # self.assertTrue(set(next(zip(*qb.all()))), set([5])) + + +class TestConsistency(AiidaTestCase): + + def test_create_node_and_query(self): + """ + Testing whether creating nodes within a iterall iteration changes the results. + """ + for _i in range(100): + n = orm.Data() + n.store() + + for idx, _item in enumerate( + orm.QueryBuilder().append(orm.Node, project=['id', 'label']).iterall(batch_size=10) + ): + if idx % 10 == 10: + n = orm.Data() + n.store() + self.assertEqual(idx, 99) # pylint: disable=undefined-loop-variable + self.assertTrue(len(orm.QueryBuilder().append(orm.Node, project=['id', 'label']).all(batch_size=10)) > 99) + + def test_len_results(self): + """ + Test whether the len of results matches the count returned. + See also https://github.com/aiidateam/aiida-core/issues/1600 + SQLAlchemy has a deduplication strategy that leads to strange behavior, tested against here + """ + parent = orm.CalculationNode().store() + # adding 5 links going out: + for inode in range(5): + output_node = orm.Data().store() + output_node.add_incoming(parent, link_type=LinkType.CREATE, link_label='link_{}'.format(inode)) + for projection in ('id', '*'): + qb = orm.QueryBuilder() + qb.append(orm.CalculationNode, filters={'id': parent.id}, tag='parent', project=projection) + qb.append(orm.Data, with_incoming='parent') + self.assertEqual(len(qb.all()), qb.count()) + + +class TestManager(AiidaTestCase): + + def test_statistics(self): + """ + Test if the statistics query works properly. + + I try to implement it in a way that does not depend on the past state. + """ + from collections import defaultdict + + # pylint: disable=protected-access + + def store_and_add(n, statistics): + n.store() + statistics['total'] += 1 + statistics['types'][n._plugin_type_string] += 1 # pylint: disable=no-member + statistics['ctime_by_day'][n.ctime.strftime('%Y-%m-%d')] += 1 + + qmanager = self.backend.query_manager + current_db_statistics = qmanager.get_creation_statistics() + types = defaultdict(int) + types.update(current_db_statistics['types']) + ctime_by_day = defaultdict(int) + ctime_by_day.update(current_db_statistics['ctime_by_day']) + + expected_db_statistics = {'total': current_db_statistics['total'], 'types': types, 'ctime_by_day': ctime_by_day} + + store_and_add(orm.Data(), expected_db_statistics) + store_and_add(orm.Dict(), expected_db_statistics) + store_and_add(orm.Dict(), expected_db_statistics) + store_and_add(orm.CalculationNode(), expected_db_statistics) + + new_db_statistics = qmanager.get_creation_statistics() + # I only check a few fields + new_db_statistics = {k: v for k, v in new_db_statistics.items() if k in expected_db_statistics} + + expected_db_statistics = { + k: dict(v) if isinstance(v, defaultdict) else v for k, v in expected_db_statistics.items() + } + + self.assertEqual(new_db_statistics, expected_db_statistics) + + def test_statistics_default_class(self): + """ + Test if the statistics query works properly. + + I try to implement it in a way that does not depend on the past state. + """ + from collections import defaultdict + + def store_and_add(n, statistics): + n.store() + statistics['total'] += 1 + statistics['types'][n._plugin_type_string] += 1 # pylint: disable=no-member,protected-access + statistics['ctime_by_day'][n.ctime.strftime('%Y-%m-%d')] += 1 + + current_db_statistics = self.backend.query_manager.get_creation_statistics() + types = defaultdict(int) + types.update(current_db_statistics['types']) + ctime_by_day = defaultdict(int) + ctime_by_day.update(current_db_statistics['ctime_by_day']) + + expected_db_statistics = {'total': current_db_statistics['total'], 'types': types, 'ctime_by_day': ctime_by_day} + + store_and_add(orm.Data(), expected_db_statistics) + store_and_add(orm.Dict(), expected_db_statistics) + store_and_add(orm.Dict(), expected_db_statistics) + store_and_add(orm.CalculationNode(), expected_db_statistics) + + new_db_statistics = self.backend.query_manager.get_creation_statistics() + # I only check a few fields + new_db_statistics = {k: v for k, v in new_db_statistics.items() if k in expected_db_statistics} + + expected_db_statistics = { + k: dict(v) if isinstance(v, defaultdict) else v for k, v in expected_db_statistics.items() + } + + self.assertEqual(new_db_statistics, expected_db_statistics) + + +class TestDoubleStar(AiidaTestCase): + """ + In this test class we check if QueryBuilder returns the correct results + when double star is provided as projection. + """ + + def test_statistics_default_class(self): + + # The expected result + # pylint: disable=no-member + expected_dict = { + u'description': self.computer.description, + u'scheduler_type': self.computer.get_scheduler_type(), + u'hostname': self.computer.hostname, + u'uuid': self.computer.uuid, + u'name': self.computer.name, + u'transport_type': self.computer.get_transport_type(), + u'id': self.computer.id, + u'metadata': self.computer.get_metadata(), + } + + qb = orm.QueryBuilder() + qb.append(orm.Computer, project=['**']) + # We expect one result + self.assertEqual(qb.count(), 1) + + # Get the one result record and check that the returned + # data are correct + res = list(qb.dict()[0].values())[0] + self.assertDictEqual(res, expected_dict) + + # Ask the same query as above using queryhelp + qh = {'project': {'computer': ['**']}, 'path': [{'tag': 'computer', 'cls': orm.Computer}]} + qb = orm.QueryBuilder(**qh) + # We expect one result + self.assertEqual(qb.count(), 1) + + # Get the one result record and check that the returned + # data are correct + res = list(qb.dict()[0].values())[0] + self.assertDictEqual(res, expected_dict) diff --git a/aiida/backends/tests/test_query.py b/aiida/backends/tests/test_query.py deleted file mode 100644 index 38f4666e50..0000000000 --- a/aiida/backends/tests/test_query.py +++ /dev/null @@ -1,1361 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# pylint: disable=invalid-name,missing-docstring,too-many-lines -"""Tests for the QueryBuilder.""" -from __future__ import division -from __future__ import absolute_import -from __future__ import print_function - -import warnings - -from six.moves import range, zip - -from aiida import orm -from aiida.backends.testbase import AiidaTestCase -from aiida.common.links import LinkType -from aiida.manage import configuration - - -class TestQueryBuilder(AiidaTestCase): - - def setUp(self): - super(TestQueryBuilder, self).setUp() - self.clean_db() - self.insert_data() - - def test_ormclass_type_classification(self): - """ - This tests the classifications of the QueryBuilder - """ - # pylint: disable=protected-access - from aiida.common.exceptions import DbContentError - - qb = orm.QueryBuilder() - - # Asserting that improper declarations of the class type raise an error - with self.assertRaises(DbContentError): - qb._get_ormclass(None, 'data') - with self.assertRaises(DbContentError): - qb._get_ormclass(None, 'data.Data') - with self.assertRaises(DbContentError): - qb._get_ormclass(None, '.') - - # Asserting that the query type string and plugin type string are returned: - for _cls, classifiers in ( - qb._get_ormclass(orm.StructureData, None), - qb._get_ormclass(None, 'data.structure.StructureData.'), - ): - self.assertEqual(classifiers['ormclass_type_string'], orm.StructureData._plugin_type_string) # pylint: disable=no-member - - for _cls, classifiers in ( - qb._get_ormclass(orm.Group, None), - qb._get_ormclass(None, 'group'), - qb._get_ormclass(None, 'Group'), - ): - self.assertEqual(classifiers['ormclass_type_string'], 'group') - - for _cls, classifiers in ( - qb._get_ormclass(orm.User, None), - qb._get_ormclass(None, 'user'), - qb._get_ormclass(None, 'User'), - ): - self.assertEqual(classifiers['ormclass_type_string'], 'user') - - for _cls, classifiers in ( - qb._get_ormclass(orm.Computer, None), - qb._get_ormclass(None, 'computer'), - qb._get_ormclass(None, 'Computer'), - ): - self.assertEqual(classifiers['ormclass_type_string'], 'computer') - - for _cls, classifiers in ( - qb._get_ormclass(orm.Data, None), - qb._get_ormclass(None, 'data.Data.'), - ): - self.assertEqual(classifiers['ormclass_type_string'], orm.Data._plugin_type_string) # pylint: disable=no-member - - def test_process_type_classification(self): - """ - This tests the classifications of the QueryBuilder - """ - from aiida.engine import WorkChain - from aiida.plugins import CalculationFactory - - ArithmeticAdd = CalculationFactory('arithmetic.add') - - qb = orm.QueryBuilder() - - # pylint: disable=protected-access - - # When passing a WorkChain class, it should return the type of the corresponding Node - # including the appropriate filter on the process_type - _cls, classifiers = qb._get_ormclass(WorkChain, None) - self.assertEqual(classifiers['ormclass_type_string'], 'process.workflow.workchain.WorkChainNode.') - self.assertEqual(classifiers['process_type_string'], 'aiida.engine.processes.workchains.workchain.WorkChain') - - # When passing a WorkChainNode, no process_type filter is applied - _cls, classifiers = qb._get_ormclass(orm.WorkChainNode, None) - self.assertEqual(classifiers['ormclass_type_string'], 'process.workflow.workchain.WorkChainNode.') - self.assertEqual(classifiers['process_type_string'], None) - - # Same tests for a calculation - _cls, classifiers = qb._get_ormclass(ArithmeticAdd, None) - self.assertEqual(classifiers['ormclass_type_string'], 'process.calculation.calcjob.CalcJobNode.') - self.assertEqual(classifiers['process_type_string'], 'aiida.calculations:arithmetic.add') - - def test_process_query(self): - """ - Test querying for a process class. - """ - from aiida.engine import run, WorkChain, if_, return_, ExitCode - from aiida.common.warnings import AiidaEntryPointWarning - - class PotentialFailureWorkChain(WorkChain): - EXIT_STATUS = 1 - EXIT_MESSAGE = 'Well you did ask for it' - OUTPUT_LABEL = 'optional_output' - OUTPUT_VALUE = 144 - - @classmethod - def define(cls, spec): - super(PotentialFailureWorkChain, cls).define(spec) - spec.input('success', valid_type=orm.Bool) - spec.input('through_return', valid_type=orm.Bool, default=orm.Bool(False)) - spec.input('through_exit_code', valid_type=orm.Bool, default=orm.Bool(False)) - spec.exit_code(cls.EXIT_STATUS, 'EXIT_STATUS', cls.EXIT_MESSAGE) - spec.outline(if_(cls.should_return_out_of_outline)(return_(cls.EXIT_STATUS)), cls.failure, cls.success) - spec.output(cls.OUTPUT_LABEL, required=False) - - def should_return_out_of_outline(self): - return self.inputs.through_return.value - - def failure(self): - # pylint: disable=no-else-return - - if self.inputs.success.value is False: - # Returning either 0 or ExitCode with non-zero status should terminate the workchain - if self.inputs.through_exit_code.value is False: - return self.EXIT_STATUS - else: - return self.exit_codes.EXIT_STATUS # pylint: disable=no-member - else: - # Returning 0 or ExitCode with zero status should *not* terminate the workchain - if self.inputs.through_exit_code.value is False: - return 0 - else: - return ExitCode() - - def success(self): - self.out(self.OUTPUT_LABEL, orm.Int(self.OUTPUT_VALUE).store()) - - class DummyWorkChain(WorkChain): - pass - - # Run a simple test WorkChain - _result = run(PotentialFailureWorkChain, success=orm.Bool(True)) - - # Query for nodes associated with this type of WorkChain - qb = orm.QueryBuilder() - - with warnings.catch_warnings(record=True) as w: # pylint: disable=no-member - # Cause all warnings to always be triggered. - warnings.simplefilter('always') # pylint: disable=no-member - - qb.append(PotentialFailureWorkChain) - - # Verify some things - assert len(w) == 1 - assert issubclass(w[-1].category, AiidaEntryPointWarning) - - # There should be one result of type WorkChainNode - self.assertEqual(qb.count(), 1) - self.assertTrue(isinstance(qb.all()[0][0], orm.WorkChainNode)) - - # Query for nodes of a different type of WorkChain - qb = orm.QueryBuilder() - - with warnings.catch_warnings(record=True) as w: # pylint: disable=no-member - # Cause all warnings to always be triggered. - warnings.simplefilter('always') # pylint: disable=no-member - - qb.append(DummyWorkChain) - - # Verify some things - assert len(w) == 1 - assert issubclass(w[-1].category, AiidaEntryPointWarning) - - # There should be no result - self.assertEqual(qb.count(), 0) - - # Query for all WorkChain nodes - qb = orm.QueryBuilder() - qb.append(WorkChain) - - # There should be one result - self.assertEqual(qb.count(), 1) - - def test_simple_query_1(self): - """ - Testing a simple query - """ - # pylint: disable=too-many-statements - - n1 = orm.Data() - n1.label = 'node1' - n1.set_attribute('foo', ['hello', 'goodbye']) - n1.store() - - n2 = orm.CalculationNode() - n2.label = 'node2' - n2.set_attribute('foo', 1) - - n3 = orm.Data() - n3.label = 'node3' - n3.set_attribute('foo', 1.0000) # Stored as fval - n3.store() - - n4 = orm.CalculationNode() - n4.label = 'node4' - n4.set_attribute('foo', 'bar') - - n5 = orm.Data() - n5.label = 'node5' - n5.set_attribute('foo', None) - n5.store() - - n2.add_incoming(n1, link_type=LinkType.INPUT_CALC, link_label='link1') - n2.store() - n3.add_incoming(n2, link_type=LinkType.CREATE, link_label='link2') - - n4.add_incoming(n3, link_type=LinkType.INPUT_CALC, link_label='link3') - n4.store() - n5.add_incoming(n4, link_type=LinkType.CREATE, link_label='link4') - - qb1 = orm.QueryBuilder() - qb1.append(orm.Node, filters={'attributes.foo': 1.000}) - - self.assertEqual(len(qb1.all()), 2) - - qb2 = orm.QueryBuilder() - qb2.append(orm.Data) - self.assertEqual(qb2.count(), 3) - - qb2 = orm.QueryBuilder() - qb2.append(entity_type='data.Data.') - self.assertEqual(qb2.count(), 3) - - qb3 = orm.QueryBuilder() - qb3.append(orm.Node, project='label', tag='node1') - qb3.append(orm.Node, project='label', tag='node2') - self.assertEqual(qb3.count(), 4) - - qb4 = orm.QueryBuilder() - qb4.append(orm.CalculationNode, tag='node1') - qb4.append(orm.Data, tag='node2') - self.assertEqual(qb4.count(), 2) - - qb5 = orm.QueryBuilder() - qb5.append(orm.Data, tag='node1') - qb5.append(orm.CalculationNode, tag='node2') - self.assertEqual(qb5.count(), 2) - - qb6 = orm.QueryBuilder() - qb6.append(orm.Data, tag='node1') - qb6.append(orm.Data, tag='node2') - self.assertEqual(qb6.count(), 0) - - def test_simple_query_2(self): - from datetime import datetime - from aiida.common.exceptions import MultipleObjectsError, NotExistent - n0 = orm.Data() - n0.label = 'hello' - n0.description = '' - n0.set_attribute('foo', 'bar') - - n1 = orm.CalculationNode() - n1.label = 'foo' - n1.description = 'I am FoO' - - n2 = orm.Data() - n2.label = 'bar' - n2.description = 'I am BaR' - - n2.add_incoming(n1, link_type=LinkType.CREATE, link_label='random_2') - n1.add_incoming(n0, link_type=LinkType.INPUT_CALC, link_label='random_1') - - for n in (n0, n1, n2): - n.store() - - qb1 = orm.QueryBuilder() - qb1.append(orm.Node, filters={'label': 'hello'}) - self.assertEqual(len(list(qb1.all())), 1) - - qh = { - 'path': [{ - 'cls': orm.Node, - 'tag': 'n1' - }, { - 'cls': orm.Node, - 'tag': 'n2', - 'with_incoming': 'n1' - }], - 'filters': { - 'n1': { - 'label': { - 'ilike': '%foO%' - }, - }, - 'n2': { - 'label': { - 'ilike': 'bar%' - }, - } - }, - 'project': { - 'n1': ['id', 'uuid', 'ctime', 'label'], - 'n2': ['id', 'description', 'label'], - } - } - - qb2 = orm.QueryBuilder(**qh) - - resdict = qb2.dict() - self.assertEqual(len(resdict), 1) - self.assertTrue(isinstance(resdict[0]['n1']['ctime'], datetime)) - - res_one = qb2.one() - self.assertTrue('bar' in res_one) - - qh = { - 'path': [{ - 'cls': orm.Node, - 'tag': 'n1' - }, { - 'cls': orm.Node, - 'tag': 'n2', - 'with_incoming': 'n1' - }], - 'filters': { - 'n1--n2': { - 'label': { - 'like': '%_2' - } - } - } - } - qb = orm.QueryBuilder(**qh) - self.assertEqual(qb.count(), 1) - - # Test the hashing: - query1 = qb.get_query() - qb.add_filter('n2', {'label': 'nonexistentlabel'}) - self.assertEqual(qb.count(), 0) - - with self.assertRaises(NotExistent): - qb.one() - with self.assertRaises(MultipleObjectsError): - orm.QueryBuilder().append(orm.Node).one() - - query2 = qb.get_query() - query3 = qb.get_query() - - self.assertTrue(id(query1) != id(query2)) - self.assertTrue(id(query2) == id(query3)) - - def test_operators_eq_lt_gt(self): - nodes = [orm.Data() for _ in range(8)] - - nodes[0].set_attribute('fa', 1) - nodes[1].set_attribute('fa', 1.0) - nodes[2].set_attribute('fa', 1.01) - nodes[3].set_attribute('fa', 1.02) - nodes[4].set_attribute('fa', 1.03) - nodes[5].set_attribute('fa', 1.04) - nodes[6].set_attribute('fa', 1.05) - nodes[7].set_attribute('fa', 1.06) - - for n in nodes: - n.store() - - self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'<': 1}}).count(), 0) - self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'==': 1}}).count(), 2) - self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'<': 1.02}}).count(), 3) - self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'<=': 1.02}}).count(), 4) - self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'>': 1.02}}).count(), 4) - self.assertEqual(orm.QueryBuilder().append(orm.Node, filters={'attributes.fa': {'>=': 1.02}}).count(), 5) - - def test_subclassing(self): - s = orm.StructureData() - s.set_attribute('cat', 'miau') - s.store() - - d = orm.Data() - d.set_attribute('cat', 'miau') - d.store() - - p = orm.Dict(dict=dict(cat='miau')) - p.store() - - # Now when asking for a node with attr.cat==miau, I want 3 esults: - qb = orm.QueryBuilder().append(orm.Node, filters={'attributes.cat': 'miau'}) - self.assertEqual(qb.count(), 3) - - qb = orm.QueryBuilder().append(orm.Data, filters={'attributes.cat': 'miau'}) - self.assertEqual(qb.count(), 3) - - # If I'm asking for the specific lowest subclass, I want one result - for cls in (orm.StructureData, orm.Dict): - qb = orm.QueryBuilder().append(cls, filters={'attributes.cat': 'miau'}) - self.assertEqual(qb.count(), 1) - - # Now I am not allow the subclassing, which should give 1 result for each - for cls, count in ((orm.StructureData, 1), (orm.Dict, 1), (orm.Data, 1), (orm.Node, 0)): - qb = orm.QueryBuilder().append(cls, filters={'attributes.cat': 'miau'}, subclassing=False) - self.assertEqual(qb.count(), count) - - # Now I am testing the subclassing with tuples: - qb = orm.QueryBuilder().append(cls=(orm.StructureData, orm.Dict), filters={'attributes.cat': 'miau'}) - self.assertEqual(qb.count(), 2) - qb = orm.QueryBuilder().append( - entity_type=('data.structure.StructureData.', 'data.dict.Dict.'), filters={'attributes.cat': 'miau'} - ) - self.assertEqual(qb.count(), 2) - qb = orm.QueryBuilder().append( - cls=(orm.StructureData, orm.Dict), filters={'attributes.cat': 'miau'}, subclassing=False - ) - self.assertEqual(qb.count(), 2) - qb = orm.QueryBuilder().append( - cls=(orm.StructureData, orm.Data), - filters={'attributes.cat': 'miau'}, - ) - self.assertEqual(qb.count(), 3) - qb = orm.QueryBuilder().append( - entity_type=('data.structure.StructureData.', 'data.dict.Dict.'), - filters={'attributes.cat': 'miau'}, - subclassing=False - ) - self.assertEqual(qb.count(), 2) - qb = orm.QueryBuilder().append( - entity_type=('data.structure.StructureData.', 'data.Data.'), - filters={'attributes.cat': 'miau'}, - subclassing=False - ) - self.assertEqual(qb.count(), 2) - - def test_list_behavior(self): - for _i in range(4): - orm.Data().store() - - self.assertEqual(len(orm.QueryBuilder().append(orm.Node).all()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project='*').all()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).all()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project=['id']).all()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node).dict()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project='*').dict()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).dict()), 4) - self.assertEqual(len(orm.QueryBuilder().append(orm.Node, project=['id']).dict()), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node).iterall())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project='*').iterall())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).iterall())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project=['id']).iterall())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node).iterdict())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project='*').iterdict())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project=['*', 'id']).iterdict())), 4) - self.assertEqual(len(list(orm.QueryBuilder().append(orm.Node, project=['id']).iterdict())), 4) - - def test_append_validation(self): - from aiida.common.exceptions import InputValidationError - - # So here I am giving two times the same tag - with self.assertRaises(InputValidationError): - orm.QueryBuilder().append(orm.StructureData, tag='n').append(orm.StructureData, tag='n') - # here I am giving a wrong filter specifications - with self.assertRaises(InputValidationError): - orm.QueryBuilder().append(orm.StructureData, filters=['jajjsd']) - # here I am giving a nonsensical projection: - with self.assertRaises(InputValidationError): - orm.QueryBuilder().append(orm.StructureData, project=True) - - # here I am giving a nonsensical projection for the edge: - with self.assertRaises(InputValidationError): - orm.QueryBuilder().append(orm.ProcessNode).append(orm.StructureData, edge_tag='t').add_projection('t', True) - # Giving a nonsensical limit - with self.assertRaises(InputValidationError): - orm.QueryBuilder().append(orm.ProcessNode).limit(2.3) - # Giving a nonsensical offset - with self.assertRaises(InputValidationError): - orm.QueryBuilder(offset=2.3) - - # So, I mess up one append, I want the QueryBuilder to clean it! - with self.assertRaises(InputValidationError): - qb = orm.QueryBuilder() - # This also checks if we correctly raise for wrong keywords - qb.append(orm.StructureData, tag='s', randomkeyword={}) - - # Now I'm checking whether this keyword appears anywhere in the internal dictionaries: - # pylint: disable=protected-access - self.assertTrue('s' not in qb._projections) - self.assertTrue('s' not in qb._filters) - self.assertTrue('s' not in qb.tag_to_alias_map) - self.assertTrue(len(qb._path) == 0) - self.assertTrue(orm.StructureData not in qb._cls_to_tag_map) - # So this should work now: - qb.append(orm.StructureData, tag='s').limit(2).dict() - - def test_tags(self): - qb = orm.QueryBuilder() - qb.append(orm.Node, tag='n1') - qb.append(orm.Node, tag='n2', edge_tag='e1', with_incoming='n1') - qb.append(orm.Node, tag='n3', edge_tag='e2', with_incoming='n2') - qb.append(orm.Computer, with_node='n3', tag='c1', edge_tag='nonsense') - self.assertEqual(qb.get_used_tags(), ['n1', 'n2', 'e1', 'n3', 'e2', 'c1', 'nonsense']) - - # Now I am testing the default tags, - qb = orm.QueryBuilder().append(orm.StructureData - ).append(orm.ProcessNode).append(orm.StructureData - ).append(orm.Dict, with_outgoing=orm.ProcessNode) - self.assertEqual( - qb.get_used_tags(), [ - 'StructureData_1', 'ProcessNode_1', 'StructureData_1--ProcessNode_1', 'StructureData_2', - 'ProcessNode_1--StructureData_2', 'Dict_1', 'ProcessNode_1--Dict_1' - ] - ) - self.assertEqual( - qb.get_used_tags(edges=False), [ - 'StructureData_1', - 'ProcessNode_1', - 'StructureData_2', - 'Dict_1', - ] - ) - self.assertEqual( - qb.get_used_tags(vertices=False), - ['StructureData_1--ProcessNode_1', 'ProcessNode_1--StructureData_2', 'ProcessNode_1--Dict_1'] - ) - self.assertEqual( - qb.get_used_tags(edges=False), [ - 'StructureData_1', - 'ProcessNode_1', - 'StructureData_2', - 'Dict_1', - ] - ) - self.assertEqual( - qb.get_used_tags(vertices=False), - ['StructureData_1--ProcessNode_1', 'ProcessNode_1--StructureData_2', 'ProcessNode_1--Dict_1'] - ) - - def test_direction_keyword(self): - """ - The direction keyword is a special case with the QueryBuilder append - method, so some tests are good. - """ - d1, d2, d3, d4 = [orm.Data().store() for _ in range(4)] - c1, c2 = [orm.CalculationNode() for _ in range(2)] - c1.add_incoming(d1, link_type=LinkType.INPUT_CALC, link_label='link_d1c1') - c1.store() - d2.add_incoming(c1, link_type=LinkType.CREATE, link_label='link_c1d2') - d4.add_incoming(c1, link_type=LinkType.CREATE, link_label='link_c1d4') - c2.add_incoming(d2, link_type=LinkType.INPUT_CALC, link_label='link_d2c2') - c2.store() - d3.add_incoming(c2, link_type=LinkType.CREATE, link_label='link_c2d3') - - # testing direction=1 for d1, which should return the outgoing - qb = orm.QueryBuilder() - qb.append(orm.Data, filters={'id': d1.id}) - qb.append(orm.CalculationNode, direction=1, project='id') - res1 = {_ for _, in qb.all()} - - qb = orm.QueryBuilder() - qb.append(orm.Data, filters={'id': d1.id}, tag='data') - qb.append(orm.CalculationNode, with_incoming='data', project='id') - res2 = {_ for _, in qb.all()} - - self.assertEqual(res1, res2) - self.assertEqual(res1, {c1.id}) - - # testing direction=-1, which should return the incoming - qb = orm.QueryBuilder() - qb.append(orm.Data, filters={'id': d2.id}) - qb.append(orm.CalculationNode, direction=-1, project='id') - res1 = {_ for _, in qb.all()} - - qb = orm.QueryBuilder() - qb.append(orm.Data, filters={'id': d2.id}, tag='data') - qb.append(orm.CalculationNode, with_outgoing='data', project='id') - res2 = {_ for _, in qb.all()} - self.assertEqual(res1, res2) - self.assertEqual(res1, {c1.id}) - - # testing direction higher than 1 - qb = orm.QueryBuilder() - qb.append(orm.CalculationNode, tag='c1', filters={'id': c1.id}) - qb.append(orm.Data, with_incoming='c1', tag='d2or4') - qb.append(orm.CalculationNode, tag='c2', with_incoming='d2or4') - qb.append(orm.Data, tag='d3', with_incoming='c2', project='id') - qh = qb.get_json_compatible_queryhelp() # saving query for later - qb.append(orm.Data, direction=-4, project='id') - res1 = {item[1] for item in qb.all()} - self.assertEqual(res1, {d1.id}) - - qb = orm.QueryBuilder(**qh) - qb.append(orm.Data, direction=4, project='id') - res2 = {item[1] for item in qb.all()} - self.assertEqual(res2, {d2.id, d4.id}) - - -class TestQueryHelp(AiidaTestCase): - - def test_queryhelp(self): - """ - Here I test the queryhelp by seeing whether results are the same as using the append method. - I also check passing of tuples. - """ - g = orm.Group(label='helloworld').store() - for cls in (orm.StructureData, orm.Dict, orm.Data): - obj = cls() - obj.set_attribute('foo-qh2', 'bar') - obj.store() - g.add_nodes(obj) - - for cls, expected_count, subclassing in ( - (orm.StructureData, 1, True), - (orm.Dict, 1, True), - (orm.Data, 3, True), - (orm.Data, 1, False), - ((orm.Dict, orm.StructureData), 2, True), - ((orm.Dict, orm.StructureData), 2, False), - ((orm.Dict, orm.Data), 2, False), - ((orm.Dict, orm.Data), 3, True), - ((orm.Dict, orm.Data, orm.StructureData), 3, False), - ): - qb = orm.QueryBuilder() - qb.append(cls, filters={'attributes.foo-qh2': 'bar'}, subclassing=subclassing, project='uuid') - self.assertEqual(qb.count(), expected_count) - - qh = qb.get_json_compatible_queryhelp() - qb_new = orm.QueryBuilder(**qh) - self.assertEqual(qb_new.count(), expected_count) - self.assertEqual(sorted([uuid for uuid, in qb.all()]), sorted([uuid for uuid, in qb_new.all()])) - - qb = orm.QueryBuilder().append(orm.Group, filters={'label': 'helloworld'}) - self.assertEqual(qb.count(), 1) - - qb = orm.QueryBuilder().append((orm.Group,), filters={'label': 'helloworld'}) - self.assertEqual(qb.count(), 1) - - qb = orm.QueryBuilder().append(orm.Computer,) - self.assertEqual(qb.count(), 1) - - qb = orm.QueryBuilder().append(cls=(orm.Computer,)) - self.assertEqual(qb.count(), 1) - - -class TestQueryBuilderCornerCases(AiidaTestCase): - """ - In this class corner cases of QueryBuilder are added. - """ - - def test_computer_json(self): # pylint: disable=no-self-use - """ - In this test we check the correct behavior of QueryBuilder when - retrieving the _metadata with no content. - Note that they are in JSON format in both backends. Forcing the - decoding of a None value leads to an exception (this was the case - under Django). - """ - n1 = orm.CalculationNode() - n1.label = 'node2' - n1.set_attribute('foo', 1) - n1.store() - - # Checking the correct retrieval of _metadata which is - # a JSON field (in both backends). - qb = orm.QueryBuilder() - qb.append(orm.CalculationNode, project=['id'], tag='calc') - qb.append(orm.Computer, project=['id', 'metadata'], outerjoin=True, with_node='calc') - qb.all() - - -class TestAttributes(AiidaTestCase): - - def test_attribute_existence(self): - # I'm storing a value under key whatever: - val = 1. - res_uuids = set() - n1 = orm.Data() - n1.set_attribute('whatever', 3.) - n1.set_attribute('test_case', 'test_attribute_existence') - n1.store() - - # I want all the nodes where whatever is smaller than 1. or there is no such value: - - qb = orm.QueryBuilder() - qb.append( - orm.Data, - filters={ - 'or': [{ - 'attributes': { - '!has_key': 'whatever' - } - }, { - 'attributes.whatever': { - '<': val - } - }], - }, - project='uuid' - ) - res_query = {str(_[0]) for _ in qb.all()} - self.assertEqual(res_query, res_uuids) - - def test_attribute_type(self): - key = 'value_test_attr_type' - n_int, n_float, n_str, n_str2, n_bool, n_arr = [orm.Data() for _ in range(6)] - n_int.set_attribute(key, 1) - n_float.set_attribute(key, 1.0) - n_bool.set_attribute(key, True) - n_str.set_attribute(key, '1') - n_str2.set_attribute(key, 'one') - n_arr.set_attribute(key, [4, 3, 5]) - - for n in (n_str2, n_str, n_int, n_float, n_bool, n_arr): - n.store() - - # Here I am testing which values contain a number 1. - # Both 1 and 1.0 are legitimate values if ask for either 1 or 1.0 - for val in (1.0, 1): - qb = orm.QueryBuilder().append(orm.Node, filters={'attributes.{}'.format(key): val}, project='uuid') - res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_float.uuid, n_int.uuid))) - qb = orm.QueryBuilder().append(orm.Node, filters={'attributes.{}'.format(key): {'>': 0.5}}, project='uuid') - res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_float.uuid, n_int.uuid))) - qb = orm.QueryBuilder().append(orm.Node, filters={'attributes.{}'.format(key): {'<': 1.5}}, project='uuid') - res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_float.uuid, n_int.uuid))) - # Now I am testing the boolean value: - qb = orm.QueryBuilder().append(orm.Node, filters={'attributes.{}'.format(key): True}, project='uuid') - res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_bool.uuid,))) - - qb = orm.QueryBuilder().append(orm.Node, filters={'attributes.{}'.format(key): {'like': '%n%'}}, project='uuid') - res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_str2.uuid,))) - qb = orm.QueryBuilder().append( - orm.Node, filters={'attributes.{}'.format(key): { - 'ilike': 'On%' - }}, project='uuid' - ) - res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_str2.uuid,))) - qb = orm.QueryBuilder().append(orm.Node, filters={'attributes.{}'.format(key): {'like': '1'}}, project='uuid') - res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_str.uuid,))) - qb = orm.QueryBuilder().append(orm.Node, filters={'attributes.{}'.format(key): {'==': '1'}}, project='uuid') - res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_str.uuid,))) - if configuration.PROFILE.database_backend == u'sqlalchemy': - # I can't query the length of an array with Django, - # so I exclude. Not the nicest way, But I would like to keep this piece - # of code because of the initialization part, that would need to be - # duplicated or wrapped otherwise. - qb = orm.QueryBuilder().append( - orm.Node, filters={'attributes.{}'.format(key): { - 'of_length': 3 - }}, project='uuid' - ) - res = [str(_) for _, in qb.all()] - self.assertEqual(set(res), set((n_arr.uuid,))) - - -class QueryBuilderLimitOffsetsTest(AiidaTestCase): - - def test_ordering_limits_offsets_of_results_general(self): - # Creating 10 nodes with an attribute that can be ordered - for i in range(10): - n = orm.Data() - n.set_attribute('foo', i) - n.store() - - qb = orm.QueryBuilder().append(orm.Node, project='attributes.foo').order_by({orm.Node: 'ctime'}) - - res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(10))) - - # Now applying an offset: - qb.offset(5) - res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(5, 10))) - - # Now also applying a limit: - qb.limit(3) - res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(5, 8))) - - # Specifying the order explicitly the order: - qb = orm.QueryBuilder().append(orm.Node, - project='attributes.foo').order_by({orm.Node: { - 'ctime': { - 'order': 'asc' - } - }}) - - res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(10))) - - # Now applying an offset: - qb.offset(5) - res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(5, 10))) - - # Now also applying a limit: - qb.limit(3) - res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(5, 8))) - - # Reversing the order: - qb = orm.QueryBuilder().append(orm.Node, - project='attributes.foo').order_by({orm.Node: { - 'ctime': { - 'order': 'desc' - } - }}) - - res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(9, -1, -1))) - - # Now applying an offset: - qb.offset(5) - res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(4, -1, -1))) - - # Now also applying a limit: - qb.limit(3) - res = next(zip(*qb.all())) - self.assertEqual(res, tuple(range(4, 1, -1))) - - -class QueryBuilderJoinsTests(AiidaTestCase): - - def test_joins1(self): - # Creating n1, who will be a parent: - parent = orm.Data() - parent.label = 'mother' - parent.store() - - good_child = orm.CalculationNode() - good_child.label = 'good_child' - good_child.set_attribute('is_good', True) - - bad_child = orm.CalculationNode() - bad_child.label = 'bad_child' - bad_child.set_attribute('is_good', False) - - unrelated = orm.CalculationNode() - unrelated.label = 'unrelated' - unrelated.store() - - good_child.add_incoming(parent, link_type=LinkType.INPUT_CALC, link_label='parent') - bad_child.add_incoming(parent, link_type=LinkType.INPUT_CALC, link_label='parent') - good_child.store() - bad_child.store() - - # Using a standard inner join - qb = orm.QueryBuilder() - qb.append(orm.Node, tag='parent') - qb.append(orm.Node, tag='children', project='label', filters={'attributes.is_good': True}) - self.assertEqual(qb.count(), 1) - - qb = orm.QueryBuilder() - qb.append(orm.Node, tag='parent') - qb.append(orm.Node, tag='children', outerjoin=True, project='label', filters={'attributes.is_good': True}) - self.assertEqual(qb.count(), 1) - - def test_joins2(self): - # Creating n1, who will be a parent: - - students = [orm.Data() for i in range(10)] - advisors = [orm.CalculationNode() for i in range(3)] - for i, a in enumerate(advisors): - a.label = 'advisor {}'.format(i) - a.set_attribute('advisor_id', i) - - for n in advisors + students: - n.store() - - # advisor 0 get student 0, 1 - for i in (0, 1): - students[i].add_incoming(advisors[0], link_type=LinkType.CREATE, link_label='is_advisor_{}'.format(i)) - - # advisor 1 get student 3, 4 - for i in (3, 4): - students[i].add_incoming(advisors[1], link_type=LinkType.CREATE, link_label='is_advisor_{}'.format(i)) - - # advisor 2 get student 5, 6, 7 - for i in (5, 6, 7): - students[i].add_incoming(advisors[2], link_type=LinkType.CREATE, link_label='is_advisor_{}'.format(i)) - - # let's add a differnt relationship than advisor: - students[9].add_incoming(advisors[2], link_type=LinkType.CREATE, link_label='lover') - - self.assertEqual( - orm.QueryBuilder().append( - orm.Node - ).append(orm.Node, edge_filters={ - 'label': { - 'like': 'is\\_advisor\\_%' - } - }, tag='student').count(), 7 - ) - - for adv_id, number_students in zip(list(range(3)), (2, 2, 3)): - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'attributes.advisor_id': adv_id - }).append(orm.Node, edge_filters={ - 'label': { - 'like': 'is\\_advisor\\_%' - } - }, tag='student').count(), number_students - ) - - def test_joins3_user_group(self): - # Create another user - new_email = 'newuser@new.n' - user = orm.User(email=new_email).store() - - # Create a group that belongs to that user - group = orm.Group(label='node_group') - group.user = user - group.store() - - # Search for the group of the user - qb = orm.QueryBuilder() - qb.append(orm.User, tag='user', filters={'id': {'==': user.id}}) - qb.append(orm.Group, with_user='user', filters={'id': {'==': group.id}}) - self.assertEqual(qb.count(), 1, 'The expected group that belongs to ' 'the selected user was not found.') - - # Search for the user that owns a group - qb = orm.QueryBuilder() - qb.append(orm.Group, tag='group', filters={'id': {'==': group.id}}) - qb.append(orm.User, with_group='group', filters={'id': {'==': user.id}}) - - self.assertEqual(qb.count(), 1, 'The expected user that owns the ' 'selected group was not found.') - - def test_joins_group_node(self): - """ - This test checks that the querying for the nodes that belong to a group works correctly (using QueryBuilder). - This is important for the Django backend with the use of aldjemy for the Django to SQLA schema translation. - Since this is not backend specific test (even if it is mainly used to test the querying of Django backend - with QueryBuilder), we keep it at the general tests (ran by both backends). - """ - new_email = 'newuser@new.n2' - user = orm.User(email=new_email).store() - - # Create a group that belongs to that user - group = orm.Group(label='node_group_2') - group.user = user - group.store() - - # Create nodes and add them to the created group - n1 = orm.Data() - n1.label = 'node1' - n1.set_attribute('foo', ['hello', 'goodbye']) - n1.store() - - n2 = orm.CalculationNode() - n2.label = 'node2' - n2.set_attribute('foo', 1) - n2.store() - - n3 = orm.Data() - n3.label = 'node3' - n3.set_attribute('foo', 1.0000) # Stored as fval - n3.store() - - n4 = orm.CalculationNode() - n4.label = 'node4' - n4.set_attribute('foo', 'bar') - n4.store() - - group.add_nodes([n1, n2, n3, n4]) - - # Check that the nodes are in the group - qb = orm.QueryBuilder() - qb.append(orm.Node, tag='node', project=['id']) - qb.append(orm.Group, with_node='node', filters={'id': {'==': group.id}}) - self.assertEqual(qb.count(), 4, 'There should be 4 nodes in the group') - id_res = [_ for [_] in qb.all()] - for curr_id in [n1.id, n2.id, n3.id, n4.id]: - self.assertIn(curr_id, id_res) - - -class QueryBuilderPath(AiidaTestCase): - - def test_query_path(self): - # pylint: disable=too-many-statements - - q = self.backend.query_manager - n1 = orm.Data() - n1.label = 'n1' - n2 = orm.CalculationNode() - n2.label = 'n2' - n3 = orm.Data() - n3.label = 'n3' - n4 = orm.Data() - n4.label = 'n4' - n5 = orm.CalculationNode() - n5.label = 'n5' - n6 = orm.Data() - n6.label = 'n6' - n7 = orm.CalculationNode() - n7.label = 'n7' - n8 = orm.Data() - n8.label = 'n8' - n9 = orm.Data() - n9.label = 'n9' - - # I create a strange graph, inserting links in a order - # such that I often have to create the transitive closure - # between two graphs - n3.add_incoming(n2, link_type=LinkType.CREATE, link_label='link1') - n2.add_incoming(n1, link_type=LinkType.INPUT_CALC, link_label='link2') - n5.add_incoming(n3, link_type=LinkType.INPUT_CALC, link_label='link3') - n5.add_incoming(n4, link_type=LinkType.INPUT_CALC, link_label='link4') - n4.add_incoming(n2, link_type=LinkType.CREATE, link_label='link5') - n7.add_incoming(n6, link_type=LinkType.INPUT_CALC, link_label='link6') - n8.add_incoming(n7, link_type=LinkType.CREATE, link_label='link7') - - for node in [n1, n2, n3, n4, n5, n6, n7, n8, n9]: - node.store() - - # There are no parents to n9, checking that - self.assertEqual(set([]), set(q.get_all_parents([n9.pk]))) - # There is one parent to n6 - self.assertEqual({(_,) for _ in (n6.pk,)}, {tuple(_) for _ in q.get_all_parents([n7.pk])}) - # There are several parents to n4 - self.assertEqual({(_.pk,) for _ in (n1, n2)}, {tuple(_) for _ in q.get_all_parents([n4.pk])}) - # There are several parents to n5 - self.assertEqual({(_.pk,) for _ in (n1, n2, n3, n4)}, {tuple(_) for _ in q.get_all_parents([n5.pk])}) - - # Yet, no links from 1 to 8 - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n1.pk - }, tag='anc').append(orm.Node, with_ancestors='anc', filters={ - 'id': n8.pk - }).count(), 0 - ) - - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n8.pk - }, tag='desc').append(orm.Node, with_descendants='desc', filters={ - 'id': n1.pk - }).count(), 0 - ) - - n6.add_incoming(n5, link_type=LinkType.CREATE, link_label='link1') - # Yet, now 2 links from 1 to 8 - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n1.pk - }, tag='anc').append(orm.Node, with_ancestors='anc', filters={ - 'id': n8.pk - }).count(), 2 - ) - - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n8.pk - }, tag='desc').append(orm.Node, with_descendants='desc', filters={ - 'id': n1.pk - }).count(), 2 - ) - - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n8.pk - }, tag='desc').append( - orm.Node, - with_descendants='desc', - filters={ - 'id': n1.pk - }, - edge_filters={ - 'depth': { - '<': 6 - } - }, - ).count(), 2 - ) - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n8.pk - }, tag='desc').append( - orm.Node, - with_descendants='desc', - filters={ - 'id': n1.pk - }, - edge_filters={ - 'depth': 5 - }, - ).count(), 2 - ) - self.assertEqual( - orm.QueryBuilder().append(orm.Node, filters={ - 'id': n8.pk - }, tag='desc').append( - orm.Node, - with_descendants='desc', - filters={ - 'id': n1.pk - }, - edge_filters={ - 'depth': { - '<': 5 - } - }, - ).count(), 0 - ) - - # TODO write a query that can filter certain paths by traversed ID # pylint: disable=fixme - qb = orm.QueryBuilder().append( - orm.Node, - filters={ - 'id': n8.pk - }, - tag='desc', - ).append(orm.Node, with_descendants='desc', edge_project='path', filters={'id': n1.pk}) - queried_path_set = {frozenset(p) for p, in qb.all()} - - paths_there_should_be = { - frozenset([n1.pk, n2.pk, n3.pk, n5.pk, n6.pk, n7.pk, n8.pk]), - frozenset([n1.pk, n2.pk, n4.pk, n5.pk, n6.pk, n7.pk, n8.pk]) - } - - self.assertTrue(queried_path_set == paths_there_should_be) - - qb = orm.QueryBuilder().append(orm.Node, filters={ - 'id': n1.pk - }, tag='anc').append(orm.Node, with_ancestors='anc', filters={'id': n8.pk}, edge_project='path') - - self.assertEqual({frozenset(p) for p, in qb.all()}, { - frozenset([n1.pk, n2.pk, n3.pk, n5.pk, n6.pk, n7.pk, n8.pk]), - frozenset([n1.pk, n2.pk, n4.pk, n5.pk, n6.pk, n7.pk, n8.pk]) - }) - - # This part of the test is no longer possible as the nodes have already been stored and the previous parts of - # the test rely on this, which means however, that here, no more links can be added as that will raise. - - # n7.add_incoming(n9, link_type=LinkType.INPUT_CALC, link_label='link0') - # # Still two links... - - # self.assertEqual( - # orm.QueryBuilder().append(orm.Node, filters={ - # 'id': n1.pk - # }, tag='anc').append(orm.Node, with_ancestors='anc', filters={ - # 'id': n8.pk - # }).count(), 2) - - # self.assertEqual( - # orm.QueryBuilder().append(orm.Node, filters={ - # 'id': n8.pk - # }, tag='desc').append(orm.Node, with_descendants='desc', filters={ - # 'id': n1.pk - # }).count(), 2) - # n9.add_incoming(n5, link_type=LinkType.CREATE, link_label='link6') - # # And now there should be 4 nodes - - # self.assertEqual( - # orm.QueryBuilder().append(orm.Node, filters={ - # 'id': n1.pk - # }, tag='anc').append(orm.Node, with_ancestors='anc', filters={ - # 'id': n8.pk - # }).count(), 4) - - # self.assertEqual( - # orm.QueryBuilder().append(orm.Node, filters={ - # 'id': n8.pk - # }, tag='desc').append(orm.Node, with_descendants='desc', filters={ - # 'id': n1.pk - # }).count(), 4) - - # qb = orm.QueryBuilder().append( - # orm.Node, filters={ - # 'id': n1.pk - # }, tag='anc').append( - # orm.Node, with_ancestors='anc', filters={'id': n8.pk}, edge_tag='edge') - # qb.add_projection('edge', 'depth') - # self.assertTrue(set(next(zip(*qb.all()))), set([5, 6])) - # qb.add_filter('edge', {'depth': 5}) - # self.assertTrue(set(next(zip(*qb.all()))), set([5])) - - -class TestConsistency(AiidaTestCase): - - def test_create_node_and_query(self): - """ - Testing whether creating nodes within a iterall iteration changes the results. - """ - for _i in range(100): - n = orm.Data() - n.store() - - for idx, _item in enumerate( - orm.QueryBuilder().append(orm.Node, project=['id', 'label']).iterall(batch_size=10) - ): - if idx % 10 == 10: - n = orm.Data() - n.store() - self.assertEqual(idx, 99) # pylint: disable=undefined-loop-variable - self.assertTrue(len(orm.QueryBuilder().append(orm.Node, project=['id', 'label']).all(batch_size=10)) > 99) - - def test_len_results(self): - """ - Test whether the len of results matches the count returned. - See also https://github.com/aiidateam/aiida-core/issues/1600 - SQLAlchemy has a deduplication strategy that leads to strange behavior, tested against here - """ - parent = orm.CalculationNode().store() - # adding 5 links going out: - for inode in range(5): - output_node = orm.Data().store() - output_node.add_incoming(parent, link_type=LinkType.CREATE, link_label='link_{}'.format(inode)) - for projection in ('id', '*'): - qb = orm.QueryBuilder() - qb.append(orm.CalculationNode, filters={'id': parent.id}, tag='parent', project=projection) - qb.append(orm.Data, with_incoming='parent') - self.assertEqual(len(qb.all()), qb.count()) - - -class TestManager(AiidaTestCase): - - def test_statistics(self): - """ - Test if the statistics query works properly. - - I try to implement it in a way that does not depend on the past state. - """ - from collections import defaultdict - - # pylint: disable=protected-access - - def store_and_add(n, statistics): - n.store() - statistics['total'] += 1 - statistics['types'][n._plugin_type_string] += 1 # pylint: disable=no-member - statistics['ctime_by_day'][n.ctime.strftime('%Y-%m-%d')] += 1 - - qmanager = self.backend.query_manager - current_db_statistics = qmanager.get_creation_statistics() - types = defaultdict(int) - types.update(current_db_statistics['types']) - ctime_by_day = defaultdict(int) - ctime_by_day.update(current_db_statistics['ctime_by_day']) - - expected_db_statistics = {'total': current_db_statistics['total'], 'types': types, 'ctime_by_day': ctime_by_day} - - store_and_add(orm.Data(), expected_db_statistics) - store_and_add(orm.Dict(), expected_db_statistics) - store_and_add(orm.Dict(), expected_db_statistics) - store_and_add(orm.CalculationNode(), expected_db_statistics) - - new_db_statistics = qmanager.get_creation_statistics() - # I only check a few fields - new_db_statistics = {k: v for k, v in new_db_statistics.items() if k in expected_db_statistics} - - expected_db_statistics = { - k: dict(v) if isinstance(v, defaultdict) else v for k, v in expected_db_statistics.items() - } - - self.assertEqual(new_db_statistics, expected_db_statistics) - - def test_statistics_default_class(self): - """ - Test if the statistics query works properly. - - I try to implement it in a way that does not depend on the past state. - """ - from collections import defaultdict - - def store_and_add(n, statistics): - n.store() - statistics['total'] += 1 - statistics['types'][n._plugin_type_string] += 1 # pylint: disable=no-member,protected-access - statistics['ctime_by_day'][n.ctime.strftime('%Y-%m-%d')] += 1 - - current_db_statistics = self.backend.query_manager.get_creation_statistics() - types = defaultdict(int) - types.update(current_db_statistics['types']) - ctime_by_day = defaultdict(int) - ctime_by_day.update(current_db_statistics['ctime_by_day']) - - expected_db_statistics = {'total': current_db_statistics['total'], 'types': types, 'ctime_by_day': ctime_by_day} - - store_and_add(orm.Data(), expected_db_statistics) - store_and_add(orm.Dict(), expected_db_statistics) - store_and_add(orm.Dict(), expected_db_statistics) - store_and_add(orm.CalculationNode(), expected_db_statistics) - - new_db_statistics = self.backend.query_manager.get_creation_statistics() - # I only check a few fields - new_db_statistics = {k: v for k, v in new_db_statistics.items() if k in expected_db_statistics} - - expected_db_statistics = { - k: dict(v) if isinstance(v, defaultdict) else v for k, v in expected_db_statistics.items() - } - - self.assertEqual(new_db_statistics, expected_db_statistics) - - -class TestDoubleStar(AiidaTestCase): - """ - In this test class we check if QueryBuilder returns the correct results - when double star is provided as projection. - """ - - def test_statistics_default_class(self): - - # The expected result - # pylint: disable=no-member - expected_dict = { - u'description': self.computer.description, - u'scheduler_type': self.computer.get_scheduler_type(), - u'hostname': self.computer.hostname, - u'uuid': self.computer.uuid, - u'name': self.computer.name, - u'transport_type': self.computer.get_transport_type(), - u'id': self.computer.id, - u'metadata': self.computer.get_metadata(), - } - - qb = orm.QueryBuilder() - qb.append(orm.Computer, project=['**']) - # We expect one result - self.assertEqual(qb.count(), 1) - - # Get the one result record and check that the returned - # data are correct - res = list(qb.dict()[0].values())[0] - self.assertDictEqual(res, expected_dict) - - # Ask the same query as above using queryhelp - qh = {'project': {'computer': ['**']}, 'path': [{'tag': 'computer', 'cls': orm.Computer}]} - qb = orm.QueryBuilder(**qh) - # We expect one result - self.assertEqual(qb.count(), 1) - - # Get the one result record and check that the returned - # data are correct - res = list(qb.dict()[0].values())[0] - self.assertDictEqual(res, expected_dict) diff --git a/aiida/orm/querybuilder.py b/aiida/orm/querybuilder.py index 32193618dc..c93b53cd9f 100644 --- a/aiida/orm/querybuilder.py +++ b/aiida/orm/querybuilder.py @@ -1766,6 +1766,23 @@ def get_json_compatible_queryhelp(self): 'offset': self._offset, }) + @property + def queryhelp(self): + """queryhelp dictionary correspondig to QueryBuilder instance. + + The queryhelp can be used to create a copy of the QueryBuilder instance like so:: + + qb = QueryBuilder(limit=3).append(StructureData, project='id').order_by({StructureData:'id'}) + qb2=QueryBuilder(**qb.queryhelp) + + :return: a queryhelp dictionary + """ + return self.get_json_compatible_queryhelp() + + def __deepcopy__(self, memodict={}): + qb = type(self)(**self.queryhelp) + return qb + def _build_order(self, alias, entitytag, entityspec): """ Build the order parameter of the query