From fb14d53cf6e68866cc13e348db18f361da8fb7ec Mon Sep 17 00:00:00 2001 From: Ning Xie Date: Thu, 3 Jul 2025 16:39:14 +0800 Subject: [PATCH] [Kernel] refactor cpu worker v0 cache dtype (#20080) Signed-off-by: Andy Xie --- vllm/worker/cpu_worker.py | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index ff110e050bb6f..a8998127b60f3 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -18,7 +18,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest -from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache +from vllm.utils import bind_kv_cache from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner @@ -54,13 +54,8 @@ class CPUCacheEngine: # in the scheduler. self.num_cpu_blocks = cache_config.num_gpu_blocks - if cache_config.cache_dtype == "auto": - self.dtype = model_config.dtype - elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]: - self.dtype = torch.float8_e5m2 - else: - raise NotImplementedError(f"Unsupported KV cache type " - f"{cache_config.cache_dtype}.") + self.dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config, + model_config) # Get attention backend. self.attn_backend = get_attn_backend( @@ -97,10 +92,20 @@ class CPUCacheEngine: def copy(self, src_to_dsts: torch.Tensor) -> None: self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts) + @staticmethod + def get_kv_cache_dtype(cache_config: CacheConfig, + model_config: ModelConfig): + if cache_config.cache_dtype == "auto": + return model_config.dtype + elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]: + return torch.float8_e5m2 + else: + raise NotImplementedError(f"Unsupported KV cache type " + f"{cache_config.cache_dtype}.") + @staticmethod def get_cache_block_size( - block_size: int, - cache_dtype: str, + cache_config: CacheConfig, model_config: ModelConfig, parallel_config: ParallelConfig, ) -> int: @@ -108,13 +113,10 @@ class CPUCacheEngine: num_heads = model_config.get_num_kv_heads(parallel_config) num_layers = model_config.get_num_layers(parallel_config) - key_cache_block = block_size * num_heads * head_size + key_cache_block = cache_config.block_size * num_heads * head_size value_cache_block = key_cache_block if not model_config.use_mla else 0 total = num_layers * (key_cache_block + value_cache_block) - if cache_dtype == "auto": - dtype = model_config.dtype - else: - dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config, model_config) dtype_size = torch.tensor([], dtype=dtype).element_size() return dtype_size * total @@ -399,9 +401,9 @@ class CPUWorker(LocalOrDistributedWorkerBase): def get_cache_block_size_bytes(self) -> int: """Return the size in bytes of a single KV cache block. """ - return CPUCacheEngine.get_cache_block_size( - self.cache_config.block_size, self.cache_config.cache_dtype, - self.model_config, self.parallel_config) + return CPUCacheEngine.get_cache_block_size(self.cache_config, + self.model_config, + self.parallel_config) def get_cpus_id_binding_based_on_numa_nodes(self) -> str: """Return CPUs id binding based on NUMA nodes.