Skip to content

Commit

Permalink
Fix earlier LoRA tokenizer changes
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored and dtrifiro committed Jul 22, 2024
1 parent 76cce9d commit 52dac4b
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions src/vllm_tgis_adapter/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,11 @@ async def Generate(
res.outputs[0],
resp_options,
max_is_token_limit=max_is_token_limit[i],
tokenizer=tokenizer,
time_limit_reached=time_limit_reached,
)
response = self._convert_input_details(
res, resp_options, sampling_params, response
res, resp_options, sampling_params, response, tokenizer
)
logs.log_response(
request=request,
Expand Down Expand Up @@ -368,7 +369,11 @@ async def GenerateStream(
if first_response is None:
service_metrics.observe_queue_time(result)
first_response = self._convert_input_details(
result, resp_options, sampling_params, GenerationResponse()
result,
resp_options,
sampling_params,
GenerationResponse(),
tokenizer,
)
last_response = first_response
yield first_response
Expand All @@ -384,6 +389,7 @@ async def GenerateStream(
output,
resp_options,
max_is_token_limit=max_is_tok_limit,
tokenizer=tokenizer,
time_limit_reached=time_limit_reached,
text_start_offset=last_output_length,
token_start_offset=last_token_count,
Expand Down Expand Up @@ -420,12 +426,13 @@ async def GenerateStream(
)
service_metrics.observe_generation_success(start_time=start_time)

def _convert_input_details(
def _convert_input_details( # noqa: PLR0913
self,
result: RequestOutput,
resp_options: ResponseOptions,
sampling_params: SamplingParams,
response: GenerationResponse,
tokenizer: PreTrainedTokenizer,
) -> GenerationResponse:
response.input_token_count = len(result.prompt_token_ids)
if resp_options.input_tokens:
Expand All @@ -435,6 +442,7 @@ def _convert_input_details(
include_logprobs=resp_options.token_logprobs,
include_ranks=resp_options.token_ranks,
top_n_tokens=resp_options.top_n_tokens,
tokenizer=tokenizer,
token_infos=response.input_tokens,
)

Expand All @@ -453,6 +461,7 @@ def _convert_output( # noqa: PLR0913
resp_options: ResponseOptions,
*,
max_is_token_limit: bool,
tokenizer: PreTrainedTokenizer,
time_limit_reached: bool = False,
text_start_offset: int = 0,
token_start_offset: int = 0,
Expand All @@ -471,11 +480,12 @@ def _convert_output( # noqa: PLR0913

if resp_options.generated_tokens:
self._convert_tokens(
output.token_ids,
list(output.token_ids),
output.logprobs,
include_logprobs=resp_options.token_logprobs,
include_ranks=resp_options.token_ranks,
top_n_tokens=resp_options.top_n_tokens,
tokenizer=tokenizer,
token_infos=response.tokens,
token_start_offset=token_start_offset,
)
Expand Down

0 comments on commit 52dac4b

Please sign in to comment.