[Hybrid] Add mamba_block_size to Engine Args (#27289)

Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
Asaf Joseph Gardin 2025-10-28 14:54:24 +02:00 committed by GitHub
parent 259504e147
commit 05181cc57f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 25 additions and 8 deletions

View File

@ -5,7 +5,7 @@ import hashlib
from dataclasses import field
from typing import TYPE_CHECKING, Any, Literal
from pydantic import Field, SkipValidation, field_validator
from pydantic import Field, SkipValidation, field_validator, model_validator
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
@ -90,8 +90,10 @@ class CacheConfig:
mamba_page_size_padded: int | None = None
""" Optional override for mamba page size; used by hybrid mamba/attention
models to ensure exact alignment with attention page size."""
mamba_block_size: int | None = None
"""Size of a contiguous cache block in number of tokens for mamba cache."""
mamba_block_size: int | None = Field(default=None, gt=0)
"""Size of a contiguous cache block in number of tokens for mamba cache.
Can be set only when prefix caching is enabled.
Value must be a multiple of 8 to align with causal_conv1d kernel."""
mamba_cache_dtype: MambaDType = "auto"
"""The data type to use for the Mamba cache (both the conv as well as the
ssm state). If set to 'auto', the data type will be inferred from the model
@ -183,3 +185,11 @@ class CacheConfig:
raise ValueError("Too large swap space. " + msg)
elif cpu_memory_usage > 0.4 * total_cpu_memory:
logger.warning("Possibly too large swap space. %s", msg)
@model_validator(mode="after")
def validate_mamba_block_size(self) -> "CacheConfig":
if self.mamba_block_size is not None and not self.enable_prefix_caching:
raise ValueError(
"--mamba-block-size can only be set with --enable-prefix-caching"
)
return self

View File

@ -535,6 +535,7 @@ class EngineArgs:
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
@ -893,6 +894,9 @@ class EngineArgs:
cache_group.add_argument(
"--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
)
cache_group.add_argument(
"--mamba-block-size", **cache_kwargs["mamba_block_size"]
)
# Multimodal related configs
multimodal_kwargs = get_kwargs(MultiModalConfig)
@ -1390,6 +1394,7 @@ class EngineArgs:
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
mamba_cache_dtype=self.mamba_cache_dtype,
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
mamba_block_size=self.mamba_block_size,
)
ray_runtime_env = None

View File

@ -291,9 +291,8 @@ class MambaModelConfig(VerifyAndUpdateConfig):
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
# Set mamba block size to max_model_len (this may get
# override by prefix caching logic later)
cache_config.mamba_block_size = model_config.max_model_len
if cache_config.mamba_block_size is None:
cache_config.mamba_block_size = model_config.max_model_len
if cache_config.enable_prefix_caching:
if model_config.supports_mamba_prefix_caching:
@ -333,6 +332,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
if not envs.VLLM_USE_V1:
return
# Save the user input before it gets modified by MambaModelConfig
mamba_block_size = vllm_config.cache_config.mamba_block_size
# Enable FULL_AND_PIECEWISE by default
MambaModelConfig.verify_and_update_config(vllm_config)
@ -386,7 +387,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
# With prefix caching, select attention block size to
# optimize for mamba kernel performance
# mamba SSD kernel uses a chunk_size, e.g. 256
# Mamba2 SSD kernel uses a chunk_size, e.g. 256
# Align the block to the kernel: use lowest multiple of chunk_size
# of attention tokens that would fit mamba_page_size:
# e.g. for mamba page size = 788kB
@ -404,7 +405,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
def lcm(a, b):
return a * b // gcd(a, b)
base_chunk_size = model_config.get_mamba_chunk_size()
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)