mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:15:20 +08:00
[V1] - Split Prefill and Decode for Mamba1 models (#22653)
Signed-off-by: amirk <amirk@ai21.com> Signed-off-by: asafg <asafg@ai21.com> Co-authored-by: asafg <asafg@ai21.com> Co-authored-by: Asaf Joseph Gardin <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
parent
5406ebf5c9
commit
fe91ce9591
@ -57,6 +57,13 @@ V1_SUPPORTED_MODELS = [
|
||||
# Avoid OOM
|
||||
MAX_NUM_SEQS = 4
|
||||
|
||||
# Once we add support for FCG in Mamba1, this list will be removed and tests
|
||||
# all test cases will use enforce_eager=False
|
||||
ENFORCE_EAGER_MODELS_V1 = [
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"ai21labs/Jamba-tiny-dev",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@ -94,13 +101,19 @@ def test_models(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
if model in V1_SUPPORTED_MODELS:
|
||||
enforce_eager = False
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
if model in HYBRID_MODELS:
|
||||
# required due to reorder_batch behaviour
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||
|
||||
if model in ENFORCE_EAGER_MODELS_V1:
|
||||
enforce_eager = True
|
||||
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
enforce_eager=enforce_eager,
|
||||
enable_prefix_caching=False) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
@ -154,13 +155,38 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
|
||||
self.prefix = prefix
|
||||
|
||||
def _ssm_transform(
|
||||
self, x: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if self.is_lora_enabled:
|
||||
# Lora kernel requires contiguous tensor.
|
||||
ssm_params = self.x_proj(x.contiguous())[0]
|
||||
else:
|
||||
ssm_params = self.x_proj(x)[0]
|
||||
time_step, B, C = torch.split(
|
||||
ssm_params,
|
||||
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
|
||||
dim=-1)
|
||||
if self.use_rms_norm:
|
||||
assert self.dt_layernorm is not None
|
||||
assert self.b_layernorm is not None
|
||||
assert self.c_layernorm is not None
|
||||
time_step = self.dt_layernorm(time_step.contiguous())
|
||||
B = self.b_layernorm(B.contiguous())
|
||||
C = self.c_layernorm(C.contiguous())
|
||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
||||
return discrete_time_step, B, C
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||
if not envs.VLLM_USE_V1:
|
||||
return CustomOp.forward(self, hidden_states, mamba_cache_params)
|
||||
else:
|
||||
return self.forward_cuda(hidden_states, mamba_cache_params)
|
||||
return self.forward_cuda(
|
||||
hidden_states,
|
||||
mamba_cache_params,
|
||||
)
|
||||
|
||||
def forward_native(self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -170,6 +196,27 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
def forward_cuda(self,
|
||||
hidden_states: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||
"""
|
||||
Run the Mamba-1 SSM pipeline.
|
||||
|
||||
Steps
|
||||
-----
|
||||
1. Apply the gated-MLP linear projection to the raw input.
|
||||
2. Pass the projected sequence through the convolutional mixing layer.
|
||||
3. Feed the result into the State-Space Model (SSM) blocks.
|
||||
4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
|
||||
to produce contextual representations.
|
||||
5. Project the contextualised sequence back
|
||||
to the output embedding dimension.
|
||||
|
||||
Batch handling
|
||||
--------------
|
||||
Prefill and decode tokens are processed by dedicated CUDA
|
||||
kernels for both the convolutional (conv1d) and SSM stages.
|
||||
In the case of a mixed batch (containing both prefill and
|
||||
decode tokens), both sets of kernels are executed independently
|
||||
and their outputs are concatenated before the final output projection.
|
||||
"""
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
@ -185,126 +232,142 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
has_initial_state = mamba1_metadata.has_initial_states
|
||||
context_lens_tensor = mamba1_metadata.context_lens_tensor
|
||||
has_initial_states = mamba1_metadata.has_initial_states
|
||||
else:
|
||||
assert isinstance(attn_metadata, AttentionMetadata)
|
||||
assert mamba_cache_params is not None
|
||||
conv_state = mamba_cache_params.conv_state
|
||||
ssm_state = mamba_cache_params.ssm_state
|
||||
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
||||
query_start_loc = attn_metadata.query_start_loc
|
||||
context_lens_tensor = attn_metadata.context_lens_tensor
|
||||
|
||||
has_initial_states = None
|
||||
if context_lens_tensor is not None:
|
||||
has_initial_state = context_lens_tensor > 0
|
||||
has_initial_states = context_lens_tensor > 0
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
hidden_states, gate = projected_states.chunk(2, dim=-2)
|
||||
hidden_states_BC, gate = projected_states.chunk(2, dim=-2)
|
||||
|
||||
# 2. Convolution sequence transformation
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
|
||||
if envs.VLLM_USE_V1 and attn_metadata is None:
|
||||
# V1 profile run
|
||||
hidden_states = hidden_states.contiguous()
|
||||
return self.out_proj(hidden_states.transpose(-2, -1))[0]
|
||||
hidden_states_BC = hidden_states_BC.contiguous()
|
||||
return self.out_proj(hidden_states_BC.transpose(-2, -1))[0]
|
||||
|
||||
if query_start_loc is not None and context_lens_tensor is not None:
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
hidden_states = causal_conv1d_fn(
|
||||
hidden_states,
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_prefills = attn_metadata.num_prefills # request count
|
||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||
has_prefill = num_prefill_tokens > 0
|
||||
has_decode = num_decode_tokens > 0
|
||||
|
||||
prefill_decode_split = split_batch_to_prefill_and_decode(
|
||||
hidden_states_BC,
|
||||
gate,
|
||||
state_indices_tensor,
|
||||
query_start_loc,
|
||||
has_initial_states,
|
||||
num_prefill_tokens,
|
||||
num_decode_tokens,
|
||||
num_prefills,
|
||||
num_decodes,
|
||||
)
|
||||
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
|
||||
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
|
||||
gate_p = prefill_decode_split.gate_p
|
||||
gate_d = prefill_decode_split.gate_d
|
||||
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
|
||||
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
|
||||
query_start_loc_p = prefill_decode_split.query_start_loc_p
|
||||
has_initial_states_p = prefill_decode_split.has_initial_states_p
|
||||
|
||||
ssm_outputs = []
|
||||
|
||||
if has_prefill:
|
||||
# 2. Convolution sequence transformation
|
||||
conv_out_p = causal_conv1d_fn(
|
||||
hidden_states_BC_p,
|
||||
conv_weights,
|
||||
bias=self.conv1d.bias,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_state,
|
||||
cache_indices=state_indices_tensor,
|
||||
query_start_loc=query_start_loc)
|
||||
else:
|
||||
hidden_states = causal_conv1d_update(
|
||||
hidden_states.transpose(0, 1),
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
query_start_loc=query_start_loc_p)
|
||||
# 3. State Space Model sequence transformations.
|
||||
discrete_time_step_p, B_p, C_p = self._ssm_transform(
|
||||
conv_out_p.transpose(-2, -1))
|
||||
time_proj_bias = self._time_proj_bias()
|
||||
|
||||
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
|
||||
scan_out_p = selective_scan_fn(
|
||||
conv_out_p,
|
||||
ssm_state,
|
||||
discrete_time_step_p,
|
||||
self.A,
|
||||
B_p.transpose(-2, -1),
|
||||
C_p.transpose(-2, -1),
|
||||
self.D.float(),
|
||||
gate_p,
|
||||
time_proj_bias,
|
||||
delta_softplus=True,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
has_initial_state=has_initial_states_p,
|
||||
query_start_loc=query_start_loc_p)
|
||||
ssm_outputs.append(scan_out_p)
|
||||
|
||||
if has_decode:
|
||||
# 2. Convolution sequence transformation
|
||||
conv_out_d = causal_conv1d_update(
|
||||
hidden_states_BC_d.transpose(0, 1),
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=state_indices_tensor)
|
||||
hidden_states = hidden_states.transpose(0, 1)
|
||||
conv_state_indices=state_indices_tensor_d).transpose(0, 1)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
# 3.a. input varying initialization of time_step, B and C
|
||||
# 3. State Space Model sequence transformation.
|
||||
discrete_time_step_d, B_d, C_d = self._ssm_transform(
|
||||
conv_out_d.transpose(-2, -1))
|
||||
time_proj_bias = self._time_proj_bias()
|
||||
|
||||
if self.is_lora_enabled:
|
||||
# lora kernel requires contiguous tensor
|
||||
ssm_parameters = self.x_proj(
|
||||
hidden_states.transpose(-2, -1).contiguous())[0]
|
||||
else:
|
||||
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]
|
||||
|
||||
time_step, B, C = torch.split(
|
||||
ssm_parameters,
|
||||
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
|
||||
dim=-1,
|
||||
)
|
||||
if self.use_rms_norm:
|
||||
assert self.dt_layernorm is not None
|
||||
assert self.b_layernorm is not None
|
||||
assert self.c_layernorm is not None
|
||||
time_step = self.dt_layernorm(time_step.contiguous())
|
||||
B = self.b_layernorm(B.contiguous())
|
||||
C = self.c_layernorm(C.contiguous())
|
||||
|
||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
||||
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
|
||||
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
|
||||
self.dt_proj, "bias") else None)
|
||||
|
||||
if query_start_loc is not None and context_lens_tensor is not None:
|
||||
scan_outputs = selective_scan_fn(
|
||||
hidden_states,
|
||||
ssm_state,
|
||||
discrete_time_step,
|
||||
self.A,
|
||||
B.transpose(-2, -1),
|
||||
C.transpose(-2, -1),
|
||||
self.D.float(),
|
||||
gate,
|
||||
time_proj_bias,
|
||||
delta_softplus=True,
|
||||
cache_indices=state_indices_tensor,
|
||||
has_initial_state=has_initial_state,
|
||||
query_start_loc=query_start_loc)
|
||||
else:
|
||||
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
|
||||
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
|
||||
scan_outputs_d = torch.empty_like(
|
||||
hidden_states_BC_d.transpose(0, 1))
|
||||
selective_state_update(ssm_state,
|
||||
hidden_states.transpose(0, 1),
|
||||
discrete_time_step.transpose(0, 1),
|
||||
conv_out_d.transpose(0, 1),
|
||||
discrete_time_step_d.transpose(0, 1),
|
||||
self.A,
|
||||
B,
|
||||
C,
|
||||
B_d,
|
||||
C_d,
|
||||
self.D,
|
||||
gate.transpose(0, 1),
|
||||
gate_d.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor,
|
||||
out=scan_outputs)
|
||||
scan_outputs = scan_outputs.transpose(0, 1)
|
||||
state_batch_indices=state_indices_tensor_d,
|
||||
out=scan_outputs_d)
|
||||
scan_outputs_d = scan_outputs_d.transpose(0, 1)
|
||||
|
||||
# 4. Final linear projection
|
||||
if self.is_lora_enabled:
|
||||
# lora kernel requires contiguous tensor
|
||||
contextualized_states = self.out_proj(
|
||||
scan_outputs.transpose(-2, -1).contiguous())[0]
|
||||
if envs.VLLM_USE_V1:
|
||||
ssm_outputs.insert(0, scan_outputs_d)
|
||||
else:
|
||||
ssm_outputs.append(scan_outputs_d)
|
||||
|
||||
scan_outputs_combined = ssm_outputs[0] if len(
|
||||
ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
|
||||
|
||||
# 5. Final output projection
|
||||
if self.is_lora_enabled: # Lora kernel requires contiguous tensor.
|
||||
scan_outputs_combined = scan_outputs_combined.transpose(
|
||||
-2, -1).contiguous()
|
||||
out = self.out_proj(scan_outputs_combined)[0]
|
||||
else:
|
||||
contextualized_states = self.out_proj(
|
||||
scan_outputs.transpose(-2, -1))[0]
|
||||
return contextualized_states
|
||||
out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]
|
||||
|
||||
return out
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
return MambaStateShapeCalculator.mamba1_state_shape(
|
||||
@ -317,3 +380,69 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
@property
|
||||
def mamba_type(self) -> str:
|
||||
return "mamba1"
|
||||
|
||||
def _time_proj_bias(self) -> Optional[torch.Tensor]:
|
||||
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
|
||||
return self.dt_proj.bias.float()
|
||||
return None
|
||||
|
||||
|
||||
class PrefillDecodeSplit(NamedTuple):
|
||||
hidden_states_BC_p: torch.Tensor
|
||||
hidden_states_BC_d: torch.Tensor
|
||||
gate_p: torch.Tensor
|
||||
gate_d: torch.Tensor
|
||||
state_indices_tensor_p: torch.Tensor
|
||||
state_indices_tensor_d: torch.Tensor
|
||||
query_start_loc_p: Optional[torch.Tensor]
|
||||
has_initial_states_p: Optional[torch.Tensor]
|
||||
|
||||
|
||||
def split_batch_to_prefill_and_decode(
|
||||
hidden_states_BC: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
state_indices_tensor: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
has_initial_states: Optional[torch.Tensor],
|
||||
num_prefill_tokens: int,
|
||||
num_decode_tokens: int,
|
||||
num_prefills: int,
|
||||
num_decodes: int,
|
||||
) -> PrefillDecodeSplit:
|
||||
if envs.VLLM_USE_V1:
|
||||
# In v1, decode tokens come first, then prefill tokens.
|
||||
hidden_states_BC_d, hidden_states_BC_p = torch.split(
|
||||
hidden_states_BC, [num_decode_tokens, num_prefill_tokens], dim=-1)
|
||||
gate_d, gate_p = torch.split(gate,
|
||||
[num_decode_tokens, num_prefill_tokens],
|
||||
dim=-1)
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor, [num_decodes, num_prefills], dim=0)
|
||||
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
|
||||
num_decodes if num_prefills > 0 else None)
|
||||
has_initial_states_p = has_initial_states[-num_prefills:] if (
|
||||
has_initial_states is not None and num_prefills > 0) else None
|
||||
else:
|
||||
# In v0, prefill tokens come first, then decode tokens.
|
||||
hidden_states_BC_p, hidden_states_BC_d = torch.split(
|
||||
hidden_states_BC, [num_prefill_tokens, num_decode_tokens], dim=-1)
|
||||
gate_p, gate_d = torch.split(gate,
|
||||
[num_prefill_tokens, num_decode_tokens],
|
||||
dim=-1)
|
||||
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
||||
state_indices_tensor, [num_prefills, num_decodes], dim=0)
|
||||
query_start_loc_p = (query_start_loc[:num_prefills +
|
||||
1] if num_prefills > 0 else None)
|
||||
has_initial_states_p = has_initial_states[:num_prefills] if (
|
||||
has_initial_states is not None and num_prefills > 0) else None
|
||||
|
||||
return PrefillDecodeSplit(
|
||||
hidden_states_BC_p=hidden_states_BC_p,
|
||||
hidden_states_BC_d=hidden_states_BC_d,
|
||||
gate_p=gate_p,
|
||||
gate_d=gate_d,
|
||||
state_indices_tensor_p=state_indices_tensor_p,
|
||||
state_indices_tensor_d=state_indices_tensor_d,
|
||||
query_start_loc_p=query_start_loc_p,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
)
|
||||
|
||||
@ -2,14 +2,15 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
@ -25,12 +26,15 @@ class Mamba1AttentionMetadata:
|
||||
query_start_loc: torch.Tensor
|
||||
context_lens_tensor: torch.Tensor
|
||||
state_indices_tensor: torch.Tensor
|
||||
has_initial_states: torch.Tensor
|
||||
has_initial_states: Optional[torch.Tensor]
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
|
||||
|
||||
class Mamba1AttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[Mamba1AttentionMetadata]):
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 1
|
||||
|
||||
def __init__(
|
||||
@ -57,11 +61,23 @@ class Mamba1AttentionMetadataBuilder(
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
context_lens_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
query_start_loc.device)
|
||||
has_initial_states = (context_lens_tensor > 0)
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(common_attn_metadata,
|
||||
decode_threshold=1))
|
||||
|
||||
has_initial_states = None
|
||||
|
||||
if num_prefills > 0:
|
||||
has_initial_states = context_lens_tensor > 0
|
||||
|
||||
return Mamba1AttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
has_initial_states=has_initial_states,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user