diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 54af970ea842d..9bf0c5842c6be 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -107,7 +107,7 @@ to enable simultaneous generation and embedding using the same engine instance i #### Mamba Models Models using selective state-space mechanisms instead of standard transformer attention are supported. -Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Additionally, Mamba-1 models require `enforce_eager=True`. +Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, `Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index aee0a50336c09..f8c0eaa8cf3a2 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -54,16 +54,14 @@ V1_SUPPORTED_MODELS = [ "tiiuae/Falcon-H1-0.5B-Base", ] +FULL_CUDA_GRAPH_MODELS = [ + "ai21labs/Jamba-tiny-dev", + "Zyphra/Zamba2-1.2B-instruct", +] + # Avoid OOM MAX_NUM_SEQS = 4 -# Once we add support for FCG in Mamba1, this list will be removed and tests -# all test cases will use enforce_eager=False -ENFORCE_EAGER_MODELS_V1 = [ - "state-spaces/mamba-130m-hf", - "ai21labs/Jamba-tiny-dev", -] - @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @@ -101,19 +99,13 @@ def test_models( example_prompts, max_tokens, num_logprobs) if model in V1_SUPPORTED_MODELS: - enforce_eager = False with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") if model in HYBRID_MODELS: # required due to reorder_batch behaviour m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") - - if model in ENFORCE_EAGER_MODELS_V1: - enforce_eager = True - with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, - enforce_eager=enforce_eager, enable_prefix_caching=False) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) @@ -373,7 +365,7 @@ def test_distributed_correctness( ) -@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"]) +@pytest.mark.parametrize("model", FULL_CUDA_GRAPH_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_full_cuda_graph( diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 56a2183f8e2c1..c654485f4fe9c 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -336,6 +336,7 @@ class CompilationConfig: "vllm.unified_attention", "vllm.unified_attention_with_output", "vllm.mamba_mixer2", + "vllm.mamba_mixer", ] def compute_hash(self) -> str: diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 3c7322260df43..a24e72778b34b 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -27,6 +27,8 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata @@ -183,22 +185,26 @@ class MambaMixer(MambaBase, CustomOp): def forward(self, hidden_states: torch.Tensor, + output: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): if not envs.VLLM_USE_V1: - return CustomOp.forward(self, hidden_states, mamba_cache_params) + CustomOp.forward(self, hidden_states, output, mamba_cache_params) else: - return self.forward_cuda( + torch.ops.vllm.mamba_mixer( hidden_states, - mamba_cache_params, + output, + self.prefix, ) def forward_native(self, hidden_states: torch.Tensor, + output: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): pass def forward_cuda(self, hidden_states: torch.Tensor, + output: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): """ Run the Mamba-1 SSM pipeline. @@ -237,6 +243,7 @@ class MambaMixer(MambaBase, CustomOp): conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] has_initial_states = mamba1_metadata.has_initial_states + num_padded_decodes = mamba1_metadata.num_padded_decodes else: assert isinstance(attn_metadata, AttentionMetadata) assert mamba_cache_params is not None @@ -248,6 +255,7 @@ class MambaMixer(MambaBase, CustomOp): has_initial_states = None if context_lens_tensor is not None: has_initial_states = context_lens_tensor > 0 + num_padded_decodes = attn_metadata.num_decode_tokens # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) @@ -267,6 +275,7 @@ class MambaMixer(MambaBase, CustomOp): num_decodes = attn_metadata.num_decode_tokens # token count (=request) has_prefill = num_prefill_tokens > 0 has_decode = num_decode_tokens > 0 + num_actual_tokens = num_prefill_tokens + num_decode_tokens prefill_decode_split = split_batch_to_prefill_and_decode( hidden_states_BC, @@ -278,6 +287,7 @@ class MambaMixer(MambaBase, CustomOp): num_decode_tokens, num_prefills, num_decodes, + num_padded_decodes, ) hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d @@ -371,7 +381,7 @@ class MambaMixer(MambaBase, CustomOp): else: out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0] - return out + output[:num_actual_tokens] = out def get_state_dtype(self) -> tuple[torch.dtype]: assert self.model_config is not None @@ -421,18 +431,27 @@ def split_batch_to_prefill_and_decode( num_decode_tokens: int, num_prefills: int, num_decodes: int, + num_padded_decodes: int, ) -> PrefillDecodeSplit: + num_actual_tokens = num_prefill_tokens + num_padded_decodes + if envs.VLLM_USE_V1: # In v1, decode tokens come first, then prefill tokens. hidden_states_BC_d, hidden_states_BC_p = torch.split( - hidden_states_BC, [num_decode_tokens, num_prefill_tokens], dim=-1) - gate_d, gate_p = torch.split(gate, - [num_decode_tokens, num_prefill_tokens], + hidden_states_BC[..., :num_actual_tokens], + [num_padded_decodes, num_prefill_tokens], + dim=-1) + gate_d, gate_p = torch.split(gate[..., :num_actual_tokens], + [num_padded_decodes, num_prefill_tokens], dim=-1) + + # num_padded_decodes accounts for CUDA graph padding when applicable state_indices_tensor_d, state_indices_tensor_p = torch.split( - state_indices_tensor, [num_decodes, num_prefills], dim=0) + state_indices_tensor[:num_padded_decodes + num_prefills], + [num_padded_decodes, num_prefills], + dim=0) query_start_loc_p = (query_start_loc[-num_prefills - 1:] - - num_decodes if num_prefills > 0 else None) + num_padded_decodes if num_prefills > 0 else None) has_initial_states_p = has_initial_states[-num_prefills:] if ( has_initial_states is not None and num_prefills > 0) else None else: @@ -459,3 +478,32 @@ def split_batch_to_prefill_and_decode( query_start_loc_p=query_start_loc_p, has_initial_states_p=has_initial_states_p, ) + + +def mamba_mixer( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + forward_context: ForwardContext = get_forward_context() + self = forward_context.no_compile_layers[layer_name] + self.forward_cuda(hidden_states=hidden_states, + output=output, + mamba_cache_params=None) + + +def mamba_mixer_fake( + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_name: str, +) -> None: + return + + +direct_register_custom_op( + op_name="mamba_mixer", + op_func=mamba_mixer, + mutates_args=["output"], + fake_impl=mamba_mixer_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 0b32d6f256590..3c1a0b68df56e 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -10,6 +10,7 @@ from transformers import JambaConfig from vllm import envs from vllm.attention.layer import Attention +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group @@ -154,10 +155,10 @@ class JambaMambaDecoderLayer(nn.Module): hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.mamba(hidden_states, mamba_cache_params) + output = torch.empty_like(hidden_states) + self.mamba(hidden_states, output, mamba_cache_params) # Fully Connected - hidden_states, residual = self.pre_ff_layernorm( - hidden_states, residual) + hidden_states, residual = self.pre_ff_layernorm(output, residual) hidden_states = self.feed_forward(hidden_states) return hidden_states, residual @@ -278,6 +279,7 @@ ALL_DECODER_LAYER_TYPES = { } +@support_torch_compile class JambaModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index f4aaf0c6f467c..f02499a4f96b5 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -9,6 +9,7 @@ from torch import nn from transformers import MambaConfig from vllm import envs +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed.parallel_state import get_pp_group from vllm.model_executor.layers.layernorm import RMSNorm @@ -81,10 +82,12 @@ class MambaDecoderLayer(nn.Module): else: hidden_states, residual = self.norm(hidden_states, residual) - hidden_states = self.mixer(hidden_states, mamba_cache_params) - return hidden_states, residual + output = torch.empty_like(hidden_states) + self.mixer(hidden_states, output, mamba_cache_params) + return output, residual +@support_torch_compile class MambaModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 6cdc509083ae9..97a1aa86dda0d 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -2,16 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata, +from vllm.attention.backends.utils import PAD_SLOT_ID +from vllm.v1.attention.backends.mamba_attn import ( + BaseMambaAttentionMetadataBuilder) +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec class Mamba1AttentionBackend(AttentionBackend): @@ -31,24 +31,11 @@ class Mamba1AttentionMetadata: num_prefill_tokens: int num_decodes: int num_decode_tokens: int + num_padded_decodes: int class Mamba1AttentionMetadataBuilder( - AttentionMetadataBuilder[Mamba1AttentionMetadata]): - reorder_batch_threshold: ClassVar[int] = 1 - - def __init__( - self, - kv_cache_spec: AttentionSpec, - vllm_config: VllmConfig, - device: torch.device, - layer_names: list[str], - ): - assert isinstance(kv_cache_spec, MambaSpec) - self.kv_cache_spec = kv_cache_spec - self.device = device - self.vllm_config = vllm_config - self.layer_names = layer_names + BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]): def build( self, @@ -67,9 +54,18 @@ class Mamba1AttentionMetadataBuilder( decode_threshold=1)) has_initial_states = None + padded_decodes = num_decodes if num_prefills > 0: has_initial_states = context_lens_tensor > 0 + elif (num_decodes > 0 and num_decodes <= self.decode_cudagraph_max_bs + and self.compilation_config.full_cuda_graph): + state_indices_for_decode = state_indices_tensor[:num_decodes] + padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes) + self.state_indices_tensor[:num_decodes].copy_( + state_indices_for_decode, non_blocking=True) + state_indices_tensor = self.state_indices_tensor[:padded_decodes] + state_indices_tensor[num_decodes:] = PAD_SLOT_ID return Mamba1AttentionMetadata( query_start_loc=query_start_loc, @@ -80,4 +76,5 @@ class Mamba1AttentionMetadataBuilder( num_prefill_tokens=num_prefill_tokens, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, + num_padded_decodes=padded_decodes, ) diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index ace078e2b27c6..ed30884fdbc94 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -2,18 +2,18 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from dataclasses import dataclass -from typing import ClassVar, Optional +from typing import Optional import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionCGSupport, - AttentionMetadataBuilder, - CommonAttentionMetadata, +from vllm.v1.attention.backends.mamba_attn import ( + BaseMambaAttentionMetadataBuilder) +from vllm.v1.attention.backends.utils import (CommonAttentionMetadata, split_decodes_and_prefills) -from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec +from vllm.v1.kv_cache_interface import AttentionSpec def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, @@ -88,29 +88,14 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( - AttentionMetadataBuilder[Mamba2AttentionMetadata]): - cudagraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE - - reorder_batch_threshold: ClassVar[int] = 1 + BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]): def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device): - assert isinstance(kv_cache_spec, MambaSpec) - self.kv_cache_spec = kv_cache_spec + super().__init__(kv_cache_spec, layer_names, vllm_config, device) self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() - self.vllm_config = vllm_config - self.compilation_config = vllm_config.compilation_config assert self.chunk_size is not None, ( "chunk_size needs to be set in the model config for Mamba2 models") - self.decode_cudagraph_max_bs = min( - self.vllm_config.scheduler_config.max_num_seqs, - self.compilation_config.max_capture_size) - self.state_indices_tensor = torch.empty( - (self.decode_cudagraph_max_bs, ), - dtype=torch.int32, - device=device, - ) def build(self, common_prefix_len: int, @@ -187,19 +172,3 @@ class Mamba2AttentionMetadataBuilder( state_indices_tensor=state_indices_tensor, ) return attn_metadata - - def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata): - """ - This method builds the metadata for full cudagraph capture. - Currently, only decode is supported for full cudagraphs with Mamba. - """ - m = common_attn_metadata - - assert m.num_reqs == m.num_actual_tokens, \ - "Mamba only supports decode-only full CUDAGraph capture. " \ - "Make sure all cudagraph capture sizes <= max_num_seq." - - m.max_query_len = 1 # decode-only - - return self.build(0, m) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py new file mode 100644 index 0000000000000..07ef7cb69a160 --- /dev/null +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import abc +from typing import ClassVar, TypeVar + +import torch + +from vllm.config import VllmConfig +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata) +from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec + +M = TypeVar("M") + + +class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): + reorder_batch_threshold: ClassVar[int] = 1 + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + assert isinstance(kv_cache_spec, MambaSpec) + self.kv_cache_spec = kv_cache_spec + self.device = device + self.vllm_config = vllm_config + self.layer_names = layer_names + + self.compilation_config = vllm_config.compilation_config + self.decode_cudagraph_max_bs = min( + self.vllm_config.scheduler_config.max_num_seqs, + self.compilation_config.max_capture_size) + self.state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.int32, + device=device, + ) + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata) -> M: + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with Mamba. + """ + m = common_attn_metadata + + assert m.num_reqs == m.num_actual_tokens, \ + "Mamba only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + + m.max_query_len = 1 # decode-only + + return self.build(0, m) \ No newline at end of file