mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 11:55:00 +08:00
[Hybrid] Add mamba_block_size to Engine Args (#27289)
Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
parent
259504e147
commit
05181cc57f
@ -5,7 +5,7 @@ import hashlib
|
||||
from dataclasses import field
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from pydantic import Field, SkipValidation, field_validator
|
||||
from pydantic import Field, SkipValidation, field_validator, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
@ -90,8 +90,10 @@ class CacheConfig:
|
||||
mamba_page_size_padded: int | None = None
|
||||
""" Optional override for mamba page size; used by hybrid mamba/attention
|
||||
models to ensure exact alignment with attention page size."""
|
||||
mamba_block_size: int | None = None
|
||||
"""Size of a contiguous cache block in number of tokens for mamba cache."""
|
||||
mamba_block_size: int | None = Field(default=None, gt=0)
|
||||
"""Size of a contiguous cache block in number of tokens for mamba cache.
|
||||
Can be set only when prefix caching is enabled.
|
||||
Value must be a multiple of 8 to align with causal_conv1d kernel."""
|
||||
mamba_cache_dtype: MambaDType = "auto"
|
||||
"""The data type to use for the Mamba cache (both the conv as well as the
|
||||
ssm state). If set to 'auto', the data type will be inferred from the model
|
||||
@ -183,3 +185,11 @@ class CacheConfig:
|
||||
raise ValueError("Too large swap space. " + msg)
|
||||
elif cpu_memory_usage > 0.4 * total_cpu_memory:
|
||||
logger.warning("Possibly too large swap space. %s", msg)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_mamba_block_size(self) -> "CacheConfig":
|
||||
if self.mamba_block_size is not None and not self.enable_prefix_caching:
|
||||
raise ValueError(
|
||||
"--mamba-block-size can only be set with --enable-prefix-caching"
|
||||
)
|
||||
return self
|
||||
|
||||
@ -535,6 +535,7 @@ class EngineArgs:
|
||||
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
||||
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
|
||||
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
|
||||
mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
|
||||
|
||||
additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
|
||||
|
||||
@ -893,6 +894,9 @@ class EngineArgs:
|
||||
cache_group.add_argument(
|
||||
"--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
|
||||
)
|
||||
cache_group.add_argument(
|
||||
"--mamba-block-size", **cache_kwargs["mamba_block_size"]
|
||||
)
|
||||
|
||||
# Multimodal related configs
|
||||
multimodal_kwargs = get_kwargs(MultiModalConfig)
|
||||
@ -1390,6 +1394,7 @@ class EngineArgs:
|
||||
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
|
||||
mamba_cache_dtype=self.mamba_cache_dtype,
|
||||
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
|
||||
mamba_block_size=self.mamba_block_size,
|
||||
)
|
||||
|
||||
ray_runtime_env = None
|
||||
|
||||
@ -291,9 +291,8 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
|
||||
# Set mamba block size to max_model_len (this may get
|
||||
# override by prefix caching logic later)
|
||||
cache_config.mamba_block_size = model_config.max_model_len
|
||||
if cache_config.mamba_block_size is None:
|
||||
cache_config.mamba_block_size = model_config.max_model_len
|
||||
|
||||
if cache_config.enable_prefix_caching:
|
||||
if model_config.supports_mamba_prefix_caching:
|
||||
@ -333,6 +332,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
if not envs.VLLM_USE_V1:
|
||||
return
|
||||
|
||||
# Save the user input before it gets modified by MambaModelConfig
|
||||
mamba_block_size = vllm_config.cache_config.mamba_block_size
|
||||
# Enable FULL_AND_PIECEWISE by default
|
||||
MambaModelConfig.verify_and_update_config(vllm_config)
|
||||
|
||||
@ -386,7 +387,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
# With prefix caching, select attention block size to
|
||||
# optimize for mamba kernel performance
|
||||
|
||||
# mamba SSD kernel uses a chunk_size, e.g. 256
|
||||
# Mamba2 SSD kernel uses a chunk_size, e.g. 256
|
||||
# Align the block to the kernel: use lowest multiple of chunk_size
|
||||
# of attention tokens that would fit mamba_page_size:
|
||||
# e.g. for mamba page size = 788kB
|
||||
@ -404,7 +405,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
def lcm(a, b):
|
||||
return a * b // gcd(a, b)
|
||||
|
||||
base_chunk_size = model_config.get_mamba_chunk_size()
|
||||
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
|
||||
|
||||
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
|
||||
|
||||
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user