From 4d3c39a56e2aa12fad9fc381d379fa8242cd88b4 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Thu, 19 Sep 2024 18:30:05 +0000 Subject: [PATCH 1/6] Add streaming support to multistep --- tests/entrypoints/openai/test_accuracy.py | 15 ++- vllm/config.py | 2 + vllm/engine/arg_utils.py | 7 ++ vllm/engine/llm_engine.py | 42 +++++++-- vllm/engine/multiprocessing/client.py | 2 +- vllm/engine/multiprocessing/engine.py | 9 +- vllm/outputs.py | 108 +++++++++++++++++----- vllm/sequence.py | 25 +++-- 8 files changed, 169 insertions(+), 41 deletions(-) diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index 2ad8460023c2..6a0c1adabc2b 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -19,7 +19,20 @@ RTOL = 0.03 EXPECTED_VALUE = 0.58 DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"] -MORE_ARGS_LIST = [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]] +# MORE_ARGS_LIST = [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]] +MORE_ARGS_LIST = [[ + "--num-scheduler-steps", "8", "--multi-step-stream-outputs" +]] + +# @pytest.fixture(scope="module") +# def server(): +# args = [ +# "--max-model-len", "4096", "--enable-chunked-prefill", +# "--disable-log-requests", "--enforce-eager" +# ] + +# with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: +# yield remote_server @pytest.mark.parametrize("more_args", MORE_ARGS_LIST) diff --git a/vllm/config.py b/vllm/config.py index fae2d44f174b..838334e59876 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -959,6 +959,7 @@ def __init__(self, is_multimodal_model: bool = False, preemption_mode: Optional[str] = None, num_scheduler_steps: int = 1, + multi_step_stream_outputs: bool = False, send_delta_data: bool = False) -> None: if max_num_batched_tokens is None: if enable_chunked_prefill: @@ -999,6 +1000,7 @@ def __init__(self, self.embedding_mode = embedding_mode self.preemption_mode = preemption_mode self.num_scheduler_steps = num_scheduler_steps + self.multi_step_stream_outputs = multi_step_stream_outputs self.send_delta_data = send_delta_data self._verify_args() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ca6034ddbe5c..cefe810d924d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -145,6 +145,7 @@ class EngineArgs: max_cpu_loras: Optional[int] = None device: str = 'auto' num_scheduler_steps: int = 1 + multi_step_stream_outputs: bool = False ray_workers_use_nsight: bool = False num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 @@ -595,6 +596,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help=('Maximum number of forward steps per ' 'scheduler call.')) + # LoRA related configs + parser.add_argument( + '--multi-step-stream-outputs', + action='store_true', + help='If True, then multi-step will stream outputs for every step') parser.add_argument( '--scheduler-delay-factor', type=float, @@ -999,6 +1005,7 @@ def create_engine_config(self) -> EngineConfig: is_multimodal_model=model_config.is_multimodal_model, preemption_mode=self.preemption_mode, num_scheduler_steps=self.num_scheduler_steps, + multi_step_stream_outputs=self.multi_step_stream_outputs, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), ) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 80dde804adda..b7f09be1cb14 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -95,7 +95,8 @@ class OutputData(NamedTuple): class SchedulerContext: - def __init__(self): + def __init__(self, + multi_step_stream_outputs: bool = False): self.output_queue: Deque[OutputData] = deque() self.request_outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] @@ -103,6 +104,8 @@ def __init__(self): List[SequenceGroupMetadata]] = None self.scheduler_outputs: Optional[SchedulerOutputs] = None + self.multi_step_stream_outputs: bool = multi_step_stream_outputs + def append_output(self, outputs: List[SamplerOutput], seq_group_metadata_list: List[SequenceGroupMetadata], scheduler_outputs: SchedulerOutputs, is_async: bool, @@ -115,6 +118,10 @@ def append_output(self, outputs: List[SamplerOutput], is_last_step=is_last_step, skip=[])) + # TODO: Remove + # use_request_output_cache: bool = True + # request_output_cache: PyObjectCache = PyObjectCache(request_output_builder) + class LLMEngine: """An LLM engine that receives requests and generates texts. @@ -219,6 +226,7 @@ def __init__( usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, + use_cached_outputs: bool = False, ) -> None: logger.info( "Initializing an LLM engine (v%s) with config: " @@ -234,8 +242,9 @@ def __init__( "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " - "num_scheduler_steps=%d, enable_prefix_caching=%s, " - "use_async_output_proc=%s, mm_processor_kwargs=%s)", + "num_scheduler_steps=%d, multi_step_stream_outputs=%s, " + "enable_prefix_caching=%s, use_async_output_proc=%s, " + "use_cached_outputs=%s, mm_processor_kwargs=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -266,8 +275,10 @@ def __init__( model_config.served_model_name, scheduler_config.use_v2_block_manager, scheduler_config.num_scheduler_steps, + scheduler_config.multi_step_stream_outputs, cache_config.enable_prefix_caching, model_config.use_async_output_proc, + use_cached_outputs, model_config.mm_processor_kwargs, ) # TODO(woosuk): Print more configs in debug mode. @@ -287,6 +298,7 @@ def __init__( self.observability_config = observability_config or ObservabilityConfig( ) self.log_stats = log_stats + self.use_cached_outputs = use_cached_outputs if not self.model_config.skip_tokenizer_init: self.tokenizer = self._init_tokenizer() @@ -379,7 +391,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: ] self.scheduler_contexts = [ - SchedulerContext() + SchedulerContext(multi_step_stream_outputs=self.scheduler_config. + multi_step_stream_outputs) for _ in range(self.parallel_config.pipeline_parallel_size) ] @@ -998,7 +1011,8 @@ def _process_model_outputs(self, seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) - request_output = RequestOutputFactory.create(seq_group) + request_output = RequestOutputFactory.create( + seq_group, use_cache=self.use_cached_outputs) if request_output: ctx.request_outputs.append(request_output) @@ -1019,8 +1033,8 @@ def _process_model_outputs(self, for scheduler in self.scheduler: scheduler.free_finished_seq_groups() - # For multi-step, do not create outputs each iteration - if not is_last_step: + # For multi-step without streaming, don't create outputs each iteration + if not is_last_step and not ctx.multi_step_stream_outputs: # Immediately process request outputs here (if callback is given) if (finished_now and self.process_request_outputs_callback is not None): @@ -1037,17 +1051,27 @@ def _process_model_outputs(self, seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) - request_output = RequestOutputFactory.create(seq_group) + request_output = RequestOutputFactory.create( + seq_group, use_cache=self.use_cached_outputs) if request_output: ctx.request_outputs.append(request_output) + # For multi-step with streaming, create outputs each iteration + if not is_last_step and ctx.multi_step_stream_outputs: + # Immediately process request outputs here (if callback is given) + if self.process_request_outputs_callback is not None: + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + for seq_group in scheduler_outputs.ignored_seq_groups: params = seq_group.sampling_params if params is not None and params.output_kind == ( RequestOutputKind.DELTA) and not seq_group.is_finished(): continue - request_output = RequestOutputFactory.create(seq_group) + request_output = RequestOutputFactory.create( + seq_group, use_cache=self.use_cached_outputs) if request_output: ctx.request_outputs.append(request_output) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 71099115ea12..e403fdaf95a1 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -488,7 +488,7 @@ async def _process_request( if isinstance(request_output, BaseException): raise request_output - + finished = request_output.finished yield request_output finally: diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 788c1573ae25..71bb132e9f65 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -66,7 +66,14 @@ def __init__(self, *args, log_requests: bool = True, **kwargs) -> None: - self.engine = LLMEngine(*args, **kwargs) + # For MQLLMEngine, we can use cached outputs, since each new request + # output is immediately pickled and send over the socket, which frees + # the python object to be reused again. + use_cached_outputs = True + + self.engine = LLMEngine(*args, + **kwargs, + use_cached_outputs=use_cached_outputs) self.log_requests = log_requests self.use_async_sockets = use_async_sockets diff --git a/vllm/outputs.py b/vllm/outputs.py index 85ea9196b25d..996fa165ae0b 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -114,17 +114,28 @@ def __init__( self.encoder_prompt_token_ids = encoder_prompt_token_ids @classmethod - def from_seq_group(cls, - seq_group: SequenceGroup) -> Optional["RequestOutput"]: + def from_seq_group(cls, seq_group: SequenceGroup, + use_cache: bool) -> Optional["RequestOutput"]: sampling_params = seq_group.sampling_params if sampling_params is None: raise ValueError( "Sampling parameters are missing for a CompletionRequest.") + finished = seq_group.is_finished() if sampling_params.output_kind == RequestOutputKind.FINAL_ONLY and ( not finished): return None + # Init cache (if needed) + if use_cache and seq_group.cached_request_output is None: + seq_group.cached_request_output = RequestOutput( + request_id="", + prompt=None, + prompt_token_ids=[], + prompt_logprobs=None, + outputs=[], + finished=False) + seqs = seq_group.get_seqs() if len(seqs) == 1: top_n_seqs = seqs @@ -149,29 +160,65 @@ def from_seq_group(cls, outputs = [] include_prompt = True - for seq in top_n_seqs: + for i, seq in enumerate(top_n_seqs): output_text = seq.get_output_text_to_return( text_buffer_length, delta) + output_token_ids = seq.get_output_token_ids_to_return(delta) + num_output_tokens = 1 if isinstance(output_token_ids, + int) else len(output_token_ids) + output_logprobs = seq.output_logprobs if include_logprobs else None if delta: # Slice logprobs delta if applicable if output_logprobs: - output_logprobs = output_logprobs[-len(output_token_ids):] + output_logprobs = output_logprobs[-num_output_tokens:] # Don't include prompt if this is after the first output # containing decode token ids - if include_prompt and seq.get_output_len() > len( - output_token_ids): + if include_prompt and seq.get_output_len() > num_output_tokens: include_prompt = False - outputs.append( - CompletionOutput( + if use_cache: + # Get cached output object + cached_outputs = seq_group.cached_request_output.outputs + if i >= len(cached_outputs): + cached_outputs.append( + CompletionOutput(index=i, + text="", + token_ids=[], + cumulative_logprob=None, + logprobs=None, + finish_reason=None, + stop_reason=None)) + output = cached_outputs[i] + + # Init cached output object + output.index = i + output.text = output_text + + if isinstance(output_token_ids, int): + output.token_ids.clear() + output.token_ids.append(output_token_ids) + else: + output.token_ids = output_token_ids + + output.cumulative_logprob = seq.get_cumulative_logprob() \ + if include_logprobs else None + output.logprobs = output_logprobs + output.finish_reason = SequenceStatus.get_finished_reason( + seq.status) + output.stop_reason = seq.stop_reason + + else: + output = CompletionOutput( seqs.index(seq), output_text, output_token_ids, seq.get_cumulative_logprob() if include_logprobs else None, output_logprobs, SequenceStatus.get_finished_reason(seq.status), - seq.stop_reason)) + seq.stop_reason) + + outputs.append(output) # Every sequence in the sequence group should have the same prompt. if include_prompt: @@ -188,16 +235,35 @@ def from_seq_group(cls, prompt_logprobs = None finished_time = time.time() if finished else None seq_group.set_finished_time(finished_time) - return cls(seq_group.request_id, - prompt, - prompt_token_ids, - prompt_logprobs, - outputs, - finished, - seq_group.metrics, - lora_request=seq_group.lora_request, - encoder_prompt=encoder_prompt, - encoder_prompt_token_ids=encoder_prompt_token_ids) + + if use_cache: + request_output = seq_group.cached_request_output + request_output.__init__( # type: ignore + seq_group.request_id, + prompt, + prompt_token_ids, + prompt_logprobs, + outputs, + finished, + seq_group.metrics, + lora_request=seq_group.lora_request, + encoder_prompt=encoder_prompt, + encoder_prompt_token_ids=encoder_prompt_token_ids) + + else: + request_output = cls( + seq_group.request_id, + prompt, + prompt_token_ids, + prompt_logprobs, + outputs, + finished, + seq_group.metrics, + lora_request=seq_group.lora_request, + encoder_prompt=encoder_prompt, + encoder_prompt_token_ids=encoder_prompt_token_ids) + + return request_output def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " @@ -261,10 +327,10 @@ def __repr__(self): class RequestOutputFactory: @staticmethod - def create(seq_group): + def create(seq_group: SequenceGroup, use_cache: bool = False): # Determine the type based on a condition, for example: if hasattr(seq_group, 'embeddings') and seq_group.embeddings is not None: return EmbeddingRequestOutput.from_seq_group(seq_group) else: - return RequestOutput.from_seq_group(seq_group) + return RequestOutput.from_seq_group(seq_group, use_cache) diff --git a/vllm/sequence.py b/vllm/sequence.py index d8e54ff1fc70..1e12590a43a0 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -262,6 +262,8 @@ def mrope_position_delta(self, new_mrope_position_delta): self._mrope_position_delta = new_mrope_position_delta def append_token_id(self, token_id: int, logprob: float) -> None: + self.last_appended_tokens.append(token_id) + self._output_token_ids.append(token_id) self._new_appended_tokens.append(token_id) self._cached_all_token_ids.append(token_id) @@ -499,18 +501,23 @@ def get_output_text_to_return(self, buffer_length: int, return self.output_text[last_offset:length] return "" - def get_output_token_ids_to_return(self, - delta: bool) -> GenericSequence[int]: + def get_output_token_ids_to_return( + self, delta: bool) -> GenericSequence[int] | int: """If delta is True, only new tokens since the last call to this method are returned""" if not delta: return self.get_output_token_ids() - length = self.get_output_len() - last_offset = self._last_token_ids_offset - if last_offset < length: - self._last_token_ids_offset = length - return self.data._output_token_ids[last_offset:] - return () + + # Optimization for single decode token case + # (which is what we have most of the time) + if len(self.data.last_appended_tokens) == 1: + new_token = self.data.last_appended_tokens[0] + self.data.last_appended_tokens.clear() + return new_token + else: + new_tokens = self.data.last_appended_tokens + self.data.last_appended_tokens = [] + return new_tokens def hash_of_block(self, logical_idx: int) -> int: # TODO This can produce incorrect hash when block size > prompt size @@ -671,6 +678,8 @@ def __init__( self.encoder_seq = encoder_seq self.trace_headers = trace_headers + self.cached_request_output = None + @property def prompt(self) -> Optional[str]: # All sequences in the group should have the same prompt. From bf6d04973475571f2eb04de3b1d5a33bc591778f Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Thu, 19 Sep 2024 18:52:09 +0000 Subject: [PATCH 2/6] format --- tests/entrypoints/openai/test_accuracy.py | 13 ++----------- vllm/engine/arg_utils.py | 1 - vllm/engine/llm_engine.py | 7 +------ vllm/engine/multiprocessing/client.py | 2 +- vllm/engine/multiprocessing/engine.py | 2 +- vllm/outputs.py | 7 ++++--- vllm/sequence.py | 2 +- 7 files changed, 10 insertions(+), 24 deletions(-) diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index 6a0c1adabc2b..6da5d4ba6382 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -19,21 +19,12 @@ RTOL = 0.03 EXPECTED_VALUE = 0.58 DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"] -# MORE_ARGS_LIST = [["--enable-chunked-prefill"], ["--num-scheduler-steps", "8"]] +# MORE_ARGS_LIST = [["--enable-chunked +# -prefill"], ["--num-scheduler-steps", "8"]] MORE_ARGS_LIST = [[ "--num-scheduler-steps", "8", "--multi-step-stream-outputs" ]] -# @pytest.fixture(scope="module") -# def server(): -# args = [ -# "--max-model-len", "4096", "--enable-chunked-prefill", -# "--disable-log-requests", "--enforce-eager" -# ] - -# with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: -# yield remote_server - @pytest.mark.parametrize("more_args", MORE_ARGS_LIST) def test_lm_eval_accuracy(more_args): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index cefe810d924d..0d4559e37742 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -596,7 +596,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help=('Maximum number of forward steps per ' 'scheduler call.')) - # LoRA related configs parser.add_argument( '--multi-step-stream-outputs', action='store_true', diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index b7f09be1cb14..1e77a01bfa9d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -95,8 +95,7 @@ class OutputData(NamedTuple): class SchedulerContext: - def __init__(self, - multi_step_stream_outputs: bool = False): + def __init__(self, multi_step_stream_outputs: bool = False): self.output_queue: Deque[OutputData] = deque() self.request_outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] @@ -118,10 +117,6 @@ def append_output(self, outputs: List[SamplerOutput], is_last_step=is_last_step, skip=[])) - # TODO: Remove - # use_request_output_cache: bool = True - # request_output_cache: PyObjectCache = PyObjectCache(request_output_builder) - class LLMEngine: """An LLM engine that receives requests and generates texts. diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index e403fdaf95a1..71099115ea12 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -488,7 +488,7 @@ async def _process_request( if isinstance(request_output, BaseException): raise request_output - + finished = request_output.finished yield request_output finally: diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 71bb132e9f65..3b0f617629d6 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -67,7 +67,7 @@ def __init__(self, log_requests: bool = True, **kwargs) -> None: # For MQLLMEngine, we can use cached outputs, since each new request - # output is immediately pickled and send over the socket, which frees + # output is immediately pickled and send over the socket, which frees # the python object to be reused again. use_cached_outputs = True diff --git a/vllm/outputs.py b/vllm/outputs.py index 996fa165ae0b..9f536bbc8900 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -128,7 +128,7 @@ def from_seq_group(cls, seq_group: SequenceGroup, # Init cache (if needed) if use_cache and seq_group.cached_request_output is None: - seq_group.cached_request_output = RequestOutput( + seq_group.cached_request_output = RequestOutput( # type: ignore request_id="", prompt=None, prompt_token_ids=[], @@ -181,7 +181,7 @@ def from_seq_group(cls, seq_group: SequenceGroup, if use_cache: # Get cached output object - cached_outputs = seq_group.cached_request_output.outputs + cached_outputs = seq_group.cached_request_output.outputs # type: ignore if i >= len(cached_outputs): cached_outputs.append( CompletionOutput(index=i, @@ -212,7 +212,8 @@ def from_seq_group(cls, seq_group: SequenceGroup, else: output = CompletionOutput( - seqs.index(seq), output_text, output_token_ids, + seqs.index(seq), output_text, [output_token_ids] + if isinstance(output_token_ids, int) else output_token_ids, seq.get_cumulative_logprob() if include_logprobs else None, output_logprobs, SequenceStatus.get_finished_reason(seq.status), diff --git a/vllm/sequence.py b/vllm/sequence.py index 1e12590a43a0..a48ad2321461 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -502,7 +502,7 @@ def get_output_text_to_return(self, buffer_length: int, return "" def get_output_token_ids_to_return( - self, delta: bool) -> GenericSequence[int] | int: + self, delta: bool) -> Union[GenericSequence[int], int]: """If delta is True, only new tokens since the last call to this method are returned""" if not delta: From a4aa1b93c9a5110ae587e15e7344ac5ae6f84437 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Thu, 19 Sep 2024 19:51:50 +0000 Subject: [PATCH 3/6] fix test --- tests/entrypoints/openai/test_accuracy.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/entrypoints/openai/test_accuracy.py b/tests/entrypoints/openai/test_accuracy.py index 6da5d4ba6382..63beaaba29a8 100644 --- a/tests/entrypoints/openai/test_accuracy.py +++ b/tests/entrypoints/openai/test_accuracy.py @@ -19,11 +19,11 @@ RTOL = 0.03 EXPECTED_VALUE = 0.58 DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"] -# MORE_ARGS_LIST = [["--enable-chunked -# -prefill"], ["--num-scheduler-steps", "8"]] -MORE_ARGS_LIST = [[ - "--num-scheduler-steps", "8", "--multi-step-stream-outputs" -]] +MORE_ARGS_LIST = [ + ["--enable-chunked-prefill"], # Chunked + ["--num-scheduler-steps", "8"], # MS + ["--num-scheduler-steps", "8", "--multi-step-stream-outputs"] # MS+Stream +] @pytest.mark.parametrize("more_args", MORE_ARGS_LIST) From 8e5ddf2ca8f1f2b370f11444c045437a44cbf42f Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Fri, 20 Sep 2024 20:00:59 +0000 Subject: [PATCH 4/6] Nick's comments --- vllm/outputs.py | 31 ++++++++----------------------- 1 file changed, 8 insertions(+), 23 deletions(-) diff --git a/vllm/outputs.py b/vllm/outputs.py index 9f536bbc8900..44cde6b561d8 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -194,7 +194,7 @@ def from_seq_group(cls, seq_group: SequenceGroup, output = cached_outputs[i] # Init cached output object - output.index = i + assert output.index == i output.text = output_text if isinstance(output_token_ids, int): @@ -237,32 +237,17 @@ def from_seq_group(cls, seq_group: SequenceGroup, finished_time = time.time() if finished else None seq_group.set_finished_time(finished_time) + init_args = (seq_group.request_id, prompt, prompt_token_ids, + prompt_logprobs, outputs, finished, seq_group.metrics, + seq_group.lora_request, encoder_prompt, + encoder_prompt_token_ids) + if use_cache: request_output = seq_group.cached_request_output - request_output.__init__( # type: ignore - seq_group.request_id, - prompt, - prompt_token_ids, - prompt_logprobs, - outputs, - finished, - seq_group.metrics, - lora_request=seq_group.lora_request, - encoder_prompt=encoder_prompt, - encoder_prompt_token_ids=encoder_prompt_token_ids) + request_output.__init__(*init_args) # type: ignore else: - request_output = cls( - seq_group.request_id, - prompt, - prompt_token_ids, - prompt_logprobs, - outputs, - finished, - seq_group.metrics, - lora_request=seq_group.lora_request, - encoder_prompt=encoder_prompt, - encoder_prompt_token_ids=encoder_prompt_token_ids) + request_output = cls(*init_args) return request_output From aeeda6a1a217591c3220ac0b4d49ded484876682 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Sat, 21 Sep 2024 21:39:36 +0000 Subject: [PATCH 5/6] Refactor to Nick's suggestion to use _cached_all_token_ids --- vllm/sequence.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index a48ad2321461..329a11196b57 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -262,8 +262,6 @@ def mrope_position_delta(self, new_mrope_position_delta): self._mrope_position_delta = new_mrope_position_delta def append_token_id(self, token_id: int, logprob: float) -> None: - self.last_appended_tokens.append(token_id) - self._output_token_ids.append(token_id) self._new_appended_tokens.append(token_id) self._cached_all_token_ids.append(token_id) @@ -438,7 +436,7 @@ def __init__( self.stop_reason: Union[int, str, None] = None # These are used to keep track of delta outputs - self._last_token_ids_offset: int = 0 + self._last_output_token_ids_offset: int = 0 self._last_output_text_offset: int = 0 # Used for incremental detokenization @@ -508,16 +506,22 @@ def get_output_token_ids_to_return( if not delta: return self.get_output_token_ids() - # Optimization for single decode token case - # (which is what we have most of the time) - if len(self.data.last_appended_tokens) == 1: - new_token = self.data.last_appended_tokens[0] - self.data.last_appended_tokens.clear() - return new_token + prompt_len = self.get_prompt_len() + output_len = self.get_output_len() + + # Get the number of new tokens + output_last_offset = self._last_output_token_ids_offset + num_new_tokens = output_len - self._last_output_token_ids_offset + self._last_output_token_ids_offset = output_len + + # Return new tokens + if num_new_tokens == 1: + # Optimization for single decode token case + # (which is what we have most of the time) + return self.data._cached_all_token_ids[-1] else: - new_tokens = self.data.last_appended_tokens - self.data.last_appended_tokens = [] - return new_tokens + return self.data._cached_all_token_ids[prompt_len + + output_last_offset:] def hash_of_block(self, logical_idx: int) -> int: # TODO This can produce incorrect hash when block size > prompt size From ece07c773d8720f1bee471d4d15ebc12a351052c Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Mon, 23 Sep 2024 17:24:05 +0000 Subject: [PATCH 6/6] Nick's fix --- vllm/sequence.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 329a11196b57..79e8a1f6244d 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -506,11 +506,9 @@ def get_output_token_ids_to_return( if not delta: return self.get_output_token_ids() - prompt_len = self.get_prompt_len() output_len = self.get_output_len() # Get the number of new tokens - output_last_offset = self._last_output_token_ids_offset num_new_tokens = output_len - self._last_output_token_ids_offset self._last_output_token_ids_offset = output_len @@ -519,9 +517,8 @@ def get_output_token_ids_to_return( # Optimization for single decode token case # (which is what we have most of the time) return self.data._cached_all_token_ids[-1] - else: - return self.data._cached_all_token_ids[prompt_len + - output_last_offset:] + + return self.data._cached_all_token_ids[-num_new_tokens:] def hash_of_block(self, logical_idx: int) -> int: # TODO This can produce incorrect hash when block size > prompt size