[Hybrid] Add mamba_block_size to Engine Args (#27289)

Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
Asaf Joseph Gardin 2025-10-28 14:54:24 +02:00 committed by GitHub
parent 259504e147
commit 05181cc57f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 25 additions and 8 deletions

View File

@ -5,7 +5,7 @@ import hashlib
from dataclasses import field from dataclasses import field
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
from pydantic import Field, SkipValidation, field_validator from pydantic import Field, SkipValidation, field_validator, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from vllm.config.utils import config from vllm.config.utils import config
@ -90,8 +90,10 @@ class CacheConfig:
mamba_page_size_padded: int | None = None mamba_page_size_padded: int | None = None
""" Optional override for mamba page size; used by hybrid mamba/attention """ Optional override for mamba page size; used by hybrid mamba/attention
models to ensure exact alignment with attention page size.""" models to ensure exact alignment with attention page size."""
mamba_block_size: int | None = None mamba_block_size: int | None = Field(default=None, gt=0)
"""Size of a contiguous cache block in number of tokens for mamba cache.""" """Size of a contiguous cache block in number of tokens for mamba cache.
Can be set only when prefix caching is enabled.
Value must be a multiple of 8 to align with causal_conv1d kernel."""
mamba_cache_dtype: MambaDType = "auto" mamba_cache_dtype: MambaDType = "auto"
"""The data type to use for the Mamba cache (both the conv as well as the """The data type to use for the Mamba cache (both the conv as well as the
ssm state). If set to 'auto', the data type will be inferred from the model ssm state). If set to 'auto', the data type will be inferred from the model
@ -183,3 +185,11 @@ class CacheConfig:
raise ValueError("Too large swap space. " + msg) raise ValueError("Too large swap space. " + msg)
elif cpu_memory_usage > 0.4 * total_cpu_memory: elif cpu_memory_usage > 0.4 * total_cpu_memory:
logger.warning("Possibly too large swap space. %s", msg) logger.warning("Possibly too large swap space. %s", msg)
@model_validator(mode="after")
def validate_mamba_block_size(self) -> "CacheConfig":
if self.mamba_block_size is not None and not self.enable_prefix_caching:
raise ValueError(
"--mamba-block-size can only be set with --enable-prefix-caching"
)
return self

View File

@ -535,6 +535,7 @@ class EngineArgs:
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config") additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
@ -893,6 +894,9 @@ class EngineArgs:
cache_group.add_argument( cache_group.add_argument(
"--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"] "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
) )
cache_group.add_argument(
"--mamba-block-size", **cache_kwargs["mamba_block_size"]
)
# Multimodal related configs # Multimodal related configs
multimodal_kwargs = get_kwargs(MultiModalConfig) multimodal_kwargs = get_kwargs(MultiModalConfig)
@ -1390,6 +1394,7 @@ class EngineArgs:
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill, kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
mamba_cache_dtype=self.mamba_cache_dtype, mamba_cache_dtype=self.mamba_cache_dtype,
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype, mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
mamba_block_size=self.mamba_block_size,
) )
ray_runtime_env = None ray_runtime_env = None

View File

@ -291,9 +291,8 @@ class MambaModelConfig(VerifyAndUpdateConfig):
model_config = vllm_config.model_config model_config = vllm_config.model_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
# Set mamba block size to max_model_len (this may get if cache_config.mamba_block_size is None:
# override by prefix caching logic later) cache_config.mamba_block_size = model_config.max_model_len
cache_config.mamba_block_size = model_config.max_model_len
if cache_config.enable_prefix_caching: if cache_config.enable_prefix_caching:
if model_config.supports_mamba_prefix_caching: if model_config.supports_mamba_prefix_caching:
@ -333,6 +332,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
if not envs.VLLM_USE_V1: if not envs.VLLM_USE_V1:
return return
# Save the user input before it gets modified by MambaModelConfig
mamba_block_size = vllm_config.cache_config.mamba_block_size
# Enable FULL_AND_PIECEWISE by default # Enable FULL_AND_PIECEWISE by default
MambaModelConfig.verify_and_update_config(vllm_config) MambaModelConfig.verify_and_update_config(vllm_config)
@ -386,7 +387,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
# With prefix caching, select attention block size to # With prefix caching, select attention block size to
# optimize for mamba kernel performance # optimize for mamba kernel performance
# mamba SSD kernel uses a chunk_size, e.g. 256 # Mamba2 SSD kernel uses a chunk_size, e.g. 256
# Align the block to the kernel: use lowest multiple of chunk_size # Align the block to the kernel: use lowest multiple of chunk_size
# of attention tokens that would fit mamba_page_size: # of attention tokens that would fit mamba_page_size:
# e.g. for mamba page size = 788kB # e.g. for mamba page size = 788kB
@ -404,7 +405,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
def lcm(a, b): def lcm(a, b):
return a * b // gcd(a, b) return a * b // gcd(a, b)
base_chunk_size = model_config.get_mamba_chunk_size() base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size) chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)