diff --git a/aiida/backends/tests/work/work_chain.py b/aiida/backends/tests/work/work_chain.py index 9430339334..82efa93a40 100644 --- a/aiida/backends/tests/work/work_chain.py +++ b/aiida/backends/tests/work/work_chain.py @@ -26,7 +26,7 @@ from aiida.utils.capturing import Capturing from aiida.workflows.wf_demo import WorkflowDemo from aiida import work -from aiida.work import Process +from aiida.work import ExitCode, Process from aiida.work.persistence import ObjectLoader from aiida.work.workchain import * @@ -161,53 +161,97 @@ class PotentialFailureWorkChain(WorkChain): EXIT_STATUS = 1 EXIT_MESSAGE = 'Well you did ask for it' + OUTPUT_LABEL = 'optional_output' + OUTPUT_VALUE = 144 @classmethod def define(cls, spec): super(PotentialFailureWorkChain, cls).define(spec) spec.input('success', valid_type=Bool) + spec.input('through_return', valid_type=Bool, default=Bool(False)) spec.input('through_exit_code', valid_type=Bool, default=Bool(False)) spec.exit_code(cls.EXIT_STATUS, 'EXIT_STATUS', cls.EXIT_MESSAGE) spec.outline( + if_(cls.should_return_out_of_outline)( + return_(cls.EXIT_STATUS) + ), cls.failure, cls.success ) + spec.output('optional', required=False) + + def should_return_out_of_outline(self): + return self.inputs.through_return.value def failure(self): if self.inputs.success.value is False: + # Returning either 0 or ExitCode with non-zero status should terminate the workchain if self.inputs.through_exit_code.value is False: return self.EXIT_STATUS else: return self.exit_codes.EXIT_STATUS + else: + # Returning 0 or ExitCode with zero status should *not* terminate the workchain + if self.inputs.through_exit_code.value is False: + return 0 + else: + return ExitCode() def success(self): + self.out(self.OUTPUT_LABEL, Int(self.OUTPUT_VALUE)) return class TestExitStatus(AiidaTestCase): + """ + This class should test the various ways that one can exit from the outline flow of a WorkChain, other than + it running it all the way through. Currently this can be done directly in the outline by calling the `return_` + construct, or from an outline step function by returning a non-zero integer or an ExitCode with a non-zero status + """ - def test_failing_workchain(self): + def test_failing_workchain_through_integer(self): result, node = work.run_get_node(PotentialFailureWorkChain, success=Bool(False)) self.assertEquals(node.exit_status, PotentialFailureWorkChain.EXIT_STATUS) self.assertEquals(node.exit_message, None) self.assertEquals(node.is_finished, True) self.assertEquals(node.is_finished_ok, False) self.assertEquals(node.is_failed, True) + self.assertNotIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outputs_dict()) - def test_failing_workchain_with_message(self): + def test_failing_workchain_through_exit_code(self): result, node = work.run_get_node(PotentialFailureWorkChain, success=Bool(False), through_exit_code=Bool(True)) self.assertEquals(node.exit_status, PotentialFailureWorkChain.EXIT_STATUS) self.assertEquals(node.exit_message, PotentialFailureWorkChain.EXIT_MESSAGE) self.assertEquals(node.is_finished, True) self.assertEquals(node.is_finished_ok, False) self.assertEquals(node.is_failed, True) + self.assertNotIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outputs_dict()) - def test_successful_workchain(self): + def test_successful_workchain_through_integer(self): result, node = work.run_get_node(PotentialFailureWorkChain, success=Bool(True)) self.assertEquals(node.exit_status, 0) self.assertEquals(node.is_finished, True) self.assertEquals(node.is_finished_ok, True) self.assertEquals(node.is_failed, False) + self.assertIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outputs_dict()) + self.assertEquals(node.get_outputs_dict()[PotentialFailureWorkChain.OUTPUT_LABEL], PotentialFailureWorkChain.OUTPUT_VALUE) + + def test_successful_workchain_through_exit_code(self): + result, node = work.run_get_node(PotentialFailureWorkChain, success=Bool(True), through_exit_code=Bool(True)) + self.assertEquals(node.exit_status, 0) + self.assertEquals(node.is_finished, True) + self.assertEquals(node.is_finished_ok, True) + self.assertEquals(node.is_failed, False) + self.assertIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outputs_dict()) + self.assertEquals(node.get_outputs_dict()[PotentialFailureWorkChain.OUTPUT_LABEL], PotentialFailureWorkChain.OUTPUT_VALUE) + + def test_return_out_of_outline(self): + result, node = work.run_get_node(PotentialFailureWorkChain, success=Bool(True), through_return=Bool(True)) + self.assertEquals(node.exit_status, PotentialFailureWorkChain.EXIT_STATUS) + self.assertEquals(node.is_finished, True) + self.assertEquals(node.is_finished_ok, False) + self.assertEquals(node.is_failed, True) + self.assertNotIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outputs_dict()) class IfTest(work.WorkChain): diff --git a/aiida/work/workchain.py b/aiida/work/workchain.py index 0af77c5e1f..9fb1e4bbb5 100644 --- a/aiida/work/workchain.py +++ b/aiida/work/workchain.py @@ -10,6 +10,7 @@ """Components for the WorkChain concept of the workflow engine.""" from __future__ import absolute_import import functools +import six from plumpy import auto_persist, WorkChainSpec, Wait, Continue from plumpy.workchains import if_, while_, return_, _PropagateReturn @@ -156,13 +157,18 @@ def _do_step(self): except _PropagateReturn as exception: finished, result = True, exception.exit_code else: - if isinstance(stepper_result, (int, ExitCode)): + # Set result to None unless stepper_result was non-zero positive integer or ExitCode with similar status + if isinstance(stepper_result, six.integer_types) and stepper_result > 0: + result = ExitCode(stepper_result) + elif isinstance(stepper_result, ExitCode) and stepper_result.status > 0: result = stepper_result else: result = None - if not finished and (stepper_result is None or isinstance(stepper_result, ToContext)): - + # If the stepper said we are finished or the result is an ExitCode, we exit by returning + if finished or isinstance(result, ExitCode): + return result + else: if isinstance(stepper_result, ToContext): self.to_context(**stepper_result) @@ -171,8 +177,6 @@ def _do_step(self): return Continue(self._do_step) - return result - def on_wait(self, awaitables): super(WorkChain, self).on_wait(awaitables) if self._awaitables: