mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 17:05:37 +08:00
[Model] Clean up and simplify Mamba2 Metadata Usage in both V0 and V1 (#24331)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
parent
4e5affeaa1
commit
73cfb3c5ee
@ -17,14 +17,13 @@ from vllm.v1.attention.backends.mamba2_attn import (
|
||||
|
||||
@dataclass
|
||||
class Mamba2Metadata:
|
||||
|
||||
has_initial_states: torch.Tensor
|
||||
prep_initial_states: bool
|
||||
|
||||
chunk_size: int
|
||||
seq_idx: torch.Tensor
|
||||
chunk_indices: torch.Tensor
|
||||
chunk_offsets: torch.Tensor
|
||||
|
||||
has_initial_states_p: torch.Tensor
|
||||
seq_idx_p: torch.Tensor
|
||||
chunk_indices_p: torch.Tensor
|
||||
chunk_offsets_p: torch.Tensor
|
||||
"""
|
||||
With continuous batching layout of `x` in vLLM, to enable a Triton program
|
||||
to handle a request in parallel, two supporting tensors are used
|
||||
@ -68,7 +67,6 @@ def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
|
||||
def prepare_mamba2_metadata(
|
||||
chunk_size: int,
|
||||
attn_metadata: AttentionMetadata,
|
||||
mamba2_metadata=None,
|
||||
) -> Mamba2Metadata:
|
||||
|
||||
# compute number of prefill and decode requests
|
||||
@ -76,11 +74,11 @@ def prepare_mamba2_metadata(
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
|
||||
seq_idx = None
|
||||
chunk_indices, chunk_offsets = None, None
|
||||
seq_idx_p = None
|
||||
chunk_indices_p, chunk_offsets_p = None, None
|
||||
# Need flags to indicate if there are initial states
|
||||
# currently we really only support the FlashAttention backend
|
||||
has_initial_states = None
|
||||
has_initial_states_p = None
|
||||
prep_initial_states = False
|
||||
|
||||
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
|
||||
@ -91,44 +89,30 @@ def prepare_mamba2_metadata(
|
||||
# precompute flag to avoid device syncs later in mamba2 layer
|
||||
# forwards
|
||||
# prep is only needed for mamba2 ssd prefill processing
|
||||
has_initial_states = attn_metadata.context_lens_tensor > 0
|
||||
prep_initial_states = torch.any(
|
||||
has_initial_states[:num_prefills]).item()
|
||||
query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1]
|
||||
seq_idx = torch.repeat_interleave(torch.arange(
|
||||
num_prefills, dtype=torch.int32, device=query_start_loc.device),
|
||||
query_start_loc.diff(),
|
||||
output_size=num_prefill_tokens)
|
||||
seq_idx.unsqueeze_(0)
|
||||
has_initial_states_p = (
|
||||
attn_metadata.context_lens_tensor[:num_prefills] > 0)
|
||||
prep_initial_states = torch.any(has_initial_states_p).item()
|
||||
query_start_loc_p = attn_metadata.query_start_loc[:num_prefills + 1]
|
||||
seq_idx_p = torch.repeat_interleave(torch.arange(
|
||||
num_prefills, dtype=torch.int32, device=query_start_loc_p.device),
|
||||
query_start_loc_p.diff(),
|
||||
output_size=num_prefill_tokens)
|
||||
seq_idx_p.unsqueeze_(0)
|
||||
|
||||
# We compute metadata for chunked prefill once at the top level model
|
||||
# forward and reuse them in mamba layers. If not needed, they will be
|
||||
# ignored inside mamba kernels.
|
||||
if prep_initial_states:
|
||||
chunk_indices, chunk_offsets = \
|
||||
chunk_indices_p, chunk_offsets_p = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
query_start_loc, chunk_size, num_prefill_tokens)
|
||||
query_start_loc_p, chunk_size, num_prefill_tokens)
|
||||
|
||||
if mamba2_metadata is not None:
|
||||
mamba2_metadata.has_initial_states = has_initial_states
|
||||
mamba2_metadata.prep_initial_states = prep_initial_states
|
||||
mamba2_metadata.chunk_size = chunk_size
|
||||
mamba2_metadata.seq_idx = seq_idx
|
||||
mamba2_metadata.chunk_indices = chunk_indices
|
||||
mamba2_metadata.chunk_offsets = chunk_offsets
|
||||
# We use 1 reset flag:
|
||||
# * mamba2_metadata.cu_seqlen is None
|
||||
# update config specific to (each input)
|
||||
# (become available at first layer, e.g. conv_weights)
|
||||
mamba2_metadata.cu_seqlen = None # suppose to be updated at each input
|
||||
|
||||
return mamba2_metadata
|
||||
return Mamba2Metadata(has_initial_states=has_initial_states,
|
||||
return Mamba2Metadata(has_initial_states_p=has_initial_states_p,
|
||||
prep_initial_states=prep_initial_states,
|
||||
chunk_size=chunk_size,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets)
|
||||
seq_idx_p=seq_idx_p,
|
||||
chunk_indices_p=chunk_indices_p,
|
||||
chunk_offsets_p=chunk_offsets_p)
|
||||
|
||||
|
||||
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
|
||||
|
||||
@ -518,22 +518,19 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
prep_initial_states = attn_metadata.prep_initial_states
|
||||
chunk_size = attn_metadata.chunk_size
|
||||
seq_idx_p = attn_metadata.seq_idx_p
|
||||
chunk_indices_p = attn_metadata.chunk_indices_p
|
||||
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
||||
else:
|
||||
conv_state = mamba_cache_params.conv_state
|
||||
ssm_state = mamba_cache_params.ssm_state
|
||||
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
||||
has_initial_states_p = mamba2_metadata.has_initial_states
|
||||
|
||||
# Common members between V1 metadata and V0 metadata
|
||||
if mamba2_metadata is not None:
|
||||
has_initial_states_p = mamba2_metadata.has_initial_states_p
|
||||
prep_initial_states = mamba2_metadata.prep_initial_states
|
||||
chunk_size = mamba2_metadata.chunk_size
|
||||
seq_idx_p = mamba2_metadata.seq_idx
|
||||
chunk_indices_p = mamba2_metadata.chunk_indices
|
||||
chunk_offsets_p = mamba2_metadata.chunk_offsets
|
||||
seq_idx_p = mamba2_metadata.seq_idx_p
|
||||
chunk_indices_p = mamba2_metadata.chunk_indices_p
|
||||
chunk_offsets_p = mamba2_metadata.chunk_offsets_p
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states, _ = self.in_proj(hidden_states)
|
||||
@ -677,15 +674,9 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# 3. State Space Model sequence transformation
|
||||
initial_states = None
|
||||
if (has_initial_states_p is not None and prep_initial_states):
|
||||
# making a copy of the states
|
||||
if envs.VLLM_USE_V1:
|
||||
initial_states = torch.where(
|
||||
has_initial_states_p[:, None, None, None],
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
else:
|
||||
initial_states = torch.where(
|
||||
has_initial_states_p[:num_prefills, None, None, None],
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
initial_states = torch.where(
|
||||
has_initial_states_p[:, None, None, None],
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
varlen_state = mamba_chunk_scan_combined(
|
||||
|
||||
@ -279,22 +279,19 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
prep_initial_states = attn_metadata.prep_initial_states
|
||||
chunk_size = attn_metadata.chunk_size
|
||||
seq_idx_p = attn_metadata.seq_idx_p
|
||||
chunk_indices_p = attn_metadata.chunk_indices_p
|
||||
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
||||
else:
|
||||
conv_state = mamba_cache_params.conv_state
|
||||
ssm_state = mamba_cache_params.ssm_state
|
||||
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
||||
has_initial_states_p = mamba2_metadata.has_initial_states
|
||||
|
||||
# Common members between V1 metadata and V0 metadata
|
||||
if mamba2_metadata is not None:
|
||||
has_initial_states_p = mamba2_metadata.has_initial_states_p
|
||||
prep_initial_states = mamba2_metadata.prep_initial_states
|
||||
chunk_size = mamba2_metadata.chunk_size
|
||||
seq_idx_p = mamba2_metadata.seq_idx
|
||||
chunk_indices_p = mamba2_metadata.chunk_indices
|
||||
chunk_offsets_p = mamba2_metadata.chunk_offsets
|
||||
seq_idx_p = mamba2_metadata.seq_idx_p
|
||||
chunk_indices_p = mamba2_metadata.chunk_indices_p
|
||||
chunk_offsets_p = mamba2_metadata.chunk_offsets_p
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)
|
||||
@ -414,14 +411,10 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
||||
initial_states = None
|
||||
if has_initial_states_p is not None and prep_initial_states:
|
||||
# making a copy of the states
|
||||
if envs.VLLM_USE_V1:
|
||||
initial_states = torch.where(
|
||||
has_initial_states_p[:, None, None, None],
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
else:
|
||||
initial_states = torch.where(
|
||||
has_initial_states_p[:num_prefills, None, None, None],
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
initial_states = torch.where(
|
||||
has_initial_states_p[:, None, None, None],
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
|
||||
varlen_state = mamba_chunk_scan_combined(
|
||||
hidden_states_p.view(1, num_prefill_tokens,
|
||||
self.num_heads // self.tp_size,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user