diff --git a/docs/source/tutorial.ipynb b/docs/source/tutorial.ipynb index 90194728..2da8bf87 100644 --- a/docs/source/tutorial.ipynb +++ b/docs/source/tutorial.ipynb @@ -117,7 +117,7 @@ "source": [ "class SimpleProcess(plumpy.Process):\n", "\n", - " def run(self):\n", + " async def run(self):\n", " print(self.state.name)\n", " \n", "process = SimpleProcess()\n", @@ -219,7 +219,7 @@ " spec.output('output2.output2a')\n", " spec.output('output2.output2b')\n", "\n", - " def run(self):\n", + " async def run(self):\n", " self.out('output1', self.inputs.input1)\n", " self.out('output2.output2a', self.inputs.input2.input2a)\n", " self.out('output2.output2b', self.inputs.input2.input2b)\n", @@ -277,7 +277,7 @@ "source": [ "class ContinueProcess(plumpy.Process):\n", "\n", - " def run(self):\n", + " async def run(self):\n", " print(\"running\")\n", " return plumpy.Continue(self.continue_fn)\n", " \n", @@ -340,7 +340,7 @@ "\n", "class WaitProcess(plumpy.Process):\n", "\n", - " def run(self):\n", + " async def run(self):\n", " return plumpy.Wait(self.resume_fn)\n", " \n", " def resume_fn(self):\n", @@ -405,7 +405,7 @@ " super().define(spec)\n", " spec.input('name')\n", "\n", - " def run(self):\n", + " async def run(self):\n", " print(self.inputs.name, \"run\")\n", " return plumpy.Continue(self.continue_fn)\n", "\n", @@ -469,12 +469,12 @@ "source": [ "class SimpleProcess(plumpy.Process):\n", " \n", - " def run(self):\n", + " async def run(self):\n", " print(self.get_name())\n", " \n", "class PauseProcess(plumpy.Process):\n", "\n", - " def run(self):\n", + " async def run(self):\n", " print(f\"{self.get_name()}: pausing\")\n", " self.pause()\n", " print(f\"{self.get_name()}: continue step\")\n", @@ -727,7 +727,7 @@ " spec.input('name', valid_type=str, default='process')\n", " spec.output('value')\n", "\n", - " def run(self):\n", + " async def run(self):\n", " print(self.inputs.name)\n", " self.out('value', 'value')\n", "\n", diff --git a/examples/process_helloworld.py b/examples/process_helloworld.py index cf043eba..23a2929b 100644 --- a/examples/process_helloworld.py +++ b/examples/process_helloworld.py @@ -10,7 +10,7 @@ def define(cls, spec): spec.input('name', default='World', required=True) spec.output('greeting', valid_type=str) - def run(self): + async def run(self): self.out('greeting', f'Hello {self.inputs.name}!') return plumpy.Stop(None, True) diff --git a/examples/process_launch.py b/examples/process_launch.py index 645af0fd..b3a212f6 100644 --- a/examples/process_launch.py +++ b/examples/process_launch.py @@ -18,7 +18,7 @@ def define(cls, spec): spec.outputs.dynamic = True spec.output('default', valid_type=int) - def run(self): + async def run(self): self.out('default', 5) diff --git a/examples/process_wait_and_resume.py b/examples/process_wait_and_resume.py index 03e8b57a..008ea4fb 100644 --- a/examples/process_wait_and_resume.py +++ b/examples/process_wait_and_resume.py @@ -6,7 +6,7 @@ class WaitForResumeProc(plumpy.Process): - def run(self): + async def run(self): print(f'Now I am running: {self.state}') return plumpy.Wait(self.after_resume_and_exec) diff --git a/src/plumpy/futures.py b/src/plumpy/futures.py index 365b8008..2de36ff1 100644 --- a/src/plumpy/futures.py +++ b/src/plumpy/futures.py @@ -3,7 +3,7 @@ Module containing future related methods and classes """ import asyncio -from typing import Any, Callable, Coroutine, Optional +from typing import Any, Awaitable, Callable, Optional import kiwipy @@ -54,7 +54,7 @@ def run(self, *args: Any, **kwargs: Any) -> None: self._action = None # type: ignore -def create_task(coro: Callable[[], Coroutine], loop: Optional[asyncio.AbstractEventLoop] = None) -> Future: +def create_task(coro: Callable[[], Awaitable[Any]], loop: Optional[asyncio.AbstractEventLoop] = None) -> Future: """ Schedule a call to a coro in the event loop and wrap the outcome in a future. diff --git a/src/plumpy/process_states.py b/src/plumpy/process_states.py index da91b506..fcd57742 100644 --- a/src/plumpy/process_states.py +++ b/src/plumpy/process_states.py @@ -4,7 +4,7 @@ import sys import traceback from types import TracebackType -from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Type, Union, cast +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Type, Union, cast import yaml from yaml.loader import Loader @@ -20,7 +20,7 @@ from .base import state_machine from .lang import NULL from .persistence import auto_persist -from .utils import SAVED_STATE_TYPE +from .utils import SAVED_STATE_TYPE, ensure_coroutine __all__ = [ 'ProcessState', @@ -195,10 +195,12 @@ class Running(State): _running: bool = False _run_handle = None - def __init__(self, process: 'Process', run_fn: Callable[..., Any], *args: Any, **kwargs: Any) -> None: + def __init__( + self, process: 'Process', run_fn: Callable[..., Union[Awaitable[Any], Any]], *args: Any, **kwargs: Any + ) -> None: super().__init__(process) assert run_fn is not None - self.run_fn = run_fn + self.run_fn = ensure_coroutine(run_fn) self.args = args self.kwargs = kwargs self._run_handle = None @@ -211,7 +213,7 @@ def save_instance_state(self, out_state: SAVED_STATE_TYPE, save_context: persist def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persistence.LoadSaveContext) -> None: super().load_instance_state(saved_state, load_context) - self.run_fn = getattr(self.process, saved_state[self.RUN_FN]) + self.run_fn = ensure_coroutine(getattr(self.process, saved_state[self.RUN_FN])) if self.COMMAND in saved_state: self._command = persistence.Savable.load(saved_state[self.COMMAND], load_context) # type: ignore @@ -225,7 +227,7 @@ async def execute(self) -> State: # type: ignore # pylint: disable=invalid-over try: try: self._running = True - result = self.run_fn(*self.args, **self.kwargs) + result = await self.run_fn(*self.args, **self.kwargs) finally: self._running = False except Interruption: diff --git a/src/plumpy/processes.py b/src/plumpy/processes.py index 40b2ccbb..2285955c 100644 --- a/src/plumpy/processes.py +++ b/src/plumpy/processes.py @@ -1182,7 +1182,7 @@ def recreate_state(self, saved_state: persistence.Bundle) -> process_states.Stat # region Execution related methods - def run(self) -> Any: + async def run(self) -> Any: """This function will be run when the process is triggered. It should be overridden by a subclass. """ diff --git a/src/plumpy/utils.py b/src/plumpy/utils.py index a11ebd01..4b5863d3 100644 --- a/src/plumpy/utils.py +++ b/src/plumpy/utils.py @@ -7,8 +7,20 @@ import inspect import logging import types -from typing import Set # pylint: disable=unused-import -from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterator, List, MutableMapping, Optional, Tuple, Type +from typing import ( # pylint: disable=unused-import + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Hashable, + Iterator, + List, + MutableMapping, + Optional, + Set, + Tuple, + Type, +) from . import lang from .settings import check_override, check_protected @@ -221,7 +233,7 @@ def type_check(obj: Any, expected_type: Type) -> None: raise TypeError(f"Got object of type '{type(obj)}' when expecting '{expected_type}'") -def ensure_coroutine(coro_or_fn: Any) -> Callable[..., Any]: +def ensure_coroutine(coro_or_fn: Any) -> Callable[..., Awaitable[Any]]: """ Ensure that the given function ``fct`` is a coroutine diff --git a/src/plumpy/workchains.py b/src/plumpy/workchains.py index 90e35482..95e192b1 100644 --- a/src/plumpy/workchains.py +++ b/src/plumpy/workchains.py @@ -156,7 +156,7 @@ def to_context(self, **kwargs: Union[asyncio.Future, processes.Process]) -> None awaitable = awaitable.future() self._awaitables[awaitable] = key - def run(self) -> Any: + async def run(self) -> Any: return self._do_step() def _do_step(self) -> Any: diff --git a/test/test_process_comms.py b/test/test_process_comms.py index 6d3d335c..1c6f4dcb 100644 --- a/test/test_process_comms.py +++ b/test/test_process_comms.py @@ -12,7 +12,7 @@ class Process(plumpy.Process): - def run(self): + async def run(self): pass diff --git a/test/test_processes.py b/test/test_processes.py index 737b463d..aa283b76 100644 --- a/test/test_processes.py +++ b/test/test_processes.py @@ -327,7 +327,7 @@ def test_logging(self): class LoggerTester(Process): - def run(self, **kwargs): + async def run(self, **kwargs): self.logger.info('Test') # TODO: Test giving a custom logger to see if it gets used @@ -442,7 +442,7 @@ def test_kill_in_run(self): class KillProcess(Process): after_kill = False - def run(self, **kwargs): + async def run(self, **kwargs): self.kill('killed') # The following line should be executed because kill will not # interrupt execution of a method call in the RUNNING state @@ -459,7 +459,7 @@ def test_kill_when_paused_in_run(self): class PauseProcess(Process): - def run(self, **kwargs): + async def run(self, **kwargs): self.pause() self.kill() @@ -513,7 +513,7 @@ def test_invalid_output(self): class InvalidOutput(plumpy.Process): - def run(self): + async def run(self): self.out('invalid', 5) proc = InvalidOutput() @@ -541,7 +541,7 @@ class Proc(Process): def define(cls, spec): super().define(spec) - def run(self): + async def run(self): return plumpy.UnsuccessfulResult(ERROR_CODE) proc = Proc() @@ -555,7 +555,7 @@ def test_pause_in_process(self): class TestPausePlay(plumpy.Process): - def run(self): + async def run(self): fut = self.pause() test_case.assertIsInstance(fut, plumpy.Future) @@ -580,7 +580,7 @@ def test_pause_play_in_process(self): class TestPausePlay(plumpy.Process): - def run(self): + async def run(self): fut = self.pause() test_case.assertIsInstance(fut, plumpy.Future) result = self.play() @@ -597,7 +597,7 @@ def test_process_stack(self): class StackTest(plumpy.Process): - def run(self): + async def run(self): test_case.assertIs(self, Process.current()) proc = StackTest() @@ -614,7 +614,7 @@ def test_nested(process): class StackTest(plumpy.Process): - def run(self): + async def run(self): # TODO: unexpected behaviour here # if assert error happend here not raise # it will be handled by try except clause in process @@ -624,7 +624,7 @@ def run(self): class ParentProcess(plumpy.Process): - def run(self): + async def run(self): expect_true.append(self == Process.current()) StackTest().execute() @@ -647,12 +647,12 @@ def test_process_nested(self): class StackTest(plumpy.Process): - def run(self): + async def run(self): pass class ParentProcess(plumpy.Process): - def run(self): + async def run(self): StackTest().execute() ParentProcess().execute() @@ -661,7 +661,7 @@ def test_call_soon(self): class CallSoon(plumpy.Process): - def run(self): + async def run(self): self.call_soon(self.do_except) def do_except(self): @@ -699,7 +699,7 @@ def test_exception_during_run(self): class RaisingProcess(Process): - def run(self): + async def run(self): raise RuntimeError('exception during run') process = RaisingProcess() @@ -719,7 +719,7 @@ def init(self): super().init() self.steps_ran = [] - def run(self): + async def run(self): self.pause() self.steps_ran.append(self.run.__name__) return plumpy.Continue(self.step2) @@ -811,6 +811,7 @@ def test_saving_each_step(self): saver = utils.ProcessSaver(proc) saver.capture() self.assertEqual(proc.state, ProcessState.FINISHED) + print(proc) self.assertTrue(utils.check_process_against_snapshots(loop, proc_class, saver.snapshots)) def test_restart(self): @@ -980,7 +981,7 @@ def define(cls, spec): spec.output('required_bool', valid_type=bool) spec.output_namespace(namespace, valid_type=int, dynamic=True) - def run(self): + async def run(self): if self.inputs.output_mode == OutputMode.NONE: pass elif self.inputs.output_mode == OutputMode.DYNAMIC_PORT_NAMESPACE: diff --git a/test/test_workchains.py b/test/test_workchains.py index 71cd0f6a..3c449309 100644 --- a/test/test_workchains.py +++ b/test/test_workchains.py @@ -205,7 +205,7 @@ def define(cls, spec): super().define(spec) spec.output('res') - def run(self): + async def run(self): self.out('res', A) class ReturnB(plumpy.Process): @@ -215,7 +215,7 @@ def define(cls, spec): super().define(spec) spec.output('res') - def run(self): + async def run(self): self.out('res', B) class Wf(WorkChain): @@ -362,7 +362,7 @@ def define(cls, spec): spec.outline(cls.run, cls.check) spec.outputs.dynamic = True - def run(self): + async def run(self): return ToContext(subwc=self.launch(SubWorkChain)) def check(self): @@ -375,7 +375,7 @@ def define(cls, spec): super().define(spec) spec.outline(cls.run) - def run(self): + async def run(self): self.out('value', 5) workchain = MainWorkChain() @@ -419,7 +419,7 @@ def define(cls, spec): super().define(spec) spec.output('_return') - def run(self): + async def run(self): self.out('_return', val) class Workchain(WorkChain): diff --git a/test/utils.py b/test/utils.py index feb3d1c8..65d5355d 100644 --- a/test/utils.py +++ b/test/utils.py @@ -40,7 +40,7 @@ def define(cls, spec): spec.outputs.dynamic = True spec.output('default', valid_type=int) - def run(self, **kwargs): + async def run(self, **kwargs): self.out('default', 5) @@ -53,21 +53,21 @@ def define(cls, spec): spec.inputs.dynamic = True spec.outputs.dynamic = True - def run(self, **kwargs): + async def run(self, **kwargs): self.out('default', 5) class KeyboardInterruptProc(processes.Process): @utils.override - def run(self): + async def run(self): raise KeyboardInterrupt() class ProcessWithCheckpoint(processes.Process): @utils.override - def run(self): + async def run(self): return process_states.Continue(self.last_step) def last_step(self): @@ -77,7 +77,7 @@ def last_step(self): class WaitForSignalProcess(processes.Process): @utils.override - def run(self): + async def run(self): return process_states.Wait(self.last_step) def last_step(self): @@ -87,7 +87,7 @@ def last_step(self): class KillProcess(processes.Process): @utils.override - def run(self): + async def run(self): return process_states.Kill('killed') @@ -171,7 +171,7 @@ def define(cls, spec): super().define(spec) spec.outputs.dynamic = True - def run(self): + async def run(self): self.out('test', 5) @@ -181,7 +181,7 @@ class ThreeSteps(ProcessEventsTester): _last_checkpoint = None @utils.override - def run(self): + async def run(self): self.out('test', 5) return process_states.Continue(self.middle_step) @@ -194,7 +194,7 @@ def last_step(self): class TwoCheckpointNoFinish(ProcessEventsTester): - def run(self): + async def run(self): self.out('test', 5) return process_states.Continue(self.middle_step) @@ -204,7 +204,7 @@ def middle_step(self): class ExceptionProcess(ProcessEventsTester): - def run(self): + async def run(self): self.out('test', 5) raise RuntimeError('Great scott!')