Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Optimize get_seqs #7051

Merged
merged 3 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,5 +700,5 @@ def get_common_computed_block_ids(

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
if self.enable_caching:
for seq in seq_group.seqs_dict.values():
for seq in seq_group.get_seqs():
self.compute_full_blocks_in_seq(seq)
40 changes: 21 additions & 19 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ def __init__(
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
self.request_id = request_id
self.seqs = seqs
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
self.sampling_params = sampling_params
self.metrics = RequestMetrics(arrival_time=arrival_time,
Expand All @@ -458,25 +459,24 @@ def __init__(
self.prompt_adapter_request = prompt_adapter_request
self.encoder_seq = encoder_seq
self.trace_headers = trace_headers
self._first_seq = next(iter(self.seqs_dict.values()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can still keep self._first_seq = seqs[0] , and use it to replace self.seqs[0]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it doesn't hurt much to use seqs[0] without caching it? _first_seq was introduced to avoid the overhead of retrieving a value from the dictionary. I believe the overhead of seqs[0] will be negligible even if it's Python.

Also, since the sequence can be removed, I feel more comfortable with self.seqs[0] than caching the sequence.


@property
def prompt(self) -> Optional[str]:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return self._first_seq.prompt
return self.seqs[0].prompt

@property
def prompt_token_ids(self) -> List[int]:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return self._first_seq.prompt_token_ids
return self.seqs[0].prompt_token_ids

@property
def multi_modal_data(self) -> "MultiModalDataDict":
# All sequences in the group should have the same multi-modal data.
# We use the multi-modal data of an arbitrary sequence.
return self._first_seq.multi_modal_data
return self.seqs[0].multi_modal_data

@property
def lora_int_id(self) -> int:
Expand Down Expand Up @@ -512,7 +512,7 @@ def maybe_set_first_token_time(self, time: float) -> None:
# in TPOT, rather than recalculating TTFT (since from the )
# POV of the user, there is simply a long generation delay.
if (self.metrics.first_token_time is None
and self.get_seqs()[0].get_output_len() == 1):
and self.seqs[0].get_output_len() == 1):
self.metrics.first_token_time = time

def maybe_set_first_scheduled_time(self, time: float) -> None:
Expand Down Expand Up @@ -548,9 +548,10 @@ def get_seqs(
self,
status: Optional[SequenceStatus] = None,
) -> List[Sequence]:
return list(self.seqs_dict.values()) if status is None else [
seq for seq in self.seqs_dict.values() if seq.status == status
]
if status is None:
return self.seqs
else:
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
return [seq for seq in self.seqs if seq.status == status]

def is_encoder_decoder(self) -> bool:
return self.encoder_seq is not None
Expand All @@ -559,22 +560,20 @@ def get_encoder_seq(self) -> Optional[Sequence]:
return self.encoder_seq

def get_unfinished_seqs(self) -> List[Sequence]:
return [
seq for seq in self.seqs_dict.values() if not seq.is_finished()
]
return [seq for seq in self.seqs if not seq.is_finished()]

def get_finished_seqs(self) -> List[Sequence]:
return [seq for seq in self.seqs_dict.values() if seq.is_finished()]
return [seq for seq in self.seqs if seq.is_finished()]

def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far."""
for seq in self.seqs_dict.values():
for seq in self.seqs:
if not seq.is_finished():
seq.data.update_num_computed_tokens(num_new_computed_tokens)

def get_num_uncomputed_tokens(self) -> int:
num_uncomputed_tokens = 0
for seq in self.get_seqs():
for seq in self.seqs:
if not seq.is_finished():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
return num_uncomputed_tokens
Expand All @@ -583,7 +582,7 @@ def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
# Optimization. We don't need to call get_seqs if we don't need to
# filter by states.
if status is None:
return len(self.seqs_dict)
return len(self.seqs)

return len(self.get_seqs(status))

Expand All @@ -602,23 +601,26 @@ def add(self, seq: Sequence) -> None:
if seq.seq_id in self.seqs_dict:
raise ValueError(f"Sequence {seq.seq_id} already exists.")
self.seqs_dict[seq.seq_id] = seq
self.seqs.append(seq)

def remove(self, seq_id: int) -> None:
if seq_id not in self.seqs_dict:
seq = self.seqs_dict.get(seq_id, None)
if seq is None:
raise ValueError(f"Sequence {seq_id} not found.")
del self.seqs_dict[seq_id]
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
self.seqs.remove(seq)

def is_finished(self) -> bool:
return all(seq.is_finished() for seq in self.get_seqs())
return all(seq.is_finished() for seq in self.seqs)

def is_prefill(self) -> bool:
# Every sequence should be in the same stage.
return self.get_seqs()[0].is_prefill()
return self.seqs[0].is_prefill()

def __repr__(self) -> str:
return (f"SequenceGroup(request_id={self.request_id}, "
f"sampling_params={self.sampling_params}, "
f"num_seqs={len(self.seqs_dict)})")
f"num_seqs={len(self.seqs)})")


class SequenceGroupMetadata:
Expand Down
2 changes: 1 addition & 1 deletion vllm/transformers_utils/detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
assert prms is not None

# We can pick any sequence for the prompt.
seq = next(iter(seq_group.seqs_dict.values()))
seq = seq_group.get_seqs()[0]
# Only prompt, without the generated token.
all_token_ids = seq.get_token_ids()
prompt_token_ids = all_token_ids[:-1]
Expand Down
Loading