Skip to content

Commit

Permalink
Merge pull request #904 from sphuber/fix_902_incorrect_stepper_recrea…
Browse files Browse the repository at this point in the history
…tion_if_construct

Fix recreation of stepper in workchain 'if' logical block
  • Loading branch information
nmounet authored Nov 9, 2017
2 parents 318a6e5 + f250f86 commit 666bce8
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 18 deletions.
60 changes: 57 additions & 3 deletions aiida/backends/tests/work/workChain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from aiida.backends.testbase import AiidaTestCase
from plum.engine.ticking import TickingEngine
from plum.persistence.bundle import Bundle
import plum.process_monitor
from aiida.orm.calculation.work import WorkCalculation
from aiida.work.workchain import WorkChain, \
Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(self):
[self.s1.__name__, self.s2.__name__, self.s3.__name__,
self.s4.__name__, self.s5.__name__, self.s6.__name__,
self.isA.__name__, self.isB.__name__, self.ltN.__name__]
}
}

def s1(self):
self._set_finished(inspect.stack()[0][3])
Expand Down Expand Up @@ -120,6 +121,33 @@ def test_dict(self):
c['new_attr']


class IfTest(WorkChain):
@classmethod
def define(cls, spec):
super(IfTest, cls).define(spec)
spec.outline(
if_(cls.condition)(
cls.step1,
cls.step2
)
)

def on_create(self, pid, inputs, saved_state):
super(IfTest, self).on_create(pid, inputs, saved_state)
if saved_state is None:
self.ctx.s1 = False
self.ctx.s2 = False

def condition(self):
return True

def step1(self):
self.ctx.s1 = True

def step2(self):
self.ctx.s2 = True


class TestWorkchain(AiidaTestCase):
def setUp(self):
super(TestWorkchain, self).setUp()
Expand Down Expand Up @@ -321,6 +349,29 @@ def run(self):

run(MainWorkChain)

def test_if_block_persistence(self):
""" This test was created to capture issue #902 """
wc = IfTest.new_instance()

while not wc.ctx.s1 and not wc.has_finished():
wc.tick()
self.assertTrue(wc.ctx.s1)
self.assertFalse(wc.ctx.s2)

# Now bundle the thing
b = Bundle()
wc.save_instance_state(b)
# Abort the current one
wc.stop()
wc.destroy(execute=True)

# Load from saved tate
wc = IfTest.create_from(b)
self.assertTrue(wc.ctx.s1)
self.assertFalse(wc.ctx.s2)

wc.run_until_complete()

def _run_with_checkpoints(self, wf_class, inputs=None):
finished_steps = {}

Expand All @@ -336,7 +387,6 @@ def _run_with_checkpoints(self, wf_class, inputs=None):


class TestWorkchainWithOldWorkflows(AiidaTestCase):

def setUp(self):
super(TestWorkchainWithOldWorkflows, self).setUp()
import logging
Expand Down Expand Up @@ -409,10 +459,12 @@ def test_get_proc_outputs(self):
self.assertEquals(outputs['a'], a)
self.assertEquals(outputs['b'], b)


class TestWorkChainAbort(AiidaTestCase):
"""
Test the functionality to abort a workchain
"""

class AbortableWorkChain(WorkChain):
@classmethod
def define(cls, spec):
Expand Down Expand Up @@ -490,11 +542,13 @@ def test_simple_kill_through_process(self):
self.assertEquals(future.process.calc.has_aborted(), True)
engine.shutdown()


class TestWorkChainAbortChildren(AiidaTestCase):
"""
Test the functionality to abort a workchain and verify that children
are also aborted appropriately
"""

class SubWorkChain(WorkChain):
@classmethod
def define(cls, spec):
Expand Down Expand Up @@ -575,4 +629,4 @@ def test_simple_kill_through_node(self):
self.assertEquals(future.process.calc.has_finished_ok(), False)
self.assertEquals(future.process.calc.has_failed(), False)
self.assertEquals(future.process.calc.has_aborted(), True)
engine.shutdown()
engine.shutdown()
40 changes: 25 additions & 15 deletions aiida/work/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def abort(self, msg=None, timeout=None):
self._aborted = True
self.stop()


def ToContext(**kwargs):
"""
Utility function that returns a list of UpdateContext Interstep instances
Expand All @@ -366,6 +367,7 @@ class _InterstepFactory(object):
Factory to create the appropriate Interstep instance based
on the class string that was written to the bundle
"""

def create(self, bundle):
class_string = bundle[Bundle.CLASS]
if class_string == get_class_string(ToContext):
Expand Down Expand Up @@ -567,22 +569,21 @@ class Stepper(Stepper):
def __init__(self, workflow, if_spec):
super(_If.Stepper, self).__init__(workflow)
self._if_spec = if_spec
self._pos = 0
self._pos = -1
self._current_stepper = None

def step(self):
if self._current_stepper is None:
stepper = self._get_next_stepper()
# If we can't get a stepper then no conditions match, return
if stepper is None:
return True, None
self._current_stepper = stepper
self._create_stepper()

# If we can't get a stepper then no conditions match, return
if self._current_stepper is None:
return True, None

finished, retval = self._current_stepper.step()
if finished:
self._current_stepper = None
else:
self._pos += 1
self._pos = -1

return finished, retval

Expand All @@ -596,15 +597,24 @@ def save_position(self, out_position):
def load_position(self, bundle):
self._pos = bundle[self._POSITION]
if self._STEPPER_POS in bundle:
self._current_stepper = self._get_next_stepper()
self._create_stepper()
self._current_stepper.load_position(bundle[self._STEPPER_POS])
else:
self._current_stepper = None

def _get_next_stepper(self):
# Check the conditions until we find that that is true
for conditional in self._if_spec.conditionals[self._pos:]:
if conditional.is_true(self._workflow):
return conditional.body.create_stepper(self._workflow)
return None
def _create_stepper(self):
if self._pos == -1:
self._current_stepper = None
# Check the conditions until we find one that is true
for idx, condition in enumerate(self._if_spec.conditionals):
if condition.is_true(self._workflow):
stepper = condition.body.create_stepper(self._workflow)
self._pos = idx
self._current_stepper = stepper
return
else:
branch = self._if_spec.conditionals[self._pos]
self._current_stepper = branch.body.create_stepper(self._workflow)

def __init__(self, condition):
super(_If, self).__init__()
Expand Down

0 comments on commit 666bce8

Please sign in to comment.