Skip to content

Commit

Permalink
Raise when returning an unstored Data from a WorkflowNode (#2747)
Browse files Browse the repository at this point in the history
This will happen if one tries to return a `Data` node from within the
body of a `WorkChain` or `workfunction`, which means that they probably
created a new node based on its inputs or data returned by one of the
processes it calls. However, this is strictly forbidden as the
provenance of the new node will be lost. Given that beginning users are
likely to make this mistake, instead of issuing a warning we explicitly
forbid this behavior and raise in the link validation.
  • Loading branch information
sphuber authored Apr 13, 2019
1 parent cc38745 commit d0ab2b8
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 27 deletions.
8 changes: 4 additions & 4 deletions .ci/workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def finalize(self):
if self.should_submit():
self.report('Getting sub-workchain output.')
sub_workchain = self.ctx.workchain[0]
self.out('output', sub_workchain.outputs.output + 1)
self.out('output', Int(sub_workchain.outputs.output + 1).store())
else:
self.report('Bottom-level workchain reached.')
self.out('output', Int(0))
self.out('output', Int(0).store())


class SerializeWorkChain(WorkChain):
Expand Down Expand Up @@ -109,7 +109,7 @@ def do_test(self):
input_list = self.inputs.namespace.input
assert isinstance(input_list, list)
assert not isinstance(input_list, List)
self.out('output', List(list=list(input_list)))
self.out('output', List(list=list(input_list)).store())


class DynamicDbInput(WorkChain):
Expand Down Expand Up @@ -140,7 +140,7 @@ def do_test(self):
assert isinstance(input_non_db, int)
assert not isinstance(input_non_db, Int)
assert isinstance(input_db, Int)
self.out('output', input_db + input_non_db)
self.out('output', Int(input_db + input_non_db).store())


class CalcFunctionRunnerWorkChain(WorkChain):
Expand Down
2 changes: 1 addition & 1 deletion aiida/backends/tests/engine/test_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def define(cls, spec):
spec.output('result', valid_type=Int)

def add(self):
self.out('result', self.inputs.a + self.inputs.b)
self.out('result', Int(self.inputs.a + self.inputs.b).store())


class TestLaunchers(AiidaTestCase):
Expand Down
4 changes: 2 additions & 2 deletions aiida/backends/tests/engine/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,8 @@ def define(cls, spec):

def run(self):
if self.inputs.add_outputs:
self.out('required_string', orm.Str('testing'))
self.out('integer.namespace.two', orm.Int(2))
self.out('required_string', orm.Str('testing').store())
self.out('integer.namespace.two', orm.Int(2).store())

results, node = run_get_node(TestProcess)

Expand Down
10 changes: 10 additions & 0 deletions aiida/backends/tests/engine/test_process_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def function_exit_code(exit_status, exit_message):
def function_excepts(exception):
raise RuntimeError(exception.value)

@workfunction
def function_out_unstored():
return orm.Int(DEFAULT_INT)

self.function_return_input = function_return_input
self.function_return_true = function_return_true
self.function_args = function_args
Expand All @@ -105,6 +109,7 @@ def function_excepts(exception):
self.function_defaults = function_defaults
self.function_exit_code = function_exit_code
self.function_excepts = function_excepts
self.function_out_unstored = function_out_unstored

def tearDown(self):
super(TestProcessFunction, self).tearDown()
Expand Down Expand Up @@ -332,6 +337,11 @@ def test_normal_exception(self):
self.assertTrue(node.is_excepted)
self.assertEqual(node.exception, exception)

def test_function_out_unstored(self):
"""A workfunction that returns an unstored node should raise as it indicates users tried to create data."""
with self.assertRaises(ValueError):
self.function_out_unstored()

def test_simple_workflow(self):
"""Test construction of simple workflow by chaining process functions."""

Expand Down
39 changes: 31 additions & 8 deletions aiida/backends/tests/engine/test_work_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def failure(self):
return ExitCode()

def success(self):
self.out(self.OUTPUT_LABEL, Int(self.OUTPUT_VALUE))
self.out(self.OUTPUT_LABEL, Int(self.OUTPUT_VALUE).store())
return


Expand Down Expand Up @@ -365,6 +365,25 @@ def define(cls, spec):
with self.assertRaises(AssertionError):
launch.run(IncompleteDefineWorkChain)

def test_out_unstored(self):
"""Calling `self.out` on an unstored `Node` should raise.
It indicates that users created new data whose provenance will be lost.
"""

class IllegalWorkChain(WorkChain):

@classmethod
def define(cls, spec):
super(IllegalWorkChain, cls).define(spec)
spec.outline(cls.illegal)

def illegal(self):
self.out('not_allowed', orm.Int(2))

with self.assertRaises(ValueError):
launch.run(IllegalWorkChain)

def test_same_input_node(self):

class Wf(WorkChain):
Expand All @@ -385,8 +404,8 @@ def check_a_b(self):
run_and_check_success(Wf, a=x, b=x)

def test_context(self):
A = Str("a")
B = Str("b")
A = Str("a").store()
B = Str("b").store()

test_case = self

Expand Down Expand Up @@ -541,12 +560,14 @@ def define(cls, spec):
spec.outline(cls.do_run)

def do_run(self):
self.out("value", Int(5))
self.out("value", Int(5).store())

run_and_check_success(MainWorkChain)

def test_tocontext_schedule_workchain(self):

node = Int(5).store()

class MainWorkChain(WorkChain):

@classmethod
Expand All @@ -559,7 +580,7 @@ def do_run(self):
return ToContext(subwc=self.submit(SubWorkChain))

def check(self):
assert self.ctx.subwc.outputs.value == Int(5)
assert self.ctx.subwc.outputs.value == node

class SubWorkChain(WorkChain):

Expand All @@ -569,7 +590,7 @@ def define(cls, spec):
spec.outline(cls.do_run)

def do_run(self):
self.out('value', Int(5))
self.out('value', node)

run_and_check_success(MainWorkChain)

Expand Down Expand Up @@ -635,7 +656,7 @@ def check(self):
run_and_check_success(TestWorkChain)

def test_to_context(self):
val = Int(5)
val = Int(5).store()

test_case = self

Expand Down Expand Up @@ -1140,7 +1161,9 @@ def define(cls, spec):
spec.outline(cls.do_run)

def do_run(self):
self.out('a', self.inputs.a + self.inputs.b)
summed = self.inputs.a + self.inputs.b
summed.store()
self.out('a', summed)
self.out('b', self.inputs.b)
self.out('c', self.inputs.c)

Expand Down
26 changes: 18 additions & 8 deletions aiida/backends/tests/orm/node/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def test_add_incoming_return(self):
"""Nodes can have an infinite amount of incoming RETURN links, as long as the link triple is unique."""
source_one = WorkflowNode()
source_two = WorkflowNode()
target = Data()
target = Data().store() # Needs to be stored: see `test_validate_outgoing_workflow`

target.add_incoming(source_one, LinkType.RETURN, 'link_label')

Expand All @@ -202,6 +202,19 @@ def test_add_incoming_return(self):
target.validate_incoming(source_one, LinkType.RETURN, 'other_label')
target.validate_incoming(source_two, LinkType.RETURN, 'link_label')

def test_validate_outgoing_workflow(self):
"""Verify that attaching an unstored `Data` node with `RETURN` link from a `WorkflowNode` raises.
This would for example be the case if a user inside a workfunction or work chain creates a new node based on its
inputs or the outputs returned by another process and tries to attach it as an output. This would the provenance
of that data node to be lost and should be explicitly forbidden by raising.
"""
source = WorkflowNode().store()
target = Data()

with self.assertRaises(ValueError):
target.add_incoming(source, LinkType.RETURN, 'link_label')

def test_get_incoming(self):
"""Test that `Node.get_incoming` will return stored and cached input links."""
source_one = Data().store()
Expand Down Expand Up @@ -253,12 +266,11 @@ def test_node_indegree_unique_triple(self):
"""
return_one = WorkflowNode().store()
return_two = WorkflowNode().store()
data = Data()
data = Data().store() # Needs to be stored: see `test_validate_outgoing_workflow`

# Verify that adding two return links with the same link label but from different source is allowed
data.add_incoming(return_one, link_type=LinkType.RETURN, link_label='returned')
data.add_incoming(return_two, link_type=LinkType.RETURN, link_label='returned')
data.store()

uuids_incoming = set(node.uuid for node in data.get_incoming().all_nodes())
uuids_expected = set([return_one.uuid, return_two.uuid])
Expand Down Expand Up @@ -325,9 +337,8 @@ def test_tab_completable_properties(self):
output1 = Data().store()
output2 = Data().store()

# top_workflow has two inputs, proxies them to workflow, that in turn
# calls two calcs (passing 1 data to each),
# and return the two data nodes returned one by each called calculation
# The `top_workflow` has two inputs, proxies them to `workflow`, that in turn calls two calculations, passing
# one data node to each as input, and return the two data nodes returned one by each called calculation
top_workflow.add_incoming(input1, link_type=LinkType.INPUT_WORK, link_label='a')
top_workflow.add_incoming(input2, link_type=LinkType.INPUT_WORK, link_label='b')

Expand All @@ -348,7 +359,6 @@ def test_tab_completable_properties(self):
output1.add_incoming(top_workflow, link_type=LinkType.RETURN, link_label='result_a')
output2.add_incoming(top_workflow, link_type=LinkType.RETURN, link_label='result_b')

## Now we test the methods
# creator
self.assertEqual(output1.creator.pk, calc1.pk)
self.assertEqual(output2.creator.pk, calc2.pk)
Expand Down Expand Up @@ -386,4 +396,4 @@ def test_tab_completable_properties(self):
self.assertEqual(workflow.outputs.result_a.pk, output1.pk)
self.assertEqual(workflow.outputs.result_b.pk, output2.pk)
with self.assertRaises(exceptions.NotExistent):
_ = workflow.outputs.some_label # noqa
_ = workflow.outputs.some_label
2 changes: 1 addition & 1 deletion aiida/backends/tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def failure(self):
return ExitCode()

def success(self):
self.out(self.OUTPUT_LABEL, orm.Int(self.OUTPUT_VALUE))
self.out(self.OUTPUT_LABEL, orm.Int(self.OUTPUT_VALUE).store())

class DummyWorkChain(WorkChain):
pass
Expand Down
3 changes: 2 additions & 1 deletion aiida/backends/tests/utils/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def define(cls, spec):
spec.output('result', required=True)

def run(self):
self.out(self.inputs.a + self.inputs.b)
summed = self.inputs.a + self.inputs.b
self.out(summed.store())


class BadOutput(Process):
Expand Down
22 changes: 20 additions & 2 deletions aiida/orm/nodes/process/workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
class WorkflowNode(ProcessNode):
"""Base class for all nodes representing the execution of a workflow process."""

# pylint: disable=too-few-public-methods

# Workflow nodes are storable
_storable = True
_unstorable_message = 'storing for this node has been disabled'
Expand Down Expand Up @@ -52,3 +50,23 @@ def outputs(self):
:return: `NodeLinksManager`
"""
return NodeLinksManager(node=self, link_type=LinkType.RETURN, incoming=False)

def validate_outgoing(self, target, link_type, link_label):
"""Validate adding a link of the given type from ourself to a given node.
A workflow cannot 'create' Data, so if we receive an outgoing link to an unstored Data node, that means
the user created a Data node within our function body and tries to attach it as an output. This is strictly
forbidden and can cause provenance to be lost.
:param target: the node to which the link is going
:param link_type: the link type
:param link_label: the link label
:raise TypeError: if `target` is not a Node instance or `link_type` is not a `LinkType` enum
:raise ValueError: if the proposed link is invalid
"""
super(WorkflowNode, self).validate_outgoing(target, link_type, link_label)
if link_type is LinkType.RETURN and not target.is_stored:
raise ValueError(
'Workflow<{}> tried returning an unstored `Data` node. This likely means new `Data` is being created '
'inside the workflow. In order to preserve data provenance, use a `calcfunction` to create this node '
'and return its output from the workflow'.format(self.process_label))

0 comments on commit d0ab2b8

Please sign in to comment.