Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to completion API to truncate prompt tokens #3144

Merged
merged 13 commits into from
Apr 5, 2024
2 changes: 2 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ class CompletionRequest(BaseModel):
guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
truncate_prompt_tokens: Optional[int] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit can you make this a constrained integer: https://docs.pydantic.dev/2.3/api/types/#pydantic.types.conint

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0
Expand Down Expand Up @@ -225,6 +226,7 @@ def logit_bias_logits_processor(
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)

@model_validator(mode="before")
Expand Down
10 changes: 8 additions & 2 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,10 +300,16 @@ async def create_completion(self, request: CompletionRequest,
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
input_ids = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt)
request,
prompt_ids=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
else:
input_ids = self._validate_prompt_and_tokenize(
request, prompt=prompt)
request,
prompt=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)

generators.append(
self.engine.generate(None,
Expand Down
19 changes: 15 additions & 4 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ async def _post_init(self):
self.tokenizer = get_tokenizer(
engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode,
trust_remote_code=engine_model_config.trust_remote_code)
trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left")
tdoublep marked this conversation as resolved.
Show resolved Hide resolved

async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
Expand Down Expand Up @@ -147,15 +148,25 @@ def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None) -> List[int]:
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[int] = None) -> List[int]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids):
raise ValueError(
"Only one of prompt or prompt_ids should be provided.")

input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
prompt).input_ids
if prompt_ids is None:
tokenizer_kwargs = {} if truncate_prompt_tokens is None else {
"truncation": True,
"max_length": truncate_prompt_tokens,
}
input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
elif truncate_prompt_tokens is not None:
input_ids = prompt_ids[-truncate_prompt_tokens:]
else:
input_ids = prompt_ids

token_num = len(input_ids)

if request.max_tokens is None:
Expand Down
12 changes: 11 additions & 1 deletion vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ class SamplingParams:
tokens in the output. Defaults to True.
logits_processors: List of functions that modify logits based on
previously generated tokens.
truncate_prompt_tokens: If set to an integer k, will use only the last k
tokens from the prompt (i.e., left truncation). Defaults to None (i.e.,
no truncation).
"""

def __init__(
Expand Down Expand Up @@ -118,6 +121,7 @@ def __init__(
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None,
truncate_prompt_tokens: Optional[int] = None,
) -> None:
self.n = n
self.best_of = best_of if best_of is not None else n
Expand Down Expand Up @@ -150,6 +154,7 @@ def __init__(
self.spaces_between_special_tokens = spaces_between_special_tokens
self.logits_processors = logits_processors
self.include_stop_str_in_output = include_stop_str_in_output
self.truncate_prompt_tokens = truncate_prompt_tokens
self._verify_args()
if self.use_beam_search:
self._verify_beam_search()
Expand Down Expand Up @@ -197,6 +202,10 @@ def _verify_args(self) -> None:
if self.prompt_logprobs is not None and self.prompt_logprobs < 0:
raise ValueError(f"prompt_logprobs must be non-negative, got "
f"{self.prompt_logprobs}.")
if self.truncate_prompt_tokens is not None and self.truncate_prompt_tokens < 1:
raise ValueError(
f"truncate_prompt_tokens must be >= 1, got {self.truncate_prompt_tokens}"
)

def _verify_beam_search(self) -> None:
if self.best_of == 1:
Expand Down Expand Up @@ -276,4 +285,5 @@ def __repr__(self) -> str:
f"prompt_logprobs={self.prompt_logprobs}, "
f"skip_special_tokens={self.skip_special_tokens}, "
"spaces_between_special_tokens="
f"{self.spaces_between_special_tokens})")
f"{self.spaces_between_special_tokens}, "
f"truncate_prompt_tokens={self.truncate_prompt_tokens})")
Loading