mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 01:15:26 +08:00
[v1] Support mamba2 (#19327)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
ffacb222cb
commit
a89209b78d
@ -17,9 +17,10 @@ SSM_MODELS = [
|
||||
"state-spaces/mamba-130m-hf",
|
||||
"tiiuae/falcon-mamba-tiny-dev",
|
||||
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
|
||||
# Mamba2 is buggy for Codestral as it doesn't handle n_groups.
|
||||
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
|
||||
# doesn't compare vLLM output with HF output.
|
||||
# See https://github.com/huggingface/transformers/pull/35943
|
||||
# "mistralai/Mamba-Codestral-7B-v0.1",
|
||||
"mistralai/Mamba-Codestral-7B-v0.1",
|
||||
]
|
||||
|
||||
HYBRID_MODELS = [
|
||||
@ -35,6 +36,10 @@ HYBRID_MODELS = [
|
||||
"hmellor/tiny-random-BambaForCausalLM",
|
||||
]
|
||||
|
||||
V1_SUPPORTED_MODELS = [
|
||||
"mistralai/Mamba-Codestral-7B-v0.1",
|
||||
]
|
||||
|
||||
# Avoid OOM
|
||||
MAX_NUM_SEQS = 4
|
||||
|
||||
@ -46,24 +51,50 @@ def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
monkeypatch,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
if model != "mistralai/Mamba-Codestral-7B-v0.1":
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
hf_outputs = None
|
||||
|
||||
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
vllm_v0_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
if model in V1_SUPPORTED_MODELS:
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=False) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
else:
|
||||
vllm_v1_outputs = None
|
||||
|
||||
if hf_outputs is not None:
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_v0_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm-v0",
|
||||
)
|
||||
|
||||
if model in V1_SUPPORTED_MODELS:
|
||||
ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=ref_outputs,
|
||||
outputs_1_lst=vllm_v1_outputs,
|
||||
name_0="hf" if hf_outputs is not None else "vllm-v0",
|
||||
name_1="vllm-v1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", SSM_MODELS + HYBRID_MODELS)
|
||||
|
||||
@ -12,7 +12,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
UNSUPPORTED_MODELS_V1 = [
|
||||
"openai/whisper-large-v3", # transcription
|
||||
"facebook/bart-large-cnn", # encoder decoder
|
||||
"mistralai/Mamba-Codestral-7B-v0.1", # mamba
|
||||
"state-spaces/mamba-130m-hf", # mamba1
|
||||
"hmellor/tiny-random-BambaForCausalLM", # hybrid
|
||||
"BAAI/bge-m3", # embedding
|
||||
]
|
||||
|
||||
@ -1355,12 +1355,17 @@ class EngineArgs:
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No Mamba or Encoder-Decoder so far.
|
||||
# No Encoder-Decoder, not all Mamba so far.
|
||||
if not model_config.is_v1_compatible:
|
||||
_raise_or_fallback(feature_name=model_config.architectures,
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# V1 mamba models are unoptimized.
|
||||
if model_config.has_inner_state and _warn_or_fallback(
|
||||
feature_name="Mamba"):
|
||||
return False
|
||||
|
||||
# No Concurrent Partial Prefills so far.
|
||||
if (self.max_num_partial_prefills
|
||||
!= SchedulerConfig.max_num_partial_prefills
|
||||
|
||||
@ -6,7 +6,9 @@ from typing import Optional, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
@ -27,6 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
LoaderFunction, composed_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata
|
||||
|
||||
# Added by the IBM Team, 2024
|
||||
|
||||
@ -227,20 +230,22 @@ class MambaMixer2(CustomOp):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
n_groups: int = 1,
|
||||
num_heads: int = 128,
|
||||
head_dim: int = 64,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation: str = "silu",
|
||||
use_rms_norm: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
self,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
n_groups: int = 1,
|
||||
num_heads: int = 128,
|
||||
head_dim: int = 64,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation: str = "silu",
|
||||
use_rms_norm: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
chunk_size: int = -1, # the chunk size used by v1
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -273,6 +278,7 @@ class MambaMixer2(CustomOp):
|
||||
), "Tensor parallel currently not supported for quantized models."
|
||||
|
||||
self.ssm_state_size = ssm_state_size
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.activation = activation
|
||||
|
||||
self.intermediate_size = intermediate_size
|
||||
@ -411,6 +417,22 @@ class MambaMixer2(CustomOp):
|
||||
self.use_rms_norm,
|
||||
eps=rms_norm_eps)
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The outer list is for v0 PP virtual engine. Though this code path
|
||||
# only runs for v1, we have to do this to unify with the interface
|
||||
# of Attention + v0 PP.
|
||||
# The inner tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
|
||||
assert chunk_size != -1, "chunk_size must be set for v1"
|
||||
|
||||
# NOTE: chunk_size may be -1 for models without v1 support
|
||||
self.chunk_size = chunk_size
|
||||
self.prefix = prefix
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -426,17 +448,37 @@ class MambaMixer2(CustomOp):
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
# mamba2_metadata contains metadata necessary for the mamba2 triton
|
||||
# kernels to operate in continuous batching and in chunked prefill
|
||||
# modes; they are computed at top-level model forward since they
|
||||
# stay the same and reused for all mamba layers in the same iteration
|
||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||
|
||||
num_prefills = attn_metadata.num_prefills # request count
|
||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||
has_prefill = num_prefills > 0
|
||||
has_decode = num_decodes > 0
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
if envs.VLLM_USE_V1:
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0]
|
||||
ssm_state = self_kv_cache[1]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
has_initial_states_p = attn_metadata.has_initial_states
|
||||
prep_initial_states = attn_metadata.prep_initial_states
|
||||
chunk_size = attn_metadata.chunk_size
|
||||
seq_idx_p = attn_metadata.seq_idx
|
||||
chunk_indices_p = attn_metadata.chunk_indices
|
||||
chunk_offsets_p = attn_metadata.chunk_offsets
|
||||
else:
|
||||
conv_state = mamba_cache_params.conv_state
|
||||
ssm_state = mamba_cache_params.ssm_state
|
||||
state_indices_tensor = mamba_cache_params.state_indices_tensor
|
||||
has_initial_states_p = mamba2_metadata.has_initial_states
|
||||
prep_initial_states = mamba2_metadata.prep_initial_states
|
||||
chunk_size = mamba2_metadata.chunk_size
|
||||
seq_idx_p = mamba2_metadata.seq_idx
|
||||
chunk_indices_p = mamba2_metadata.chunk_indices
|
||||
chunk_offsets_p = mamba2_metadata.chunk_offsets
|
||||
|
||||
groups_time_state_size = self.n_groups * self.ssm_state_size
|
||||
|
||||
@ -459,27 +501,6 @@ class MambaMixer2(CustomOp):
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
# Split along token dimension
|
||||
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
|
||||
hidden_states_B_C,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
dt_p, dt_d = torch.split(
|
||||
dt,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
||||
mamba_cache_params.state_indices_tensor,
|
||||
[num_prefills, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills + 1]
|
||||
if has_prefill else None)
|
||||
|
||||
# - get hidden_states, B and C after depthwise convolution.
|
||||
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
|
||||
hidden_states_B_C,
|
||||
@ -491,20 +512,80 @@ class MambaMixer2(CustomOp):
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_V1 and attn_metadata is None:
|
||||
# V1 profile run
|
||||
hidden_states_B_C = (hidden_states_B_C.transpose(
|
||||
0, 1).clone().transpose(0, 1)).contiguous()
|
||||
hidden_states, _B, _C = split_hidden_states_B_C_fn(
|
||||
hidden_states_B_C)
|
||||
hidden_states = self.norm(hidden_states, gate)
|
||||
out, _ = self.out_proj(hidden_states)
|
||||
return out
|
||||
|
||||
num_prefills = attn_metadata.num_prefills # request count
|
||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||
has_prefill = num_prefills > 0
|
||||
has_decode = num_decodes > 0
|
||||
|
||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
# Split along token dimension
|
||||
if envs.VLLM_USE_V1:
|
||||
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
||||
hidden_states_B_C,
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
dt_d, dt_p = torch.split(
|
||||
dt,
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor,
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (
|
||||
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
||||
num_decodes if has_prefill else None)
|
||||
else:
|
||||
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
|
||||
hidden_states_B_C,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
dt_p, dt_d = torch.split(
|
||||
dt,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
||||
state_indices_tensor,
|
||||
[num_prefills, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (attn_metadata.query_start_loc[:num_prefills +
|
||||
1]
|
||||
if has_prefill else None)
|
||||
|
||||
ssd_output_list = []
|
||||
|
||||
# Process prefill requests
|
||||
if has_prefill:
|
||||
# 2. Convolution sequence transformation
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "mamba_cache_params.state_indices_tensor"
|
||||
# pointed to by "state_indices_tensor"
|
||||
hidden_states_B_C_p = causal_conv1d_fn(
|
||||
hidden_states_B_C_p.transpose(0, 1),
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=mamba_cache_params.conv_state,
|
||||
has_initial_state=mamba2_metadata.has_initial_states,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
query_start_loc=query_start_loc_p).transpose(
|
||||
0, 1)[:num_prefill_tokens]
|
||||
@ -516,12 +597,11 @@ class MambaMixer2(CustomOp):
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
initial_states = None
|
||||
if (mamba2_metadata.has_initial_states is not None
|
||||
and mamba2_metadata.prep_initial_states):
|
||||
if (has_initial_states_p is not None and prep_initial_states):
|
||||
# making a copy of the states
|
||||
initial_states = torch.where(
|
||||
mamba2_metadata.has_initial_states[:, None, None, None],
|
||||
mamba_cache_params.ssm_state[state_indices_tensor_p], 0)
|
||||
has_initial_states_p[:, None, None, None],
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
|
||||
scan_output, varlen_state = mamba_chunk_scan_combined(
|
||||
hidden_states_p.view(1, num_prefill_tokens,
|
||||
@ -533,14 +613,14 @@ class MambaMixer2(CustomOp):
|
||||
-1),
|
||||
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
|
||||
-1),
|
||||
chunk_size=mamba2_metadata.chunk_size,
|
||||
chunk_size=chunk_size,
|
||||
D=self.D,
|
||||
z=None,
|
||||
dt_bias=self.dt_bias,
|
||||
seq_idx=mamba2_metadata.seq_idx,
|
||||
chunk_indices=mamba2_metadata.chunk_indices,
|
||||
chunk_offsets=mamba2_metadata.chunk_offsets,
|
||||
cu_seqlens=attn_metadata.query_start_loc[:num_prefills + 1],
|
||||
seq_idx=seq_idx_p,
|
||||
chunk_indices=chunk_indices_p,
|
||||
chunk_offsets=chunk_offsets_p,
|
||||
cu_seqlens=query_start_loc_p,
|
||||
initial_states=initial_states,
|
||||
return_varlen_states=True,
|
||||
return_final_states=False,
|
||||
@ -550,7 +630,7 @@ class MambaMixer2(CustomOp):
|
||||
|
||||
# update ssm states
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
||||
mamba_cache_params.ssm_state[state_indices_tensor_p] = varlen_state
|
||||
ssm_state[state_indices_tensor_p] = varlen_state
|
||||
|
||||
# - reshape
|
||||
ssd_output_list.append(scan_output.view(num_prefill_tokens, -1))
|
||||
@ -560,7 +640,7 @@ class MambaMixer2(CustomOp):
|
||||
# 2. Convolution sequence transformation
|
||||
hidden_states_B_C_d = causal_conv1d_update(
|
||||
hidden_states_B_C_d,
|
||||
mamba_cache_params.conv_state,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
@ -586,7 +666,7 @@ class MambaMixer2(CustomOp):
|
||||
# using state_indices_tensor_d
|
||||
|
||||
hidden_states_d = selective_state_update(
|
||||
mamba_cache_params.ssm_state,
|
||||
ssm_state,
|
||||
hidden_states_d,
|
||||
dt_d,
|
||||
A_d,
|
||||
@ -598,9 +678,16 @@ class MambaMixer2(CustomOp):
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor_d,
|
||||
)
|
||||
ssd_output_list.append(
|
||||
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
|
||||
self.head_dim))
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
ssd_output_list.insert(
|
||||
0,
|
||||
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
|
||||
self.head_dim))
|
||||
else:
|
||||
ssd_output_list.append(
|
||||
hidden_states_d.view(-1, (self.num_heads // self.tp_size) *
|
||||
self.head_dim))
|
||||
|
||||
# Merge prefill and decode outputs before passing to gated MLP
|
||||
hidden_states = torch.vstack(ssd_output_list)
|
||||
@ -614,3 +701,31 @@ class MambaMixer2(CustomOp):
|
||||
# 5. Final linear projection
|
||||
out, _ = self.out_proj(hidden_states)
|
||||
return out
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
conv_state_shape, temporal_state_shape = None, None
|
||||
|
||||
# if n_groups is not divisible by world_size, need to extend the shards
|
||||
# to ensure all groups needed by a head is sharded along with it
|
||||
n_groups = (self.n_groups +
|
||||
extra_groups_for_head_shards(self.n_groups, world_size))
|
||||
|
||||
# - heads and n_groups are TP-ed
|
||||
conv_dim = (self.intermediate_size +
|
||||
2 * n_groups * self.ssm_state_size)
|
||||
conv_state_shape = (
|
||||
divide(conv_dim, world_size),
|
||||
self.conv_kernel_size - 1,
|
||||
)
|
||||
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
# - they are typically small
|
||||
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
|
||||
temporal_state_shape = (
|
||||
divide(self.num_heads, world_size),
|
||||
self.head_dim,
|
||||
self.ssm_state_size,
|
||||
)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@ -8,6 +8,7 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import MambaConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
@ -25,8 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (HasInnerState,
|
||||
IsAttentionFree,
|
||||
SupportsV0Only)
|
||||
IsAttentionFree)
|
||||
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
|
||||
MambaCacheParams)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@ -44,7 +44,8 @@ class Mamba2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: MambaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.mixer = MambaMixer2(hidden_size=config.hidden_size,
|
||||
@ -60,7 +61,9 @@ class Mamba2DecoderLayer(nn.Module):
|
||||
head_dim=config.head_dim,
|
||||
rms_norm_eps=config.layer_norm_epsilon,
|
||||
activation=config.hidden_act,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mixer",
|
||||
chunk_size=config.chunk_size)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
@ -108,8 +111,8 @@ class Mamba2Model(nn.Module):
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Mamba2DecoderLayer(config,
|
||||
quant_config=quant_config),
|
||||
lambda prefix: Mamba2DecoderLayer(
|
||||
config, quant_config=quant_config, prefix=prefix),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
self.norm_f = RMSNorm(config.hidden_size,
|
||||
@ -142,10 +145,14 @@ class Mamba2Model(nn.Module):
|
||||
|
||||
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
|
||||
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
if not envs.VLLM_USE_V1:
|
||||
mamba2_metadata = prepare_mamba2_metadata(
|
||||
chunk_size=self.config.chunk_size,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
else:
|
||||
# v1 get mamba2_metadata from forward_context
|
||||
mamba2_metadata = None
|
||||
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
@ -155,7 +162,7 @@ class Mamba2Model(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
mamba_cache_params=mamba_cache_params.at_layer_idx(
|
||||
i - self.start_layer),
|
||||
i - self.start_layer) if mamba_cache_params else None,
|
||||
mamba2_metadata=mamba2_metadata)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
@ -190,8 +197,7 @@ class Mamba2Model(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
|
||||
SupportsV0Only):
|
||||
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
config = vllm_config.model_config.hf_config
|
||||
@ -242,14 +248,20 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs):
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config, LayerBlockType.mamba)
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
|
||||
*self._get_mamba_cache_shape())
|
||||
if not envs.VLLM_USE_V1:
|
||||
if self.mamba_cache is None:
|
||||
num_mamba_layers = (
|
||||
self.model_config.get_num_layers_by_block_type(
|
||||
self.vllm_config.parallel_config,
|
||||
LayerBlockType.mamba))
|
||||
self.mamba_cache = MambaCacheManager(
|
||||
self.vllm_config, self.lm_head.weight.dtype,
|
||||
num_mamba_layers, *self._get_mamba_cache_shape())
|
||||
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
|
||||
else:
|
||||
# NOTE: mamba_cache_params is not needed for v1
|
||||
mamba_cache_params = None
|
||||
|
||||
hidden_states = self.backbone(input_ids, positions, mamba_cache_params,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
192
vllm/v1/attention/backends/mamba_attn.py
Normal file
192
vllm/v1/attention/backends/mamba_attn.py
Normal file
@ -0,0 +1,192 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import (
|
||||
_query_start_loc_to_chunk_indices_offsets)
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.kv_cache_interface import MambaSpec
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
|
||||
def get_mamba2_chunk_size(vllm_config: VllmConfig) -> int:
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
layers = get_layers_from_vllm_config(vllm_config, MambaMixer2)
|
||||
chunk_sizes = set(layer.chunk_size for layer in layers.values())
|
||||
assert len(
|
||||
chunk_sizes) == 1, "All Mamba2 layers must have the same chunk size"
|
||||
return chunk_sizes.pop()
|
||||
|
||||
|
||||
class Mamba2AttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
|
||||
return Mamba2AttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mamba2AttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
query_start_loc: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
has_initial_states: torch.Tensor
|
||||
prep_initial_states: bool
|
||||
chunk_size: int
|
||||
seq_idx: torch.Tensor
|
||||
chunk_indices: torch.Tensor
|
||||
chunk_offsets: torch.Tensor
|
||||
|
||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||
|
||||
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner", kv_cache_spec: MambaSpec,
|
||||
block_table: BlockTable):
|
||||
self.runner = runner
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_table = block_table
|
||||
self.chunk_size = get_mamba2_chunk_size(runner.vllm_config)
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
# NOTE (Chen): Copied from MLACommonMetadataBuilder and
|
||||
# FlashInferMetadataBuilder. Should be refactored later to avoid code
|
||||
# duplication of these 3 functions.
|
||||
# We now want to reorder the batch so that the "decode" requests are and
|
||||
# the front and the "prefill" requests are at the using the least amount
|
||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||
# where attention is likely memory-bound and "prefill" to mean requests
|
||||
# where attention is likely compute-bound, TODO(lucas): figure out a
|
||||
# better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
num_decode_tokens = 0
|
||||
num_prefill_tokens = 0
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
# for now treat 1 scheduled token as "decode" even if its not,
|
||||
# we should update this to something like < 8 in the future but
|
||||
# currently the decode run only supports num_tokens = 1
|
||||
if num_tokens == 1:
|
||||
decodes.append(i)
|
||||
num_decode_tokens += num_tokens
|
||||
else:
|
||||
prefills.append(i)
|
||||
num_prefill_tokens += num_tokens
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
decode_idx = decodes[num_decodes - i]
|
||||
if decode_idx < num_decodes:
|
||||
break
|
||||
|
||||
input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||
modified_batch = True
|
||||
|
||||
# Save for next `build` call
|
||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||
# better way of doing this
|
||||
self._num_decodes = num_decodes
|
||||
self._num_prefills = num_prefills
|
||||
self._num_decode_tokens = num_decode_tokens
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
|
||||
return modified_batch
|
||||
|
||||
def build(self, common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata):
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
seq_idx = None
|
||||
chunk_indices, chunk_offsets = None, None
|
||||
# Need flags to indicate if there are initial states
|
||||
# currently we really only support the FlashAttention backend
|
||||
has_initial_states = None
|
||||
prep_initial_states = False
|
||||
|
||||
state_indices_tensor = self.block_table.block_table[:num_reqs, 0]
|
||||
|
||||
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
|
||||
if self._num_prefills > 0:
|
||||
#[batch,]
|
||||
has_initial_states_cpu = (
|
||||
self.runner.input_batch.
|
||||
num_computed_tokens_cpu_tensor[num_reqs -
|
||||
self._num_prefills:num_reqs]
|
||||
> 0)
|
||||
prep_initial_states = torch.any(has_initial_states_cpu).item()
|
||||
has_initial_states = has_initial_states_cpu.to(
|
||||
query_start_loc.device)
|
||||
|
||||
query_start_loc_p = common_attn_metadata.query_start_loc[
|
||||
-self._num_prefills - 1:] - self._num_decode_tokens
|
||||
|
||||
seq_idx = torch.repeat_interleave(
|
||||
torch.arange(self._num_prefills,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc_p.device),
|
||||
query_start_loc_p.diff(),
|
||||
output_size=self._num_prefill_tokens)
|
||||
seq_idx.unsqueeze_(0)
|
||||
|
||||
# We compute metadata for chunked prefill once at the top level
|
||||
# model forward and reuse them in mamba layers. If not needed,
|
||||
# they will be ignored inside mamba kernels.
|
||||
if prep_initial_states:
|
||||
chunk_indices, chunk_offsets = (
|
||||
_query_start_loc_to_chunk_indices_offsets(
|
||||
query_start_loc_p, self.chunk_size,
|
||||
self._num_prefill_tokens))
|
||||
|
||||
attn_metadata = Mamba2AttentionMetadata(
|
||||
num_prefills=self._num_prefills,
|
||||
num_prefill_tokens=self._num_prefill_tokens,
|
||||
num_decodes=self._num_decodes,
|
||||
num_decode_tokens=self._num_decode_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
has_initial_states=has_initial_states,
|
||||
prep_initial_states=prep_initial_states,
|
||||
chunk_size=self.chunk_size,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
)
|
||||
return attn_metadata
|
||||
@ -8,7 +8,7 @@ from vllm.utils import cdiv
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
|
||||
SlidingWindowSpec)
|
||||
MambaSpec, SlidingWindowSpec)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
@ -52,6 +52,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
|
||||
self.caching_hash_fn = caching_hash_fn
|
||||
self.kv_cache_group_id = kv_cache_group_id
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
def get_num_blocks_to_allocate(
|
||||
self, request_id: str, num_tokens: int,
|
||||
@ -390,9 +391,49 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
return 0
|
||||
|
||||
|
||||
class MambaManager(SingleTypeKVCacheManager):
|
||||
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(
|
||||
kv_cache_spec,
|
||||
MambaSpec), ("MambaManager can only be used for mamba groups")
|
||||
# Prefix caching is not supported for mamba now. Always return empty
|
||||
# list.
|
||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||
[] for _ in range(len(kv_cache_group_ids)))
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str,
|
||||
num_computed_tokens: int) -> None:
|
||||
# Each request will always have 1 block at this moment, so no need to
|
||||
# remove blocks.
|
||||
pass
|
||||
|
||||
def get_num_common_prefix_blocks(self, request_id: str,
|
||||
num_running_requests: int) -> int:
|
||||
return 0
|
||||
|
||||
def allocate_new_blocks(self, request_id: str,
|
||||
num_tokens: int) -> list[KVCacheBlock]:
|
||||
new_blocks = super().allocate_new_blocks(request_id, num_tokens)
|
||||
assert len(self.req_to_blocks[request_id]) == 1, (
|
||||
"MambaManager should only allocate 1 block for each request.")
|
||||
return new_blocks
|
||||
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
SlidingWindowSpec: SlidingWindowManager,
|
||||
MambaSpec: MambaManager,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from math import prod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -154,6 +155,29 @@ class SlidingWindowSpec(AttentionSpec):
|
||||
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaSpec(KVCacheSpec):
|
||||
shapes: tuple[tuple[int, ...], ...]
|
||||
dtype: torch.dtype
|
||||
|
||||
def __post_init__(self):
|
||||
self.num_elements = sum(prod(shape) for shape in self.shapes)
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
return f"mamba_{self.shapes}_{self.dtype}"
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
return self.num_elements * get_dtype_size(self.dtype)
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
# We allocate 1 block for each request now, so max_memory_usage_bytes is
|
||||
# the same as page_size_bytes.
|
||||
# Need to update this when supporting prefix caching.
|
||||
return self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheTensor:
|
||||
"""
|
||||
|
||||
@ -29,6 +29,7 @@ from vllm.distributed.parallel_state import (
|
||||
from vllm.forward_context import (DPMetadata, get_forward_context,
|
||||
set_forward_context)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -38,12 +39,14 @@ from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
|
||||
check_use_alibi, is_pin_memory_available)
|
||||
check_use_alibi, get_dtype_size,
|
||||
is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
|
||||
KVCacheConfig, KVCacheSpec,
|
||||
KVCacheConfig, KVCacheSpec, MambaSpec,
|
||||
SlidingWindowSpec)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
ModelRunnerOutput)
|
||||
@ -2093,28 +2096,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
for i, kv_cache_group_spec in enumerate(
|
||||
kv_cache_config.kv_cache_groups):
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
if not isinstance(kv_cache_spec, AttentionSpec):
|
||||
raise NotImplementedError(
|
||||
"Only AttentionSpec is supported for now.")
|
||||
attn_backend_i = get_attn_backend(
|
||||
kv_cache_spec.head_size,
|
||||
self.dtype,
|
||||
kv_cache_spec.dtype,
|
||||
kv_cache_spec.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=kv_cache_spec.use_mla,
|
||||
)
|
||||
if attn_backend_i is None:
|
||||
error_msg = (
|
||||
f"Error with get_attn_backend: {kv_cache_spec.head_size=}, "
|
||||
f"{self.dtype=}, {kv_cache_spec.dtype=}, "
|
||||
f"{kv_cache_spec.block_size=}, "
|
||||
f"{self.model_config.is_attention_free=}, "
|
||||
f"{kv_cache_spec.use_mla=}")
|
||||
logger.error(error_msg)
|
||||
raise NotImplementedError(
|
||||
"Non-Attention backend is not supported by V1 "
|
||||
"GPUModelRunner.")
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
attn_backend_i = get_attn_backend(
|
||||
kv_cache_spec.head_size,
|
||||
self.dtype,
|
||||
kv_cache_spec.dtype,
|
||||
kv_cache_spec.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=kv_cache_spec.use_mla,
|
||||
)
|
||||
if attn_backend_i is None:
|
||||
error_msg = (f"Error with get_attn_backend: "
|
||||
f"{kv_cache_spec.head_size=}, "
|
||||
f"{self.dtype=}, {kv_cache_spec.dtype=}, "
|
||||
f"{kv_cache_spec.block_size=}, "
|
||||
f"{self.model_config.is_attention_free=}, "
|
||||
f"{kv_cache_spec.use_mla=}")
|
||||
logger.error(error_msg)
|
||||
raise NotImplementedError(
|
||||
"Non-Attention backend is not supported by V1 "
|
||||
"GPUModelRunner.")
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
attn_backend_i = Mamba2AttentionBackend
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown KV cache spec type: {type(kv_cache_spec)}")
|
||||
|
||||
block_table_i = self.input_batch.block_table[i]
|
||||
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
|
||||
@ -2242,6 +2248,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv_caches[layer_name] = kv_cache_raw_tensors[
|
||||
layer_name].view(dtype).view(kv_cache_shape).permute(
|
||||
*inv_order)
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
dtype = kv_cache_spec.dtype
|
||||
state_tensors = []
|
||||
start_pos = 0
|
||||
for shape in kv_cache_spec.shapes:
|
||||
target_shape = (num_blocks, *shape)
|
||||
size_in_bytes = np.prod(shape) * get_dtype_size(
|
||||
dtype) * num_blocks
|
||||
tensor = raw_tensor[start_pos:start_pos +
|
||||
size_in_bytes]
|
||||
tensor = tensor.view(dtype).view(target_shape)
|
||||
state_tensors.append(tensor)
|
||||
start_pos += size_in_bytes
|
||||
assert start_pos == raw_tensor.numel()
|
||||
kv_caches[layer_name] = tuple(state_tensors)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return kv_caches
|
||||
@ -2307,11 +2329,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
|
||||
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
for layer_name, attn_module in layers.items():
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if (kv_tgt_layer :=
|
||||
attn_module.kv_sharing_target_layer_name) is not None:
|
||||
# The layer doesn't need its own KV cache and will use that of
|
||||
@ -2351,4 +2373,24 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
raise ValueError(
|
||||
f"Unknown attention type: {attn_module.attn_type}")
|
||||
|
||||
mamba_layers = get_layers_from_vllm_config(self.vllm_config,
|
||||
MambaMixer2)
|
||||
if len(mamba_layers) > 0:
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
raise NotImplementedError(
|
||||
"Mamba with speculative decoding is not supported yet.")
|
||||
if not self.vllm_config.model_config.enforce_eager:
|
||||
raise NotImplementedError(
|
||||
"Mamba with cuda graph is not supported yet.")
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported for Mamba yet.")
|
||||
max_model_len = self.vllm_config.model_config.max_model_len
|
||||
# Set block_size to max_model_len, so that mamba model will always
|
||||
# have only one block in the KV cache.
|
||||
for layer_name, mamba_module in mamba_layers.items():
|
||||
kv_cache_spec[layer_name] = MambaSpec(
|
||||
shapes=mamba_module.get_state_shape(),
|
||||
dtype=self.kv_cache_dtype,
|
||||
block_size=max_model_len)
|
||||
return kv_cache_spec
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user