mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-15 11:33:32 +08:00
[Kernel] refactor cpu worker v0 cache dtype (#20080)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
parent
b024a42e93
commit
fb14d53cf6
@ -18,7 +18,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache
|
from vllm.utils import bind_kv_cache
|
||||||
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
|
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
|
||||||
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
|
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
|
||||||
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
|
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
|
||||||
@ -54,13 +54,8 @@ class CPUCacheEngine:
|
|||||||
# in the scheduler.
|
# in the scheduler.
|
||||||
self.num_cpu_blocks = cache_config.num_gpu_blocks
|
self.num_cpu_blocks = cache_config.num_gpu_blocks
|
||||||
|
|
||||||
if cache_config.cache_dtype == "auto":
|
self.dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config,
|
||||||
self.dtype = model_config.dtype
|
model_config)
|
||||||
elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]:
|
|
||||||
self.dtype = torch.float8_e5m2
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Unsupported KV cache type "
|
|
||||||
f"{cache_config.cache_dtype}.")
|
|
||||||
|
|
||||||
# Get attention backend.
|
# Get attention backend.
|
||||||
self.attn_backend = get_attn_backend(
|
self.attn_backend = get_attn_backend(
|
||||||
@ -97,10 +92,20 @@ class CPUCacheEngine:
|
|||||||
def copy(self, src_to_dsts: torch.Tensor) -> None:
|
def copy(self, src_to_dsts: torch.Tensor) -> None:
|
||||||
self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
|
self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_dtype(cache_config: CacheConfig,
|
||||||
|
model_config: ModelConfig):
|
||||||
|
if cache_config.cache_dtype == "auto":
|
||||||
|
return model_config.dtype
|
||||||
|
elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]:
|
||||||
|
return torch.float8_e5m2
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported KV cache type "
|
||||||
|
f"{cache_config.cache_dtype}.")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_cache_block_size(
|
def get_cache_block_size(
|
||||||
block_size: int,
|
cache_config: CacheConfig,
|
||||||
cache_dtype: str,
|
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
) -> int:
|
) -> int:
|
||||||
@ -108,13 +113,10 @@ class CPUCacheEngine:
|
|||||||
num_heads = model_config.get_num_kv_heads(parallel_config)
|
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||||
num_layers = model_config.get_num_layers(parallel_config)
|
num_layers = model_config.get_num_layers(parallel_config)
|
||||||
|
|
||||||
key_cache_block = block_size * num_heads * head_size
|
key_cache_block = cache_config.block_size * num_heads * head_size
|
||||||
value_cache_block = key_cache_block if not model_config.use_mla else 0
|
value_cache_block = key_cache_block if not model_config.use_mla else 0
|
||||||
total = num_layers * (key_cache_block + value_cache_block)
|
total = num_layers * (key_cache_block + value_cache_block)
|
||||||
if cache_dtype == "auto":
|
dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config, model_config)
|
||||||
dtype = model_config.dtype
|
|
||||||
else:
|
|
||||||
dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
|
|
||||||
dtype_size = torch.tensor([], dtype=dtype).element_size()
|
dtype_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
return dtype_size * total
|
return dtype_size * total
|
||||||
|
|
||||||
@ -399,9 +401,9 @@ class CPUWorker(LocalOrDistributedWorkerBase):
|
|||||||
def get_cache_block_size_bytes(self) -> int:
|
def get_cache_block_size_bytes(self) -> int:
|
||||||
"""Return the size in bytes of a single KV cache block.
|
"""Return the size in bytes of a single KV cache block.
|
||||||
"""
|
"""
|
||||||
return CPUCacheEngine.get_cache_block_size(
|
return CPUCacheEngine.get_cache_block_size(self.cache_config,
|
||||||
self.cache_config.block_size, self.cache_config.cache_dtype,
|
self.model_config,
|
||||||
self.model_config, self.parallel_config)
|
self.parallel_config)
|
||||||
|
|
||||||
def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
|
def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
|
||||||
"""Return CPUs id binding based on NUMA nodes.
|
"""Return CPUs id binding based on NUMA nodes.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user