From 6b6d4961147220fb80f9cc7dcb74db478f9c9a23 Mon Sep 17 00:00:00 2001 From: chunxiaozheng <55471457+chunxiaozheng@users.noreply.github.com> Date: Tue, 27 May 2025 21:08:44 +0800 Subject: [PATCH] optimize get_kv_cache_torch_dtype (#18531) Signed-off-by: idellzheng --- vllm/utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 7222a3c99102..846df7743736 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -759,16 +759,15 @@ def get_kv_cache_torch_dtype( model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype: if isinstance(cache_dtype, str): if cache_dtype == "auto": - if isinstance(model_dtype, str): + if isinstance(model_dtype, + str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] elif isinstance(model_dtype, torch.dtype): torch_dtype = model_dtype else: raise ValueError(f"Invalid model dtype: {model_dtype}") - elif cache_dtype in ["half", "bfloat16", "float"]: + elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] - elif cache_dtype == "fp8": - torch_dtype = torch.uint8 else: raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") elif isinstance(cache_dtype, torch.dtype):