mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 18:25:48 +08:00
[V1] [Hybrid] Enable Full CUDA Graph (decode-only) for Mamba layers (#21401)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
42172ad18f
commit
61f67d8acd
@ -384,3 +384,63 @@ def test_distributed_correctness(
|
||||
name_0="vllm_tp_1",
|
||||
name_1="vllm_tp_2",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_full_cuda_graph(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
monkeypatch,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
|
||||
try:
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
with hf_runner(model) as hf_model:
|
||||
if model not in HF_UNSUPPORTED_MODELS:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
hf_outputs = None
|
||||
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
if model in HYBRID_MODELS:
|
||||
# required due to reorder_batch behaviour
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
compilation_config={'full_cuda_graph': True},
|
||||
enable_prefix_caching=False) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
if hf_outputs is not None:
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_v0_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm-v0",
|
||||
)
|
||||
|
||||
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=ref_outputs,
|
||||
outputs_1_lst=vllm_v1_outputs,
|
||||
name_0="hf" if hf_outputs is not None else "vllm-v0",
|
||||
name_1="vllm-v1",
|
||||
)
|
||||
|
||||
@ -7,8 +7,10 @@ from typing import ClassVar, Optional
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
@ -82,6 +84,8 @@ class Mamba2AttentionMetadata:
|
||||
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
|
||||
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.PURE_DECODE_ONLY
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
@ -90,8 +94,18 @@ class Mamba2AttentionMetadataBuilder(
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
assert self.chunk_size is not None, (
|
||||
"chunk_size needs to be set in the model config for Mamba2 models")
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.vllm_config.scheduler_config.max_num_seqs,
|
||||
self.compilation_config.max_capture_size)
|
||||
self.state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs, ),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
@ -144,6 +158,14 @@ class Mamba2AttentionMetadataBuilder(
|
||||
query_start_loc_p, self.chunk_size,
|
||||
num_prefill_tokens))
|
||||
|
||||
elif num_decodes <= self.decode_cudagraph_max_bs:
|
||||
# Pad state tensor for CUDA graph
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||
self.state_indices_tensor[:num_decodes].copy_(state_indices_tensor,
|
||||
non_blocking=True)
|
||||
state_indices_tensor = self.state_indices_tensor[:num_input_tokens]
|
||||
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||
|
||||
attn_metadata = Mamba2AttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
@ -160,3 +182,23 @@ class Mamba2AttentionMetadataBuilder(
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata):
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with Mamba.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
|
||||
assert m.num_reqs == m.num_actual_tokens, \
|
||||
"Mamba only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
|
||||
m.max_query_len = 1 # decode-only
|
||||
|
||||
return self.build(0, m)
|
||||
|
||||
def can_run_in_cudagraph(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||
return common_attn_metadata.max_query_len == 1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user