From fe91ce9591760e3b58fc24845fa80364fbcdd07f Mon Sep 17 00:00:00 2001 From: amirai21 <89905406+amirai21@users.noreply.github.com> Date: Fri, 15 Aug 2025 11:59:52 +0300 Subject: [PATCH] [V1] - Split Prefill and Decode for Mamba1 models (#22653) Signed-off-by: amirk Signed-off-by: asafg Co-authored-by: asafg Co-authored-by: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com> --- .../models/language/generation/test_hybrid.py | 13 + .../layers/mamba/mamba_mixer.py | 305 +++++++++++++----- vllm/v1/attention/backends/mamba1_attn.py | 26 +- 3 files changed, 251 insertions(+), 93 deletions(-) diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 19fcbf561640..e75677347f03 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -57,6 +57,13 @@ V1_SUPPORTED_MODELS = [ # Avoid OOM MAX_NUM_SEQS = 4 +# Once we add support for FCG in Mamba1, this list will be removed and tests +# all test cases will use enforce_eager=False +ENFORCE_EAGER_MODELS_V1 = [ + "state-spaces/mamba-130m-hf", + "ai21labs/Jamba-tiny-dev", +] + @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("max_tokens", [64]) @@ -94,13 +101,19 @@ def test_models( example_prompts, max_tokens, num_logprobs) if model in V1_SUPPORTED_MODELS: + enforce_eager = False with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") if model in HYBRID_MODELS: # required due to reorder_batch behaviour m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + + if model in ENFORCE_EAGER_MODELS_V1: + enforce_eager = True + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, + enforce_eager=enforce_eager, enable_prefix_caching=False) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 17b7f84a933f..3b17fb0ca8c7 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional +from typing import NamedTuple, Optional import torch from torch import nn from torch.nn.parameter import Parameter from vllm import envs +from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import get_current_vllm_config from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) @@ -154,13 +155,38 @@ class MambaMixer(MambaBase, CustomOp): self.prefix = prefix + def _ssm_transform( + self, x: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if self.is_lora_enabled: + # Lora kernel requires contiguous tensor. + ssm_params = self.x_proj(x.contiguous())[0] + else: + ssm_params = self.x_proj(x)[0] + time_step, B, C = torch.split( + ssm_params, + [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], + dim=-1) + if self.use_rms_norm: + assert self.dt_layernorm is not None + assert self.b_layernorm is not None + assert self.c_layernorm is not None + time_step = self.dt_layernorm(time_step.contiguous()) + B = self.b_layernorm(B.contiguous()) + C = self.c_layernorm(C.contiguous()) + discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) + return discrete_time_step, B, C + def forward(self, hidden_states: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): if not envs.VLLM_USE_V1: return CustomOp.forward(self, hidden_states, mamba_cache_params) else: - return self.forward_cuda(hidden_states, mamba_cache_params) + return self.forward_cuda( + hidden_states, + mamba_cache_params, + ) def forward_native(self, hidden_states: torch.Tensor, @@ -170,6 +196,27 @@ class MambaMixer(MambaBase, CustomOp): def forward_cuda(self, hidden_states: torch.Tensor, mamba_cache_params: Optional[MambaCacheParams] = None): + """ + Run the Mamba-1 SSM pipeline. + + Steps + ----- + 1. Apply the gated-MLP linear projection to the raw input. + 2. Pass the projected sequence through the convolutional mixing layer. + 3. Feed the result into the State-Space Model (SSM) blocks. + 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) + to produce contextual representations. + 5. Project the contextualised sequence back + to the output embedding dimension. + + Batch handling + -------------- + Prefill and decode tokens are processed by dedicated CUDA + kernels for both the convolutional (conv1d) and SSM stages. + In the case of a mixed batch (containing both prefill and + decode tokens), both sets of kernels are executed independently + and their outputs are concatenated before the final output projection. + """ forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -185,126 +232,142 @@ class MambaMixer(MambaBase, CustomOp): self_kv_cache = self.kv_cache[forward_context.virtual_engine] conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] - has_initial_state = mamba1_metadata.has_initial_states - context_lens_tensor = mamba1_metadata.context_lens_tensor + has_initial_states = mamba1_metadata.has_initial_states else: + assert isinstance(attn_metadata, AttentionMetadata) assert mamba_cache_params is not None conv_state = mamba_cache_params.conv_state ssm_state = mamba_cache_params.ssm_state state_indices_tensor = mamba_cache_params.state_indices_tensor query_start_loc = attn_metadata.query_start_loc context_lens_tensor = attn_metadata.context_lens_tensor - + has_initial_states = None if context_lens_tensor is not None: - has_initial_state = context_lens_tensor > 0 + has_initial_states = context_lens_tensor > 0 # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) - hidden_states, gate = projected_states.chunk(2, dim=-2) + hidden_states_BC, gate = projected_states.chunk(2, dim=-2) - # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) if envs.VLLM_USE_V1 and attn_metadata is None: # V1 profile run - hidden_states = hidden_states.contiguous() - return self.out_proj(hidden_states.transpose(-2, -1))[0] + hidden_states_BC = hidden_states_BC.contiguous() + return self.out_proj(hidden_states_BC.transpose(-2, -1))[0] - if query_start_loc is not None and context_lens_tensor is not None: - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - hidden_states = causal_conv1d_fn( - hidden_states, + num_prefill_tokens = attn_metadata.num_prefill_tokens # token count + num_decode_tokens = attn_metadata.num_decode_tokens + num_prefills = attn_metadata.num_prefills # request count + num_decodes = attn_metadata.num_decode_tokens # token count (=request) + has_prefill = num_prefill_tokens > 0 + has_decode = num_decode_tokens > 0 + + prefill_decode_split = split_batch_to_prefill_and_decode( + hidden_states_BC, + gate, + state_indices_tensor, + query_start_loc, + has_initial_states, + num_prefill_tokens, + num_decode_tokens, + num_prefills, + num_decodes, + ) + hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p + hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d + gate_p = prefill_decode_split.gate_p + gate_d = prefill_decode_split.gate_d + state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p + state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d + query_start_loc_p = prefill_decode_split.query_start_loc_p + has_initial_states_p = prefill_decode_split.has_initial_states_p + + ssm_outputs = [] + + if has_prefill: + # 2. Convolution sequence transformation + conv_out_p = causal_conv1d_fn( + hidden_states_BC_p, conv_weights, - bias=self.conv1d.bias, + self.conv1d.bias, activation=self.activation, conv_states=conv_state, - has_initial_state=has_initial_state, - cache_indices=state_indices_tensor, - query_start_loc=query_start_loc) - else: - hidden_states = causal_conv1d_update( - hidden_states.transpose(0, 1), + has_initial_state=has_initial_states_p, + cache_indices=state_indices_tensor_p, + query_start_loc=query_start_loc_p) + # 3. State Space Model sequence transformations. + discrete_time_step_p, B_p, C_p = self._ssm_transform( + conv_out_p.transpose(-2, -1)) + time_proj_bias = self._time_proj_bias() + + # 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) + scan_out_p = selective_scan_fn( + conv_out_p, + ssm_state, + discrete_time_step_p, + self.A, + B_p.transpose(-2, -1), + C_p.transpose(-2, -1), + self.D.float(), + gate_p, + time_proj_bias, + delta_softplus=True, + cache_indices=state_indices_tensor_p, + has_initial_state=has_initial_states_p, + query_start_loc=query_start_loc_p) + ssm_outputs.append(scan_out_p) + + if has_decode: + # 2. Convolution sequence transformation + conv_out_d = causal_conv1d_update( + hidden_states_BC_d.transpose(0, 1), conv_state, conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=state_indices_tensor) - hidden_states = hidden_states.transpose(0, 1) + conv_state_indices=state_indices_tensor_d).transpose(0, 1) - # 3. State Space Model sequence transformation - # 3.a. input varying initialization of time_step, B and C + # 3. State Space Model sequence transformation. + discrete_time_step_d, B_d, C_d = self._ssm_transform( + conv_out_d.transpose(-2, -1)) + time_proj_bias = self._time_proj_bias() - if self.is_lora_enabled: - # lora kernel requires contiguous tensor - ssm_parameters = self.x_proj( - hidden_states.transpose(-2, -1).contiguous())[0] - else: - ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0] - - time_step, B, C = torch.split( - ssm_parameters, - [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], - dim=-1, - ) - if self.use_rms_norm: - assert self.dt_layernorm is not None - assert self.b_layernorm is not None - assert self.c_layernorm is not None - time_step = self.dt_layernorm(time_step.contiguous()) - B = self.b_layernorm(B.contiguous()) - C = self.c_layernorm(C.contiguous()) - - discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1) - # 3.c perform the recurrence y ← SSM(A, B, C)(x) - time_proj_bias = (self.dt_proj.bias.float() if hasattr( - self.dt_proj, "bias") else None) - - if query_start_loc is not None and context_lens_tensor is not None: - scan_outputs = selective_scan_fn( - hidden_states, - ssm_state, - discrete_time_step, - self.A, - B.transpose(-2, -1), - C.transpose(-2, -1), - self.D.float(), - gate, - time_proj_bias, - delta_softplus=True, - cache_indices=state_indices_tensor, - has_initial_state=has_initial_state, - query_start_loc=query_start_loc) - else: - scan_outputs = torch.empty_like(hidden_states.transpose(0, 1)) + # 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x) + scan_outputs_d = torch.empty_like( + hidden_states_BC_d.transpose(0, 1)) selective_state_update(ssm_state, - hidden_states.transpose(0, 1), - discrete_time_step.transpose(0, 1), + conv_out_d.transpose(0, 1), + discrete_time_step_d.transpose(0, 1), self.A, - B, - C, + B_d, + C_d, self.D, - gate.transpose(0, 1), + gate_d.transpose(0, 1), time_proj_bias, dt_softplus=True, - state_batch_indices=state_indices_tensor, - out=scan_outputs) - scan_outputs = scan_outputs.transpose(0, 1) + state_batch_indices=state_indices_tensor_d, + out=scan_outputs_d) + scan_outputs_d = scan_outputs_d.transpose(0, 1) - # 4. Final linear projection - if self.is_lora_enabled: - # lora kernel requires contiguous tensor - contextualized_states = self.out_proj( - scan_outputs.transpose(-2, -1).contiguous())[0] + if envs.VLLM_USE_V1: + ssm_outputs.insert(0, scan_outputs_d) + else: + ssm_outputs.append(scan_outputs_d) + + scan_outputs_combined = ssm_outputs[0] if len( + ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1) + + # 5. Final output projection + if self.is_lora_enabled: # Lora kernel requires contiguous tensor. + scan_outputs_combined = scan_outputs_combined.transpose( + -2, -1).contiguous() + out = self.out_proj(scan_outputs_combined)[0] else: - contextualized_states = self.out_proj( - scan_outputs.transpose(-2, -1))[0] - return contextualized_states + out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0] + + return out def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: return MambaStateShapeCalculator.mamba1_state_shape( @@ -317,3 +380,69 @@ class MambaMixer(MambaBase, CustomOp): @property def mamba_type(self) -> str: return "mamba1" + + def _time_proj_bias(self) -> Optional[torch.Tensor]: + if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None: + return self.dt_proj.bias.float() + return None + + +class PrefillDecodeSplit(NamedTuple): + hidden_states_BC_p: torch.Tensor + hidden_states_BC_d: torch.Tensor + gate_p: torch.Tensor + gate_d: torch.Tensor + state_indices_tensor_p: torch.Tensor + state_indices_tensor_d: torch.Tensor + query_start_loc_p: Optional[torch.Tensor] + has_initial_states_p: Optional[torch.Tensor] + + +def split_batch_to_prefill_and_decode( + hidden_states_BC: torch.Tensor, + gate: torch.Tensor, + state_indices_tensor: torch.Tensor, + query_start_loc: torch.Tensor, + has_initial_states: Optional[torch.Tensor], + num_prefill_tokens: int, + num_decode_tokens: int, + num_prefills: int, + num_decodes: int, +) -> PrefillDecodeSplit: + if envs.VLLM_USE_V1: + # In v1, decode tokens come first, then prefill tokens. + hidden_states_BC_d, hidden_states_BC_p = torch.split( + hidden_states_BC, [num_decode_tokens, num_prefill_tokens], dim=-1) + gate_d, gate_p = torch.split(gate, + [num_decode_tokens, num_prefill_tokens], + dim=-1) + state_indices_tensor_d, state_indices_tensor_p = torch.split( + state_indices_tensor, [num_decodes, num_prefills], dim=0) + query_start_loc_p = (query_start_loc[-num_prefills - 1:] - + num_decodes if num_prefills > 0 else None) + has_initial_states_p = has_initial_states[-num_prefills:] if ( + has_initial_states is not None and num_prefills > 0) else None + else: + # In v0, prefill tokens come first, then decode tokens. + hidden_states_BC_p, hidden_states_BC_d = torch.split( + hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1) + gate_p, gate_d = torch.split(gate, + [num_prefill_tokens, num_decode_tokens], + dim=-1) + state_indices_tensor_p, state_indices_tensor_d = torch.split( + state_indices_tensor, [num_prefills, num_decodes], dim=0) + query_start_loc_p = (query_start_loc[:num_prefills + + 1] if num_prefills > 0 else None) + has_initial_states_p = has_initial_states[:num_prefills] if ( + has_initial_states is not None and num_prefills > 0) else None + + return PrefillDecodeSplit( + hidden_states_BC_p=hidden_states_BC_p, + hidden_states_BC_d=hidden_states_BC_d, + gate_p=gate_p, + gate_d=gate_d, + state_indices_tensor_p=state_indices_tensor_p, + state_indices_tensor_d=state_indices_tensor_d, + query_start_loc_p=query_start_loc_p, + has_initial_states_p=has_initial_states_p, + ) diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index f0e4636fdb52..6cdc509083ae 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -2,14 +2,15 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar +from typing import ClassVar, Optional import torch from vllm.attention.backends.abstract import AttentionBackend from vllm.config import VllmConfig from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, - CommonAttentionMetadata) + CommonAttentionMetadata, + split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec @@ -25,12 +26,15 @@ class Mamba1AttentionMetadata: query_start_loc: torch.Tensor context_lens_tensor: torch.Tensor state_indices_tensor: torch.Tensor - has_initial_states: torch.Tensor + has_initial_states: Optional[torch.Tensor] + num_prefills: int + num_prefill_tokens: int + num_decodes: int + num_decode_tokens: int class Mamba1AttentionMetadataBuilder( AttentionMetadataBuilder[Mamba1AttentionMetadata]): - reorder_batch_threshold: ClassVar[int] = 1 def __init__( @@ -57,11 +61,23 @@ class Mamba1AttentionMetadataBuilder( state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to( query_start_loc.device) - has_initial_states = (context_lens_tensor > 0) + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=1)) + + has_initial_states = None + + if num_prefills > 0: + has_initial_states = context_lens_tensor > 0 return Mamba1AttentionMetadata( query_start_loc=query_start_loc, context_lens_tensor=context_lens_tensor, has_initial_states=has_initial_states, state_indices_tensor=state_indices_tensor, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, )