[V1] [Hybrid] Enable Full CUDA Graph (decode-only) for Mamba layers (#21401)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2025-08-10 05:16:11 +02:00 committed by GitHub
parent 42172ad18f
commit 61f67d8acd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 103 additions and 1 deletions

View File

@ -384,3 +384,63 @@ def test_distributed_correctness(
name_0="vllm_tp_1",
name_1="vllm_tp_2",
)
@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_full_cuda_graph(
hf_runner,
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
num_logprobs: int,
) -> None:
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
except ValueError:
pass
with hf_runner(model) as hf_model:
if model not in HF_UNSUPPORTED_MODELS:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
else:
hf_outputs = None
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
# required due to reorder_batch behaviour
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
compilation_config={'full_cuda_graph': True},
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
if hf_outputs is not None:
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v0_outputs,
name_0="hf",
name_1="vllm-v0",
)
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
check_logprobs_close(
outputs_0_lst=ref_outputs,
outputs_1_lst=vllm_v1_outputs,
name_0="hf" if hf_outputs is not None else "vllm-v0",
name_1="vllm-v1",
)

View File

@ -7,8 +7,10 @@ from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
@ -82,6 +84,8 @@ class Mamba2AttentionMetadata:
class Mamba2AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.PURE_DECODE_ONLY
reorder_batch_threshold: ClassVar[int] = 1
@ -90,8 +94,18 @@ class Mamba2AttentionMetadataBuilder(
assert isinstance(kv_cache_spec, MambaSpec)
self.kv_cache_spec = kv_cache_spec
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
assert self.chunk_size is not None, (
"chunk_size needs to be set in the model config for Mamba2 models")
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs,
self.compilation_config.max_capture_size)
self.state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, ),
dtype=torch.int32,
device=device,
)
def build(self,
common_prefix_len: int,
@ -144,6 +158,14 @@ class Mamba2AttentionMetadataBuilder(
query_start_loc_p, self.chunk_size,
num_prefill_tokens))
elif num_decodes <= self.decode_cudagraph_max_bs:
# Pad state tensor for CUDA graph
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
self.state_indices_tensor[:num_decodes].copy_(state_indices_tensor,
non_blocking=True)
state_indices_tensor = self.state_indices_tensor[:num_input_tokens]
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
attn_metadata = Mamba2AttentionMetadata(
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
@ -160,3 +182,23 @@ class Mamba2AttentionMetadataBuilder(
state_indices_tensor=state_indices_tensor,
)
return attn_metadata
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata):
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
"Mamba only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
m.max_query_len = 1 # decode-only
return self.build(0, m)
def can_run_in_cudagraph(
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
return common_attn_metadata.max_query_len == 1