optimize get_kv_cache_torch_dtype (#18531)

Signed-off-by: idellzheng <idellzheng@tencent.com>
This commit is contained in:
chunxiaozheng 2025-05-27 21:08:44 +08:00 committed by GitHub
parent aaa4ac1c95
commit 6b6d496114
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -759,16 +759,15 @@ def get_kv_cache_torch_dtype(
model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype: model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
if isinstance(cache_dtype, str): if isinstance(cache_dtype, str):
if cache_dtype == "auto": if cache_dtype == "auto":
if isinstance(model_dtype, str): if isinstance(model_dtype,
str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
elif isinstance(model_dtype, torch.dtype): elif isinstance(model_dtype, torch.dtype):
torch_dtype = model_dtype torch_dtype = model_dtype
else: else:
raise ValueError(f"Invalid model dtype: {model_dtype}") raise ValueError(f"Invalid model dtype: {model_dtype}")
elif cache_dtype in ["half", "bfloat16", "float"]: elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE:
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
elif cache_dtype == "fp8":
torch_dtype = torch.uint8
else: else:
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
elif isinstance(cache_dtype, torch.dtype): elif isinstance(cache_dtype, torch.dtype):