[Perf] Add decode full-graph support to FlashInfer-MLA backend (#26313)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett 2025-10-06 19:03:49 -04:00 committed by GitHub
parent f231e5bc21
commit f77df94647
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
from typing import ClassVar, Optional, Union
import torch
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
@ -12,13 +12,20 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
)
from vllm.v1.attention.backends.utils import AttentionCGSupport
logger = init_logger(__name__)
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable full CUDA Graph support for decode-only capture
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
class FlashInferMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
@ -28,6 +35,10 @@ class FlashInferMLABackend(MLACommonBackend):
def get_impl_cls() -> type["FlashInferMLAImpl"]:
return FlashInferMLAImpl
@staticmethod
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
return FlashInferMLAMetadataBuilder
g_fi_workspace = torch.zeros(
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,