From 6aae3406b26530b906829fb77eb151cda1de90ad Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 1 Aug 2024 16:13:49 -0700 Subject: [PATCH 1/3] [Performance] Optimize get_seqs --- vllm/core/block_manager_v1.py | 2 +- vllm/sequence.py | 38 ++++++++++++++------------ vllm/transformers_utils/detokenizer.py | 2 +- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index e29eba375f4d..d81648caa585 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -700,5 +700,5 @@ def get_common_computed_block_ids( def mark_blocks_as_computed(self, seq_group: SequenceGroup): if self.enable_caching: - for seq in seq_group.seqs_dict.values(): + for seq in seq_group.get_seqs(): self.compute_full_blocks_in_seq(seq) diff --git a/vllm/sequence.py b/vllm/sequence.py index ab50cfdfd29a..db60b594e729 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -444,6 +444,7 @@ def __init__( prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: self.request_id = request_id + self.seqs = seqs self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.sampling_params = sampling_params self.metrics = RequestMetrics(arrival_time=arrival_time, @@ -458,25 +459,24 @@ def __init__( self.prompt_adapter_request = prompt_adapter_request self.encoder_seq = encoder_seq self.trace_headers = trace_headers - self._first_seq = next(iter(self.seqs_dict.values())) @property def prompt(self) -> Optional[str]: # All sequences in the group should have the same prompt. # We use the prompt of an arbitrary sequence. - return self._first_seq.prompt + return self.seqs[0].prompt @property def prompt_token_ids(self) -> List[int]: # All sequences in the group should have the same prompt. # We use the prompt of an arbitrary sequence. - return self._first_seq.prompt_token_ids + return self.seqs[0].prompt_token_ids @property def multi_modal_data(self) -> "MultiModalDataDict": # All sequences in the group should have the same multi-modal data. # We use the multi-modal data of an arbitrary sequence. - return self._first_seq.multi_modal_data + return self.seqs[0].multi_modal_data @property def lora_int_id(self) -> int: @@ -512,7 +512,7 @@ def maybe_set_first_token_time(self, time: float) -> None: # in TPOT, rather than recalculating TTFT (since from the ) # POV of the user, there is simply a long generation delay. if (self.metrics.first_token_time is None - and self.get_seqs()[0].get_output_len() == 1): + and self.seqs[0].get_output_len() == 1): self.metrics.first_token_time = time def maybe_set_first_scheduled_time(self, time: float) -> None: @@ -548,9 +548,10 @@ def get_seqs( self, status: Optional[SequenceStatus] = None, ) -> List[Sequence]: - return list(self.seqs_dict.values()) if status is None else [ - seq for seq in self.seqs_dict.values() if seq.status == status - ] + if status is None: + return self.seqs + else: + return [seq for seq in self.seqs if seq.status == status] def is_encoder_decoder(self) -> bool: return self.encoder_seq is not None @@ -560,21 +561,21 @@ def get_encoder_seq(self) -> Optional[Sequence]: def get_unfinished_seqs(self) -> List[Sequence]: return [ - seq for seq in self.seqs_dict.values() if not seq.is_finished() + seq for seq in self.seqs if not seq.is_finished() ] def get_finished_seqs(self) -> List[Sequence]: - return [seq for seq in self.seqs_dict.values() if seq.is_finished()] + return [seq for seq in self.seqs if seq.is_finished()] def update_num_computed_tokens(self, num_new_computed_tokens: int): """Update number of tokens computed so far.""" - for seq in self.seqs_dict.values(): + for seq in self.seqs: if not seq.is_finished(): seq.data.update_num_computed_tokens(num_new_computed_tokens) def get_num_uncomputed_tokens(self) -> int: num_uncomputed_tokens = 0 - for seq in self.get_seqs(): + for seq in self.seqs: if not seq.is_finished(): num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens() return num_uncomputed_tokens @@ -583,7 +584,7 @@ def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: # Optimization. We don't need to call get_seqs if we don't need to # filter by states. if status is None: - return len(self.seqs_dict) + return len(self.seqs) return len(self.get_seqs(status)) @@ -602,23 +603,26 @@ def add(self, seq: Sequence) -> None: if seq.seq_id in self.seqs_dict: raise ValueError(f"Sequence {seq.seq_id} already exists.") self.seqs_dict[seq.seq_id] = seq + self.seqs.append(seq) def remove(self, seq_id: int) -> None: - if seq_id not in self.seqs_dict: + seq = self.seqs_dict.get(seq_id, None) + if seq is None: raise ValueError(f"Sequence {seq_id} not found.") del self.seqs_dict[seq_id] + self.seqs.remove(seq) def is_finished(self) -> bool: - return all(seq.is_finished() for seq in self.get_seqs()) + return all(seq.is_finished() for seq in self.seqs) def is_prefill(self) -> bool: # Every sequence should be in the same stage. - return self.get_seqs()[0].is_prefill() + return self.seqs[0].is_prefill() def __repr__(self) -> str: return (f"SequenceGroup(request_id={self.request_id}, " f"sampling_params={self.sampling_params}, " - f"num_seqs={len(self.seqs_dict)})") + f"num_seqs={len(self.seqs)})") class SequenceGroupMetadata: diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 76f418674532..001af67f3bb9 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -40,7 +40,7 @@ def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, assert prms is not None # We can pick any sequence for the prompt. - seq = next(iter(seq_group.seqs_dict.values())) + seq = seq_group.get_seqs()[0] # Only prompt, without the generated token. all_token_ids = seq.get_token_ids() prompt_token_ids = all_token_ids[:-1] From 1f5b63d29e70eda3bbe2214db32d6ae97790cfaa Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 1 Aug 2024 16:16:04 -0700 Subject: [PATCH 2/3] yapf --- vllm/sequence.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index db60b594e729..b5cb6b6101f8 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -560,9 +560,7 @@ def get_encoder_seq(self) -> Optional[Sequence]: return self.encoder_seq def get_unfinished_seqs(self) -> List[Sequence]: - return [ - seq for seq in self.seqs if not seq.is_finished() - ] + return [seq for seq in self.seqs if not seq.is_finished()] def get_finished_seqs(self) -> List[Sequence]: return [seq for seq in self.seqs if seq.is_finished()] From 4d3d3b9d30abe44c8264c3a6edfe21a9cee29b53 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Thu, 1 Aug 2024 16:36:14 -0700 Subject: [PATCH 3/3] Address review --- vllm/sequence.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index b5cb6b6101f8..7ef9387c611f 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -550,8 +550,7 @@ def get_seqs( ) -> List[Sequence]: if status is None: return self.seqs - else: - return [seq for seq in self.seqs if seq.status == status] + return [seq for seq in self.seqs if seq.status == status] def is_encoder_decoder(self) -> bool: return self.encoder_seq is not None @@ -604,10 +603,9 @@ def add(self, seq: Sequence) -> None: self.seqs.append(seq) def remove(self, seq_id: int) -> None: - seq = self.seqs_dict.get(seq_id, None) + seq = self.seqs_dict.pop(seq_id, None) if seq is None: raise ValueError(f"Sequence {seq_id} not found.") - del self.seqs_dict[seq_id] self.seqs.remove(seq) def is_finished(self) -> bool: