From 5f0ec3935a0118fee8cf2764728f765c8cc53d2a Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Mon, 20 Jan 2025 21:54:16 +0800 Subject: [PATCH] [V1] Remove `_get_cache_block_size` (#12214) Signed-off-by: Chen Zhang --- vllm/v1/worker/gpu_worker.py | 24 +----------------------- 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 0929e64d58f1e..bd40112aea5e8 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -8,14 +8,13 @@ import torch.nn as nn import vllm.envs as envs -from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.platforms import current_platform -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, LayerBlockType, get_dtype_size from vllm.v1.core.scheduler import SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput @@ -235,24 +234,3 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): f"of at least 8.0. Your {gpu_name} GPU {compute_str}. " "You can use float16 instead by explicitly setting the" "`dtype` flag in CLI, for example: --dtype=half.") - - -def _get_cache_block_size( - cache_config: CacheConfig, - model_config: ModelConfig, - parallel_config: ParallelConfig, -) -> int: - head_size = model_config.get_head_size() - num_heads = model_config.get_num_kv_heads(parallel_config) - num_attention_layers = model_config.get_num_layers_by_block_type( - parallel_config, LayerBlockType.attention) - - key_cache_block = cache_config.block_size * num_heads * head_size - value_cache_block = key_cache_block - total = num_attention_layers * (key_cache_block + value_cache_block) - if cache_config.cache_dtype == "auto": - dtype = model_config.dtype - else: - dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - dtype_size = get_dtype_size(dtype) - return dtype_size * total