mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 03:05:02 +08:00
[Add] cmdline argument parsing for KV cache offloading modules (#27621)
Signed-off-by: ApostaC <yihua98@uchicago.edu> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
e2347dbf58
commit
e675118849
65
tests/v1/kv_connector/unit/test_config.py
Normal file
65
tests/v1/kv_connector/unit/test_config.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
"""Tests for KV cache offloading configuration."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.config import CacheConfig, KVTransferConfig, ParallelConfig, VllmConfig
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.cpu_test
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"kv_offloading_backend,kv_offloading_size,tp,pp,expected_backend,expected_bytes",
|
||||||
|
[
|
||||||
|
("native", 4.0, 1, 1, "OffloadingConnector", 4.0 * (1 << 30)),
|
||||||
|
# bytes per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
|
||||||
|
("native", 8.0, 2, 2, "OffloadingConnector", 8.0 * (1 << 30) / 4),
|
||||||
|
("lmcache", 4.0, 1, 1, "LMCacheConnectorV1", 4.0),
|
||||||
|
# size per rank: 8.0 GiB / (2 * 2) = 2.0 GiB
|
||||||
|
("lmcache", 8.0, 2, 2, "LMCacheConnectorV1", 2.0),
|
||||||
|
(None, None, 1, 1, None, None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_kv_connector(
|
||||||
|
kv_offloading_backend, kv_offloading_size, tp, pp, expected_backend, expected_bytes
|
||||||
|
):
|
||||||
|
kv_transfer_config = (
|
||||||
|
KVTransferConfig(kv_connector_extra_config={"existing_key": "existing_value"})
|
||||||
|
if expected_backend is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
cache_config=CacheConfig(
|
||||||
|
kv_offloading_backend=kv_offloading_backend,
|
||||||
|
kv_offloading_size=kv_offloading_size,
|
||||||
|
),
|
||||||
|
kv_transfer_config=kv_transfer_config,
|
||||||
|
parallel_config=ParallelConfig(
|
||||||
|
tensor_parallel_size=tp, pipeline_parallel_size=pp
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# No KV transfer config expected
|
||||||
|
if expected_backend is None:
|
||||||
|
assert vllm_config.kv_transfer_config is expected_backend
|
||||||
|
return
|
||||||
|
|
||||||
|
kv_transfer_config = vllm_config.kv_transfer_config
|
||||||
|
kv_connector_extra_config = kv_transfer_config.kv_connector_extra_config
|
||||||
|
|
||||||
|
assert kv_transfer_config.kv_connector == expected_backend
|
||||||
|
assert kv_transfer_config.kv_role == "kv_both"
|
||||||
|
|
||||||
|
if kv_offloading_backend == "native":
|
||||||
|
assert kv_connector_extra_config["kv_bytes_per_rank"] == expected_bytes
|
||||||
|
assert kv_connector_extra_config["num_cpu_blocks"] == 0
|
||||||
|
# Existing config should be preserved
|
||||||
|
assert kv_connector_extra_config["existing_key"] == "existing_value"
|
||||||
|
elif kv_offloading_backend == "lmcache":
|
||||||
|
assert kv_connector_extra_config["lmcache.local_cpu"] is True
|
||||||
|
assert kv_connector_extra_config["lmcache.max_local_cpu_size"] == expected_bytes
|
||||||
|
# Existing config should be replaced
|
||||||
|
assert "existing_key" not in kv_connector_extra_config
|
||||||
@ -24,6 +24,7 @@ BlockSize = Literal[1, 8, 16, 32, 64, 128, 256]
|
|||||||
CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
|
CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
|
||||||
MambaDType = Literal["auto", "float32"]
|
MambaDType = Literal["auto", "float32"]
|
||||||
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
|
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
|
||||||
|
KVOffloadingBackend = Literal["native", "lmcache"]
|
||||||
|
|
||||||
|
|
||||||
@config
|
@config
|
||||||
@ -128,6 +129,17 @@ class CacheConfig:
|
|||||||
gpu_memory_utilization. Note that kv_cache_memory_bytes
|
gpu_memory_utilization. Note that kv_cache_memory_bytes
|
||||||
(when not-None) ignores gpu_memory_utilization"""
|
(when not-None) ignores gpu_memory_utilization"""
|
||||||
|
|
||||||
|
kv_offloading_size: float | None = None
|
||||||
|
"""Size of the KV cache offloading buffer in GiB. When TP > 1, this is
|
||||||
|
the total buffer size summed across all TP ranks. By default, this is set
|
||||||
|
to None, which means no KV offloading is enabled. When set with
|
||||||
|
kv_offloading_backend, vLLM will enable KV cache offloading to CPU"""
|
||||||
|
|
||||||
|
kv_offloading_backend: KVOffloadingBackend | None = None
|
||||||
|
"""The backend to use for KV cache offloading. Supported backends include
|
||||||
|
'native' (vLLM native CPU offloading), 'lmcache' This option must be used
|
||||||
|
together with kv_offloading_size."""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
WARNING: Whenever a new field is added to this config,
|
WARNING: Whenever a new field is added to this config,
|
||||||
|
|||||||
@ -289,6 +289,48 @@ class VllmConfig:
|
|||||||
|
|
||||||
return replace(self, model_config=model_config)
|
return replace(self, model_config=model_config)
|
||||||
|
|
||||||
|
def _post_init_kv_transfer_config(self) -> None:
|
||||||
|
"""Update KVTransferConfig based on top-level configs in VllmConfig.
|
||||||
|
|
||||||
|
Right now, this function reads the offloading settings from
|
||||||
|
CacheConfig and configures the KVTransferConfig accordingly.
|
||||||
|
"""
|
||||||
|
if (kv_offloading_backend := self.cache_config.kv_offloading_backend) is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# If no KVTransferConfig is provided, create a default one.
|
||||||
|
if self.kv_transfer_config is None:
|
||||||
|
self.kv_transfer_config = KVTransferConfig()
|
||||||
|
|
||||||
|
if (kv_offloading_size := self.cache_config.kv_offloading_size) is None:
|
||||||
|
raise ValueError(
|
||||||
|
"You must set kv_offloading_size when kv_offloading_backend is set."
|
||||||
|
)
|
||||||
|
num_kv_ranks = (
|
||||||
|
self.parallel_config.tensor_parallel_size
|
||||||
|
* self.parallel_config.pipeline_parallel_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if kv_offloading_backend == "native":
|
||||||
|
self.kv_transfer_config.kv_connector = "OffloadingConnector"
|
||||||
|
kv_bytes_per_rank = kv_offloading_size * (1 << 30) / num_kv_ranks
|
||||||
|
|
||||||
|
# NOTE(ApostaC): the actual calculation for num_cpu_blocks should be
|
||||||
|
# done after the model's KV cache is initialized
|
||||||
|
self.kv_transfer_config.kv_connector_extra_config.update(
|
||||||
|
{"kv_bytes_per_rank": kv_bytes_per_rank, "num_cpu_blocks": 0}
|
||||||
|
)
|
||||||
|
elif kv_offloading_backend == "lmcache":
|
||||||
|
self.kv_transfer_config.kv_connector = "LMCacheConnectorV1"
|
||||||
|
kv_gb_per_rank = kv_offloading_size / num_kv_ranks
|
||||||
|
self.kv_transfer_config.kv_connector_extra_config = {
|
||||||
|
"lmcache.local_cpu": True,
|
||||||
|
"lmcache.max_local_cpu_size": kv_gb_per_rank,
|
||||||
|
}
|
||||||
|
|
||||||
|
# This is the same for all backends
|
||||||
|
self.kv_transfer_config.kv_role = "kv_both"
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Verify configs are valid & consistent with each other."""
|
"""Verify configs are valid & consistent with each other."""
|
||||||
|
|
||||||
@ -646,6 +688,9 @@ class VllmConfig:
|
|||||||
if "-quant_fp8" not in custom_ops:
|
if "-quant_fp8" not in custom_ops:
|
||||||
custom_ops.append("+quant_fp8")
|
custom_ops.append("+quant_fp8")
|
||||||
|
|
||||||
|
# Handle the KV connector configs
|
||||||
|
self._post_init_kv_transfer_config()
|
||||||
|
|
||||||
def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list:
|
def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list:
|
||||||
# remove the sizes that not multiple of tp_size when
|
# remove the sizes that not multiple of tp_size when
|
||||||
# enable sequence parallelism
|
# enable sequence parallelism
|
||||||
|
|||||||
@ -54,7 +54,13 @@ from vllm.config import (
|
|||||||
VllmConfig,
|
VllmConfig,
|
||||||
get_attr_docs,
|
get_attr_docs,
|
||||||
)
|
)
|
||||||
from vllm.config.cache import BlockSize, CacheDType, MambaDType, PrefixCachingHashAlgo
|
from vllm.config.cache import (
|
||||||
|
BlockSize,
|
||||||
|
CacheDType,
|
||||||
|
KVOffloadingBackend,
|
||||||
|
MambaDType,
|
||||||
|
PrefixCachingHashAlgo,
|
||||||
|
)
|
||||||
from vllm.config.device import Device
|
from vllm.config.device import Device
|
||||||
from vllm.config.model import (
|
from vllm.config.model import (
|
||||||
ConvertOption,
|
ConvertOption,
|
||||||
@ -553,6 +559,11 @@ class EngineArgs:
|
|||||||
|
|
||||||
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
|
kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill
|
||||||
|
|
||||||
|
kv_offloading_size: float | None = CacheConfig.kv_offloading_size
|
||||||
|
kv_offloading_backend: KVOffloadingBackend | None = (
|
||||||
|
CacheConfig.kv_offloading_backend
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# support `EngineArgs(compilation_config={...})`
|
# support `EngineArgs(compilation_config={...})`
|
||||||
# without having to manually construct a
|
# without having to manually construct a
|
||||||
@ -896,6 +907,12 @@ class EngineArgs:
|
|||||||
cache_group.add_argument(
|
cache_group.add_argument(
|
||||||
"--mamba-block-size", **cache_kwargs["mamba_block_size"]
|
"--mamba-block-size", **cache_kwargs["mamba_block_size"]
|
||||||
)
|
)
|
||||||
|
cache_group.add_argument(
|
||||||
|
"--kv-offloading-size", **cache_kwargs["kv_offloading_size"]
|
||||||
|
)
|
||||||
|
cache_group.add_argument(
|
||||||
|
"--kv-offloading-backend", **cache_kwargs["kv_offloading_backend"]
|
||||||
|
)
|
||||||
|
|
||||||
# Multimodal related configs
|
# Multimodal related configs
|
||||||
multimodal_kwargs = get_kwargs(MultiModalConfig)
|
multimodal_kwargs = get_kwargs(MultiModalConfig)
|
||||||
@ -1387,6 +1404,8 @@ class EngineArgs:
|
|||||||
mamba_cache_dtype=self.mamba_cache_dtype,
|
mamba_cache_dtype=self.mamba_cache_dtype,
|
||||||
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
|
mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype,
|
||||||
mamba_block_size=self.mamba_block_size,
|
mamba_block_size=self.mamba_block_size,
|
||||||
|
kv_offloading_size=self.kv_offloading_size,
|
||||||
|
kv_offloading_backend=self.kv_offloading_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
ray_runtime_env = None
|
ray_runtime_env = None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user