Skip to content

Commit

Permalink
Merge pull request #1271 from sphuber/fix_1270_get_db_input_links
Browse files Browse the repository at this point in the history
Ensure that get_inputs methods respect link_type argument for SqlAlchemy
  • Loading branch information
sphuber authored Mar 12, 2018
2 parents a96b118 + 991596e commit 41cdb29
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 49 deletions.
55 changes: 55 additions & 0 deletions aiida/backends/tests/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1969,3 +1969,58 @@ def test_check_single_calc_source(self):
# more than one input to the same data object!
with self.assertRaises(ValueError):
d1.add_link_from(calc2, link_type=LinkType.CREATE)

def test_node_get_inputs_outputs_link_type_stored(self):
"""
Test that the link_type parameter in get_inputs and get_outputs only
returns those nodes with the correct link type for stored nodes
"""
node_origin = Node().store()
node_caller = Node().store()
node_called = Node().store()
node_input = Node().store()
node_output = Node().store()
node_return = Node().store()

# Input links of node_origin
node_origin.add_link_from(node_caller, label='caller', link_type=LinkType.CALL)
node_origin.add_link_from(node_input, label='input', link_type=LinkType.INPUT)

# Output links of node_origin
node_called.add_link_from(node_origin, label='called', link_type=LinkType.CALL)
node_output.add_link_from(node_origin, label='output', link_type=LinkType.CREATE)
node_return.add_link_from(node_origin, label='return', link_type=LinkType.RETURN)

# All inputs and outputs
self.assertEquals(len(node_origin.get_inputs()), 2)
self.assertEquals(len(node_origin.get_outputs()), 3)

# Link specific inputs
self.assertEquals(len(node_origin.get_inputs(link_type=LinkType.CALL)), 1)
self.assertEquals(len(node_origin.get_inputs(link_type=LinkType.INPUT)), 1)

# Link specific outputs
self.assertEquals(len(node_origin.get_outputs(link_type=LinkType.CALL)), 1)
self.assertEquals(len(node_origin.get_outputs(link_type=LinkType.CREATE)), 1)
self.assertEquals(len(node_origin.get_outputs(link_type=LinkType.RETURN)), 1)

def test_node_get_inputs_link_type_unstored(self):
"""
Test that the link_type parameter in get_inputs only returns those nodes with
the correct link type for unstored nodes. We don't check this analogously for
get_outputs because there is not output links cache
"""
node_origin = Node()
node_caller = Node()
node_input = Node()

# Input links of node_origin
node_origin.add_link_from(node_caller, label='caller', link_type=LinkType.CALL)
node_origin.add_link_from(node_input, label='input', link_type=LinkType.INPUT)

# All inputs and outputs
self.assertEquals(len(node_origin.get_inputs()), 2)

# Link specific inputs
self.assertEquals(len(node_origin.get_inputs(link_type=LinkType.CALL)), 1)
self.assertEquals(len(node_origin.get_inputs(link_type=LinkType.INPUT)), 1)
69 changes: 32 additions & 37 deletions aiida/orm/implementation/general/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,28 +647,24 @@ def get_outputs_dict(self, link_type=None):

return new_outputs

def get_inputs(self,
node_type=None,
also_labels=False,
only_in_db=False,
link_type=None):
def get_inputs(self, node_type=None, also_labels=False, only_in_db=False, link_type=None):
"""
Return a list of nodes that enter (directly) in this node
:param node_type: If specified, should be a class, and it filters only
elements of that specific type (or a subclass of 'type')
:param also_labels: If False (default) only return a list of input nodes.
If True, return a list of tuples, where each tuple has the
following format: ('label', Node), with 'label' the link label,
and Node a Node instance or subclass
If True, return a list of tuples, where each tuple has the
following format: ('label', Node), with 'label' the link label,
and Node a Node instance or subclass
:param only_in_db: Return only the inputs that are in the database,
ignoring those that are in the local cache. Otherwise, return
all links.
ignoring those that are in the local cache. Otherwise, return
all links.
:param link_type: Only get inputs of this link type, if None then
returns all inputs of all link types.
returns all inputs of all link types.
"""
if link_type is not None and not isinstance(link_type, LinkType):
raise TypeError("link_type should be a LinkType object")
raise TypeError('link_type should be a LinkType object')

inputs_list = self._get_db_input_links(link_type=link_type)

Expand All @@ -678,19 +674,18 @@ def get_inputs(self,

for label, v in self._inputlinks_cache.iteritems():
src = v[0]
input_link_type = v[1]
if label in input_list_keys:
raise InternalError(
"There exist a link with the same name "
"'{}' both in the DB and in the internal "
"cache for node pk= {}!".format(label, self.pk))
inputs_list.append((label, src))
raise InternalError("There exist a link with the same name '{}' both in the DB "
"and in the internal cache for node pk= {}!".format(label, self.pk))

if link_type is None or input_link_type is link_type:
inputs_list.append((label, src))

if node_type is None:
filtered_list = inputs_list
else:
filtered_list = [
i for i in inputs_list if isinstance(i[1], node_type)
]
filtered_list = [i for i in inputs_list if isinstance(i[1], node_type)]

if also_labels:
return list(filtered_list)
Expand All @@ -708,33 +703,33 @@ def _get_db_input_links(self, link_type):
"""
pass

# pylint: disable=no-else-return
@override
def get_outputs(self, type=None, also_labels=False, link_type=None):
def get_outputs(self, node_type=None, also_labels=False, link_type=None):
"""
Return a list of nodes that exit (directly) from this node
:param type: if specified, should be a class, and it filters only
elements of that specific type (or a subclass of 'type')
:param node_type: if specified, should be a class, and it filters only
elements of that specific node_type (or a subclass of 'node_type')
:param also_labels: if False (default) only return a list of input nodes.
If True, return a list of tuples, where each tuple has the
following format: ('label', Node), with 'label' the link label,
and Node a Node instance or subclass
If True, return a list of tuples, where each tuple has the
following format: ('label', Node), with 'label' the link label,
and Node a Node instance or subclass
:param link_type: Only return outputs connected by links of this type.
"""
if link_type is not None and not isinstance(link_type, LinkType):
raise TypeError('link_type should be a LinkType object')

outputs_list = self._get_db_output_links(link_type=link_type)

if type is None:
if also_labels:
return list(outputs_list)
else:
return [i[1] for i in outputs_list]
if node_type is None:
filtered_list = outputs_list
else:
filtered_list = (i for i in outputs_list if isinstance(i[1], type))
if also_labels:
return list(filtered_list)
else:
return [i[1] for i in filtered_list]
filtered_list = (i for i in outputs_list if isinstance(i[1], node_type))

if also_labels:
return list(filtered_list)

return [i[1] for i in filtered_list]

@abstractmethod
def _get_db_output_links(self, link_type):
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/implementation/sqlalchemy/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def _get_db_input_links(self, link_type):
if link_type is not None:
link_filter['type'] = link_type.value
return [(i.label, i.input.get_aiida_class()) for i in
DbLink.query.filter_by(output=self.dbnode).distinct().all()]
DbLink.query.filter_by(**link_filter).distinct().all()]


def _get_db_output_links(self, link_type):
Expand Down
13 changes: 6 additions & 7 deletions aiida/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,20 +114,19 @@ def get_result_parameterdata_node(self):
from aiida.orm.data.parameter import ParameterData
from aiida.common.exceptions import NotExistent

out_parameters = self._calc.get_outputs(type=ParameterData, also_labels=True)
out_parameterdata = [i[1] for i in out_parameters
if i[0] == self.get_linkname_outparams()]
out_parameters = self._calc.get_outputs(node_type=ParameterData, also_labels=True)
out_parameter_data = [i[1] for i in out_parameters if i[0] == self.get_linkname_outparams()]

if not out_parameterdata:
if not out_parameter_data:
raise NotExistent("No output .res ParameterData node found")
elif len(out_parameterdata) > 1:
elif len(out_parameter_data) > 1:
from aiida.common.exceptions import UniquenessError

raise UniquenessError("Output ParameterData should be found once, "
"found it instead {} times"
.format(len(out_parameterdata)))
.format(len(out_parameter_data)))

return out_parameterdata[0]
return out_parameter_data[0]

def get_result_keys(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion aiida/tools/dbexporters/tcod.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,7 +976,7 @@ def export_cifnode(what, parameters=None, trajectory_index=None,
raise ValueError("Supplied parameters are not an "
"instance of ParameterData")
elif calc is not None:
params = calc.get_outputs(type=ParameterData, link_type=LinkType.CREATE)
params = calc.get_outputs(node_type=ParameterData, link_type=LinkType.CREATE)
if len(params) == 1:
parameters = params[0]
elif len(params) > 0:
Expand Down
2 changes: 1 addition & 1 deletion aiida/workflows/wf_XTiO3.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def final_step(self):
optimal_alat = self.get_attribute("optimal_alat")

opt_calc = self.get_step_calculations(self.optimize)[0] # .get_calculations()[0]
opt_e = opt_calc.get_outputs(type=ParameterData)[0].get_dict()['energy']
opt_e = opt_calc.get_outputs(node_type=ParameterData)[0].get_dict()['energy']

self.append_to_report(x_material + "Ti03 optimal with a=" + str(optimal_alat) + ", e=" + str(opt_e))

Expand Down
4 changes: 2 additions & 2 deletions docs/source/old_workflows/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ aside to the final optimal cell parameter value.
optimal_alat = self.get_attribute("optimal_alat")
opt_calc = self.get_step_calculations(self.optimize)[0] #.get_calculations()[0]
opt_e = opt_calc.get_outputs(type=ParameterData)[0].get_dict()['energy']
opt_e = opt_calc.get_outputs(node_type=ParameterData)[0].get_dict()['energy']
self.append_to_report(x_material+"Ti03 optimal with a="+str(optimal_alat)+", e="+str(opt_e))
Expand Down Expand Up @@ -741,7 +741,7 @@ phonon vibrational frequncies for some XTiO3 materials, namely Ba, Sr and Pb.
run_ph_calcs = self.get_step_calculations(self.run_ph) #.get_calculations()
for c in run_ph_calcs:
dm = c.get_outputs(type=ParameterData)[0].get_dict()['dynamical_matrix_1']
dm = c.get_outputs(node_type=ParameterData)[0].get_dict()['dynamical_matrix_1']
self.append_to_report("Point q: {0} Frequencies: {1}".format(dm['q_point'],dm['frequencies']))
self.next(self.exit)
Expand Down

0 comments on commit 41cdb29

Please sign in to comment.