Skip to content

Commit

Permalink
[core] Multi Step Scheduling (vllm-project#7000)
Browse files Browse the repository at this point in the history
Co-authored-by: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com>
  • Loading branch information
SolitaryThinker and afeldman-nm authored Aug 19, 2024
1 parent 0400383 commit 27a1999
Show file tree
Hide file tree
Showing 13 changed files with 1,004 additions and 34 deletions.
9 changes: 9 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,15 @@ steps:
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py

- label: Multi-step Tests (4 GPUs) # 10min
working_dir: "/vllm-workspace/tests"
num_gpus: 4
source_file_dependencies:
- vllm/
- tests/multi_step/test_correctness.py
commands:
- pytest -v -s multi_step/test_correctness.py

- label: Pipeline Parallelism Test # 23min
working_dir: "/vllm-workspace/tests"
num_gpus: 4
Expand Down
Empty file added tests/multi_step/__init__.py
Empty file.
85 changes: 85 additions & 0 deletions tests/multi_step/test_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Test the AsyncLLMEngine with multi-step-decoding

from typing import List

import pytest

from ..utils import RemoteOpenAIServer

MODELS = [
"JackFram/llama-160m",
]
NUM_SCHEDULER_STEPS = [8] # Multi-step decoding steps
NUM_PROMPTS = [10]

DEFAULT_SERVER_ARGS: List[str] = [
"--disable-log-requests",
"--use-v2-block-manager",
"--worker-use-ray",
"--gpu-memory-utilization",
"0.85",
"--swap-space",
"16",
]


async def completions_with_server_args(prompts: List[str], model_name: str,
server_cli_args: List[str]):

outputs = None
with RemoteOpenAIServer(model_name, server_cli_args) as server:
client = server.get_async_client()
outputs = await client.completions.create(model=model_name,
prompt=prompts,
temperature=0,
stream=False,
max_tokens=5)
assert outputs is not None

return outputs


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize(("tp_size, pp_size"), [
(1, 1),
(2, 2),
])
@pytest.mark.parametrize("eager_mode", [False, True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.asyncio
async def test_multi_step(example_prompts, model: str, tp_size: int,
pp_size: int, eager_mode: int,
num_scheduler_steps: int, num_prompts: int):

prompts = example_prompts
if len(prompts) < num_prompts:
prompts = prompts * ((num_prompts // len(prompts)) + 1)
prompts = prompts[:num_prompts]
assert len(prompts) == num_prompts

server_args = DEFAULT_SERVER_ARGS + ["--enforce-eager"]
ms_server_args = DEFAULT_SERVER_ARGS + \
["--num-scheduler-steps", f"{num_scheduler_steps}"]

if eager_mode:
ms_server_args.append("--enforce-eager")

distributed_args = [
"--tensor-parallel-size",
str(tp_size),
"--pipeline-parallel-size",
str(pp_size),
]

ref_completions = await completions_with_server_args(
prompts, model, server_args + distributed_args)
test_completions = await completions_with_server_args(
prompts, model, ms_server_args + distributed_args)

def get_text_generations(completions):
return [x.text for x in completions.choices]

ref_generations = get_text_generations(ref_completions)
test_generations = get_text_generations(test_completions)
assert ref_generations == test_generations
77 changes: 77 additions & 0 deletions tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.worker.embedding_model_runner import (
ModelInputForGPUWithPoolingMetadata)
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.worker.multi_step_model_runner import StatefulModelInput


class MockAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -154,3 +155,79 @@ def test_embedding_model_runner_input():
None) == getattr(attn_metadata, field.name, None)
# Pooling metadata is not broadcast.
assert received_model_input.pooling_metadata is None


def test_multi_step_model_runner_input():
sampling_metadata = SamplingMetadata(
["seq_group"],
"selected_token_indices",
"categorized_sample_indices",
"num_prompts",
)
attn_metadata = AttentionMetadata(
num_prefills=1,
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
)
frozen_model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),
input_positions=torch.ones(10),
sampling_metadata=sampling_metadata,
attn_metadata=attn_metadata)

model_input = StatefulModelInput(
frozen_model_input=frozen_model_input,
is_last_step=True,
is_first_multi_step=False,
current_step=4,
last_sampled_token_ids=torch.ones((10, 1)),
is_multi_step=True,
num_queries=8,
num_seqs=5,
cached_outputs=[],
)

assert isinstance(model_input, StatefulModelInput)

# Test round trip serialization.
tensor_dict = model_input.as_broadcastable_tensor_dict()
attn_backend = MockAttentionBackend()
received_model_input = (StatefulModelInput.from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))

receieved_frozen_input = received_model_input.frozen_model_input

# Check that received copy has correct values.
assert isinstance(received_model_input, StatefulModelInput)
assert receieved_frozen_input.input_tokens is not None
assert (receieved_frozen_input.input_tokens ==
frozen_model_input.input_tokens).all()
assert receieved_frozen_input.input_positions is not None
assert (receieved_frozen_input.input_positions ==
frozen_model_input.input_positions).all()
assert receieved_frozen_input.multi_modal_kwargs is None
assert (frozen_model_input.multi_modal_kwargs ==
frozen_model_input.multi_modal_kwargs)
assert receieved_frozen_input.lora_requests is None
assert (receieved_frozen_input.lora_requests ==
frozen_model_input.lora_requests)
assert receieved_frozen_input.lora_mapping is None
assert (
receieved_frozen_input.lora_mapping == frozen_model_input.lora_mapping)
for field in dataclasses.fields(AttentionMetadata):
assert getattr(receieved_frozen_input.attn_metadata, field.name,
None) == getattr(attn_metadata, field.name, None)
# For sampling metadata, only selected_token_indices is copied.
assert (receieved_frozen_input.sampling_metadata.selected_token_indices ==
sampling_metadata.selected_token_indices)
assert receieved_frozen_input.sampling_metadata.seq_groups is None

# check non frozen fields
assert received_model_input.is_last_step == model_input.is_last_step
assert (received_model_input.is_first_multi_step ==
model_input.is_first_multi_step)
assert received_model_input.current_step == model_input.current_step
assert (received_model_input.last_sampled_token_ids ==
model_input.last_sampled_token_ids).all()
assert received_model_input.is_multi_step == model_input.is_multi_step
7 changes: 6 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,6 +853,12 @@ def create_engine_config(self, ) -> EngineConfig:
"in low performance due to small KV cache space. Consider "
"setting --max-model-len to a smaller value.", max_model_len)

if self.num_scheduler_steps > 1 and not self.use_v2_block_manager:
self.use_v2_block_manager = True
logger.warning(
"Enabled BlockSpaceManagerV2 because it is "
"required for multi-step (--num-scheduler-steps > 1)")

speculative_config = SpeculativeConfig.maybe_create_spec_config(
target_model_config=model_config,
target_parallel_config=parallel_config,
Expand Down Expand Up @@ -881,7 +887,6 @@ def create_engine_config(self, ) -> EngineConfig:
)

if self.num_scheduler_steps > 1:
raise NotImplementedError("Multi-step is not yet supported.")
if speculative_config is not None:
raise ValueError("Speculative decoding is not supported with "
"multi-step (--num-scheduler-steps > 1)")
Expand Down
Loading

0 comments on commit 27a1999

Please sign in to comment.