mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 11:54:54 +08:00
[V1][Mamba1] - Full CUDA and Piecewise CUDA Graphs Support (#23035)
Signed-off-by: asafg <asafg@ai21.com> Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com> Co-authored-by: asafg <asafg@ai21.com>
This commit is contained in:
parent
2461d9e562
commit
3663870c72
@ -107,7 +107,7 @@ to enable simultaneous generation and embedding using the same engine instance i
|
||||
#### Mamba Models
|
||||
|
||||
Models using selective state-space mechanisms instead of standard transformer attention are supported.
|
||||
Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Additionally, Mamba-1 models require `enforce_eager=True`.
|
||||
Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1.
|
||||
|
||||
Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
|
||||
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that
|
||||
|
||||
@ -54,16 +54,14 @@ V1_SUPPORTED_MODELS = [
|
||||
"tiiuae/Falcon-H1-0.5B-Base",
|
||||
]
|
||||
|
||||
FULL_CUDA_GRAPH_MODELS = [
|
||||
"ai21labs/Jamba-tiny-dev",
|
||||
"Zyphra/Zamba2-1.2B-instruct",
|
||||
]
|
||||
|
||||
# Avoid OOM
|
||||
MAX_NUM_SEQS = 4
|
||||
|
||||
# Once we add support for FCG in Mamba1, this list will be removed and tests
|
||||
# all test cases will use enforce_eager=False
|
||||
ENFORCE_EAGER_MODELS_V1 = [
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"ai21labs/Jamba-tiny-dev",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@ -101,19 +99,13 @@ def test_models(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
if model in V1_SUPPORTED_MODELS:
|
||||
enforce_eager = False
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
if model in HYBRID_MODELS:
|
||||
# required due to reorder_batch behaviour
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||
|
||||
if model in ENFORCE_EAGER_MODELS_V1:
|
||||
enforce_eager = True
|
||||
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_prefix_caching=False) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
@ -373,7 +365,7 @@ def test_distributed_correctness(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"])
|
||||
@pytest.mark.parametrize("model", FULL_CUDA_GRAPH_MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_full_cuda_graph(
|
||||
|
||||
@ -336,6 +336,7 @@ class CompilationConfig:
|
||||
"vllm.unified_attention",
|
||||
"vllm.unified_attention_with_output",
|
||||
"vllm.mamba_mixer2",
|
||||
"vllm.mamba_mixer",
|
||||
]
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
|
||||
@ -27,6 +27,8 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
|
||||
|
||||
|
||||
@ -183,22 +185,26 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||
if not envs.VLLM_USE_V1:
|
||||
return CustomOp.forward(self, hidden_states, mamba_cache_params)
|
||||
CustomOp.forward(self, hidden_states, output, mamba_cache_params)
|
||||
else:
|
||||
return self.forward_cuda(
|
||||
torch.ops.vllm.mamba_mixer(
|
||||
hidden_states,
|
||||
mamba_cache_params,
|
||||
output,
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
def forward_native(self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||
pass
|
||||
|
||||
def forward_cuda(self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||
"""
|
||||
Run the Mamba-1 SSM pipeline.
|
||||
@ -237,6 +243,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
has_initial_states = mamba1_metadata.has_initial_states
|
||||
num_padded_decodes = mamba1_metadata.num_padded_decodes
|
||||
else:
|
||||
assert isinstance(attn_metadata, AttentionMetadata)
|
||||
assert mamba_cache_params is not None
|
||||
@ -248,6 +255,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
has_initial_states = None
|
||||
if context_lens_tensor is not None:
|
||||
has_initial_states = context_lens_tensor > 0
|
||||
num_padded_decodes = attn_metadata.num_decode_tokens
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
@ -267,6 +275,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||
has_prefill = num_prefill_tokens > 0
|
||||
has_decode = num_decode_tokens > 0
|
||||
num_actual_tokens = num_prefill_tokens + num_decode_tokens
|
||||
|
||||
prefill_decode_split = split_batch_to_prefill_and_decode(
|
||||
hidden_states_BC,
|
||||
@ -278,6 +287,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
num_decode_tokens,
|
||||
num_prefills,
|
||||
num_decodes,
|
||||
num_padded_decodes,
|
||||
)
|
||||
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
|
||||
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
|
||||
@ -371,7 +381,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
else:
|
||||
out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]
|
||||
|
||||
return out
|
||||
output[:num_actual_tokens] = out
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype]:
|
||||
assert self.model_config is not None
|
||||
@ -421,18 +431,27 @@ def split_batch_to_prefill_and_decode(
|
||||
num_decode_tokens: int,
|
||||
num_prefills: int,
|
||||
num_decodes: int,
|
||||
num_padded_decodes: int,
|
||||
) -> PrefillDecodeSplit:
|
||||
num_actual_tokens = num_prefill_tokens + num_padded_decodes
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
# In v1, decode tokens come first, then prefill tokens.
|
||||
hidden_states_BC_d, hidden_states_BC_p = torch.split(
|
||||
hidden_states_BC, [num_decode_tokens, num_prefill_tokens], dim=-1)
|
||||
gate_d, gate_p = torch.split(gate,
|
||||
[num_decode_tokens, num_prefill_tokens],
|
||||
hidden_states_BC[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
dim=-1)
|
||||
gate_d, gate_p = torch.split(gate[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
dim=-1)
|
||||
|
||||
# num_padded_decodes accounts for CUDA graph padding when applicable
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor, [num_decodes, num_prefills], dim=0)
|
||||
state_indices_tensor[:num_padded_decodes + num_prefills],
|
||||
[num_padded_decodes, num_prefills],
|
||||
dim=0)
|
||||
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
|
||||
num_decodes if num_prefills > 0 else None)
|
||||
num_padded_decodes if num_prefills > 0 else None)
|
||||
has_initial_states_p = has_initial_states[-num_prefills:] if (
|
||||
has_initial_states is not None and num_prefills > 0) else None
|
||||
else:
|
||||
@ -459,3 +478,32 @@ def split_batch_to_prefill_and_decode(
|
||||
query_start_loc_p=query_start_loc_p,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
)
|
||||
|
||||
|
||||
def mamba_mixer(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states,
|
||||
output=output,
|
||||
mamba_cache_params=None)
|
||||
|
||||
|
||||
def mamba_mixer_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="mamba_mixer",
|
||||
op_func=mamba_mixer,
|
||||
mutates_args=["output"],
|
||||
fake_impl=mamba_mixer_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
@ -10,6 +10,7 @@ from transformers import JambaConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
@ -154,10 +155,10 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
hidden_states = self.mamba(hidden_states, mamba_cache_params)
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mamba(hidden_states, output, mamba_cache_params)
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.pre_ff_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states, residual = self.pre_ff_layernorm(output, residual)
|
||||
hidden_states = self.feed_forward(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
@ -278,6 +279,7 @@ ALL_DECODER_LAYER_TYPES = {
|
||||
}
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class JambaModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
||||
@ -9,6 +9,7 @@ from torch import nn
|
||||
from transformers import MambaConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -81,10 +82,12 @@ class MambaDecoderLayer(nn.Module):
|
||||
else:
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.mixer(hidden_states, mamba_cache_params)
|
||||
return hidden_states, residual
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mixer(hidden_states, output, mamba_cache_params)
|
||||
return output, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class MambaModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
||||
@ -2,16 +2,16 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
BaseMambaAttentionMetadataBuilder)
|
||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
class Mamba1AttentionBackend(AttentionBackend):
|
||||
@ -31,24 +31,11 @@ class Mamba1AttentionMetadata:
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_padded_decodes: int
|
||||
|
||||
|
||||
class Mamba1AttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[Mamba1AttentionMetadata]):
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
layer_names: list[str],
|
||||
):
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.device = device
|
||||
self.vllm_config = vllm_config
|
||||
self.layer_names = layer_names
|
||||
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]):
|
||||
|
||||
def build(
|
||||
self,
|
||||
@ -67,9 +54,18 @@ class Mamba1AttentionMetadataBuilder(
|
||||
decode_threshold=1))
|
||||
|
||||
has_initial_states = None
|
||||
padded_decodes = num_decodes
|
||||
|
||||
if num_prefills > 0:
|
||||
has_initial_states = context_lens_tensor > 0
|
||||
elif (num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.full_cuda_graph):
|
||||
state_indices_for_decode = state_indices_tensor[:num_decodes]
|
||||
padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||
self.state_indices_tensor[:num_decodes].copy_(
|
||||
state_indices_for_decode, non_blocking=True)
|
||||
state_indices_tensor = self.state_indices_tensor[:padded_decodes]
|
||||
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||
|
||||
return Mamba1AttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
@ -80,4 +76,5 @@ class Mamba1AttentionMetadataBuilder(
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_padded_decodes=padded_decodes,
|
||||
)
|
||||
|
||||
@ -2,18 +2,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
from vllm.v1.attention.backends.mamba_attn import (
|
||||
BaseMambaAttentionMetadataBuilder)
|
||||
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
||||
@ -88,29 +88,14 @@ class Mamba2AttentionMetadata:
|
||||
|
||||
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]):
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
assert self.chunk_size is not None, (
|
||||
"chunk_size needs to be set in the model config for Mamba2 models")
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.vllm_config.scheduler_config.max_num_seqs,
|
||||
self.compilation_config.max_capture_size)
|
||||
self.state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs, ),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build(self,
|
||||
common_prefix_len: int,
|
||||
@ -187,19 +172,3 @@ class Mamba2AttentionMetadataBuilder(
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata):
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with Mamba.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
|
||||
assert m.num_reqs == m.num_actual_tokens, \
|
||||
"Mamba only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
|
||||
m.max_query_len = 1 # decode-only
|
||||
|
||||
return self.build(0, m)
|
||||
|
||||
55
vllm/v1/attention/backends/mamba_attn.py
Normal file
55
vllm/v1/attention/backends/mamba_attn.py
Normal file
@ -0,0 +1,55 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import abc
|
||||
from typing import ClassVar, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
M = TypeVar("M")
|
||||
|
||||
|
||||
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.device = device
|
||||
self.vllm_config = vllm_config
|
||||
self.layer_names = layer_names
|
||||
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.vllm_config.scheduler_config.max_num_seqs,
|
||||
self.compilation_config.max_capture_size)
|
||||
self.state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs, ),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata) -> M:
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with Mamba.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
|
||||
assert m.num_reqs == m.num_actual_tokens, \
|
||||
"Mamba only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
|
||||
m.max_query_len = 1 # decode-only
|
||||
|
||||
return self.build(0, m)
|
||||
Loading…
x
Reference in New Issue
Block a user