[Model] Clean up and simplify Mamba2 Metadata Usage in both V0 and V1 (#24331)

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
Chih-Chieh Yang 2025-09-16 10:53:43 -04:00 committed by GitHub
parent 4e5affeaa1
commit 73cfb3c5ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 44 additions and 76 deletions

View File

@ -17,14 +17,13 @@ from vllm.v1.attention.backends.mamba2_attn import (
@dataclass @dataclass
class Mamba2Metadata: class Mamba2Metadata:
has_initial_states: torch.Tensor
prep_initial_states: bool prep_initial_states: bool
chunk_size: int chunk_size: int
seq_idx: torch.Tensor
chunk_indices: torch.Tensor has_initial_states_p: torch.Tensor
chunk_offsets: torch.Tensor seq_idx_p: torch.Tensor
chunk_indices_p: torch.Tensor
chunk_offsets_p: torch.Tensor
""" """
With continuous batching layout of `x` in vLLM, to enable a Triton program With continuous batching layout of `x` in vLLM, to enable a Triton program
to handle a request in parallel, two supporting tensors are used to handle a request in parallel, two supporting tensors are used
@ -68,7 +67,6 @@ def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
def prepare_mamba2_metadata( def prepare_mamba2_metadata(
chunk_size: int, chunk_size: int,
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
mamba2_metadata=None,
) -> Mamba2Metadata: ) -> Mamba2Metadata:
# compute number of prefill and decode requests # compute number of prefill and decode requests
@ -76,11 +74,11 @@ def prepare_mamba2_metadata(
num_prefills = attn_metadata.num_prefills num_prefills = attn_metadata.num_prefills
num_prefill_tokens = attn_metadata.num_prefill_tokens num_prefill_tokens = attn_metadata.num_prefill_tokens
seq_idx = None seq_idx_p = None
chunk_indices, chunk_offsets = None, None chunk_indices_p, chunk_offsets_p = None, None
# Need flags to indicate if there are initial states # Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend # currently we really only support the FlashAttention backend
has_initial_states = None has_initial_states_p = None
prep_initial_states = False prep_initial_states = False
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only # Compute seq_idx, chunk_indices and chunk_offsets for prefill only
@ -91,44 +89,30 @@ def prepare_mamba2_metadata(
# precompute flag to avoid device syncs later in mamba2 layer # precompute flag to avoid device syncs later in mamba2 layer
# forwards # forwards
# prep is only needed for mamba2 ssd prefill processing # prep is only needed for mamba2 ssd prefill processing
has_initial_states = attn_metadata.context_lens_tensor > 0 has_initial_states_p = (
prep_initial_states = torch.any( attn_metadata.context_lens_tensor[:num_prefills] > 0)
has_initial_states[:num_prefills]).item() prep_initial_states = torch.any(has_initial_states_p).item()
query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1] query_start_loc_p = attn_metadata.query_start_loc[:num_prefills + 1]
seq_idx = torch.repeat_interleave(torch.arange( seq_idx_p = torch.repeat_interleave(torch.arange(
num_prefills, dtype=torch.int32, device=query_start_loc.device), num_prefills, dtype=torch.int32, device=query_start_loc_p.device),
query_start_loc.diff(), query_start_loc_p.diff(),
output_size=num_prefill_tokens) output_size=num_prefill_tokens)
seq_idx.unsqueeze_(0) seq_idx_p.unsqueeze_(0)
# We compute metadata for chunked prefill once at the top level model # We compute metadata for chunked prefill once at the top level model
# forward and reuse them in mamba layers. If not needed, they will be # forward and reuse them in mamba layers. If not needed, they will be
# ignored inside mamba kernels. # ignored inside mamba kernels.
if prep_initial_states: if prep_initial_states:
chunk_indices, chunk_offsets = \ chunk_indices_p, chunk_offsets_p = \
_query_start_loc_to_chunk_indices_offsets( _query_start_loc_to_chunk_indices_offsets(
query_start_loc, chunk_size, num_prefill_tokens) query_start_loc_p, chunk_size, num_prefill_tokens)
if mamba2_metadata is not None: return Mamba2Metadata(has_initial_states_p=has_initial_states_p,
mamba2_metadata.has_initial_states = has_initial_states
mamba2_metadata.prep_initial_states = prep_initial_states
mamba2_metadata.chunk_size = chunk_size
mamba2_metadata.seq_idx = seq_idx
mamba2_metadata.chunk_indices = chunk_indices
mamba2_metadata.chunk_offsets = chunk_offsets
# We use 1 reset flag:
# * mamba2_metadata.cu_seqlen is None
# update config specific to (each input)
# (become available at first layer, e.g. conv_weights)
mamba2_metadata.cu_seqlen = None # suppose to be updated at each input
return mamba2_metadata
return Mamba2Metadata(has_initial_states=has_initial_states,
prep_initial_states=prep_initial_states, prep_initial_states=prep_initial_states,
chunk_size=chunk_size, chunk_size=chunk_size,
seq_idx=seq_idx, seq_idx_p=seq_idx_p,
chunk_indices=chunk_indices, chunk_indices_p=chunk_indices_p,
chunk_offsets=chunk_offsets) chunk_offsets_p=chunk_offsets_p)
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor, def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,

View File

@ -518,22 +518,19 @@ class MambaMixer2(MambaBase, CustomOp):
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
chunk_offsets_p = attn_metadata.chunk_offsets_p
else: else:
conv_state = mamba_cache_params.conv_state conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor state_indices_tensor = mamba_cache_params.state_indices_tensor
has_initial_states_p = mamba2_metadata.has_initial_states
# Common members between V1 metadata and V0 metadata
if mamba2_metadata is not None:
has_initial_states_p = mamba2_metadata.has_initial_states_p
prep_initial_states = mamba2_metadata.prep_initial_states prep_initial_states = mamba2_metadata.prep_initial_states
chunk_size = mamba2_metadata.chunk_size chunk_size = mamba2_metadata.chunk_size
seq_idx_p = mamba2_metadata.seq_idx seq_idx_p = mamba2_metadata.seq_idx_p
chunk_indices_p = mamba2_metadata.chunk_indices chunk_indices_p = mamba2_metadata.chunk_indices_p
chunk_offsets_p = mamba2_metadata.chunk_offsets chunk_offsets_p = mamba2_metadata.chunk_offsets_p
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states) projected_states, _ = self.in_proj(hidden_states)
@ -677,15 +674,9 @@ class MambaMixer2(MambaBase, CustomOp):
# 3. State Space Model sequence transformation # 3. State Space Model sequence transformation
initial_states = None initial_states = None
if (has_initial_states_p is not None and prep_initial_states): if (has_initial_states_p is not None and prep_initial_states):
# making a copy of the states initial_states = torch.where(
if envs.VLLM_USE_V1: has_initial_states_p[:, None, None, None],
initial_states = torch.where( ssm_state[state_indices_tensor_p], 0)
has_initial_states_p[:, None, None, None],
ssm_state[state_indices_tensor_p], 0)
else:
initial_states = torch.where(
has_initial_states_p[:num_prefills, None, None, None],
ssm_state[state_indices_tensor_p], 0)
# NOTE: final output is an in-place update of out tensor # NOTE: final output is an in-place update of out tensor
varlen_state = mamba_chunk_scan_combined( varlen_state = mamba_chunk_scan_combined(

View File

@ -279,22 +279,19 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
chunk_offsets_p = attn_metadata.chunk_offsets_p
else: else:
conv_state = mamba_cache_params.conv_state conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor state_indices_tensor = mamba_cache_params.state_indices_tensor
has_initial_states_p = mamba2_metadata.has_initial_states
# Common members between V1 metadata and V0 metadata
if mamba2_metadata is not None:
has_initial_states_p = mamba2_metadata.has_initial_states_p
prep_initial_states = mamba2_metadata.prep_initial_states prep_initial_states = mamba2_metadata.prep_initial_states
chunk_size = mamba2_metadata.chunk_size chunk_size = mamba2_metadata.chunk_size
seq_idx_p = mamba2_metadata.seq_idx seq_idx_p = mamba2_metadata.seq_idx_p
chunk_indices_p = mamba2_metadata.chunk_indices chunk_indices_p = mamba2_metadata.chunk_indices_p
chunk_offsets_p = mamba2_metadata.chunk_offsets chunk_offsets_p = mamba2_metadata.chunk_offsets_p
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states) projected_states = self.in_proj(hidden_states)
@ -414,14 +411,10 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
initial_states = None initial_states = None
if has_initial_states_p is not None and prep_initial_states: if has_initial_states_p is not None and prep_initial_states:
# making a copy of the states # making a copy of the states
if envs.VLLM_USE_V1: initial_states = torch.where(
initial_states = torch.where( has_initial_states_p[:, None, None, None],
has_initial_states_p[:, None, None, None], ssm_state[state_indices_tensor_p], 0)
ssm_state[state_indices_tensor_p], 0)
else:
initial_states = torch.where(
has_initial_states_p[:num_prefills, None, None, None],
ssm_state[state_indices_tensor_p], 0)
varlen_state = mamba_chunk_scan_combined( varlen_state = mamba_chunk_scan_combined(
hidden_states_p.view(1, num_prefill_tokens, hidden_states_p.view(1, num_prefill_tokens,
self.num_heads // self.tp_size, self.num_heads // self.tp_size,