mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-24 15:11:25 +08:00
[Bugfix][mamba] Fix type annotation of Mamba2Metadata (#22787)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
6b794c756c
commit
fceafaf582
@ -473,12 +473,12 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||||
ssm_state = self_kv_cache[1]
|
ssm_state = self_kv_cache[1]
|
||||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||||
has_initial_states_p = attn_metadata.has_initial_states
|
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||||
prep_initial_states = attn_metadata.prep_initial_states
|
prep_initial_states = attn_metadata.prep_initial_states
|
||||||
chunk_size = attn_metadata.chunk_size
|
chunk_size = attn_metadata.chunk_size
|
||||||
seq_idx_p = attn_metadata.seq_idx
|
seq_idx_p = attn_metadata.seq_idx_p
|
||||||
chunk_indices_p = attn_metadata.chunk_indices
|
chunk_indices_p = attn_metadata.chunk_indices_p
|
||||||
chunk_offsets_p = attn_metadata.chunk_offsets
|
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
||||||
else:
|
else:
|
||||||
conv_state = mamba_cache_params.conv_state
|
conv_state = mamba_cache_params.conv_state
|
||||||
ssm_state = mamba_cache_params.ssm_state
|
ssm_state = mamba_cache_params.ssm_state
|
||||||
|
|||||||
@ -68,14 +68,19 @@ class Mamba2AttentionMetadata:
|
|||||||
query_start_loc: torch.Tensor
|
query_start_loc: torch.Tensor
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
|
|
||||||
has_initial_states: torch.Tensor
|
|
||||||
prep_initial_states: bool
|
prep_initial_states: bool
|
||||||
chunk_size: int
|
chunk_size: int
|
||||||
seq_idx: torch.Tensor
|
|
||||||
chunk_indices: torch.Tensor
|
# The following tensors only contain prefill requests and will be None if
|
||||||
chunk_offsets: torch.Tensor
|
# the batch has no prefill request.
|
||||||
|
has_initial_states_p: Optional[torch.Tensor]
|
||||||
|
seq_idx_p: Optional[torch.Tensor]
|
||||||
|
chunk_indices_p: Optional[torch.Tensor]
|
||||||
|
chunk_offsets_p: Optional[torch.Tensor]
|
||||||
|
|
||||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||||
|
|
||||||
|
# The following attributes are for triton implementation of causal_conv1d
|
||||||
nums_dict: Optional[dict] = None
|
nums_dict: Optional[dict] = None
|
||||||
cu_seqlen: Optional[int] = None
|
cu_seqlen: Optional[int] = None
|
||||||
batch_ptr: Optional[torch.tensor] = None
|
batch_ptr: Optional[torch.tensor] = None
|
||||||
@ -115,11 +120,11 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
query_start_loc = common_attn_metadata.query_start_loc
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
seq_lens = common_attn_metadata.seq_lens
|
seq_lens = common_attn_metadata.seq_lens
|
||||||
|
|
||||||
seq_idx = None
|
seq_idx_p = None
|
||||||
chunk_indices, chunk_offsets = None, None
|
chunk_indices_p, chunk_offsets_p = None, None
|
||||||
# Need flags to indicate if there are initial states
|
# Need flags to indicate if there are initial states
|
||||||
# currently we really only support the FlashAttention backend
|
# currently we really only support the FlashAttention backend
|
||||||
has_initial_states = None
|
has_initial_states_p = None
|
||||||
prep_initial_states = False
|
prep_initial_states = False
|
||||||
|
|
||||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||||
@ -135,25 +140,25 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
common_attn_metadata.
|
common_attn_metadata.
|
||||||
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
|
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
|
||||||
prep_initial_states = torch.any(has_initial_states_cpu).item()
|
prep_initial_states = torch.any(has_initial_states_cpu).item()
|
||||||
has_initial_states = has_initial_states_cpu.to(
|
has_initial_states_p = has_initial_states_cpu.to(
|
||||||
query_start_loc.device)
|
query_start_loc.device)
|
||||||
|
|
||||||
query_start_loc_p = common_attn_metadata.query_start_loc[
|
query_start_loc_p = common_attn_metadata.query_start_loc[
|
||||||
-num_prefills - 1:] - num_decode_tokens
|
-num_prefills - 1:] - num_decode_tokens
|
||||||
|
|
||||||
seq_idx = torch.repeat_interleave(torch.arange(
|
seq_idx_p = torch.repeat_interleave(torch.arange(
|
||||||
num_prefills,
|
num_prefills,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=query_start_loc_p.device),
|
device=query_start_loc_p.device),
|
||||||
query_start_loc_p.diff(),
|
query_start_loc_p.diff(),
|
||||||
output_size=num_prefill_tokens)
|
output_size=num_prefill_tokens)
|
||||||
seq_idx.unsqueeze_(0)
|
seq_idx_p.unsqueeze_(0)
|
||||||
|
|
||||||
# We compute metadata for chunked prefill once at the top level
|
# We compute metadata for chunked prefill once at the top level
|
||||||
# model forward and reuse them in mamba layers. If not needed,
|
# model forward and reuse them in mamba layers. If not needed,
|
||||||
# they will be ignored inside mamba kernels.
|
# they will be ignored inside mamba kernels.
|
||||||
if prep_initial_states:
|
if prep_initial_states:
|
||||||
chunk_indices, chunk_offsets = (
|
chunk_indices_p, chunk_offsets_p = (
|
||||||
_query_start_loc_to_chunk_indices_offsets(
|
_query_start_loc_to_chunk_indices_offsets(
|
||||||
query_start_loc_p, self.chunk_size,
|
query_start_loc_p, self.chunk_size,
|
||||||
num_prefill_tokens))
|
num_prefill_tokens))
|
||||||
@ -173,12 +178,12 @@ class Mamba2AttentionMetadataBuilder(
|
|||||||
num_decode_tokens=num_decode_tokens,
|
num_decode_tokens=num_decode_tokens,
|
||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
seq_lens=seq_lens,
|
seq_lens=seq_lens,
|
||||||
has_initial_states=has_initial_states,
|
|
||||||
prep_initial_states=prep_initial_states,
|
prep_initial_states=prep_initial_states,
|
||||||
chunk_size=self.chunk_size,
|
chunk_size=self.chunk_size,
|
||||||
seq_idx=seq_idx,
|
has_initial_states_p=has_initial_states_p,
|
||||||
chunk_indices=chunk_indices,
|
seq_idx_p=seq_idx_p,
|
||||||
chunk_offsets=chunk_offsets,
|
chunk_indices_p=chunk_indices_p,
|
||||||
|
chunk_offsets_p=chunk_offsets_p,
|
||||||
state_indices_tensor=state_indices_tensor,
|
state_indices_tensor=state_indices_tensor,
|
||||||
)
|
)
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user