mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 21:55:50 +08:00
[Attention][FlashInfer] Enable FP8 FlashInfer (TRTLLM) MLA decode (#24705)
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
This commit is contained in:
parent
c89ed8de43
commit
7ba32aa60b
@ -178,6 +178,7 @@ class MockAttentionLayer:
|
||||
self._k_scale = torch.tensor(1.0, device=device)
|
||||
self._v_scale = torch.tensor(1.0, device=device)
|
||||
# Add float versions for flashinfer
|
||||
self._q_scale_float = 1.0
|
||||
self._k_scale_float = 1.0
|
||||
self._v_scale_float = 1.0
|
||||
|
||||
|
||||
@ -33,10 +33,12 @@ def test_ragged_paged_attention():
|
||||
)
|
||||
|
||||
class FakeAttentionLayer:
|
||||
_q_scale_float: float
|
||||
_k_scale_float: float
|
||||
_v_scale_float: float
|
||||
|
||||
layer = FakeAttentionLayer()
|
||||
layer._q_scale_float = 1.0
|
||||
layer._k_scale_float = 1.0
|
||||
layer._v_scale_float = 1.0
|
||||
|
||||
|
||||
@ -240,6 +240,7 @@ class AttentionLayer(Protocol):
|
||||
_q_scale: torch.Tensor
|
||||
_k_scale: torch.Tensor
|
||||
_v_scale: torch.Tensor
|
||||
_q_scale_float: float
|
||||
_k_scale_float: float
|
||||
_v_scale_float: float
|
||||
_prob_scale: torch.Tensor
|
||||
|
||||
@ -476,6 +476,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# - "FLASHINFER": use flashinfer
|
||||
# - "FLASHMLA": use FlashMLA
|
||||
# - "FLASH_ATTN_MLA": use FlashAttention for MLA
|
||||
# - "FLASHINFER_MLA": use FlashInfer for MLA
|
||||
# - "CUTLASS_MLA": use CUTLASS for MLA
|
||||
"VLLM_ATTENTION_BACKEND":
|
||||
lambda: os.getenv("VLLM_ATTENTION_BACKEND", None),
|
||||
|
||||
|
||||
@ -88,6 +88,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
"Setting it to k_scale. This only matters for "
|
||||
"the flash-attn backend.")
|
||||
layer._q_scale.copy_(k_scale)
|
||||
layer._q_scale_float = k_scale
|
||||
|
||||
# These are used in the final Attention.forward()
|
||||
layer._k_scale.copy_(k_scale)
|
||||
@ -124,6 +125,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
|
||||
# These are used in the final Attention.forward()
|
||||
layer._q_scale.copy_(q_scale)
|
||||
layer._q_scale_float = q_scale
|
||||
layer._prob_scale.copy_(prob_scale)
|
||||
if layer.kv_cache_dtype == "fp8" and (q_scale == 1.0
|
||||
or prob_scale == 1.0):
|
||||
|
||||
@ -179,6 +179,7 @@ class CudaPlatformBase(Platform):
|
||||
cache_config.block_size = 128
|
||||
logger.info("Forcing kv cache block size to 128 for "
|
||||
"CUTLASS_MLA backend.")
|
||||
|
||||
if use_flashinfer_mla and cache_config.block_size not in [32, 64]:
|
||||
cache_config.block_size = 64
|
||||
logger.info(
|
||||
@ -541,7 +542,9 @@ class CudaPlatformBase(Platform):
|
||||
attention_backend = "FLASHMLA"
|
||||
|
||||
# Only FlashMLA and CUTLASS_MLA support fp8
|
||||
if attention_backend in ["FLASHMLA", "CUTLASS_MLA"]:
|
||||
if attention_backend in [
|
||||
"FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA"
|
||||
]:
|
||||
supported = True
|
||||
else:
|
||||
supported = (not fp8_attention)
|
||||
|
||||
@ -584,7 +584,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
window_left=self._global_hyperparameters.window_left,
|
||||
logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
|
||||
q_data_type=self.model_config.dtype,
|
||||
kv_data_type=self.kv_cache_spec.dtype,
|
||||
)
|
||||
|
||||
# Prepare context prefills
|
||||
@ -605,7 +604,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
logits_soft_cap=self._global_hyperparameters.
|
||||
logits_soft_cap,
|
||||
q_data_type=self.model_config.dtype,
|
||||
kv_data_type=self.kv_cache_spec.dtype,
|
||||
)
|
||||
|
||||
prefill.prefill_main = self._fi_prefill_main
|
||||
|
||||
@ -6,8 +6,7 @@ from typing import Optional, Union
|
||||
import torch
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionLayer, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
@ -69,11 +68,9 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
"are not implemented for "
|
||||
"FlashInferMLAImpl")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlashInferMLA V1 with FP8 KV cache not yet supported")
|
||||
|
||||
self._workspace_buffer = g_fi_workspace
|
||||
self.bmm1_scale: Optional[float] = None
|
||||
self.bmm2_scale: Optional[float] = None
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
@ -92,6 +89,12 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
# trtllm API requires extra dimension q_len_per_request for MTP
|
||||
q = q.unsqueeze(1)
|
||||
|
||||
if self.bmm1_scale is None:
|
||||
self.bmm1_scale = (layer._q_scale_float * layer._k_scale_float *
|
||||
self.scale)
|
||||
if self.bmm2_scale is None:
|
||||
self.bmm2_scale = layer._v_scale_float
|
||||
|
||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||
@ -102,7 +105,8 @@ class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
seq_lens=attn_metadata.decode.seq_lens,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=self.scale,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
)
|
||||
|
||||
# TODO: Return LSE pending support from Flashinfer API:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user