[v1] Support mamba2 (#19327)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-06-19 04:34:15 +08:00 committed by GitHub
parent ffacb222cb
commit a89209b78d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 582 additions and 120 deletions

View File

@ -17,9 +17,10 @@ SSM_MODELS = [
"state-spaces/mamba-130m-hf", "state-spaces/mamba-130m-hf",
"tiiuae/falcon-mamba-tiny-dev", "tiiuae/falcon-mamba-tiny-dev",
# TODO: Compare to a Mamba2 model. The HF transformers implementation of # TODO: Compare to a Mamba2 model. The HF transformers implementation of
# Mamba2 is buggy for Codestral as it doesn't handle n_groups. # Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
# doesn't compare vLLM output with HF output.
# See https://github.com/huggingface/transformers/pull/35943 # See https://github.com/huggingface/transformers/pull/35943
# "mistralai/Mamba-Codestral-7B-v0.1", "mistralai/Mamba-Codestral-7B-v0.1",
] ]
HYBRID_MODELS = [ HYBRID_MODELS = [
@ -35,6 +36,10 @@ HYBRID_MODELS = [
"hmellor/tiny-random-BambaForCausalLM", "hmellor/tiny-random-BambaForCausalLM",
] ]
V1_SUPPORTED_MODELS = [
"mistralai/Mamba-Codestral-7B-v0.1",
]
# Avoid OOM # Avoid OOM
MAX_NUM_SEQS = 4 MAX_NUM_SEQS = 4
@ -46,24 +51,50 @@ def test_models(
hf_runner, hf_runner,
vllm_runner, vllm_runner,
example_prompts, example_prompts,
monkeypatch,
model: str, model: str,
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
) -> None: ) -> None:
with hf_runner(model) as hf_model: with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit( if model != "mistralai/Mamba-Codestral-7B-v0.1":
example_prompts, max_tokens, num_logprobs) hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
else:
hf_outputs = None
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs( vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs) example_prompts, max_tokens, num_logprobs)
check_logprobs_close( if model in V1_SUPPORTED_MODELS:
outputs_0_lst=hf_outputs, with monkeypatch.context() as m:
outputs_1_lst=vllm_outputs, m.setenv("VLLM_USE_V1", "1")
name_0="hf", with vllm_runner(model,
name_1="vllm", max_num_seqs=MAX_NUM_SEQS,
) enforce_eager=True,
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
vllm_v1_outputs = None
if hf_outputs is not None:
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v0_outputs,
name_0="hf",
name_1="vllm-v0",
)
if model in V1_SUPPORTED_MODELS:
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
check_logprobs_close(
outputs_0_lst=ref_outputs,
outputs_1_lst=vllm_v1_outputs,
name_0="hf" if hf_outputs is not None else "vllm-v0",
name_1="vllm-v1",
)
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS) @pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)

View File

@ -12,7 +12,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
UNSUPPORTED_MODELS_V1 = [ UNSUPPORTED_MODELS_V1 = [
"openai/whisper-large-v3", # transcription "openai/whisper-large-v3", # transcription
"facebook/bart-large-cnn", # encoder decoder "facebook/bart-large-cnn", # encoder decoder
"mistralai/Mamba-Codestral-7B-v0.1", # mamba "state-spaces/mamba-130m-hf", # mamba1
"hmellor/tiny-random-BambaForCausalLM", # hybrid "hmellor/tiny-random-BambaForCausalLM", # hybrid
"BAAI/bge-m3", # embedding "BAAI/bge-m3", # embedding
] ]

View File

@ -1355,12 +1355,17 @@ class EngineArgs:
recommend_to_remove=False) recommend_to_remove=False)
return False return False
# No Mamba or Encoder-Decoder so far. # No Encoder-Decoder, not all Mamba so far.
if not model_config.is_v1_compatible: if not model_config.is_v1_compatible:
_raise_or_fallback(feature_name=model_config.architectures, _raise_or_fallback(feature_name=model_config.architectures,
recommend_to_remove=False) recommend_to_remove=False)
return False return False
# V1 mamba models are unoptimized.
if model_config.has_inner_state and _warn_or_fallback(
feature_name="Mamba"):
return False
# No Concurrent Partial Prefills so far. # No Concurrent Partial Prefills so far.
if (self.max_num_partial_prefills if (self.max_num_partial_prefills
!= SchedulerConfig.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills

View File

@ -6,7 +6,9 @@ from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import get_current_vllm_config
from vllm.distributed import (divide, get_tensor_model_parallel_rank, from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
@ -27,6 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, composed_weight_loader, sharded_weight_loader) LoaderFunction, composed_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata
# Added by the IBM Team, 2024 # Added by the IBM Team, 2024
@ -227,20 +230,22 @@ class MambaMixer2(CustomOp):
""" """
def __init__( def __init__(
self, self,
hidden_size: int, hidden_size: int,
ssm_state_size: int, ssm_state_size: int,
conv_kernel_size: int, conv_kernel_size: int,
intermediate_size: int, intermediate_size: int,
use_conv_bias: bool, use_conv_bias: bool,
use_bias: bool, use_bias: bool,
n_groups: int = 1, n_groups: int = 1,
num_heads: int = 128, num_heads: int = 128,
head_dim: int = 64, head_dim: int = 64,
rms_norm_eps: float = 1e-5, rms_norm_eps: float = 1e-5,
activation: str = "silu", activation: str = "silu",
use_rms_norm: bool = True, use_rms_norm: bool = True,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
chunk_size: int = -1, # the chunk size used by v1
): ):
super().__init__() super().__init__()
@ -273,6 +278,7 @@ class MambaMixer2(CustomOp):
), "Tensor parallel currently not supported for quantized models." ), "Tensor parallel currently not supported for quantized models."
self.ssm_state_size = ssm_state_size self.ssm_state_size = ssm_state_size
self.conv_kernel_size = conv_kernel_size
self.activation = activation self.activation = activation
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
@ -411,6 +417,22 @@ class MambaMixer2(CustomOp):
self.use_rms_norm, self.use_rms_norm,
eps=rms_norm_eps) eps=rms_norm_eps)
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
assert chunk_size != -1, "chunk_size must be set for v1"
# NOTE: chunk_size may be -1 for models without v1 support
self.chunk_size = chunk_size
self.prefix = prefix
def forward_native( def forward_native(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -426,17 +448,37 @@ class MambaMixer2(CustomOp):
mamba2_metadata: Mamba2Metadata, mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None, mup_vector: Optional[torch.Tensor] = None,
): ):
forward_context = get_forward_context()
# mamba2_metadata contains metadata necessary for the mamba2 triton # mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill # kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they # modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration # stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1:
num_prefills = attn_metadata.num_prefills # request count if attn_metadata is not None:
num_decodes = attn_metadata.num_decode_tokens # token count (=request) assert isinstance(attn_metadata, dict)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count attn_metadata = attn_metadata[self.prefix]
has_prefill = num_prefills > 0 assert isinstance(attn_metadata, Mamba2AttentionMetadata)
has_decode = num_decodes > 0 self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0]
ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx
chunk_indices_p = attn_metadata.chunk_indices
chunk_offsets_p = attn_metadata.chunk_offsets
else:
conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor
has_initial_states_p = mamba2_metadata.has_initial_states
prep_initial_states = mamba2_metadata.prep_initial_states
chunk_size = mamba2_metadata.chunk_size
seq_idx_p = mamba2_metadata.seq_idx
chunk_indices_p = mamba2_metadata.chunk_indices
chunk_offsets_p = mamba2_metadata.chunk_offsets
groups_time_state_size = self.n_groups * self.ssm_state_size groups_time_state_size = self.n_groups * self.ssm_state_size
@ -459,27 +501,6 @@ class MambaMixer2(CustomOp):
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2)) self.conv1d.weight.size(2))
# Separate prefill and decode by splitting varlen input
# Split along token dimension
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
hidden_states_B_C,
[num_prefill_tokens, num_decodes],
dim=0,
)
dt_p, dt_d = torch.split(
dt,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
mamba_cache_params.state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
if has_prefill else None)
# - get hidden_states, B and C after depthwise convolution. # - get hidden_states, B and C after depthwise convolution.
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split( split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
hidden_states_B_C, hidden_states_B_C,
@ -491,20 +512,80 @@ class MambaMixer2(CustomOp):
dim=-1, dim=-1,
) )
if envs.VLLM_USE_V1 and attn_metadata is None:
# V1 profile run
hidden_states_B_C = (hidden_states_B_C.transpose(
0, 1).clone().transpose(0, 1)).contiguous()
hidden_states, _B, _C = split_hidden_states_B_C_fn(
hidden_states_B_C)
hidden_states = self.norm(hidden_states, gate)
out, _ = self.out_proj(hidden_states)
return out
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
if envs.VLLM_USE_V1:
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C,
[num_decodes, num_prefill_tokens],
dim=0,
)
dt_d, dt_p = torch.split(
dt,
[num_decodes, num_prefill_tokens],
dim=0,
)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
else:
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
hidden_states_B_C,
[num_prefill_tokens, num_decodes],
dim=0,
)
dt_p, dt_d = torch.split(
dt,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
1]
if has_prefill else None)
ssd_output_list = [] ssd_output_list = []
# Process prefill requests # Process prefill requests
if has_prefill: if has_prefill:
# 2. Convolution sequence transformation # 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions # - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor" # pointed to by "state_indices_tensor"
hidden_states_B_C_p = causal_conv1d_fn( hidden_states_B_C_p = causal_conv1d_fn(
hidden_states_B_C_p.transpose(0, 1), hidden_states_B_C_p.transpose(0, 1),
conv_weights, conv_weights,
self.conv1d.bias, self.conv1d.bias,
activation=self.activation, activation=self.activation,
conv_states=mamba_cache_params.conv_state, conv_states=conv_state,
has_initial_state=mamba2_metadata.has_initial_states, has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p, cache_indices=state_indices_tensor_p,
query_start_loc=query_start_loc_p).transpose( query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens] 0, 1)[:num_prefill_tokens]
@ -516,12 +597,11 @@ class MambaMixer2(CustomOp):
# 3. State Space Model sequence transformation # 3. State Space Model sequence transformation
initial_states = None initial_states = None
if (mamba2_metadata.has_initial_states is not None if (has_initial_states_p is not None and prep_initial_states):
and mamba2_metadata.prep_initial_states):
# making a copy of the states # making a copy of the states
initial_states = torch.where( initial_states = torch.where(
mamba2_metadata.has_initial_states[:, None, None, None], has_initial_states_p[:, None, None, None],
mamba_cache_params.ssm_state[state_indices_tensor_p], 0) ssm_state[state_indices_tensor_p], 0)
scan_output, varlen_state = mamba_chunk_scan_combined( scan_output, varlen_state = mamba_chunk_scan_combined(
hidden_states_p.view(1, num_prefill_tokens, hidden_states_p.view(1, num_prefill_tokens,
@ -533,14 +613,14 @@ class MambaMixer2(CustomOp):
-1), -1),
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
-1), -1),
chunk_size=mamba2_metadata.chunk_size, chunk_size=chunk_size,
D=self.D, D=self.D,
z=None, z=None,
dt_bias=self.dt_bias, dt_bias=self.dt_bias,
seq_idx=mamba2_metadata.seq_idx, seq_idx=seq_idx_p,
chunk_indices=mamba2_metadata.chunk_indices, chunk_indices=chunk_indices_p,
chunk_offsets=mamba2_metadata.chunk_offsets, chunk_offsets=chunk_offsets_p,
cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1], cu_seqlens=query_start_loc_p,
initial_states=initial_states, initial_states=initial_states,
return_varlen_states=True, return_varlen_states=True,
return_final_states=False, return_final_states=False,
@ -550,7 +630,7 @@ class MambaMixer2(CustomOp):
# update ssm states # update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state ssm_state[state_indices_tensor_p] = varlen_state
# - reshape # - reshape
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1)) ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
@ -560,7 +640,7 @@ class MambaMixer2(CustomOp):
# 2. Convolution sequence transformation # 2. Convolution sequence transformation
hidden_states_B_C_d = causal_conv1d_update( hidden_states_B_C_d = causal_conv1d_update(
hidden_states_B_C_d, hidden_states_B_C_d,
mamba_cache_params.conv_state, conv_state,
conv_weights, conv_weights,
self.conv1d.bias, self.conv1d.bias,
self.activation, self.activation,
@ -586,7 +666,7 @@ class MambaMixer2(CustomOp):
# using state_indices_tensor_d # using state_indices_tensor_d
hidden_states_d = selective_state_update( hidden_states_d = selective_state_update(
mamba_cache_params.ssm_state, ssm_state,
hidden_states_d, hidden_states_d,
dt_d, dt_d,
A_d, A_d,
@ -598,9 +678,16 @@ class MambaMixer2(CustomOp):
dt_softplus=True, dt_softplus=True,
state_batch_indices=state_indices_tensor_d, state_batch_indices=state_indices_tensor_d,
) )
ssd_output_list.append(
hidden_states_d.view(-1, (self.num_heads // self.tp_size) * if envs.VLLM_USE_V1:
self.head_dim)) ssd_output_list.insert(
0,
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
self.head_dim))
else:
ssd_output_list.append(
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
self.head_dim))
# Merge prefill and decode outputs before passing to gated MLP # Merge prefill and decode outputs before passing to gated MLP
hidden_states = torch.vstack(ssd_output_list) hidden_states = torch.vstack(ssd_output_list)
@ -614,3 +701,31 @@ class MambaMixer2(CustomOp):
# 5. Final linear projection # 5. Final linear projection
out, _ = self.out_proj(hidden_states) out, _ = self.out_proj(hidden_states)
return out return out
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
world_size = get_tensor_model_parallel_world_size()
conv_state_shape, temporal_state_shape = None, None
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = (self.n_groups +
extra_groups_for_head_shards(self.n_groups, world_size))
# - heads and n_groups are TP-ed
conv_dim = (self.intermediate_size +
2 * n_groups * self.ssm_state_size)
conv_state_shape = (
divide(conv_dim, world_size),
self.conv_kernel_size - 1,
)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
temporal_state_shape = (
divide(self.num_heads, world_size),
self.head_dim,
self.ssm_state_size,
)
return conv_state_shape, temporal_state_shape

View File

@ -8,6 +8,7 @@ import torch
from torch import nn from torch import nn
from transformers import MambaConfig from transformers import MambaConfig
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed import divide, get_tensor_model_parallel_world_size
@ -25,8 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (HasInnerState, from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree, IsAttentionFree)
SupportsV0Only)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager, from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams) MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -44,7 +44,8 @@ class Mamba2DecoderLayer(nn.Module):
def __init__(self, def __init__(self,
config: MambaConfig, config: MambaConfig,
quant_config: Optional[QuantizationConfig] = None) -> None: quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.mixer = MambaMixer2(hidden_size=config.hidden_size, self.mixer = MambaMixer2(hidden_size=config.hidden_size,
@ -60,7 +61,9 @@ class Mamba2DecoderLayer(nn.Module):
head_dim=config.head_dim, head_dim=config.head_dim,
rms_norm_eps=config.layer_norm_epsilon, rms_norm_eps=config.layer_norm_epsilon,
activation=config.hidden_act, activation=config.hidden_act,
quant_config=quant_config) quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.chunk_size)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@ -108,8 +111,8 @@ class Mamba2Model(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: Mamba2DecoderLayer(config, lambda prefix: Mamba2DecoderLayer(
quant_config=quant_config), config, quant_config=quant_config, prefix=prefix),
prefix=f"{prefix}.layers") prefix=f"{prefix}.layers")
self.norm_f = RMSNorm(config.hidden_size, self.norm_f = RMSNorm(config.hidden_size,
@ -142,10 +145,14 @@ class Mamba2Model(nn.Module):
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata( if not envs.VLLM_USE_V1:
chunk_size=self.config.chunk_size, mamba2_metadata = prepare_mamba2_metadata(
attn_metadata=attn_metadata, chunk_size=self.config.chunk_size,
) attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
@ -155,7 +162,7 @@ class Mamba2Model(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
residual=residual, residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx( mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer), i - self.start_layer) if mamba_cache_params else None,
mamba2_metadata=mamba2_metadata) mamba2_metadata=mamba2_metadata)
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
@ -190,8 +197,7 @@ class Mamba2Model(nn.Module):
return loaded_params return loaded_params
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree, class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
SupportsV0Only):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
@ -242,14 +248,20 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
**kwargs): **kwargs):
if self.mamba_cache is None: if not envs.VLLM_USE_V1:
num_mamba_layers = self.model_config.get_num_layers_by_block_type( if self.mamba_cache is None:
self.vllm_config.parallel_config, LayerBlockType.mamba) num_mamba_layers = (
self.mamba_cache = MambaCacheManager( self.model_config.get_num_layers_by_block_type(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers, self.vllm_config.parallel_config,
*self._get_mamba_cache_shape()) LayerBlockType.mamba))
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype,
num_mamba_layers, *self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
else:
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params = None
hidden_states = self.backbone(input_ids, positions, mamba_cache_params, hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds) intermediate_tensors, inputs_embeds)

View File

@ -0,0 +1,192 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.mamba.mamba2_metadata import (
_query_start_loc_to_chunk_indices_offsets)
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import MambaSpec
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
layers = get_layers_from_vllm_config(vllm_config, MambaMixer2)
chunk_sizes = set(layer.chunk_size for layer in layers.values())
assert len(
chunk_sizes) == 1, "All Mamba2 layers must have the same chunk size"
return chunk_sizes.pop()
class Mamba2AttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
return Mamba2AttentionMetadataBuilder
@dataclass
class Mamba2AttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
query_start_loc: torch.Tensor
seq_lens: torch.Tensor
has_initial_states: torch.Tensor
prep_initial_states: bool
chunk_size: int
seq_idx: torch.Tensor
chunk_indices: torch.Tensor
chunk_offsets: torch.Tensor
state_indices_tensor: torch.Tensor # shape: [batch,]
class Mamba2AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec,
block_table: BlockTable):
self.runner = runner
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
self.chunk_size = get_mamba2_chunk_size(runner.vllm_config)
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# NOTE (Chen): Copied from MLACommonMetadataBuilder and
# FlashInferMetadataBuilder. Should be refactored later to avoid code
# duplication of these 3 functions.
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the decode run only supports num_tokens = 1
if num_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
seq_idx = None
chunk_indices, chunk_offsets = None, None
# Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend
has_initial_states = None
prep_initial_states = False
state_indices_tensor = self.block_table.block_table[:num_reqs, 0]
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if self._num_prefills > 0:
#[batch,]
has_initial_states_cpu = (
self.runner.input_batch.
num_computed_tokens_cpu_tensor[num_reqs -
self._num_prefills:num_reqs]
> 0)
prep_initial_states = torch.any(has_initial_states_cpu).item()
has_initial_states = has_initial_states_cpu.to(
query_start_loc.device)
query_start_loc_p = common_attn_metadata.query_start_loc[
-self._num_prefills - 1:] - self._num_decode_tokens
seq_idx = torch.repeat_interleave(
torch.arange(self._num_prefills,
dtype=torch.int32,
device=query_start_loc_p.device),
query_start_loc_p.diff(),
output_size=self._num_prefill_tokens)
seq_idx.unsqueeze_(0)
# We compute metadata for chunked prefill once at the top level
# model forward and reuse them in mamba layers. If not needed,
# they will be ignored inside mamba kernels.
if prep_initial_states:
chunk_indices, chunk_offsets = (
_query_start_loc_to_chunk_indices_offsets(
query_start_loc_p, self.chunk_size,
self._num_prefill_tokens))
attn_metadata = Mamba2AttentionMetadata(
num_prefills=self._num_prefills,
num_prefill_tokens=self._num_prefill_tokens,
num_decodes=self._num_decodes,
num_decode_tokens=self._num_decode_tokens,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
has_initial_states=has_initial_states,
prep_initial_states=prep_initial_states,
chunk_size=self.chunk_size,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
state_indices_tensor=state_indices_tensor,
)
return attn_metadata

View File

@ -8,7 +8,7 @@ from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
SlidingWindowSpec) MambaSpec, SlidingWindowSpec)
from vllm.v1.request import Request from vllm.v1.request import Request
@ -52,6 +52,7 @@ class SingleTypeKVCacheManager(ABC):
self.caching_hash_fn = caching_hash_fn self.caching_hash_fn = caching_hash_fn
self.kv_cache_group_id = kv_cache_group_id self.kv_cache_group_id = kv_cache_group_id
self._null_block = block_pool.null_block
def get_num_blocks_to_allocate( def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int, self, request_id: str, num_tokens: int,
@ -390,9 +391,49 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
return 0 return 0
class MambaManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(
kv_cache_spec,
MambaSpec), ("MambaManager can only be used for mamba groups")
# Prefix caching is not supported for mamba now. Always return empty
# list.
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids)))
return computed_blocks
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
# Each request will always have 1 block at this moment, so no need to
# remove blocks.
pass
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
return 0
def allocate_new_blocks(self, request_id: str,
num_tokens: int) -> list[KVCacheBlock]:
new_blocks = super().allocate_new_blocks(request_id, num_tokens)
assert len(self.req_to_blocks[request_id]) == 1, (
"MambaManager should only allocate 1 block for each request.")
return new_blocks
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager, FullAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager, SlidingWindowSpec: SlidingWindowManager,
MambaSpec: MambaManager,
} }

View File

@ -3,6 +3,7 @@
import copy import copy
from dataclasses import dataclass from dataclasses import dataclass
from math import prod
from typing import Optional from typing import Optional
import torch import torch
@ -154,6 +155,29 @@ class SlidingWindowSpec(AttentionSpec):
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
@dataclass
class MambaSpec(KVCacheSpec):
shapes: tuple[tuple[int, ...], ...]
dtype: torch.dtype
def __post_init__(self):
self.num_elements = sum(prod(shape) for shape in self.shapes)
@property
def type_id(self) -> str:
return f"mamba_{self.shapes}_{self.dtype}"
@property
def page_size_bytes(self) -> int:
return self.num_elements * get_dtype_size(self.dtype)
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
# We allocate 1 block for each request now, so max_memory_usage_bytes is
# the same as page_size_bytes.
# Need to update this when supporting prefix caching.
return self.page_size_bytes
@dataclass @dataclass
class KVCacheTensor: class KVCacheTensor:
""" """

View File

@ -29,6 +29,7 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import (DPMetadata, get_forward_context, from vllm.forward_context import (DPMetadata, get_forward_context,
set_forward_context) set_forward_context)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@ -38,12 +39,14 @@ from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
check_use_alibi, is_pin_memory_available) check_use_alibi, get_dtype_size,
is_pin_memory_available)
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec, KVCacheConfig, KVCacheSpec, MambaSpec,
SlidingWindowSpec) SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput) ModelRunnerOutput)
@ -2093,28 +2096,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for i, kv_cache_group_spec in enumerate( for i, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups): kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if not isinstance(kv_cache_spec, AttentionSpec): if isinstance(kv_cache_spec, AttentionSpec):
raise NotImplementedError( attn_backend_i = get_attn_backend(
"Only AttentionSpec is supported for now.") kv_cache_spec.head_size,
attn_backend_i = get_attn_backend( self.dtype,
kv_cache_spec.head_size, kv_cache_spec.dtype,
self.dtype, kv_cache_spec.block_size,
kv_cache_spec.dtype, self.model_config.is_attention_free,
kv_cache_spec.block_size, use_mla=kv_cache_spec.use_mla,
self.model_config.is_attention_free, )
use_mla=kv_cache_spec.use_mla, if attn_backend_i is None:
) error_msg = (f"Error with get_attn_backend: "
if attn_backend_i is None: f"{kv_cache_spec.head_size=}, "
error_msg = ( f"{self.dtype=}, {kv_cache_spec.dtype=}, "
f"Error with get_attn_backend: {kv_cache_spec.head_size=}, " f"{kv_cache_spec.block_size=}, "
f"{self.dtype=}, {kv_cache_spec.dtype=}, " f"{self.model_config.is_attention_free=}, "
f"{kv_cache_spec.block_size=}, " f"{kv_cache_spec.use_mla=}")
f"{self.model_config.is_attention_free=}, " logger.error(error_msg)
f"{kv_cache_spec.use_mla=}") raise NotImplementedError(
logger.error(error_msg) "Non-Attention backend is not supported by V1 "
raise NotImplementedError( "GPUModelRunner.")
"Non-Attention backend is not supported by V1 " elif isinstance(kv_cache_spec, MambaSpec):
"GPUModelRunner.") attn_backend_i = Mamba2AttentionBackend
else:
raise ValueError(
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
block_table_i = self.input_batch.block_table[i] block_table_i = self.input_batch.block_table[i]
attn_metadata_builder_i = attn_backend_i.get_builder_cls()( attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
@ -2242,6 +2248,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_caches[layer_name] = kv_cache_raw_tensors[ kv_caches[layer_name] = kv_cache_raw_tensors[
layer_name].view(dtype).view(kv_cache_shape).permute( layer_name].view(dtype).view(kv_cache_shape).permute(
*inv_order) *inv_order)
elif isinstance(kv_cache_spec, MambaSpec):
raw_tensor = kv_cache_raw_tensors[layer_name]
dtype = kv_cache_spec.dtype
state_tensors = []
start_pos = 0
for shape in kv_cache_spec.shapes:
target_shape = (num_blocks, *shape)
size_in_bytes = np.prod(shape) * get_dtype_size(
dtype) * num_blocks
tensor = raw_tensor[start_pos:start_pos +
size_in_bytes]
tensor = tensor.view(dtype).view(target_shape)
state_tensors.append(tensor)
start_pos += size_in_bytes
assert start_pos == raw_tensor.numel()
kv_caches[layer_name] = tuple(state_tensors)
else: else:
raise NotImplementedError raise NotImplementedError
return kv_caches return kv_caches
@ -2307,11 +2329,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
format. Layers that do not need KV cache are not included. format. Layers that do not need KV cache are not included.
""" """
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
block_size = self.vllm_config.cache_config.block_size block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla use_mla = self.vllm_config.model_config.use_mla
kv_cache_spec: dict[str, KVCacheSpec] = {} kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in layers.items(): attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
if (kv_tgt_layer := if (kv_tgt_layer :=
attn_module.kv_sharing_target_layer_name) is not None: attn_module.kv_sharing_target_layer_name) is not None:
# The layer doesn't need its own KV cache and will use that of # The layer doesn't need its own KV cache and will use that of
@ -2351,4 +2373,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
raise ValueError( raise ValueError(
f"Unknown attention type: {attn_module.attn_type}") f"Unknown attention type: {attn_module.attn_type}")
mamba_layers = get_layers_from_vllm_config(self.vllm_config,
MambaMixer2)
if len(mamba_layers) > 0:
if self.vllm_config.speculative_config is not None:
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet.")
if not self.vllm_config.model_config.enforce_eager:
raise NotImplementedError(
"Mamba with cuda graph is not supported yet.")
if self.vllm_config.cache_config.enable_prefix_caching:
raise NotImplementedError(
"Prefix caching is not supported for Mamba yet.")
max_model_len = self.vllm_config.model_config.max_model_len
# Set block_size to max_model_len, so that mamba model will always
# have only one block in the KV cache.
for layer_name, mamba_module in mamba_layers.items():
kv_cache_spec[layer_name] = MambaSpec(
shapes=mamba_module.get_state_shape(),
dtype=self.kv_cache_dtype,
block_size=max_model_len)
return kv_cache_spec return kv_cache_spec