mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-31 04:57:03 +08:00
[Model] Mamba2 causal conv1d Refactor to Split Prefill and Decode Requests for Corresponding Kernels (#17146)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
This commit is contained in:
parent
6de3e13413
commit
18dd5e01f2
@ -6,7 +6,7 @@ import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
_seq_idx_to_chunk_indices_offsets)
|
||||
_query_start_loc_to_chunk_indices_offsets)
|
||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
mamba_chunk_scan_combined)
|
||||
from vllm.platforms import current_platform
|
||||
@ -274,8 +274,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
|
||||
last_taken, exhausted, n_heads,
|
||||
d_head, itype):
|
||||
|
||||
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
|
||||
seq_idx, chunk_size)
|
||||
chunk_indices, chunk_offsets = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
cu_seqlens, chunk_size, cu_seqlens[-1])
|
||||
|
||||
Y, new_states = mamba_chunk_scan_combined(
|
||||
X,
|
||||
|
||||
@ -13,7 +13,6 @@ from vllm.attention.backends.xformers import XFormersMetadata
|
||||
|
||||
@dataclass
|
||||
class Mamba2Metadata:
|
||||
has_prefill: bool
|
||||
|
||||
has_initial_states: torch.Tensor
|
||||
prep_initial_states: bool
|
||||
@ -24,21 +23,23 @@ class Mamba2Metadata:
|
||||
chunk_offsets: torch.Tensor
|
||||
|
||||
|
||||
def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
|
||||
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
||||
chunk_size: int,
|
||||
total_seqlens: int):
|
||||
|
||||
# convert seq_idx to chunk indices and offsets
|
||||
# - derive the cu_seqlens
|
||||
_, cu_seqlens = torch.where(seq_idx.diff())
|
||||
cu_seqlens += 1
|
||||
cu_seqlens = query_start_loc[1:] # remove prepended 0
|
||||
|
||||
# outputs will have length expansion of chunks that do not divide
|
||||
# chunk_size
|
||||
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
|
||||
> 0).sum()
|
||||
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
|
||||
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
|
||||
N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size
|
||||
> 0).sum()
|
||||
chunk_indices = torch.arange(N,
|
||||
dtype=torch.int,
|
||||
device=query_start_loc.device)
|
||||
chunk_offsets = torch.zeros((N, ),
|
||||
dtype=torch.int,
|
||||
device=query_start_loc.device)
|
||||
|
||||
cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
|
||||
p = 0 # num of insertions
|
||||
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
|
||||
|
||||
@ -60,48 +61,49 @@ def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
|
||||
|
||||
def prepare_mamba2_metadata(
|
||||
chunk_size: int,
|
||||
input_ids: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> Mamba2Metadata:
|
||||
|
||||
# compute number of prefill and decode requests
|
||||
# NOTE: in V0 we assume prefills are before decodes
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens
|
||||
|
||||
seq_idx = None
|
||||
chunk_indices, chunk_offsets = None, None
|
||||
# Need flags to indicate if there are initial states
|
||||
# currently we really only support the FlashAttention backend
|
||||
has_initial_states = None
|
||||
prep_initial_states = False
|
||||
if (isinstance(attn_metadata, (FlashAttentionMetadata, XFormersMetadata,
|
||||
PlaceholderAttentionMetadata))
|
||||
and attn_metadata.context_lens_tensor is not None):
|
||||
has_initial_states = attn_metadata.context_lens_tensor > 0
|
||||
# precompute flag to avoid device syncs later in mamba2 forwards
|
||||
prep_initial_states = torch.any(has_initial_states).item()
|
||||
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
|
||||
if num_prefills > 0:
|
||||
if (isinstance(attn_metadata,
|
||||
(FlashAttentionMetadata, XFormersMetadata,
|
||||
PlaceholderAttentionMetadata))
|
||||
and attn_metadata.context_lens_tensor is not None):
|
||||
has_initial_states = \
|
||||
attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,]
|
||||
# precompute flag to avoid device syncs in mamba2 layer forwards
|
||||
# prep is only needed for mamba2 ssd prefill processing
|
||||
prep_initial_states = torch.any(has_initial_states).item()
|
||||
|
||||
seq_idx = None
|
||||
chunk_indices, chunk_offsets = None, None
|
||||
if has_prefill:
|
||||
seq_idx = torch.zeros_like(input_ids, dtype=torch.int32)
|
||||
for i, (srt, end) in enumerate(
|
||||
zip(
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.query_start_loc[1:],
|
||||
)):
|
||||
seq_idx[srt:end] = i
|
||||
query_start_loc = attn_metadata.query_start_loc[:num_prefills + 1]
|
||||
seq_idx = torch.repeat_interleave(torch.arange(
|
||||
num_prefills, dtype=torch.int32, device=query_start_loc.device),
|
||||
query_start_loc.diff(),
|
||||
output_size=num_prefill_tokens)
|
||||
seq_idx.unsqueeze_(0)
|
||||
|
||||
# compute metadata for chunked prefill.
|
||||
# actually this is only needed if there are initial states,
|
||||
# but this is determinable only from attention metadata yet
|
||||
# unavailable from the top-level model forward. Rather than
|
||||
# complicating things to extract said metadata, we simply just
|
||||
# compute them once at the top level model forward and reuse
|
||||
# them in mamba layers. If not needed, they will be ignored
|
||||
# inside mamba kernels.
|
||||
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
|
||||
seq_idx, chunk_size)
|
||||
# We compute metadata for chunked prefill once at the top level model
|
||||
# forward and reuse them in mamba layers. If not needed, they will be
|
||||
# ignored inside mamba kernels.
|
||||
if prep_initial_states:
|
||||
chunk_indices, chunk_offsets = \
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
query_start_loc, chunk_size, num_prefill_tokens)
|
||||
|
||||
return Mamba2Metadata(has_prefill=has_prefill,
|
||||
has_initial_states=has_initial_states,
|
||||
return Mamba2Metadata(has_initial_states=has_initial_states,
|
||||
prep_initial_states=prep_initial_states,
|
||||
chunk_size=chunk_size,
|
||||
seq_idx=seq_idx,
|
||||
|
||||
@ -388,10 +388,15 @@ class MambaMixer2(CustomOp):
|
||||
# mamba2_metadata contains metadata necessary for the mamba2 triton
|
||||
# kernels to operate in continuous batching and in chunked prefill
|
||||
# modes; they are computed at top-level model forward since they
|
||||
# are the same and reused for all mamba layers in the same iteration
|
||||
# stay the same and reused for all mamba layers in the same iteration
|
||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||
|
||||
seq_len, _ = hidden_states.shape
|
||||
num_prefills = attn_metadata.num_prefills # request count
|
||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||
has_prefill = num_prefills > 0
|
||||
has_decode = num_decodes > 0
|
||||
|
||||
groups_time_state_size = self.n_groups * self.ssm_state_size
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
@ -406,44 +411,32 @@ class MambaMixer2(CustomOp):
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# 2. Convolution sequence transformation
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
|
||||
if mamba2_metadata.has_prefill:
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "mamba_cache_params.state_indices_tensor"
|
||||
hidden_states_B_C = causal_conv1d_fn(
|
||||
hidden_states_B_C.transpose(0, 1),
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=mamba_cache_params.conv_state,
|
||||
has_initial_state=mamba2_metadata.has_initial_states,
|
||||
cache_indices=mamba_cache_params.state_indices_tensor,
|
||||
query_start_loc=attn_metadata.query_start_loc).transpose(
|
||||
0, 1)[:seq_len]
|
||||
|
||||
# TODO: Why is this needed?
|
||||
hidden_states_B_C = hidden_states_B_C.contiguous()
|
||||
else:
|
||||
hidden_states_B_C = causal_conv1d_update(
|
||||
hidden_states_B_C,
|
||||
mamba_cache_params.conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=mamba_cache_params.state_indices_tensor)
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
# Split along token dimension
|
||||
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
|
||||
hidden_states_B_C,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
dt_p, dt_d = torch.split(
|
||||
dt,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
||||
mamba_cache_params.state_indices_tensor,
|
||||
[num_prefills, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
|
||||
if has_prefill else None)
|
||||
|
||||
# - get hidden_states, B and C after depthwise convolution.
|
||||
hidden_states, B, C = torch.split(
|
||||
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
|
||||
hidden_states_B_C,
|
||||
[
|
||||
self.intermediate_size // self.tp_size,
|
||||
@ -453,24 +446,48 @@ class MambaMixer2(CustomOp):
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
if mamba2_metadata.has_prefill:
|
||||
ssd_output_list = []
|
||||
|
||||
# Process prefill requests
|
||||
if has_prefill:
|
||||
# 2. Convolution sequence transformation
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "mamba_cache_params.state_indices_tensor"
|
||||
hidden_states_B_C_p = causal_conv1d_fn(
|
||||
hidden_states_B_C_p.transpose(0, 1),
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=mamba_cache_params.conv_state,
|
||||
has_initial_state=mamba2_metadata.has_initial_states,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
query_start_loc=query_start_loc_p).transpose(
|
||||
0, 1)[:num_prefill_tokens]
|
||||
|
||||
# TODO: Why is this needed?
|
||||
hidden_states_B_C_p = hidden_states_B_C_p.contiguous()
|
||||
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
|
||||
hidden_states_B_C_p)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
initial_states = None
|
||||
if (mamba2_metadata.has_initial_states is not None
|
||||
and mamba2_metadata.prep_initial_states):
|
||||
# making a copy of the states
|
||||
initial_states = torch.where(
|
||||
mamba2_metadata.has_initial_states[:, None, None, None],
|
||||
mamba_cache_params.ssm_state[
|
||||
mamba_cache_params.state_indices_tensor], 0)
|
||||
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
|
||||
|
||||
scan_output, varlen_state = mamba_chunk_scan_combined(
|
||||
hidden_states.view(1, seq_len, self.num_heads // self.tp_size,
|
||||
self.head_dim),
|
||||
dt.unsqueeze(0),
|
||||
hidden_states_p.view(1, num_prefill_tokens,
|
||||
self.num_heads // self.tp_size,
|
||||
self.head_dim),
|
||||
dt_p.unsqueeze(0),
|
||||
self.A,
|
||||
B.view(1, seq_len, self.n_groups // self.tp_size, -1),
|
||||
C.view(1, seq_len, self.n_groups // self.tp_size, -1),
|
||||
B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
|
||||
-1),
|
||||
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
|
||||
-1),
|
||||
chunk_size=mamba2_metadata.chunk_size,
|
||||
D=self.D,
|
||||
z=None,
|
||||
@ -478,7 +495,7 @@ class MambaMixer2(CustomOp):
|
||||
seq_idx=mamba2_metadata.seq_idx,
|
||||
chunk_indices=mamba2_metadata.chunk_indices,
|
||||
chunk_offsets=mamba2_metadata.chunk_offsets,
|
||||
cu_seqlens=attn_metadata.query_start_loc,
|
||||
cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1],
|
||||
initial_states=initial_states,
|
||||
return_varlen_states=True,
|
||||
return_final_states=False,
|
||||
@ -487,52 +504,65 @@ class MambaMixer2(CustomOp):
|
||||
)
|
||||
|
||||
# update ssm states
|
||||
# - varlen state is a (batch, nheads, headdim, dstate) tensor
|
||||
mamba_cache_params.ssm_state[
|
||||
mamba_cache_params.state_indices_tensor] = varlen_state
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
||||
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
|
||||
|
||||
# - reshape
|
||||
hidden_states = scan_output.view(seq_len, -1)
|
||||
else:
|
||||
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
|
||||
|
||||
# Process decode requests
|
||||
if has_decode:
|
||||
# 2. Convolution sequence transformation
|
||||
hidden_states_B_C_d = causal_conv1d_update(
|
||||
hidden_states_B_C_d,
|
||||
mamba_cache_params.conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=state_indices_tensor_d)
|
||||
|
||||
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(
|
||||
hidden_states_B_C_d)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
n_groups = self.n_groups // self.tp_size
|
||||
A = self.A[:, None, ...][:, :, None].expand(
|
||||
A_d = self.A[:, None, ...][:, :, None].expand(
|
||||
-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
|
||||
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
|
||||
dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
|
||||
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
|
||||
D = self.D[:, None, ...].expand(-1, self.head_dim)
|
||||
B = B.view(-1, n_groups, B.shape[1] // n_groups)
|
||||
C = C.view(-1, n_groups, C.shape[1] // n_groups)
|
||||
hidden_states_reshaped = hidden_states.view(
|
||||
D_d = self.D[:, None, ...].expand(-1, self.head_dim)
|
||||
B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
|
||||
C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
|
||||
hidden_states_d = hidden_states_d.view(
|
||||
-1, self.num_heads // self.tp_size, self.head_dim)
|
||||
|
||||
# - the hidden is reshaped into number of current batches
|
||||
# - in this case there is no more prefill, so the batches gen
|
||||
# 1 token at a time
|
||||
# - thus hidden will be (bs, num_heads, head_dim)
|
||||
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
||||
# - mamba_cache_params.ssm_state's slots will be selected
|
||||
# using "mamba_cache_params.state_indices_tensor", just as
|
||||
# above in the prefill case
|
||||
# using state_indices_tensor_d
|
||||
|
||||
hidden_states = selective_state_update(
|
||||
hidden_states_d = selective_state_update(
|
||||
mamba_cache_params.ssm_state,
|
||||
hidden_states_reshaped,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
hidden_states_d,
|
||||
dt_d,
|
||||
A_d,
|
||||
B_d,
|
||||
C_d,
|
||||
D_d,
|
||||
z=None,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=mamba_cache_params.state_indices_tensor,
|
||||
state_batch_indices=state_indices_tensor_d,
|
||||
)
|
||||
hidden_states = hidden_states.view(
|
||||
-1, (self.num_heads // self.tp_size) * self.head_dim)
|
||||
ssd_output_list.append(
|
||||
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
|
||||
self.head_dim))
|
||||
|
||||
# # 4. gated MLP
|
||||
# Merge prefill and decode outputs before passing to gated MLP
|
||||
hidden_states = torch.vstack(ssd_output_list)
|
||||
|
||||
# 4. gated MLP
|
||||
hidden_states = self.norm(hidden_states, gate)
|
||||
|
||||
# # 5. Final linear projection
|
||||
# 5. Final linear projection
|
||||
out, _ = self.out_proj(hidden_states)
|
||||
return out
|
||||
|
||||
@ -40,7 +40,6 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
_, _, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (batch, seqlen, ngroups, dstate)
|
||||
assert x.shape == (batch, seqlen, nheads, headdim)
|
||||
assert dt.shape == (batch, seqlen, nheads)
|
||||
assert A.shape == (nheads, )
|
||||
assert C.shape == B.shape
|
||||
|
||||
@ -313,7 +313,6 @@ class BambaModel(nn.Module):
|
||||
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.mamba_chunk_size,
|
||||
input_ids=input_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
|
||||
@ -338,7 +338,6 @@ class GraniteMoeHybridModel(nn.Module):
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.mamba_chunk_size,
|
||||
input_ids=input_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
|
||||
@ -142,7 +142,6 @@ class Mamba2Model(nn.Module):
|
||||
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.chunk_size,
|
||||
input_ids=input_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
|
||||
@ -751,7 +751,6 @@ class Zamba2Model(nn.Module):
|
||||
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.chunk_size,
|
||||
input_ids=input_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user