Skip to content

Commit 99d69d2

Browse files
committed
fix may_trigger for nested and async
see #594
1 parent c72e801 commit 99d69d2

File tree

5 files changed

+22
-4
lines changed

5 files changed

+22
-4
lines changed

tests/test_async.py

+15
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,21 @@ async def run():
498498

499499
asyncio.run(run())
500500

501+
def test_may_transition_internal(self):
502+
states = ['A', 'B', 'C']
503+
d = DummyModel()
504+
_ = self.machine_cls(model=d, states=states, transitions=[["go", "A", "B"], ["wait", "B", None]],
505+
initial='A', auto_transitions=False)
506+
507+
async def run():
508+
assert await d.may_go()
509+
assert not await d.may_wait()
510+
await d.go()
511+
assert not await d.may_go()
512+
assert await d.may_wait()
513+
514+
asyncio.run(run())
515+
501516

502517
@skipIf(asyncio is None or (pgv is None and gv is None), "AsyncGraphMachine requires asyncio and (py)gaphviz")
503518
class TestAsyncGraphMachine(TestAsync):

tests/test_nesting.py

+3
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,7 @@ def test_machine_may_transitions(self):
910910
transitions = [
911911
{'trigger': 'walk', 'source': 'A', 'dest': 'B'},
912912
{'trigger': 'run', 'source': 'B', 'dest': 'C'},
913+
{'trigger': 'wait', 'source': 'B', 'dest': None},
913914
{'trigger': 'run_fast', 'source': 'C', 'dest': 'C{0}1'.format(self.separator)},
914915
{'trigger': 'sprint', 'source': 'C', 'dest': 'D'}
915916
]
@@ -920,11 +921,13 @@ def test_machine_may_transitions(self):
920921
assert not m.may_run()
921922
assert not m.may_run_fast()
922923
assert not m.may_sprint()
924+
assert not m.may_wait()
923925

924926
m.walk()
925927
assert not m.may_walk()
926928
assert m.may_run()
927929
assert not m.may_run_fast()
930+
assert m.may_wait()
928931

929932
m.run()
930933
assert m.may_run_fast()

transitions/core.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,7 @@ def _can_trigger(self, model, trigger, *args, **kwargs):
884884
continue
885885
for transition in self.events[trigger_name].transitions[state]:
886886
try:
887-
_ = transition.source if transition.dest is None else self.get_state(transition.dest)
887+
_ = self.get_state(transition.dest) if transition.dest is not None else transition.source
888888
except ValueError:
889889
continue
890890

transitions/extensions/asyncio.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ async def _can_trigger(self, model, trigger, *args, **kwargs):
433433
continue
434434
for transition in self.events[trigger_name].transitions[state]:
435435
try:
436-
_ = self.get_state(transition.dest)
436+
_ = self.get_state(transition.dest) if transition.dest is not None else transition.source
437437
except ValueError:
438438
continue
439439
await self.callbacks(self.prepare_event, evt)
@@ -559,7 +559,7 @@ async def _can_trigger_nested(self, model, trigger, path, *args, **kwargs):
559559
state_name = self.state_cls.separator.join(source_path)
560560
for transition in self.events[trigger].transitions.get(state_name, []):
561561
try:
562-
_ = self.get_state(transition.dest)
562+
_ = self.get_state(transition.dest) if transition.dest is not None else transition.source
563563
except ValueError:
564564
continue
565565
await self.callbacks(self.prepare_event, evt)

transitions/extensions/nesting.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ def _can_trigger_nested(self, model, trigger, path, *args, **kwargs):
686686
state_name = self.state_cls.separator.join(source_path)
687687
for transition in self.events[trigger].transitions.get(state_name, []):
688688
try:
689-
_ = self.get_state(transition.dest)
689+
_ = self.get_state(transition.dest) if transition.dest is not None else transition.source
690690
except ValueError:
691691
continue
692692
self.callbacks(self.prepare_event, evt)

0 commit comments

Comments
 (0)