mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 22:14:34 +08:00
[Kernel] refactor cpu worker v0 cache dtype (#20080)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
parent
b024a42e93
commit
fb14d53cf6
@ -18,7 +18,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache
|
||||
from vllm.utils import bind_kv_cache
|
||||
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
|
||||
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
|
||||
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
|
||||
@ -54,13 +54,8 @@ class CPUCacheEngine:
|
||||
# in the scheduler.
|
||||
self.num_cpu_blocks = cache_config.num_gpu_blocks
|
||||
|
||||
if cache_config.cache_dtype == "auto":
|
||||
self.dtype = model_config.dtype
|
||||
elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]:
|
||||
self.dtype = torch.float8_e5m2
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported KV cache type "
|
||||
f"{cache_config.cache_dtype}.")
|
||||
self.dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config,
|
||||
model_config)
|
||||
|
||||
# Get attention backend.
|
||||
self.attn_backend = get_attn_backend(
|
||||
@ -97,10 +92,20 @@ class CPUCacheEngine:
|
||||
def copy(self, src_to_dsts: torch.Tensor) -> None:
|
||||
self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_dtype(cache_config: CacheConfig,
|
||||
model_config: ModelConfig):
|
||||
if cache_config.cache_dtype == "auto":
|
||||
return model_config.dtype
|
||||
elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]:
|
||||
return torch.float8_e5m2
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported KV cache type "
|
||||
f"{cache_config.cache_dtype}.")
|
||||
|
||||
@staticmethod
|
||||
def get_cache_block_size(
|
||||
block_size: int,
|
||||
cache_dtype: str,
|
||||
cache_config: CacheConfig,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> int:
|
||||
@ -108,13 +113,10 @@ class CPUCacheEngine:
|
||||
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
num_layers = model_config.get_num_layers(parallel_config)
|
||||
|
||||
key_cache_block = block_size * num_heads * head_size
|
||||
key_cache_block = cache_config.block_size * num_heads * head_size
|
||||
value_cache_block = key_cache_block if not model_config.use_mla else 0
|
||||
total = num_layers * (key_cache_block + value_cache_block)
|
||||
if cache_dtype == "auto":
|
||||
dtype = model_config.dtype
|
||||
else:
|
||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
|
||||
dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config, model_config)
|
||||
dtype_size = torch.tensor([], dtype=dtype).element_size()
|
||||
return dtype_size * total
|
||||
|
||||
@ -399,9 +401,9 @@ class CPUWorker(LocalOrDistributedWorkerBase):
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
"""Return the size in bytes of a single KV cache block.
|
||||
"""
|
||||
return CPUCacheEngine.get_cache_block_size(
|
||||
self.cache_config.block_size, self.cache_config.cache_dtype,
|
||||
self.model_config, self.parallel_config)
|
||||
return CPUCacheEngine.get_cache_block_size(self.cache_config,
|
||||
self.model_config,
|
||||
self.parallel_config)
|
||||
|
||||
def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
|
||||
"""Return CPUs id binding based on NUMA nodes.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user