Improve configs - CacheConfig (#16835)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-20 05:25:04 +01:00 committed by GitHub
parent 87aaadef73
commit 4b07d36891
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 124 additions and 158 deletions

View File

@ -1245,22 +1245,70 @@ class ModelConfig:
or getattr(self.hf_config, "is_matryoshka", False))
class CacheConfig:
"""Configuration for the KV cache.
BlockSize = Literal[8, 16, 32, 64, 128]
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
PrefixCachingHashAlgo = Literal["builtin", "sha256"]
Args:
block_size: Size of a cache block in number of tokens.
gpu_memory_utilization: Fraction of GPU memory to use for the
vLLM execution.
swap_space: Size of the CPU swap space per GPU (in GiB).
cache_dtype: Data type for kv cache storage.
is_attention_free: Whether the model is attention-free.
num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
profiled num_gpu_blocks if specified. Does nothing if None.
sliding_window: Sliding window size for the KV cache.
enable_prefix_caching: Whether to enable prefix caching.
cpu_offload_gb: Size of the CPU offload buffer in GiB.
@config
@dataclass
class CacheConfig:
"""Configuration for the KV cache."""
block_size: Optional[BlockSize] = None
"""Size of a contiguous cache block in number of tokens. This is ignored on
neuron devices and set to `--max-model-len`. On CUDA devices, only block
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
"""
gpu_memory_utilization: float = 0.9
"""The fraction of GPU memory to be used for the model executor, which can
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
utilization. If unspecified, will use the default value of 0.9. This is a
per-instance limit, and only applies to the current vLLM instance. It does
not matter if you have another vLLM instance running on the same GPU. For
example, if you have two vLLM instances running on the same GPU, you can
set the GPU memory utilization to 0.5 for each instance."""
swap_space: float = 4
"""Size of the CPU swap space per GPU (in GiB)."""
cache_dtype: CacheDType = "auto"
"""Data type for kv cache storage. If "auto", will use model data type.
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
fp8 (=fp8_e4m3)."""
is_attention_free: bool = False
"""Whether the model is attention-free. This is primarily set in
`ModelConfig` and that value should be manually duplicated here."""
num_gpu_blocks_override: Optional[int] = None
"""Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks`
if specified. Does nothing if `None`. Used for testing preemption."""
sliding_window: Optional[int] = None
"""Sliding window size for the KV cache. This is primarily set in
`ModelConfig` and that value should be manually duplicated here."""
enable_prefix_caching: Optional[bool] = None
"""Whether to enable prefix caching. Disabled by default for V0. Enabled by
default for V1."""
prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin"
"""Set the hash algorithm for prefix caching:\n
- "builtin" is Python's built-in hash.\n
- "sha256" is collision resistant but with certain overheads."""
cpu_offload_gb: float = 0
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
no offloading. Intuitively, this argument can be seen as a virtual way to
increase the GPU memory size. For example, if you have one 24 GB GPU and
set this to 10, virtually you can think of it as a 34 GB GPU. Then you can
load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
Note that this requires fast CPU-GPU interconnect, as part of the model is
loaded from CPU memory to GPU memory on the fly in each model forward pass.
"""
calculate_kv_scales: bool = False
"""This enables dynamic calculation of `k_scale` and `v_scale` when
kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
checkpoint if available. Otherwise, the scales will default to 1.0."""
# Will be set after profiling.
num_gpu_blocks: Optional[int] = field(default=None, init=False)
"""The number of blocks to allocate for GPU memory."""
num_cpu_blocks: Optional[int] = field(default=None, init=False)
"""The number of blocks to allocate for CPU memory."""
def compute_hash(self) -> str:
"""
@ -1281,43 +1329,13 @@ class CacheConfig:
usedforsecurity=False).hexdigest()
return hash_str
def __init__(
self,
block_size: int,
gpu_memory_utilization: float,
swap_space: float,
cache_dtype: str,
is_attention_free: bool = False,
num_gpu_blocks_override: Optional[int] = None,
sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False,
prefix_caching_hash_algo: str = "builtin",
cpu_offload_gb: float = 0,
calculate_kv_scales: Optional[bool] = None,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * GiB_bytes
self.num_gpu_blocks_override = num_gpu_blocks_override
self.cache_dtype = cache_dtype
self.is_attention_free = is_attention_free
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
self.prefix_caching_hash_algo = prefix_caching_hash_algo
self.cpu_offload_gb = cpu_offload_gb
self.calculate_kv_scales = calculate_kv_scales
def __post_init__(self) -> None:
self.swap_space_bytes = self.swap_space * GiB_bytes
self._verify_args()
self._verify_cache_dtype()
self._verify_prefix_caching()
# Will be set after profiling.
self.num_gpu_blocks: Optional[int] = None
self.num_cpu_blocks: Optional[int] = None
# Set calculate_kv_scales to False if the value is unset.
if self.calculate_kv_scales is None:
self.calculate_kv_scales = False
def metrics_info(self):
# convert cache_config to dict(key: str, value: str) for prometheus
# metrics info
@ -1336,7 +1354,7 @@ class CacheConfig:
def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
elif self.cache_dtype in get_args(CacheDType):
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
@ -1354,12 +1372,12 @@ class CacheConfig:
"Prefix caching is not supported with sliding window. "
"Run with --disable-sliding-window to use prefix caching.")
if self.enable_prefix_caching and self.prefix_caching_hash_algo not in (
"builtin", "sha256"):
if (self.enable_prefix_caching and self.prefix_caching_hash_algo
not in get_args(PrefixCachingHashAlgo)):
raise ValueError(
"Unknown prefix caching hash algorithm: "
f"{self.prefix_caching_hash_algo}. Must be either "
"'builtin' or 'sha256'.")
f"{self.prefix_caching_hash_algo}. Must be one of "
f"{get_args(PrefixCachingHashAlgo)}.")
def verify_with_parallel_config(
self,

View File

@ -16,16 +16,16 @@ from typing_extensions import TypeIs
import vllm.envs as envs
from vllm import version
from vllm.config import (CacheConfig, CompilationConfig, Config, ConfigFormat,
DecodingConfig, Device, DeviceConfig,
DistributedExecutorBackend, HfOverrides,
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
Config, ConfigFormat, DecodingConfig, Device,
DeviceConfig, DistributedExecutorBackend, HfOverrides,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ModelImpl, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
PoolType, PromptAdapterConfig, SchedulerConfig,
SchedulerPolicy, SpeculativeConfig, TaskOption,
TokenizerPoolConfig, VllmConfig, get_attr_docs,
get_field)
PoolType, PrefixCachingHashAlgo, PromptAdapterConfig,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerPoolConfig, VllmConfig,
get_attr_docs, get_field)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@ -138,7 +138,7 @@ class EngineArgs:
load_format: str = LoadConfig.load_format
config_format: ConfigFormat = ConfigFormat.AUTO
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
seed: Optional[int] = None
max_model_len: Optional[int] = None
# Note: Specifying a custom executor backend by passing a class
@ -154,15 +154,16 @@ class EngineArgs:
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
max_parallel_loading_workers: Optional[
int] = ParallelConfig.max_parallel_loading_workers
block_size: Optional[int] = None
enable_prefix_caching: Optional[bool] = None
prefix_caching_hash_algo: str = "builtin"
block_size: Optional[BlockSize] = CacheConfig.block_size
enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
prefix_caching_hash_algo: PrefixCachingHashAlgo = \
CacheConfig.prefix_caching_hash_algo
disable_sliding_window: bool = False
disable_cascade_attn: bool = False
use_v2_block_manager: bool = True
swap_space: float = 4 # GiB
cpu_offload_gb: float = 0 # GiB
gpu_memory_utilization: float = 0.90
swap_space: float = CacheConfig.swap_space
cpu_offload_gb: float = CacheConfig.cpu_offload_gb
gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
max_num_batched_tokens: Optional[
int] = SchedulerConfig.max_num_batched_tokens
max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
@ -211,7 +212,8 @@ class EngineArgs:
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
num_gpu_blocks_override: Optional[int] = None
num_gpu_blocks_override: Optional[
int] = CacheConfig.num_gpu_blocks_override
num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
model_loader_extra_config: dict = \
get_field(LoadConfig, "model_loader_extra_config")
@ -250,7 +252,7 @@ class EngineArgs:
enable_sleep_mode: bool = False
model_impl: str = "auto"
calculate_kv_scales: Optional[bool] = None
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
additional_config: Optional[Dict[str, Any]] = None
enable_reasoning: Optional[bool] = None
@ -306,12 +308,19 @@ class EngineArgs:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
name = field.name
# Get the default value of the field
default = field.default
# This will only be True if default is MISSING
if field.default_factory is not MISSING:
default = field.default_factory()
kwargs[name] = {"default": default, "help": cls_docs[name]}
# Get the help text for the field
name = field.name
help = cls_docs[name]
# Escape % for argparse
help = help.replace("%", "%%")
# Initialise the kwargs dictionary for the field
kwargs[name] = {"default": default, "help": help}
# Make note of if the field is optional and get the actual
# type of the field if it is
@ -319,6 +328,8 @@ class EngineArgs:
field_type = get_args(
field.type)[0] if optional else field.type
# Set type, action and choices for the field depending on the
# type of the field
if can_be_type(field_type, bool):
# Creates --no-<name> and --<name> flags
kwargs[name]["action"] = argparse.BooleanOptionalAction
@ -463,14 +474,6 @@ class EngineArgs:
'* "bfloat16" for a balance between precision and range.\n'
'* "float" is shorthand for FP32 precision.\n'
'* "float32" for FP32 precision.')
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument('--max-model-len',
type=human_readable_int,
default=EngineArgs.max_model_len,
@ -544,33 +547,30 @@ class EngineArgs:
parallel_group.add_argument(
'--disable-custom-all-reduce',
**parallel_kwargs["disable_custom_all_reduce"])
# KV cache arguments
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32, 64, 128],
help='Token block size for contiguous chunks of '
'tokens. This is ignored on neuron devices and '
'set to ``--max-model-len``. On CUDA devices, '
'only block sizes up to 32 are supported. '
'On HPU devices, block size defaults to 128.')
parser.add_argument(
"--enable-prefix-caching",
action=argparse.BooleanOptionalAction,
default=EngineArgs.enable_prefix_caching,
help="Enables automatic prefix caching. "
"Use ``--no-enable-prefix-caching`` to disable explicitly.",
)
parser.add_argument(
"--prefix-caching-hash-algo",
type=str,
choices=["builtin", "sha256"],
default=EngineArgs.prefix_caching_hash_algo,
help="Set the hash algorithm for prefix caching. "
"Options are 'builtin' (Python's built-in hash) or 'sha256' "
"(collision resistant but with certain overheads).",
# KV cache arguments
cache_kwargs = get_kwargs(CacheConfig)
cache_group = parser.add_argument_group(
title="CacheConfig",
description=CacheConfig.__doc__,
)
cache_group.add_argument('--block-size', **cache_kwargs["block_size"])
cache_group.add_argument('--gpu-memory-utilization',
**cache_kwargs["gpu_memory_utilization"])
cache_group.add_argument('--swap-space', **cache_kwargs["swap_space"])
cache_group.add_argument('--kv-cache-dtype',
**cache_kwargs["cache_dtype"])
cache_group.add_argument('--num-gpu-blocks-override',
**cache_kwargs["num_gpu_blocks_override"])
cache_group.add_argument("--enable-prefix-caching",
**cache_kwargs["enable_prefix_caching"])
cache_group.add_argument("--prefix-caching-hash-algo",
**cache_kwargs["prefix_caching_hash_algo"])
cache_group.add_argument('--cpu-offload-gb',
**cache_kwargs["cpu_offload_gb"])
cache_group.add_argument('--calculate-kv-scales',
**cache_kwargs["calculate_kv_scales"])
parser.add_argument('--disable-sliding-window',
action='store_true',
help='Disables sliding window, '
@ -588,43 +588,6 @@ class EngineArgs:
type=int,
default=EngineArgs.seed,
help='Random seed for operations.')
parser.add_argument('--swap-space',
type=float,
default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU.')
parser.add_argument(
'--cpu-offload-gb',
type=float,
default=0,
help='The space in GiB to offload to CPU, per GPU. '
'Default is 0, which means no offloading. Intuitively, '
'this argument can be seen as a virtual way to increase '
'the GPU memory size. For example, if you have one 24 GB '
'GPU and set this to 10, virtually you can think of it as '
'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
'which requires at least 26GB GPU memory. Note that this '
'requires fast CPU-GPU interconnect, as part of the model is '
'loaded from CPU memory to GPU memory on the fly in each '
'model forward pass.')
parser.add_argument(
'--gpu-memory-utilization',
type=float,
default=EngineArgs.gpu_memory_utilization,
help='The fraction of GPU memory to be used for the model '
'executor, which can range from 0 to 1. For example, a value of '
'0.5 would imply 50%% GPU memory utilization. If unspecified, '
'will use the default value of 0.9. This is a per-instance '
'limit, and only applies to the current vLLM instance.'
'It does not matter if you have another vLLM instance running '
'on the same GPU. For example, if you have two vLLM instances '
'running on the same GPU, you can set the GPU memory utilization '
'to 0.5 for each instance.')
parser.add_argument(
'--num-gpu-blocks-override',
type=int,
default=None,
help='If specified, ignore GPU profiling result and use this number'
' of GPU blocks. Used for testing preemption.')
parser.add_argument(
'--max-logprobs',
type=int,
@ -994,15 +957,6 @@ class EngineArgs:
help="Enable sleep mode for the engine. "
"(only cuda platform is supported)")
parser.add_argument(
'--calculate-kv-scales',
action='store_true',
help='This enables dynamic calculation of '
'k_scale and v_scale when kv-cache-dtype is fp8. '
'If calculate-kv-scales is false, the scales will '
'be loaded from the model checkpoint if available. '
'Otherwise, the scales will default to 1.0.')
parser.add_argument(
"--additional-config",
type=json.loads,
@ -1625,9 +1579,7 @@ class EngineArgs:
self.enable_prefix_caching = False
# VLLM_V0 only supports builtin hash algo for prefix caching.
if self.prefix_caching_hash_algo is None:
self.prefix_caching_hash_algo = "builtin"
elif self.prefix_caching_hash_algo == "sha256":
if self.prefix_caching_hash_algo == "sha256":
raise ValueError(
"sha256 is not supported for prefix caching in V0 engine. "
"Please use 'builtin'.")
@ -1646,10 +1598,6 @@ class EngineArgs:
if self.enable_prefix_caching is None:
self.enable_prefix_caching = True
# if using prefix caching, we must set a hash algo
if self.enable_prefix_caching and self.prefix_caching_hash_algo is None:
self.prefix_caching_hash_algo = "builtin"
# V1 should use the new scheduler by default.
# Swap it only if this arg is set to the original V0 default
if self.scheduler_cls == EngineArgs.scheduler_cls:

View File

@ -50,7 +50,7 @@ class NeuronPlatform(Platform):
if cache_config:
# neuron needs block_size = max_model_len
vllm_config.cache_config.block_size = \
vllm_config.model_config.max_model_len
vllm_config.model_config.max_model_len # type: ignore
@classmethod
def is_pin_memory_available(cls) -> bool: