Skip to content

Commit

Permalink
Fixes 902
Browse files Browse the repository at this point in the history
There was a mistake in the logic of the 'if' block of workchains when
loading.  A position varaiable was being used to keep track of how many
times that conditional block was ticked.  This was being upped each
time but also being used to determine which of possible conditional
branches the condition was at e.g.:

if(...)( <-- pos 0

) elif(...) <-- pos 1

) else(

)

Long story short, when the condition was reloaded from a saved state
it was possible that pos was larger than the number of conditions
(usual just one if there's only an if) and it couldn't resume
from where it was when it was saved.
  • Loading branch information
muhrin authored and sphuber committed Nov 9, 2017
1 parent 318a6e5 commit f250f86
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 18 deletions.
60 changes: 57 additions & 3 deletions aiida/backends/tests/work/workChain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from aiida.backends.testbase import AiidaTestCase
from plum.engine.ticking import TickingEngine
from plum.persistence.bundle import Bundle
import plum.process_monitor
from aiida.orm.calculation.work import WorkCalculation
from aiida.work.workchain import WorkChain, \
Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(self):
[self.s1.__name__, self.s2.__name__, self.s3.__name__,
self.s4.__name__, self.s5.__name__, self.s6.__name__,
self.isA.__name__, self.isB.__name__, self.ltN.__name__]
}
}

def s1(self):
self._set_finished(inspect.stack()[0][3])
Expand Down Expand Up @@ -120,6 +121,33 @@ def test_dict(self):
c['new_attr']


class IfTest(WorkChain):
@classmethod
def define(cls, spec):
super(IfTest, cls).define(spec)
spec.outline(
if_(cls.condition)(
cls.step1,
cls.step2
)
)

def on_create(self, pid, inputs, saved_state):
super(IfTest, self).on_create(pid, inputs, saved_state)
if saved_state is None:
self.ctx.s1 = False
self.ctx.s2 = False

def condition(self):
return True

def step1(self):
self.ctx.s1 = True

def step2(self):
self.ctx.s2 = True


class TestWorkchain(AiidaTestCase):
def setUp(self):
super(TestWorkchain, self).setUp()
Expand Down Expand Up @@ -321,6 +349,29 @@ def run(self):

run(MainWorkChain)

def test_if_block_persistence(self):
""" This test was created to capture issue #902 """
wc = IfTest.new_instance()

while not wc.ctx.s1 and not wc.has_finished():
wc.tick()
self.assertTrue(wc.ctx.s1)
self.assertFalse(wc.ctx.s2)

# Now bundle the thing
b = Bundle()
wc.save_instance_state(b)
# Abort the current one
wc.stop()
wc.destroy(execute=True)

# Load from saved tate
wc = IfTest.create_from(b)
self.assertTrue(wc.ctx.s1)
self.assertFalse(wc.ctx.s2)

wc.run_until_complete()

def _run_with_checkpoints(self, wf_class, inputs=None):
finished_steps = {}

Expand All @@ -336,7 +387,6 @@ def _run_with_checkpoints(self, wf_class, inputs=None):


class TestWorkchainWithOldWorkflows(AiidaTestCase):

def setUp(self):
super(TestWorkchainWithOldWorkflows, self).setUp()
import logging
Expand Down Expand Up @@ -409,10 +459,12 @@ def test_get_proc_outputs(self):
self.assertEquals(outputs['a'], a)
self.assertEquals(outputs['b'], b)


class TestWorkChainAbort(AiidaTestCase):
"""
Test the functionality to abort a workchain
"""

class AbortableWorkChain(WorkChain):
@classmethod
def define(cls, spec):
Expand Down Expand Up @@ -490,11 +542,13 @@ def test_simple_kill_through_process(self):
self.assertEquals(future.process.calc.has_aborted(), True)
engine.shutdown()


class TestWorkChainAbortChildren(AiidaTestCase):
"""
Test the functionality to abort a workchain and verify that children
are also aborted appropriately
"""

class SubWorkChain(WorkChain):
@classmethod
def define(cls, spec):
Expand Down Expand Up @@ -575,4 +629,4 @@ def test_simple_kill_through_node(self):
self.assertEquals(future.process.calc.has_finished_ok(), False)
self.assertEquals(future.process.calc.has_failed(), False)
self.assertEquals(future.process.calc.has_aborted(), True)
engine.shutdown()
engine.shutdown()
40 changes: 25 additions & 15 deletions aiida/work/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def abort(self, msg=None, timeout=None):
self._aborted = True
self.stop()


def ToContext(**kwargs):
"""
Utility function that returns a list of UpdateContext Interstep instances
Expand All @@ -366,6 +367,7 @@ class _InterstepFactory(object):
Factory to create the appropriate Interstep instance based
on the class string that was written to the bundle
"""

def create(self, bundle):
class_string = bundle[Bundle.CLASS]
if class_string == get_class_string(ToContext):
Expand Down Expand Up @@ -567,22 +569,21 @@ class Stepper(Stepper):
def __init__(self, workflow, if_spec):
super(_If.Stepper, self).__init__(workflow)
self._if_spec = if_spec
self._pos = 0
self._pos = -1
self._current_stepper = None

def step(self):
if self._current_stepper is None:
stepper = self._get_next_stepper()
# If we can't get a stepper then no conditions match, return
if stepper is None:
return True, None
self._current_stepper = stepper
self._create_stepper()

# If we can't get a stepper then no conditions match, return
if self._current_stepper is None:
return True, None

finished, retval = self._current_stepper.step()
if finished:
self._current_stepper = None
else:
self._pos += 1
self._pos = -1

return finished, retval

Expand All @@ -596,15 +597,24 @@ def save_position(self, out_position):
def load_position(self, bundle):
self._pos = bundle[self._POSITION]
if self._STEPPER_POS in bundle:
self._current_stepper = self._get_next_stepper()
self._create_stepper()
self._current_stepper.load_position(bundle[self._STEPPER_POS])
else:
self._current_stepper = None

def _get_next_stepper(self):
# Check the conditions until we find that that is true
for conditional in self._if_spec.conditionals[self._pos:]:
if conditional.is_true(self._workflow):
return conditional.body.create_stepper(self._workflow)
return None
def _create_stepper(self):
if self._pos == -1:
self._current_stepper = None
# Check the conditions until we find one that is true
for idx, condition in enumerate(self._if_spec.conditionals):
if condition.is_true(self._workflow):
stepper = condition.body.create_stepper(self._workflow)
self._pos = idx
self._current_stepper = stepper
return
else:
branch = self._if_spec.conditionals[self._pos]
self._current_stepper = branch.body.create_stepper(self._workflow)

def __init__(self, condition):
super(_If, self).__init__()
Expand Down

0 comments on commit f250f86

Please sign in to comment.