[V1] [Hybrid] Refactor mamba state shape calculation; enable V1 via cli (#20840)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2025-07-15 13:04:35 +02:00 committed by GitHub
parent c586b55667
commit 3534c39a20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 441 additions and 353 deletions

View File

@ -112,8 +112,7 @@ enforcing eager mode and disabling prefix caching in V1.
Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that
these models currently require enforcing eager mode, disabling prefix caching, and using the FlashInfer attention
backend in V1. It is also necessary to pass a non-standard block size for attention layers (this is not possible
using the `vllm serve` CLI yet).
backend in V1.
#### Encoder-Decoder Models

View File

@ -61,14 +61,6 @@ V1_SUPPORTED_MODELS = [
"tiiuae/Falcon-H1-0.5B-Base",
]
ATTN_BLOCK_SIZES = {
"ibm-ai-platform/Bamba-9B-v1": 528,
"Zyphra/Zamba2-1.2B-instruct": 80,
"nvidia/Nemotron-H-8B-Base-8K": 528,
"ibm-granite/granite-4.0-tiny-preview": 400,
"tiiuae/Falcon-H1-0.5B-Base": 800,
}
# Avoid OOM
MAX_NUM_SEQS = 4
@ -105,11 +97,6 @@ def test_models(
example_prompts, max_tokens, num_logprobs)
if model in V1_SUPPORTED_MODELS:
if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES:
block_size = ATTN_BLOCK_SIZES[model]
else:
block_size = 16
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
if model in HYBRID_MODELS:
@ -118,8 +105,7 @@ def test_models(
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
enforce_eager=True,
enable_prefix_caching=False,
block_size=block_size) as vllm_model:
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
else:

View File

@ -1630,6 +1630,9 @@ class CacheConfig:
checkpoint if available. Otherwise, the scales will default to 1.0."""
cpu_kvcache_space_bytes: Optional[int] = None
"""(CPU backend only) CPU key-value cache space."""
mamba_page_size_padded: Optional[int] = None
""" Optional override for mamba page size; used by hybrid mamba/attention
models to ensure exact alignment with attention page size."""
# Will be set after profiling.
num_gpu_blocks: Optional[int] = field(default=None, init=False)
@ -4882,11 +4885,15 @@ class VllmConfig:
if architecture is None:
return
from vllm.model_executor.models.config import MODELS_CONFIG_MAP
from vllm.model_executor.models.config import (
MODELS_CONFIG_MAP, HybridAttentionMambaModelConfig)
cls = MODELS_CONFIG_MAP.get(architecture, None)
if cls is not None:
cls.verify_and_update_config(self)
if self.model_config.is_hybrid:
HybridAttentionMambaModelConfig.verify_and_update_config(self)
if self.model_config.task == "classify":
# Maybe convert ForCausalLM into ForSequenceClassification model.
from vllm.model_executor.models.adapters import (

View File

@ -20,6 +20,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
update_metadata)
from vllm.model_executor.layers.mamba.mamba_utils import (
extra_groups_for_head_shards, get_mamba_state_shape)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
@ -146,18 +148,6 @@ class Mixer2RMSNormGated(CustomOp):
return out
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for
replication in order to accompany the head shards."""
# in the case ngoups % tp_size == 0, this will be zero
if ngroups % tp_size == 0:
return 0
# for n_groups == 1, this is exactly tp_size - n_groups
return tp_size - ngroups
def mamba_v2_sharded_weight_loader(
shard_spec: list[tuple[int, int, float]],
tp_size: int,
@ -707,30 +697,12 @@ class MambaMixer2(MambaBase, CustomOp):
return out
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
world_size = get_tensor_model_parallel_world_size()
conv_state_shape, temporal_state_shape = None, None
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = (self.n_groups +
extra_groups_for_head_shards(self.n_groups, world_size))
# - heads and n_groups are TP-ed
conv_dim = (self.intermediate_size +
2 * n_groups * self.ssm_state_size)
# contiguous along 'dim' axis
conv_state_shape = (
self.conv_kernel_size - 1,
divide(conv_dim, world_size),
return get_mamba_state_shape(
intermediate_size=self.intermediate_size,
tp_world_size=get_tensor_model_parallel_world_size(),
n_groups=self.n_groups,
num_heads=self.num_heads,
head_dim=self.head_dim,
state_size=self.ssm_state_size,
conv_kernel=self.conv_kernel_size,
)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
temporal_state_shape = (
divide(self.num_heads, world_size),
self.head_dim,
self.ssm_state_size,
)
return conv_state_shape, temporal_state_shape

View File

@ -0,0 +1,55 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.distributed import divide
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for
replication in order to accompany the head shards."""
# in the case ngoups % tp_size == 0, this will be zero
if ngroups % tp_size == 0:
return 0
# for n_groups == 1, this is exactly tp_size - n_groups
return tp_size - ngroups
def get_mamba_state_shape(
intermediate_size: int,
tp_world_size: int,
n_groups: int,
num_heads: int,
head_dim: int,
state_size: int,
conv_kernel: int,
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
""" Get the shape of mamba state."""
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = (n_groups +
extra_groups_for_head_shards(n_groups, tp_world_size))
# - heads and n_groups are TP-ed
conv_dim = (intermediate_size + 2 * n_groups * state_size)
# contiguous along 'dim' axis
conv_state_shape = (
conv_kernel - 1,
divide(conv_dim, tp_world_size),
)
if not use_v1:
conv_state_shape = (conv_state_shape[1], conv_state_shape[0])
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
temporal_state_shape = (
divide(num_heads, tp_world_size),
head_dim,
state_size,
)
return conv_state_shape, temporal_state_shape

View File

@ -12,7 +12,7 @@ from transformers import BambaConfig
from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul
@ -23,8 +23,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -435,6 +435,38 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
}
embedding_padding_modules = ["lm_head"]
@classmethod
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
- conv_state_shape: Shape for convolutional state cache
- temporal_state_shape: Shape for state space model cache
"""
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
intermediate_size = hf_config.mamba_expand * hf_config.hidden_size
return get_mamba_state_shape(
intermediate_size=intermediate_size,
tp_world_size=parallel_config.tensor_parallel_size,
n_groups=hf_config.mamba_n_groups,
num_heads=hf_config.mamba_n_heads,
head_dim=hf_config.mamba_d_head,
state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
@ -491,10 +523,13 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
self.vllm_config.parallel_config,
LayerBlockType.mamba
)
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype,
num_mamba_layers, *self._get_mamba_cache_shape())
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
self.mamba_cache = MambaCacheManager(self.vllm_config,
self.lm_head.weight.dtype,
num_mamba_layers,
*mamba_state_shape)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
@ -510,38 +545,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size
conv_state_shape, temporal_state_shape = None, None
intermediate_size = self.config.mamba_expand * hidden_size
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards(
self.config.mamba_n_groups, world_size))
# - heads and n_groups are TP-ed
conv_dim = (intermediate_size +
2 * n_groups * self.config.mamba_d_state)
conv_state_shape = (
divide(conv_dim, world_size),
self.config.mamba_d_conv - 1,
)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
temporal_state_shape = (
divide(self.config.mamba_n_heads, world_size),
self.config.mamba_d_head,
self.config.mamba_d_state,
)
return conv_state_shape, temporal_state_shape
def compute_logits(
self,
hidden_states: torch.Tensor,

View File

@ -3,9 +3,14 @@
from copy import deepcopy
from typing import TYPE_CHECKING
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
@ -200,6 +205,91 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
}
class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""
Ensure that page size of attention layers is greater than or
equal to the mamba layers. If not, automatically set the attention
block size to ensure that it is. If the attention page size is
strictly greater than the mamba page size, we pad the mamba page size
to make them equal.
Args:
vllm_config: vLLM Config
"""
if not envs.VLLM_USE_V1:
return
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
if cache_config.cache_dtype == "auto":
kv_cache_dtype = model_config.dtype
else:
kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# get attention page size (for 1 token)
attn_page_size_1_token = FullAttentionSpec(
block_size=1,
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
use_mla=model_config.use_mla).page_size_bytes
model_cls = ModelRegistry.resolve_model_cls(
model_config._model_info.architecture)[0]
# get mamba page size
mamba_page_size = MambaSpec(
shapes=model_cls.get_mamba_state_shape_from_config(vllm_config),
dtype=kv_cache_dtype,
block_size=model_config.max_model_len,
).page_size_bytes
# some attention backends (e.g. FA) only support setting
# block size to multiple of 16, so let's suggest a value
# that would work (note: FA is currently not compatible
# with mamba layers, use FlashInfer instead).
attn_block_size = 16 * cdiv(mamba_page_size,
16 * attn_page_size_1_token)
# override attention block size if either (a) the
# user has not set it or (b) the user has set it
# too small.
if (cache_config.block_size is None
or cache_config.block_size < attn_block_size):
cache_config.block_size = attn_block_size
logger.info(
"Setting attention block size to %d tokens "
"to ensure that attention page size is >= mamba page size.",
attn_block_size)
# compute new attention page size
attn_page_size = \
cache_config.block_size * attn_page_size_1_token
assert attn_page_size >= mamba_page_size
if attn_page_size == mamba_page_size:
# don't need to pad mamba page size
return
# pad mamba page size to exactly match attention
if (cache_config.mamba_page_size_padded is None
or cache_config.mamba_page_size_padded != attn_page_size):
cache_config.mamba_page_size_padded = (attn_page_size)
mamba_padding_pct = 100 * (attn_page_size -
mamba_page_size) / mamba_page_size
logger.info(
"Padding mamba page size by %.2f%% to ensure "
"that mamba page size and attention page size are "
"exactly equal.", mamba_padding_pct)
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"GteModel": SnowflakeGteNewModelConfig,
"GteNewModel": GteNewModelConfig,

View File

@ -11,7 +11,7 @@ from transformers import FalconH1Config
from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul
@ -22,8 +22,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -514,6 +514,42 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
}
embedding_padding_modules = ["lm_head"]
@classmethod
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
- conv_state_shape: Shape for convolutional state cache
- temporal_state_shape: Shape for state space model cache
"""
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
intermediate_size = (int(hf_config.mamba_expand *
hf_config.hidden_size)
if hf_config.mamba_d_ssm is None else
hf_config.mamba_d_ssm)
return get_mamba_state_shape(
intermediate_size=intermediate_size,
tp_world_size=parallel_config.tensor_parallel_size,
n_groups=hf_config.mamba_n_groups,
num_heads=hf_config.mamba_n_heads,
head_dim=hf_config.mamba_d_head,
state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
@ -580,12 +616,15 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
mamba_cache_params = None
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
self.mamba_cache = MambaCacheManager(
self.vllm_config,
self.lm_head.weight.dtype if hasattr(
self.lm_head, 'weight') else torch.bfloat16,
self.config.num_hidden_layers,
*self._get_mamba_cache_shape(),
*mamba_state_shape,
)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
@ -606,39 +645,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size
conv_state_shape, temporal_state_shape = None, None
intermediate_size = (int(self.config.mamba_expand *
hidden_size) if self.config.mamba_d_ssm
is None else self.config.mamba_d_ssm)
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = self.config.mamba_n_groups + extra_groups_for_head_shards(
self.config.mamba_n_groups, world_size)
# - heads and n_groups are TP-ed
conv_dim = intermediate_size + 2 * n_groups * self.config.mamba_d_state
conv_state_shape = (
divide(conv_dim, world_size),
self.config.mamba_d_conv - 1,
)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
temporal_state_shape = (
divide(self.config.mamba_n_heads, world_size),
self.config.mamba_d_head,
self.config.mamba_d_state,
)
return conv_state_shape, temporal_state_shape
def compute_logits(
self,
hidden_states: torch.Tensor,

View File

@ -12,7 +12,7 @@ from transformers import GraniteMoeHybridConfig
from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
@ -21,8 +21,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -524,6 +524,38 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
}
embedding_padding_modules = ["lm_head"]
@classmethod
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
- conv_state_shape: Shape for convolutional state cache
- temporal_state_shape: Shape for state space model cache
"""
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
intermediate_size = hf_config.mamba_expand * hf_config.hidden_size
return get_mamba_state_shape(
intermediate_size=intermediate_size,
tp_world_size=parallel_config.tensor_parallel_size,
n_groups=hf_config.mamba_n_groups,
num_heads=hf_config.mamba_n_heads,
head_dim=hf_config.mamba_d_head,
state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@ -587,9 +619,13 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba))
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.model_config.dtype,
num_mamba_layers, *self._get_mamba_cache_shape())
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
self.mamba_cache = MambaCacheManager(self.vllm_config,
self.model_config.dtype,
num_mamba_layers,
*mamba_state_shape)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
@ -605,38 +641,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size
conv_state_shape, temporal_state_shape = None, None
intermediate_size = self.config.mamba_expand * hidden_size
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards(
self.config.mamba_n_groups, world_size))
# - heads and n_groups are TP-ed
conv_dim = (intermediate_size +
2 * n_groups * self.config.mamba_d_state)
conv_state_shape = (
divide(conv_dim, world_size),
self.config.mamba_d_conv - 1,
)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
temporal_state_shape = (
divide(self.config.mamba_n_heads, world_size),
self.config.mamba_d_head,
self.config.mamba_d_state,
)
return conv_state_shape, temporal_state_shape
def compute_logits(
self,
hidden_states: torch.Tensor,

View File

@ -22,6 +22,7 @@ from .interfaces_base import is_pooling_model
if TYPE_CHECKING:
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.models.utils import WeightsMapper
from vllm.sequence import IntermediateTensors
@ -481,6 +482,25 @@ class IsHybrid(Protocol):
, also indicates that the model's hf_config has
'layers_block_type' """
@classmethod
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
- conv_state_shape: Shape for convolutional state cache
- temporal_state_shape: Shape for state space model cache
"""
...
@runtime_checkable
class _IsHybridType(Protocol):

View File

@ -11,15 +11,14 @@ from transformers import MambaConfig
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -198,6 +197,38 @@ class Mamba2Model(nn.Module):
class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
@classmethod
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
- conv_state_shape: Shape for convolutional state cache
- temporal_state_shape: Shape for state space model cache
"""
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
intermediate_size = hf_config.expand * hf_config.hidden_size
return get_mamba_state_shape(
intermediate_size=intermediate_size,
tp_world_size=parallel_config.tensor_parallel_size,
n_groups=hf_config.n_groups,
num_heads=hf_config.num_heads,
head_dim=hf_config.head_dim,
state_size=hf_config.state_size,
conv_kernel=hf_config.conv_kernel,
use_v1=use_v1,
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
@ -253,9 +284,13 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
self.model_config.get_num_layers_by_block_type(
self.vllm_config.parallel_config,
LayerBlockType.mamba))
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype,
num_mamba_layers, *self._get_mamba_cache_shape())
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
self.mamba_cache = MambaCacheManager(self.vllm_config,
self.lm_head.weight.dtype,
num_mamba_layers,
*mamba_state_shape)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
else:
@ -274,39 +309,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree):
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
conv_state_shape, temporal_state_shape = None, None
intermediate_size = getattr(
self.config, "intermediate_size",
self.config.expand * self.config.hidden_size)
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = (
self.config.n_groups +
extra_groups_for_head_shards(self.config.n_groups, world_size))
# - heads and n_groups are TP-ed
conv_dim = (intermediate_size + 2 * n_groups * self.config.state_size)
conv_state_shape = (
divide(conv_dim, world_size),
self.config.conv_kernel - 1,
)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
temporal_state_shape = (
divide(self.config.num_heads, world_size),
self.config.head_dim,
self.config.state_size,
)
return conv_state_shape, temporal_state_shape
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,

View File

@ -26,7 +26,7 @@ from torch import nn
from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import ReLUSquaredActivation
@ -37,8 +37,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
@ -459,6 +459,38 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
}
embedding_padding_modules = ["lm_head"]
@classmethod
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
- conv_state_shape: Shape for convolutional state cache
- temporal_state_shape: Shape for state space model cache
"""
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
intermediate_size = hf_config.expand * hf_config.hidden_size
return get_mamba_state_shape(
intermediate_size=intermediate_size,
tp_world_size=parallel_config.tensor_parallel_size,
n_groups=hf_config.n_groups,
num_heads=hf_config.mamba_num_heads,
head_dim=hf_config.mamba_head_dim,
state_size=hf_config.ssm_state_size,
conv_kernel=hf_config.conv_kernel,
use_v1=use_v1,
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
self.vllm_config = vllm_config
@ -515,10 +547,13 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
self.vllm_config.parallel_config,
LayerBlockType.mamba
)
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype,
num_mamba_layers, *self._get_mamba_cache_shape())
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
self.mamba_cache = MambaCacheManager(self.vllm_config,
self.lm_head.weight.dtype,
num_mamba_layers,
*mamba_state_shape)
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
@ -534,39 +569,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> tuple[tuple[int, int], tuple[int, int]]:
world_size = get_tensor_model_parallel_world_size()
hidden_size = self.config.hidden_size
conv_state_shape, temporal_state_shape = None, None
intermediate_size = self.config.expand * hidden_size
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = (
self.config.n_groups +
extra_groups_for_head_shards(self.config.n_groups, world_size))
# - heads and n_groups are TP-ed
conv_dim = (intermediate_size +
2 * n_groups * self.config.ssm_state_size)
conv_state_shape = (
divide(conv_dim, world_size),
self.config.conv_kernel - 1,
)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
temporal_state_shape = (
divide(self.config.mamba_num_heads, world_size),
self.config.mamba_head_dim,
self.config.ssm_state_size,
)
return conv_state_shape, temporal_state_shape
def compute_logits(
self,
hidden_states: torch.Tensor,

View File

@ -18,7 +18,7 @@ from transformers import Zamba2Config
from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
@ -30,8 +30,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
MambaMixer2, extra_groups_for_head_shards)
from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2
from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -843,6 +843,39 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
"1.weight": "B.weight",
})
@classmethod
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
use_v1: bool = True,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
use_v1: Get shapes for V1 (or V0)
Returns:
Tuple containing:
- conv_state_shape: Shape for convolutional state cache
- temporal_state_shape: Shape for state space model cache
"""
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
intermediate_size = hf_config.mamba_expand * hf_config.hidden_size
return get_mamba_state_shape(
intermediate_size=intermediate_size,
tp_world_size=parallel_config.tensor_parallel_size,
n_groups=hf_config.mamba_ngroups,
num_heads=hf_config.n_mamba_heads,
head_dim=hf_config.mamba_headdim,
state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv,
use_v1=use_v1,
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
"""Initialize the Zamba2 model for causal language modeling.
@ -925,9 +958,13 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
if not envs.VLLM_USE_V1:
if self.mamba_cache is None:
num_mamba_layers = self.config.num_hidden_layers
self.mamba_cache = MambaCacheManager(
self.vllm_config, self.lm_head.weight.dtype,
num_mamba_layers, *self._get_mamba_cache_shape())
mamba_state_shape = \
self.get_mamba_state_shape_from_config(
self.vllm_config, use_v1=False)
self.mamba_cache = MambaCacheManager(self.vllm_config,
self.lm_head.weight.dtype,
num_mamba_layers,
*mamba_state_shape)
# Get cache parameters for current run
mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs)
@ -968,49 +1005,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid):
"""
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def _get_mamba_cache_shape(
self) -> tuple[tuple[int, int], tuple[int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Returns:
Tuple containing:
- conv_state_shape: Shape for convolutional state cache
- temporal_state_shape: Shape for state space model cache
"""
world_size = get_tensor_model_parallel_world_size()
intermediate_size = self.config.mamba_expand * self.config.hidden_size
# Extend groups if needed to ensure all groups needed by a head
# are sharded together
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = (self.config.mamba_ngroups + extra_groups_for_head_shards(
self.config.mamba_ngroups, world_size))
# Calculate conv state shape (includes groups)
# - heads and n_groups are TP-ed
conv_dim = (intermediate_size +
2 * n_groups * self.config.mamba_d_state)
conv_state_shape = (
divide(conv_dim, world_size),
self.config.mamba_d_conv - 1,
)
# Calculate temporal state shape (per-head states)
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, d_head, d_state) = (128, 64, 128)
temporal_state_shape = (
divide(divide(intermediate_size, self.config.mamba_headdim),
world_size),
self.config.mamba_headdim,
self.config.mamba_d_state,
)
return conv_state_shape, temporal_state_shape
def compute_logits(
self,
hidden_states: torch.Tensor,

View File

@ -42,7 +42,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes, LazyLoader, async_tensor_h2d, cdiv,
GiB_bytes, LazyLoader, async_tensor_h2d,
check_use_alibi, get_dtype_size,
is_pin_memory_available, round_up)
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend
@ -2648,9 +2648,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
"Prefix caching is not supported for Mamba yet.")
max_model_len = self.vllm_config.model_config.max_model_len
page_size_padded = self._maybe_pad_mamba_page_size(
attn_layers, mamba_layers, kv_cache_spec, max_model_len,
block_size)
page_size_padded = (
self.vllm_config.cache_config.mamba_page_size_padded)
# Set block_size to max_model_len, so that mamba model will always
# have only one block in the KV cache.
@ -2662,54 +2661,3 @@ class GPUModelRunner(LoRAModelRunnerMixin):
page_size_padded=page_size_padded)
return kv_cache_spec
def _maybe_pad_mamba_page_size(
self,
attn_layers: dict[str, Attention],
mamba_layers: dict[str, MambaBase],
kv_cache_spec: dict[str, KVCacheSpec],
max_model_len: int,
block_size: int,
) -> Optional[int]:
"""
Ensure that page size of attention KV cache groups is greater than or
equal to the mamba KV cache groups. If not, we suggest to the user
how to set the attention block size to ensure that it is.
If the attention page size is strictly greater than the mamba page size,
we pad the mamba page size to make them equal.
Args:
attn_layers: Attention layers
mamba_layers: Mamba layers
kv_cache_spec: KV cache spec (populated with attention layers)
Returns:
Optional[int]: Mamba page size with padding (None if no padding).
"""
if len(attn_layers) == 0:
return None
attn_layer_name = next(iter(attn_layers))
attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes
mamba_layer_name = next(iter(mamba_layers))
mamba_page_size = MambaSpec(
shapes=mamba_layers[mamba_layer_name].get_state_shape(),
dtype=self.kv_cache_dtype,
block_size=max_model_len).page_size_bytes
if attn_page_size < mamba_page_size:
# attention page size (for 16 tokens)
attn_page_size_16 = 16 * attn_page_size // block_size
# some attention backends (e.g. FA) only support setting
# block size to multiple of 16, so let's suggest a value
# that would work (note: FA is currently not compatible
# with mamba layers, use FlashInfer instead).
suggest_attn_block_size = 16 * cdiv(mamba_page_size,
attn_page_size_16)
raise ValueError(
"Attention block size should be increased to at least "
f"{suggest_attn_block_size} in order to match "
"the mamba page size")
return attn_page_size