[Attention][FlashInfer] Enable FP8 FlashInfer (TRTLLM) MLA decode (#24705)

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
Matthew Bonanni 2025-09-12 17:45:53 -04:00 committed by GitHub
parent c89ed8de43
commit 7ba32aa60b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 23 additions and 10 deletions

View File

@ -178,6 +178,7 @@ class MockAttentionLayer:
self._k_scale = torch.tensor(1.0, device=device)
self._v_scale = torch.tensor(1.0, device=device)
# Add float versions for flashinfer
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0

View File

@ -33,10 +33,12 @@ def test_ragged_paged_attention():
)
class FakeAttentionLayer:
_q_scale_float: float
_k_scale_float: float
_v_scale_float: float
layer = FakeAttentionLayer()
layer._q_scale_float = 1.0
layer._k_scale_float = 1.0
layer._v_scale_float = 1.0

View File

@ -240,6 +240,7 @@ class AttentionLayer(Protocol):
_q_scale: torch.Tensor
_k_scale: torch.Tensor
_v_scale: torch.Tensor
_q_scale_float: float
_k_scale_float: float
_v_scale_float: float
_prob_scale: torch.Tensor

View File

@ -476,6 +476,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "FLASHINFER": use flashinfer
# - "FLASHMLA": use FlashMLA
# - "FLASH_ATTN_MLA": use FlashAttention for MLA
# - "FLASHINFER_MLA": use FlashInfer for MLA
# - "CUTLASS_MLA": use CUTLASS for MLA
"VLLM_ATTENTION_BACKEND":
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),

View File

@ -88,6 +88,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
"Setting it to k_scale. This only matters for "
"the flash-attn backend.")
layer._q_scale.copy_(k_scale)
layer._q_scale_float = k_scale
# These are used in the final Attention.forward()
layer._k_scale.copy_(k_scale)
@ -124,6 +125,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
# These are used in the final Attention.forward()
layer._q_scale.copy_(q_scale)
layer._q_scale_float = q_scale
layer._prob_scale.copy_(prob_scale)
if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0
or prob_scale == 1.0):

View File

@ -179,6 +179,7 @@ class CudaPlatformBase(Platform):
cache_config.block_size = 128
logger.info("Forcing kv cache block size to 128 for "
"CUTLASS_MLA backend.")
if use_flashinfer_mla and cache_config.block_size not in [32, 64]:
cache_config.block_size = 64
logger.info(
@ -541,7 +542,9 @@ class CudaPlatformBase(Platform):
attention_backend = "FLASHMLA"
# Only FlashMLA and CUTLASS_MLA support fp8
if attention_backend in ["FLASHMLA", "CUTLASS_MLA"]:
if attention_backend in [
"FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA"
]:
supported = True
else:
supported = (not fp8_attention)

View File

@ -584,7 +584,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
window_left=self._global_hyperparameters.window_left,
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
q_data_type=self.model_config.dtype,
kv_data_type=self.kv_cache_spec.dtype,
)
# Prepare context prefills
@ -605,7 +604,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
logits_soft_cap=self._global_hyperparameters.
logits_soft_cap,
q_data_type=self.model_config.dtype,
kv_data_type=self.kv_cache_spec.dtype,
)
prefill.prefill_main = self._fi_prefill_main

View File

@ -6,8 +6,7 @@ from typing import Optional, Union
import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
@ -69,11 +68,9 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
"are not implemented for "
"FlashInferMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashInferMLA V1 with FP8 KV cache not yet supported")
self._workspace_buffer = g_fi_workspace
self.bmm1_scale: Optional[float] = None
self.bmm2_scale: Optional[float] = None
def _forward_decode(
self,
@ -92,6 +89,12 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
# trtllm API requires extra dimension q_len_per_request for MTP
q = q.unsqueeze(1)
if self.bmm1_scale is None:
self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float *
self.scale)
if self.bmm2_scale is None:
self.bmm2_scale = layer._v_scale_float
o = trtllm_batch_decode_with_kv_cache_mla(
query=q,
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
@ -102,7 +105,8 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
block_tables=attn_metadata.decode.block_table,
seq_lens=attn_metadata.decode.seq_lens,
max_seq_len=attn_metadata.max_seq_len,
bmm1_scale=self.scale,
bmm1_scale=self.bmm1_scale,
bmm2_scale=self.bmm2_scale,
)
# TODO: Return LSE pending support from Flashinfer API: