Enable V1 for Hybrid SSM/Attention Models (#20016)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Stanislaw Wozniak <stw@zurich.ibm.com>
Co-authored-by: Tyler Michael Smith <tysmith@redhat.com>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Thomas Parnell 2025-07-04 19:46:53 +02:00 committed by GitHub
parent ffe00ef77a
commit 2f35a022e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 399 additions and 134 deletions

View File

@ -3,6 +3,7 @@
import pytest
from tests.models.registry import HF_EXAMPLE_MODELS
from tests.utils import multi_gpu_test
from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams
@ -19,31 +20,55 @@ pytestmark = pytest.mark.hybrid_model
SSM_MODELS = [
"state-spaces/mamba-130m-hf",
"tiiuae/falcon-mamba-tiny-dev",
# TODO: Compare to a Mamba2 model. The HF transformers implementation of
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
# doesn't compare vLLM output with HF output.
# See https://github.com/huggingface/transformers/pull/35943
"mistralai/Mamba-Codestral-7B-v0.1",
]
HYBRID_MODELS = [
"ai21labs/Jamba-tiny-dev",
# NOTE: Currently the test failes due to HF transformers issue fixed in:
# https://github.com/huggingface/transformers/pull/39033
# We will enable vLLM test for Granite after next HF transformers release.
# "ibm-granite/granite-4.0-tiny-preview",
# NOTE: Running Plamo2 in transformers implementation requires to install
# causal-conv1d package, which is not listed as a test dependency as it's
# not compatible with pip-compile.
"pfnet/plamo-2-1b",
"Zyphra/Zamba2-1.2B-instruct",
"hmellor/tiny-random-BambaForCausalLM",
"ibm-ai-platform/Bamba-9B-v1",
"nvidia/Nemotron-H-8B-Base-8K",
"ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base",
]
HF_UNSUPPORTED_MODELS = [
# The HF transformers implementation of
# Mamba2 is buggy for Codestral as it doesn't handle n_groups, so the test
# doesn't compare vLLM output with HF output.
# See https://github.com/huggingface/transformers/pull/35943
"mistralai/Mamba-Codestral-7B-v0.1",
# Note: I'm not seeing the same output from vLLM V0 vs. HF transformers
# for Nemotron-H-8B; currently only compare vLLM V0 vs. vLLM V1
"nvidia/Nemotron-H-8B-Base-8K",
# NOTE: Currently the test fails due to HF transformers issue fixed in:
# https://github.com/huggingface/transformers/pull/39033
# We will enable vLLM test for Granite after next HF transformers release.
"ibm-granite/granite-4.0-tiny-preview",
]
V1_SUPPORTED_MODELS = [
"mistralai/Mamba-Codestral-7B-v0.1",
"ibm-ai-platform/Bamba-9B-v1",
"Zyphra/Zamba2-1.2B-instruct",
"nvidia/Nemotron-H-8B-Base-8K",
"ibm-granite/granite-4.0-tiny-preview",
"tiiuae/Falcon-H1-0.5B-Base",
]
ATTN_BLOCK_SIZES = {
"ibm-ai-platform/Bamba-9B-v1": 528,
"Zyphra/Zamba2-1.2B-instruct": 80,
"nvidia/Nemotron-H-8B-Base-8K": 528,
"ibm-granite/granite-4.0-tiny-preview": 400,
"tiiuae/Falcon-H1-0.5B-Base": 800,
}
# Avoid OOM
MAX_NUM_SEQS = 4
@ -60,8 +85,16 @@ def test_models(
max_tokens: int,
num_logprobs: int,
) -> None:
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
except ValueError:
pass
with hf_runner(model) as hf_model:
if model != "mistralai/Mamba-Codestral-7B-v0.1":
if model not in HF_UNSUPPORTED_MODELS:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
else:
@ -72,12 +105,21 @@ def test_models(
example_prompts, max_tokens, num_logprobs)
if model in V1_SUPPORTED_MODELS:
if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES:
block_size = ATTN_BLOCK_SIZES[model]
else:
block_size = 16
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
# required due to reorder_batch behaviour
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
enforce_eager=True,
enable_prefix_caching=False) as vllm_model:
enable_prefix_caching=False,
block_size=block_size) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:
@ -111,6 +153,14 @@ def test_batching(
max_tokens: int,
num_logprobs: int,
) -> None:
try:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")
except ValueError:
pass
for_loop_outputs = []
with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model:
for prompt in example_prompts:

View File

@ -169,7 +169,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"ExaoneForCausalLM": _HfExamplesInfo("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"), # noqa: E501
"Fairseq2LlamaForCausalLM": _HfExamplesInfo("mgleize/fairseq2-dummy-Llama-3.2-1B"), # noqa: E501
"FalconForCausalLM": _HfExamplesInfo("tiiuae/falcon-7b"),
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-1.5B-Instruct",
"FalconH1ForCausalLM":_HfExamplesInfo("tiiuae/Falcon-H1-0.5B-Base",
min_transformers_version="4.53"),
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),

View File

@ -13,7 +13,6 @@ UNSUPPORTED_MODELS_V1 = [
"openai/whisper-large-v3", # transcription
"facebook/bart-large-cnn", # encoder decoder
"state-spaces/mamba-130m-hf", # mamba1
"hmellor/tiny-random-BambaForCausalLM", # hybrid
"BAAI/bge-m3", # embedding
]

View File

@ -108,7 +108,7 @@ def _selective_scan_update_kernel(
# is the same as the batch id.
if HAS_STATE_BATCH_INDICES:
state_batch_indices_ptr += pid_b
state_batch_idx = tl.load(state_batch_indices_ptr)
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
state_ptr += (state_batch_idx * stride_state_batch +
pid_h * stride_state_head)
else:

View File

@ -9,6 +9,7 @@ import torch
from torch import nn
from transformers import BambaConfig
from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@ -36,7 +37,7 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import LayerBlockType
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsQuant, SupportsV0Only)
SupportsQuant)
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -97,7 +98,9 @@ class BambaMixerDecoderLayer(nn.Module):
head_dim=config.mamba_d_head,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.mamba_chunk_size)
self.feed_forward = BambaMLP(config, quant_config=quant_config)
self.input_layernorm = RMSNorm(config.hidden_size,
@ -313,10 +316,14 @@ class BambaModel(nn.Module):
attn_metadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
@ -337,7 +344,8 @@ class BambaModel(nn.Module):
num_attn += 1
layer_mamba_cache_params = None
if isinstance(layer, BambaMixerDecoderLayer):
if isinstance(layer,
BambaMixerDecoderLayer) and mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
i - num_attn)
@ -411,7 +419,7 @@ class BambaModel(nn.Module):
class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid, SupportsV0Only, SupportsQuant):
IsHybrid, SupportsQuant):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -475,15 +483,22 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = \
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba
)
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype,
num_mamba_layers, *self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)

View File

@ -8,6 +8,7 @@ import torch
from torch import nn
from transformers import FalconH1Config
from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@ -33,8 +34,7 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsV0Only)
from .interfaces import HasInnerState, IsHybrid, SupportsLoRA, SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -85,6 +85,7 @@ class FalconH1SSMDecoderLayer(nn.Module):
config: FalconH1Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
@ -107,6 +108,8 @@ class FalconH1SSMDecoderLayer(nn.Module):
activation=config.hidden_act,
quant_config=quant_config,
use_rms_norm=config.mamba_rms_norm,
prefix=f"{prefix}.mixer",
chunk_size=config.mamba_chunk_size,
)
# n_groups is overridden later by `MambaMixer2`
self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state
@ -316,6 +319,7 @@ class FalconH1ParallelHybrid(nn.Module):
prefix: str = "",
) -> None:
super().__init__()
# Instantiate the attention branch
self.self_attn = FalconH1AttentionDecoderLayer(
config=config,
@ -323,11 +327,18 @@ class FalconH1ParallelHybrid(nn.Module):
quant_config=quant_config,
prefix=prefix,
)
# In V1 all attention/ssm layers must have
# different index in prefix
ssm_layer_idx = config.num_hidden_layers + layer_idx
ssm_prefix = prefix.split(".")[0] + f".{ssm_layer_idx}"
# Instantiate the SSM branch
self.mamba = FalconH1SSMDecoderLayer(
config=config,
cache_config=cache_config,
quant_config=quant_config,
prefix=ssm_prefix,
)
self.ssm_out_multiplier = config.ssm_out_multiplier
self.ssm_in_multiplier = config.ssm_in_multiplier
@ -452,10 +463,16 @@ class FalconH1Model(nn.Module):
# proper continuous batching computation including
# chunked prefill
attn_metadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds * self.embedding_multiplier
@ -468,7 +485,9 @@ class FalconH1Model(nn.Module):
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i)
layer_mamba_cache_params = None
if mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(i)
hidden_states = layer(
positions=positions,
hidden_states=hidden_states,
@ -484,7 +503,7 @@ class FalconH1Model(nn.Module):
class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid, SupportsV0Only):
IsHybrid):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
@ -558,15 +577,19 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
):
if self.mamba_cache is None:
self.mamba_cache = MambaCacheManager(
self.vllm_config,
self.lm_head.weight.dtype
if hasattr(self.lm_head, 'weight') else torch.bfloat16,
self.config.num_hidden_layers,
*self._get_mamba_cache_shape(),
)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
self.mamba_cache = MambaCacheManager(
self.vllm_config,
self.lm_head.weight.dtype if hasattr(
self.lm_head, 'weight') else torch.bfloat16,
self.config.num_hidden_layers,
*self._get_mamba_cache_shape(),
)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(
input_ids,
positions,

View File

@ -9,6 +9,7 @@ import torch
from torch import nn
from transformers import GraniteMoeHybridConfig
from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@ -36,7 +37,7 @@ from vllm.utils import LayerBlockType
from .granitemoe import GraniteMoeMoE
from .granitemoeshared import GraniteMoeSharedMLP
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsQuant, SupportsV0Only)
SupportsQuant)
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@ -67,7 +68,9 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
head_dim=config.mamba_d_head,
rms_norm_eps=config.rms_norm_eps,
activation=config.hidden_act,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.mamba_chunk_size)
self.block_sparse_moe = None
if getattr(config, "num_local_experts", 0) > 0:
@ -361,10 +364,15 @@ class GraniteMoeHybridModel(nn.Module):
) -> torch.Tensor:
attn_metadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.mamba_chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
@ -386,7 +394,9 @@ class GraniteMoeHybridModel(nn.Module):
num_attn += 1
layer_mamba_cache_params = None
if isinstance(layer, GraniteMoeHybridMambaDecoderLayer):
if isinstance(
layer,
GraniteMoeHybridMambaDecoderLayer) and mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
i - num_attn)
@ -501,8 +511,7 @@ class GraniteMoeHybridModel(nn.Module):
class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
SupportsPP, IsHybrid, SupportsV0Only,
SupportsQuant):
SupportsPP, IsHybrid, SupportsQuant):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -571,14 +580,20 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.model_config.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = (
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba))
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.model_config.dtype,
num_mamba_layers, *self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)

View File

@ -23,6 +23,7 @@ from typing import Optional
import torch
from torch import nn
from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@ -44,8 +45,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
SupportsLoRA, SupportsPP,
SupportsQuant,
SupportsV0Only)
SupportsQuant)
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
MambaCacheParams)
from vllm.model_executor.models.utils import (
@ -153,6 +153,8 @@ class NemotronHMambaDecoderLayer(nn.Module):
rms_norm_eps=config.rms_norm_eps,
activation=config.mamba_hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.chunk_size,
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -348,10 +350,14 @@ class NemotronHModel(nn.Module):
attn_metadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
@ -369,7 +375,8 @@ class NemotronHModel(nn.Module):
for i in range(len(self.layers)):
layer = self.layers[i]
layer_mamba_cache_params = None
if isinstance(layer, NemotronHMambaDecoderLayer):
if isinstance(layer,
NemotronHMambaDecoderLayer) and mamba_cache_params:
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
i - num_non_mamba_layers)
else:
@ -437,7 +444,7 @@ class NemotronHModel(nn.Module):
class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
IsHybrid, SupportsV0Only, SupportsQuant):
IsHybrid, SupportsQuant):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@ -499,15 +506,23 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
num_mamba_layers = self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config, LayerBlockType.mamba)
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = \
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba
)
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype,
num_mamba_layers, *self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
hidden_states = self.model(input_ids, positions, mamba_cache_params,
intermediate_tensors, inputs_embeds)

View File

@ -15,6 +15,7 @@ import torch
from torch import nn
from transformers import Zamba2Config
from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@ -41,7 +42,7 @@ from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import HasInnerState, IsHybrid, SupportsV0Only
from .interfaces import HasInnerState, IsHybrid
from .utils import AutoWeightsLoader, WeightsMapper, maybe_prefix
@ -58,6 +59,7 @@ class Zamba2LoRA(nn.Module):
rank: int,
output_dim: Union[int, list[int]],
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
"""Initialize the attention layer.
@ -283,6 +285,7 @@ class Zamba2MLP(nn.Module):
bare_block_idx: int,
num_hybrid_layers: dict[int, int],
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
"""Initialize the MLP layer.
@ -471,11 +474,10 @@ class Zamba2MambaDecoderLayer(nn.Module):
computation depending on configuration.
"""
def __init__(
self,
config: Zamba2Config,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
def __init__(self,
config: Zamba2Config,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "") -> None:
"""Initialize the Mamba decoder layer.
Args:
@ -486,20 +488,21 @@ class Zamba2MambaDecoderLayer(nn.Module):
# Initialize Mamba mixer with expanded intermediate size
intermediate_size = config.mamba_expand * config.hidden_size
self.mamba = MambaMixer2(
hidden_size=config.hidden_size,
ssm_state_size=config.mamba_d_state,
conv_kernel_size=config.mamba_d_conv,
intermediate_size=intermediate_size,
use_conv_bias=config.use_conv_bias,
use_bias=config.add_bias_linear,
n_groups=config.mamba_ngroups,
num_heads=config.n_mamba_heads,
head_dim=intermediate_size // config.n_mamba_heads,
rms_norm_eps=config.rms_norm_eps,
activation="silu",
quant_config=quant_config,
)
self.mamba = MambaMixer2(hidden_size=config.hidden_size,
ssm_state_size=config.mamba_d_state,
conv_kernel_size=config.mamba_d_conv,
intermediate_size=intermediate_size,
use_conv_bias=config.use_conv_bias,
use_bias=config.add_bias_linear,
n_groups=config.mamba_ngroups,
num_heads=config.n_mamba_heads,
head_dim=intermediate_size //
config.n_mamba_heads,
rms_norm_eps=config.rms_norm_eps,
activation="silu",
quant_config=quant_config,
prefix=f"{prefix}.mixer",
chunk_size=config.chunk_size)
# Input normalization
self.input_layernorm = RMSNorm(config.hidden_size,
@ -573,6 +576,7 @@ class Zamba2HybridLayer(nn.Module):
config: Zamba2Config,
block_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
"""Initialize the hybrid layer.
@ -589,7 +593,8 @@ class Zamba2HybridLayer(nn.Module):
bias=False,
quant_config=quant_config)
self.mamba_decoder = Zamba2MambaDecoderLayer(config,
quant_config=quant_config)
quant_config=quant_config,
prefix=prefix)
def forward(
self,
@ -699,14 +704,23 @@ class Zamba2Model(nn.Module):
# Initialize layers according to block type configuration
layers = []
for layer_idx, layer_type in enumerate(config.layers_block_type):
# tdoublep: avoid layers getting same index
# somewhat hacky but correct (I think)
prefix = str(len(layer2block_map) + layer_idx)
if layer_type == "hybrid":
block = next(blocks)
block_idx = layer2block_map[layer_idx]
layers.append(
Zamba2HybridLayer(block, config, block_idx, quant_config))
Zamba2HybridLayer(block,
config,
block_idx,
quant_config,
prefix=prefix))
else:
layers.append(
Zamba2MambaDecoderLayer(config, quant_config=quant_config))
Zamba2MambaDecoderLayer(config,
quant_config=quant_config,
prefix=prefix))
self.layers = nn.ModuleList(layers)
# Final layer normalization
@ -751,19 +765,30 @@ class Zamba2Model(nn.Module):
attn_metadata = get_forward_context().attn_metadata
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
if not envs.VLLM_USE_V1:
mamba2_metadata = prepare_mamba2_metadata(
chunk_size=self.config.chunk_size,
attn_metadata=attn_metadata,
)
else:
# v1 get mamba2_metadata from forward_context
mamba2_metadata = None
# Process through layers
original_hidden_states = torch.clone(hidden_states)
for layer_idx, layer in enumerate(self.layers):
layer_mamba_cache_params = None
if (isinstance(layer, (Zamba2HybridLayer, Zamba2MambaDecoderLayer))
and mamba_cache_params):
layer_mamba_cache_params = mamba_cache_params.at_layer_idx(
layer_idx)
layer_outputs = layer(
hidden_states,
original_hidden_states=original_hidden_states,
positions=positions,
mamba_cache_params=mamba_cache_params.at_layer_idx(layer_idx),
mamba_cache_params=layer_mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
hidden_states = layer_outputs
@ -803,7 +828,7 @@ class Zamba2Model(nn.Module):
return loaded_params
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
"""Zamba2 model with causal language modeling head.
This class wraps the core Zamba2 model and adds:
@ -897,14 +922,16 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid, SupportsV0Only):
Output hidden states
"""
# Initialize Mamba cache if needed
if self.mamba_cache is None:
num_mamba_layers = self.config.num_hidden_layers
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype, num_mamba_layers,
*self._get_mamba_cache_shape())
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = self.config.num_hidden_layers
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype,
num_mamba_layers, *self._get_mamba_cache_shape())
# Get cache parameters for current run
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
# Get cache parameters for current run
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
# Forward pass through model
hidden_states = self.model(

View File

@ -27,6 +27,7 @@ class KVCacheCoordinator(ABC):
):
self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len
self.enable_caching = enable_caching
self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching,
enable_kv_cache_events)
@ -268,9 +269,13 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
self.full_attention_block_size = self.full_attention_spec.block_size
self.other_block_size = self.other_spec.block_size
assert self.other_block_size % self.full_attention_block_size == 0, (
"KVCacheCoordinator assumes the block_size of full attention "
"layers is divisible by other layers now.")
if self.enable_caching:
# this requirement is only needed for the prefix caching logic
divisible = self.other_block_size % self.full_attention_block_size
assert divisible == 0, (
"KVCacheCoordinator assumes the block_size of full "
"attention layers is divisible by other layers now.")
if max(self.full_attention_group_ids) < min(self.other_group_ids):
self.full_attn_first = True

View File

@ -84,12 +84,15 @@ class KVCacheManager:
self.log_stats = log_stats
# FIXME: make prefix cache stats conditional on log_stats
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
assert len(
set(g.kv_cache_spec.block_size
for g in kv_cache_config.kv_cache_groups)
) == 1, "Only one block size is supported for now"
self.block_size = kv_cache_config.kv_cache_groups[
0].kv_cache_spec.block_size
self.block_size: Optional[int] = None
if self.enable_caching:
assert len(
set(g.kv_cache_spec.block_size
for g in kv_cache_config.kv_cache_groups)
) == 1, "Only one block size is supported for now"
self.block_size = kv_cache_config.kv_cache_groups[
0].kv_cache_spec.block_size
self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
@ -154,6 +157,7 @@ class KVCacheManager:
# if the scheduler has tried to schedule the request before.
block_hashes = self.req_to_block_hashes[request.request_id]
if not block_hashes:
assert self.block_size is not None
block_hashes = hash_request_tokens(self.caching_hash_fn,
self.block_size, request)
self.req_to_block_hashes[request.request_id] = block_hashes

View File

@ -864,9 +864,11 @@ def _get_kv_cache_config_uniform_page_size(
kv_cache_groups=kv_cache_groups,
)
min_block_size = min(
[group.kv_cache_spec.block_size for group in kv_cache_groups])
# Print the KV cache size and maximum concurrency.
num_tokens = num_blocks // len(
grouped_layers) * vllm_config.cache_config.block_size
num_tokens = num_blocks // len(grouped_layers) * min_block_size
num_tokens_str = f"{num_tokens:,}"
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"

View File

@ -159,6 +159,7 @@ class SlidingWindowSpec(AttentionSpec):
class MambaSpec(KVCacheSpec):
shapes: tuple[tuple[int, ...], ...]
dtype: torch.dtype
page_size_padded: Optional[int] = None
def __post_init__(self):
self.num_elements = sum(prod(shape) for shape in self.shapes)
@ -169,7 +170,11 @@ class MambaSpec(KVCacheSpec):
@property
def page_size_bytes(self) -> int:
return self.num_elements * get_dtype_size(self.dtype)
page_size = self.num_elements * get_dtype_size(self.dtype)
if self.page_size_padded is not None:
assert self.page_size_padded >= page_size
return self.page_size_padded
return page_size
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
# We allocate 1 block for each request now, so max_memory_usage_bytes is

View File

@ -334,6 +334,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# the same order of requests. We ensure this by only allowing the first
# group to reorder the batch and asserting that all other groups do not
# reorder the batch.
# TODO(tdoublep): make this more flexible so that any group can
# re-order the batch (not only the first).
# TODO(tdoublep): verify this during engine init instead of at runtime
for i in range(1, len(self.kv_cache_config.kv_cache_groups)):
batch_reordered = self.attn_metadata_builders[i].reorder_batch(
self.input_batch, scheduler_output)
@ -2449,6 +2452,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
corresponding memory buffer for KV cache.
"""
kv_caches: dict[str, torch.Tensor] = {}
has_attn, has_mamba = False, False
for i, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
@ -2458,6 +2462,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_blocks = (raw_tensor.numel() //
kv_cache_spec.page_size_bytes)
if isinstance(kv_cache_spec, AttentionSpec):
has_attn = True
kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
@ -2486,25 +2491,67 @@ class GPUModelRunner(LoRAModelRunnerMixin):
layer_name].view(dtype).view(kv_cache_shape).permute(
*inv_order)
elif isinstance(kv_cache_spec, MambaSpec):
has_mamba = True
raw_tensor = kv_cache_raw_tensors[layer_name]
dtype = kv_cache_spec.dtype
num_element_per_page = (kv_cache_spec.page_size_bytes //
get_dtype_size(dtype))
state_tensors = []
start_pos = 0
storage_offset = 0
for shape in kv_cache_spec.shapes:
target_shape = (num_blocks, *shape)
size_in_bytes = np.prod(shape) * get_dtype_size(
dtype) * num_blocks
tensor = raw_tensor[start_pos:start_pos +
size_in_bytes]
tensor = tensor.view(dtype).view(target_shape)
stride = torch.empty(target_shape).stride()
target_stride = (num_element_per_page, *stride[1:])
tensor = torch.as_strided(
raw_tensor.view(dtype),
size=target_shape,
stride=target_stride,
storage_offset=storage_offset,
)
state_tensors.append(tensor)
start_pos += size_in_bytes
assert start_pos == raw_tensor.numel()
kv_caches[layer_name] = tuple(state_tensors)
storage_offset += stride[0]
kv_caches[layer_name] = state_tensors
else:
raise NotImplementedError
if has_attn and has_mamba:
self._verify_hybrid_attention_mamba_layout(kv_cache_config,
kv_cache_raw_tensors)
return kv_caches
def _verify_hybrid_attention_mamba_layout(
self, kv_cache_config: KVCacheConfig,
kv_cache_raw_tensors: dict[str, torch.Tensor]) -> None:
"""
Verify that the KV cache memory layout is compatible for
models with both attention and mamba KV cache groups.
Args:
kv_cache_config: The KV cache config
kv_cache_raw_tensors: The KV cache buffer of each layer.
"""
for i, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
for layer_name in kv_cache_group_spec.layer_names:
raw_tensor = kv_cache_raw_tensors[layer_name]
num_blocks = (raw_tensor.numel() //
kv_cache_spec.page_size_bytes)
if isinstance(kv_cache_spec, AttentionSpec):
kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
if kv_cache_shape[0] != num_blocks or kv_cache_shape[
1] != 2:
raise ValueError(
"Hybrid models in V1 require an attention "
"backend with kv_cache_shape="
"(num_blocks, 2, ...). Please try setting "
"VLLM_ATTENTION_BACKEND=FLASHINFER")
def initialize_kv_cache_tensors(
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
"""
@ -2623,11 +2670,69 @@ class GPUModelRunner(LoRAModelRunnerMixin):
raise NotImplementedError(
"Prefix caching is not supported for Mamba yet.")
max_model_len = self.vllm_config.model_config.max_model_len
page_size_padded = self._maybe_pad_mamba_page_size(
attn_layers, mamba_layers, kv_cache_spec, max_model_len,
block_size)
# Set block_size to max_model_len, so that mamba model will always
# have only one block in the KV cache.
for layer_name, mamba_module in mamba_layers.items():
kv_cache_spec[layer_name] = MambaSpec(
shapes=mamba_module.get_state_shape(),
dtype=self.kv_cache_dtype,
block_size=max_model_len)
block_size=max_model_len,
page_size_padded=page_size_padded)
return kv_cache_spec
def _maybe_pad_mamba_page_size(
self,
attn_layers: dict[str, Attention],
mamba_layers: dict[str, MambaMixer2],
kv_cache_spec: dict[str, KVCacheSpec],
max_model_len: int,
block_size: int,
) -> Optional[int]:
"""
Ensure that page size of attention KV cache groups is greater than or
equal to the mamba KV cache groups. If not, we suggest to the user
how to set the attention block size to ensure that it is.
If the attention page size is strictly greater than the mamba page size,
we pad the mamba page size to make them equal.
Args:
attn_layers: Attention layers
mamba_layers: Mamba layers
kv_cache_spec: KV cache spec (populated with attention layers)
Returns:
Optional[int]: Mamba page size with padding (None if no padding).
"""
if len(attn_layers) == 0:
return None
attn_layer_name = next(iter(attn_layers))
attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes
mamba_layer_name = next(iter(mamba_layers))
mamba_page_size = MambaSpec(
shapes=mamba_layers[mamba_layer_name].get_state_shape(),
dtype=self.kv_cache_dtype,
block_size=max_model_len).page_size_bytes
if attn_page_size < mamba_page_size:
# attention page size (for 16 tokens)
attn_page_size_16 = 16 * attn_page_size // block_size
# some attention backends (e.g. FA) only support setting
# block size to multiple of 16, so let's suggest a value
# that would work (note: FA is currently not compatible
# with mamba layers, use FlashInfer instead).
suggest_attn_block_size = 16 * cdiv(mamba_page_size,
attn_page_size_16)
raise ValueError(
"Attention block size should be increased to at least "
f"{suggest_attn_block_size} in order to match "
"the mamba page size")
return attn_page_size