Skip to content

Commit

Permalink
[Frontend][Core] Move merge_async_iterators to utils (vllm-project#…
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored Apr 12, 2024
1 parent 1096717 commit 7fd3949
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 39 deletions.
38 changes: 1 addition & 37 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional, Tuple)
Expand All @@ -17,7 +16,7 @@
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput
from vllm.utils import random_uuid
from vllm.utils import merge_async_iterators, random_uuid

logger = init_logger(__name__)

Expand Down Expand Up @@ -50,41 +49,6 @@ def parse_prompt_format(prompt) -> Tuple[bool, list]:
return prompt_is_tokens, prompts


def merge_async_iterators(*iterators):
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
queue = asyncio.Queue()

finished = [False] * len(iterators)

async def producer(i, iterator):
try:
async for item in iterator:
await queue.put((i, item))
except Exception as e:
await queue.put(e)
finished[i] = True

_tasks = [
asyncio.create_task(producer(i, iterator))
for i, iterator in enumerate(iterators)
]

async def consumer():
while not all(finished) or not queue.empty():
item = await queue.get()
if isinstance(item, Exception):
raise item
yield item
await asyncio.gather(*_tasks)

return consumer()


class OpenAIServingCompletion(OpenAIServing):

def __init__(self,
Expand Down
40 changes: 38 additions & 2 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from collections import OrderedDict, defaultdict
from functools import lru_cache, partial
from platform import uname
from typing import (Any, Awaitable, Callable, Dict, Generic, Hashable, List,
Optional, Tuple, TypeVar, Union)
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, Tuple, TypeVar, Union)

import psutil
import torch
Expand Down Expand Up @@ -181,6 +181,42 @@ def _async_wrapper(*args, **kwargs) -> asyncio.Future:
return _async_wrapper


def merge_async_iterators(
*iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()

finished = [False] * len(iterators)

async def producer(i: int, iterator: AsyncIterator[T]):
try:
async for item in iterator:
await queue.put((i, item))
except Exception as e:
await queue.put(e)
finished[i] = True

_tasks = [
asyncio.create_task(producer(i, iterator))
for i, iterator in enumerate(iterators)
]

async def consumer():
while not all(finished) or not queue.empty():
item = await queue.get()
if isinstance(item, Exception):
raise item
yield item
await asyncio.gather(*_tasks)

return consumer()


def get_ip() -> str:
host_ip = os.environ.get("HOST_IP")
if host_ip:
Expand Down

0 comments on commit 7fd3949

Please sign in to comment.