Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add output streaming support to multi-step + async while ensuring RequestOutput obj reuse #8335

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tests/entrypoints/openai/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
RTOL = 0.03
EXPECTED_VALUE = 0.58
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
MORE_ARGS_LIST = [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]]
MORE_ARGS_LIST = [
["--enable-chunked-prefill"], # Chunked
["--num-scheduler-steps", "8"], # MS
["--num-scheduler-steps", "8", "--multi-step-stream-outputs"] # MS+Stream
]


@pytest.mark.parametrize("more_args", MORE_ARGS_LIST)
Expand Down
2 changes: 2 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,7 @@ def __init__(self,
is_multimodal_model: bool = False,
preemption_mode: Optional[str] = None,
num_scheduler_steps: int = 1,
multi_step_stream_outputs: bool = False,
send_delta_data: bool = False) -> None:
if max_num_batched_tokens is None:
if enable_chunked_prefill:
Expand Down Expand Up @@ -999,6 +1000,7 @@ def __init__(self,
self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode
self.num_scheduler_steps = num_scheduler_steps
self.multi_step_stream_outputs = multi_step_stream_outputs
self.send_delta_data = send_delta_data
self._verify_args()

Expand Down
6 changes: 6 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class EngineArgs:
max_cpu_loras: Optional[int] = None
device: str = 'auto'
num_scheduler_steps: int = 1
multi_step_stream_outputs: bool = False
ray_workers_use_nsight: bool = False
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0
Expand Down Expand Up @@ -595,6 +596,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help=('Maximum number of forward steps per '
'scheduler call.'))

parser.add_argument(
'--multi-step-stream-outputs',
action='store_true',
help='If True, then multi-step will stream outputs for every step')
parser.add_argument(
'--scheduler-delay-factor',
type=float,
Expand Down Expand Up @@ -999,6 +1004,7 @@ def create_engine_config(self) -> EngineConfig:
is_multimodal_model=model_config.is_multimodal_model,
preemption_mode=self.preemption_mode,
num_scheduler_steps=self.num_scheduler_steps,
multi_step_stream_outputs=self.multi_step_stream_outputs,
send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER
and parallel_config.use_ray),
)
Expand Down
37 changes: 28 additions & 9 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,16 @@ class OutputData(NamedTuple):

class SchedulerContext:

def __init__(self):
def __init__(self, multi_step_stream_outputs: bool = False):
self.output_queue: Deque[OutputData] = deque()
self.request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = []
self.seq_group_metadata_list: Optional[
List[SequenceGroupMetadata]] = None
self.scheduler_outputs: Optional[SchedulerOutputs] = None

self.multi_step_stream_outputs: bool = multi_step_stream_outputs

def append_output(self, outputs: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduler_outputs: SchedulerOutputs, is_async: bool,
Expand Down Expand Up @@ -219,6 +221,7 @@ def __init__(
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
use_cached_outputs: bool = False,
) -> None:
logger.info(
"Initializing an LLM engine (v%s) with config: "
Expand All @@ -234,8 +237,9 @@ def __init__(
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, observability_config=%r, "
"seed=%d, served_model_name=%s, use_v2_block_manager=%s, "
"num_scheduler_steps=%d, enable_prefix_caching=%s, "
"use_async_output_proc=%s, mm_processor_kwargs=%s)",
"num_scheduler_steps=%d, multi_step_stream_outputs=%s, "
"enable_prefix_caching=%s, use_async_output_proc=%s, "
"use_cached_outputs=%s, mm_processor_kwargs=%s)",
VLLM_VERSION,
model_config.model,
speculative_config,
Expand Down Expand Up @@ -266,8 +270,10 @@ def __init__(
model_config.served_model_name,
scheduler_config.use_v2_block_manager,
scheduler_config.num_scheduler_steps,
scheduler_config.multi_step_stream_outputs,
cache_config.enable_prefix_caching,
model_config.use_async_output_proc,
use_cached_outputs,
model_config.mm_processor_kwargs,
)
# TODO(woosuk): Print more configs in debug mode.
Expand All @@ -287,6 +293,7 @@ def __init__(
self.observability_config = observability_config or ObservabilityConfig(
)
self.log_stats = log_stats
self.use_cached_outputs = use_cached_outputs

if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer()
Expand Down Expand Up @@ -379,7 +386,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
]

self.scheduler_contexts = [
SchedulerContext()
SchedulerContext(multi_step_stream_outputs=self.scheduler_config.
multi_step_stream_outputs)
for _ in range(self.parallel_config.pipeline_parallel_size)
]

Expand Down Expand Up @@ -998,7 +1006,8 @@ def _process_model_outputs(self,

seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group)
request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
if request_output:
ctx.request_outputs.append(request_output)

Expand All @@ -1019,8 +1028,8 @@ def _process_model_outputs(self,
for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()

# For multi-step, do not create outputs each iteration
if not is_last_step:
# For multi-step without streaming, don't create outputs each iteration
if not is_last_step and not ctx.multi_step_stream_outputs:
# Immediately process request outputs here (if callback is given)
if (finished_now
and self.process_request_outputs_callback is not None):
Expand All @@ -1037,17 +1046,27 @@ def _process_model_outputs(self,

seq_group = scheduled_seq_group.seq_group
seq_group.maybe_set_first_token_time(now)
request_output = RequestOutputFactory.create(seq_group)
request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
if request_output:
ctx.request_outputs.append(request_output)

# For multi-step with streaming, create outputs each iteration
if not is_last_step and ctx.multi_step_stream_outputs:
# Immediately process request outputs here (if callback is given)
if self.process_request_outputs_callback is not None:
self.process_request_outputs_callback(ctx.request_outputs)
ctx.request_outputs.clear()
return

for seq_group in scheduler_outputs.ignored_seq_groups:
params = seq_group.sampling_params
if params is not None and params.output_kind == (
RequestOutputKind.DELTA) and not seq_group.is_finished():
continue

request_output = RequestOutputFactory.create(seq_group)
request_output = RequestOutputFactory.create(
seq_group, use_cache=self.use_cached_outputs)
if request_output:
ctx.request_outputs.append(request_output)

Expand Down
9 changes: 8 additions & 1 deletion vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,14 @@ def __init__(self,
*args,
log_requests: bool = True,
**kwargs) -> None:
self.engine = LLMEngine(*args, **kwargs)
# For MQLLMEngine, we can use cached outputs, since each new request
# output is immediately pickled and send over the socket, which frees
# the python object to be reused again.
use_cached_outputs = True

self.engine = LLMEngine(*args,
**kwargs,
use_cached_outputs=use_cached_outputs)
self.log_requests = log_requests

self.use_async_sockets = use_async_sockets
Expand Down
96 changes: 74 additions & 22 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,28 @@ def __init__(
self.encoder_prompt_token_ids = encoder_prompt_token_ids

@classmethod
def from_seq_group(cls,
seq_group: SequenceGroup) -> Optional["RequestOutput"]:
def from_seq_group(cls, seq_group: SequenceGroup,
use_cache: bool) -> Optional["RequestOutput"]:
sampling_params = seq_group.sampling_params
if sampling_params is None:
raise ValueError(
"Sampling parameters are missing for a CompletionRequest.")

finished = seq_group.is_finished()
if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and (
not finished):
return None

# Init cache (if needed)
if use_cache and seq_group.cached_request_output is None:
seq_group.cached_request_output = RequestOutput( # type: ignore
request_id="",
prompt=None,
prompt_token_ids=[],
prompt_logprobs=None,
outputs=[],
finished=False)
Comment on lines +131 to +137
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it would be nice if we had a canonical EmptyRequestOutput object to compare against

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is not necessary since we don't compare to EmptyRequestOutput in this code. It kinda works like an initial placeholder here.


seqs = seq_group.get_seqs()
if len(seqs) == 1:
top_n_seqs = seqs
Expand All @@ -149,29 +160,66 @@ def from_seq_group(cls,

outputs = []
include_prompt = True
for seq in top_n_seqs:
for i, seq in enumerate(top_n_seqs):
output_text = seq.get_output_text_to_return(
text_buffer_length, delta)

output_token_ids = seq.get_output_token_ids_to_return(delta)
num_output_tokens = 1 if isinstance(output_token_ids,
int) else len(output_token_ids)

output_logprobs = seq.output_logprobs if include_logprobs else None

if delta:
# Slice logprobs delta if applicable
if output_logprobs:
output_logprobs = output_logprobs[-len(output_token_ids):]
output_logprobs = output_logprobs[-num_output_tokens:]
# Don't include prompt if this is after the first output
# containing decode token ids
if include_prompt and seq.get_output_len() > len(
output_token_ids):
if include_prompt and seq.get_output_len() > num_output_tokens:
include_prompt = False

outputs.append(
CompletionOutput(
seqs.index(seq), output_text, output_token_ids,
if use_cache:
# Get cached output object
cached_outputs = seq_group.cached_request_output.outputs # type: ignore
if i >= len(cached_outputs):
cached_outputs.append(
CompletionOutput(index=i,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None,
finish_reason=None,
stop_reason=None))
output = cached_outputs[i]

# Init cached output object
assert output.index == i
output.text = output_text

if isinstance(output_token_ids, int):
output.token_ids.clear()
output.token_ids.append(output_token_ids)
else:
output.token_ids = output_token_ids

output.cumulative_logprob = seq.get_cumulative_logprob() \
if include_logprobs else None
output.logprobs = output_logprobs
output.finish_reason = SequenceStatus.get_finished_reason(
seq.status)
output.stop_reason = seq.stop_reason

else:
output = CompletionOutput(
seqs.index(seq), output_text, [output_token_ids]
if isinstance(output_token_ids, int) else output_token_ids,
seq.get_cumulative_logprob() if include_logprobs else None,
output_logprobs,
SequenceStatus.get_finished_reason(seq.status),
seq.stop_reason))
seq.stop_reason)

outputs.append(output)

# Every sequence in the sequence group should have the same prompt.
if include_prompt:
Expand All @@ -188,16 +236,20 @@ def from_seq_group(cls,
prompt_logprobs = None
finished_time = time.time() if finished else None
seq_group.set_finished_time(finished_time)
return cls(seq_group.request_id,
prompt,
prompt_token_ids,
prompt_logprobs,
outputs,
finished,
seq_group.metrics,
lora_request=seq_group.lora_request,
encoder_prompt=encoder_prompt,
encoder_prompt_token_ids=encoder_prompt_token_ids)

init_args = (seq_group.request_id, prompt, prompt_token_ids,
prompt_logprobs, outputs, finished, seq_group.metrics,
seq_group.lora_request, encoder_prompt,
encoder_prompt_token_ids)

if use_cache:
request_output = seq_group.cached_request_output
request_output.__init__(*init_args) # type: ignore

else:
request_output = cls(*init_args)

return request_output

def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, "
Expand Down Expand Up @@ -261,10 +313,10 @@ def __repr__(self):
class RequestOutputFactory:

@staticmethod
def create(seq_group):
def create(seq_group: SequenceGroup, use_cache: bool = False):
# Determine the type based on a condition, for example:
if hasattr(seq_group,
'embeddings') and seq_group.embeddings is not None:
return EmbeddingRequestOutput.from_seq_group(seq_group)
else:
return RequestOutput.from_seq_group(seq_group)
return RequestOutput.from_seq_group(seq_group, use_cache)
28 changes: 19 additions & 9 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def __init__(
self.stop_reason: Union[int, str, None] = None

# These are used to keep track of delta outputs
self._last_token_ids_offset: int = 0
self._last_output_token_ids_offset: int = 0
self._last_output_text_offset: int = 0

# Used for incremental detokenization
Expand Down Expand Up @@ -499,18 +499,26 @@ def get_output_text_to_return(self, buffer_length: int,
return self.output_text[last_offset:length]
return ""

def get_output_token_ids_to_return(self,
delta: bool) -> GenericSequence[int]:
def get_output_token_ids_to_return(
self, delta: bool) -> Union[GenericSequence[int], int]:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if not delta:
return self.get_output_token_ids()
length = self.get_output_len()
last_offset = self._last_token_ids_offset
if last_offset < length:
self._last_token_ids_offset = length
return self.data._output_token_ids[last_offset:]
return ()

output_len = self.get_output_len()

# Get the number of new tokens
num_new_tokens = output_len - self._last_output_token_ids_offset
self._last_output_token_ids_offset = output_len

# Return new tokens
if num_new_tokens == 1:
# Optimization for single decode token case
# (which is what we have most of the time)
return self.data._cached_all_token_ids[-1]

return self.data._cached_all_token_ids[-num_new_tokens:]

def hash_of_block(self, logical_idx: int) -> int:
# TODO This can produce incorrect hash when block size > prompt size
Expand Down Expand Up @@ -671,6 +679,8 @@ def __init__(
self.encoder_seq = encoder_seq
self.trace_headers = trace_headers

self.cached_request_output = None

@property
def prompt(self) -> Optional[str]:
# All sequences in the group should have the same prompt.
Expand Down
Loading