[Bugfix][mamba] Fix type annotation of Mamba2Metadata (#22787)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-08-13 06:07:09 -07:00 committed by GitHub
parent 6b794c756c
commit fceafaf582
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 21 deletions

View File

@ -473,12 +473,12 @@ class MambaMixer2(MambaBase, CustomOp):
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices chunk_indices_p = attn_metadata.chunk_indices_p
chunk_offsets_p = attn_metadata.chunk_offsets chunk_offsets_p = attn_metadata.chunk_offsets_p
else: else:
conv_state = mamba_cache_params.conv_state conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state ssm_state = mamba_cache_params.ssm_state

View File

@ -68,14 +68,19 @@ class Mamba2AttentionMetadata:
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
seq_lens: torch.Tensor seq_lens: torch.Tensor
has_initial_states: torch.Tensor
prep_initial_states: bool prep_initial_states: bool
chunk_size: int chunk_size: int
seq_idx: torch.Tensor
chunk_indices: torch.Tensor # The following tensors only contain prefill requests and will be None if
chunk_offsets: torch.Tensor # the batch has no prefill request.
has_initial_states_p: Optional[torch.Tensor]
seq_idx_p: Optional[torch.Tensor]
chunk_indices_p: Optional[torch.Tensor]
chunk_offsets_p: Optional[torch.Tensor]
state_indices_tensor: torch.Tensor # shape: [batch,] state_indices_tensor: torch.Tensor # shape: [batch,]
# The following attributes are for triton implementation of causal_conv1d
nums_dict: Optional[dict] = None nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.tensor] = None batch_ptr: Optional[torch.tensor] = None
@ -115,11 +120,11 @@ class Mamba2AttentionMetadataBuilder(
query_start_loc = common_attn_metadata.query_start_loc query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens seq_lens = common_attn_metadata.seq_lens
seq_idx = None seq_idx_p = None
chunk_indices, chunk_offsets = None, None chunk_indices_p, chunk_offsets_p = None, None
# Need flags to indicate if there are initial states # Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend # currently we really only support the FlashAttention backend
has_initial_states = None has_initial_states_p = None
prep_initial_states = False prep_initial_states = False
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
@ -135,25 +140,25 @@ class Mamba2AttentionMetadataBuilder(
common_attn_metadata. common_attn_metadata.
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0) num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
prep_initial_states = torch.any(has_initial_states_cpu).item() prep_initial_states = torch.any(has_initial_states_cpu).item()
has_initial_states = has_initial_states_cpu.to( has_initial_states_p = has_initial_states_cpu.to(
query_start_loc.device) query_start_loc.device)
query_start_loc_p = common_attn_metadata.query_start_loc[ query_start_loc_p = common_attn_metadata.query_start_loc[
-num_prefills - 1:] - num_decode_tokens -num_prefills - 1:] - num_decode_tokens
seq_idx = torch.repeat_interleave(torch.arange( seq_idx_p = torch.repeat_interleave(torch.arange(
num_prefills, num_prefills,
dtype=torch.int32, dtype=torch.int32,
device=query_start_loc_p.device), device=query_start_loc_p.device),
query_start_loc_p.diff(), query_start_loc_p.diff(),
output_size=num_prefill_tokens) output_size=num_prefill_tokens)
seq_idx.unsqueeze_(0) seq_idx_p.unsqueeze_(0)
# We compute metadata for chunked prefill once at the top level # We compute metadata for chunked prefill once at the top level
# model forward and reuse them in mamba layers. If not needed, # model forward and reuse them in mamba layers. If not needed,
# they will be ignored inside mamba kernels. # they will be ignored inside mamba kernels.
if prep_initial_states: if prep_initial_states:
chunk_indices, chunk_offsets = ( chunk_indices_p, chunk_offsets_p = (
_query_start_loc_to_chunk_indices_offsets( _query_start_loc_to_chunk_indices_offsets(
query_start_loc_p, self.chunk_size, query_start_loc_p, self.chunk_size,
num_prefill_tokens)) num_prefill_tokens))
@ -173,12 +178,12 @@ class Mamba2AttentionMetadataBuilder(
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
seq_lens=seq_lens, seq_lens=seq_lens,
has_initial_states=has_initial_states,
prep_initial_states=prep_initial_states, prep_initial_states=prep_initial_states,
chunk_size=self.chunk_size, chunk_size=self.chunk_size,
seq_idx=seq_idx, has_initial_states_p=has_initial_states_p,
chunk_indices=chunk_indices, seq_idx_p=seq_idx_p,
chunk_offsets=chunk_offsets, chunk_indices_p=chunk_indices_p,
chunk_offsets_p=chunk_offsets_p,
state_indices_tensor=state_indices_tensor, state_indices_tensor=state_indices_tensor,
) )
return attn_metadata return attn_metadata