[v1] Support mamba2 (#19327)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-06-19 04:34:15 +08:00 committed by GitHub
parent ffacb222cb
commit a89209b78d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 582 additions and 120 deletions

View File

@ -17,9 +17,10 @@ SSM_MODELS = [
"state-spaces/mamba-130m-hf",
"tiiuae/falcon-mamba-tiny-dev",
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
# Mamba2 is buggy for Codestral as it doesn't handle n_groups.
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
# doesn't compare vLLM output with HF output.
# See https://github.com/huggingface/transformers/pull/35943
# "mistralai/Mamba-Codestral-7B-v0.1",
"mistralai/Mamba-Codestral-7B-v0.1",
]
HYBRID_MODELS = [
@ -35,6 +36,10 @@ HYBRID_MODELS = [
"hmellor/tiny-random-BambaForCausalLM",
]
V1_SUPPORTED_MODELS = [
"mistralai/Mamba-Codestral-7B-v0.1",
]
# Avoid OOM
MAX_NUM_SEQS = 4
@ -46,24 +51,50 @@ def test_models(
hf_runner,
vllm_runner,
example_prompts,
monkeypatch,
model: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
if model != "mistralai/Mamba-Codestral-7B-v0.1":
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
else:
hf_outputs = None
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
if model in V1_SUPPORTED_MODELS:
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
enforce_eager=True,
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
vllm_v1_outputs = None
if hf_outputs is not None:
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_v0_outputs,
name_0="hf",
name_1="vllm-v0",
)
if model in V1_SUPPORTED_MODELS:
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
check_logprobs_close(
outputs_0_lst=ref_outputs,
outputs_1_lst=vllm_v1_outputs,
name_0="hf" if hf_outputs is not None else "vllm-v0",
name_1="vllm-v1",
)
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)

View File

@ -12,7 +12,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
UNSUPPORTED_MODELS_V1 = [
"openai/whisper-large-v3", # transcription
"facebook/bart-large-cnn", # encoder decoder
"mistralai/Mamba-Codestral-7B-v0.1", # mamba
"state-spaces/mamba-130m-hf", # mamba1
"hmellor/tiny-random-BambaForCausalLM", # hybrid
"BAAI/bge-m3", # embedding
]

View File

@ -1355,12 +1355,17 @@ class EngineArgs:
recommend_to_remove=False)
return False
# No Mamba or Encoder-Decoder so far.
# No Encoder-Decoder, not all Mamba so far.
if not model_config.is_v1_compatible:
_raise_or_fallback(feature_name=model_config.architectures,
recommend_to_remove=False)
return False
# V1 mamba models are unoptimized.
if model_config.has_inner_state and _warn_or_fallback(
feature_name="Mamba"):
return False
# No Concurrent Partial Prefills so far.
if (self.max_num_partial_prefills
!= SchedulerConfig.max_num_partial_prefills

View File

@ -6,7 +6,9 @@ from typing import Optional, Union
import torch
from torch import nn
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import get_current_vllm_config
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
@ -27,6 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, composed_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata
# Added by the IBM Team, 2024
@ -227,20 +230,22 @@ class MambaMixer2(CustomOp):
"""
def __init__(
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
quant_config: Optional[QuantizationConfig] = None,
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
chunk_size: int = -1, # the chunk size used by v1
):
super().__init__()
@ -273,6 +278,7 @@ class MambaMixer2(CustomOp):
), "Tensor parallel currently not supported for quantized models."
self.ssm_state_size = ssm_state_size
self.conv_kernel_size = conv_kernel_size
self.activation = activation
self.intermediate_size = intermediate_size
@ -411,6 +417,22 @@ class MambaMixer2(CustomOp):
self.use_rms_norm,
eps=rms_norm_eps)
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The outer list is for v0 PP virtual engine. Though this code path
# only runs for v1, we have to do this to unify with the interface
# of Attention + v0 PP.
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
assert chunk_size != -1, "chunk_size must be set for v1"
# NOTE: chunk_size may be -1 for models without v1 support
self.chunk_size = chunk_size
self.prefix = prefix
def forward_native(
self,
hidden_states: torch.Tensor,
@ -426,17 +448,37 @@ class MambaMixer2(CustomOp):
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
):
forward_context = get_forward_context()
# mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1:
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0]
ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx
chunk_indices_p = attn_metadata.chunk_indices
chunk_offsets_p = attn_metadata.chunk_offsets
else:
conv_state = mamba_cache_params.conv_state
ssm_state = mamba_cache_params.ssm_state
state_indices_tensor = mamba_cache_params.state_indices_tensor
has_initial_states_p = mamba2_metadata.has_initial_states
prep_initial_states = mamba2_metadata.prep_initial_states
chunk_size = mamba2_metadata.chunk_size
seq_idx_p = mamba2_metadata.seq_idx
chunk_indices_p = mamba2_metadata.chunk_indices
chunk_offsets_p = mamba2_metadata.chunk_offsets
groups_time_state_size = self.n_groups * self.ssm_state_size
@ -459,27 +501,6 @@ class MambaMixer2(CustomOp):
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
# Separate prefill and decode by splitting varlen input
# Split along token dimension
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
hidden_states_B_C,
[num_prefill_tokens, num_decodes],
dim=0,
)
dt_p, dt_d = torch.split(
dt,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
mamba_cache_params.state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
if has_prefill else None)
# - get hidden_states, B and C after depthwise convolution.
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
hidden_states_B_C,
@ -491,20 +512,80 @@ class MambaMixer2(CustomOp):
dim=-1,
)
if envs.VLLM_USE_V1 and attn_metadata is None:
# V1 profile run
hidden_states_B_C = (hidden_states_B_C.transpose(
0, 1).clone().transpose(0, 1)).contiguous()
hidden_states, _B, _C = split_hidden_states_B_C_fn(
hidden_states_B_C)
hidden_states = self.norm(hidden_states, gate)
out, _ = self.out_proj(hidden_states)
return out
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
if envs.VLLM_USE_V1:
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C,
[num_decodes, num_prefill_tokens],
dim=0,
)
dt_d, dt_p = torch.split(
dt,
[num_decodes, num_prefill_tokens],
dim=0,
)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
else:
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
hidden_states_B_C,
[num_prefill_tokens, num_decodes],
dim=0,
)
dt_p, dt_d = torch.split(
dt,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
1]
if has_prefill else None)
ssd_output_list = []
# Process prefill requests
if has_prefill:
# 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor"
# pointed to by "state_indices_tensor"
hidden_states_B_C_p = causal_conv1d_fn(
hidden_states_B_C_p.transpose(0, 1),
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=mamba2_metadata.has_initial_states,
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]
@ -516,12 +597,11 @@ class MambaMixer2(CustomOp):
# 3. State Space Model sequence transformation
initial_states = None
if (mamba2_metadata.has_initial_states is not None
and mamba2_metadata.prep_initial_states):
if (has_initial_states_p is not None and prep_initial_states):
# making a copy of the states
initial_states = torch.where(
mamba2_metadata.has_initial_states[:, None, None, None],
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
has_initial_states_p[:, None, None, None],
ssm_state[state_indices_tensor_p], 0)
scan_output, varlen_state = mamba_chunk_scan_combined(
hidden_states_p.view(1, num_prefill_tokens,
@ -533,14 +613,14 @@ class MambaMixer2(CustomOp):
-1),
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
-1),
chunk_size=mamba2_metadata.chunk_size,
chunk_size=chunk_size,
D=self.D,
z=None,
dt_bias=self.dt_bias,
seq_idx=mamba2_metadata.seq_idx,
chunk_indices=mamba2_metadata.chunk_indices,
chunk_offsets=mamba2_metadata.chunk_offsets,
cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1],
seq_idx=seq_idx_p,
chunk_indices=chunk_indices_p,
chunk_offsets=chunk_offsets_p,
cu_seqlens=query_start_loc_p,
initial_states=initial_states,
return_varlen_states=True,
return_final_states=False,
@ -550,7 +630,7 @@ class MambaMixer2(CustomOp):
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
ssm_state[state_indices_tensor_p] = varlen_state
# - reshape
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
@ -560,7 +640,7 @@ class MambaMixer2(CustomOp):
# 2. Convolution sequence transformation
hidden_states_B_C_d = causal_conv1d_update(
hidden_states_B_C_d,
mamba_cache_params.conv_state,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
@ -586,7 +666,7 @@ class MambaMixer2(CustomOp):
# using state_indices_tensor_d
hidden_states_d = selective_state_update(
mamba_cache_params.ssm_state,
ssm_state,
hidden_states_d,
dt_d,
A_d,
@ -598,9 +678,16 @@ class MambaMixer2(CustomOp):
dt_softplus=True,
state_batch_indices=state_indices_tensor_d,
)
ssd_output_list.append(
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
self.head_dim))
if envs.VLLM_USE_V1:
ssd_output_list.insert(
0,
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
self.head_dim))
else:
ssd_output_list.append(
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
self.head_dim))
# Merge prefill and decode outputs before passing to gated MLP
hidden_states = torch.vstack(ssd_output_list)
@ -614,3 +701,31 @@ class MambaMixer2(CustomOp):
# 5. Final linear projection
out, _ = self.out_proj(hidden_states)
return out
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
world_size = get_tensor_model_parallel_world_size()
conv_state_shape, temporal_state_shape = None, None
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = (self.n_groups +
extra_groups_for_head_shards(self.n_groups, world_size))
# - heads and n_groups are TP-ed
conv_dim = (self.intermediate_size +
2 * n_groups * self.ssm_state_size)
conv_state_shape = (
divide(conv_dim, world_size),
self.conv_kernel_size - 1,
)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
temporal_state_shape = (
divide(self.num_heads, world_size),
self.head_dim,
self.ssm_state_size,
)
return conv_state_shape, temporal_state_shape

View File

@ -8,6 +8,7 @@ import torch
from torch import nn
from transformers import MambaConfig
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@ -25,8 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (HasInnerState,
IsAttentionFree,
SupportsV0Only)
IsAttentionFree)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -44,7 +44,8 @@ class Mamba2DecoderLayer(nn.Module):
def __init__(self,
config: MambaConfig,
quant_config: Optional[QuantizationConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
super().__init__()
self.config = config
self.mixer = MambaMixer2(hidden_size=config.hidden_size,
@ -60,7 +61,9 @@ class Mamba2DecoderLayer(nn.Module):
head_dim=config.head_dim,
rms_norm_eps=config.layer_norm_epsilon,
activation=config.hidden_act,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.chunk_size)
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@ -108,8 +111,8 @@ class Mamba2Model(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Mamba2DecoderLayer(config,
quant_config=quant_config),
lambda prefix: Mamba2DecoderLayer(
config, quant_config=quant_config, prefix=prefix),
prefix=f"{prefix}.layers")
self.norm_f = RMSNorm(config.hidden_size,
@ -142,10 +145,14 @@ class Mamba2Model(nn.Module):
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
for i in range(len(self.layers)):
layer = self.layers[i]
@ -155,7 +162,7 @@ class Mamba2Model(nn.Module):
hidden_states=hidden_states,
residual=residual,
mamba_cache_params=mamba_cache_params.at_layer_idx(
i - self.start_layer),
i - self.start_layer) if mamba_cache_params else None,
mamba2_metadata=mamba2_metadata)
if not get_pp_group().is_last_rank:
@ -190,8 +197,7 @@ class Mamba2Model(nn.Module):
return loaded_params
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
SupportsV0Only):
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
@ -242,14 +248,20 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = (
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba))
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype,
num_mamba_layers, *self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
else:
# NOTE: mamba_cache_params is not needed for v1
mamba_cache_params = None
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)

View File

@ -0,0 +1,192 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.model_executor.layers.mamba.mamba2_metadata import (
_query_start_loc_to_chunk_indices_offsets)
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import MambaSpec
from vllm.v1.worker.block_table import BlockTable
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
layers = get_layers_from_vllm_config(vllm_config, MambaMixer2)
chunk_sizes = set(layer.chunk_size for layer in layers.values())
assert len(
chunk_sizes) == 1, "All Mamba2 layers must have the same chunk size"
return chunk_sizes.pop()
class Mamba2AttentionBackend(AttentionBackend):
@staticmethod
def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
return Mamba2AttentionMetadataBuilder
@dataclass
class Mamba2AttentionMetadata:
num_prefills: int
num_prefill_tokens: int
num_decodes: int
num_decode_tokens: int
query_start_loc: torch.Tensor
seq_lens: torch.Tensor
has_initial_states: torch.Tensor
prep_initial_states: bool
chunk_size: int
seq_idx: torch.Tensor
chunk_indices: torch.Tensor
chunk_offsets: torch.Tensor
state_indices_tensor: torch.Tensor # shape: [batch,]
class Mamba2AttentionMetadataBuilder(
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec,
block_table: BlockTable):
self.runner = runner
self.kv_cache_spec = kv_cache_spec
self.block_table = block_table
self.chunk_size = get_mamba2_chunk_size(runner.vllm_config)
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# NOTE (Chen): Copied from MLACommonMetadataBuilder and
# FlashInferMetadataBuilder. Should be refactored later to avoid code
# duplication of these 3 functions.
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the decode run only supports num_tokens = 1
if num_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
return modified_batch
def build(self, common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata):
num_reqs = common_attn_metadata.num_reqs
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
seq_idx = None
chunk_indices, chunk_offsets = None, None
# Need flags to indicate if there are initial states
# currently we really only support the FlashAttention backend
has_initial_states = None
prep_initial_states = False
state_indices_tensor = self.block_table.block_table[:num_reqs, 0]
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if self._num_prefills > 0:
#[batch,]
has_initial_states_cpu = (
self.runner.input_batch.
num_computed_tokens_cpu_tensor[num_reqs -
self._num_prefills:num_reqs]
> 0)
prep_initial_states = torch.any(has_initial_states_cpu).item()
has_initial_states = has_initial_states_cpu.to(
query_start_loc.device)
query_start_loc_p = common_attn_metadata.query_start_loc[
-self._num_prefills - 1:] - self._num_decode_tokens
seq_idx = torch.repeat_interleave(
torch.arange(self._num_prefills,
dtype=torch.int32,
device=query_start_loc_p.device),
query_start_loc_p.diff(),
output_size=self._num_prefill_tokens)
seq_idx.unsqueeze_(0)
# We compute metadata for chunked prefill once at the top level
# model forward and reuse them in mamba layers. If not needed,
# they will be ignored inside mamba kernels.
if prep_initial_states:
chunk_indices, chunk_offsets = (
_query_start_loc_to_chunk_indices_offsets(
query_start_loc_p, self.chunk_size,
self._num_prefill_tokens))
attn_metadata = Mamba2AttentionMetadata(
num_prefills=self._num_prefills,
num_prefill_tokens=self._num_prefill_tokens,
num_decodes=self._num_decodes,
num_decode_tokens=self._num_decode_tokens,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
has_initial_states=has_initial_states,
prep_initial_states=prep_initial_states,
chunk_size=self.chunk_size,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
state_indices_tensor=state_indices_tensor,
)
return attn_metadata

View File

@ -8,7 +8,7 @@ from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
SlidingWindowSpec)
MambaSpec, SlidingWindowSpec)
from vllm.v1.request import Request
@ -52,6 +52,7 @@ class SingleTypeKVCacheManager(ABC):
self.caching_hash_fn = caching_hash_fn
self.kv_cache_group_id = kv_cache_group_id
self._null_block = block_pool.null_block
def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int,
@ -390,9 +391,49 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
return 0
class MambaManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(
kv_cache_spec,
MambaSpec), ("MambaManager can only be used for mamba groups")
# Prefix caching is not supported for mamba now. Always return empty
# list.
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids)))
return computed_blocks
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
# Each request will always have 1 block at this moment, so no need to
# remove blocks.
pass
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
return 0
def allocate_new_blocks(self, request_id: str,
num_tokens: int) -> list[KVCacheBlock]:
new_blocks = super().allocate_new_blocks(request_id, num_tokens)
assert len(self.req_to_blocks[request_id]) == 1, (
"MambaManager should only allocate 1 block for each request.")
return new_blocks
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
MambaSpec: MambaManager,
}

View File

@ -3,6 +3,7 @@
import copy
from dataclasses import dataclass
from math import prod
from typing import Optional
import torch
@ -154,6 +155,29 @@ class SlidingWindowSpec(AttentionSpec):
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
@dataclass
class MambaSpec(KVCacheSpec):
shapes: tuple[tuple[int, ...], ...]
dtype: torch.dtype
def __post_init__(self):
self.num_elements = sum(prod(shape) for shape in self.shapes)
@property
def type_id(self) -> str:
return f"mamba_{self.shapes}_{self.dtype}"
@property
def page_size_bytes(self) -> int:
return self.num_elements * get_dtype_size(self.dtype)
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
# We allocate 1 block for each request now, so max_memory_usage_bytes is
# the same as page_size_bytes.
# Need to update this when supporting prefix caching.
return self.page_size_bytes
@dataclass
class KVCacheTensor:
"""

View File

@ -29,6 +29,7 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import (DPMetadata, get_forward_context,
set_forward_context)
from vllm.logger import init_logger
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
@ -38,12 +39,14 @@ from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
check_use_alibi, is_pin_memory_available)
check_use_alibi, get_dtype_size,
is_pin_memory_available)
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec,
KVCacheConfig, KVCacheSpec, MambaSpec,
SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
@ -2093,28 +2096,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for i, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
if not isinstance(kv_cache_spec, AttentionSpec):
raise NotImplementedError(
"Only AttentionSpec is supported for now.")
attn_backend_i = get_attn_backend(
kv_cache_spec.head_size,
self.dtype,
kv_cache_spec.dtype,
kv_cache_spec.block_size,
self.model_config.is_attention_free,
use_mla=kv_cache_spec.use_mla,
)
if attn_backend_i is None:
error_msg = (
f"Error with get_attn_backend: {kv_cache_spec.head_size=}, "
f"{self.dtype=}, {kv_cache_spec.dtype=}, "
f"{kv_cache_spec.block_size=}, "
f"{self.model_config.is_attention_free=}, "
f"{kv_cache_spec.use_mla=}")
logger.error(error_msg)
raise NotImplementedError(
"Non-Attention backend is not supported by V1 "
"GPUModelRunner.")
if isinstance(kv_cache_spec, AttentionSpec):
attn_backend_i = get_attn_backend(
kv_cache_spec.head_size,
self.dtype,
kv_cache_spec.dtype,
kv_cache_spec.block_size,
self.model_config.is_attention_free,
use_mla=kv_cache_spec.use_mla,
)
if attn_backend_i is None:
error_msg = (f"Error with get_attn_backend: "
f"{kv_cache_spec.head_size=}, "
f"{self.dtype=}, {kv_cache_spec.dtype=}, "
f"{kv_cache_spec.block_size=}, "
f"{self.model_config.is_attention_free=}, "
f"{kv_cache_spec.use_mla=}")
logger.error(error_msg)
raise NotImplementedError(
"Non-Attention backend is not supported by V1 "
"GPUModelRunner.")
elif isinstance(kv_cache_spec, MambaSpec):
attn_backend_i = Mamba2AttentionBackend
else:
raise ValueError(
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
block_table_i = self.input_batch.block_table[i]
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
@ -2242,6 +2248,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_caches[layer_name] = kv_cache_raw_tensors[
layer_name].view(dtype).view(kv_cache_shape).permute(
*inv_order)
elif isinstance(kv_cache_spec, MambaSpec):
raw_tensor = kv_cache_raw_tensors[layer_name]
dtype = kv_cache_spec.dtype
state_tensors = []
start_pos = 0
for shape in kv_cache_spec.shapes:
target_shape = (num_blocks, *shape)
size_in_bytes = np.prod(shape) * get_dtype_size(
dtype) * num_blocks
tensor = raw_tensor[start_pos:start_pos +
size_in_bytes]
tensor = tensor.view(dtype).view(target_shape)
state_tensors.append(tensor)
start_pos += size_in_bytes
assert start_pos == raw_tensor.numel()
kv_caches[layer_name] = tuple(state_tensors)
else:
raise NotImplementedError
return kv_caches
@ -2307,11 +2329,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
format. Layers that do not need KV cache are not included.
"""
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in layers.items():
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
if (kv_tgt_layer :=
attn_module.kv_sharing_target_layer_name) is not None:
# The layer doesn't need its own KV cache and will use that of
@ -2351,4 +2373,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
raise ValueError(
f"Unknown attention type: {attn_module.attn_type}")
mamba_layers = get_layers_from_vllm_config(self.vllm_config,
MambaMixer2)
if len(mamba_layers) > 0:
if self.vllm_config.speculative_config is not None:
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet.")
if not self.vllm_config.model_config.enforce_eager:
raise NotImplementedError(
"Mamba with cuda graph is not supported yet.")
if self.vllm_config.cache_config.enable_prefix_caching:
raise NotImplementedError(
"Prefix caching is not supported for Mamba yet.")
max_model_len = self.vllm_config.model_config.max_model_len
# Set block_size to max_model_len, so that mamba model will always
# have only one block in the KV cache.
for layer_name, mamba_module in mamba_layers.items():
kv_cache_spec[layer_name] = MambaSpec(
shapes=mamba_module.get_state_shape(),
dtype=self.kv_cache_dtype,
block_size=max_model_len)
return kv_cache_spec