Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 handle MistralTokenizer special case #162

Merged
merged 5 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 37 additions & 29 deletions src/vllm_tgis_adapter/grpc/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@
import time
import uuid
from collections.abc import Callable, Coroutine
from typing import (
TYPE_CHECKING,
Any,
TypeVar,
)
from typing import TYPE_CHECKING, Any, TypeVar

import grpc
from grpc import StatusCode, aio
Expand All @@ -26,6 +22,8 @@
extract_trace_headers,
log_tracing_disabled_warning,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer # noqa: TCH002
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import iterate_with_cancellation

from vllm_tgis_adapter.logging import init_logger
Expand Down Expand Up @@ -64,7 +62,6 @@
from collections.abc import AsyncIterator, MutableSequence

from grpc.aio import ServicerContext
from transformers import PreTrainedTokenizer
from vllm import CompletionOutput, RequestOutput
from vllm.config import ModelConfig

Expand Down Expand Up @@ -473,7 +470,7 @@ def _convert_input_details(
resp_options: ResponseOptions,
sampling_params: SamplingParams,
response: GenerationResponse,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> GenerationResponse:
if result.prompt_token_ids:
response.input_token_count = len(result.prompt_token_ids)
Expand Down Expand Up @@ -504,7 +501,7 @@ def _convert_output( # noqa: PLR0913
*,
generated_token_count: int,
max_is_token_limit: bool,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
time_limit_reached: bool = False,
) -> GenerationResponse:
stop_reason, stop_sequence = self._convert_reason(
Expand Down Expand Up @@ -548,7 +545,7 @@ def request_id(context: ServicerContext) -> str:
async def _validate_and_convert_params(
self,
params: Parameters,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
context: ServicerContext,
) -> tuple[SamplingParams, float | None]:
"""Return (sampling_params, deadline)."""
Expand Down Expand Up @@ -681,9 +678,7 @@ async def _validate_adapters(
await context.abort(StatusCode.INVALID_ARGUMENT, str(e))
return adapters

async def _get_tokenizer(
self, adapter_kwargs: dict[str, Any]
) -> PreTrainedTokenizer:
async def _get_tokenizer(self, adapter_kwargs: dict[str, Any]) -> AnyTokenizer:
return await self.engine.get_tokenizer(
adapter_kwargs.get("lora_request"),
)
Expand All @@ -694,7 +689,7 @@ def _convert_reason(
*,
max_is_token_limit: bool,
time_limit_reached: bool,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
) -> tuple[StopReason, str | None]:
finish_reason = output.finish_reason
stop_sequence = None
Expand Down Expand Up @@ -735,7 +730,7 @@ def _convert_tokens( # noqa: PLR0913
include_logprobs: bool,
include_ranks: bool,
top_n_tokens: int,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
token_infos: MutableSequence[TokenInfo], # OUT
token_start_offset: int = 0,
) -> None:
Expand Down Expand Up @@ -789,7 +784,7 @@ async def _validate_prompt_and_tokenize(
sampling_params: SamplingParams,
truncate_input_tokens: int | None,
prompt: str,
tokenizer: PreTrainedTokenizer,
tokenizer: AnyTokenizer,
context: ServicerContext,
) -> tuple[list[int], bool]:
assert self.config is not None
Expand Down Expand Up @@ -861,18 +856,25 @@ async def Tokenize(
tokenizer = await self._get_tokenizer(adapter_kwargs)

responses: list[TokenizeResponse] = []
is_mistral_tokenizer = isinstance(tokenizer, MistralTokenizer)

# TODO: maybe parallelize, also move convert_ids_to_tokens into the
# other threads
for req in request.requests:
batch_encoding = tokenizer.encode_plus(
text=req.text,
return_offsets_mapping=request.return_offsets,
add_special_tokens=ADD_SPECIAL_TOKENS,
)
if is_mistral_tokenizer:
prashantgupta24 marked this conversation as resolved.
Show resolved Hide resolved
token_ids = tokenizer.encode(
prompt=req.text,
)
else:
batch_encoding = tokenizer.encode_plus(
text=req.text,
return_offsets_mapping=request.return_offsets,
add_special_tokens=ADD_SPECIAL_TOKENS,
)

# Tokenize the input text
token_ids = batch_encoding.input_ids

# Tokenize the input text
token_ids = batch_encoding.input_ids
token_count = len(token_ids)

if 0 < request.truncate_input_tokens < token_count:
Expand All @@ -883,13 +885,19 @@ async def Tokenize(
offsets = None

if request.return_offsets:
offsets = [
{"start": start, "end": end}
for start, end in batch_encoding.offset_mapping
if start is not None and end is not None
]
# Truncate offset list if request.truncate_input_tokens
offsets = offsets[-token_count:]
if is_mistral_tokenizer:
logger.warning(
"Mistral tokenizer doesn't support "
"return_offsets at the moment. "
)
prashantgupta24 marked this conversation as resolved.
Show resolved Hide resolved
else:
offsets = [
{"start": start, "end": end}
for start, end in batch_encoding.offset_mapping
if start is not None and end is not None
]
# Truncate offset list if request.truncate_input_tokens
offsets = offsets[-token_count:]

tokens = tokens[-token_count:] if request.return_tokens else None

Expand Down
4 changes: 2 additions & 2 deletions src/vllm_tgis_adapter/tgis_utils/guided_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import concurrent.futures
from re import escape as regex_escape

from transformers import PreTrainedTokenizer
from vllm.model_executor.guided_decoding import outlines_decoding
from vllm.model_executor.guided_decoding.outlines_decoding import (
GuidedDecodingMode,
Expand All @@ -16,12 +15,13 @@
JSONLogitsProcessor,
RegexLogitsProcessor,
)
from vllm.transformers_utils.tokenizer import AnyTokenizer

from vllm_tgis_adapter.grpc.pb.generation_pb2 import DecodingParameters


async def get_outlines_guided_decoding_logits_processor(
decoding_params: DecodingParameters, tokenizer: PreTrainedTokenizer
decoding_params: DecodingParameters, tokenizer: AnyTokenizer
) -> JSONLogitsProcessor | RegexLogitsProcessor | None:
"""Check for guided decoding parameters.

Expand Down