[Feature] Full Cuda Graph Support for Cutlass MLA and 6% E2E Throughput Improvement (#22763)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-08-15 02:27:30 -04:00 committed by GitHub
parent b4cef5e6c7
commit 5c3fbfe46b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 88 additions and 2 deletions

View File

@ -66,6 +66,80 @@ def llm_pair(request):
)
@pytest.fixture(scope="class")
def cutlass_mla_llm_pair(request):
model = request.param
# force V1 engine and Cutlass MLA backend
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
"FORCE_NUM_KV_SPLITS":
"1", # TODO: remove this when hang issue is fixed
}):
full = LLM(
model=model,
gpu_memory_utilization=0.45,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(
full_cuda_graph=True,
cudagraph_capture_sizes=[16, 32, 64, 128, 256, 512],
),
)
piecewise = LLM(
model=model,
gpu_memory_utilization=0.45,
trust_remote_code=True,
max_model_len=1024,
compilation_config=CompilationConfig(),
)
yield weakref.proxy(full), weakref.proxy(piecewise)
del full
del piecewise
wait_for_gpu_memory_to_clear(
devices=[0],
threshold_ratio=0.1,
)
@pytest.mark.parametrize(
"cutlass_mla_llm_pair",
[
# use an MLA model
"deepseek-ai/DeepSeek-V2-Lite",
],
indirect=True)
@pytest.mark.skipif(current_platform.get_device_capability() != (10, 0),
reason="Only Blackwell GPUs support Cutlass MLA")
class TestFullCUDAGraphCutlassMLA:
"""
Validate full CUDA Graph with Cutlass MLA (decode-only capture).
"""
@pytest.mark.parametrize(("batch_size", "max_tokens"), [
(8, 8),
])
def test_full_cudagraph_sm100_cutlass_mla(
self, batch_size, max_tokens, cutlass_mla_llm_pair: tuple[LLM,
LLM]):
piecewise_llm, full_cudagraph_llm = cutlass_mla_llm_pair
prompts = ["Hello, my name is"] * batch_size
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
top_p=0.95)
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
for piecewise_res, full_res in zip(piecewise_responses,
full_responses):
assert piecewise_res.outputs[0].text == full_res.outputs[0].text
@pytest.mark.parametrize(
"llm_pair",
[

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Optional
from typing import ClassVar, Optional
import torch
@ -12,11 +12,19 @@ from vllm.attention.backends.abstract import (AttentionType,
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
logger = init_logger(__name__)
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
# enable full CUDA Graph support for decode-only capture
attn_cudagraph_support: ClassVar[
AttentionCGSupport] = AttentionCGSupport.PURE_DECODE_ONLY
class CutlassMLABackend(MLACommonBackend):
@staticmethod
@ -27,6 +35,10 @@ class CutlassMLABackend(MLACommonBackend):
def get_impl_cls() -> type["CutlassMLAImpl"]:
return CutlassMLAImpl
@staticmethod
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
return CutlassMLAMetadataBuilder
class SM100Workspace: