diff --git a/src/vllm_tgis_adapter/grpc/grpc_server.py b/src/vllm_tgis_adapter/grpc/grpc_server.py index 1ee9bcc..b61a8db 100644 --- a/src/vllm_tgis_adapter/grpc/grpc_server.py +++ b/src/vllm_tgis_adapter/grpc/grpc_server.py @@ -17,10 +17,10 @@ from grpc._cython.cygrpc import AbortError from grpc_health.v1 import health, health_pb2, health_pb2_grpc from grpc_reflection.v1alpha import reflection -from vllm import SamplingParams from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.serving_completion import merge_async_iterators from vllm.inputs import LLMInputs +from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.tracing import ( contains_trace_headers, extract_trace_headers, @@ -42,6 +42,7 @@ ServiceMetrics, TGISStatLogger, ) +from vllm_tgis_adapter.utils import to_list from .adapters import AdapterStore, validate_adapters from .pb import generation_pb2_grpc @@ -244,6 +245,7 @@ async def Generate( sampling_params, deadline = await self._validate_and_convert_params( request.params, tokenizer, context ) + sampling_params.output_kind = RequestOutputKind.FINAL_ONLY truncate_input_tokens = with_default(request.params.truncate_input_tokens, None) request_count = len(request.requests) @@ -302,12 +304,14 @@ async def is_cancelled() -> bool: for i in range(len(responses)): res = responses[i] + output = res.outputs[0] response = self._convert_output( - res.outputs[0], + output, resp_options, max_is_token_limit=max_is_token_limit[i], tokenizer=tokenizer, time_limit_reached=time_limit_reached, + generated_token_count=len(output.token_ids), ) response = self._convert_input_details( res, resp_options, sampling_params, response, tokenizer @@ -327,7 +331,7 @@ async def is_cancelled() -> bool: return BatchedGenerationResponse(responses=responses) @log_rpc_handler_errors - async def GenerateStream( # noqa: PLR0915 + async def GenerateStream( # noqa: PLR0915, C901 self, request: SingleGenerationRequest, context: ServicerContext, @@ -341,6 +345,7 @@ async def GenerateStream( # noqa: PLR0915 sampling_params, deadline = await self._validate_and_convert_params( request.params, tokenizer, context ) + sampling_params.output_kind = RequestOutputKind.DELTA truncate_input_tokens = with_default(request.params.truncate_input_tokens, None) input_ids, max_is_tok_limit = await self._validate_prompt_and_tokenize( @@ -381,17 +386,23 @@ async def is_cancelled() -> bool: first_response = None last_response = None - last_output_length = 0 - last_token_count = 0 + generated_token_count = 0 time_limit_reached = False full_output = "" last_engine_response = None async for result in result_generator: - if result.prompt is None: - result.prompt = request.request.text last_engine_response = result - if first_response is None: - service_metrics.observe_queue_time(result) + # In chunked prefill case it's possible that there will be + # multiple prompt-only outputs + if first_response is not None or ( + result.prompt_token_ids and not generated_token_count + ): + if first_response is None: + service_metrics.observe_queue_time(result) + + if result.prompt is None: + result.prompt = request.request.text + first_response = self._convert_input_details( result, resp_options, @@ -402,28 +413,33 @@ async def is_cancelled() -> bool: last_response = first_response yield first_response - output = result.outputs[0] - if deadline is not None and time.time() >= deadline: await self.engine.abort(request_id) time_limit_reached = True - # Convert output text and token_ids to deltas + output = result.outputs[0] + generated_token_count += len(output.token_ids) + + if ( + not generated_token_count + and not output.finish_reason + and not time_limit_reached + ): + continue + + # Convert output text and token_ids last_response = self._convert_output( output, resp_options, max_is_token_limit=max_is_tok_limit, tokenizer=tokenizer, time_limit_reached=time_limit_reached, - text_start_offset=last_output_length, - token_start_offset=last_token_count, + generated_token_count=generated_token_count, ) yield last_response - last_output_length = len(output.text) - last_token_count = len(output.token_ids) # Save full output for logging - full_output = output.text + full_output += output.text if time_limit_reached: break @@ -485,11 +501,10 @@ def _convert_output( # noqa: PLR0913 output: CompletionOutput, resp_options: ResponseOptions, *, + generated_token_count: int, max_is_token_limit: bool, tokenizer: PreTrainedTokenizer, time_limit_reached: bool = False, - text_start_offset: int = 0, - token_start_offset: int = 0, ) -> GenerationResponse: stop_reason, stop_sequence = self._convert_reason( output, @@ -498,22 +513,21 @@ def _convert_output( # noqa: PLR0913 tokenizer=tokenizer, ) response = GenerationResponse( - text=output.text[text_start_offset:], - generated_token_count=len(output.token_ids), + text=output.text, + generated_token_count=generated_token_count, stop_reason=stop_reason, stop_sequence=stop_sequence, ) if resp_options.generated_tokens: self._convert_tokens( - list(output.token_ids), + to_list(output.token_ids), output.logprobs, include_logprobs=resp_options.token_logprobs, include_ranks=resp_options.token_ranks, top_n_tokens=resp_options.top_n_tokens, tokenizer=tokenizer, token_infos=response.tokens, - token_start_offset=token_start_offset, ) return response @@ -713,7 +727,7 @@ def _convert_reason( return stop_reason, stop_sequence @staticmethod - def _convert_tokens( # noqa: PLR0913 C901 + def _convert_tokens( # noqa: PLR0913 token_ids: list[int], logprobs_list: list[dict[int, Logprob] | None] | None, *, @@ -731,11 +745,7 @@ def _convert_tokens( # noqa: PLR0913 C901 token_texts = tokenizer.convert_ids_to_tokens(token_ids) for i, text in enumerate(token_texts): token_info = TokenInfo(text=text) - if logprobs_list is None: - token_infos.append(token_info) - continue - - logprobs = logprobs_list[i] + logprobs = logprobs_list[i] if logprobs_list else None # Logprobs entry will be None for first prompt token if logprobs is None: token_infos.append(token_info) diff --git a/src/vllm_tgis_adapter/utils.py b/src/vllm_tgis_adapter/utils.py index e9df4e4..a77b1b1 100644 --- a/src/vllm_tgis_adapter/utils.py +++ b/src/vllm_tgis_adapter/utils.py @@ -1,5 +1,5 @@ import asyncio -from collections.abc import Iterable +from collections.abc import Iterable, Sequence def check_for_failed_tasks(tasks: Iterable[asyncio.Task]) -> None: @@ -34,3 +34,7 @@ def write_termination_log(msg: str, file: str = "/dev/termination-log") -> None: logger = init_logger("vllm-tgis-adapter") logger.exception("Unable to write termination logs to %s", file) + + +def to_list(seq: Sequence[int]) -> list[int]: + return seq if isinstance(seq, list) else list(seq)