Skip to content

Commit

Permalink
[mypy] Enable following imports for some directories (vllm-project#6681)
Browse files Browse the repository at this point in the history
  • Loading branch information
DarkLight1337 authored and kylesayrs committed Aug 17, 2024
1 parent 354f34f commit 21098f5
Show file tree
Hide file tree
Showing 18 changed files with 185 additions and 143 deletions.
31 changes: 13 additions & 18 deletions .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,17 @@ jobs:
pip install types-setuptools
- name: Mypy
run: |
mypy tests --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/attention --config-file pyproject.toml
mypy vllm/core --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
mypy vllm/inputs --config-file pyproject.toml
mypy vllm/logging --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/multimodal --config-file pyproject.toml
mypy vllm/platforms --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy tests --follow-imports skip
mypy vllm/attention --follow-imports skip
mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip
mypy vllm/entrypoints --follow-imports skip
mypy vllm/executor --follow-imports skip
mypy vllm/lora --follow-imports skip
mypy vllm/model_executor --follow-imports skip
mypy vllm/prompt_adapter --follow-imports skip
mypy vllm/spec_decode --follow-imports skip
mypy vllm/worker --follow-imports skip
mypy
30 changes: 13 additions & 17 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -96,23 +96,19 @@ echo 'vLLM yapf: Done'

# Run mypy
echo 'vLLM mypy:'
mypy tests --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/attention --config-file pyproject.toml
mypy vllm/core --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
mypy vllm/logging --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/multimodal --config-file pyproject.toml
mypy vllm/prompt_adapter --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy tests --follow-imports skip
mypy vllm/attention --follow-imports skip
mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip
mypy vllm/entrypoints --follow-imports skip
mypy vllm/executor --follow-imports skip
mypy vllm/lora --follow-imports skip
mypy vllm/model_executor --follow-imports skip
mypy vllm/prompt_adapter --follow-imports skip
mypy vllm/spec_decode --follow-imports skip
mypy vllm/worker --follow-imports skip
mypy


# If git diff returns a file that is in the skip list, the file may be checked anyway:
Expand Down
18 changes: 16 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,23 @@ python_version = "3.8"

ignore_missing_imports = true
check_untyped_defs = true
follow_imports = "skip"
follow_imports = "silent"

files = "vllm"
# After fixing type errors resulting from follow_imports: "skip" -> "silent",
# move the directory here and remove it from format.sh and mypy.yaml
files = [
"vllm/*.py",
"vllm/adapter_commons",
"vllm/assets",
"vllm/inputs",
"vllm/logging",
"vllm/multimodal",
"vllm/platforms",
"vllm/server",
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
]
# TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
Expand Down
2 changes: 1 addition & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def cutlass_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
Expand Down
62 changes: 49 additions & 13 deletions vllm/_ipex_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,33 @@ def _reshape_activation_tensor(
x2 = x2.reshape(num, d)
return x1, x2

@staticmethod
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.silu_mul(x1, x2, out)

@staticmethod
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.gelu_mul(x1, x2, out, "none")

@staticmethod
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x)
ipex.llm.functional.gelu_mul(x1, x2, out, "tanh")

@staticmethod
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None:
out.copy_(torch.nn.functional.gelu(x))

@staticmethod
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None:
out.copy_(torch.nn.functional.gelu(x))

# TODO add implementation of gelu_quick here
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:

@staticmethod
def paged_attention_v1(
out: torch.Tensor,
query: torch.Tensor,
Expand Down Expand Up @@ -78,12 +84,21 @@ def paged_attention_v1(
).view(num_kv_heads,
1).repeat_interleave(num_queries_per_tokens).flatten()
# todo: ipex will refactor namespace
torch.xpu.paged_attention_v1(out, query.contiguous(),
key_cache.view_as(value_cache),
value_cache, head_mapping, scale,
block_tables, context_lens, block_size,
max_context_len, alibi_slopes)
torch.xpu.paged_attention_v1( # type: ignore
out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache,
head_mapping,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
)

@staticmethod
def paged_attention_v2(
out: torch.Tensor,
exp_sum: torch.Tensor,
Expand Down Expand Up @@ -119,13 +134,24 @@ def paged_attention_v2(
).view(num_kv_heads,
1).repeat_interleave(num_queries_per_tokens).flatten()
# todo: ipex will refactor namespace
torch.xpu.paged_attention_v2(out, exp_sum, max_logits, tmp_out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache, head_mapping, block_tables,
context_lens, scale, block_size,
max_context_len, alibi_slopes)
torch.xpu.paged_attention_v2( # type: ignore
out,
exp_sum,
max_logits,
tmp_out,
query.contiguous(),
key_cache.view_as(value_cache),
value_cache,
head_mapping,
block_tables,
context_lens,
scale,
block_size,
max_context_len,
alibi_slopes,
)

@staticmethod
def rotary_embedding(
positions: torch.Tensor, # [batch_size, seq_len]
query: torch.Tensor, # [batch_size, seq_len, num_heads*head_size]
Expand Down Expand Up @@ -158,6 +184,7 @@ def rotary_embedding(
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
rotary_dim, is_neox, positions)

@staticmethod
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool,
Expand Down Expand Up @@ -189,17 +216,20 @@ def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
rotary_dim, is_neox, positions)

@staticmethod
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None:
tmp = ipex.llm.functional.rms_norm(input, weight, epsilon)
out.copy_(tmp)

@staticmethod
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
weight: torch.Tensor, epsilon: float) -> None:
tmp = ipex.llm.functional.add_rms_norm(residual, input, weight, None,
epsilon, True)
input.copy_(tmp)

@staticmethod
def varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
Expand All @@ -222,6 +252,7 @@ def varlen_attention(
softmax_scale, zero_tensors,
is_causal, return_softmax, gen_)

@staticmethod
def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
Expand All @@ -240,8 +271,13 @@ def reshape_and_cache(
def copy_blocks(key_caches: List[torch.Tensor],
value_caches: List[torch.Tensor],
block_mapping: torch.Tensor) -> None:
torch.xpu.copy_blocks(key_caches, value_caches, block_mapping)
torch.xpu.copy_blocks( # type: ignore
key_caches,
value_caches,
block_mapping,
)

@staticmethod
def swap_blocks(src: torch.Tensor, dst: torch.Tensor,
block_mapping: torch.Tensor) -> None:
torch.xpu.swap_blocks(src, dst, block_mapping)
torch.xpu.swap_blocks(src, dst, block_mapping) # type: ignore
30 changes: 15 additions & 15 deletions vllm/adapter_commons/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable],
super().__init__(capacity)
self.deactivate_fn = deactivate_fn

def _on_remove(self, key: Hashable, value: T):
def _on_remove(self, key: Hashable, value: Optional[T]):
logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key)
return super()._on_remove(key, value)
Expand Down Expand Up @@ -59,46 +59,46 @@ def __len__(self) -> int:

@property
@abstractmethod
def adapter_slots(self):
...
def adapter_slots(self) -> int:
raise NotImplementedError

@property
@abstractmethod
def capacity(self):
...
def capacity(self) -> int:
raise NotImplementedError

@abstractmethod
def activate_adapter(self, adapter_id: int) -> bool:
...
raise NotImplementedError

@abstractmethod
def deactivate_adapter(self, adapter_id: int) -> bool:
...
raise NotImplementedError

@abstractmethod
def add_adapter(self, adapter: Any) -> bool:
...
raise NotImplementedError

@abstractmethod
def set_adapter_mapping(self, mapping: Any) -> None:
...
raise NotImplementedError

@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
...
raise NotImplementedError

@abstractmethod
def remove_all_adapters(self):
...
def remove_all_adapters(self) -> None:
raise NotImplementedError

@abstractmethod
def get_adapter(self, adapter_id: int) -> Optional[Any]:
...
raise NotImplementedError

@abstractmethod
def list_adapters(self) -> Dict[int, Any]:
...
raise NotImplementedError

@abstractmethod
def pin_adapter(self, adapter_id: int) -> bool:
...
raise NotImplementedError
10 changes: 5 additions & 5 deletions vllm/adapter_commons/request.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
from abc import abstractmethod
from abc import ABC, abstractmethod
from dataclasses import dataclass


@dataclass
class AdapterRequest:
class AdapterRequest(ABC):
"""
Base class for adapter requests.
"""

@property
@abstractmethod
def adapter_id(self):
...
def adapter_id(self) -> int:
raise NotImplementedError

def __post_init__(self):
def __post_init__(self) -> None:
if self.adapter_id < 1:
raise ValueError(f"id must be > 0, got {self.adapter_id}")

Expand Down
14 changes: 7 additions & 7 deletions vllm/adapter_commons/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,25 @@ def __init__(self, device: torch.device):
@property
@abstractmethod
def is_enabled(self) -> bool:
...
raise NotImplementedError

@abstractmethod
def set_active_adapters(self, requests: Set[Any],
mapping: Optional[Any]) -> None:
...
raise NotImplementedError

@abstractmethod
def add_adapter(self, adapter_request: Any) -> bool:
...
raise NotImplementedError

@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
...
raise NotImplementedError

@abstractmethod
def remove_all_adapters(self):
...
def remove_all_adapters(self) -> None:
raise NotImplementedError

@abstractmethod
def list_adapters(self) -> Set[int]:
...
raise NotImplementedError
3 changes: 2 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,7 @@ def __init__(
backend)

self._verify_args()
self.rank = 0
self.rank: int = 0

@property
def use_ray(self) -> bool:
Expand Down Expand Up @@ -850,6 +850,7 @@ def _verify_args(self) -> None:


class DeviceConfig:
device: Optional[torch.device]

def __init__(self, device: str = "auto") -> None:
if device == "auto":
Expand Down
Loading

0 comments on commit 21098f5

Please sign in to comment.