Skip to content

Commit 8f1d693

Browse files
DarkLight1337dtrifiro
authored andcommitted
[Bugfix][CI/Build] Fix test and improve code for merge_async_iterators (vllm-project#5096)
1 parent 07a7677 commit 8f1d693

File tree

3 files changed

+62
-45
lines changed

3 files changed

+62
-45
lines changed

tests/async_engine/test_merge_async_iterators.py

-41
This file was deleted.

tests/test_utils.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,64 @@
1+
import asyncio
2+
import sys
3+
from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol,
4+
Tuple, TypeVar)
5+
16
import pytest
27

3-
from vllm.utils import deprecate_kwargs
8+
from vllm.utils import deprecate_kwargs, merge_async_iterators
49

510
from .utils import error_on_warning
611

12+
if sys.version_info < (3, 10):
13+
if TYPE_CHECKING:
14+
_AwaitableT = TypeVar("_AwaitableT", bound=Awaitable[Any])
15+
_AwaitableT_co = TypeVar("_AwaitableT_co",
16+
bound=Awaitable[Any],
17+
covariant=True)
18+
19+
class _SupportsSynchronousAnext(Protocol[_AwaitableT_co]):
20+
21+
def __anext__(self) -> _AwaitableT_co:
22+
...
23+
24+
def anext(i: "_SupportsSynchronousAnext[_AwaitableT]", /) -> "_AwaitableT":
25+
return i.__anext__()
26+
27+
28+
@pytest.mark.asyncio
29+
async def test_merge_async_iterators():
30+
31+
async def mock_async_iterator(idx: int) -> AsyncIterator[str]:
32+
try:
33+
while True:
34+
yield f"item from iterator {idx}"
35+
await asyncio.sleep(0.1)
36+
except asyncio.CancelledError:
37+
pass
38+
39+
iterators = [mock_async_iterator(i) for i in range(3)]
40+
merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators(
41+
*iterators)
42+
43+
async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
44+
async for idx, output in generator:
45+
print(f"idx: {idx}, output: {output}")
46+
47+
task = asyncio.create_task(stream_output(merged_iterator))
48+
await asyncio.sleep(0.5)
49+
task.cancel()
50+
with pytest.raises(asyncio.CancelledError):
51+
await task
52+
53+
for iterator in iterators:
54+
try:
55+
await asyncio.wait_for(anext(iterator), 1)
56+
except StopAsyncIteration:
57+
# All iterators should be cancelled and print this message.
58+
print("Iterator was cancelled normally")
59+
except (Exception, asyncio.CancelledError) as e:
60+
raise AssertionError() from e
61+
762

863
def test_deprecate_kwargs_always():
964

vllm/utils.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
import socket
77
import subprocess
8+
import sys
89
import tempfile
910
import threading
1011
import uuid
@@ -234,9 +235,11 @@ async def consumer():
234235
yield item
235236
except (Exception, asyncio.CancelledError) as e:
236237
for task in _tasks:
237-
# NOTE: Pass the error msg in cancel()
238-
# when only Python 3.9+ is supported.
239-
task.cancel()
238+
if sys.version_info >= (3, 9):
239+
# msg parameter only supported in Python 3.9+
240+
task.cancel(e)
241+
else:
242+
task.cancel()
240243
raise e
241244
await asyncio.gather(*_tasks)
242245

0 commit comments

Comments
 (0)