mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 04:27:15 +08:00
[Bugfix] Enable FP8 KV cache for FlashInfer and Triton backend on non-sm100 GPUs (#24577)
Signed-off-by: Thien Tran <gau.nernst@yahoo.com.sg>
This commit is contained in:
parent
09e68bce34
commit
a0933c3bd6
@ -530,6 +530,10 @@ class CudaPlatformBase(Platform):
|
||||
supported = flash_attn_supports_fp8()
|
||||
else:
|
||||
supported = True
|
||||
elif attention_backend == "FLASHINFER":
|
||||
supported = True
|
||||
elif attention_backend == "TRITON_ATTN_VLLM_V1":
|
||||
supported = cls.supports_fp8()
|
||||
return supported
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -202,7 +202,11 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
else:
|
||||
assert self.kv_cache_spec.dtype == self.model_config.dtype
|
||||
self.kv_cache_dtype = self.kv_cache_spec.dtype
|
||||
self.q_data_type = self.kv_cache_dtype
|
||||
|
||||
if supports_trtllm_attention()[0]:
|
||||
self.q_data_type = self.kv_cache_dtype
|
||||
else:
|
||||
self.q_data_type = self.model_config.dtype
|
||||
|
||||
self._cascade_wrapper = None # Wrapper for cascade attention
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user