diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index fc1a399d6f43f..f15c5594c27b2 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -530,6 +530,10 @@ class CudaPlatformBase(Platform): supported = flash_attn_supports_fp8() else: supported = True + elif attention_backend == "FLASHINFER": + supported = True + elif attention_backend == "TRITON_ATTN_VLLM_V1": + supported = cls.supports_fp8() return supported @classmethod diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index c7a565810b45b..4a491eeaafd10 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -202,7 +202,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): else: assert self.kv_cache_spec.dtype == self.model_config.dtype self.kv_cache_dtype = self.kv_cache_spec.dtype - self.q_data_type = self.kv_cache_dtype + + if supports_trtllm_attention()[0]: + self.q_data_type = self.kv_cache_dtype + else: + self.q_data_type = self.model_config.dtype self._cascade_wrapper = None # Wrapper for cascade attention