diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f47499309d8f6..e2c861587583c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1418,14 +1418,15 @@ class EngineArgs: and not envs.is_set("VLLM_ATTENTION_BACKEND") ) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1" supported = False - if current_platform.is_rocm(): + if current_platform.is_rocm() or ( + current_platform.is_cuda() + and current_platform.is_device_capability(100)): supported = True elif fp8_attention and will_use_fa: from vllm.attention.utils.fa_utils import ( flash_attn_supports_fp8) supported = flash_attn_supports_fp8() - elif envs.VLLM_USE_TRTLLM_DECODE_ATTENTION: - supported = True + if not supported: _raise_or_fallback(feature_name="--kv-cache-dtype", recommend_to_remove=False)