Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exploit vLLM options to return deltas/final-output only #137

Merged
merged 2 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 52 additions & 41 deletions src/vllm_tgis_adapter/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from grpc._cython.cygrpc import AbortError
from grpc_health.v1 import health, health_pb2, health_pb2_grpc
from grpc_reflection.v1alpha import reflection
from vllm import SamplingParams
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
from vllm.inputs import LLMInputs
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.tracing import (
contains_trace_headers,
extract_trace_headers,
Expand All @@ -42,6 +42,7 @@
ServiceMetrics,
TGISStatLogger,
)
from vllm_tgis_adapter.utils import to_list

from .adapters import AdapterStore, validate_adapters
from .pb import generation_pb2_grpc
Expand Down Expand Up @@ -244,6 +245,7 @@ async def Generate(
sampling_params, deadline = await self._validate_and_convert_params(
request.params, tokenizer, context
)
sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
truncate_input_tokens = with_default(request.params.truncate_input_tokens, None)
request_count = len(request.requests)

Expand Down Expand Up @@ -302,12 +304,14 @@ async def is_cancelled() -> bool:

for i in range(len(responses)):
res = responses[i]
output = res.outputs[0]
response = self._convert_output(
res.outputs[0],
output,
resp_options,
max_is_token_limit=max_is_token_limit[i],
tokenizer=tokenizer,
time_limit_reached=time_limit_reached,
generated_token_count=len(output.token_ids),
)
response = self._convert_input_details(
res, resp_options, sampling_params, response, tokenizer
Expand All @@ -327,7 +331,7 @@ async def is_cancelled() -> bool:
return BatchedGenerationResponse(responses=responses)

@log_rpc_handler_errors
async def GenerateStream( # noqa: PLR0915
async def GenerateStream( # noqa: PLR0915, C901
self,
request: SingleGenerationRequest,
context: ServicerContext,
Expand All @@ -341,6 +345,7 @@ async def GenerateStream( # noqa: PLR0915
sampling_params, deadline = await self._validate_and_convert_params(
request.params, tokenizer, context
)
sampling_params.output_kind = RequestOutputKind.DELTA
truncate_input_tokens = with_default(request.params.truncate_input_tokens, None)

input_ids, max_is_tok_limit = await self._validate_prompt_and_tokenize(
Expand Down Expand Up @@ -381,17 +386,23 @@ async def is_cancelled() -> bool:

first_response = None
last_response = None
last_output_length = 0
last_token_count = 0
generated_token_count = 0
time_limit_reached = False
full_output = ""
last_engine_response = None
async for result in result_generator:
if result.prompt is None:
result.prompt = request.request.text
last_engine_response = result
if first_response is None:
service_metrics.observe_queue_time(result)
# In chunked prefill case it's possible that there will be
# multiple prompt-only outputs
if first_response is None or (
result.prompt_token_ids and not generated_token_count
):
if first_response is None:
service_metrics.observe_queue_time(result)

if result.prompt is None:
result.prompt = request.request.text

first_response = self._convert_input_details(
result,
resp_options,
Expand All @@ -402,28 +413,33 @@ async def is_cancelled() -> bool:
last_response = first_response
yield first_response

output = result.outputs[0]

if deadline is not None and time.time() >= deadline:
await self.engine.abort(request_id)
time_limit_reached = True

# Convert output text and token_ids to deltas
output = result.outputs[0]
generated_token_count += len(output.token_ids)

if (
not generated_token_count
and not output.finish_reason
and not time_limit_reached
):
continue

# Convert output text and token_ids
last_response = self._convert_output(
output,
resp_options,
max_is_token_limit=max_is_tok_limit,
tokenizer=tokenizer,
time_limit_reached=time_limit_reached,
text_start_offset=last_output_length,
token_start_offset=last_token_count,
generated_token_count=generated_token_count,
)
yield last_response

last_output_length = len(output.text)
last_token_count = len(output.token_ids)
# Save full output for logging
full_output = output.text
full_output += output.text

if time_limit_reached:
break
Expand Down Expand Up @@ -459,19 +475,20 @@ def _convert_input_details(
response: GenerationResponse,
tokenizer: PreTrainedTokenizer,
) -> GenerationResponse:
response.input_token_count = len(result.prompt_token_ids)
if resp_options.input_tokens:
self._convert_tokens(
result.prompt_token_ids,
result.prompt_logprobs,
include_logprobs=resp_options.token_logprobs,
include_ranks=resp_options.token_ranks,
top_n_tokens=resp_options.top_n_tokens,
tokenizer=tokenizer,
token_infos=response.input_tokens,
)
if result.prompt_token_ids:
response.input_token_count = len(result.prompt_token_ids)
if resp_options.input_tokens:
self._convert_tokens(
result.prompt_token_ids,
result.prompt_logprobs,
include_logprobs=resp_options.token_logprobs,
include_ranks=resp_options.token_ranks,
top_n_tokens=resp_options.top_n_tokens,
tokenizer=tokenizer,
token_infos=response.input_tokens,
)

if resp_options.input_text:
if resp_options.input_text and result.prompt:
response.text = (
result.prompt if not response.text else result.prompt + response.text
)
Expand All @@ -485,11 +502,10 @@ def _convert_output( # noqa: PLR0913
output: CompletionOutput,
resp_options: ResponseOptions,
*,
generated_token_count: int,
max_is_token_limit: bool,
tokenizer: PreTrainedTokenizer,
time_limit_reached: bool = False,
text_start_offset: int = 0,
token_start_offset: int = 0,
) -> GenerationResponse:
stop_reason, stop_sequence = self._convert_reason(
output,
Expand All @@ -498,22 +514,21 @@ def _convert_output( # noqa: PLR0913
tokenizer=tokenizer,
)
response = GenerationResponse(
text=output.text[text_start_offset:],
generated_token_count=len(output.token_ids),
text=output.text,
generated_token_count=generated_token_count,
stop_reason=stop_reason,
stop_sequence=stop_sequence,
)

if resp_options.generated_tokens:
self._convert_tokens(
list(output.token_ids),
to_list(output.token_ids),
output.logprobs,
include_logprobs=resp_options.token_logprobs,
include_ranks=resp_options.token_ranks,
top_n_tokens=resp_options.top_n_tokens,
tokenizer=tokenizer,
token_infos=response.tokens,
token_start_offset=token_start_offset,
)
return response

Expand Down Expand Up @@ -713,7 +728,7 @@ def _convert_reason(
return stop_reason, stop_sequence

@staticmethod
def _convert_tokens( # noqa: PLR0913 C901
def _convert_tokens( # noqa: PLR0913
token_ids: list[int],
logprobs_list: list[dict[int, Logprob] | None] | None,
*,
Expand All @@ -731,11 +746,7 @@ def _convert_tokens( # noqa: PLR0913 C901
token_texts = tokenizer.convert_ids_to_tokens(token_ids)
for i, text in enumerate(token_texts):
token_info = TokenInfo(text=text)
if logprobs_list is None:
token_infos.append(token_info)
continue

logprobs = logprobs_list[i]
logprobs = logprobs_list[i] if logprobs_list else None
# Logprobs entry will be None for first prompt token
if logprobs is None:
token_infos.append(token_info)
Expand Down
6 changes: 5 additions & 1 deletion src/vllm_tgis_adapter/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from collections.abc import Iterable
from collections.abc import Iterable, Sequence


def check_for_failed_tasks(tasks: Iterable[asyncio.Task]) -> None:
Expand Down Expand Up @@ -34,3 +34,7 @@ def write_termination_log(msg: str, file: str = "/dev/termination-log") -> None:

logger = init_logger("vllm-tgis-adapter")
logger.exception("Unable to write termination logs to %s", file)


def to_list(seq: Sequence[int]) -> list[int]:
return seq if isinstance(seq, list) else list(seq)