Skip to content

Commit

Permalink
Exploit vLLM options to return deltas/final-output only
Browse files Browse the repository at this point in the history
Nontrivial performance benefit, particularly when running with decoupled front-end process.

These changes require vLLM >= 0.6.1.post2
  • Loading branch information
njhill authored and dtrifiro committed Sep 27, 2024
1 parent aa0a81b commit 00ac2f1
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 30 deletions.
68 changes: 39 additions & 29 deletions src/vllm_tgis_adapter/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from grpc._cython.cygrpc import AbortError
from grpc_health.v1 import health, health_pb2, health_pb2_grpc
from grpc_reflection.v1alpha import reflection
from vllm import SamplingParams
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
from vllm.inputs import LLMInputs
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.tracing import (
contains_trace_headers,
extract_trace_headers,
Expand All @@ -42,6 +42,7 @@
ServiceMetrics,
TGISStatLogger,
)
from vllm_tgis_adapter.utils import to_list

from .adapters import AdapterStore, validate_adapters
from .pb import generation_pb2_grpc
Expand Down Expand Up @@ -244,6 +245,7 @@ async def Generate(
sampling_params, deadline = await self._validate_and_convert_params(
request.params, tokenizer, context
)
sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
truncate_input_tokens = with_default(request.params.truncate_input_tokens, None)
request_count = len(request.requests)

Expand Down Expand Up @@ -302,12 +304,14 @@ async def is_cancelled() -> bool:

for i in range(len(responses)):
res = responses[i]
output = res.outputs[0]
response = self._convert_output(
res.outputs[0],
output,
resp_options,
max_is_token_limit=max_is_token_limit[i],
tokenizer=tokenizer,
time_limit_reached=time_limit_reached,
generated_token_count=len(output.token_ids),
)
response = self._convert_input_details(
res, resp_options, sampling_params, response, tokenizer
Expand All @@ -327,7 +331,7 @@ async def is_cancelled() -> bool:
return BatchedGenerationResponse(responses=responses)

@log_rpc_handler_errors
async def GenerateStream( # noqa: PLR0915
async def GenerateStream( # noqa: PLR0915, C901
self,
request: SingleGenerationRequest,
context: ServicerContext,
Expand All @@ -341,6 +345,7 @@ async def GenerateStream( # noqa: PLR0915
sampling_params, deadline = await self._validate_and_convert_params(
request.params, tokenizer, context
)
sampling_params.output_kind = RequestOutputKind.DELTA
truncate_input_tokens = with_default(request.params.truncate_input_tokens, None)

input_ids, max_is_tok_limit = await self._validate_prompt_and_tokenize(
Expand Down Expand Up @@ -381,17 +386,23 @@ async def is_cancelled() -> bool:

first_response = None
last_response = None
last_output_length = 0
last_token_count = 0
generated_token_count = 0
time_limit_reached = False
full_output = ""
last_engine_response = None
async for result in result_generator:
if result.prompt is None:
result.prompt = request.request.text
last_engine_response = result
if first_response is None:
service_metrics.observe_queue_time(result)
# In chunked prefill case it's possible that there will be
# multiple prompt-only outputs
if first_response is not None or (
result.prompt_token_ids and not generated_token_count
):
if first_response is None:
service_metrics.observe_queue_time(result)

if result.prompt is None:
result.prompt = request.request.text

first_response = self._convert_input_details(
result,
resp_options,
Expand All @@ -402,28 +413,33 @@ async def is_cancelled() -> bool:
last_response = first_response
yield first_response

output = result.outputs[0]

if deadline is not None and time.time() >= deadline:
await self.engine.abort(request_id)
time_limit_reached = True

# Convert output text and token_ids to deltas
output = result.outputs[0]
generated_token_count += len(output.token_ids)

if (
not generated_token_count
and not output.finish_reason
and not time_limit_reached
):
continue

# Convert output text and token_ids
last_response = self._convert_output(
output,
resp_options,
max_is_token_limit=max_is_tok_limit,
tokenizer=tokenizer,
time_limit_reached=time_limit_reached,
text_start_offset=last_output_length,
token_start_offset=last_token_count,
generated_token_count=generated_token_count,
)
yield last_response

last_output_length = len(output.text)
last_token_count = len(output.token_ids)
# Save full output for logging
full_output = output.text
full_output += output.text

if time_limit_reached:
break
Expand Down Expand Up @@ -485,11 +501,10 @@ def _convert_output( # noqa: PLR0913
output: CompletionOutput,
resp_options: ResponseOptions,
*,
generated_token_count: int,
max_is_token_limit: bool,
tokenizer: PreTrainedTokenizer,
time_limit_reached: bool = False,
text_start_offset: int = 0,
token_start_offset: int = 0,
) -> GenerationResponse:
stop_reason, stop_sequence = self._convert_reason(
output,
Expand All @@ -498,22 +513,21 @@ def _convert_output( # noqa: PLR0913
tokenizer=tokenizer,
)
response = GenerationResponse(
text=output.text[text_start_offset:],
generated_token_count=len(output.token_ids),
text=output.text,
generated_token_count=generated_token_count,
stop_reason=stop_reason,
stop_sequence=stop_sequence,
)

if resp_options.generated_tokens:
self._convert_tokens(
list(output.token_ids),
to_list(output.token_ids),
output.logprobs,
include_logprobs=resp_options.token_logprobs,
include_ranks=resp_options.token_ranks,
top_n_tokens=resp_options.top_n_tokens,
tokenizer=tokenizer,
token_infos=response.tokens,
token_start_offset=token_start_offset,
)
return response

Expand Down Expand Up @@ -713,7 +727,7 @@ def _convert_reason(
return stop_reason, stop_sequence

@staticmethod
def _convert_tokens( # noqa: PLR0913 C901
def _convert_tokens( # noqa: PLR0913
token_ids: list[int],
logprobs_list: list[dict[int, Logprob] | None] | None,
*,
Expand All @@ -731,11 +745,7 @@ def _convert_tokens( # noqa: PLR0913 C901
token_texts = tokenizer.convert_ids_to_tokens(token_ids)
for i, text in enumerate(token_texts):
token_info = TokenInfo(text=text)
if logprobs_list is None:
token_infos.append(token_info)
continue

logprobs = logprobs_list[i]
logprobs = logprobs_list[i] if logprobs_list else None
# Logprobs entry will be None for first prompt token
if logprobs is None:
token_infos.append(token_info)
Expand Down
6 changes: 5 additions & 1 deletion src/vllm_tgis_adapter/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from collections.abc import Iterable
from collections.abc import Iterable, Sequence


def check_for_failed_tasks(tasks: Iterable[asyncio.Task]) -> None:
Expand Down Expand Up @@ -34,3 +34,7 @@ def write_termination_log(msg: str, file: str = "/dev/termination-log") -> None:

logger = init_logger("vllm-tgis-adapter")
logger.exception("Unable to write termination logs to %s", file)


def to_list(seq: Sequence[int]) -> list[int]:
return seq if isinstance(seq, list) else list(seq)

0 comments on commit 00ac2f1

Please sign in to comment.