[Bugfix] Enable FP8 KV cache for FlashInfer and Triton backend on non-sm100 GPUs (#24577)

Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
This commit is contained in:
Thien Tran 2025-09-11 03:33:41 +08:00 committed by GitHub
parent 09e68bce34
commit a0933c3bd6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 1 deletions

View File

@ -530,6 +530,10 @@ class CudaPlatformBase(Platform):
supported = flash_attn_supports_fp8()
else:
supported = True
elif attention_backend == "FLASHINFER":
supported = True
elif attention_backend == "TRITON_ATTN_VLLM_V1":
supported = cls.supports_fp8()
return supported
@classmethod

View File

@ -202,7 +202,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
else:
assert self.kv_cache_spec.dtype == self.model_config.dtype
self.kv_cache_dtype = self.kv_cache_spec.dtype
self.q_data_type = self.kv_cache_dtype
if supports_trtllm_attention()[0]:
self.q_data_type = self.kv_cache_dtype
else:
self.q_data_type = self.model_config.dtype
self._cascade_wrapper = None # Wrapper for cascade attention