mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 06:55:01 +08:00
optimize get_kv_cache_torch_dtype (#18531)
Signed-off-by: idellzheng <idellzheng@tencent.com>
This commit is contained in:
parent
aaa4ac1c95
commit
6b6d496114
@ -759,16 +759,15 @@ def get_kv_cache_torch_dtype(
|
||||
model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
|
||||
if isinstance(cache_dtype, str):
|
||||
if cache_dtype == "auto":
|
||||
if isinstance(model_dtype, str):
|
||||
if isinstance(model_dtype,
|
||||
str) and model_dtype in STR_DTYPE_TO_TORCH_DTYPE:
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
|
||||
elif isinstance(model_dtype, torch.dtype):
|
||||
torch_dtype = model_dtype
|
||||
else:
|
||||
raise ValueError(f"Invalid model dtype: {model_dtype}")
|
||||
elif cache_dtype in ["half", "bfloat16", "float"]:
|
||||
elif cache_dtype in STR_DTYPE_TO_TORCH_DTYPE:
|
||||
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype]
|
||||
elif cache_dtype == "fp8":
|
||||
torch_dtype = torch.uint8
|
||||
else:
|
||||
raise ValueError(f"Invalid kv cache dtype: {cache_dtype}")
|
||||
elif isinstance(cache_dtype, torch.dtype):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user