Skip to content

Commit

Permalink
Add support for a tuple of classes or types in the QueryBuilder.append (
Browse files Browse the repository at this point in the history
#1607)

The `append` method of the `QueryBuilder` now accepts a tuple, list or set of orm classes
for the `cls` and `type` keyword argument, with the one restriction that all classes share a
common base class. This allows the user to append a join for a set of classes with the same
projection and filtering rules
  • Loading branch information
sphuber committed May 30, 2018
1 parent 78a7a4c commit 8a33ad0
Show file tree
Hide file tree
Showing 6 changed files with 270 additions and 222 deletions.
96 changes: 0 additions & 96 deletions aiida/backends/djsite/querybuilder_django/querybuilder_django.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,102 +507,6 @@ def get_aiida_res(self, key, res):
returnval = res
return returnval

def get_ormclass(self, cls, ormclasstype):
"""
Return the valid ormclass for the connections
"""
# Checks whether valid cls and ormclasstype are done before

# If it is a class:
if cls:
# Nodes:
if issubclass(cls, self.Node):
# If something pass an ormclass node
# Users wouldn't do that, by why not...
ormclasstype = self.AiidaNode._plugin_type_string
query_type_string = self.AiidaNode._query_type_string
ormclass = cls
elif issubclass(cls, self.AiidaNode):
ormclasstype = cls._plugin_type_string
query_type_string = cls._query_type_string
ormclass = self.Node
# Groups:
elif issubclass(cls, self.Group):
ormclasstype = 'group'
query_type_string = None
ormclass = cls
elif issubclass(cls, self.AiidaGroup):
ormclasstype = 'group'
query_type_string = None
ormclass = self.Group
# Computers:
elif issubclass(cls, self.Computer):
ormclasstype = 'computer'
query_type_string = None
ormclass = cls
elif issubclass(cls, self.AiidaComputer):
ormclasstype = 'computer'
query_type_string = None
ormclass = self.Computer

# Users
elif issubclass(cls, self.User):
ormclasstype = 'user'
query_type_string = None
ormclass = cls
elif issubclass(cls, self.AiidaUser):
ormclasstype = 'user'
query_type_string = None
ormclass = self.User
else:
raise InputValidationError(
"\n\n\n"
"I do not know what to do with {}"
"\n\n\n".format(cls)
)
# If it is not a class
else:
if ormclasstype.lower() == 'group':
ormclasstype = ormclasstype.lower()
query_type_string = None
ormclass = self.Group
elif ormclasstype.lower() == 'computer':
ormclasstype = ormclasstype.lower()
query_type_string = None
ormclass = self.Computer
elif ormclasstype.lower() == 'user':
ormclasstype = ormclasstype.lower()
query_type_string = None
ormclass = self.User
else:
# At this point, it has to be a node.
# The only valid string at this point is a string
# that matches exactly the _plugin_type_string
# of a node class
from aiida.plugins.loader import get_plugin_type_from_type_string, load_plugin
ormclass = self.Node
try:
plugin_type = get_plugin_type_from_type_string(ormclasstype)

# I want to check at this point if that is a valid class,
# so I use the load_plugin to load the plugin class
# and use the classes _plugin_type_string attribute
# In the future, assuming the user knows what he or she is doing
# we could remove that check
PluginClass = load_plugin(plugin_type)
except (DbContentError, MissingPluginError) as e:
raise InputValidationError(
"\nYou provide a vertice of the path with\n"
"type={}\n"
"But that string is not a valid type string\n"
"Exception raise during check\n"
"{}".format(ormclasstype, e)
)

ormclasstype = PluginClass._plugin_type_string
query_type_string = PluginClass._query_type_string

return ormclass, ormclasstype, query_type_string

def yield_per(self, query, batch_size):
"""
Expand Down
3 changes: 0 additions & 3 deletions aiida/backends/general/querybuilder_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,6 @@ def get_aiida_res(self, key, res):
"""
pass

@abstractmethod
def get_ormclass(self, cls, ormclasstype):
pass


@abstractmethod
Expand Down
97 changes: 0 additions & 97 deletions aiida/backends/sqlalchemy/querybuilder_sqla.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,103 +490,6 @@ def get_aiida_res(self, key, res):
returnval = res
return returnval

def get_ormclass(self, cls, ormclasstype):
"""
Return the valid ormclass for the connections
"""
# Checks whether valid cls and ormclasstype are done before

# If it is a class:
if cls:
# Nodes:
if issubclass(cls, self.Node):
# If something pass an ormclass node
# Users wouldn't do that, by why not...
ormclasstype = self.AiidaNode._plugin_type_string
query_type_string = self.AiidaNode._query_type_string
ormclass = cls
elif issubclass(cls, self.AiidaNode):
ormclasstype = cls._plugin_type_string
query_type_string = cls._query_type_string
ormclass = self.Node
# Groups:
elif issubclass(cls, self.Group):
ormclasstype = 'group'
query_type_string = None
ormclass = cls
elif issubclass(cls, self.AiidaGroup):
ormclasstype = 'group'
query_type_string = None
ormclass = self.Group
# Computers:
elif issubclass(cls, self.Computer):
ormclasstype = 'computer'
query_type_string = None
ormclass = cls
elif issubclass(cls, self.AiidaComputer):
ormclasstype = 'computer'
query_type_string = None
ormclass = self.Computer

# Users
elif issubclass(cls, self.User):
ormclasstype = 'user'
query_type_string = None
ormclass = cls
elif issubclass(cls, self.AiidaUser):
ormclasstype = 'user'
query_type_string = None
ormclass = self.User
else:
raise InputValidationError(
"\n\n\n"
"I do not know what to do with {}"
"\n\n\n".format(cls)
)
# If it is not a class
else:
if ormclasstype.lower() == 'group':
ormclasstype = ormclasstype.lower()
query_type_string = None
ormclass = self.Group
elif ormclasstype.lower() == 'computer':
ormclasstype = ormclasstype.lower()
query_type_string = None
ormclass = self.Computer
elif ormclasstype.lower() == 'user':
ormclasstype = ormclasstype.lower()
query_type_string = None
ormclass = self.User
else:
# At this point, it has to be a node.
# The only valid string at this point is a string
# that matches exactly the _plugin_type_string
# of a node class
from aiida.plugins.loader import get_plugin_type_from_type_string, load_plugin
ormclass = self.Node
try:
plugin_type = get_plugin_type_from_type_string(ormclasstype)

# I want to check at this point if that is a valid class,
# so I use the load_plugin to load the plugin class
# and use the classes _plugin_type_string attribute
# In the future, assuming the user knows what he or she is doing
# we could remove that check
PluginClass = load_plugin(plugin_type)
except (DbContentError, MissingPluginError) as e:
raise InputValidationError(
"\nYou provide a vertice of the path with\n"
"type={}\n"
"But that string is not a valid type string\n"
"Exception raise during check\n"
"{}".format(ormclasstype, e)
)

ormclasstype = PluginClass._plugin_type_string
query_type_string = PluginClass._query_type_string

return ormclass, ormclasstype, query_type_string

def yield_per(self, query, batch_size):
"""
:param count: Number of rows to yield per step
Expand Down
56 changes: 56 additions & 0 deletions aiida/backends/tests/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,22 @@ def test_subclassing(self):
qb = QueryBuilder().append(cls, filters={'attributes.cat': 'miau'}, subclassing=False)
self.assertEqual(qb.count(), 1)

# Now I am testing the subclassing with tuples:
qb = QueryBuilder().append(cls=(StructureData, ParameterData), filters={'attributes.cat':'miau'})
self.assertEqual(qb.count(), 2)
qb = QueryBuilder().append(type=('data.structure.StructureData.', 'data.parameter.ParameterData.'), filters={'attributes.cat':'miau'})
self.assertEqual(qb.count(), 2)
qb = QueryBuilder().append(cls=(StructureData, ParameterData), filters={'attributes.cat':'miau'}, subclassing=False)
self.assertEqual(qb.count(), 2)
qb = QueryBuilder().append(cls=(StructureData, Data), filters={'attributes.cat':'miau'}, )
self.assertEqual(qb.count(), 3)
qb = QueryBuilder().append(type=('data.structure.StructureData.', 'data.parameter.ParameterData.'),
filters={'attributes.cat':'miau'}, subclassing=False)
self.assertEqual(qb.count(), 2)
qb = QueryBuilder().append(type=('data.structure.StructureData.', 'data.Data.'),
filters={'attributes.cat':'miau'}, subclassing=False)
self.assertEqual(qb.count(), 2)

def test_list_behavior(self):
from aiida.orm import Node
from aiida.orm.querybuilder import QueryBuilder
Expand Down Expand Up @@ -415,6 +431,46 @@ def test_tags(self):
])


class TestQueryHelp(AiidaTestCase):
def test_queryhelp(self):
"""
Here I test the queryhelp by seeing whether results are the same as using the append method.
I also check passing of tuples.
"""

from aiida.orm.data.structure import StructureData
from aiida.orm.data.parameter import ParameterData
from aiida.orm.data import Data
from aiida.orm.querybuilder import QueryBuilder
for cls in (StructureData, ParameterData, Data):
obj = cls()
obj._set_attr('foo-qh2', 'bar')
obj.store()

for cls, expected_count, subclassing in (
(StructureData, 1, True),
(ParameterData, 1, True),
(Data, 3, True),
(Data, 1, False),
((ParameterData, StructureData), 2, True),
((ParameterData, StructureData), 2, False),
((ParameterData, Data), 2, False),
((ParameterData, Data), 3, True),
((ParameterData, Data, StructureData), 3, False),
):
qb = QueryBuilder()
qb.append(cls, filters={'attributes.foo-qh2':'bar'}, subclassing=subclassing, project='uuid')
self.assertEqual(qb.count(), expected_count)

qh = qb.get_json_compatible_queryhelp()
qb_new = QueryBuilder(**qh)
self.assertEqual(qb_new.count(), expected_count)
self.assertEqual(
sorted([uuid for uuid, in qb.all()]),
sorted([uuid for uuid, in qb_new.all()]))



class TestQueryBuilderCornerCases(AiidaTestCase):
"""
In this class corner cases of QueryBuilder are added.
Expand Down
Loading

0 comments on commit 8a33ad0

Please sign in to comment.