From f77df94647ca278575c2de4dada36d21d089e979 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 6 Oct 2025 19:03:49 -0400 Subject: [PATCH] [Perf] Add decode full-graph support to FlashInfer-MLA backend (#26313) Signed-off-by: Benjamin Chislett --- vllm/v1/attention/backends/mla/flashinfer_mla.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index f0ea1d653c3e7..13552edab87bb 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import ClassVar, Optional, Union import torch from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla @@ -12,13 +12,20 @@ from vllm.v1.attention.backends.mla.common import ( MLACommonBackend, MLACommonImpl, MLACommonMetadata, + MLACommonMetadataBuilder, ) +from vllm.v1.attention.backends.utils import AttentionCGSupport logger = init_logger(__name__) FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024 +class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): + # enable full CUDA Graph support for decode-only capture + cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH + + class FlashInferMLABackend(MLACommonBackend): @staticmethod def get_name() -> str: @@ -28,6 +35,10 @@ class FlashInferMLABackend(MLACommonBackend): def get_impl_cls() -> type["FlashInferMLAImpl"]: return FlashInferMLAImpl + @staticmethod + def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]: + return FlashInferMLAMetadataBuilder + g_fi_workspace = torch.zeros( FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,