[V1] - Split Prefill and Decode for Mamba1 models (#22653)

Signed-off-by: amirk <amirk@ai21.com>
Signed-off-by: asafg <asafg@ai21.com>
Co-authored-by: asafg <asafg@ai21.com>
Co-authored-by: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
amirai21 2025-08-15 11:59:52 +03:00 committed by GitHub
parent 5406ebf5c9
commit fe91ce9591
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 251 additions and 93 deletions

View File

@ -57,6 +57,13 @@ V1_SUPPORTED_MODELS = [
# Avoid OOM # Avoid OOM
MAX_NUM_SEQS = 4 MAX_NUM_SEQS = 4
# Once we add support for FCG in Mamba1, this list will be removed and tests
# all test cases will use enforce_eager=False
ENFORCE_EAGER_MODELS_V1 = [
"state-spaces/mamba-130m-hf",
"ai21labs/Jamba-tiny-dev",
]
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
@ -94,13 +101,19 @@ def test_models(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
if model in V1_SUPPORTED_MODELS: if model in V1_SUPPORTED_MODELS:
enforce_eager = False
with monkeypatch.context() as m: with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS: if model in HYBRID_MODELS:
# required due to reorder_batch behaviour # required due to reorder_batch behaviour
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
if model in ENFORCE_EAGER_MODELS_V1:
enforce_eager = True
with vllm_runner(model, with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS, max_num_seqs=MAX_NUM_SEQS,
enforce_eager=enforce_eager,
enable_prefix_caching=False) as vllm_model: enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs( vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)

View File

@ -1,13 +1,14 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import NamedTuple, Optional
import torch import torch
from torch import nn from torch import nn
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import envs from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
@ -154,13 +155,38 @@ class MambaMixer(MambaBase, CustomOp):
self.prefix = prefix self.prefix = prefix
def _ssm_transform(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.is_lora_enabled:
# Lora kernel requires contiguous tensor.
ssm_params = self.x_proj(x.contiguous())[0]
else:
ssm_params = self.x_proj(x)[0]
time_step, B, C = torch.split(
ssm_params,
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
dim=-1)
if self.use_rms_norm:
assert self.dt_layernorm is not None
assert self.b_layernorm is not None
assert self.c_layernorm is not None
time_step = self.dt_layernorm(time_step.contiguous())
B = self.b_layernorm(B.contiguous())
C = self.c_layernorm(C.contiguous())
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
return discrete_time_step, B, C
def forward(self, def forward(self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None): mamba_cache_params: Optional[MambaCacheParams] = None):
if not envs.VLLM_USE_V1: if not envs.VLLM_USE_V1:
return CustomOp.forward(self, hidden_states, mamba_cache_params) return CustomOp.forward(self, hidden_states, mamba_cache_params)
else: else:
return self.forward_cuda(hidden_states, mamba_cache_params) return self.forward_cuda(
hidden_states,
mamba_cache_params,
)
def forward_native(self, def forward_native(self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -170,6 +196,27 @@ class MambaMixer(MambaBase, CustomOp):
def forward_cuda(self, def forward_cuda(self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None): mamba_cache_params: Optional[MambaCacheParams] = None):
"""
Run the Mamba-1 SSM pipeline.
Steps
-----
1. Apply the gated-MLP linear projection to the raw input.
2. Pass the projected sequence through the convolutional mixing layer.
3. Feed the result into the State-Space Model (SSM) blocks.
4. Perform the recurrence y SSM(A, B, C, Δ)(x)
to produce contextual representations.
5. Project the contextualised sequence back
to the output embedding dimension.
Batch handling
--------------
Prefill and decode tokens are processed by dedicated CUDA
kernels for both the convolutional (conv1d) and SSM stages.
In the case of a mixed batch (containing both prefill and
decode tokens), both sets of kernels are executed independently
and their outputs are concatenated before the final output projection.
"""
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
@ -185,126 +232,142 @@ class MambaMixer(MambaBase, CustomOp):
self_kv_cache = self.kv_cache[forward_context.virtual_engine] self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2) conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1] ssm_state = self_kv_cache[1]
has_initial_state = mamba1_metadata.has_initial_states has_initial_states = mamba1_metadata.has_initial_states
context_lens_tensor = mamba1_metadata.context_lens_tensor
else: else:
assert isinstance(attn_metadata, AttentionMetadata)
assert mamba_cache_params is not None assert mamba_cache_params is not None
conv_state = mamba_cache_params.conv_state conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor state_indices_tensor = mamba_cache_params.state_indices_tensor
query_start_loc = attn_metadata.query_start_loc query_start_loc = attn_metadata.query_start_loc
context_lens_tensor = attn_metadata.context_lens_tensor context_lens_tensor = attn_metadata.context_lens_tensor
has_initial_states = None
if context_lens_tensor is not None: if context_lens_tensor is not None:
has_initial_state = context_lens_tensor > 0 has_initial_states = context_lens_tensor > 0
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1) projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
hidden_states, gate = projected_states.chunk(2, dim=-2) hidden_states_BC, gate = projected_states.chunk(2, dim=-2)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2)) self.conv1d.weight.size(2))
if envs.VLLM_USE_V1 and attn_metadata is None: if envs.VLLM_USE_V1 and attn_metadata is None:
# V1 profile run # V1 profile run
hidden_states = hidden_states.contiguous() hidden_states_BC = hidden_states_BC.contiguous()
return self.out_proj(hidden_states.transpose(-2, -1))[0] return self.out_proj(hidden_states_BC.transpose(-2, -1))[0]
if query_start_loc is not None and context_lens_tensor is not None: num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
# |---------- N-1 iteration --------| num_decode_tokens = attn_metadata.num_decode_tokens
# |---------------- N iteration ---------------------| num_prefills = attn_metadata.num_prefills # request count
# |- tokenA -|......................|-- newTokens ---| num_decodes = attn_metadata.num_decode_tokens # token count (=request)
# |---------- context_len ----------| has_prefill = num_prefill_tokens > 0
# |-------------------- seq_len ---------------------| has_decode = num_decode_tokens > 0
# |-- query_len ---|
hidden_states = causal_conv1d_fn( prefill_decode_split = split_batch_to_prefill_and_decode(
hidden_states, hidden_states_BC,
gate,
state_indices_tensor,
query_start_loc,
has_initial_states,
num_prefill_tokens,
num_decode_tokens,
num_prefills,
num_decodes,
)
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
gate_p = prefill_decode_split.gate_p
gate_d = prefill_decode_split.gate_d
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
query_start_loc_p = prefill_decode_split.query_start_loc_p
has_initial_states_p = prefill_decode_split.has_initial_states_p
ssm_outputs = []
if has_prefill:
# 2. Convolution sequence transformation
conv_out_p = causal_conv1d_fn(
hidden_states_BC_p,
conv_weights, conv_weights,
bias=self.conv1d.bias, self.conv1d.bias,
activation=self.activation, activation=self.activation,
conv_states=conv_state, conv_states=conv_state,
has_initial_state=has_initial_state, has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor, cache_indices=state_indices_tensor_p,
query_start_loc=query_start_loc) query_start_loc=query_start_loc_p)
else: # 3. State Space Model sequence transformations.
hidden_states = causal_conv1d_update( discrete_time_step_p, B_p, C_p = self._ssm_transform(
hidden_states.transpose(0, 1), conv_out_p.transpose(-2, -1))
time_proj_bias = self._time_proj_bias()
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
scan_out_p = selective_scan_fn(
conv_out_p,
ssm_state,
discrete_time_step_p,
self.A,
B_p.transpose(-2, -1),
C_p.transpose(-2, -1),
self.D.float(),
gate_p,
time_proj_bias,
delta_softplus=True,
cache_indices=state_indices_tensor_p,
has_initial_state=has_initial_states_p,
query_start_loc=query_start_loc_p)
ssm_outputs.append(scan_out_p)
if has_decode:
# 2. Convolution sequence transformation
conv_out_d = causal_conv1d_update(
hidden_states_BC_d.transpose(0, 1),
conv_state, conv_state,
conv_weights, conv_weights,
self.conv1d.bias, self.conv1d.bias,
self.activation, self.activation,
conv_state_indices=state_indices_tensor) conv_state_indices=state_indices_tensor_d).transpose(0, 1)
hidden_states = hidden_states.transpose(0, 1)
# 3. State Space Model sequence transformation # 3. State Space Model sequence transformation.
# 3.a. input varying initialization of time_step, B and C discrete_time_step_d, B_d, C_d = self._ssm_transform(
conv_out_d.transpose(-2, -1))
time_proj_bias = self._time_proj_bias()
if self.is_lora_enabled: # 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
# lora kernel requires contiguous tensor scan_outputs_d = torch.empty_like(
ssm_parameters = self.x_proj( hidden_states_BC_d.transpose(0, 1))
hidden_states.transpose(-2, -1).contiguous())[0]
else:
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
time_step, B, C = torch.split(
ssm_parameters,
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
dim=-1,
)
if self.use_rms_norm:
assert self.dt_layernorm is not None
assert self.b_layernorm is not None
assert self.c_layernorm is not None
time_step = self.dt_layernorm(time_step.contiguous())
B = self.b_layernorm(B.contiguous())
C = self.c_layernorm(C.contiguous())
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
self.dt_proj, "bias") else None)
if query_start_loc is not None and context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
ssm_state,
discrete_time_step,
self.A,
B.transpose(-2, -1),
C.transpose(-2, -1),
self.D.float(),
gate,
time_proj_bias,
delta_softplus=True,
cache_indices=state_indices_tensor,
has_initial_state=has_initial_state,
query_start_loc=query_start_loc)
else:
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
selective_state_update(ssm_state, selective_state_update(ssm_state,
hidden_states.transpose(0, 1), conv_out_d.transpose(0, 1),
discrete_time_step.transpose(0, 1), discrete_time_step_d.transpose(0, 1),
self.A, self.A,
B, B_d,
C, C_d,
self.D, self.D,
gate.transpose(0, 1), gate_d.transpose(0, 1),
time_proj_bias, time_proj_bias,
dt_softplus=True, dt_softplus=True,
state_batch_indices=state_indices_tensor, state_batch_indices=state_indices_tensor_d,
out=scan_outputs) out=scan_outputs_d)
scan_outputs = scan_outputs.transpose(0, 1) scan_outputs_d = scan_outputs_d.transpose(0, 1)
# 4. Final linear projection if envs.VLLM_USE_V1:
if self.is_lora_enabled: ssm_outputs.insert(0, scan_outputs_d)
# lora kernel requires contiguous tensor else:
contextualized_states = self.out_proj( ssm_outputs.append(scan_outputs_d)
scan_outputs.transpose(-2, -1).contiguous())[0]
scan_outputs_combined = ssm_outputs[0] if len(
ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
# 5. Final output projection
if self.is_lora_enabled: # Lora kernel requires contiguous tensor.
scan_outputs_combined = scan_outputs_combined.transpose(
-2, -1).contiguous()
out = self.out_proj(scan_outputs_combined)[0]
else: else:
contextualized_states = self.out_proj( out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]
scan_outputs.transpose(-2, -1))[0]
return contextualized_states return out
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.mamba1_state_shape( return MambaStateShapeCalculator.mamba1_state_shape(
@ -317,3 +380,69 @@ class MambaMixer(MambaBase, CustomOp):
@property @property
def mamba_type(self) -> str: def mamba_type(self) -> str:
return "mamba1" return "mamba1"
def _time_proj_bias(self) -> Optional[torch.Tensor]:
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
return self.dt_proj.bias.float()
return None
class PrefillDecodeSplit(NamedTuple):
hidden_states_BC_p: torch.Tensor
hidden_states_BC_d: torch.Tensor
gate_p: torch.Tensor
gate_d: torch.Tensor
state_indices_tensor_p: torch.Tensor
state_indices_tensor_d: torch.Tensor
query_start_loc_p: Optional[torch.Tensor]
has_initial_states_p: Optional[torch.Tensor]
def split_batch_to_prefill_and_decode(
hidden_states_BC: torch.Tensor,
gate: torch.Tensor,
state_indices_tensor: torch.Tensor,
query_start_loc: torch.Tensor,
has_initial_states: Optional[torch.Tensor],
num_prefill_tokens: int,
num_decode_tokens: int,
num_prefills: int,
num_decodes: int,
) -> PrefillDecodeSplit:
if envs.VLLM_USE_V1:
# In v1, decode tokens come first, then prefill tokens.
hidden_states_BC_d, hidden_states_BC_p = torch.split(
hidden_states_BC, [num_decode_tokens, num_prefill_tokens], dim=-1)
gate_d, gate_p = torch.split(gate,
[num_decode_tokens, num_prefill_tokens],
dim=-1)
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor, [num_decodes, num_prefills], dim=0)
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
num_decodes if num_prefills > 0 else None)
has_initial_states_p = has_initial_states[-num_prefills:] if (
has_initial_states is not None and num_prefills > 0) else None
else:
# In v0, prefill tokens come first, then decode tokens.
hidden_states_BC_p, hidden_states_BC_d = torch.split(
hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1)
gate_p, gate_d = torch.split(gate,
[num_prefill_tokens, num_decode_tokens],
dim=-1)
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor, [num_prefills, num_decodes], dim=0)
query_start_loc_p = (query_start_loc[:num_prefills +
1] if num_prefills > 0 else None)
has_initial_states_p = has_initial_states[:num_prefills] if (
has_initial_states is not None and num_prefills > 0) else None
return PrefillDecodeSplit(
hidden_states_BC_p=hidden_states_BC_p,
hidden_states_BC_d=hidden_states_BC_d,
gate_p=gate_p,
gate_d=gate_d,
state_indices_tensor_p=state_indices_tensor_p,
state_indices_tensor_d=state_indices_tensor_d,
query_start_loc_p=query_start_loc_p,
has_initial_states_p=has_initial_states_p,
)

View File

@ -2,14 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import ClassVar from typing import ClassVar, Optional
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
@ -25,12 +26,15 @@ class Mamba1AttentionMetadata:
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
context_lens_tensor: torch.Tensor context_lens_tensor: torch.Tensor
state_indices_tensor: torch.Tensor state_indices_tensor: torch.Tensor
has_initial_states: torch.Tensor has_initial_states: Optional[torch.Tensor]
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
class Mamba1AttentionMetadataBuilder( class Mamba1AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba1AttentionMetadata]): AttentionMetadataBuilder[Mamba1AttentionMetadata]):
reorder_batch_threshold: ClassVar[int] = 1 reorder_batch_threshold: ClassVar[int] = 1
def __init__( def __init__(
@ -57,11 +61,23 @@ class Mamba1AttentionMetadataBuilder(
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to( context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
query_start_loc.device) query_start_loc.device)
has_initial_states = (context_lens_tensor > 0)
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(common_attn_metadata,
decode_threshold=1))
has_initial_states = None
if num_prefills > 0:
has_initial_states = context_lens_tensor > 0
return Mamba1AttentionMetadata( return Mamba1AttentionMetadata(
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
context_lens_tensor=context_lens_tensor, context_lens_tensor=context_lens_tensor,
has_initial_states=has_initial_states, has_initial_states=has_initial_states,
state_indices_tensor=state_indices_tensor, state_indices_tensor=state_indices_tensor,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
) )