mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-16 19:09:16 +08:00
[Feature] Full Cuda Graph Support for Cutlass MLA and 6% E2E Throughput Improvement (#22763)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
b4cef5e6c7
commit
5c3fbfe46b
@ -66,6 +66,80 @@ def llm_pair(request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="class")
|
||||||
|
def cutlass_mla_llm_pair(request):
|
||||||
|
model = request.param
|
||||||
|
|
||||||
|
# force V1 engine and Cutlass MLA backend
|
||||||
|
with temporary_environ({
|
||||||
|
"VLLM_USE_V1": "1",
|
||||||
|
"VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
|
||||||
|
"FORCE_NUM_KV_SPLITS":
|
||||||
|
"1", # TODO: remove this when hang issue is fixed
|
||||||
|
}):
|
||||||
|
full = LLM(
|
||||||
|
model=model,
|
||||||
|
gpu_memory_utilization=0.45,
|
||||||
|
trust_remote_code=True,
|
||||||
|
max_model_len=1024,
|
||||||
|
compilation_config=CompilationConfig(
|
||||||
|
full_cuda_graph=True,
|
||||||
|
cudagraph_capture_sizes=[16, 32, 64, 128, 256, 512],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
piecewise = LLM(
|
||||||
|
model=model,
|
||||||
|
gpu_memory_utilization=0.45,
|
||||||
|
trust_remote_code=True,
|
||||||
|
max_model_len=1024,
|
||||||
|
compilation_config=CompilationConfig(),
|
||||||
|
)
|
||||||
|
|
||||||
|
yield weakref.proxy(full), weakref.proxy(piecewise)
|
||||||
|
del full
|
||||||
|
del piecewise
|
||||||
|
|
||||||
|
wait_for_gpu_memory_to_clear(
|
||||||
|
devices=[0],
|
||||||
|
threshold_ratio=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"cutlass_mla_llm_pair",
|
||||||
|
[
|
||||||
|
# use an MLA model
|
||||||
|
"deepseek-ai/DeepSeek-V2-Lite",
|
||||||
|
],
|
||||||
|
indirect=True)
|
||||||
|
@pytest.mark.skipif(current_platform.get_device_capability() != (10, 0),
|
||||||
|
reason="Only Blackwell GPUs support Cutlass MLA")
|
||||||
|
class TestFullCUDAGraphCutlassMLA:
|
||||||
|
"""
|
||||||
|
Validate full CUDA Graph with Cutlass MLA (decode-only capture).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("batch_size", "max_tokens"), [
|
||||||
|
(8, 8),
|
||||||
|
])
|
||||||
|
def test_full_cudagraph_sm100_cutlass_mla(
|
||||||
|
self, batch_size, max_tokens, cutlass_mla_llm_pair: tuple[LLM,
|
||||||
|
LLM]):
|
||||||
|
piecewise_llm, full_cudagraph_llm = cutlass_mla_llm_pair
|
||||||
|
|
||||||
|
prompts = ["Hello, my name is"] * batch_size
|
||||||
|
sampling_params = SamplingParams(temperature=0.0,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
top_p=0.95)
|
||||||
|
|
||||||
|
piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
|
||||||
|
full_responses = full_cudagraph_llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
for piecewise_res, full_res in zip(piecewise_responses,
|
||||||
|
full_responses):
|
||||||
|
assert piecewise_res.outputs[0].text == full_res.outputs[0].text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"llm_pair",
|
"llm_pair",
|
||||||
[
|
[
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import ClassVar, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -12,11 +12,19 @@ from vllm.attention.backends.abstract import (AttentionType,
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||||
MLACommonImpl,
|
MLACommonImpl,
|
||||||
MLACommonMetadata)
|
MLACommonMetadata,
|
||||||
|
MLACommonMetadataBuilder)
|
||||||
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||||
|
# enable full CUDA Graph support for decode-only capture
|
||||||
|
attn_cudagraph_support: ClassVar[
|
||||||
|
AttentionCGSupport] = AttentionCGSupport.PURE_DECODE_ONLY
|
||||||
|
|
||||||
|
|
||||||
class CutlassMLABackend(MLACommonBackend):
|
class CutlassMLABackend(MLACommonBackend):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -27,6 +35,10 @@ class CutlassMLABackend(MLACommonBackend):
|
|||||||
def get_impl_cls() -> type["CutlassMLAImpl"]:
|
def get_impl_cls() -> type["CutlassMLAImpl"]:
|
||||||
return CutlassMLAImpl
|
return CutlassMLAImpl
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
|
||||||
|
return CutlassMLAMetadataBuilder
|
||||||
|
|
||||||
|
|
||||||
class SM100Workspace:
|
class SM100Workspace:
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user