diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 459ea2d676c1..d7634223542d 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -112,8 +112,7 @@ enforcing eager mode and disabling prefix caching in V1. Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`, `Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that these models currently require enforcing eager mode, disabling prefix caching, and using the FlashInfer attention -backend in V1. It is also necessary to pass a non-standard block size for attention layers (this is not possible -using the `vllm serve` CLI yet). +backend in V1. #### Encoder-Decoder Models diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index ecaae3ec1fc4..eba14e64553e 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -61,14 +61,6 @@ V1_SUPPORTED_MODELS = [ "tiiuae/Falcon-H1-0.5B-Base", ] -ATTN_BLOCK_SIZES = { - "ibm-ai-platform/Bamba-9B-v1": 528, - "Zyphra/Zamba2-1.2B-instruct": 80, - "nvidia/Nemotron-H-8B-Base-8K": 528, - "ibm-granite/granite-4.0-tiny-preview": 400, - "tiiuae/Falcon-H1-0.5B-Base": 800, -} - # Avoid OOM MAX_NUM_SEQS = 4 @@ -105,11 +97,6 @@ def test_models( example_prompts, max_tokens, num_logprobs) if model in V1_SUPPORTED_MODELS: - if model in HYBRID_MODELS and model in ATTN_BLOCK_SIZES: - block_size = ATTN_BLOCK_SIZES[model] - else: - block_size = 16 - with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") if model in HYBRID_MODELS: @@ -118,8 +105,7 @@ def test_models( with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS, enforce_eager=True, - enable_prefix_caching=False, - block_size=block_size) as vllm_model: + enable_prefix_caching=False) as vllm_model: vllm_v1_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) else: diff --git a/vllm/config.py b/vllm/config.py index 42410006f60d..2d84f6875cd9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1630,6 +1630,9 @@ class CacheConfig: checkpoint if available. Otherwise, the scales will default to 1.0.""" cpu_kvcache_space_bytes: Optional[int] = None """(CPU backend only) CPU key-value cache space.""" + mamba_page_size_padded: Optional[int] = None + """ Optional override for mamba page size; used by hybrid mamba/attention + models to ensure exact alignment with attention page size.""" # Will be set after profiling. num_gpu_blocks: Optional[int] = field(default=None, init=False) @@ -4882,11 +4885,15 @@ class VllmConfig: if architecture is None: return - from vllm.model_executor.models.config import MODELS_CONFIG_MAP + from vllm.model_executor.models.config import ( + MODELS_CONFIG_MAP, HybridAttentionMambaModelConfig) cls = MODELS_CONFIG_MAP.get(architecture, None) if cls is not None: cls.verify_and_update_config(self) + if self.model_config.is_hybrid: + HybridAttentionMambaModelConfig.verify_and_update_config(self) + if self.model_config.task == "classify": # Maybe convert ForCausalLM into ForSequenceClassification model. from vllm.model_executor.models.adapters import ( diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 4ca8e6b97fce..a88bd55e2367 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -20,6 +20,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata, update_metadata) +from vllm.model_executor.layers.mamba.mamba_utils import ( + extra_groups_for_head_shards, get_mamba_state_shape) from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( @@ -146,18 +148,6 @@ class Mixer2RMSNormGated(CustomOp): return out -def extra_groups_for_head_shards(ngroups: int, tp_size: int): - """Compute the increase in group numbers to account for - replication in order to accompany the head shards.""" - - # in the case ngoups % tp_size == 0, this will be zero - if ngroups % tp_size == 0: - return 0 - - # for n_groups == 1, this is exactly tp_size - n_groups - return tp_size - ngroups - - def mamba_v2_sharded_weight_loader( shard_spec: list[tuple[int, int, float]], tp_size: int, @@ -707,30 +697,12 @@ class MambaMixer2(MambaBase, CustomOp): return out def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]: - world_size = get_tensor_model_parallel_world_size() - - conv_state_shape, temporal_state_shape = None, None - - # if n_groups is not divisible by world_size, need to extend the shards - # to ensure all groups needed by a head is sharded along with it - n_groups = (self.n_groups + - extra_groups_for_head_shards(self.n_groups, world_size)) - - # - heads and n_groups are TP-ed - conv_dim = (self.intermediate_size + - 2 * n_groups * self.ssm_state_size) - # contiguous along 'dim' axis - conv_state_shape = ( - self.conv_kernel_size - 1, - divide(conv_dim, world_size), + return get_mamba_state_shape( + intermediate_size=self.intermediate_size, + tp_world_size=get_tensor_model_parallel_world_size(), + n_groups=self.n_groups, + num_heads=self.num_heads, + head_dim=self.head_dim, + state_size=self.ssm_state_size, + conv_kernel=self.conv_kernel_size, ) - - # These are not TP-ed as they depend on A, dt_bias, D - # - they are typically small - # e.g., (h_heads, d_head, d_state) = (128, 64, 128) - temporal_state_shape = ( - divide(self.num_heads, world_size), - self.head_dim, - self.ssm_state_size, - ) - return conv_state_shape, temporal_state_shape diff --git a/vllm/model_executor/layers/mamba/mamba_utils.py b/vllm/model_executor/layers/mamba/mamba_utils.py new file mode 100644 index 000000000000..99a582066c0d --- /dev/null +++ b/vllm/model_executor/layers/mamba/mamba_utils.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.distributed import divide + + +def extra_groups_for_head_shards(ngroups: int, tp_size: int): + """Compute the increase in group numbers to account for + replication in order to accompany the head shards.""" + + # in the case ngoups % tp_size == 0, this will be zero + if ngroups % tp_size == 0: + return 0 + + # for n_groups == 1, this is exactly tp_size - n_groups + return tp_size - ngroups + + +def get_mamba_state_shape( + intermediate_size: int, + tp_world_size: int, + n_groups: int, + num_heads: int, + head_dim: int, + state_size: int, + conv_kernel: int, + use_v1: bool = True, +) -> tuple[tuple[int, int], tuple[int, int, int]]: + """ Get the shape of mamba state.""" + + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = (n_groups + + extra_groups_for_head_shards(n_groups, tp_world_size)) + + # - heads and n_groups are TP-ed + conv_dim = (intermediate_size + 2 * n_groups * state_size) + # contiguous along 'dim' axis + conv_state_shape = ( + conv_kernel - 1, + divide(conv_dim, tp_world_size), + ) + + if not use_v1: + conv_state_shape = (conv_state_shape[1], conv_state_shape[0]) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, head_dim, state_size) = (128, 64, 128) + temporal_state_shape = ( + divide(num_heads, tp_world_size), + head_dim, + state_size, + ) + + return conv_state_shape, temporal_state_shape diff --git a/vllm/model_executor/models/bamba.py b/vllm/model_executor/models/bamba.py index dfc55b0c341b..e93d4294a62c 100644 --- a/vllm/model_executor/models/bamba.py +++ b/vllm/model_executor/models/bamba.py @@ -12,7 +12,7 @@ from transformers import BambaConfig from vllm import envs from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul @@ -23,8 +23,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) -from vllm.model_executor.layers.mamba.mamba_mixer2 import ( - MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -435,6 +435,38 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, } embedding_padding_modules = ["lm_head"] + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + intermediate_size = hf_config.mamba_expand * hf_config.hidden_size + + return get_mamba_state_shape( + intermediate_size=intermediate_size, + tp_world_size=parallel_config.tensor_parallel_size, + n_groups=hf_config.mamba_n_groups, + num_heads=hf_config.mamba_n_heads, + head_dim=hf_config.mamba_d_head, + state_size=hf_config.mamba_d_state, + conv_kernel=hf_config.mamba_d_conv, + use_v1=use_v1, + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config @@ -491,10 +523,13 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, self.vllm_config.parallel_config, LayerBlockType.mamba ) - - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, - num_mamba_layers, *self._get_mamba_cache_shape()) + mamba_state_shape = \ + self.get_mamba_state_shape_from_config( + self.vllm_config, use_v1=False) + self.mamba_cache = MambaCacheManager(self.vllm_config, + self.lm_head.weight.dtype, + num_mamba_layers, + *mamba_state_shape) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) @@ -510,38 +545,6 @@ class BambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size - - conv_state_shape, temporal_state_shape = None, None - - intermediate_size = self.config.mamba_expand * hidden_size - - # if n_groups is not divisible by world_size, need to extend the shards - # to ensure all groups needed by a head is sharded along with it - n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards( - self.config.mamba_n_groups, world_size)) - - # - heads and n_groups are TP-ed - conv_dim = (intermediate_size + - 2 * n_groups * self.config.mamba_d_state) - conv_state_shape = ( - divide(conv_dim, world_size), - self.config.mamba_d_conv - 1, - ) - - # These are not TP-ed as they depend on A, dt_bias, D - # - they are typically small - # e.g., (h_heads, d_head, d_state) = (128, 64, 128) - temporal_state_shape = ( - divide(self.config.mamba_n_heads, world_size), - self.config.mamba_d_head, - self.config.mamba_d_state, - ) - return conv_state_shape, temporal_state_shape - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index 6d0ffad1a819..6c6f8e7268b6 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -3,9 +3,14 @@ from copy import deepcopy from typing import TYPE_CHECKING +import vllm.envs as envs from vllm.logger import init_logger +from vllm.model_executor.models import ModelRegistry +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv +from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec if TYPE_CHECKING: + from vllm.config import VllmConfig logger = init_logger(__name__) @@ -200,6 +205,91 @@ class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig): } +class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): + + @classmethod + def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: + """ + Ensure that page size of attention layers is greater than or + equal to the mamba layers. If not, automatically set the attention + block size to ensure that it is. If the attention page size is + strictly greater than the mamba page size, we pad the mamba page size + to make them equal. + + Args: + vllm_config: vLLM Config + """ + + if not envs.VLLM_USE_V1: + return + + cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + + if cache_config.cache_dtype == "auto": + kv_cache_dtype = model_config.dtype + else: + kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + # get attention page size (for 1 token) + attn_page_size_1_token = FullAttentionSpec( + block_size=1, + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + dtype=kv_cache_dtype, + use_mla=model_config.use_mla).page_size_bytes + + model_cls = ModelRegistry.resolve_model_cls( + model_config._model_info.architecture)[0] + + # get mamba page size + mamba_page_size = MambaSpec( + shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), + dtype=kv_cache_dtype, + block_size=model_config.max_model_len, + ).page_size_bytes + + # some attention backends (e.g. FA) only support setting + # block size to multiple of 16, so let's suggest a value + # that would work (note: FA is currently not compatible + # with mamba layers, use FlashInfer instead). + attn_block_size = 16 * cdiv(mamba_page_size, + 16 * attn_page_size_1_token) + + # override attention block size if either (a) the + # user has not set it or (b) the user has set it + # too small. + if (cache_config.block_size is None + or cache_config.block_size < attn_block_size): + cache_config.block_size = attn_block_size + logger.info( + "Setting attention block size to %d tokens " + "to ensure that attention page size is >= mamba page size.", + attn_block_size) + + # compute new attention page size + attn_page_size = \ + cache_config.block_size * attn_page_size_1_token + + assert attn_page_size >= mamba_page_size + + if attn_page_size == mamba_page_size: + # don't need to pad mamba page size + return + + # pad mamba page size to exactly match attention + if (cache_config.mamba_page_size_padded is None + or cache_config.mamba_page_size_padded != attn_page_size): + cache_config.mamba_page_size_padded = (attn_page_size) + mamba_padding_pct = 100 * (attn_page_size - + mamba_page_size) / mamba_page_size + logger.info( + "Padding mamba page size by %.2f%% to ensure " + "that mamba page size and attention page size are " + "exactly equal.", mamba_padding_pct) + + MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "GteModel": SnowflakeGteNewModelConfig, "GteNewModel": GteNewModelConfig, diff --git a/vllm/model_executor/models/falcon_h1.py b/vllm/model_executor/models/falcon_h1.py index ad3f39793b65..7761de224c9d 100644 --- a/vllm/model_executor/models/falcon_h1.py +++ b/vllm/model_executor/models/falcon_h1.py @@ -11,7 +11,7 @@ from transformers import FalconH1Config from vllm import envs from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul @@ -22,8 +22,8 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) -from vllm.model_executor.layers.mamba.mamba_mixer2 import ( - MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -514,6 +514,42 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, } embedding_padding_modules = ["lm_head"] + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + + intermediate_size = (int(hf_config.mamba_expand * + hf_config.hidden_size) + if hf_config.mamba_d_ssm is None else + hf_config.mamba_d_ssm) + + return get_mamba_state_shape( + intermediate_size=intermediate_size, + tp_world_size=parallel_config.tensor_parallel_size, + n_groups=hf_config.mamba_n_groups, + num_heads=hf_config.mamba_n_heads, + head_dim=hf_config.mamba_d_head, + state_size=hf_config.mamba_d_state, + conv_kernel=hf_config.mamba_d_conv, + use_v1=use_v1, + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config @@ -580,12 +616,15 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, mamba_cache_params = None if not envs.VLLM_USE_V1: if self.mamba_cache is None: + mamba_state_shape = \ + self.get_mamba_state_shape_from_config( + self.vllm_config, use_v1=False) self.mamba_cache = MambaCacheManager( self.vllm_config, self.lm_head.weight.dtype if hasattr( self.lm_head, 'weight') else torch.bfloat16, self.config.num_hidden_layers, - *self._get_mamba_cache_shape(), + *mamba_state_shape, ) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) @@ -606,39 +645,6 @@ class FalconH1ForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size - - conv_state_shape, temporal_state_shape = None, None - - intermediate_size = (int(self.config.mamba_expand * - hidden_size) if self.config.mamba_d_ssm - is None else self.config.mamba_d_ssm) - - # if n_groups is not divisible by world_size, need to extend the shards - # to ensure all groups needed by a head is sharded along with it - n_groups = self.config.mamba_n_groups + extra_groups_for_head_shards( - self.config.mamba_n_groups, world_size) - - # - heads and n_groups are TP-ed - conv_dim = intermediate_size + 2 * n_groups * self.config.mamba_d_state - conv_state_shape = ( - divide(conv_dim, world_size), - self.config.mamba_d_conv - 1, - ) - - # These are not TP-ed as they depend on A, dt_bias, D - # - they are typically small - # e.g., (h_heads, d_head, d_state) = (128, 64, 128) - temporal_state_shape = ( - divide(self.config.mamba_n_heads, world_size), - self.config.mamba_d_head, - self.config.mamba_d_state, - ) - return conv_state_shape, temporal_state_shape - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/granitemoehybrid.py b/vllm/model_executor/models/granitemoehybrid.py index 1055fa0372b1..1c93e90737ad 100644 --- a/vllm/model_executor/models/granitemoehybrid.py +++ b/vllm/model_executor/models/granitemoehybrid.py @@ -12,7 +12,7 @@ from transformers import GraniteMoeHybridConfig from vllm import envs from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm @@ -21,8 +21,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) -from vllm.model_executor.layers.mamba.mamba_mixer2 import ( - MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -524,6 +524,38 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, } embedding_padding_modules = ["lm_head"] + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + intermediate_size = hf_config.mamba_expand * hf_config.hidden_size + + return get_mamba_state_shape( + intermediate_size=intermediate_size, + tp_world_size=parallel_config.tensor_parallel_size, + n_groups=hf_config.mamba_n_groups, + num_heads=hf_config.mamba_n_heads, + head_dim=hf_config.mamba_d_head, + state_size=hf_config.mamba_d_state, + conv_kernel=hf_config.mamba_d_conv, + use_v1=use_v1, + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -587,9 +619,13 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba)) - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.model_config.dtype, - num_mamba_layers, *self._get_mamba_cache_shape()) + mamba_state_shape = \ + self.get_mamba_state_shape_from_config( + self.vllm_config, use_v1=False) + self.mamba_cache = MambaCacheManager(self.vllm_config, + self.model_config.dtype, + num_mamba_layers, + *mamba_state_shape) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) @@ -605,38 +641,6 @@ class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA, def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size - - conv_state_shape, temporal_state_shape = None, None - - intermediate_size = self.config.mamba_expand * hidden_size - - # if n_groups is not divisible by world_size, need to extend the shards - # to ensure all groups needed by a head is sharded along with it - n_groups = (self.config.mamba_n_groups + extra_groups_for_head_shards( - self.config.mamba_n_groups, world_size)) - - # - heads and n_groups are TP-ed - conv_dim = (intermediate_size + - 2 * n_groups * self.config.mamba_d_state) - conv_state_shape = ( - divide(conv_dim, world_size), - self.config.mamba_d_conv - 1, - ) - - # These are not TP-ed as they depend on A, dt_bias, D - # - they are typically small - # e.g., (h_heads, d_head, d_state) = (128, 64, 128) - temporal_state_shape = ( - divide(self.config.mamba_n_heads, world_size), - self.config.mamba_d_head, - self.config.mamba_d_state, - ) - return conv_state_shape, temporal_state_shape - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 3a97641aa2f2..95970474d554 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -22,6 +22,7 @@ from .interfaces_base import is_pooling_model if TYPE_CHECKING: from vllm.attention import AttentionMetadata + from vllm.config import VllmConfig from vllm.model_executor.models.utils import WeightsMapper from vllm.sequence import IntermediateTensors @@ -481,6 +482,25 @@ class IsHybrid(Protocol): , also indicates that the model's hf_config has 'layers_block_type' """ + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ + ... + @runtime_checkable class _IsHybridType(Protocol): diff --git a/vllm/model_executor/models/mamba2.py b/vllm/model_executor/models/mamba2.py index b9fa57073935..d812d8cc0a39 100644 --- a/vllm/model_executor/models/mamba2.py +++ b/vllm/model_executor/models/mamba2.py @@ -11,15 +11,14 @@ from transformers import MambaConfig from vllm import envs from vllm.attention.backends.abstract import AttentionMetadata from vllm.config import VllmConfig -from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) -from vllm.model_executor.layers.mamba.mamba_mixer2 import ( - MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -198,6 +197,38 @@ class Mamba2Model(nn.Module): class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + intermediate_size = hf_config.expand * hf_config.hidden_size + + return get_mamba_state_shape( + intermediate_size=intermediate_size, + tp_world_size=parallel_config.tensor_parallel_size, + n_groups=hf_config.n_groups, + num_heads=hf_config.num_heads, + head_dim=hf_config.head_dim, + state_size=hf_config.state_size, + conv_kernel=hf_config.conv_kernel, + use_v1=use_v1, + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config @@ -253,9 +284,13 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): self.model_config.get_num_layers_by_block_type( self.vllm_config.parallel_config, LayerBlockType.mamba)) - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, - num_mamba_layers, *self._get_mamba_cache_shape()) + mamba_state_shape = \ + self.get_mamba_state_shape_from_config( + self.vllm_config, use_v1=False) + self.mamba_cache = MambaCacheManager(self.vllm_config, + self.lm_head.weight.dtype, + num_mamba_layers, + *mamba_state_shape) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) else: @@ -274,39 +309,6 @@ class Mamba2ForCausalLM(nn.Module, HasInnerState, IsAttentionFree): def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int]]: - world_size = get_tensor_model_parallel_world_size() - - conv_state_shape, temporal_state_shape = None, None - - intermediate_size = getattr( - self.config, "intermediate_size", - self.config.expand * self.config.hidden_size) - - # if n_groups is not divisible by world_size, need to extend the shards - # to ensure all groups needed by a head is sharded along with it - n_groups = ( - self.config.n_groups + - extra_groups_for_head_shards(self.config.n_groups, world_size)) - - # - heads and n_groups are TP-ed - conv_dim = (intermediate_size + 2 * n_groups * self.config.state_size) - conv_state_shape = ( - divide(conv_dim, world_size), - self.config.conv_kernel - 1, - ) - - # These are not TP-ed as they depend on A, dt_bias, D - # - they are typically small - # e.g., (h_heads, d_head, d_state) = (128, 64, 128) - temporal_state_shape = ( - divide(self.config.num_heads, world_size), - self.config.head_dim, - self.config.state_size, - ) - return conv_state_shape, temporal_state_shape - def compute_logits(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata) -> torch.Tensor: logits = self.logits_processor(self.lm_head, hidden_states, diff --git a/vllm/model_executor/models/nemotron_h.py b/vllm/model_executor/models/nemotron_h.py index 60fb72547255..cf7b39db1fe3 100644 --- a/vllm/model_executor/models/nemotron_h.py +++ b/vllm/model_executor/models/nemotron_h.py @@ -26,7 +26,7 @@ from torch import nn from vllm import envs from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed.parallel_state import get_pp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import ReLUSquaredActivation @@ -37,8 +37,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) -from vllm.model_executor.layers.mamba.mamba_mixer2 import ( - MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) @@ -459,6 +459,38 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, } embedding_padding_modules = ["lm_head"] + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + intermediate_size = hf_config.expand * hf_config.hidden_size + + return get_mamba_state_shape( + intermediate_size=intermediate_size, + tp_world_size=parallel_config.tensor_parallel_size, + n_groups=hf_config.n_groups, + num_heads=hf_config.mamba_num_heads, + head_dim=hf_config.mamba_head_dim, + state_size=hf_config.ssm_state_size, + conv_kernel=hf_config.conv_kernel, + use_v1=use_v1, + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config = vllm_config.model_config.hf_config self.vllm_config = vllm_config @@ -515,10 +547,13 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, self.vllm_config.parallel_config, LayerBlockType.mamba ) - - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, - num_mamba_layers, *self._get_mamba_cache_shape()) + mamba_state_shape = \ + self.get_mamba_state_shape_from_config( + self.vllm_config, use_v1=False) + self.mamba_cache = MambaCacheManager(self.vllm_config, + self.lm_head.weight.dtype, + num_mamba_layers, + *mamba_state_shape) mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) @@ -534,39 +569,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP, def get_seqlen_agnostic_capture_inputs(self, batch_size: int): return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int]]: - world_size = get_tensor_model_parallel_world_size() - hidden_size = self.config.hidden_size - - conv_state_shape, temporal_state_shape = None, None - - intermediate_size = self.config.expand * hidden_size - - # if n_groups is not divisible by world_size, need to extend the shards - # to ensure all groups needed by a head is sharded along with it - n_groups = ( - self.config.n_groups + - extra_groups_for_head_shards(self.config.n_groups, world_size)) - - # - heads and n_groups are TP-ed - conv_dim = (intermediate_size + - 2 * n_groups * self.config.ssm_state_size) - conv_state_shape = ( - divide(conv_dim, world_size), - self.config.conv_kernel - 1, - ) - - # These are not TP-ed as they depend on A, dt_bias, D - # - they are typically small - # e.g., (h_heads, d_head, d_state) = (128, 64, 128) - temporal_state_shape = ( - divide(self.config.mamba_num_heads, world_size), - self.config.mamba_head_dim, - self.config.ssm_state_size, - ) - return conv_state_shape, temporal_state_shape - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/zamba2.py b/vllm/model_executor/models/zamba2.py index 4935fd9e6df4..ebf8dd497f67 100644 --- a/vllm/model_executor/models/zamba2.py +++ b/vllm/model_executor/models/zamba2.py @@ -18,7 +18,7 @@ from transformers import Zamba2Config from vllm import envs from vllm.attention.layer import Attention from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -30,8 +30,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba2_metadata import ( Mamba2Metadata, prepare_mamba2_metadata) -from vllm.model_executor.layers.mamba.mamba_mixer2 import ( - MambaMixer2, extra_groups_for_head_shards) +from vllm.model_executor.layers.mamba.mamba_mixer2 import MambaMixer2 +from vllm.model_executor.layers.mamba.mamba_utils import get_mamba_state_shape from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -843,6 +843,39 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): "1.weight": "B.weight", }) + @classmethod + def get_mamba_state_shape_from_config( + cls, + vllm_config: "VllmConfig", + use_v1: bool = True, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + """Calculate shapes for Mamba's convolutional and state caches. + + Args: + vllm_config: vLLM config + use_v1: Get shapes for V1 (or V0) + + Returns: + Tuple containing: + - conv_state_shape: Shape for convolutional state cache + - temporal_state_shape: Shape for state space model cache + """ + + parallel_config = vllm_config.parallel_config + hf_config = vllm_config.model_config.hf_config + intermediate_size = hf_config.mamba_expand * hf_config.hidden_size + + return get_mamba_state_shape( + intermediate_size=intermediate_size, + tp_world_size=parallel_config.tensor_parallel_size, + n_groups=hf_config.mamba_ngroups, + num_heads=hf_config.n_mamba_heads, + head_dim=hf_config.mamba_headdim, + state_size=hf_config.mamba_d_state, + conv_kernel=hf_config.mamba_d_conv, + use_v1=use_v1, + ) + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: """Initialize the Zamba2 model for causal language modeling. @@ -925,9 +958,13 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): if not envs.VLLM_USE_V1: if self.mamba_cache is None: num_mamba_layers = self.config.num_hidden_layers - self.mamba_cache = MambaCacheManager( - self.vllm_config, self.lm_head.weight.dtype, - num_mamba_layers, *self._get_mamba_cache_shape()) + mamba_state_shape = \ + self.get_mamba_state_shape_from_config( + self.vllm_config, use_v1=False) + self.mamba_cache = MambaCacheManager(self.vllm_config, + self.lm_head.weight.dtype, + num_mamba_layers, + *mamba_state_shape) # Get cache parameters for current run mamba_cache_params = self.mamba_cache.current_run_tensors(**kwargs) @@ -968,49 +1005,6 @@ class Zamba2ForCausalLM(nn.Module, HasInnerState, IsHybrid): """ return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) - def _get_mamba_cache_shape( - self) -> tuple[tuple[int, int], tuple[int, int]]: - """Calculate shapes for Mamba's convolutional and state caches. - - Returns: - Tuple containing: - - conv_state_shape: Shape for convolutional state cache - - temporal_state_shape: Shape for state space model cache - """ - world_size = get_tensor_model_parallel_world_size() - - intermediate_size = self.config.mamba_expand * self.config.hidden_size - - # Extend groups if needed to ensure all groups needed by a head - # are sharded together - - # if n_groups is not divisible by world_size, need to extend the shards - # to ensure all groups needed by a head is sharded along with it - n_groups = (self.config.mamba_ngroups + extra_groups_for_head_shards( - self.config.mamba_ngroups, world_size)) - - # Calculate conv state shape (includes groups) - # - heads and n_groups are TP-ed - conv_dim = (intermediate_size + - 2 * n_groups * self.config.mamba_d_state) - conv_state_shape = ( - divide(conv_dim, world_size), - self.config.mamba_d_conv - 1, - ) - - # Calculate temporal state shape (per-head states) - # These are not TP-ed as they depend on A, dt_bias, D - # - they are typically small - # e.g., (h_heads, d_head, d_state) = (128, 64, 128) - temporal_state_shape = ( - divide(divide(intermediate_size, self.config.mamba_headdim), - world_size), - self.config.mamba_headdim, - self.config.mamba_d_state, - ) - - return conv_state_shape, temporal_state_shape - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 734df82589ac..af216539c900 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -42,7 +42,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - GiB_bytes, LazyLoader, async_tensor_h2d, cdiv, + GiB_bytes, LazyLoader, async_tensor_h2d, check_use_alibi, get_dtype_size, is_pin_memory_available, round_up) from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionBackend @@ -2648,9 +2648,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): "Prefix caching is not supported for Mamba yet.") max_model_len = self.vllm_config.model_config.max_model_len - page_size_padded = self._maybe_pad_mamba_page_size( - attn_layers, mamba_layers, kv_cache_spec, max_model_len, - block_size) + page_size_padded = ( + self.vllm_config.cache_config.mamba_page_size_padded) # Set block_size to max_model_len, so that mamba model will always # have only one block in the KV cache. @@ -2662,54 +2661,3 @@ class GPUModelRunner(LoRAModelRunnerMixin): page_size_padded=page_size_padded) return kv_cache_spec - - def _maybe_pad_mamba_page_size( - self, - attn_layers: dict[str, Attention], - mamba_layers: dict[str, MambaBase], - kv_cache_spec: dict[str, KVCacheSpec], - max_model_len: int, - block_size: int, - ) -> Optional[int]: - """ - Ensure that page size of attention KV cache groups is greater than or - equal to the mamba KV cache groups. If not, we suggest to the user - how to set the attention block size to ensure that it is. - - If the attention page size is strictly greater than the mamba page size, - we pad the mamba page size to make them equal. - - Args: - attn_layers: Attention layers - mamba_layers: Mamba layers - kv_cache_spec: KV cache spec (populated with attention layers) - - Returns: - Optional[int]: Mamba page size with padding (None if no padding). - """ - - if len(attn_layers) == 0: - return None - - attn_layer_name = next(iter(attn_layers)) - attn_page_size = kv_cache_spec[attn_layer_name].page_size_bytes - mamba_layer_name = next(iter(mamba_layers)) - mamba_page_size = MambaSpec( - shapes=mamba_layers[mamba_layer_name].get_state_shape(), - dtype=self.kv_cache_dtype, - block_size=max_model_len).page_size_bytes - if attn_page_size < mamba_page_size: - # attention page size (for 16 tokens) - attn_page_size_16 = 16 * attn_page_size // block_size - # some attention backends (e.g. FA) only support setting - # block size to multiple of 16, so let's suggest a value - # that would work (note: FA is currently not compatible - # with mamba layers, use FlashInfer instead). - suggest_attn_block_size = 16 * cdiv(mamba_page_size, - attn_page_size_16) - raise ValueError( - "Attention block size should be increased to at least " - f"{suggest_attn_block_size} in order to match " - "the mamba page size") - - return attn_page_size