From 73cfb3c5eeb8b00a6e222751a28fd89a5f6229dc Mon Sep 17 00:00:00 2001 From: Chih-Chieh Yang <7364402+cyang49@users.noreply.github.com> Date: Tue, 16 Sep 2025 10:53:43 -0400 Subject: [PATCH] [Model] Clean up and simplify Mamba2 Metadata Usage in both V0 and V1 (#24331) Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> --- .../layers/mamba/mamba2_metadata.py | 62 +++++++------------ .../layers/mamba/mamba_mixer2.py | 29 +++------ vllm/model_executor/models/plamo2.py | 29 ++++----- 3 files changed, 44 insertions(+), 76 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba2_metadata.py b/vllm/model_executor/layers/mamba/mamba2_metadata.py index 3256ac034aa1..368bfe3af1d3 100644 --- a/vllm/model_executor/layers/mamba/mamba2_metadata.py +++ b/vllm/model_executor/layers/mamba/mamba2_metadata.py @@ -17,14 +17,13 @@ from vllm.v1.attention.backends.mamba2_attn import ( @dataclass class Mamba2Metadata: - - has_initial_states: torch.Tensor prep_initial_states: bool - chunk_size: int - seq_idx: torch.Tensor - chunk_indices: torch.Tensor - chunk_offsets: torch.Tensor + + has_initial_states_p: torch.Tensor + seq_idx_p: torch.Tensor + chunk_indices_p: torch.Tensor + chunk_offsets_p: torch.Tensor """ With continuous batching layout of `x` in vLLM, to enable a Triton program to handle a request in parallel, two supporting tensors are used @@ -68,7 +67,6 @@ def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]: def prepare_mamba2_metadata( chunk_size: int, attn_metadata: AttentionMetadata, - mamba2_metadata=None, ) -> Mamba2Metadata: # compute number of prefill and decode requests @@ -76,11 +74,11 @@ def prepare_mamba2_metadata( num_prefills = attn_metadata.num_prefills num_prefill_tokens = attn_metadata.num_prefill_tokens - seq_idx = None - chunk_indices, chunk_offsets = None, None + seq_idx_p = None + chunk_indices_p, chunk_offsets_p = None, None # Need flags to indicate if there are initial states # currently we really only support the FlashAttention backend - has_initial_states = None + has_initial_states_p = None prep_initial_states = False # Compute seq_idx, chunk_indices and chunk_offsets for prefill only @@ -91,44 +89,30 @@ def prepare_mamba2_metadata( # precompute flag to avoid device syncs later in mamba2 layer # forwards # prep is only needed for mamba2 ssd prefill processing - has_initial_states = attn_metadata.context_lens_tensor > 0 - prep_initial_states = torch.any( - has_initial_states[:num_prefills]).item() - query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1] - seq_idx = torch.repeat_interleave(torch.arange( - num_prefills, dtype=torch.int32, device=query_start_loc.device), - query_start_loc.diff(), - output_size=num_prefill_tokens) - seq_idx.unsqueeze_(0) + has_initial_states_p = ( + attn_metadata.context_lens_tensor[:num_prefills] > 0) + prep_initial_states = torch.any(has_initial_states_p).item() + query_start_loc_p = attn_metadata.query_start_loc[:num_prefills + 1] + seq_idx_p = torch.repeat_interleave(torch.arange( + num_prefills, dtype=torch.int32, device=query_start_loc_p.device), + query_start_loc_p.diff(), + output_size=num_prefill_tokens) + seq_idx_p.unsqueeze_(0) # We compute metadata for chunked prefill once at the top level model # forward and reuse them in mamba layers. If not needed, they will be # ignored inside mamba kernels. if prep_initial_states: - chunk_indices, chunk_offsets = \ + chunk_indices_p, chunk_offsets_p = \ _query_start_loc_to_chunk_indices_offsets( - query_start_loc, chunk_size, num_prefill_tokens) + query_start_loc_p, chunk_size, num_prefill_tokens) - if mamba2_metadata is not None: - mamba2_metadata.has_initial_states = has_initial_states - mamba2_metadata.prep_initial_states = prep_initial_states - mamba2_metadata.chunk_size = chunk_size - mamba2_metadata.seq_idx = seq_idx - mamba2_metadata.chunk_indices = chunk_indices - mamba2_metadata.chunk_offsets = chunk_offsets - # We use 1 reset flag: - # * mamba2_metadata.cu_seqlen is None - # update config specific to (each input) - # (become available at first layer, e.g. conv_weights) - mamba2_metadata.cu_seqlen = None # suppose to be updated at each input - - return mamba2_metadata - return Mamba2Metadata(has_initial_states=has_initial_states, + return Mamba2Metadata(has_initial_states_p=has_initial_states_p, prep_initial_states=prep_initial_states, chunk_size=chunk_size, - seq_idx=seq_idx, - chunk_indices=chunk_indices, - chunk_offsets=chunk_offsets) + seq_idx_p=seq_idx_p, + chunk_indices_p=chunk_indices_p, + chunk_offsets_p=chunk_offsets_p) def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor, diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 23e19da430e1..02e6a9138c05 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -518,22 +518,19 @@ class MambaMixer2(MambaBase, CustomOp): conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] state_indices_tensor = attn_metadata.state_indices_tensor - has_initial_states_p = attn_metadata.has_initial_states_p - prep_initial_states = attn_metadata.prep_initial_states - chunk_size = attn_metadata.chunk_size - seq_idx_p = attn_metadata.seq_idx_p - chunk_indices_p = attn_metadata.chunk_indices_p - chunk_offsets_p = attn_metadata.chunk_offsets_p else: conv_state = mamba_cache_params.conv_state ssm_state = mamba_cache_params.ssm_state state_indices_tensor = mamba_cache_params.state_indices_tensor - has_initial_states_p = mamba2_metadata.has_initial_states + + # Common members between V1 metadata and V0 metadata + if mamba2_metadata is not None: + has_initial_states_p = mamba2_metadata.has_initial_states_p prep_initial_states = mamba2_metadata.prep_initial_states chunk_size = mamba2_metadata.chunk_size - seq_idx_p = mamba2_metadata.seq_idx - chunk_indices_p = mamba2_metadata.chunk_indices - chunk_offsets_p = mamba2_metadata.chunk_offsets + seq_idx_p = mamba2_metadata.seq_idx_p + chunk_indices_p = mamba2_metadata.chunk_indices_p + chunk_offsets_p = mamba2_metadata.chunk_offsets_p # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) @@ -677,15 +674,9 @@ class MambaMixer2(MambaBase, CustomOp): # 3. State Space Model sequence transformation initial_states = None if (has_initial_states_p is not None and prep_initial_states): - # making a copy of the states - if envs.VLLM_USE_V1: - initial_states = torch.where( - has_initial_states_p[:, None, None, None], - ssm_state[state_indices_tensor_p], 0) - else: - initial_states = torch.where( - has_initial_states_p[:num_prefills, None, None, None], - ssm_state[state_indices_tensor_p], 0) + initial_states = torch.where( + has_initial_states_p[:, None, None, None], + ssm_state[state_indices_tensor_p], 0) # NOTE: final output is an in-place update of out tensor varlen_state = mamba_chunk_scan_combined( diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index b9869f5e5880..ef96d272adfb 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -279,22 +279,19 @@ class Plamo2MambaMixer(MambaBase, CustomOp): conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] state_indices_tensor = attn_metadata.state_indices_tensor - has_initial_states_p = attn_metadata.has_initial_states_p - prep_initial_states = attn_metadata.prep_initial_states - chunk_size = attn_metadata.chunk_size - seq_idx_p = attn_metadata.seq_idx_p - chunk_indices_p = attn_metadata.chunk_indices_p - chunk_offsets_p = attn_metadata.chunk_offsets_p else: conv_state = mamba_cache_params.conv_state ssm_state = mamba_cache_params.ssm_state state_indices_tensor = mamba_cache_params.state_indices_tensor - has_initial_states_p = mamba2_metadata.has_initial_states + + # Common members between V1 metadata and V0 metadata + if mamba2_metadata is not None: + has_initial_states_p = mamba2_metadata.has_initial_states_p prep_initial_states = mamba2_metadata.prep_initial_states chunk_size = mamba2_metadata.chunk_size - seq_idx_p = mamba2_metadata.seq_idx - chunk_indices_p = mamba2_metadata.chunk_indices - chunk_offsets_p = mamba2_metadata.chunk_offsets + seq_idx_p = mamba2_metadata.seq_idx_p + chunk_indices_p = mamba2_metadata.chunk_indices_p + chunk_offsets_p = mamba2_metadata.chunk_offsets_p # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states) @@ -414,14 +411,10 @@ class Plamo2MambaMixer(MambaBase, CustomOp): initial_states = None if has_initial_states_p is not None and prep_initial_states: # making a copy of the states - if envs.VLLM_USE_V1: - initial_states = torch.where( - has_initial_states_p[:, None, None, None], - ssm_state[state_indices_tensor_p], 0) - else: - initial_states = torch.where( - has_initial_states_p[:num_prefills, None, None, None], - ssm_state[state_indices_tensor_p], 0) + initial_states = torch.where( + has_initial_states_p[:, None, None, None], + ssm_state[state_indices_tensor_p], 0) + varlen_state = mamba_chunk_scan_combined( hidden_states_p.view(1, num_prefill_tokens, self.num_heads // self.tp_size,