Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypy] Enable mypy type checking for vllm/core #7229

Merged
merged 2 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ jobs:
mypy
mypy tests --follow-imports skip
mypy vllm/attention --follow-imports skip
mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip
mypy vllm/executor --follow-imports skip
Expand Down
1 change: 0 additions & 1 deletion format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ echo 'vLLM mypy:'
mypy --follow-imports skip # Note that this is less strict than CI
mypy tests --follow-imports skip
mypy vllm/attention --follow-imports skip
mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip
mypy vllm/executor --follow-imports skip
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ files = [
"vllm/adapter_commons",
"vllm/assets",
"vllm/entrypoints",
"vllm/core",
"vllm/inputs",
"vllm/logging",
"vllm/multimodal",
Expand Down
8 changes: 6 additions & 2 deletions vllm/block.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Token blocks."""
from typing import List, Optional
from typing import TYPE_CHECKING, Iterator, List, Optional

from vllm.utils import Device

DEFAULT_LAST_ACCESSED_TIME = -1
DEFAULT_LAST_ACCESSED_TIME: float = -1


class PhysicalTokenBlock:
Expand Down Expand Up @@ -59,6 +59,10 @@ def __len__(self) -> int:
def __getitem__(self, key):
return self._blocks[key]

if TYPE_CHECKING:
def __iter__(self) -> Iterator[PhysicalTokenBlock]:
raise RuntimeError("Method should be automatically generated")

def __setitem__(self, key, value):
if isinstance(key, slice):
blocks = value
Expand Down
2 changes: 1 addition & 1 deletion vllm/core/block/cpu_gpu_block_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def allocate_mutable_block(self, prev_block: Optional[Block],

def allocate_immutable_blocks(self, prev_block: Optional[Block],
block_token_ids: List[List[int]],
device: Optional[Device]) -> List[Block]:
device: Device) -> List[Block]:
"""Allocates a new group of immutable blocks with the provided block
token IDs on the specified device.

Expand Down
7 changes: 4 additions & 3 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def __init__(
# request ID
self.cross_block_tables: Dict[str, BlockTable] = {}

def _get_seq_num_required_blocks(self, seq: Sequence) -> int:
def _get_seq_num_required_blocks(self, seq: Optional[Sequence]) -> int:
return 0 if seq is None else seq.n_blocks

def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
Expand Down Expand Up @@ -310,13 +310,14 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
return AllocStatus.LATER

def _allocate_sequence(self, \
seq: Sequence, \
seq: Optional[Sequence], \
ref_count: int, \
is_encoder_decoder: bool = True) -> BlockTable:
# Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = seq.n_blocks
num_prompt_blocks = self._get_seq_num_required_blocks(seq)

block_table: BlockTable = BlockTable()
assert seq is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is required? Since num_prompt_blocks = 0 if seq is None anyway

Copy link
Contributor Author

@jberkhahn jberkhahn Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is seq is optional and directly referenced on lines 328-330:

                block = self.gpu_allocator.allocate(
                    seq.hash_of_block(logical_idx),
                    seq.num_hashed_tokens_of_block(logical_idx))

It's not technically possible to get there if seq is none but mypy is still complaining because it doesn't care about the check in _get_seq_num_required_blocks(seq). Do you have a more straightforward suggestion?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for the explanation.

for logical_idx in range(num_prompt_blocks):
if (self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window):
Expand Down
8 changes: 6 additions & 2 deletions vllm/core/block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,10 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
)

if seq_group.is_encoder_decoder():
encoder_seq = seq_group.get_encoder_seq()
assert encoder_seq is not None
num_required_blocks += BlockTable.get_num_required_blocks(
seq_group.get_encoder_seq().get_token_ids(),
encoder_seq.get_token_ids(),
block_size=self.block_size,
)

Expand Down Expand Up @@ -189,7 +191,9 @@ def allocate(self, seq_group: SequenceGroup) -> None:
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)

if seq_group.is_encoder_decoder():
block_table = self._allocate_sequence(seq_group.get_encoder_seq())
encoder_seq = seq_group.get_encoder_seq()
assert encoder_seq is not None
block_table = self._allocate_sequence(encoder_seq)
self.cross_block_tables[request_id] = block_table

def can_append_slots(self, seq_group: SequenceGroup,
Expand Down
4 changes: 2 additions & 2 deletions vllm/core/embedding_model_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def access_all_blocks_in_seq(
pass

def get_common_computed_block_ids(self,
seq_group: SequenceGroup) -> List[int]:
return None # type: ignore
seq_group: List[Sequence]) -> List[int]:
return []

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass
Expand Down
15 changes: 9 additions & 6 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,10 @@ class SchedulerSwappedInOutputs:
"""
# Selected sequences that are going to be swapped in and is in a
# decoding phase.
decode_seq_groups: List[SequenceGroup]
decode_seq_groups: List[ScheduledSequenceGroup]
# Selected sequences that are going to be swapped in and in a prefill
# phase. I.e., it means the prefill has been chunked.
prefill_seq_groups: List[SequenceGroup]
prefill_seq_groups: List[ScheduledSequenceGroup]
# The blocks to swap in.
blocks_to_swap_in: List[Tuple[int, int]]
# The blocks to copy.
Expand Down Expand Up @@ -253,7 +253,7 @@ class SchedulerPrefillOutputs:
to be recomputed from scratch.
"""
# Selected sequences for prefill.
seq_groups: List[SequenceGroup]
seq_groups: List[ScheduledSequenceGroup]
# Ignored sequence groups.
ignored_seq_groups: List[SequenceGroup]
num_lookahead_slots: int
Expand Down Expand Up @@ -288,7 +288,8 @@ def scheduler_running_outputs_builder():


def scheduled_seq_group_builder():
return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
return ScheduledSequenceGroup(SequenceGroup("", [], -1), token_chunk_size=0)
# return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)


class Scheduler:
Expand Down Expand Up @@ -737,7 +738,7 @@ def _schedule_prefills(
SchedulerPrefillOutputs.
"""
ignored_seq_groups: List[SequenceGroup] = []
seq_groups: List[SequenceGroup] = []
seq_groups: List[ScheduledSequenceGroup] = []

waiting_queue = self.waiting

Expand Down Expand Up @@ -1057,7 +1058,9 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:

if seq_group.is_encoder_decoder():
# Encoder associated with SequenceGroup
encoder_seq_data = seq_group.get_encoder_seq().data
encoder_seq = seq_group.get_encoder_seq()
assert encoder_seq is not None
encoder_seq_data = encoder_seq.data
# Block table for cross-attention
# Also managed at SequenceGroup level
cross_block_table = self.block_manager.get_cross_block_table(
Expand Down