mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-24 15:11:25 +08:00
[Perf] Add decode full-graph support to FlashInfer-MLA backend (#26313)
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
parent
f231e5bc21
commit
f77df94647
@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Optional, Union
|
from typing import ClassVar, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||||
@ -12,13 +12,20 @@ from vllm.v1.attention.backends.mla.common import (
|
|||||||
MLACommonBackend,
|
MLACommonBackend,
|
||||||
MLACommonImpl,
|
MLACommonImpl,
|
||||||
MLACommonMetadata,
|
MLACommonMetadata,
|
||||||
|
MLACommonMetadataBuilder,
|
||||||
)
|
)
|
||||||
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||||
|
# enable full CUDA Graph support for decode-only capture
|
||||||
|
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||||
|
|
||||||
|
|
||||||
class FlashInferMLABackend(MLACommonBackend):
|
class FlashInferMLABackend(MLACommonBackend):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
@ -28,6 +35,10 @@ class FlashInferMLABackend(MLACommonBackend):
|
|||||||
def get_impl_cls() -> type["FlashInferMLAImpl"]:
|
def get_impl_cls() -> type["FlashInferMLAImpl"]:
|
||||||
return FlashInferMLAImpl
|
return FlashInferMLAImpl
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
|
||||||
|
return FlashInferMLAMetadataBuilder
|
||||||
|
|
||||||
|
|
||||||
g_fi_workspace = torch.zeros(
|
g_fi_workspace = torch.zeros(
|
||||||
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
|
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user