mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-02 21:11:18 +08:00
[Hybrid] Add mamba_block_size to Engine Args (#27289)
Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
This commit is contained in:
parent
259504e147
commit
05181cc57f
@ -5,7 +5,7 @@ import hashlib
|
|||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
from typing import TYPE_CHECKING, Any, Literal
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
from pydantic import Field, SkipValidation, field_validator
|
from pydantic import Field, SkipValidation, field_validator, model_validator
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
|
|
||||||
from vllm.config.utils import config
|
from vllm.config.utils import config
|
||||||
@ -90,8 +90,10 @@ class CacheConfig:
|
|||||||
mamba_page_size_padded: int | None = None
|
mamba_page_size_padded: int | None = None
|
||||||
""" Optional override for mamba page size; used by hybrid mamba/attention
|
""" Optional override for mamba page size; used by hybrid mamba/attention
|
||||||
models to ensure exact alignment with attention page size."""
|
models to ensure exact alignment with attention page size."""
|
||||||
mamba_block_size: int | None = None
|
mamba_block_size: int | None = Field(default=None, gt=0)
|
||||||
"""Size of a contiguous cache block in number of tokens for mamba cache."""
|
"""Size of a contiguous cache block in number of tokens for mamba cache.
|
||||||
|
Can be set only when prefix caching is enabled.
|
||||||
|
Value must be a multiple of 8 to align with causal_conv1d kernel."""
|
||||||
mamba_cache_dtype: MambaDType = "auto"
|
mamba_cache_dtype: MambaDType = "auto"
|
||||||
"""The data type to use for the Mamba cache (both the conv as well as the
|
"""The data type to use for the Mamba cache (both the conv as well as the
|
||||||
ssm state). If set to 'auto', the data type will be inferred from the model
|
ssm state). If set to 'auto', the data type will be inferred from the model
|
||||||
@ -183,3 +185,11 @@ class CacheConfig:
|
|||||||
raise ValueError("Too large swap space. " + msg)
|
raise ValueError("Too large swap space. " + msg)
|
||||||
elif cpu_memory_usage > 0.4 * total_cpu_memory:
|
elif cpu_memory_usage > 0.4 * total_cpu_memory:
|
||||||
logger.warning("Possibly too large swap space. %s", msg)
|
logger.warning("Possibly too large swap space. %s", msg)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_mamba_block_size(self) -> "CacheConfig":
|
||||||
|
if self.mamba_block_size is not None and not self.enable_prefix_caching:
|
||||||
|
raise ValueError(
|
||||||
|
"--mamba-block-size can only be set with --enable-prefix-caching"
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|||||||
@ -535,6 +535,7 @@ class EngineArgs:
|
|||||||
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
||||||
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
|
mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype
|
||||||
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
|
mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype
|
||||||
|
mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size")
|
||||||
|
|
||||||
additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
|
additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config")
|
||||||
|
|
||||||
@ -893,6 +894,9 @@ class EngineArgs:
|
|||||||
cache_group.add_argument(
|
cache_group.add_argument(
|
||||||
"--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
|
"--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"]
|
||||||
)
|
)
|
||||||
|
cache_group.add_argument(
|
||||||
|
"--mamba-block-size", **cache_kwargs["mamba_block_size"]
|
||||||
|
)
|
||||||
|
|
||||||
# Multimodal related configs
|
# Multimodal related configs
|
||||||
multimodal_kwargs = get_kwargs(MultiModalConfig)
|
multimodal_kwargs = get_kwargs(MultiModalConfig)
|
||||||
@ -1390,6 +1394,7 @@ class EngineArgs:
|
|||||||
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
|
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
|
||||||
mamba_cache_dtype=self.mamba_cache_dtype,
|
mamba_cache_dtype=self.mamba_cache_dtype,
|
||||||
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
|
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
|
||||||
|
mamba_block_size=self.mamba_block_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
ray_runtime_env = None
|
ray_runtime_env = None
|
||||||
|
|||||||
@ -291,9 +291,8 @@ class MambaModelConfig(VerifyAndUpdateConfig):
|
|||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
|
|
||||||
# Set mamba block size to max_model_len (this may get
|
if cache_config.mamba_block_size is None:
|
||||||
# override by prefix caching logic later)
|
cache_config.mamba_block_size = model_config.max_model_len
|
||||||
cache_config.mamba_block_size = model_config.max_model_len
|
|
||||||
|
|
||||||
if cache_config.enable_prefix_caching:
|
if cache_config.enable_prefix_caching:
|
||||||
if model_config.supports_mamba_prefix_caching:
|
if model_config.supports_mamba_prefix_caching:
|
||||||
@ -333,6 +332,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
|||||||
if not envs.VLLM_USE_V1:
|
if not envs.VLLM_USE_V1:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Save the user input before it gets modified by MambaModelConfig
|
||||||
|
mamba_block_size = vllm_config.cache_config.mamba_block_size
|
||||||
# Enable FULL_AND_PIECEWISE by default
|
# Enable FULL_AND_PIECEWISE by default
|
||||||
MambaModelConfig.verify_and_update_config(vllm_config)
|
MambaModelConfig.verify_and_update_config(vllm_config)
|
||||||
|
|
||||||
@ -386,7 +387,7 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
|||||||
# With prefix caching, select attention block size to
|
# With prefix caching, select attention block size to
|
||||||
# optimize for mamba kernel performance
|
# optimize for mamba kernel performance
|
||||||
|
|
||||||
# mamba SSD kernel uses a chunk_size, e.g. 256
|
# Mamba2 SSD kernel uses a chunk_size, e.g. 256
|
||||||
# Align the block to the kernel: use lowest multiple of chunk_size
|
# Align the block to the kernel: use lowest multiple of chunk_size
|
||||||
# of attention tokens that would fit mamba_page_size:
|
# of attention tokens that would fit mamba_page_size:
|
||||||
# e.g. for mamba page size = 788kB
|
# e.g. for mamba page size = 788kB
|
||||||
@ -404,7 +405,8 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
|||||||
def lcm(a, b):
|
def lcm(a, b):
|
||||||
return a * b // gcd(a, b)
|
return a * b // gcd(a, b)
|
||||||
|
|
||||||
base_chunk_size = model_config.get_mamba_chunk_size()
|
base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size()
|
||||||
|
|
||||||
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
|
attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token)
|
||||||
|
|
||||||
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
|
chunk_size = lcm(base_chunk_size, kernel_block_alignment_size)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user