From fceafaf582cd72e6636f47127a665afb9e0ea0aa Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Wed, 13 Aug 2025 06:07:09 -0700 Subject: [PATCH] [Bugfix][mamba] Fix type annotation of Mamba2Metadata (#22787) Signed-off-by: Chen Zhang --- .../layers/mamba/mamba_mixer2.py | 8 ++-- vllm/v1/attention/backends/mamba_attn.py | 39 +++++++++++-------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index d5f4877135c94..10a5618c227e8 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -473,12 +473,12 @@ class MambaMixer2(MambaBase, CustomOp): conv_state = self_kv_cache[0].transpose(-1, -2) ssm_state = self_kv_cache[1] state_indices_tensor = attn_metadata.state_indices_tensor - has_initial_states_p = attn_metadata.has_initial_states + has_initial_states_p = attn_metadata.has_initial_states_p prep_initial_states = attn_metadata.prep_initial_states chunk_size = attn_metadata.chunk_size - seq_idx_p = attn_metadata.seq_idx - chunk_indices_p = attn_metadata.chunk_indices - chunk_offsets_p = attn_metadata.chunk_offsets + seq_idx_p = attn_metadata.seq_idx_p + chunk_indices_p = attn_metadata.chunk_indices_p + chunk_offsets_p = attn_metadata.chunk_offsets_p else: conv_state = mamba_cache_params.conv_state ssm_state = mamba_cache_params.ssm_state diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 7c1226049f696..3f84f8967db7a 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -68,14 +68,19 @@ class Mamba2AttentionMetadata: query_start_loc: torch.Tensor seq_lens: torch.Tensor - has_initial_states: torch.Tensor prep_initial_states: bool chunk_size: int - seq_idx: torch.Tensor - chunk_indices: torch.Tensor - chunk_offsets: torch.Tensor + + # The following tensors only contain prefill requests and will be None if + # the batch has no prefill request. + has_initial_states_p: Optional[torch.Tensor] + seq_idx_p: Optional[torch.Tensor] + chunk_indices_p: Optional[torch.Tensor] + chunk_offsets_p: Optional[torch.Tensor] state_indices_tensor: torch.Tensor # shape: [batch,] + + # The following attributes are for triton implementation of causal_conv1d nums_dict: Optional[dict] = None cu_seqlen: Optional[int] = None batch_ptr: Optional[torch.tensor] = None @@ -115,11 +120,11 @@ class Mamba2AttentionMetadataBuilder( query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens - seq_idx = None - chunk_indices, chunk_offsets = None, None + seq_idx_p = None + chunk_indices_p, chunk_offsets_p = None, None # Need flags to indicate if there are initial states # currently we really only support the FlashAttention backend - has_initial_states = None + has_initial_states_p = None prep_initial_states = False state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] @@ -135,25 +140,25 @@ class Mamba2AttentionMetadataBuilder( common_attn_metadata. num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0) prep_initial_states = torch.any(has_initial_states_cpu).item() - has_initial_states = has_initial_states_cpu.to( + has_initial_states_p = has_initial_states_cpu.to( query_start_loc.device) query_start_loc_p = common_attn_metadata.query_start_loc[ -num_prefills - 1:] - num_decode_tokens - seq_idx = torch.repeat_interleave(torch.arange( + seq_idx_p = torch.repeat_interleave(torch.arange( num_prefills, dtype=torch.int32, device=query_start_loc_p.device), - query_start_loc_p.diff(), - output_size=num_prefill_tokens) - seq_idx.unsqueeze_(0) + query_start_loc_p.diff(), + output_size=num_prefill_tokens) + seq_idx_p.unsqueeze_(0) # We compute metadata for chunked prefill once at the top level # model forward and reuse them in mamba layers. If not needed, # they will be ignored inside mamba kernels. if prep_initial_states: - chunk_indices, chunk_offsets = ( + chunk_indices_p, chunk_offsets_p = ( _query_start_loc_to_chunk_indices_offsets( query_start_loc_p, self.chunk_size, num_prefill_tokens)) @@ -173,12 +178,12 @@ class Mamba2AttentionMetadataBuilder( num_decode_tokens=num_decode_tokens, query_start_loc=query_start_loc, seq_lens=seq_lens, - has_initial_states=has_initial_states, prep_initial_states=prep_initial_states, chunk_size=self.chunk_size, - seq_idx=seq_idx, - chunk_indices=chunk_indices, - chunk_offsets=chunk_offsets, + has_initial_states_p=has_initial_states_p, + seq_idx_p=seq_idx_p, + chunk_indices_p=chunk_indices_p, + chunk_offsets_p=chunk_offsets_p, state_indices_tensor=state_indices_tensor, ) return attn_metadata