Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Varun/multi step chunked prefill #7563

66 changes: 41 additions & 25 deletions examples/openai_completion_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,44 @@
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)

models = client.models.list()
model = models.data[0].id

# Completion API
stream = False
completion = client.completions.create(
model=model,
prompt="A robot may not injure a human being",
echo=False,
n=2,
stream=stream,
logprobs=3)

print("Completion results:")
if stream:
for c in completion:
print(c)
else:
print(completion)

def get_prompts(n=1):
ps = ['A robot may not injure a human being']
for i in range(1, n):
ps.append(' '.join(["hi!"] * i))

return ps


def main():
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)

models = client.models.list()
model = models.data[0].id

prompts = get_prompts(50)
#print (f"{prompts}")
print(f"# PROMPTS : {len(prompts)}")

# Completion API
stream = False
completion = client.completions.create(model=model,
prompt=prompts,
echo=False,
n=1,
stream=stream)

print("Completion results:")
if stream:
for c in completion:
print(c)
else:
print(completion)


if __name__ == '__main__':
main()
75 changes: 75 additions & 0 deletions tests/basic_correctness/test_multi_step_chunked_prefill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Test the AsyncLLMEngine with multi-step-decoding and chunked prefill

from typing import List

import pytest

from ..utils import RemoteOpenAIServer

MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
]
NUM_SCHEDULER_STEPS = [8, 16] # Multi-step decoding steps
NUM_PROMPTS = [100]

# TODO (varun) : Expand tests for multiple TP & PP
DEFAULT_SERVER_ARGS: List[str] = [
"--disable-log-requests",
"--use-v2-block-manager",
"--worker-use-ray",
"--gpu-memory-utilization",
"0.90",
"--swap-space",
"16",
"--tensor-parallel-size",
"1",
"--pipeline-parallel-size",
"1",
]


async def completions_with_server_args(prompts: List[str], model_name: str,
server_cli_args: List[str]):

outputs = None
with RemoteOpenAIServer(model_name, server_cli_args) as server:
client = server.get_async_client()
outputs = await client.completions.create(model=model_name,
prompt=prompts,
temperature=0,
stream=False,
max_tokens=150)
assert outputs is not None

return outputs


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.asyncio
async def test_mutli_step_with_chunked_prefill(example_prompts, model: str,
num_scheduler_steps: int,
num_prompts: int):

prompts = example_prompts
if len(prompts) < num_prompts:
prompts = prompts * ((num_prompts // len(prompts)) + 1)
prompts = prompts[:num_prompts]
assert len(prompts) == num_prompts

server_args = DEFAULT_SERVER_ARGS + \
["--num-scheduler-steps", f"{num_scheduler_steps}"]

ref_completions = await completions_with_server_args(
prompts, model, server_args)
test_completions = await completions_with_server_args(
prompts, model, server_args + ["--enable-chunked-prefill"])

def get_text_generations(completions):
return [x.text for x in completions.choices]

ref_generations = get_text_generations(ref_completions)
test_generations = get_text_generations(test_completions)
assert ref_generations == test_generations
81 changes: 81 additions & 0 deletions tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from vllm.worker.embedding_model_runner import (
ModelInputForGPUWithPoolingMetadata)
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.worker.multi_step_model_runner import (
MutableModelInputForGPUWithMultiStepMetadata)


class MockAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -154,3 +156,82 @@ def test_embedding_model_runner_input():
None) == getattr(attn_metadata, field.name, None)
# Pooling metadata is not broadcast.
assert received_model_input.pooling_metadata is None


def test_multi_step_model_runner_input():
sampling_metadata = SamplingMetadata(
["seq_group"],
"selected_token_indices",
"categorized_sample_indices",
"num_prompts",
)
attn_metadata = AttentionMetadata(
num_prefills=1,
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
)
frozen_model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),
input_positions=torch.ones(10),
sampling_metadata=sampling_metadata,
attn_metadata=attn_metadata)

model_input = MutableModelInputForGPUWithMultiStepMetadata(
frozen_model_input=frozen_model_input,
is_last_step=True,
is_first_multi_step=False,
current_step=4,
last_sampled_token_ids=torch.ones((10, 1)),
is_multi_step=True,
num_queries=8,
num_seqs=5,
outputs=[],
)

assert isinstance(model_input,
MutableModelInputForGPUWithMultiStepMetadata)

# Test round trip serialization.
tensor_dict = model_input.as_broadcastable_tensor_dict()
attn_backend = MockAttentionBackend()
received_model_input = (MutableModelInputForGPUWithMultiStepMetadata.
from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))

receieved_frozen_input = received_model_input.frozen_model_input

# Check that received copy has correct values.
assert isinstance(received_model_input,
MutableModelInputForGPUWithMultiStepMetadata)
assert receieved_frozen_input.input_tokens is not None
assert (receieved_frozen_input.input_tokens ==
frozen_model_input.input_tokens).all()
assert receieved_frozen_input.input_positions is not None
assert (receieved_frozen_input.input_positions ==
frozen_model_input.input_positions).all()
assert receieved_frozen_input.multi_modal_kwargs is None
assert (frozen_model_input.multi_modal_kwargs ==
frozen_model_input.multi_modal_kwargs)
assert receieved_frozen_input.lora_requests is None
assert (receieved_frozen_input.lora_requests ==
frozen_model_input.lora_requests)
assert receieved_frozen_input.lora_mapping is None
assert (
receieved_frozen_input.lora_mapping == frozen_model_input.lora_mapping)
for field in dataclasses.fields(AttentionMetadata):
assert getattr(receieved_frozen_input.attn_metadata, field.name,
None) == getattr(attn_metadata, field.name, None)
# For sampling metadata, only selected_token_indices is copied.
assert (receieved_frozen_input.sampling_metadata.selected_token_indices ==
sampling_metadata.selected_token_indices)
assert receieved_frozen_input.sampling_metadata.seq_groups is None

# check non frozen fields
assert received_model_input.is_last_step == model_input.is_last_step
assert (received_model_input.is_first_multi_step ==
model_input.is_first_multi_step)
assert received_model_input.current_step == model_input.current_step
assert (received_model_input.last_sampled_token_ids ==
model_input.last_sampled_token_ids).all()
assert received_model_input.is_multi_step == model_input.is_multi_step
18 changes: 17 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,21 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs:
[s.seq_group for s in swapped_in.prefill_seq_groups])
# Update swapped requests.
self.swapped.extend(running_scheduled.swapped_out)

if self.scheduler_config.is_multi_step:
# It maybe the case that prefills are scheduled along
# with decodes. In that case update the multi-step state
# of all the scheduled sequences to perform just a single
# decoding step.
has_prefills = len(prefills.seq_groups) + \
len(running_scheduled.prefill_seq_groups) + \
len(swapped_in.prefill_seq_groups) > 0
if has_prefills:
for sg in running_scheduled.decode_seq_groups:
sg.seq_group.init_multi_step(1)
for sg in swapped_in.decode_seq_groups:
sg.seq_group.init_multi_step(1)

return SchedulerOutputs(
scheduled_seq_groups=(prefills.seq_groups +
running_scheduled.prefill_seq_groups +
Expand Down Expand Up @@ -1187,7 +1202,8 @@ def _append_slots(
the new source and destination block indices for the appended
slots.
"""
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
num_lookahead_slots = self._get_num_lookahead_slots(\
is_prefill=seq_group.is_prefill())
seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1)

for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
Expand Down
5 changes: 2 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,12 +865,11 @@ def create_engine_config(self, ) -> EngineConfig:
)

if self.num_scheduler_steps > 1:
raise NotImplementedError("Multi-step is not yet supported.")
if speculative_config is not None:
raise ValueError("Speculative decoding is not supported with "
"multi-step (--num-scheduler-steps > 1)")
if self.enable_chunked_prefill:
raise ValueError("Chunked prefill is not supported with "
if not self.use_v2_block_manager:
raise ValueError("BlockSpaceManagerV2 is required for "
"multi-step (--num-scheduler-steps > 1)")

# make sure num_lookahead_slots is set the higher value depending on
Expand Down
Loading
Loading