mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-16 15:47:22 +08:00
[Perf] Add decode full-graph support to FlashInfer-MLA backend (#26313)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
parent
f231e5bc21
commit
f77df94647
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional, Union
|
||||
from typing import ClassVar, Optional, Union
|
||||
|
||||
import torch
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
@ -12,13 +12,20 @@ from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
||||
|
||||
|
||||
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
# enable full CUDA Graph support for decode-only capture
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
|
||||
class FlashInferMLABackend(MLACommonBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@ -28,6 +35,10 @@ class FlashInferMLABackend(MLACommonBackend):
|
||||
def get_impl_cls() -> type["FlashInferMLAImpl"]:
|
||||
return FlashInferMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
|
||||
return FlashInferMLAMetadataBuilder
|
||||
|
||||
|
||||
g_fi_workspace = torch.zeros(
|
||||
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user