diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index efe9c843f144c..cc1a95b820a46 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -66,6 +66,80 @@ def llm_pair(request): ) +@pytest.fixture(scope="class") +def cutlass_mla_llm_pair(request): + model = request.param + + # force V1 engine and Cutlass MLA backend + with temporary_environ({ + "VLLM_USE_V1": "1", + "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA", + "FORCE_NUM_KV_SPLITS": + "1", # TODO: remove this when hang issue is fixed + }): + full = LLM( + model=model, + gpu_memory_utilization=0.45, + trust_remote_code=True, + max_model_len=1024, + compilation_config=CompilationConfig( + full_cuda_graph=True, + cudagraph_capture_sizes=[16, 32, 64, 128, 256, 512], + ), + ) + piecewise = LLM( + model=model, + gpu_memory_utilization=0.45, + trust_remote_code=True, + max_model_len=1024, + compilation_config=CompilationConfig(), + ) + + yield weakref.proxy(full), weakref.proxy(piecewise) + del full + del piecewise + + wait_for_gpu_memory_to_clear( + devices=[0], + threshold_ratio=0.1, + ) + + +@pytest.mark.parametrize( + "cutlass_mla_llm_pair", + [ + # use an MLA model + "deepseek-ai/DeepSeek-V2-Lite", + ], + indirect=True) +@pytest.mark.skipif(current_platform.get_device_capability() != (10, 0), + reason="Only Blackwell GPUs support Cutlass MLA") +class TestFullCUDAGraphCutlassMLA: + """ + Validate full CUDA Graph with Cutlass MLA (decode-only capture). + """ + + @pytest.mark.parametrize(("batch_size", "max_tokens"), [ + (8, 8), + ]) + def test_full_cudagraph_sm100_cutlass_mla( + self, batch_size, max_tokens, cutlass_mla_llm_pair: tuple[LLM, + LLM]): + piecewise_llm, full_cudagraph_llm = cutlass_mla_llm_pair + + prompts = ["Hello, my name is"] * batch_size + sampling_params = SamplingParams(temperature=0.0, + max_tokens=max_tokens, + top_p=0.95) + + piecewise_responses = piecewise_llm.generate(prompts, sampling_params) + full_responses = full_cudagraph_llm.generate(prompts, sampling_params) + + for piecewise_res, full_res in zip(piecewise_responses, + full_responses): + assert piecewise_res.outputs[0].text == full_res.outputs[0].text + + @pytest.mark.parametrize( "llm_pair", [ diff --git a/vllm/v1/attention/backends/mla/cutlass_mla.py b/vllm/v1/attention/backends/mla/cutlass_mla.py index b23a8f0a5e870..b076613c8645a 100644 --- a/vllm/v1/attention/backends/mla/cutlass_mla.py +++ b/vllm/v1/attention/backends/mla/cutlass_mla.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os -from typing import Optional +from typing import ClassVar, Optional import torch @@ -12,11 +12,19 @@ from vllm.attention.backends.abstract import (AttentionType, from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, MLACommonImpl, - MLACommonMetadata) + MLACommonMetadata, + MLACommonMetadataBuilder) +from vllm.v1.attention.backends.utils import AttentionCGSupport logger = init_logger(__name__) +class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]): + # enable full CUDA Graph support for decode-only capture + attn_cudagraph_support: ClassVar[ + AttentionCGSupport] = AttentionCGSupport.PURE_DECODE_ONLY + + class CutlassMLABackend(MLACommonBackend): @staticmethod @@ -27,6 +35,10 @@ class CutlassMLABackend(MLACommonBackend): def get_impl_cls() -> type["CutlassMLAImpl"]: return CutlassMLAImpl + @staticmethod + def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]: + return CutlassMLAMetadataBuilder + class SM100Workspace: