[Core][Hybrid allocator + kv connector 1/n] Enable hybrid allocator + KV cache connector (#25712)

Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Signed-off-by: Kuntai Du <kuntai@uchicago.edu>
This commit is contained in:
Kuntai Du 2025-10-24 23:34:18 -07:00 committed by GitHub
parent 56ed7609a9
commit b853540388
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 113 additions and 18 deletions

View File

@ -899,6 +899,7 @@ def test_kv_connector_basic():
scheduler = create_scheduler(
enable_prefix_caching=True,
use_kv_connector=True,
disable_hybrid_kv_cache_manager=True,
)
NUM_TOTAL_BLOCKS = scheduler.kv_cache_manager.block_pool.get_num_free_blocks()
BLOCK_SIZE = scheduler.cache_config.block_size
@ -1024,6 +1025,7 @@ def test_external_prefix_cache_metrics():
scheduler = create_scheduler(
enable_prefix_caching=False,
use_kv_connector=True,
disable_hybrid_kv_cache_manager=True,
)
# Mock connector to simulate a partial external cache hit
@ -1088,6 +1090,7 @@ def test_kv_connector_unable_to_allocate():
use_kv_connector=True,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
disable_hybrid_kv_cache_manager=True,
)
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE * 2
scheduler.connector.get_num_new_matched_tokens = Mock(name="method")
@ -1171,6 +1174,7 @@ def test_kv_connector_handles_preemption():
use_kv_connector=True,
block_size=BLOCK_SIZE,
num_blocks=NUM_BLOCKS,
disable_hybrid_kv_cache_manager=True,
)
NUM_MATCHED_NEW_TOKENS = BLOCK_SIZE
@ -1387,6 +1391,7 @@ def create_scheduler_with_priority(
block_size: int = 16,
max_model_len: int | None = None,
num_speculative_tokens: int | None = None,
disable_hybrid_kv_cache_manager: bool = False,
) -> Scheduler:
"""Create scheduler with priority policy enabled.
@ -1411,6 +1416,7 @@ def create_scheduler_with_priority(
disable_chunked_mm_input=disable_chunked_mm_input,
enable_chunked_prefill=True,
policy="priority", # Enable priority scheduling
disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager,
)
model_config = ModelConfig(
model=model,
@ -2018,6 +2024,7 @@ def test_priority_scheduling_preemption_and_resumption_when_out_of_kv():
num_blocks=5, # Can hold 64 tokens (first block is null)
block_size=16, # Standard block size
use_kv_connector=True,
disable_hybrid_kv_cache_manager=True,
)
# Create a request and schedule it

View File

@ -46,6 +46,7 @@ def create_scheduler(
num_speculative_tokens: int | None = None,
skip_tokenizer_init: bool = False,
async_scheduling: bool = False,
disable_hybrid_kv_cache_manager: bool = False,
) -> Scheduler | AsyncScheduler:
"""Create scheduler under test.
@ -70,6 +71,7 @@ def create_scheduler(
disable_chunked_mm_input=disable_chunked_mm_input,
enable_chunked_prefill=True,
async_scheduling=async_scheduling,
disable_hybrid_kv_cache_manager=disable_hybrid_kv_cache_manager,
)
model_config = ModelConfig(
model=model,

View File

@ -136,6 +136,7 @@ run_tests_for_model() {
vllm serve $model_name \
--port $PORT \
--enforce-eager \
--disable-hybrid-kv-cache-manager \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--tensor-parallel-size $PREFILLER_TP_SIZE \
--kv-transfer-config '$KV_CONFIG'"
@ -178,6 +179,7 @@ run_tests_for_model() {
--port $PORT \
--enforce-eager \
--gpu-memory-utilization $GPU_MEMORY_UTILIZATION \
--disable-hybrid-kv-cache-manager \
--kv-transfer-config '$KV_CONFIG'"
# DP-EP attention mode

View File

@ -85,6 +85,7 @@ run_tests_for_model() {
--port $PREFILL_PORT \
--enforce-eager \
--gpu-memory-utilization 0.2 \
--disable-hybrid-kv-cache-manager \
--kv-transfer-config '$KV_CONFIG'"
if [ -n "$model_args" ]; then
@ -103,6 +104,7 @@ run_tests_for_model() {
--port $DECODE_PORT \
--enforce-eager \
--gpu-memory-utilization 0.2 \
--disable-hybrid-kv-cache-manager \
--kv-transfer-config '$KV_CONFIG'"
if [ -n "$model_args" ]; then

View File

@ -114,6 +114,7 @@ def test_multi_shared_storage_connector_consistency():
enforce_eager=True,
gpu_memory_utilization=0.5,
kv_transfer_config=kv_transfer_config,
disable_hybrid_kv_cache_manager=True,
)
# Run generation - this should trigger saving KV cache
_ = llm.generate(PROMPTS, SAMPLING_PARAMS)

View File

@ -932,6 +932,7 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
"gpu_memory_utilization": 0.5,
"kv_transfer_config": kv_transfer_config,
"distributed_executor_backend": distributed_executor_backend,
"disable_hybrid_kv_cache_manager": True,
}
timeout = 6

View File

@ -132,6 +132,7 @@ def test_shared_storage_connector_hashes(tmp_path):
enforce_eager=True,
kv_transfer_config=kv_transfer_config,
limit_mm_per_prompt={"image": 2},
disable_hybrid_kv_cache_manager=True,
)
# don't put this import at the top level

View File

@ -91,6 +91,9 @@ def create_vllm_config(
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
enable_chunked_prefill=enable_chunked_prefill,
# Disable hybrid KV cache manager for testing
# Should be removed after we support hybrid KV cache manager-based testing.
disable_hybrid_kv_cache_manager=True,
)
model_config = ModelConfig(
model=model,

View File

@ -27,6 +27,7 @@ def test_cpu_offloading(cpu_block_size: int) -> None:
model="meta-llama/Llama-3.2-1B-Instruct",
gpu_memory_utilization=0.5,
kv_transfer_config=kv_transfer_config,
disable_hybrid_kv_cache_manager=True,
)
prompts = ["Hi " * 100]

View File

@ -40,11 +40,14 @@ if TYPE_CHECKING:
from transformers import PretrainedConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.v1.kv_cache_interface import KVCacheConfig
else:
PretrainedConfig = Any
QuantizationConfig = Any
KVCacheConfig = Any
logger = init_logger(__name__)
@ -568,9 +571,6 @@ class VllmConfig:
if not current_platform.support_hybrid_kv_cache():
# Hybrid KV cache manager is not supported on non-GPU platforms.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.kv_transfer_config is not None:
# Hybrid KV cache manager is not compatible with KV transfer.
self.scheduler_config.disable_hybrid_kv_cache_manager = True
if self.kv_events_config is not None:
# Hybrid KV cache manager is not compatible with KV events.
self.scheduler_config.disable_hybrid_kv_cache_manager = True

View File

@ -6,15 +6,18 @@ from collections.abc import Callable
from typing import TYPE_CHECKING, cast
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import (
KVConnectorBase,
KVConnectorBaseType,
)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorRole,
supports_hma,
)
from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
logger = init_logger(__name__)
@ -38,7 +41,7 @@ class KVConnectorFactory:
@classmethod
def create_connector(
cls,
config: "VllmConfig",
config: VllmConfig,
role: KVConnectorRole,
) -> KVConnectorBase:
if not envs.VLLM_USE_V1:
@ -51,6 +54,15 @@ class KVConnectorFactory:
if kv_transfer_config is None:
raise ValueError("kv_transfer_config must be set to create a connector")
connector_cls = cls.get_connector_class(kv_transfer_config)
# check if the connector supports HMA
hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager
if hma_enabled and not supports_hma(connector_cls):
raise ValueError(
f"Connector {connector_cls.__name__} does not support HMA but "
f"HMA is enabled. Please set `--disable-hybrid-kv-cache-manager`."
)
logger.info(
"Creating v1 connector with name: %s and engine_id: %s",
connector_cls.__name__,

View File

@ -3,9 +3,17 @@
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1,
KVConnectorRole,
SupportsHMA,
supports_hma,
)
from vllm.distributed.kv_transfer.kv_connector.v1.decode_bench_connector import ( # noqa E:501
DecodeBenchConnector,
)
__all__ = ["KVConnectorRole", "KVConnectorBase_V1", "DecodeBenchConnector"]
__all__ = [
"KVConnectorRole",
"KVConnectorBase_V1",
"supports_hma",
"SupportsHMA",
"DecodeBenchConnector",
]

View File

@ -70,6 +70,45 @@ CopyBlocksOp = Callable[
logger = init_logger(__name__)
class SupportsHMA(ABC):
"""
The class that indicates the corresponding connector supports hybrid memory
allocator (HMA).
This is required to use the connector together with hybrid memory allocator.
"""
@abstractmethod
def request_finished_all_groups(
self,
request: "Request",
block_ids: tuple[list[int], ...],
) -> tuple[bool, dict[str, Any] | None]:
"""
Called exactly once when a request has finished for all kv cache groups,
before its blocks are freed for each group.
NOTE(Kuntai): This function is only supported by connectors that support HMA.
The connector may assumes responsibility for freeing the blocks
asynchronously by returning True.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
raise NotImplementedError
def supports_hma(connector: Any) -> bool:
if isinstance(connector, type):
return issubclass(connector, SupportsHMA)
else:
return isinstance(connector, SupportsHMA)
class KVConnectorRole(enum.Enum):
# Connector running in the scheduler process
SCHEDULER = 0
@ -370,7 +409,7 @@ class KVConnectorBase_V1(ABC):
Called exactly once when a request has finished, before its blocks are
freed.
The connector may assumes responsibility for freeing the the blocks
The connector may assumes responsibility for freeing the blocks
asynchronously by returning True.
Returns:

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import itertools
import time
from collections import defaultdict
@ -13,6 +13,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
supports_hma,
)
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.logger import init_logger
@ -86,15 +87,14 @@ class Scheduler(SchedulerInterface):
self.connector = None
self.connector_prefix_cache_stats: PrefixCacheStats | None = None
if self.vllm_config.kv_transfer_config is not None:
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"Multiple KV cache groups are not currently supported "
"with KV connectors"
)
assert not self.is_encoder_decoder, (
"Encoder-decoder models are not currently supported with KV connectors"
)
connector_vllm_config = copy.copy(self.vllm_config)
connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config)
self.connector = KVConnectorFactory.create_connector(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER
config=connector_vllm_config, role=KVConnectorRole.SCHEDULER
)
if self.log_stats:
self.connector_prefix_cache_stats = PrefixCacheStats()
@ -1324,8 +1324,17 @@ class Scheduler(SchedulerInterface):
if self.connector is None:
return False, None
(block_ids,) = self.kv_cache_manager.get_block_ids(request.request_id)
return self.connector.request_finished(request, block_ids)
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
if not supports_hma(self.connector):
# NOTE(Kuntai): We should deprecate this code path after we enforce
# all connectors to support HMA.
# Hybrid memory allocator should be already turned off for this
# code path, but let's double-check here.
assert len(self.kv_cache_config.kv_cache_groups) == 1
return self.connector.request_finished(request, block_ids[0])
else:
return self.connector.request_finished(request, block_ids)
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
"""

View File

@ -331,6 +331,15 @@ class Worker(WorkerBase):
def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
# Init kv cache connector here, because it requires
# `kv_cache_config`.
# NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
# because `initialize_kv_cache` will inject kv cache groups not
# related to kv cache connector (e.g. kv cache sharing layers).
connector_vllm_config = copy.copy(self.vllm_config)
connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config)
ensure_kv_transfer_initialized(connector_vllm_config)
if self.vllm_config.model_config.enable_sleep_mode:
from vllm.device_allocator.cumem import CuMemAllocator
@ -783,5 +792,3 @@ def init_worker_distributed_environment(
parallel_config.pipeline_parallel_size,
parallel_config.decode_context_parallel_size,
)
ensure_kv_transfer_initialized(vllm_config)