[V1] - Split Prefill and Decode for Mamba1 models (#22653)

Signed-off-by: amirk <amirk@ai21.com>
Signed-off-by: asafg <asafg@ai21.com>
Co-authored-by: asafg <asafg@ai21.com>
Co-authored-by: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
amirai21 2025-08-15 11:59:52 +03:00 committed by GitHub
parent 5406ebf5c9
commit fe91ce9591
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 251 additions and 93 deletions

View File

@ -57,6 +57,13 @@ V1_SUPPORTED_MODELS = [
# Avoid OOM
MAX_NUM_SEQS = 4
# Once we add support for FCG in Mamba1, this list will be removed and tests
# all test cases will use enforce_eager=False
ENFORCE_EAGER_MODELS_V1 = [
"state-spaces/mamba-130m-hf",
"ai21labs/Jamba-tiny-dev",
]
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("max_tokens", [64])
@ -94,13 +101,19 @@ def test_models(
example_prompts, max_tokens, num_logprobs)
if model in V1_SUPPORTED_MODELS:
enforce_eager = False
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
# required due to reorder_batch behaviour
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
if model in ENFORCE_EAGER_MODELS_V1:
enforce_eager = True
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
enforce_eager=enforce_eager,
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

View File

@ -1,13 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import NamedTuple, Optional
import torch
from torch import nn
from torch.nn.parameter import Parameter
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import get_current_vllm_config
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
@ -154,13 +155,38 @@ class MambaMixer(MambaBase, CustomOp):
self.prefix = prefix
def _ssm_transform(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.is_lora_enabled:
# Lora kernel requires contiguous tensor.
ssm_params = self.x_proj(x.contiguous())[0]
else:
ssm_params = self.x_proj(x)[0]
time_step, B, C = torch.split(
ssm_params,
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
dim=-1)
if self.use_rms_norm:
assert self.dt_layernorm is not None
assert self.b_layernorm is not None
assert self.c_layernorm is not None
time_step = self.dt_layernorm(time_step.contiguous())
B = self.b_layernorm(B.contiguous())
C = self.c_layernorm(C.contiguous())
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
return discrete_time_step, B, C
def forward(self,
hidden_states: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
if not envs.VLLM_USE_V1:
return CustomOp.forward(self, hidden_states, mamba_cache_params)
else:
return self.forward_cuda(hidden_states, mamba_cache_params)
return self.forward_cuda(
hidden_states,
mamba_cache_params,
)
def forward_native(self,
hidden_states: torch.Tensor,
@ -170,6 +196,27 @@ class MambaMixer(MambaBase, CustomOp):
def forward_cuda(self,
hidden_states: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
"""
Run the Mamba-1 SSM pipeline.
Steps
-----
1. Apply the gated-MLP linear projection to the raw input.
2. Pass the projected sequence through the convolutional mixing layer.
3. Feed the result into the State-Space Model (SSM) blocks.
4. Perform the recurrence y SSM(A, B, C, Δ)(x)
to produce contextual representations.
5. Project the contextualised sequence back
to the output embedding dimension.
Batch handling
--------------
Prefill and decode tokens are processed by dedicated CUDA
kernels for both the convolutional (conv1d) and SSM stages.
In the case of a mixed batch (containing both prefill and
decode tokens), both sets of kernels are executed independently
and their outputs are concatenated before the final output projection.
"""
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
@ -185,126 +232,142 @@ class MambaMixer(MambaBase, CustomOp):
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_state = mamba1_metadata.has_initial_states
context_lens_tensor = mamba1_metadata.context_lens_tensor
has_initial_states = mamba1_metadata.has_initial_states
else:
assert isinstance(attn_metadata, AttentionMetadata)
assert mamba_cache_params is not None
conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor
query_start_loc = attn_metadata.query_start_loc
context_lens_tensor = attn_metadata.context_lens_tensor
has_initial_states = None
if context_lens_tensor is not None:
has_initial_state = context_lens_tensor > 0
has_initial_states = context_lens_tensor > 0
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
hidden_states, gate = projected_states.chunk(2, dim=-2)
hidden_states_BC, gate = projected_states.chunk(2, dim=-2)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if envs.VLLM_USE_V1 and attn_metadata is None:
# V1 profile run
hidden_states = hidden_states.contiguous()
return self.out_proj(hidden_states.transpose(-2, -1))[0]
hidden_states_BC = hidden_states_BC.contiguous()
return self.out_proj(hidden_states_BC.transpose(-2, -1))[0]
if query_start_loc is not None and context_lens_tensor is not None:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states = causal_conv1d_fn(
hidden_states,
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
num_decode_tokens = attn_metadata.num_decode_tokens
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
has_prefill = num_prefill_tokens > 0
has_decode = num_decode_tokens > 0
prefill_decode_split = split_batch_to_prefill_and_decode(
hidden_states_BC,
gate,
state_indices_tensor,
query_start_loc,
has_initial_states,
num_prefill_tokens,
num_decode_tokens,
num_prefills,
num_decodes,
)
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
gate_p = prefill_decode_split.gate_p
gate_d = prefill_decode_split.gate_d
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
query_start_loc_p = prefill_decode_split.query_start_loc_p
has_initial_states_p = prefill_decode_split.has_initial_states_p
ssm_outputs = []
if has_prefill:
# 2. Convolution sequence transformation
conv_out_p = causal_conv1d_fn(
hidden_states_BC_p,
conv_weights,
bias=self.conv1d.bias,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
has_initial_state=has_initial_state,
cache_indices=state_indices_tensor,
query_start_loc=query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
query_start_loc=query_start_loc_p)
# 3. State Space Model sequence transformations.
discrete_time_step_p, B_p, C_p = self._ssm_transform(
conv_out_p.transpose(-2, -1))
time_proj_bias = self._time_proj_bias()
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
scan_out_p = selective_scan_fn(
conv_out_p,
ssm_state,
discrete_time_step_p,
self.A,
B_p.transpose(-2, -1),
C_p.transpose(-2, -1),
self.D.float(),
gate_p,
time_proj_bias,
delta_softplus=True,
cache_indices=state_indices_tensor_p,
has_initial_state=has_initial_states_p,
query_start_loc=query_start_loc_p)
ssm_outputs.append(scan_out_p)
if has_decode:
# 2. Convolution sequence transformation
conv_out_d = causal_conv1d_update(
hidden_states_BC_d.transpose(0, 1),
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1)
conv_state_indices=state_indices_tensor_d).transpose(0, 1)
# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
# 3. State Space Model sequence transformation.
discrete_time_step_d, B_d, C_d = self._ssm_transform(
conv_out_d.transpose(-2, -1))
time_proj_bias = self._time_proj_bias()
if self.is_lora_enabled:
# lora kernel requires contiguous tensor
ssm_parameters = self.x_proj(
hidden_states.transpose(-2, -1).contiguous())[0]
else:
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
time_step, B, C = torch.split(
ssm_parameters,
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
dim=-1,
)
if self.use_rms_norm:
assert self.dt_layernorm is not None
assert self.b_layernorm is not None
assert self.c_layernorm is not None
time_step = self.dt_layernorm(time_step.contiguous())
B = self.b_layernorm(B.contiguous())
C = self.c_layernorm(C.contiguous())
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
self.dt_proj, "bias") else None)
if query_start_loc is not None and context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
ssm_state,
discrete_time_step,
self.A,
B.transpose(-2, -1),
C.transpose(-2, -1),
self.D.float(),
gate,
time_proj_bias,
delta_softplus=True,
cache_indices=state_indices_tensor,
has_initial_state=has_initial_state,
query_start_loc=query_start_loc)
else:
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
scan_outputs_d = torch.empty_like(
hidden_states_BC_d.transpose(0, 1))
selective_state_update(ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
conv_out_d.transpose(0, 1),
discrete_time_step_d.transpose(0, 1),
self.A,
B,
C,
B_d,
C_d,
self.D,
gate.transpose(0, 1),
gate_d.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor,
out=scan_outputs)
scan_outputs = scan_outputs.transpose(0, 1)
state_batch_indices=state_indices_tensor_d,
out=scan_outputs_d)
scan_outputs_d = scan_outputs_d.transpose(0, 1)
# 4. Final linear projection
if self.is_lora_enabled:
# lora kernel requires contiguous tensor
contextualized_states = self.out_proj(
scan_outputs.transpose(-2, -1).contiguous())[0]
if envs.VLLM_USE_V1:
ssm_outputs.insert(0, scan_outputs_d)
else:
ssm_outputs.append(scan_outputs_d)
scan_outputs_combined = ssm_outputs[0] if len(
ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
# 5. Final output projection
if self.is_lora_enabled: # Lora kernel requires contiguous tensor.
scan_outputs_combined = scan_outputs_combined.transpose(
-2, -1).contiguous()
out = self.out_proj(scan_outputs_combined)[0]
else:
contextualized_states = self.out_proj(
scan_outputs.transpose(-2, -1))[0]
return contextualized_states
out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]
return out
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.mamba1_state_shape(
@ -317,3 +380,69 @@ class MambaMixer(MambaBase, CustomOp):
@property
def mamba_type(self) -> str:
return "mamba1"
def _time_proj_bias(self) -> Optional[torch.Tensor]:
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
return self.dt_proj.bias.float()
return None
class PrefillDecodeSplit(NamedTuple):
hidden_states_BC_p: torch.Tensor
hidden_states_BC_d: torch.Tensor
gate_p: torch.Tensor
gate_d: torch.Tensor
state_indices_tensor_p: torch.Tensor
state_indices_tensor_d: torch.Tensor
query_start_loc_p: Optional[torch.Tensor]
has_initial_states_p: Optional[torch.Tensor]
def split_batch_to_prefill_and_decode(
hidden_states_BC: torch.Tensor,
gate: torch.Tensor,
state_indices_tensor: torch.Tensor,
query_start_loc: torch.Tensor,
has_initial_states: Optional[torch.Tensor],
num_prefill_tokens: int,
num_decode_tokens: int,
num_prefills: int,
num_decodes: int,
) -> PrefillDecodeSplit:
if envs.VLLM_USE_V1:
# In v1, decode tokens come first, then prefill tokens.
hidden_states_BC_d, hidden_states_BC_p = torch.split(
hidden_states_BC, [num_decode_tokens, num_prefill_tokens], dim=-1)
gate_d, gate_p = torch.split(gate,
[num_decode_tokens, num_prefill_tokens],
dim=-1)
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor, [num_decodes, num_prefills], dim=0)
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
num_decodes if num_prefills > 0 else None)
has_initial_states_p = has_initial_states[-num_prefills:] if (
has_initial_states is not None and num_prefills > 0) else None
else:
# In v0, prefill tokens come first, then decode tokens.
hidden_states_BC_p, hidden_states_BC_d = torch.split(
hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1)
gate_p, gate_d = torch.split(gate,
[num_prefill_tokens, num_decode_tokens],
dim=-1)
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor, [num_prefills, num_decodes], dim=0)
query_start_loc_p = (query_start_loc[:num_prefills +
1] if num_prefills > 0 else None)
has_initial_states_p = has_initial_states[:num_prefills] if (
has_initial_states is not None and num_prefills > 0) else None
return PrefillDecodeSplit(
hidden_states_BC_p=hidden_states_BC_p,
hidden_states_BC_d=hidden_states_BC_d,
gate_p=gate_p,
gate_d=gate_d,
state_indices_tensor_p=state_indices_tensor_p,
state_indices_tensor_d=state_indices_tensor_d,
query_start_loc_p=query_start_loc_p,
has_initial_states_p=has_initial_states_p,
)

View File

@ -2,14 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar
from typing import ClassVar, Optional
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
@ -25,12 +26,15 @@ class Mamba1AttentionMetadata:
query_start_loc: torch.Tensor
context_lens_tensor: torch.Tensor
state_indices_tensor: torch.Tensor
has_initial_states: torch.Tensor
has_initial_states: Optional[torch.Tensor]
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
class Mamba1AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba1AttentionMetadata]):
reorder_batch_threshold: ClassVar[int] = 1
def __init__(
@ -57,11 +61,23 @@ class Mamba1AttentionMetadataBuilder(
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
query_start_loc.device)
has_initial_states = (context_lens_tensor > 0)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=1))
has_initial_states = None
if num_prefills > 0:
has_initial_states = context_lens_tensor > 0
return Mamba1AttentionMetadata(
query_start_loc=query_start_loc,
context_lens_tensor=context_lens_tensor,
has_initial_states=has_initial_states,
state_indices_tensor=state_indices_tensor,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
)