Skip to content

Commit

Permalink
fix bug where .serve exits early and add integration test (#15691)
Browse files Browse the repository at this point in the history
Co-authored-by: Alexander Streed <desertaxle@users.noreply.github.com>
  • Loading branch information
zzstoatzz and desertaxle authored Oct 15, 2024
1 parent 245d718 commit 3aa2d89
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 48 deletions.
11 changes: 5 additions & 6 deletions .github/workflows/integration-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,11 @@ jobs:
fail-fast: false
matrix:
server-version: [
# These versions correspond to Prefect image tags, the patch version is
# excluded to always pull the latest patch of each minor version. The ref
# should generally be set to the latest patch release for that version.
{version: "2.19", ref: "2.19.2", image: "prefecthq/prefect:2.19-python3.10"},
{version: "main", ref: "main"},
]
# These versions correspond to Prefect image tags, the patch version is
# excluded to always pull the latest patch of each minor version. The ref
# should generally be set to the latest patch release for that version.
{ version: "main", ref: "main" },
]

steps:
- uses: actions/checkout@v4
Expand Down
3 changes: 2 additions & 1 deletion flows/flow_retries.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from prefect import flow, task
from prefect.cache_policies import INPUTS, TASK_SOURCE

flow_run_count = 0
task_run_count = 0


@task
@task(cache_policy=INPUTS + TASK_SOURCE)
def happy_task():
global task_run_count
task_run_count += 1
Expand Down
55 changes: 55 additions & 0 deletions flows/serve_a_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import signal
import tempfile
from datetime import timedelta
from pathlib import Path

from prefect import flow
from prefect.settings import PREFECT_RUNNER_POLL_FREQUENCY, temporary_settings


@flow
def may_i_take_your_hat_sir(item: str, counter_dir: Path):
assert item == "hat", "I don't know how to do everything"
(counter_dir / f"{id(may_i_take_your_hat_sir)}.txt").touch()
return f"May I take your {item}?"


def timeout_handler(signum, frame):
raise TimeoutError("Timeout reached. Shutting down gracefully.")


def count_runs(counter_dir: Path):
return len(list(counter_dir.glob("*.txt")))


if __name__ == "__main__":
TIMEOUT: int = 10
INTERVAL_SECONDS: int = 3

EXPECTED_N_FLOW_RUNS: int = TIMEOUT // INTERVAL_SECONDS

signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(TIMEOUT)

with tempfile.TemporaryDirectory() as tmp_dir:
counter_dir = Path(tmp_dir) / "flow_run_counter"
counter_dir.mkdir(exist_ok=True)

with temporary_settings({PREFECT_RUNNER_POLL_FREQUENCY: 1}):
try:
may_i_take_your_hat_sir.serve(
interval=timedelta(seconds=INTERVAL_SECONDS),
parameters={"item": "hat", "counter_dir": counter_dir},
)
except TimeoutError as e:
print(str(e))
finally:
signal.alarm(0)

actual_run_count = count_runs(counter_dir)

assert (
actual_run_count >= EXPECTED_N_FLOW_RUNS
), f"Expected at least {EXPECTED_N_FLOW_RUNS} flow runs, got {actual_run_count}"

print(f"Successfully completed and audited {actual_run_count} flow runs")
14 changes: 11 additions & 3 deletions flows/task_results.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,36 @@
import time
from uuid import UUID

import anyio

from prefect import flow, task
from prefect.client.orchestration import get_client
from prefect.states import State


@task(persist_result=True)
def hello():
def hello() -> str:
return "Hello!"


@flow
def wrapper_flow():
def wrapper_flow() -> State[str]:
return hello(return_state=True)


async def get_state_from_api(task_run_id):
async def get_state_from_api(task_run_id: UUID) -> State[str]:
async with get_client() as client:
task_run = await client.read_task_run(task_run_id)
assert task_run.state is not None
return task_run.state


if __name__ == "__main__":
task_state = wrapper_flow()
assert task_state.result() == "Hello!"
assert task_state.state_details.task_run_id is not None

time.sleep(0.3) # wait for task run state to propagate

api_state = anyio.run(get_state_from_api, task_state.state_details.task_run_id)

Expand Down
13 changes: 10 additions & 3 deletions scripts/run-integration-flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,13 @@

def run_script(script_path: str):
print(f" {script_path} ".center(90, "-"), flush=True)
result = subprocess.run(["python", script_path], capture_output=True, text=True)
return result.stdout, result.stderr
try:
result = subprocess.run(
["uv", "run", script_path], capture_output=True, text=True, check=True
)
return result.stdout, result.stderr, None
except subprocess.CalledProcessError as e:
return e.stdout, e.stderr, e


def run_flows(search_path: Union[str, Path]):
Expand All @@ -39,10 +44,12 @@ def run_flows(search_path: Union[str, Path]):
with ProcessPoolExecutor(max_workers=4) as executor:
results = list(executor.map(run_script, scripts))

for script, (stdout, stderr) in zip(scripts, results):
for script, (stdout, stderr, error) in zip(scripts, results):
print(f" {script.relative_to(search_path)} ".center(90, "-"), flush=True)
print(stdout)
print(stderr)
if error:
raise error
print("".center(90, "-") + "\n", flush=True)
count += 1

Expand Down
2 changes: 1 addition & 1 deletion src/prefect/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,7 +1245,7 @@ def __call__(self: "Flow[P, NoReturn]", *args: P.args, **kwargs: P.kwargs) -> No
@overload
def __call__(
self: "Flow[P, Coroutine[Any, Any, T]]", *args: P.args, **kwargs: P.kwargs
) -> Awaitable[T]:
) -> Coroutine[Any, Any, T]:
...

@overload
Expand Down
74 changes: 40 additions & 34 deletions src/prefect/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def fast_flow():
import sys
import tempfile
import threading
from contextlib import AsyncExitStack
from copy import deepcopy
from functools import partial
from pathlib import Path
Expand Down Expand Up @@ -186,7 +185,6 @@ def goodbye_flow(name):
self.query_seconds = query_seconds or PREFECT_RUNNER_POLL_FREQUENCY.value()
self._prefetch_seconds = prefetch_seconds

self._exit_stack = AsyncExitStack()
self._limiter: Optional[anyio.CapacityLimiter] = None
self._client = get_client()
self._submitting_flow_run_ids = set()
Expand Down Expand Up @@ -400,37 +398,40 @@ def goodbye_flow(name):
start_client_metrics_server()

async with self as runner:
for storage in self._storage_objs:
if storage.pull_interval:
self._runs_task_group.start_soon(
partial(
critical_service_loop,
workload=storage.pull_code,
interval=storage.pull_interval,
run_once=run_once,
jitter_range=0.3,
# This task group isn't included in the exit stack because we want to
# stay in this function until the runner is told to stop
async with self._loops_task_group as loops_task_group:
for storage in self._storage_objs:
if storage.pull_interval:
loops_task_group.start_soon(
partial(
critical_service_loop,
workload=storage.pull_code,
interval=storage.pull_interval,
run_once=run_once,
jitter_range=0.3,
)
)
else:
loops_task_group.start_soon(storage.pull_code)
loops_task_group.start_soon(
partial(
critical_service_loop,
workload=runner._get_and_submit_flow_runs,
interval=self.query_seconds,
run_once=run_once,
jitter_range=0.3,
)
else:
self._runs_task_group.start_soon(storage.pull_code)
self._runs_task_group.start_soon(
partial(
critical_service_loop,
workload=runner._get_and_submit_flow_runs,
interval=self.query_seconds,
run_once=run_once,
jitter_range=0.3,
)
)
self._runs_task_group.start_soon(
partial(
critical_service_loop,
workload=runner._check_for_cancelled_flow_runs,
interval=self.query_seconds * 2,
run_once=run_once,
jitter_range=0.3,
loops_task_group.start_soon(
partial(
critical_service_loop,
workload=runner._check_for_cancelled_flow_runs,
interval=self.query_seconds * 2,
run_once=run_once,
jitter_range=0.3,
)
)
)

def execute_in_background(self, func, *args, **kwargs):
"""
Expand Down Expand Up @@ -1265,16 +1266,14 @@ async def __aenter__(self):
if not hasattr(self, "_loop") or not self._loop:
self._loop = asyncio.get_event_loop()

await self._exit_stack.__aenter__()
await self._client.__aenter__()

await self._exit_stack.enter_async_context(self._client)
if not hasattr(self, "_runs_task_group") or not self._runs_task_group:
self._runs_task_group: anyio.abc.TaskGroup = anyio.create_task_group()
await self._exit_stack.enter_async_context(self._runs_task_group)
await self._runs_task_group.__aenter__()

if not hasattr(self, "_loops_task_group") or not self._loops_task_group:
self._loops_task_group: anyio.abc.TaskGroup = anyio.create_task_group()
await self._exit_stack.enter_async_context(self._loops_task_group)

self.started = True
return self
Expand All @@ -1284,9 +1283,16 @@ async def __aexit__(self, *exc_info):
if self.pause_on_shutdown:
await self._pause_schedules()
self.started = False

for scope in self._scheduled_task_scopes:
scope.cancel()
await self._exit_stack.__aexit__(*exc_info)

if self._runs_task_group:
await self._runs_task_group.__aexit__(*exc_info)

if self._client:
await self._client.__aexit__(*exc_info)

shutil.rmtree(str(self._tmp_dir))
del self._runs_task_group, self._loops_task_group

Expand Down

0 comments on commit 3aa2d89

Please sign in to comment.