From baee28c46c242b72f90d6b1211ab9d7872ab05d3 Mon Sep 17 00:00:00 2001 From: cloudhan Date: Sat, 2 Mar 2024 14:34:48 +0800 Subject: [PATCH] Reorder kv dtype check to avoid nvcc not found error on AMD platform (#3104) --- vllm/config.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index e260e6a0cb1d..ff8536c1aca5 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -330,15 +330,14 @@ class CacheConfig: if self.cache_dtype == "auto": pass elif self.cache_dtype == "fp8_e5m2": + if is_hip(): + raise NotImplementedError( + "FP8_E5M2 KV Cache on AMD GPU has not been supported yet.") nvcc_cuda_version = get_nvcc_cuda_version() if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"): raise ValueError( "FP8 is not supported when cuda version is lower than 11.8." ) - device_name = torch.cuda.get_device_name() - if "AMD" in device_name: - raise NotImplementedError( - "FP8_E5M2 KV Cache on AMD GPU has not been supported yet.") logger.info( "Using fp8_e5m2 data type to store kv cache. It reduces " "the GPU memory footprint and boosts the performance. "