From ebfce922f91ee6ab407099335fd16373f6586d88 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Fri, 26 Sep 2025 12:51:46 -0700 Subject: [PATCH] full cg support Signed-off-by: Lucas Wilkinson --- .../attention/backends/mla/flashinfer_mla.py | 15 +++++++++++++-- vllm/v1/attention/backends/mla/triton_mla.py | 19 +++++++++++++++---- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 701248670f72e..a949a6eaa1e3a 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -1,16 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import ClassVar, Optional, Union import torch from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla from vllm.attention.backends.abstract import AttentionLayer, AttentionType from vllm.logger import init_logger +from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, - MLACommonMetadata) + MLACommonMetadata, + MLACommonMetadataBuilder) logger = init_logger(__name__) @@ -23,6 +25,10 @@ class FlashInferMLABackend(MLACommonBackend): def get_name() -> str: return "FLASHINFER_MLA" + @staticmethod + def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]: + return FlashInferMLAMetadataBuilder + @staticmethod def get_impl_cls() -> type["FlashInferMLAImpl"]: return FlashInferMLAImpl @@ -34,6 +40,11 @@ g_fi_workspace = torch.zeros( device="cuda", ) +class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): + cudagraph_support: ClassVar[ + AttentionCGSupport] = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + pass + class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]): diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 076152061d502..0ee1ae7e4e39b 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional, Union +from typing import ClassVar, Optional, Union import torch @@ -13,9 +13,11 @@ from vllm.attention.ops.triton_flash_attention import triton_attention from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON +from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, - MLACommonMetadata) + MLACommonMetadata, + MLACommonMetadataBuilder) logger = init_logger(__name__) @@ -24,12 +26,21 @@ class TritonMLABackend(MLACommonBackend): @staticmethod def get_name() -> str: - return "TRITON_MLA" + return "TRITON_MLA_VLLM_V1" + + @staticmethod + def get_builder_cls() -> type["TritonMLAMetadataBuilder"]: + return TritonMLAMetadataBuilder @staticmethod def get_impl_cls() -> type["TritonMLAImpl"]: return TritonMLAImpl - + + +class TritonMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_BATCH + pass class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): can_return_lse_for_decode: bool = True