Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt message arguments passing to process controller #6668

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
'importlib-metadata~=6.0',
'numpy~=1.21',
'paramiko~=3.0',
'plumpy~=0.22.3',
'plumpy',
'pgsu~=0.3.0',
'psutil~=5.6',
'psycopg[binary]~=3.0',
Expand Down Expand Up @@ -509,3 +509,7 @@ passenv =
AIIDA_TEST_WORKERS
commands = molecule {posargs:test}
"""

# FIXME: remove before merge
[tool.uv.sources]
plumpy = {git = "https://github.com/unkcpz/plumpy", rev = "560098c2b55a312b60884a0b8dfac97f6e8139d8"}
18 changes: 14 additions & 4 deletions src/aiida/cmdline/commands/cmd_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,13 @@ def process_kill(processes, all_entries, timeout, wait):

with capture_logging() as stream:
try:
message = 'Killed through `verdi process kill`'
control.kill_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, message=message)
control.kill_processes(
processes,
msg_text='Killed through `verdi process kill`',
all_entries=all_entries,
timeout=timeout,
wait=wait,
)
except control.ProcessTimeoutException as exception:
echo.echo_critical(f'{exception}\n{REPAIR_INSTRUCTIONS}')

Expand Down Expand Up @@ -371,8 +376,13 @@ def process_pause(processes, all_entries, timeout, wait):

with capture_logging() as stream:
try:
message = 'Paused through `verdi process pause`'
control.pause_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, message=message)
control.pause_processes(
processes,
msg_text='Paused through `verdi process pause`',
all_entries=all_entries,
timeout=timeout,
wait=wait,
)
except control.ProcessTimeoutException as exception:
echo.echo_critical(f'{exception}\n{REPAIR_INSTRUCTIONS}')

Expand Down
19 changes: 11 additions & 8 deletions src/aiida/engine/processes/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from __future__ import annotations

import collections
import concurrent
import concurrent.futures
import functools
import typing as t

import kiwipy
Expand Down Expand Up @@ -135,7 +136,7 @@ def play_processes(
def pause_processes(
processes: list[ProcessNode] | None = None,
*,
message: str = 'Paused through `aiida.engine.processes.control.pause_processes`',
msg_text: str = 'Paused through `aiida.engine.processes.control.pause_processes`',
all_entries: bool = False,
timeout: float = 5.0,
wait: bool = False,
Expand Down Expand Up @@ -164,13 +165,14 @@ def pause_processes(
return

controller = get_manager().get_process_controller()
_perform_actions(processes, controller.pause_process, 'pause', 'pausing', timeout, wait, msg=message)
action = functools.partial(controller.pause_process, msg_text=msg_text)
_perform_actions(processes, action, 'pause', 'pausing', timeout, wait)


def kill_processes(
processes: list[ProcessNode] | None = None,
*,
message: str = 'Killed through `aiida.engine.processes.control.kill_processes`',
msg_text: str = 'Killed through `aiida.engine.processes.control.kill_processes`',
all_entries: bool = False,
timeout: float = 5.0,
wait: bool = False,
Expand Down Expand Up @@ -199,12 +201,13 @@ def kill_processes(
return

controller = get_manager().get_process_controller()
_perform_actions(processes, controller.kill_process, 'kill', 'killing', timeout, wait, msg=message)
action = functools.partial(controller.kill_process, msg_text=msg_text)
_perform_actions(processes, action, 'kill', 'killing', timeout, wait)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @unkcpz
in _perform_actions we have:

 future = action(process.pk, **kwargs)

Therefore I would suggest, either put everything inside functools.partial so it would be serve as we suggested and discussed (one action that would be triggered), or probably take the msg_text out.
Right now, it's kinda hard to understand why because the rest of arguments are passed via **kwargs .

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, makes sense. I I removed the kwargs and add the typing for action argument as well.



def _perform_actions(
processes: list[ProcessNode],
action: t.Callable,
action: t.Callable[[int | None], kiwipy.Future],
infinitive: str,
present: str,
timeout: t.Optional[float] = None,
Expand All @@ -231,7 +234,7 @@ def _perform_actions(
continue

try:
future = action(process.pk, **kwargs)
future = action(process.pk)
except communications.UnroutableError:
LOGGER.error(f'Process<{process.pk}> is unreachable.')
else:
Expand All @@ -241,7 +244,7 @@ def _perform_actions(


def _resolve_futures(
futures: dict[concurrent.futures.Future, ProcessNode],
futures: dict[kiwipy.Future, ProcessNode],
infinitive: str,
present: str,
wait: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions src/aiida/engine/processes/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@
if kwargs and not process_class.spec().inputs.dynamic:
raise ValueError(f'{function.__name__} does not support these kwargs: {kwargs.keys()}')

process = process_class(inputs=inputs, runner=runner)
process: Process = process_class(inputs=inputs, runner=runner)

# Only add handlers for interrupt signal to kill the process if we are in a local and not a daemon runner.
# Without this check, running process functions in a daemon worker would be killed if the daemon is shutdown
Expand All @@ -235,7 +235,7 @@
def kill_process(_num, _frame):
"""Send the kill signal to the process in the current scope."""
LOGGER.critical('runner received interrupt, killing process %s', process.pid)
result = process.kill(msg='Process was killed because the runner received an interrupt')
result = process.kill(msg_text='Process was killed because the runner received an interrupt')

Check warning on line 238 in src/aiida/engine/processes/functions.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/engine/processes/functions.py#L238

Added line #L238 was not covered by tests
return result

# Store the current handler on the signal such that it can be restored after process has terminated
Expand Down
6 changes: 3 additions & 3 deletions src/aiida/engine/processes/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def load_instance_state(

self.node.logger.info(f'Loaded process<{self.node.pk}> from saved state')

def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.futures.Future]:
def kill(self, msg_text: str | None = None) -> Union[bool, plumpy.futures.Future]:
"""Kill the process and all the children calculations it called

:param msg: message
Expand All @@ -338,7 +338,7 @@ def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.futures.Futur

had_been_terminated = self.has_terminated()

result = super().kill(msg)
result = super().kill(msg_text)

# Only kill children if we could be killed ourselves
if result is not False and not had_been_terminated:
Expand All @@ -348,7 +348,7 @@ def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.futures.Futur
self.logger.info('no controller available to kill child<%s>', child.pk)
continue
try:
result = self.runner.controller.kill_process(child.pk, f'Killed by parent<{self.node.pk}>')
result = self.runner.controller.kill_process(child.pk, msg_text=f'Killed by parent<{self.node.pk}>')
result = asyncio.wrap_future(result) # type: ignore[arg-type]
if asyncio.isfuture(result):
killing.append(result)
Expand Down
2 changes: 1 addition & 1 deletion src/aiida/engine/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@
LOGGER.warning('runner received interrupt, process %s already being killed', process_inited.pid)
return
LOGGER.critical('runner received interrupt, killing process %s', process_inited.pid)
process_inited.kill(msg='Process was killed because the runner received an interrupt')
process_inited.kill(msg_text='Process was killed because the runner received an interrupt')

Check warning on line 253 in src/aiida/engine/runners.py

View check run for this annotation

Codecov / codecov/patch

src/aiida/engine/runners.py#L253

Added line #L253 was not covered by tests

original_handler_int = signal.getsignal(signal.SIGINT)
original_handler_term = signal.getsignal(signal.SIGTERM)
Expand Down
10 changes: 4 additions & 6 deletions tests/engine/test_rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ async def do_pause():
assert result
assert calc_node.paused

kill_message = 'Sorry, you have to go mate'
kill_future = controller.kill_process(calc_node.pk, msg=kill_message)
kill_future = controller.kill_process(calc_node.pk, msg_text='Sorry, you have to go mate')
future = await with_timeout(asyncio.wrap_future(kill_future))
result = await self.wait_future(asyncio.wrap_future(future))
assert result
Expand All @@ -112,7 +111,7 @@ async def do_pause_play():
await asyncio.sleep(0.1)

pause_message = 'Take a seat'
pause_future = controller.pause_process(calc_node.pk, msg=pause_message)
pause_future = controller.pause_process(calc_node.pk, msg_text=pause_message)
future = await with_timeout(asyncio.wrap_future(pause_future))
result = await self.wait_future(asyncio.wrap_future(future))
assert calc_node.paused
Expand All @@ -126,8 +125,7 @@ async def do_pause_play():
assert not calc_node.paused
assert calc_node.process_status is None

kill_message = 'Sorry, you have to go mate'
kill_future = controller.kill_process(calc_node.pk, msg=kill_message)
kill_future = controller.kill_process(calc_node.pk, msg_text='Sorry, you have to go mate')
future = await with_timeout(asyncio.wrap_future(kill_future))
result = await self.wait_future(asyncio.wrap_future(future))
assert result
Expand All @@ -145,7 +143,7 @@ async def do_kill():
await asyncio.sleep(0.1)

kill_message = 'Sorry, you have to go mate'
kill_future = controller.kill_process(calc_node.pk, msg=kill_message)
kill_future = controller.kill_process(calc_node.pk, msg_text=kill_message)
future = await with_timeout(asyncio.wrap_future(kill_future))
result = await self.wait_future(asyncio.wrap_future(future))
assert result
Expand Down
21 changes: 8 additions & 13 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading