mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 17:25:38 +08:00
Reorder kv dtype check to avoid nvcc not found error on AMD platform (#3104)
This commit is contained in:
parent
29e70e3e88
commit
baee28c46c
@ -330,15 +330,14 @@ class CacheConfig:
|
|||||||
if self.cache_dtype == "auto":
|
if self.cache_dtype == "auto":
|
||||||
pass
|
pass
|
||||||
elif self.cache_dtype == "fp8_e5m2":
|
elif self.cache_dtype == "fp8_e5m2":
|
||||||
|
if is_hip():
|
||||||
|
raise NotImplementedError(
|
||||||
|
"FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
|
||||||
nvcc_cuda_version = get_nvcc_cuda_version()
|
nvcc_cuda_version = get_nvcc_cuda_version()
|
||||||
if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"):
|
if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"FP8 is not supported when cuda version is lower than 11.8."
|
"FP8 is not supported when cuda version is lower than 11.8."
|
||||||
)
|
)
|
||||||
device_name = torch.cuda.get_device_name()
|
|
||||||
if "AMD" in device_name:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Using fp8_e5m2 data type to store kv cache. It reduces "
|
"Using fp8_e5m2 data type to store kv cache. It reduces "
|
||||||
"the GPU memory footprint and boosts the performance. "
|
"the GPU memory footprint and boosts the performance. "
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user