[Kernel] refactor cpu worker v0 cache dtype (#20080)

Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
Ning Xie 2025-07-03 16:39:14 +08:00 committed by GitHub
parent b024a42e93
commit fb14d53cf6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -18,7 +18,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache
from vllm.utils import bind_kv_cache
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
@ -54,13 +54,8 @@ class CPUCacheEngine:
# in the scheduler.
self.num_cpu_blocks = cache_config.num_gpu_blocks
if cache_config.cache_dtype == "auto":
self.dtype = model_config.dtype
elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]:
self.dtype = torch.float8_e5m2
else:
raise NotImplementedError(f"Unsupported KV cache type "
f"{cache_config.cache_dtype}.")
self.dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config,
model_config)
# Get attention backend.
self.attn_backend = get_attn_backend(
@ -97,10 +92,20 @@ class CPUCacheEngine:
def copy(self, src_to_dsts: torch.Tensor) -> None:
self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
@staticmethod
def get_kv_cache_dtype(cache_config: CacheConfig,
model_config: ModelConfig):
if cache_config.cache_dtype == "auto":
return model_config.dtype
elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]:
return torch.float8_e5m2
else:
raise NotImplementedError(f"Unsupported KV cache type "
f"{cache_config.cache_dtype}.")
@staticmethod
def get_cache_block_size(
block_size: int,
cache_dtype: str,
cache_config: CacheConfig,
model_config: ModelConfig,
parallel_config: ParallelConfig,
) -> int:
@ -108,13 +113,10 @@ class CPUCacheEngine:
num_heads = model_config.get_num_kv_heads(parallel_config)
num_layers = model_config.get_num_layers(parallel_config)
key_cache_block = block_size * num_heads * head_size
key_cache_block = cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block if not model_config.use_mla else 0
total = num_layers * (key_cache_block + value_cache_block)
if cache_dtype == "auto":
dtype = model_config.dtype
else:
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config, model_config)
dtype_size = torch.tensor([], dtype=dtype).element_size()
return dtype_size * total
@ -399,9 +401,9 @@ class CPUWorker(LocalOrDistributedWorkerBase):
def get_cache_block_size_bytes(self) -> int:
"""Return the size in bytes of a single KV cache block.
"""
return CPUCacheEngine.get_cache_block_size(
self.cache_config.block_size, self.cache_config.cache_dtype,
self.model_config, self.parallel_config)
return CPUCacheEngine.get_cache_block_size(self.cache_config,
self.model_config,
self.parallel_config)
def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
"""Return CPUs id binding based on NUMA nodes.