mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-07 08:03:05 +08:00
Improve configs - CacheConfig (#16835)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
87aaadef73
commit
4b07d36891
122
vllm/config.py
122
vllm/config.py
@ -1245,22 +1245,70 @@ class ModelConfig:
|
||||
or getattr(self.hf_config, "is_matryoshka", False))
|
||||
|
||||
|
||||
class CacheConfig:
|
||||
"""Configuration for the KV cache.
|
||||
BlockSize = Literal[8, 16, 32, 64, 128]
|
||||
CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"]
|
||||
PrefixCachingHashAlgo = Literal["builtin", "sha256"]
|
||||
|
||||
Args:
|
||||
block_size: Size of a cache block in number of tokens.
|
||||
gpu_memory_utilization: Fraction of GPU memory to use for the
|
||||
vLLM execution.
|
||||
swap_space: Size of the CPU swap space per GPU (in GiB).
|
||||
cache_dtype: Data type for kv cache storage.
|
||||
is_attention_free: Whether the model is attention-free.
|
||||
num_gpu_blocks_override: Number of GPU blocks to use. This overrides the
|
||||
profiled num_gpu_blocks if specified. Does nothing if None.
|
||||
sliding_window: Sliding window size for the KV cache.
|
||||
enable_prefix_caching: Whether to enable prefix caching.
|
||||
cpu_offload_gb: Size of the CPU offload buffer in GiB.
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class CacheConfig:
|
||||
"""Configuration for the KV cache."""
|
||||
|
||||
block_size: Optional[BlockSize] = None
|
||||
"""Size of a contiguous cache block in number of tokens. This is ignored on
|
||||
neuron devices and set to `--max-model-len`. On CUDA devices, only block
|
||||
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
|
||||
"""
|
||||
gpu_memory_utilization: float = 0.9
|
||||
"""The fraction of GPU memory to be used for the model executor, which can
|
||||
range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory
|
||||
utilization. If unspecified, will use the default value of 0.9. This is a
|
||||
per-instance limit, and only applies to the current vLLM instance. It does
|
||||
not matter if you have another vLLM instance running on the same GPU. For
|
||||
example, if you have two vLLM instances running on the same GPU, you can
|
||||
set the GPU memory utilization to 0.5 for each instance."""
|
||||
swap_space: float = 4
|
||||
"""Size of the CPU swap space per GPU (in GiB)."""
|
||||
cache_dtype: CacheDType = "auto"
|
||||
"""Data type for kv cache storage. If "auto", will use model data type.
|
||||
CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports
|
||||
fp8 (=fp8_e4m3)."""
|
||||
is_attention_free: bool = False
|
||||
"""Whether the model is attention-free. This is primarily set in
|
||||
`ModelConfig` and that value should be manually duplicated here."""
|
||||
num_gpu_blocks_override: Optional[int] = None
|
||||
"""Number of GPU blocks to use. This overrides the profiled `num_gpu_blocks`
|
||||
if specified. Does nothing if `None`. Used for testing preemption."""
|
||||
sliding_window: Optional[int] = None
|
||||
"""Sliding window size for the KV cache. This is primarily set in
|
||||
`ModelConfig` and that value should be manually duplicated here."""
|
||||
enable_prefix_caching: Optional[bool] = None
|
||||
"""Whether to enable prefix caching. Disabled by default for V0. Enabled by
|
||||
default for V1."""
|
||||
prefix_caching_hash_algo: PrefixCachingHashAlgo = "builtin"
|
||||
"""Set the hash algorithm for prefix caching:\n
|
||||
- "builtin" is Python's built-in hash.\n
|
||||
- "sha256" is collision resistant but with certain overheads."""
|
||||
cpu_offload_gb: float = 0
|
||||
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
|
||||
no offloading. Intuitively, this argument can be seen as a virtual way to
|
||||
increase the GPU memory size. For example, if you have one 24 GB GPU and
|
||||
set this to 10, virtually you can think of it as a 34 GB GPU. Then you can
|
||||
load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
|
||||
Note that this requires fast CPU-GPU interconnect, as part of the model is
|
||||
loaded from CPU memory to GPU memory on the fly in each model forward pass.
|
||||
"""
|
||||
calculate_kv_scales: bool = False
|
||||
"""This enables dynamic calculation of `k_scale` and `v_scale` when
|
||||
kv_cache_dtype is fp8. If `False`, the scales will be loaded from the model
|
||||
checkpoint if available. Otherwise, the scales will default to 1.0."""
|
||||
|
||||
# Will be set after profiling.
|
||||
num_gpu_blocks: Optional[int] = field(default=None, init=False)
|
||||
"""The number of blocks to allocate for GPU memory."""
|
||||
num_cpu_blocks: Optional[int] = field(default=None, init=False)
|
||||
"""The number of blocks to allocate for CPU memory."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@ -1281,43 +1329,13 @@ class CacheConfig:
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
gpu_memory_utilization: float,
|
||||
swap_space: float,
|
||||
cache_dtype: str,
|
||||
is_attention_free: bool = False,
|
||||
num_gpu_blocks_override: Optional[int] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
enable_prefix_caching: bool = False,
|
||||
prefix_caching_hash_algo: str = "builtin",
|
||||
cpu_offload_gb: float = 0,
|
||||
calculate_kv_scales: Optional[bool] = None,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.gpu_memory_utilization = gpu_memory_utilization
|
||||
self.swap_space_bytes = swap_space * GiB_bytes
|
||||
self.num_gpu_blocks_override = num_gpu_blocks_override
|
||||
self.cache_dtype = cache_dtype
|
||||
self.is_attention_free = is_attention_free
|
||||
self.sliding_window = sliding_window
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
self.prefix_caching_hash_algo = prefix_caching_hash_algo
|
||||
self.cpu_offload_gb = cpu_offload_gb
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
def __post_init__(self) -> None:
|
||||
self.swap_space_bytes = self.swap_space * GiB_bytes
|
||||
|
||||
self._verify_args()
|
||||
self._verify_cache_dtype()
|
||||
self._verify_prefix_caching()
|
||||
|
||||
# Will be set after profiling.
|
||||
self.num_gpu_blocks: Optional[int] = None
|
||||
self.num_cpu_blocks: Optional[int] = None
|
||||
|
||||
# Set calculate_kv_scales to False if the value is unset.
|
||||
if self.calculate_kv_scales is None:
|
||||
self.calculate_kv_scales = False
|
||||
|
||||
def metrics_info(self):
|
||||
# convert cache_config to dict(key: str, value: str) for prometheus
|
||||
# metrics info
|
||||
@ -1336,7 +1354,7 @@ class CacheConfig:
|
||||
def _verify_cache_dtype(self) -> None:
|
||||
if self.cache_dtype == "auto":
|
||||
pass
|
||||
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
|
||||
elif self.cache_dtype in get_args(CacheDType):
|
||||
logger.info(
|
||||
"Using fp8 data type to store kv cache. It reduces the GPU "
|
||||
"memory footprint and boosts the performance. "
|
||||
@ -1354,12 +1372,12 @@ class CacheConfig:
|
||||
"Prefix caching is not supported with sliding window. "
|
||||
"Run with --disable-sliding-window to use prefix caching.")
|
||||
|
||||
if self.enable_prefix_caching and self.prefix_caching_hash_algo not in (
|
||||
"builtin", "sha256"):
|
||||
if (self.enable_prefix_caching and self.prefix_caching_hash_algo
|
||||
not in get_args(PrefixCachingHashAlgo)):
|
||||
raise ValueError(
|
||||
"Unknown prefix caching hash algorithm: "
|
||||
f"{self.prefix_caching_hash_algo}. Must be either "
|
||||
"'builtin' or 'sha256'.")
|
||||
f"{self.prefix_caching_hash_algo}. Must be one of "
|
||||
f"{get_args(PrefixCachingHashAlgo)}.")
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
|
||||
@ -16,16 +16,16 @@ from typing_extensions import TypeIs
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import version
|
||||
from vllm.config import (CacheConfig, CompilationConfig, Config, ConfigFormat,
|
||||
DecodingConfig, Device, DeviceConfig,
|
||||
DistributedExecutorBackend, HfOverrides,
|
||||
from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
Config, ConfigFormat, DecodingConfig, Device,
|
||||
DeviceConfig, DistributedExecutorBackend, HfOverrides,
|
||||
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
|
||||
ModelConfig, ModelImpl, MultiModalConfig,
|
||||
ObservabilityConfig, ParallelConfig, PoolerConfig,
|
||||
PoolType, PromptAdapterConfig, SchedulerConfig,
|
||||
SchedulerPolicy, SpeculativeConfig, TaskOption,
|
||||
TokenizerPoolConfig, VllmConfig, get_attr_docs,
|
||||
get_field)
|
||||
PoolType, PrefixCachingHashAlgo, PromptAdapterConfig,
|
||||
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
||||
TaskOption, TokenizerPoolConfig, VllmConfig,
|
||||
get_attr_docs, get_field)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
@ -138,7 +138,7 @@ class EngineArgs:
|
||||
load_format: str = LoadConfig.load_format
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO
|
||||
dtype: str = 'auto'
|
||||
kv_cache_dtype: str = 'auto'
|
||||
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
|
||||
seed: Optional[int] = None
|
||||
max_model_len: Optional[int] = None
|
||||
# Note: Specifying a custom executor backend by passing a class
|
||||
@ -154,15 +154,16 @@ class EngineArgs:
|
||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||
max_parallel_loading_workers: Optional[
|
||||
int] = ParallelConfig.max_parallel_loading_workers
|
||||
block_size: Optional[int] = None
|
||||
enable_prefix_caching: Optional[bool] = None
|
||||
prefix_caching_hash_algo: str = "builtin"
|
||||
block_size: Optional[BlockSize] = CacheConfig.block_size
|
||||
enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
|
||||
prefix_caching_hash_algo: PrefixCachingHashAlgo = \
|
||||
CacheConfig.prefix_caching_hash_algo
|
||||
disable_sliding_window: bool = False
|
||||
disable_cascade_attn: bool = False
|
||||
use_v2_block_manager: bool = True
|
||||
swap_space: float = 4 # GiB
|
||||
cpu_offload_gb: float = 0 # GiB
|
||||
gpu_memory_utilization: float = 0.90
|
||||
swap_space: float = CacheConfig.swap_space
|
||||
cpu_offload_gb: float = CacheConfig.cpu_offload_gb
|
||||
gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
|
||||
max_num_batched_tokens: Optional[
|
||||
int] = SchedulerConfig.max_num_batched_tokens
|
||||
max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
|
||||
@ -211,7 +212,8 @@ class EngineArgs:
|
||||
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
|
||||
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
|
||||
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
|
||||
num_gpu_blocks_override: Optional[int] = None
|
||||
num_gpu_blocks_override: Optional[
|
||||
int] = CacheConfig.num_gpu_blocks_override
|
||||
num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
|
||||
model_loader_extra_config: dict = \
|
||||
get_field(LoadConfig, "model_loader_extra_config")
|
||||
@ -250,7 +252,7 @@ class EngineArgs:
|
||||
enable_sleep_mode: bool = False
|
||||
model_impl: str = "auto"
|
||||
|
||||
calculate_kv_scales: Optional[bool] = None
|
||||
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
||||
|
||||
additional_config: Optional[Dict[str, Any]] = None
|
||||
enable_reasoning: Optional[bool] = None
|
||||
@ -306,12 +308,19 @@ class EngineArgs:
|
||||
cls_docs = get_attr_docs(cls)
|
||||
kwargs = {}
|
||||
for field in fields(cls):
|
||||
name = field.name
|
||||
# Get the default value of the field
|
||||
default = field.default
|
||||
# This will only be True if default is MISSING
|
||||
if field.default_factory is not MISSING:
|
||||
default = field.default_factory()
|
||||
kwargs[name] = {"default": default, "help": cls_docs[name]}
|
||||
|
||||
# Get the help text for the field
|
||||
name = field.name
|
||||
help = cls_docs[name]
|
||||
# Escape % for argparse
|
||||
help = help.replace("%", "%%")
|
||||
|
||||
# Initialise the kwargs dictionary for the field
|
||||
kwargs[name] = {"default": default, "help": help}
|
||||
|
||||
# Make note of if the field is optional and get the actual
|
||||
# type of the field if it is
|
||||
@ -319,6 +328,8 @@ class EngineArgs:
|
||||
field_type = get_args(
|
||||
field.type)[0] if optional else field.type
|
||||
|
||||
# Set type, action and choices for the field depending on the
|
||||
# type of the field
|
||||
if can_be_type(field_type, bool):
|
||||
# Creates --no-<name> and --<name> flags
|
||||
kwargs[name]["action"] = argparse.BooleanOptionalAction
|
||||
@ -463,14 +474,6 @@ class EngineArgs:
|
||||
'* "bfloat16" for a balance between precision and range.\n'
|
||||
'* "float" is shorthand for FP32 precision.\n'
|
||||
'* "float32" for FP32 precision.')
|
||||
parser.add_argument(
|
||||
'--kv-cache-dtype',
|
||||
type=str,
|
||||
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
|
||||
default=EngineArgs.kv_cache_dtype,
|
||||
help='Data type for kv cache storage. If "auto", will use model '
|
||||
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
||||
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
||||
parser.add_argument('--max-model-len',
|
||||
type=human_readable_int,
|
||||
default=EngineArgs.max_model_len,
|
||||
@ -544,33 +547,30 @@ class EngineArgs:
|
||||
parallel_group.add_argument(
|
||||
'--disable-custom-all-reduce',
|
||||
**parallel_kwargs["disable_custom_all_reduce"])
|
||||
# KV cache arguments
|
||||
parser.add_argument('--block-size',
|
||||
type=int,
|
||||
default=EngineArgs.block_size,
|
||||
choices=[8, 16, 32, 64, 128],
|
||||
help='Token block size for contiguous chunks of '
|
||||
'tokens. This is ignored on neuron devices and '
|
||||
'set to ``--max-model-len``. On CUDA devices, '
|
||||
'only block sizes up to 32 are supported. '
|
||||
'On HPU devices, block size defaults to 128.')
|
||||
|
||||
parser.add_argument(
|
||||
"--enable-prefix-caching",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=EngineArgs.enable_prefix_caching,
|
||||
help="Enables automatic prefix caching. "
|
||||
"Use ``--no-enable-prefix-caching`` to disable explicitly.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix-caching-hash-algo",
|
||||
type=str,
|
||||
choices=["builtin", "sha256"],
|
||||
default=EngineArgs.prefix_caching_hash_algo,
|
||||
help="Set the hash algorithm for prefix caching. "
|
||||
"Options are 'builtin' (Python's built-in hash) or 'sha256' "
|
||||
"(collision resistant but with certain overheads).",
|
||||
# KV cache arguments
|
||||
cache_kwargs = get_kwargs(CacheConfig)
|
||||
cache_group = parser.add_argument_group(
|
||||
title="CacheConfig",
|
||||
description=CacheConfig.__doc__,
|
||||
)
|
||||
cache_group.add_argument('--block-size', **cache_kwargs["block_size"])
|
||||
cache_group.add_argument('--gpu-memory-utilization',
|
||||
**cache_kwargs["gpu_memory_utilization"])
|
||||
cache_group.add_argument('--swap-space', **cache_kwargs["swap_space"])
|
||||
cache_group.add_argument('--kv-cache-dtype',
|
||||
**cache_kwargs["cache_dtype"])
|
||||
cache_group.add_argument('--num-gpu-blocks-override',
|
||||
**cache_kwargs["num_gpu_blocks_override"])
|
||||
cache_group.add_argument("--enable-prefix-caching",
|
||||
**cache_kwargs["enable_prefix_caching"])
|
||||
cache_group.add_argument("--prefix-caching-hash-algo",
|
||||
**cache_kwargs["prefix_caching_hash_algo"])
|
||||
cache_group.add_argument('--cpu-offload-gb',
|
||||
**cache_kwargs["cpu_offload_gb"])
|
||||
cache_group.add_argument('--calculate-kv-scales',
|
||||
**cache_kwargs["calculate_kv_scales"])
|
||||
|
||||
parser.add_argument('--disable-sliding-window',
|
||||
action='store_true',
|
||||
help='Disables sliding window, '
|
||||
@ -588,43 +588,6 @@ class EngineArgs:
|
||||
type=int,
|
||||
default=EngineArgs.seed,
|
||||
help='Random seed for operations.')
|
||||
parser.add_argument('--swap-space',
|
||||
type=float,
|
||||
default=EngineArgs.swap_space,
|
||||
help='CPU swap space size (GiB) per GPU.')
|
||||
parser.add_argument(
|
||||
'--cpu-offload-gb',
|
||||
type=float,
|
||||
default=0,
|
||||
help='The space in GiB to offload to CPU, per GPU. '
|
||||
'Default is 0, which means no offloading. Intuitively, '
|
||||
'this argument can be seen as a virtual way to increase '
|
||||
'the GPU memory size. For example, if you have one 24 GB '
|
||||
'GPU and set this to 10, virtually you can think of it as '
|
||||
'a 34 GB GPU. Then you can load a 13B model with BF16 weight, '
|
||||
'which requires at least 26GB GPU memory. Note that this '
|
||||
'requires fast CPU-GPU interconnect, as part of the model is '
|
||||
'loaded from CPU memory to GPU memory on the fly in each '
|
||||
'model forward pass.')
|
||||
parser.add_argument(
|
||||
'--gpu-memory-utilization',
|
||||
type=float,
|
||||
default=EngineArgs.gpu_memory_utilization,
|
||||
help='The fraction of GPU memory to be used for the model '
|
||||
'executor, which can range from 0 to 1. For example, a value of '
|
||||
'0.5 would imply 50%% GPU memory utilization. If unspecified, '
|
||||
'will use the default value of 0.9. This is a per-instance '
|
||||
'limit, and only applies to the current vLLM instance.'
|
||||
'It does not matter if you have another vLLM instance running '
|
||||
'on the same GPU. For example, if you have two vLLM instances '
|
||||
'running on the same GPU, you can set the GPU memory utilization '
|
||||
'to 0.5 for each instance.')
|
||||
parser.add_argument(
|
||||
'--num-gpu-blocks-override',
|
||||
type=int,
|
||||
default=None,
|
||||
help='If specified, ignore GPU profiling result and use this number'
|
||||
' of GPU blocks. Used for testing preemption.')
|
||||
parser.add_argument(
|
||||
'--max-logprobs',
|
||||
type=int,
|
||||
@ -994,15 +957,6 @@ class EngineArgs:
|
||||
help="Enable sleep mode for the engine. "
|
||||
"(only cuda platform is supported)")
|
||||
|
||||
parser.add_argument(
|
||||
'--calculate-kv-scales',
|
||||
action='store_true',
|
||||
help='This enables dynamic calculation of '
|
||||
'k_scale and v_scale when kv-cache-dtype is fp8. '
|
||||
'If calculate-kv-scales is false, the scales will '
|
||||
'be loaded from the model checkpoint if available. '
|
||||
'Otherwise, the scales will default to 1.0.')
|
||||
|
||||
parser.add_argument(
|
||||
"--additional-config",
|
||||
type=json.loads,
|
||||
@ -1625,9 +1579,7 @@ class EngineArgs:
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
# VLLM_V0 only supports builtin hash algo for prefix caching.
|
||||
if self.prefix_caching_hash_algo is None:
|
||||
self.prefix_caching_hash_algo = "builtin"
|
||||
elif self.prefix_caching_hash_algo == "sha256":
|
||||
if self.prefix_caching_hash_algo == "sha256":
|
||||
raise ValueError(
|
||||
"sha256 is not supported for prefix caching in V0 engine. "
|
||||
"Please use 'builtin'.")
|
||||
@ -1646,10 +1598,6 @@ class EngineArgs:
|
||||
if self.enable_prefix_caching is None:
|
||||
self.enable_prefix_caching = True
|
||||
|
||||
# if using prefix caching, we must set a hash algo
|
||||
if self.enable_prefix_caching and self.prefix_caching_hash_algo is None:
|
||||
self.prefix_caching_hash_algo = "builtin"
|
||||
|
||||
# V1 should use the new scheduler by default.
|
||||
# Swap it only if this arg is set to the original V0 default
|
||||
if self.scheduler_cls == EngineArgs.scheduler_cls:
|
||||
|
||||
@ -50,7 +50,7 @@ class NeuronPlatform(Platform):
|
||||
if cache_config:
|
||||
# neuron needs block_size = max_model_len
|
||||
vllm_config.cache_config.block_size = \
|
||||
vllm_config.model_config.max_model_len
|
||||
vllm_config.model_config.max_model_len # type: ignore
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls) -> bool:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user