mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
[V1] [Hybrid] Enable Full CUDA Graph (decode-only) for Mamba layers (#21401)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
42172ad18f
commit
61f67d8acd
@ -384,3 +384,63 @@ def test_distributed_correctness(
|
|||||||
name_0="vllm_tp_1",
|
name_0="vllm_tp_1",
|
||||||
name_1="vllm_tp_2",
|
name_1="vllm_tp_2",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
def test_full_cuda_graph(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
monkeypatch,
|
||||||
|
model: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||||
|
model_info.check_available_online(on_fail="skip")
|
||||||
|
model_info.check_transformers_version(on_fail="skip")
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
with hf_runner(model) as hf_model:
|
||||||
|
if model not in HF_UNSUPPORTED_MODELS:
|
||||||
|
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||||
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
else:
|
||||||
|
hf_outputs = None
|
||||||
|
|
||||||
|
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||||
|
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
|
with monkeypatch.context() as m:
|
||||||
|
m.setenv("VLLM_USE_V1", "1")
|
||||||
|
if model in HYBRID_MODELS:
|
||||||
|
# required due to reorder_batch behaviour
|
||||||
|
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||||
|
with vllm_runner(model,
|
||||||
|
max_num_seqs=MAX_NUM_SEQS,
|
||||||
|
compilation_config={'full_cuda_graph': True},
|
||||||
|
enable_prefix_caching=False) as vllm_model:
|
||||||
|
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
|
if hf_outputs is not None:
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=hf_outputs,
|
||||||
|
outputs_1_lst=vllm_v0_outputs,
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm-v0",
|
||||||
|
)
|
||||||
|
|
||||||
|
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=ref_outputs,
|
||||||
|
outputs_1_lst=vllm_v1_outputs,
|
||||||
|
name_0="hf" if hf_outputs is not None else "vllm-v0",
|
||||||
|
name_1="vllm-v1",
|
||||||
|
)
|
||||||
|
|||||||
@ -7,8 +7,10 @@ from typing import ClassVar, Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||||
|
AttentionMetadataBuilder,
|
||||||
CommonAttentionMetadata,
|
CommonAttentionMetadata,
|
||||||
split_decodes_and_prefills)
|
split_decodes_and_prefills)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||||
@ -82,6 +84,8 @@ class Mamba2AttentionMetadata:
|
|||||||
|
|
||||||
class Mamba2AttentionMetadataBuilder(
|
class Mamba2AttentionMetadataBuilder(
|
||||||
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
|
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
|
||||||
|
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||||
|
AttentionCGSupport.PURE_DECODE_ONLY
|
||||||
|
|
||||||
reorder_batch_threshold: ClassVar[int] = 1
|
reorder_batch_threshold: ClassVar[int] = 1
|
||||||
|
|
||||||
@ -90,8 +94,18 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
assert isinstance(kv_cache_spec, MambaSpec)
|
assert isinstance(kv_cache_spec, MambaSpec)
|
||||||
self.kv_cache_spec = kv_cache_spec
|
self.kv_cache_spec = kv_cache_spec
|
||||||
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
|
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.compilation_config = vllm_config.compilation_config
|
||||||
assert self.chunk_size is not None, (
|
assert self.chunk_size is not None, (
|
||||||
"chunk_size needs to be set in the model config for Mamba2 models")
|
"chunk_size needs to be set in the model config for Mamba2 models")
|
||||||
|
self.decode_cudagraph_max_bs = min(
|
||||||
|
self.vllm_config.scheduler_config.max_num_seqs,
|
||||||
|
self.compilation_config.max_capture_size)
|
||||||
|
self.state_indices_tensor = torch.empty(
|
||||||
|
(self.decode_cudagraph_max_bs, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
def build(self,
|
def build(self,
|
||||||
common_prefix_len: int,
|
common_prefix_len: int,
|
||||||
@ -144,6 +158,14 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
query_start_loc_p, self.chunk_size,
|
query_start_loc_p, self.chunk_size,
|
||||||
num_prefill_tokens))
|
num_prefill_tokens))
|
||||||
|
|
||||||
|
elif num_decodes <= self.decode_cudagraph_max_bs:
|
||||||
|
# Pad state tensor for CUDA graph
|
||||||
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||||
|
self.state_indices_tensor[:num_decodes].copy_(state_indices_tensor,
|
||||||
|
non_blocking=True)
|
||||||
|
state_indices_tensor = self.state_indices_tensor[:num_input_tokens]
|
||||||
|
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||||
|
|
||||||
attn_metadata = Mamba2AttentionMetadata(
|
attn_metadata = Mamba2AttentionMetadata(
|
||||||
num_prefills=num_prefills,
|
num_prefills=num_prefills,
|
||||||
num_prefill_tokens=num_prefill_tokens,
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
@ -160,3 +182,23 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
state_indices_tensor=state_indices_tensor,
|
state_indices_tensor=state_indices_tensor,
|
||||||
)
|
)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
|
def build_for_cudagraph_capture(
|
||||||
|
self, common_attn_metadata: CommonAttentionMetadata):
|
||||||
|
"""
|
||||||
|
This method builds the metadata for full cudagraph capture.
|
||||||
|
Currently, only decode is supported for full cudagraphs with Mamba.
|
||||||
|
"""
|
||||||
|
m = common_attn_metadata
|
||||||
|
|
||||||
|
assert m.num_reqs == m.num_actual_tokens, \
|
||||||
|
"Mamba only supports decode-only full CUDAGraph capture. " \
|
||||||
|
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||||
|
|
||||||
|
m.max_query_len = 1 # decode-only
|
||||||
|
|
||||||
|
return self.build(0, m)
|
||||||
|
|
||||||
|
def can_run_in_cudagraph(
|
||||||
|
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
|
||||||
|
return common_attn_metadata.max_query_len == 1
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user