Skip to content

Commit

Permalink
ms for chunked prefill
Browse files Browse the repository at this point in the history
  • Loading branch information
Varun Sundar Rabindranath committed Aug 16, 2024
1 parent 3351973 commit 9b8ba16
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 47 deletions.
66 changes: 41 additions & 25 deletions examples/openai_completion_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,44 @@
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)

models = client.models.list()
model = models.data[0].id

# Completion API
stream = False
completion = client.completions.create(
model=model,
prompt="A robot may not injure a human being",
echo=False,
n=2,
stream=stream,
logprobs=3)

print("Completion results:")
if stream:
for c in completion:
print(c)
else:
print(completion)

def get_prompts(n=1):
ps = ['A robot may not injure a human being']
for i in range(1, n):
ps.append(' '.join(["hi!"] * i))

return ps


def main():
client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)

models = client.models.list()
model = models.data[0].id

prompts = get_prompts(50)
#print (f"{prompts}")
print(f"# PROMPTS : {len(prompts)}")

# Completion API
stream = False
completion = client.completions.create(model=model,
prompt=prompts,
echo=False,
n=1,
stream=stream)

print("Completion results:")
if stream:
for c in completion:
print(c)
else:
print(completion)


if __name__ == '__main__':
main()
75 changes: 75 additions & 0 deletions tests/basic_correctness/test_multi_step_chunked_prefill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Test the AsyncLLMEngine with multi-step-decoding and chunked prefill

from typing import List

import pytest

from ..utils import RemoteOpenAIServer

MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
]
NUM_SCHEDULER_STEPS = [8, 16] # Multi-step decoding steps
NUM_PROMPTS = [100]

# TODO (varun) : Expand tests for multiple TP & PP
DEFAULT_SERVER_ARGS: List[str] = [
"--disable-log-requests",
"--use-v2-block-manager",
"--worker-use-ray",
"--gpu-memory-utilization",
"0.90",
"--swap-space",
"16",
"--tensor-parallel-size",
"1",
"--pipeline-parallel-size",
"1",
]


async def completions_with_server_args(prompts: List[str], model_name: str,
server_cli_args: List[str]):

outputs = None
with RemoteOpenAIServer(model_name, server_cli_args) as server:
client = server.get_async_client()
outputs = await client.completions.create(model=model_name,
prompt=prompts,
temperature=0,
stream=False,
max_tokens=150)
assert outputs is not None

return outputs


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.asyncio
async def test_mutli_step_with_chunked_prefill(example_prompts, model: str,
num_scheduler_steps: int,
num_prompts: int):

prompts = example_prompts
if len(prompts) < num_prompts:
prompts = prompts * ((num_prompts // len(prompts)) + 1)
prompts = prompts[:num_prompts]
assert len(prompts) == num_prompts

server_args = DEFAULT_SERVER_ARGS + \
["--num-scheduler-steps", f"{num_scheduler_steps}"]

ref_completions = await completions_with_server_args(
prompts, model, server_args)
test_completions = await completions_with_server_args(
prompts, model, server_args + ["--enable-chunked-prefill"])

def get_text_generations(completions):
return [x.text for x in completions.choices]

ref_generations = get_text_generations(ref_completions)
test_generations = get_text_generations(test_completions)
assert ref_generations == test_generations
18 changes: 17 additions & 1 deletion vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,21 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs:
[s.seq_group for s in swapped_in.prefill_seq_groups])
# Update swapped requests.
self.swapped.extend(running_scheduled.swapped_out)

if self.scheduler_config.is_multi_step:
# It maybe the case that prefills are scheduled along
# with decodes. In that case update the multi-step state
# of all the scheduled sequences to perform just a single
# decoding step.
has_prefills = len(prefills.seq_groups) + \
len(running_scheduled.prefill_seq_groups) + \
len(swapped_in.prefill_seq_groups) > 0
if has_prefills:
for sg in running_scheduled.decode_seq_groups:
sg.seq_group.init_multi_step(1)
for sg in swapped_in.decode_seq_groups:
sg.seq_group.init_multi_step(1)

return SchedulerOutputs(
scheduled_seq_groups=(prefills.seq_groups +
running_scheduled.prefill_seq_groups +
Expand Down Expand Up @@ -1187,7 +1202,8 @@ def _append_slots(
the new source and destination block indices for the appended
slots.
"""
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
num_lookahead_slots = self._get_num_lookahead_slots(\
is_prefill=seq_group.is_prefill())
seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1)

for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
Expand Down
24 changes: 15 additions & 9 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,10 +294,10 @@ async def step_async(
seq_group_metadata_list, scheduler_outputs = self.scheduler[
virtual_engine].schedule()

if (self.scheduler_config.is_multi_step
and scheduler_outputs.num_lookahead_slots > 0):
if self.scheduler_config.is_multi_step and \
self._remaining_steps(seq_group_metadata_list) > 1:
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
# one.
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs)

Expand Down Expand Up @@ -361,14 +361,15 @@ async def step_async(

return request_outputs

def _has_remaining_steps(
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
) -> bool:
def _remaining_steps(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
) -> int:
if not self.scheduler_config.is_multi_step:
return False
return 0

if not seq_group_metadata_list:
return False
return 0

# TODO(will) this is a sanity check for nowto make sure that all the
# seqs are on the same steps. Eventually we will want to do some sort of
Expand All @@ -381,7 +382,12 @@ def _has_remaining_steps(
raise AssertionError(("All running sequence groups should "
"have the same remaining steps."))

return ref_remaining_steps > 0
return ref_remaining_steps

def _has_remaining_steps(
self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]]
) -> bool:
return self._remaining_steps(seq_group_metadata_list) > 0

def _cache_scheduler_outputs_for_multi_step(
self, virtual_engine: int,
Expand Down
30 changes: 18 additions & 12 deletions vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,10 @@ def _pythonize_sampler_output(
# samples generation should have been skipped
assert not output.outputs

pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries]
# dont use num-queries as some of the sequence's may not need sampling.
# Like, chunked prefill seqs.
n_sampled_token_ids = sampled_token_ids.shape[0]
pinned_buffer = pinned_sampled_token_buffer[:n_sampled_token_ids]

# CPU GPU sync
pinned_buffer = pinned_buffer.copy_(sampled_token_ids, non_blocking=False)
Expand All @@ -491,20 +494,23 @@ def _pythonize_sampler_output(

sampling_metadata = frozen_model_input.sampling_metadata

for (seq_group, sample_result) in zip(sampling_metadata.seq_groups,
samples_list):
seq_ids = seq_group.seq_ids
next_token_ids = sample_result
parent_ids = [0]
sample_result_it = iter(samples_list)
for seq_group in sampling_metadata.seq_groups:
seq_outputs: List[SequenceOutput] = []
if seq_group.sampling_params.logits_processors:
assert len(seq_group.sampling_params.logits_processors) == 0, (
"Logits Processors are not supported in multi-step decoding")
for parent_id, next_token_id in zip(parent_ids, next_token_ids):
# TODO(will): support logprobs
# Hard coded logprob
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id,
{next_token_id: Logprob(logprob=42)}))
if seq_group.do_sample:
sample_result = next(sample_result_it)
seq_ids = seq_group.seq_ids
next_token_ids = sample_result
parent_ids = [0]
for parent_id, next_token_id in zip(parent_ids, next_token_ids):
# TODO(will): support logprobs
# Hard coded logprob
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id,
{next_token_id: Logprob(logprob=42)}))
output.outputs.append(CompletionSequenceGroupOutput(seq_outputs, None))

assert len(output.outputs) > 0

0 comments on commit 9b8ba16

Please sign in to comment.