Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[misc] hide best_of from engine #9261

Merged
merged 12 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions tests/entrypoints/openai/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ async def client(server):
[("_sum", _NUM_REQUESTS * _NUM_GENERATION_TOKENS_PER_REQUEST),
("_count", _NUM_REQUESTS)],
"vllm:request_params_n": [("_count", _NUM_REQUESTS)],
"vllm:request_params_best_of": [("_count", _NUM_REQUESTS)],
"vllm:prompt_tokens": [("_total",
_NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST)],
"vllm:generation_tokens":
Expand Down Expand Up @@ -151,9 +150,6 @@ async def test_metrics_counts(client: openai.AsyncOpenAI):
"vllm:request_params_n_sum",
"vllm:request_params_n_bucket",
"vllm:request_params_n_count",
"vllm:request_params_best_of_sum",
"vllm:request_params_best_of_bucket",
"vllm:request_params_best_of_count",
"vllm:num_preemptions_total",
"vllm:prompt_tokens_total",
"vllm:generation_tokens_total",
Expand Down
1 change: 0 additions & 1 deletion tests/metrics/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,6 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
"vllm:e2e_request_latency_seconds",
"vllm:request_prompt_tokens",
"vllm:request_generation_tokens",
"vllm:request_params_best_of",
"vllm:request_params_n",
]
for metric_name in request_histogram_metrics:
Expand Down
4 changes: 0 additions & 4 deletions tests/tracing/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ def test_traces(trace_service):
SpanAttributes.LLM_REQUEST_TOP_P) == sampling_params.top_p
assert attributes.get(
SpanAttributes.LLM_REQUEST_MAX_TOKENS) == sampling_params.max_tokens
assert attributes.get(
SpanAttributes.LLM_REQUEST_BEST_OF) == sampling_params.best_of
assert attributes.get(SpanAttributes.LLM_REQUEST_N) == sampling_params.n
assert attributes.get(SpanAttributes.LLM_USAGE_PROMPT_TOKENS) == len(
outputs[0].prompt_token_ids)
Expand Down Expand Up @@ -155,8 +153,6 @@ def test_traces_with_detailed_steps(trace_service):
SpanAttributes.LLM_REQUEST_TOP_P) == sampling_params.top_p
assert attributes.get(
SpanAttributes.LLM_REQUEST_MAX_TOKENS) == sampling_params.max_tokens
assert attributes.get(
SpanAttributes.LLM_REQUEST_BEST_OF) == sampling_params.best_of
assert attributes.get(SpanAttributes.LLM_REQUEST_N) == sampling_params.n
assert attributes.get(SpanAttributes.LLM_USAGE_PROMPT_TOKENS) == len(
outputs[0].prompt_token_ids)
Expand Down
2 changes: 1 addition & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,7 @@ def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
# async_output_proc is allowed only when we have a single sequence
# in the sequence group
no_single_seq = seq_group.sampling_params is None or (
seq_group.sampling_params.best_of == 1)
seq_group.sampling_params.n == 1)
return no_single_seq

def schedule(
Expand Down
11 changes: 2 additions & 9 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ def add_request(
Details:
- Set arrival_time to the current time if it is None.
- Set prompt_token_ids to the encoded prompt if it is None.
- Create `best_of` number of :class:`~vllm.Sequence` objects.
- Create `n` number of :class:`~vllm.Sequence` objects.
- Create a :class:`~vllm.SequenceGroup` object
from the list of :class:`~vllm.Sequence`.
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
Expand Down Expand Up @@ -1242,8 +1242,7 @@ def _advance_to_next_step(
if seq_group_metadata.do_sample:
assert len(sequence_group_outputs.samples) == 1, (
"Async output processor expects a single sample"
" (i.e sampling_params.n == 1 and no "
"sampling_params.best_of > 1)")
" (i.e sampling_params.n == 1)")
sample = sequence_group_outputs.samples[0]

assert len(seq_group.seqs) == 1
Expand Down Expand Up @@ -1612,7 +1611,6 @@ def _get_stats(self,
# Metadata
num_prompt_tokens_requests: List[int] = []
num_generation_tokens_requests: List[int] = []
best_of_requests: List[int] = []
n_requests: List[int] = []
finished_reason_requests: List[str] = []

Expand Down Expand Up @@ -1683,8 +1681,6 @@ def _get_stats(self,
for seq in seq_group.get_finished_seqs()
])
if seq_group.sampling_params is not None:
best_of_requests.append(
seq_group.sampling_params.best_of)
n_requests.append(seq_group.sampling_params.n)
finished_reason_requests.extend([
SequenceStatus.get_finished_reason(seq.status)
Expand Down Expand Up @@ -1737,7 +1733,6 @@ def _get_stats(self,
# Metadata
num_prompt_tokens_requests=num_prompt_tokens_requests,
num_generation_tokens_requests=num_generation_tokens_requests,
best_of_requests=best_of_requests,
n_requests=n_requests,
finished_reason_requests=finished_reason_requests,
)
Expand Down Expand Up @@ -1824,8 +1819,6 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None:
seq_group.sampling_params.top_p)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_MAX_TOKENS,
seq_group.sampling_params.max_tokens)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_BEST_OF,
seq_group.sampling_params.best_of)
seq_span.set_attribute(SpanAttributes.LLM_REQUEST_N,
seq_group.sampling_params.n)
seq_span.set_attribute(SpanAttributes.LLM_USAGE_NUM_SEQUENCES,
Expand Down
8 changes: 0 additions & 8 deletions vllm/engine/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,6 @@ def __init__(self, labelnames: List[str], max_model_len: int):
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_best_of_request = self._histogram_cls(
name="vllm:request_params_best_of",
documentation="Histogram of the best_of request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
self.histogram_n_request = self._histogram_cls(
name="vllm:request_params_n",
documentation="Histogram of the n request parameter.",
Expand Down Expand Up @@ -473,8 +467,6 @@ def _log_prometheus(self, stats: Stats) -> None:
self.metrics.histogram_num_generation_tokens_request,
stats.num_generation_tokens_requests)
self._log_histogram(self.metrics.histogram_n_request, stats.n_requests)
self._log_histogram(self.metrics.histogram_best_of_request,
stats.best_of_requests)

def _log_prometheus_interval(self, prompt_throughput: float,
generation_throughput: float) -> None:
Expand Down
1 change: 0 additions & 1 deletion vllm/engine/metrics_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ class Stats:
# Metadata
num_prompt_tokens_requests: List[int]
num_generation_tokens_requests: List[int]
best_of_requests: List[int]
n_requests: List[int]
finished_reason_requests: List[str]

Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
outputs: SequenceGroupOutput,
is_async: bool) -> None:
sampling_params = seq_group.sampling_params
if sampling_params.best_of == 1:
if sampling_params.n == 1:
# only have one output sample
sample = outputs.samples[0]
# only have one sequence
Expand Down
17 changes: 8 additions & 9 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def _random_sample(
same as the length of selected_seq_groups. If the corresponding
seq_group has do_sample=False, tuple contains ([], [])
"""
# Find the maximum best_of value of the prompt phase requests.
# Find the maximum n value of the prompt phase requests.
random_samples = random_samples.cpu()
sample_idx = 0
results: SampleResultType = []
Expand All @@ -523,9 +523,9 @@ def _random_sample(
num_parent_seqs = len(seq_ids)
if is_prompt:
# Prompt phase.
parent_ids = [0] * sampling_params.best_of
parent_ids = [0] * sampling_params.n
next_token_ids = random_samples[
sample_idx, :sampling_params.best_of].tolist()
sample_idx, :sampling_params.n].tolist()
else:
# Generation phase.
parent_ids = list(range(num_parent_seqs))
Expand Down Expand Up @@ -570,7 +570,7 @@ def _beam_search_sample(
is_prompt = seq_group.is_prompt
seq_ids, sampling_params = seq_group.seq_ids, seq_group.sampling_params
num_parent_seqs = len(seq_ids)
beam_width = sampling_params.best_of
beam_width = sampling_params.n
seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs]
if is_prompt:
# Prompt phase.
Expand Down Expand Up @@ -797,12 +797,11 @@ def _sample_with_torch(
greedy_samples)

elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of_in_batch = 1
max_n_in_batch = 1
for seq_group in seq_groups:
if seq_group.is_prompt:
sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of)
max_n_in_batch = max(max_n_in_batch, sampling_params.n)
seq_groups_arg = (None if sampling_type == SamplingType.RANDOM else
seq_groups)

Expand All @@ -812,13 +811,13 @@ def _sample_with_torch(
probs[long_sample_indices],
sampling_tensors.top_ks[long_sample_indices],
sampling_tensors.top_ps[long_sample_indices],
max_best_of_in_batch,
max_n_in_batch,
seq_groups_arg,
)
else:
multinomial_samples[sampling_type] = _multinomial(
probs[long_sample_indices],
max_best_of_in_batch,
max_n_in_batch,
seq_groups=seq_groups_arg)

if sampled_token_ids_tensor is not None:
Expand Down
2 changes: 1 addition & 1 deletion vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def from_seq_group(cls, seq_group: SequenceGroup,
top_n_seqs = seqs
else:
# Get the top-n sequences.
n = sampling_params.n
n = sampling_params._real_n or sampling_params.n
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
Expand Down
33 changes: 17 additions & 16 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,8 @@ class SamplingParams(
n: Number of output sequences to return for the given prompt.
best_of: Number of output sequences that are generated from the prompt.
From these `best_of` sequences, the top `n` sequences are returned.
`best_of` must be greater than or equal to `n`. This is treated as
the beam width when `use_beam_search` is True. By default, `best_of`
is set to `n`.
`best_of` must be greater than or equal to `n`. By default,
`best_of` is set to `n`.
presence_penalty: Float that penalizes new tokens based on whether they
appear in the generated text so far. Values > 0 encourage the model
to use new tokens, while values < 0 encourage the model to repeat
Expand Down Expand Up @@ -173,6 +172,7 @@ class SamplingParams(

n: int = 1
best_of: Optional[int] = None
_real_n: Optional[int] = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
repetition_penalty: float = 1.0
Expand Down Expand Up @@ -282,7 +282,19 @@ def from_optional(
)

def __post_init__(self) -> None:
self.best_of = self.best_of or self.n
# how we deal with `best_of``:
# if `best_of`` is not set, we default to `n`;
# if `best_of`` is set, we set `n`` to `best_of`,
# and set `_real_n`` to the original `n`.
# when we return the result, we will check
# if we need to return `n` or `_real_n` results
if self.best_of:
if self.best_of < self.n:
raise ValueError(
f"best_of must be greater than or equal to n, "
f"got n={self.n} and best_of={self.best_of}.")
self._real_n = self.n
self.n = self.best_of
if 0 < self.temperature < _MAX_TEMP:
logger.warning(
"temperature %s is less than %s, which may cause numerical "
Expand Down Expand Up @@ -329,12 +341,6 @@ def _verify_args(self) -> None:
f"type {type(self.n)}")
if self.n < 1:
raise ValueError(f"n must be at least 1, got {self.n}.")
if not isinstance(self.best_of, int):
raise ValueError(f"best_of must be an int, but is of "
f"type {type(self.best_of)}")
if self.best_of < self.n:
raise ValueError(f"best_of must be greater than or equal to n, "
f"got n={self.n} and best_of={self.best_of}.")
if not -2.0 <= self.presence_penalty <= 2.0:
raise ValueError("presence_penalty must be in [-2, 2], got "
f"{self.presence_penalty}.")
Expand Down Expand Up @@ -385,18 +391,14 @@ def _verify_args(self) -> None:
raise ValueError(
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop.")
if self.best_of != self.n and self.output_kind == (
if self.best_of != self._real_n and self.output_kind == (
RequestOutputKind.DELTA):
raise ValueError("best_of must equal n to use output_kind=DELTA")

def _verify_greedy_sampling(self) -> None:
if self.n > 1:
raise ValueError("n must be 1 when using greedy sampling, "
f"got {self.n}.")
assert isinstance(self.best_of, int)
if self.best_of > 1:
raise ValueError("best_of must be 1 when using greedy sampling, "
f"got {self.best_of}.")

def update_from_generation_config(
self,
Expand Down Expand Up @@ -453,7 +455,6 @@ def clone(self) -> "SamplingParams":
def __repr__(self) -> str:
return (
f"SamplingParams(n={self.n}, "
f"best_of={self.best_of}, "
f"presence_penalty={self.presence_penalty}, "
f"frequency_penalty={self.frequency_penalty}, "
f"repetition_penalty={self.repetition_penalty}, "
Expand Down
10 changes: 5 additions & 5 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,14 +803,14 @@ def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
if self.sampling_params:
best_of = self.sampling_params.best_of
assert isinstance(best_of, int)
if best_of > self.num_seqs():
n = self.sampling_params.n
assert isinstance(n, int)
if n > self.num_seqs():
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences
# generation stage, we will have `n` sequences
# running.
return best_of
return n
# At sampling stages, return the number of actual sequences
# that are not finished yet.
return self.num_unfinished_seqs()
Expand Down
1 change: 0 additions & 1 deletion vllm/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ class SpanAttributes(BaseSpanAttributes):
# The following span attribute names are added here because they are missing
# from the Semantic Conventions for LLM.
LLM_REQUEST_ID = "gen_ai.request.id"
LLM_REQUEST_BEST_OF = "gen_ai.request.best_of"
LLM_REQUEST_N = "gen_ai.request.n"
LLM_USAGE_NUM_SEQUENCES = "gen_ai.usage.num_sequences"
LLM_LATENCY_TIME_IN_QUEUE = "gen_ai.latency.time_in_queue"
Expand Down
Loading