Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure WorkChain does not exit unless stepper returns non-zero value #1945

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 48 additions & 4 deletions aiida/backends/tests/work/work_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from aiida.utils.capturing import Capturing
from aiida.workflows.wf_demo import WorkflowDemo
from aiida import work
from aiida.work import Process
from aiida.work import ExitCode, Process
from aiida.work.persistence import ObjectLoader
from aiida.work.workchain import *

Expand Down Expand Up @@ -161,53 +161,97 @@ class PotentialFailureWorkChain(WorkChain):

EXIT_STATUS = 1
EXIT_MESSAGE = 'Well you did ask for it'
OUTPUT_LABEL = 'optional_output'
OUTPUT_VALUE = 144

@classmethod
def define(cls, spec):
super(PotentialFailureWorkChain, cls).define(spec)
spec.input('success', valid_type=Bool)
spec.input('through_return', valid_type=Bool, default=Bool(False))
spec.input('through_exit_code', valid_type=Bool, default=Bool(False))
spec.exit_code(cls.EXIT_STATUS, 'EXIT_STATUS', cls.EXIT_MESSAGE)
spec.outline(
if_(cls.should_return_out_of_outline)(
return_(cls.EXIT_STATUS)
),
cls.failure,
cls.success
)
spec.output('optional', required=False)

def should_return_out_of_outline(self):
return self.inputs.through_return.value

def failure(self):
if self.inputs.success.value is False:
# Returning either 0 or ExitCode with non-zero status should terminate the workchain
if self.inputs.through_exit_code.value is False:
return self.EXIT_STATUS
else:
return self.exit_codes.EXIT_STATUS
else:
# Returning 0 or ExitCode with zero status should *not* terminate the workchain
if self.inputs.through_exit_code.value is False:
return 0
else:
return ExitCode()

def success(self):
self.out(self.OUTPUT_LABEL, Int(self.OUTPUT_VALUE))
return


class TestExitStatus(AiidaTestCase):
"""
This class should test the various ways that one can exit from the outline flow of a WorkChain, other than
it running it all the way through. Currently this can be done directly in the outline by calling the `return_`
construct, or from an outline step function by returning a non-zero integer or an ExitCode with a non-zero status
"""

def test_failing_workchain(self):
def test_failing_workchain_through_integer(self):
result, node = work.run_get_node(PotentialFailureWorkChain, success=Bool(False))
self.assertEquals(node.exit_status, PotentialFailureWorkChain.EXIT_STATUS)
self.assertEquals(node.exit_message, None)
self.assertEquals(node.is_finished, True)
self.assertEquals(node.is_finished_ok, False)
self.assertEquals(node.is_failed, True)
self.assertNotIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outputs_dict())

def test_failing_workchain_with_message(self):
def test_failing_workchain_through_exit_code(self):
result, node = work.run_get_node(PotentialFailureWorkChain, success=Bool(False), through_exit_code=Bool(True))
self.assertEquals(node.exit_status, PotentialFailureWorkChain.EXIT_STATUS)
self.assertEquals(node.exit_message, PotentialFailureWorkChain.EXIT_MESSAGE)
self.assertEquals(node.is_finished, True)
self.assertEquals(node.is_finished_ok, False)
self.assertEquals(node.is_failed, True)
self.assertNotIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outputs_dict())

def test_successful_workchain(self):
def test_successful_workchain_through_integer(self):
result, node = work.run_get_node(PotentialFailureWorkChain, success=Bool(True))
self.assertEquals(node.exit_status, 0)
self.assertEquals(node.is_finished, True)
self.assertEquals(node.is_finished_ok, True)
self.assertEquals(node.is_failed, False)
self.assertIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outputs_dict())
self.assertEquals(node.get_outputs_dict()[PotentialFailureWorkChain.OUTPUT_LABEL], PotentialFailureWorkChain.OUTPUT_VALUE)

def test_successful_workchain_through_exit_code(self):
result, node = work.run_get_node(PotentialFailureWorkChain, success=Bool(True), through_exit_code=Bool(True))
self.assertEquals(node.exit_status, 0)
self.assertEquals(node.is_finished, True)
self.assertEquals(node.is_finished_ok, True)
self.assertEquals(node.is_failed, False)
self.assertIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outputs_dict())
self.assertEquals(node.get_outputs_dict()[PotentialFailureWorkChain.OUTPUT_LABEL], PotentialFailureWorkChain.OUTPUT_VALUE)

def test_return_out_of_outline(self):
result, node = work.run_get_node(PotentialFailureWorkChain, success=Bool(True), through_return=Bool(True))
self.assertEquals(node.exit_status, PotentialFailureWorkChain.EXIT_STATUS)
self.assertEquals(node.is_finished, True)
self.assertEquals(node.is_finished_ok, False)
self.assertEquals(node.is_failed, True)
self.assertNotIn(PotentialFailureWorkChain.OUTPUT_LABEL, node.get_outputs_dict())


class IfTest(work.WorkChain):
Expand Down
14 changes: 9 additions & 5 deletions aiida/work/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""Components for the WorkChain concept of the workflow engine."""
from __future__ import absolute_import
import functools
import six

from plumpy import auto_persist, WorkChainSpec, Wait, Continue
from plumpy.workchains import if_, while_, return_, _PropagateReturn
Expand Down Expand Up @@ -156,13 +157,18 @@ def _do_step(self):
except _PropagateReturn as exception:
finished, result = True, exception.exit_code
else:
if isinstance(stepper_result, (int, ExitCode)):
# Set result to None unless stepper_result was non-zero positive integer or ExitCode with similar status
if isinstance(stepper_result, six.integer_types) and stepper_result > 0:
result = ExitCode(stepper_result)
elif isinstance(stepper_result, ExitCode) and stepper_result.status > 0:
result = stepper_result
else:
result = None

if not finished and (stepper_result is None or isinstance(stepper_result, ToContext)):

# If the stepper said we are finished or the result is an ExitCode, we exit by returning
if finished or isinstance(result, ExitCode):
return result
else:
if isinstance(stepper_result, ToContext):
self.to_context(**stepper_result)

Expand All @@ -171,8 +177,6 @@ def _do_step(self):

return Continue(self._do_step)

return result

def on_wait(self, awaitables):
super(WorkChain, self).on_wait(awaitables)
if self._awaitables:
Expand Down