Skip to content

Commit

Permalink
[Core] Add engine option to return only deltas or final output (vllm-…
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored and MengqingCao committed Sep 30, 2024
1 parent f5f4369 commit a105ca6
Show file tree
Hide file tree
Showing 10 changed files with 370 additions and 136 deletions.
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ steps:
- tests/worker
commands:
- pytest -v -s async_engine # Async Engine
- NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
- pytest -v -s test_inputs.py
- pytest -v -s multimodal
- pytest -v -s test_utils.py # Utils
Expand Down
161 changes: 147 additions & 14 deletions tests/async_engine/test_async_llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asyncio
import os
import uuid
from asyncio import CancelledError
from copy import copy
from dataclasses import dataclass
from typing import Optional
from typing import List, Optional

import pytest
import pytest_asyncio
Expand All @@ -11,6 +14,7 @@
from vllm.config import ParallelConfig
from vllm.engine.async_llm_engine import AsyncEngineArgs, AsyncLLMEngine
from vllm.outputs import RequestOutput as RealRequestOutput
from vllm.sampling_params import RequestOutputKind

from ..conftest import cleanup
from ..utils import wait_for_gpu_memory_to_clear
Expand Down Expand Up @@ -122,8 +126,17 @@ def start_engine():
timeout_s=60,
)

num_scheduler_steps = int(os.getenv("NUM_SCHEDULER_STEPS", "1"))
print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}")

return AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(model="facebook/opt-125m", enforce_eager=True))
AsyncEngineArgs(model="facebook/opt-125m",
enforce_eager=True,
num_scheduler_steps=num_scheduler_steps))


def uid() -> str:
return str(uuid.uuid4())


@pytest_asyncio.fixture(scope="module")
Expand All @@ -148,57 +161,177 @@ def should_do_global_cleanup_after_test(request) -> bool:
@pytest.mark.asyncio(scope="module")
async def test_asyncio_run(async_engine):

scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps

async def run(prompt: str):
sampling_params = SamplingParams(
temperature=0,
max_tokens=32,
min_tokens=32,
)

output_count = 0
final_output = None
async for output in async_engine.generate(prompt,
sampling_params,
request_id=prompt):
request_id=uid()):
output_count += 1
final_output = output
return final_output
return final_output, output_count

results = await asyncio.gather(
run("test0"),
run("test1"),
run("test0"),
)
assert len(results) == 2
first, second = results

# remove nondeterministic fields for comparison
first[0].metrics = None
second[0].metrics = None
first[0].request_id = None
second[0].request_id = None

assert str(first) == str(second)

output_count = results[0][1]
if num_scheduler_steps == 1:
assert output_count == 32
else:
assert 1 < output_count < 32


@pytest.mark.asyncio(scope="module")
async def test_output_kinds(async_engine):
"""Test that output_kind works as expected and that
results are equivalent across different kinds."""

scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps

sampling_params = SamplingParams(
temperature=0,
max_tokens=32,
min_tokens=32,
)

async def run(prompt: str, kind: RequestOutputKind):
params = copy(sampling_params)
params.output_kind = kind

output_count = 0
final_output = None
async for output in async_engine.generate(prompt,
params,
request_id=uid()):
output_count += 1
final_output = output

assert final_output is not None
return (final_output.prompt_token_ids,
final_output.outputs[0].token_ids,
final_output.outputs[0].text, output_count)

async def run_deltas(prompt: str):
params = copy(sampling_params)
params.output_kind = RequestOutputKind.DELTA

prompt_tokens = None
output_tokens: List[int] = []
output_text = ""
output_count = 0
async for output in async_engine.generate(prompt,
params,
request_id=uid()):
token_ids = output.outputs[0].token_ids
text = output.outputs[0].text

# Ensure we get prompt ids iff we haven't yet received output tokens
if output_tokens:
assert 1 <= len(token_ids) <= num_scheduler_steps
assert text
assert not output.prompt_token_ids
else:
assert output.prompt_token_ids
prompt_tokens = output.prompt_token_ids

output_tokens.extend(token_ids)
output_text += text

output_count += 1
return prompt_tokens, output_tokens, output_text, output_count

results = await asyncio.gather(
run("common input prompt", RequestOutputKind.CUMULATIVE),
run("common input prompt", RequestOutputKind.FINAL_ONLY),
run_deltas("common input prompt"))

# Make sure outputs are the same
prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results)
assert len(prompt_set) == 1

text_set = set(text for _, _, text, _ in results)
assert len(text_set) == 1

tokens_set = set(tuple(ids) for _, ids, _, _ in results)
assert len(tokens_set) == 1

cumulative, final, deltas = results

# output message counts
assert cumulative[3] == deltas[3]

if num_scheduler_steps == 1:
assert cumulative[3] == 32
else:
assert 1 < cumulative[3] < 32

assert final[3] == 1


@pytest.mark.asyncio(scope="module")
async def test_cancellation(async_engine):
scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps

sampling_params = SamplingParams(
temperature=0,
min_tokens=10,
max_tokens=10,
min_tokens=13,
max_tokens=13,
)

stop_at = 5 if num_scheduler_steps == 1 else 1

request_id = uid()

i = 0
with pytest.raises(CancelledError):
async for output in async_engine.generate("test2",
sampling_params,
request_id="test2"):
request_id=request_id):
assert not output.finished
i += 1
if i == 5:
await async_engine.abort("test2")
if i == stop_at:
await async_engine.abort(request_id)

assert i == 5
assert i == stop_at


@pytest.mark.asyncio(scope="module")
async def test_delayed_generator(async_engine):
scheduler_config = await async_engine.get_scheduler_config()

if scheduler_config.num_scheduler_steps != 1:
pytest.skip("no need to test this one with multistep")

sampling_params = SamplingParams(
temperature=0,
min_tokens=10,
max_tokens=10,
)

stream = async_engine.generate("test3",
sampling_params,
request_id="test3")
stream = async_engine.generate("test3", sampling_params, request_id=uid())
i = 0
final_output: Optional[RealRequestOutput] = None
async for output in stream:
Expand Down
22 changes: 12 additions & 10 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
RequestOutputFactory)
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
Sequence, SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
Expand Down Expand Up @@ -225,9 +225,6 @@ def __init__(
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
# To improve performance, only final requests outputs may be required.
# If this set to true, then no intermediate outputs will be returned.
step_return_finished_only: bool = False,
) -> None:
logger.info(
"Initializing an LLM engine (v%s) with config: "
Expand Down Expand Up @@ -295,7 +292,6 @@ def __init__(
self.observability_config = observability_config or ObservabilityConfig(
)
self.log_stats = log_stats
self.step_return_finished_only = step_return_finished_only

if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer()
Expand Down Expand Up @@ -1388,7 +1384,8 @@ def _process_model_outputs(self,
seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output)
if request_output:
ctx.request_outputs.append(request_output)

# When we process a single request, we skip it for the next time,
# and invoke the request output callback (if there was final output)
Expand Down Expand Up @@ -1425,14 +1422,19 @@ def _process_model_outputs(self,

seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
if (seq_group.is_finished()
if self.step_return_finished_only else True):
request_output = RequestOutputFactory.create(seq_group)
request_output = RequestOutputFactory.create(seq_group)
if request_output:
ctx.request_outputs.append(request_output)

for seq_group in scheduler_outputs.ignored_seq_groups:
params = seq_group.sampling_params
if params is not None and params.output_kind == (
RequestOutputKind.DELTA) and not seq_group.is_finished():
continue

request_output = RequestOutputFactory.create(seq_group)
ctx.request_outputs.append(request_output)
if request_output:
ctx.request_outputs.append(request_output)

# Immediately process request outputs here (if callback is given)
if (ctx.request_outputs
Expand Down
23 changes: 8 additions & 15 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
Expand Down Expand Up @@ -642,14 +642,12 @@ def _validate_and_add_requests(
raise ValueError("The lengths of prompts and lora_request "
"must be the same.")

if isinstance(params, list):
params = [
self._add_guided_processor(param, guided_options)
if isinstance(param, SamplingParams) else param
for param in params
]
elif isinstance(params, SamplingParams):
params = self._add_guided_processor(params, guided_options)
for sp in params if isinstance(params, list) else (params, ):
if isinstance(sp, SamplingParams):
self._add_guided_processor(sp, guided_options)

# We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY

# Add requests to the engine.
for i, request_inputs in enumerate(inputs):
Expand Down Expand Up @@ -709,9 +707,6 @@ def _run_engine(
f"output: {0:.2f} toks/s"),
)

# In the loop below, only finished outputs are used
self.llm_engine.step_return_finished_only = True

# Run the engine.
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
total_in_toks = 0
Expand All @@ -724,6 +719,7 @@ def _run_engine(
if use_tqdm:
if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput
assert output.prompt_token_ids is not None
total_in_toks += len(output.prompt_token_ids)
in_spd = total_in_toks / pbar.format_dict["elapsed"]
total_out_toks += sum(
Expand All @@ -735,9 +731,6 @@ def _run_engine(
f"output: {out_spd:.2f} toks/s")
pbar.update(1)

# Restore original behavior
self.llm_engine.step_return_finished_only = False

if use_tqdm:
pbar.close()
# Sort the outputs by request ID.
Expand Down
7 changes: 6 additions & 1 deletion vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.entrypoints.openai.logits_processors import get_logits_processors
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.sampling_params import (LogitsProcessor, RequestOutputKind,
SamplingParams)
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid
Expand Down Expand Up @@ -316,6 +317,8 @@ def to_sampling_params(
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
)

@model_validator(mode="before")
Expand Down Expand Up @@ -559,6 +562,8 @@ def to_sampling_params(
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
)

@model_validator(mode="before")
Expand Down
Loading

0 comments on commit a105ca6

Please sign in to comment.