Skip to content

Commit

Permalink
[Core] Reduce unnecessary compute when logprobs=None (vllm-project#6532)
Browse files Browse the repository at this point in the history
  • Loading branch information
peng1999 authored Jul 29, 2024
1 parent 766435e commit db9e570
Showing 4 changed files with 135 additions and 80 deletions.
39 changes: 37 additions & 2 deletions tests/samplers/test_logprobs.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
@pytest.mark.parametrize("dtype",
["float"]) # needed for comparing logprobs with HF
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
@pytest.mark.parametrize("num_top_logprobs", [0, 6]) # 32000 == vocab_size
@pytest.mark.parametrize("detokenize", [True, False])
def test_get_prompt_logprobs(
hf_runner,
@@ -63,7 +63,10 @@ def test_get_prompt_logprobs(
assert result.outputs[0].logprobs is not None
assert len(result.outputs[0].logprobs) == max_tokens
for logprobs in result.outputs[0].logprobs:
assert len(logprobs) == num_top_logprobs
# If the output token is not included in the top X
# logprob, it can return 1 more data
assert (len(logprobs) == num_top_logprobs
or len(logprobs) == num_top_logprobs + 1)
output_text = result.outputs[0].text
output_string_from_most_likely_tokens_lst: List[str] = []
for top_logprobs in result.outputs[0].logprobs:
@@ -135,3 +138,35 @@ def test_max_logprobs():
bad_sampling_params = SamplingParams(logprobs=2)
with pytest.raises(ValueError):
runner.generate(["Hello world"], sampling_params=bad_sampling_params)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
@pytest.mark.parametrize("detokenize", [True, False])
def test_none_logprobs(vllm_runner, model, chunked_prefill_token_size: int,
detokenize: bool, example_prompts):
max_num_seqs = 256
enable_chunked_prefill = False
max_num_batched_tokens = None
if chunked_prefill_token_size != -1:
enable_chunked_prefill = True
max_num_seqs = min(chunked_prefill_token_size, max_num_seqs)
max_num_batched_tokens = chunked_prefill_token_size
max_tokens = 5

with vllm_runner(
model,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,
) as vllm_model:
sampling_params_logprobs_none = SamplingParams(max_tokens=max_tokens,
logprobs=None,
temperature=0.0,
detokenize=detokenize)
results_logprobs_none = vllm_model.model.generate(
example_prompts, sampling_params=sampling_params_logprobs_none)

for i in range(len(results_logprobs_none)):
assert results_logprobs_none[i].outputs[0].logprobs is None
assert results_logprobs_none[i].outputs[0].cumulative_logprob is None
144 changes: 81 additions & 63 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A layer that samples the next tokens from the model's outputs."""
import itertools
from math import inf
from typing import Dict, List, Optional, Tuple

import torch
@@ -774,8 +775,11 @@ def _get_logprobs(
# The next token ids to get the logprob value from.
next_token_ids: List[int] = []
# The largest requested number of logprobs. We find logprobs as many as the
# largest num logprobs in this API.
largest_num_logprobs = 1
# largest num logprobs in this API. If every logprobs is None, it will be
# set to -1.
largest_num_logprobs = -1
# If beam search is enabled.
use_beam_search = False

# Select indices to compute logprob from, ranks of token ids, and the top
# k token ids from logprobs.
@@ -808,42 +812,49 @@ def _get_logprobs(
largest_num_logprobs = max(largest_num_logprobs,
sampling_params.logprobs)

use_beam_search = use_beam_search or sampling_params.use_beam_search

assert len(next_token_ids) == len(query_indices)

if len(query_indices) == 0:
empty_sampled_logprob: SampleLogprobs = []
empty_prompt_logprob: Optional[PromptLogprobs] = None
return [empty_prompt_logprob], [empty_sampled_logprob]

query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device)

# (num_selected_query_tokens, num_logprobs). Note that query_indices can
# contain duplicates if beam search is enabled.
selected_logprobs = logprobs[[
query_indices_gpu,
next_token_ids_gpu,
]]
ranks = _get_ranks(
logprobs[query_indices_gpu],
next_token_ids_gpu,
)
assert selected_logprobs.shape[0] == ranks.shape[0]

# Logprobs of topk tokens for a batch of sequence groups.
# (num_query_tokens_across_batch).
if largest_num_logprobs > 0:
top_logprobs, top_token_ids = torch.topk(logprobs,
largest_num_logprobs,
dim=-1)
else:
top_logprobs, top_token_ids = None, None
selected_logprobs, ranks = None, None
top_logprobs, top_token_ids = None, None

# If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
# skip the whole logprob calculation.
if largest_num_logprobs >= 0 or use_beam_search:
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
next_token_ids_gpu = torch.tensor(next_token_ids,
device=logprobs.device)

# (num_selected_query_tokens, num_logprobs). Note that query_indices can
# contain duplicates if beam search is enabled.
selected_logprobs = logprobs[[
query_indices_gpu,
next_token_ids_gpu,
]]
ranks = _get_ranks(
logprobs[query_indices_gpu],
next_token_ids_gpu,
)
assert selected_logprobs.shape[0] == ranks.shape[0]

# We need to compute top k only if there exists logprobs > 0.
if largest_num_logprobs > 0:
# Logprobs of topk tokens for a batch of sequence groups.
# (num_query_tokens_across_batch).
top_logprobs, top_token_ids = torch.topk(logprobs,
largest_num_logprobs,
dim=-1)
top_logprobs = top_logprobs.to('cpu')
top_token_ids = top_token_ids.to('cpu')

selected_logprobs = selected_logprobs.to('cpu')
ranks = ranks.to('cpu')
if top_logprobs is not None and top_token_ids is not None:
top_logprobs = top_logprobs.to('cpu')
top_token_ids = top_token_ids.to('cpu')
selected_logprobs = selected_logprobs.to('cpu')
ranks = ranks.to('cpu')

# Find prompt/sample logprobs.
prompt_logprobs_per_seq_group: List[Optional[PromptLogprobs]] = []
@@ -940,46 +951,53 @@ def _get_sampled_logprob_if_needed(
):
"""Compute the sample logprob if needed."""
seq_ids = seq_group.seq_ids
num_logprobs = seq_group.sampling_params.logprobs or 0
num_logprobs = seq_group.sampling_params.logprobs
use_beam_search = seq_group.sampling_params.use_beam_search
sampled_logprobs: SampleLogprobs = []
next_token_ids, parent_seq_ids = sample_result

if seq_group.do_sample:
assert len(next_token_ids) > 0
# Pre-select items from tensor. tolist() is faster than repetitive
# `.item()` calls.
selected_logprob_items = selected_logprobs[
selected_logprobs_idx:selected_logprobs_idx +
len(next_token_ids)].tolist()
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
len(next_token_ids)].tolist()
for idx, (next_token_id,
parent_id) in enumerate(zip(next_token_ids, parent_seq_ids)):
# Get the logprob of a sampled token.
sampled_logprobs_dict = {
next_token_id: (selected_logprob_items[idx], rank_items[idx])
}
# Get top K logprobs.
if num_logprobs > 0:
top_ids = top_token_ids[top_logprob_idx +
parent_id, :num_logprobs].tolist()
top_probs = top_logprobs[top_logprob_idx +
parent_id, :num_logprobs].tolist()
# Top K is already sorted by rank, so we can use 1 ~
# num_logprobs + 1 for rank.
top_ranks = range(1, num_logprobs + 1)
sampled_logprobs_dict.update({
top_id: (top_prob, rank)
for top_id, top_prob, rank in zip(top_ids, top_probs,
top_ranks)
if num_logprobs is None and not use_beam_search:
for next_token_id in next_token_ids:
# Use a dummy logprob
sampled_logprobs.append({next_token_id: Logprob(inf)})
else:
# Pre-select items from tensor. tolist() is faster than repetitive
# `.item()` calls.
selected_logprob_items = selected_logprobs[
selected_logprobs_idx:selected_logprobs_idx +
len(next_token_ids)].tolist()
rank_items = ranks[selected_logprobs_idx:selected_logprobs_idx +
len(next_token_ids)].tolist()
for idx, (next_token_id, parent_id) in enumerate(
zip(next_token_ids, parent_seq_ids)):
# Get the logprob of a sampled token.
sampled_logprobs_dict = {
next_token_id:
(selected_logprob_items[idx], rank_items[idx])
}
if num_logprobs is not None and num_logprobs > 0:
# Get top K logprobs.
top_ids = top_token_ids[top_logprob_idx +
parent_id, :num_logprobs].tolist()
top_probs = top_logprobs[
top_logprob_idx + parent_id, :num_logprobs].tolist()
# Top K is already sorted by rank, so we can use 1 ~
# num_logprobs + 1 for rank.
top_ranks = range(1, num_logprobs + 1)
sampled_logprobs_dict.update({
top_id: (top_prob, rank)
for top_id, top_prob, rank in zip(
top_ids, top_probs, top_ranks)
})

sampled_logprobs.append({
token_id: Logprob(*logprob_and_rank)
for token_id, logprob_and_rank in
sampled_logprobs_dict.items()
})

sampled_logprobs.append({
token_id: Logprob(*logprob_and_rank)
for token_id, logprob_and_rank in
sampled_logprobs_dict.items()
})

# NOTE: This part of code is not intuitive. `selected_logprobs` include
# logprobs for the current step, which has len(next_token_ids) tokens
# per sequence group. `logprobs` includes logprobs from the previous
17 changes: 9 additions & 8 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@ class CompletionOutput:
index: int
text: str
token_ids: Tuple[int, ...]
cumulative_logprob: float
cumulative_logprob: Optional[float]
logprobs: Optional[SampleLogprobs]
finish_reason: Optional[str] = None
stop_reason: Union[int, str, None] = None
@@ -124,13 +124,14 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
include_logprobs = seq_group.sampling_params.logprobs is not None
text_buffer_length = seq_group.sampling_params.output_text_buffer_length
outputs = [
CompletionOutput(seqs.index(seq),
seq.get_output_text_to_return(text_buffer_length),
seq.get_output_token_ids(),
seq.get_cumulative_logprob(),
seq.output_logprobs if include_logprobs else None,
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason) for seq in top_n_seqs
CompletionOutput(
seqs.index(seq),
seq.get_output_text_to_return(text_buffer_length),
seq.get_output_token_ids(),
seq.get_cumulative_logprob() if include_logprobs else None,
seq.output_logprobs if include_logprobs else None,
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason) for seq in top_n_seqs
]

# Every sequence in the sequence group should have the same prompt.
15 changes: 8 additions & 7 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
@@ -92,11 +92,12 @@ class SamplingParams:
min_tokens: Minimum number of tokens to generate per output sequence
before EOS or stop_token_ids can be generated
logprobs: Number of log probabilities to return per output token.
Note that the implementation follows the OpenAI API: The return
result includes the log probabilities on the `logprobs` most likely
tokens, as well the chosen tokens. The API will always return the
log probability of the sampled token, so there may be up to
`logprobs+1` elements in the response.
When set to None, no probability is returned. If set to a non-None
value, the result includes the log probabilities of the specified
number of most likely tokens, as well as the chosen tokens.
Note that the implementation follows the OpenAI API: The API will
always return the log probability of the sampled token, so there
may be up to `logprobs+1` elements in the response.
prompt_logprobs: Number of log probabilities to return per prompt token.
detokenize: Whether to detokenize the output. Defaults to True.
skip_special_tokens: Whether to skip special tokens in the output.
@@ -168,8 +169,8 @@ def __init__(
self.ignore_eos = ignore_eos
self.max_tokens = max_tokens
self.min_tokens = min_tokens
self.logprobs = logprobs
self.prompt_logprobs = prompt_logprobs
self.logprobs = 1 if logprobs is True else logprobs
self.prompt_logprobs = 1 if prompt_logprobs is True else prompt_logprobs
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.

0 comments on commit db9e570

Please sign in to comment.