mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:04:27 +08:00
[Feature] Full Cuda Graph Support for Cutlass MLA and 6% E2E Throughput Improvement (#22763)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
b4cef5e6c7
commit
5c3fbfe46b
@ -66,6 +66,80 @@ def llm_pair(request):
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def cutlass_mla_llm_pair(request):
|
||||
model = request.param
|
||||
|
||||
# force V1 engine and Cutlass MLA backend
|
||||
with temporary_environ({
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||
"FORCE_NUM_KV_SPLITS":
|
||||
"1", # TODO: remove this when hang issue is fixed
|
||||
}):
|
||||
full = LLM(
|
||||
model=model,
|
||||
gpu_memory_utilization=0.45,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(
|
||||
full_cuda_graph=True,
|
||||
cudagraph_capture_sizes=[16, 32, 64, 128, 256, 512],
|
||||
),
|
||||
)
|
||||
piecewise = LLM(
|
||||
model=model,
|
||||
gpu_memory_utilization=0.45,
|
||||
trust_remote_code=True,
|
||||
max_model_len=1024,
|
||||
compilation_config=CompilationConfig(),
|
||||
)
|
||||
|
||||
yield weakref.proxy(full), weakref.proxy(piecewise)
|
||||
del full
|
||||
del piecewise
|
||||
|
||||
wait_for_gpu_memory_to_clear(
|
||||
devices=[0],
|
||||
threshold_ratio=0.1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cutlass_mla_llm_pair",
|
||||
[
|
||||
# use an MLA model
|
||||
"deepseek-ai/DeepSeek-V2-Lite",
|
||||
],
|
||||
indirect=True)
|
||||
@pytest.mark.skipif(current_platform.get_device_capability() != (10, 0),
|
||||
reason="Only Blackwell GPUs support Cutlass MLA")
|
||||
class TestFullCUDAGraphCutlassMLA:
|
||||
"""
|
||||
Validate full CUDA Graph with Cutlass MLA (decode-only capture).
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(("batch_size", "max_tokens"), [
|
||||
(8, 8),
|
||||
])
|
||||
def test_full_cudagraph_sm100_cutlass_mla(
|
||||
self, batch_size, max_tokens, cutlass_mla_llm_pair: tuple[LLM,
|
||||
LLM]):
|
||||
piecewise_llm, full_cudagraph_llm = cutlass_mla_llm_pair
|
||||
|
||||
prompts = ["Hello, my name is"] * batch_size
|
||||
sampling_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
top_p=0.95)
|
||||
|
||||
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
|
||||
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
|
||||
|
||||
for piecewise_res, full_res in zip(piecewise_responses,
|
||||
full_responses):
|
||||
assert piecewise_res.outputs[0].text == full_res.outputs[0].text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"llm_pair",
|
||||
[
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -12,11 +12,19 @@ from vllm.attention.backends.abstract import (AttentionType,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata)
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
# enable full CUDA Graph support for decode-only capture
|
||||
attn_cudagraph_support: ClassVar[
|
||||
AttentionCGSupport] = AttentionCGSupport.PURE_DECODE_ONLY
|
||||
|
||||
|
||||
class CutlassMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
@ -27,6 +35,10 @@ class CutlassMLABackend(MLACommonBackend):
|
||||
def get_impl_cls() -> type["CutlassMLAImpl"]:
|
||||
return CutlassMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
|
||||
return CutlassMLAMetadataBuilder
|
||||
|
||||
|
||||
class SM100Workspace:
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user