mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 19:55:00 +08:00
[Model] Clean up and simplify Mamba2 Metadata Usage in both V0 and V1 (#24331)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
parent
4e5affeaa1
commit
73cfb3c5ee
@ -17,14 +17,13 @@ from vllm.v1.attention.backends.mamba2_attn import (
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Mamba2Metadata:
|
class Mamba2Metadata:
|
||||||
|
|
||||||
has_initial_states: torch.Tensor
|
|
||||||
prep_initial_states: bool
|
prep_initial_states: bool
|
||||||
|
|
||||||
chunk_size: int
|
chunk_size: int
|
||||||
seq_idx: torch.Tensor
|
|
||||||
chunk_indices: torch.Tensor
|
has_initial_states_p: torch.Tensor
|
||||||
chunk_offsets: torch.Tensor
|
seq_idx_p: torch.Tensor
|
||||||
|
chunk_indices_p: torch.Tensor
|
||||||
|
chunk_offsets_p: torch.Tensor
|
||||||
"""
|
"""
|
||||||
With continuous batching layout of `x` in vLLM, to enable a Triton program
|
With continuous batching layout of `x` in vLLM, to enable a Triton program
|
||||||
to handle a request in parallel, two supporting tensors are used
|
to handle a request in parallel, two supporting tensors are used
|
||||||
@ -68,7 +67,6 @@ def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
|
|||||||
def prepare_mamba2_metadata(
|
def prepare_mamba2_metadata(
|
||||||
chunk_size: int,
|
chunk_size: int,
|
||||||
attn_metadata: AttentionMetadata,
|
attn_metadata: AttentionMetadata,
|
||||||
mamba2_metadata=None,
|
|
||||||
) -> Mamba2Metadata:
|
) -> Mamba2Metadata:
|
||||||
|
|
||||||
# compute number of prefill and decode requests
|
# compute number of prefill and decode requests
|
||||||
@ -76,11 +74,11 @@ def prepare_mamba2_metadata(
|
|||||||
num_prefills = attn_metadata.num_prefills
|
num_prefills = attn_metadata.num_prefills
|
||||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||||
|
|
||||||
seq_idx = None
|
seq_idx_p = None
|
||||||
chunk_indices, chunk_offsets = None, None
|
chunk_indices_p, chunk_offsets_p = None, None
|
||||||
# Need flags to indicate if there are initial states
|
# Need flags to indicate if there are initial states
|
||||||
# currently we really only support the FlashAttention backend
|
# currently we really only support the FlashAttention backend
|
||||||
has_initial_states = None
|
has_initial_states_p = None
|
||||||
prep_initial_states = False
|
prep_initial_states = False
|
||||||
|
|
||||||
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
|
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
|
||||||
@ -91,44 +89,30 @@ def prepare_mamba2_metadata(
|
|||||||
# precompute flag to avoid device syncs later in mamba2 layer
|
# precompute flag to avoid device syncs later in mamba2 layer
|
||||||
# forwards
|
# forwards
|
||||||
# prep is only needed for mamba2 ssd prefill processing
|
# prep is only needed for mamba2 ssd prefill processing
|
||||||
has_initial_states = attn_metadata.context_lens_tensor > 0
|
has_initial_states_p = (
|
||||||
prep_initial_states = torch.any(
|
attn_metadata.context_lens_tensor[:num_prefills] > 0)
|
||||||
has_initial_states[:num_prefills]).item()
|
prep_initial_states = torch.any(has_initial_states_p).item()
|
||||||
query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1]
|
query_start_loc_p = attn_metadata.query_start_loc[:num_prefills + 1]
|
||||||
seq_idx = torch.repeat_interleave(torch.arange(
|
seq_idx_p = torch.repeat_interleave(torch.arange(
|
||||||
num_prefills, dtype=torch.int32, device=query_start_loc.device),
|
num_prefills, dtype=torch.int32, device=query_start_loc_p.device),
|
||||||
query_start_loc.diff(),
|
query_start_loc_p.diff(),
|
||||||
output_size=num_prefill_tokens)
|
output_size=num_prefill_tokens)
|
||||||
seq_idx.unsqueeze_(0)
|
seq_idx_p.unsqueeze_(0)
|
||||||
|
|
||||||
# We compute metadata for chunked prefill once at the top level model
|
# We compute metadata for chunked prefill once at the top level model
|
||||||
# forward and reuse them in mamba layers. If not needed, they will be
|
# forward and reuse them in mamba layers. If not needed, they will be
|
||||||
# ignored inside mamba kernels.
|
# ignored inside mamba kernels.
|
||||||
if prep_initial_states:
|
if prep_initial_states:
|
||||||
chunk_indices, chunk_offsets = \
|
chunk_indices_p, chunk_offsets_p = \
|
||||||
_query_start_loc_to_chunk_indices_offsets(
|
_query_start_loc_to_chunk_indices_offsets(
|
||||||
query_start_loc, chunk_size, num_prefill_tokens)
|
query_start_loc_p, chunk_size, num_prefill_tokens)
|
||||||
|
|
||||||
if mamba2_metadata is not None:
|
return Mamba2Metadata(has_initial_states_p=has_initial_states_p,
|
||||||
mamba2_metadata.has_initial_states = has_initial_states
|
|
||||||
mamba2_metadata.prep_initial_states = prep_initial_states
|
|
||||||
mamba2_metadata.chunk_size = chunk_size
|
|
||||||
mamba2_metadata.seq_idx = seq_idx
|
|
||||||
mamba2_metadata.chunk_indices = chunk_indices
|
|
||||||
mamba2_metadata.chunk_offsets = chunk_offsets
|
|
||||||
# We use 1 reset flag:
|
|
||||||
# * mamba2_metadata.cu_seqlen is None
|
|
||||||
# update config specific to (each input)
|
|
||||||
# (become available at first layer, e.g. conv_weights)
|
|
||||||
mamba2_metadata.cu_seqlen = None # suppose to be updated at each input
|
|
||||||
|
|
||||||
return mamba2_metadata
|
|
||||||
return Mamba2Metadata(has_initial_states=has_initial_states,
|
|
||||||
prep_initial_states=prep_initial_states,
|
prep_initial_states=prep_initial_states,
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
seq_idx=seq_idx,
|
seq_idx_p=seq_idx_p,
|
||||||
chunk_indices=chunk_indices,
|
chunk_indices_p=chunk_indices_p,
|
||||||
chunk_offsets=chunk_offsets)
|
chunk_offsets_p=chunk_offsets_p)
|
||||||
|
|
||||||
|
|
||||||
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
|
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
|
||||||
|
|||||||
@ -518,22 +518,19 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||||
ssm_state = self_kv_cache[1]
|
ssm_state = self_kv_cache[1]
|
||||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
|
||||||
prep_initial_states = attn_metadata.prep_initial_states
|
|
||||||
chunk_size = attn_metadata.chunk_size
|
|
||||||
seq_idx_p = attn_metadata.seq_idx_p
|
|
||||||
chunk_indices_p = attn_metadata.chunk_indices_p
|
|
||||||
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
|
||||||
else:
|
else:
|
||||||
conv_state = mamba_cache_params.conv_state
|
conv_state = mamba_cache_params.conv_state
|
||||||
ssm_state = mamba_cache_params.ssm_state
|
ssm_state = mamba_cache_params.ssm_state
|
||||||
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
||||||
has_initial_states_p = mamba2_metadata.has_initial_states
|
|
||||||
|
# Common members between V1 metadata and V0 metadata
|
||||||
|
if mamba2_metadata is not None:
|
||||||
|
has_initial_states_p = mamba2_metadata.has_initial_states_p
|
||||||
prep_initial_states = mamba2_metadata.prep_initial_states
|
prep_initial_states = mamba2_metadata.prep_initial_states
|
||||||
chunk_size = mamba2_metadata.chunk_size
|
chunk_size = mamba2_metadata.chunk_size
|
||||||
seq_idx_p = mamba2_metadata.seq_idx
|
seq_idx_p = mamba2_metadata.seq_idx_p
|
||||||
chunk_indices_p = mamba2_metadata.chunk_indices
|
chunk_indices_p = mamba2_metadata.chunk_indices_p
|
||||||
chunk_offsets_p = mamba2_metadata.chunk_offsets
|
chunk_offsets_p = mamba2_metadata.chunk_offsets_p
|
||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states, _ = self.in_proj(hidden_states)
|
projected_states, _ = self.in_proj(hidden_states)
|
||||||
@ -677,15 +674,9 @@ class MambaMixer2(MambaBase, CustomOp):
|
|||||||
# 3. State Space Model sequence transformation
|
# 3. State Space Model sequence transformation
|
||||||
initial_states = None
|
initial_states = None
|
||||||
if (has_initial_states_p is not None and prep_initial_states):
|
if (has_initial_states_p is not None and prep_initial_states):
|
||||||
# making a copy of the states
|
initial_states = torch.where(
|
||||||
if envs.VLLM_USE_V1:
|
has_initial_states_p[:, None, None, None],
|
||||||
initial_states = torch.where(
|
ssm_state[state_indices_tensor_p], 0)
|
||||||
has_initial_states_p[:, None, None, None],
|
|
||||||
ssm_state[state_indices_tensor_p], 0)
|
|
||||||
else:
|
|
||||||
initial_states = torch.where(
|
|
||||||
has_initial_states_p[:num_prefills, None, None, None],
|
|
||||||
ssm_state[state_indices_tensor_p], 0)
|
|
||||||
|
|
||||||
# NOTE: final output is an in-place update of out tensor
|
# NOTE: final output is an in-place update of out tensor
|
||||||
varlen_state = mamba_chunk_scan_combined(
|
varlen_state = mamba_chunk_scan_combined(
|
||||||
|
|||||||
@ -279,22 +279,19 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
|||||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||||
ssm_state = self_kv_cache[1]
|
ssm_state = self_kv_cache[1]
|
||||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
|
||||||
prep_initial_states = attn_metadata.prep_initial_states
|
|
||||||
chunk_size = attn_metadata.chunk_size
|
|
||||||
seq_idx_p = attn_metadata.seq_idx_p
|
|
||||||
chunk_indices_p = attn_metadata.chunk_indices_p
|
|
||||||
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
|
||||||
else:
|
else:
|
||||||
conv_state = mamba_cache_params.conv_state
|
conv_state = mamba_cache_params.conv_state
|
||||||
ssm_state = mamba_cache_params.ssm_state
|
ssm_state = mamba_cache_params.ssm_state
|
||||||
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
||||||
has_initial_states_p = mamba2_metadata.has_initial_states
|
|
||||||
|
# Common members between V1 metadata and V0 metadata
|
||||||
|
if mamba2_metadata is not None:
|
||||||
|
has_initial_states_p = mamba2_metadata.has_initial_states_p
|
||||||
prep_initial_states = mamba2_metadata.prep_initial_states
|
prep_initial_states = mamba2_metadata.prep_initial_states
|
||||||
chunk_size = mamba2_metadata.chunk_size
|
chunk_size = mamba2_metadata.chunk_size
|
||||||
seq_idx_p = mamba2_metadata.seq_idx
|
seq_idx_p = mamba2_metadata.seq_idx_p
|
||||||
chunk_indices_p = mamba2_metadata.chunk_indices
|
chunk_indices_p = mamba2_metadata.chunk_indices_p
|
||||||
chunk_offsets_p = mamba2_metadata.chunk_offsets
|
chunk_offsets_p = mamba2_metadata.chunk_offsets_p
|
||||||
|
|
||||||
# 1. Gated MLP's linear projection
|
# 1. Gated MLP's linear projection
|
||||||
projected_states = self.in_proj(hidden_states)
|
projected_states = self.in_proj(hidden_states)
|
||||||
@ -414,14 +411,10 @@ class Plamo2MambaMixer(MambaBase, CustomOp):
|
|||||||
initial_states = None
|
initial_states = None
|
||||||
if has_initial_states_p is not None and prep_initial_states:
|
if has_initial_states_p is not None and prep_initial_states:
|
||||||
# making a copy of the states
|
# making a copy of the states
|
||||||
if envs.VLLM_USE_V1:
|
initial_states = torch.where(
|
||||||
initial_states = torch.where(
|
has_initial_states_p[:, None, None, None],
|
||||||
has_initial_states_p[:, None, None, None],
|
ssm_state[state_indices_tensor_p], 0)
|
||||||
ssm_state[state_indices_tensor_p], 0)
|
|
||||||
else:
|
|
||||||
initial_states = torch.where(
|
|
||||||
has_initial_states_p[:num_prefills, None, None, None],
|
|
||||||
ssm_state[state_indices_tensor_p], 0)
|
|
||||||
varlen_state = mamba_chunk_scan_combined(
|
varlen_state = mamba_chunk_scan_combined(
|
||||||
hidden_states_p.view(1, num_prefill_tokens,
|
hidden_states_p.view(1, num_prefill_tokens,
|
||||||
self.num_heads // self.tp_size,
|
self.num_heads // self.tp_size,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user