Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Overhaul async request cancellation #7111

Merged
merged 7 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions tests/async_engine/api_server_async_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""vllm.entrypoints.api_server with some extra logging for testing."""
from typing import Any, Dict
from typing import Any, Dict, Iterable

import uvicorn
from fastapi.responses import JSONResponse, Response
Expand All @@ -18,9 +18,10 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._num_aborts = 0

async def abort(self, request_id: str) -> None:
await super().abort(request_id)
self._num_aborts += 1
async def _engine_abort(self, request_ids: Iterable[str]):
ids = list(request_ids)
self._num_aborts += len(ids)
await super()._engine_abort(ids)

def testing_stats(self) -> Dict[str, Any]:
return {"num_aborted_requests": self._num_aborts}
Expand Down
25 changes: 12 additions & 13 deletions tests/async_engine/test_request_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,23 @@ async def test_request_tracker():
stream_1 = tracker.add_request("1")
assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests()
new, aborted = tracker.get_new_and_aborted_requests()
assert not tracker.new_requests_event.is_set()
assert len(new) == 1
assert new[0]["request_id"] == "1"
assert not finished
assert not aborted
assert not stream_1.finished

stream_2 = tracker.add_request("2")
stream_3 = tracker.add_request("3")
assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests()
new, aborted = tracker.get_new_and_aborted_requests()
assert not tracker.new_requests_event.is_set()
assert len(new) == 2
assert new[0]["request_id"] == "2"
assert new[1]["request_id"] == "3"
assert not finished
assert not aborted
assert not stream_2.finished
assert not stream_3.finished

Expand All @@ -36,19 +36,19 @@ async def test_request_tracker():
assert not tracker.new_requests_event.is_set()

tracker.abort_request("1")
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert "1" in finished
new, aborted = tracker.get_new_and_aborted_requests()
assert len(aborted) == 1
assert "1" in aborted
assert not new
assert stream_1.finished

stream_4 = tracker.add_request("4")
tracker.abort_request("4")
assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert "4" in finished
new, aborted = tracker.get_new_and_aborted_requests()
assert len(aborted) == 1
assert "4" in aborted
assert not new
assert stream_4.finished

Expand All @@ -57,10 +57,9 @@ async def test_request_tracker():
tracker.process_request_output(
RequestOutput("2", "output", [], [], [], finished=True))
await tracker.wait_for_new_requests()
new, finished = tracker.get_new_and_finished_requests()
new, aborted = tracker.get_new_and_aborted_requests()
assert not tracker.new_requests_event.is_set()
assert len(finished) == 1
assert "2" in finished
assert not aborted
assert len(new) == 1
assert new[0]["request_id"] == "5"
assert stream_2.finished
Expand Down
5 changes: 3 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import socket
import sys
from functools import partial
from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
Tuple, TypeVar)

Expand Down Expand Up @@ -37,11 +38,11 @@ async def mock_async_iterator(idx: int) -> AsyncIterator[str]:
yield f"item from iterator {idx}"
await asyncio.sleep(0.1)
except asyncio.CancelledError:
pass
print(f"iterator {idx} cancelled")

iterators = [mock_async_iterator(i) for i in range(3)]
merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators(
*iterators)
*iterators, is_cancelled=partial(asyncio.sleep, 0, result=False))

async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
async for idx, output in generator:
Expand Down
Loading
Loading