Skip to content

Commit

Permalink
Add run_flow_in_subprocess utility (#16802)
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle authored Jan 28, 2025
1 parent 776d354 commit 10d2c75
Show file tree
Hide file tree
Showing 2 changed files with 404 additions and 3 deletions.
119 changes: 119 additions & 0 deletions src/prefect/flow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@

import asyncio
import logging
import multiprocessing
import multiprocessing.context
import os
import time
from contextlib import ExitStack, asynccontextmanager, contextmanager, nullcontext
from dataclasses import dataclass, field
from functools import wraps
from typing import (
Any,
AsyncGenerator,
Expand Down Expand Up @@ -37,9 +40,12 @@
from prefect.context import (
AsyncClientContext,
FlowRunContext,
SettingsContext,
SyncClientContext,
TagsContext,
get_settings_context,
hydrated_context,
serialize_context,
)
from prefect.exceptions import (
Abort,
Expand All @@ -62,6 +68,8 @@
should_persist_result,
)
from prefect.settings import PREFECT_DEBUG_MODE
from prefect.settings.context import get_current_settings
from prefect.settings.models.root import Settings
from prefect.states import (
Failed,
Pending,
Expand All @@ -83,6 +91,7 @@
from prefect.utilities.asyncutils import run_coro_as_sync
from prefect.utilities.callables import (
call_with_parameters,
cloudpickle_wrapped_call,
get_call_parameters,
parameters_to_args_kwargs,
)
Expand Down Expand Up @@ -1533,3 +1542,113 @@ def _flow_parameters(
parameters = flow_run.parameters if flow_run else {}
call_args, call_kwargs = parameters_to_args_kwargs(flow.fn, parameters)
return get_call_parameters(flow.fn, call_args, call_kwargs)


def run_flow_in_subprocess(
flow: "Flow[..., Any]",
flow_run: "FlowRun | None" = None,
parameters: dict[str, Any] | None = None,
wait_for: Iterable[PrefectFuture[R]] | None = None,
context: dict[str, Any] | None = None,
) -> multiprocessing.context.SpawnProcess:
"""
Run a flow in a subprocess.
Note the result of the flow will only be accessible if the flow is configured to
persist its result.
Args:
flow: The flow to run.
flow_run: The flow run object containing run metadata.
parameters: The parameters to use when invoking the flow.
wait_for: The futures to wait for before starting the flow.
context: A serialized context to hydrate before running the flow. If not provided,
the current context will be used. A serialized context should be provided if
this function is called in a separate memory space from the parent run (e.g.
in a subprocess or on another machine).
Returns:
A multiprocessing.context.SpawnProcess representing the process that is running the flow.
"""
from prefect.flow_engine import run_flow

@wraps(run_flow)
def run_flow_with_env(
*args: Any,
env: dict[str, str] | None = None,
**kwargs: Any,
):
"""
Wrapper function to update environment variables and settings before running the flow.
"""
engine_logger = logging.getLogger("prefect.engine")

os.environ.update(env or {})
settings_context = get_settings_context()
# Create a new settings context with a new settings object to pick up the updated
# environment variables
with SettingsContext(
profile=settings_context.profile,
settings=Settings(),
):
try:
maybe_coro = run_flow(*args, **kwargs)
if asyncio.iscoroutine(maybe_coro):
# This is running in a brand new process, so there won't be an existing
# event loop.
asyncio.run(maybe_coro)
except Abort as abort_signal:
abort_signal: Abort
if flow_run:
msg = f"Execution of flow run '{flow_run.id}' aborted by orchestrator: {abort_signal}"
else:
msg = f"Execution aborted by orchestrator: {abort_signal}"
engine_logger.info(msg)
exit(0)
except Pause as pause_signal:
pause_signal: Pause
if flow_run:
msg = f"Execution of flow run '{flow_run.id}' is paused: {pause_signal}"
else:
msg = f"Execution is paused: {pause_signal}"
engine_logger.info(msg)
exit(0)
except Exception:
if flow_run:
msg = f"Execution of flow run '{flow_run.id}' exited with unexpected exception"
else:
msg = "Execution exited with unexpected exception"
engine_logger.error(msg, exc_info=True)
exit(1)
except BaseException:
if flow_run:
msg = f"Execution of flow run '{flow_run.id}' interrupted by base exception"
else:
msg = "Execution interrupted by base exception"
engine_logger.error(msg, exc_info=True)
# Let the exit code be determined by the base exception type
raise

ctx = multiprocessing.get_context("spawn")

context = context or serialize_context()

process = ctx.Process(
target=cloudpickle_wrapped_call(
run_flow_with_env,
env=get_current_settings().to_environment_variables(exclude_unset=True)
| os.environ
| {
# TODO: make this a thing we can pass into the engine
"PREFECT__ENABLE_CANCELLATION_AND_CRASHED_HOOKS": "false",
},
flow=flow,
flow_run=flow_run,
parameters=parameters,
wait_for=wait_for,
context=context,
),
)
process.start()

return process
Loading

0 comments on commit 10d2c75

Please sign in to comment.