[KV Connector] Make KVCacheConfig an explicit constructor argument (#27887)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin 2025-11-04 07:00:49 +00:00 committed by GitHub
parent 2f84ae1f27
commit 58279c60b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 410 additions and 43 deletions

View File

@ -0,0 +1,275 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for backwards compatibility with external KV connector implementations.
This test ensures that external connectors (loaded via kv_connector_module_path)
implemented with the old signature continue to work:
- Old signature: __init__(self, vllm_config, role)
- New signature: __init__(self, vllm_config, role, kv_cache_config)
"""
from typing import TYPE_CHECKING
from unittest.mock import patch
import pytest
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
)
from vllm.v1.core.sched.output import SchedulerOutput
from .utils import create_scheduler, create_vllm_config
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
class OldStyleTestConnector(KVConnectorBase_V1):
"""
Test connector using the old signature with 2 required arguments.
This simulates external connectors that haven't been updated yet.
"""
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
# Old-style call to super().__init__ with only 2 arguments
super().__init__(vllm_config=vllm_config, role=role)
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int | None, bool]:
return 0, False
def update_state_after_alloc(
self,
request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int,
):
pass
def build_connector_meta(self, scheduler_output: SchedulerOutput):
return None
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
pass
def wait_for_layer_load(self, layer_name: str) -> None:
pass
def save_kv_layer(
self,
layer_name: str,
kv_layer,
attn_metadata: "AttentionMetadata",
**kwargs,
) -> None:
pass
def wait_for_save(self):
pass
class NewStyleTestConnector(KVConnectorBase_V1):
"""
Test connector using the new signature with 3 required arguments.
"""
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig",
):
# New-style call to super().__init__ with all 3 arguments
super().__init__(
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
)
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int | None, bool]:
return 0, False
def update_state_after_alloc(
self,
request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int,
):
pass
def build_connector_meta(self, scheduler_output: SchedulerOutput):
return None
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
pass
def wait_for_layer_load(self, layer_name: str) -> None:
pass
def save_kv_layer(
self,
layer_name: str,
kv_layer,
attn_metadata: "AttentionMetadata",
**kwargs,
) -> None:
pass
def wait_for_save(self):
pass
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
def test_external_old_signature_factory_instantiation(role):
"""
Test that external connectors with old signature (2 required args) loaded
via kv_connector_module_path are correctly instantiated with backwards
compatibility support.
"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "OldStyleTestConnector"
vllm_config.kv_transfer_config.kv_connector_module_path = (
"tests.v1.kv_connector.unit.test_backwards_compatibility"
)
scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config
connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)
assert connector is not None
assert isinstance(connector, OldStyleTestConnector)
assert connector.role == role
assert connector._kv_cache_config is None
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
def test_external_new_signature_factory_instantiation(role):
"""
Test that external connectors with new signature (3 required args) loaded
via kv_connector_module_path are correctly instantiated.
"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "NewStyleTestConnector"
vllm_config.kv_transfer_config.kv_connector_module_path = (
"tests.v1.kv_connector.unit.test_backwards_compatibility"
)
scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config
connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)
assert connector is not None
assert isinstance(connector, NewStyleTestConnector)
assert connector.role == role
assert connector._kv_cache_config is not None
assert connector._kv_cache_config == kv_cache_config
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
def test_old_signature_super_init(role):
"""
Test that old-style connectors can call super().__init__() without
kv_cache_config parameter.
"""
vllm_config = create_vllm_config()
connector = OldStyleTestConnector(vllm_config, role)
assert connector is not None
assert connector.role == role
assert connector._kv_cache_config is None
def test_old_signature_super_init_with_kwargs():
"""
Test that old-style connectors can call super().__init__() with keyword
arguments in different orders.
"""
vllm_config = create_vllm_config()
# Test with vllm_config= and role= kwargs
connector1 = OldStyleTestConnector(
vllm_config=vllm_config, role=KVConnectorRole.SCHEDULER
)
assert connector1 is not None
assert connector1._kv_cache_config is None
# Test with role= and vllm_config= in reversed order
connector2 = OldStyleTestConnector(
role=KVConnectorRole.WORKER, vllm_config=vllm_config
)
assert connector2 is not None
assert connector2._kv_cache_config is None
def test_internal_connector_uses_new_signature():
"""
Test that internal connectors (registered in factory) always use the new
signature and get kv_cache_config.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import (
SharedStorageConnector,
)
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "SharedStorageConnector"
scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config
connector = KVConnectorFactory.create_connector(
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
)
assert connector is not None
assert isinstance(connector, SharedStorageConnector)
assert connector._kv_cache_config is not None
assert connector._kv_cache_config == kv_cache_config
def test_signature_detection_with_mocking():
"""
Test that the factory correctly applies compat_sig flag returned from
_get_connector_class_with_compat.
"""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config
# Mock _get_connector_class_with_compat to return old-style connector
with patch.object(
KVConnectorFactory,
"_get_connector_class_with_compat",
return_value=(OldStyleTestConnector, True),
):
old_connector = KVConnectorFactory.create_connector(
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
)
assert old_connector is not None
assert isinstance(old_connector, OldStyleTestConnector)
assert old_connector._kv_cache_config is None
# Mock _get_connector_class_with_compat to return new-style connector
with patch.object(
KVConnectorFactory,
"_get_connector_class_with_compat",
return_value=(NewStyleTestConnector, False),
):
new_connector = KVConnectorFactory.create_connector(
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
)
assert new_connector is not None
assert isinstance(new_connector, NewStyleTestConnector)
assert new_connector._kv_cache_config is not None
assert new_connector._kv_cache_config == kv_cache_config

View File

@ -254,7 +254,7 @@ def create_model_runner_output(
class TestSharedStorageConnector(SharedStorageConnector):
def __init__(self, config: VllmConfig, role):
def __init__(self, config: VllmConfig, role, kv_cache_config):
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
self._connector = SharedStorageConnector(config, role)
self.call_record: dict[str, int] = defaultdict(int)

View File

@ -3,10 +3,9 @@
import importlib
from collections.abc import Callable
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Optional, cast
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import (
KVConnectorBase,
KVConnectorBaseType,
@ -16,9 +15,12 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
supports_hma,
)
from vllm.logger import init_logger
from vllm.utils.func_utils import supports_kw
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.v1.kv_cache_interface import KVCacheConfig
logger = init_logger(__name__)
@ -41,8 +43,9 @@ class KVConnectorFactory:
@classmethod
def create_connector(
cls,
config: VllmConfig,
config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
) -> KVConnectorBase:
if not envs.VLLM_USE_V1:
raise ValueError(
@ -53,7 +56,9 @@ class KVConnectorFactory:
kv_transfer_config = config.kv_transfer_config
if kv_transfer_config is None:
raise ValueError("kv_transfer_config must be set to create a connector")
connector_cls = cls.get_connector_class(kv_transfer_config)
connector_cls, compat_sig = cls._get_connector_class_with_compat(
kv_transfer_config
)
# check if the connector supports HMA
hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager
@ -76,7 +81,12 @@ class KVConnectorFactory:
# - Co-locate with worker process
# - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation
return connector_cls(config, role)
if compat_sig:
# Old signature: __init__(self, vllm_config, role)
return connector_cls(config, role)
else:
# New signature: __init__(self, vllm_config, role, kv_cache_config)
return connector_cls(config, role, kv_cache_config)
@classmethod
def get_connector_class_by_name(
@ -97,13 +107,13 @@ class KVConnectorFactory:
return cls._registry[connector_name]()
@classmethod
def get_connector_class(
def _get_connector_class_with_compat(
cls, kv_transfer_config: "KVTransferConfig"
) -> type[KVConnectorBaseType]:
"""Get the connector class by name."""
) -> tuple[type[KVConnectorBaseType], bool]:
connector_name = kv_transfer_config.kv_connector
if connector_name is None:
raise ValueError("Connector name is not set in KVTransferConfig")
compat_sig = False
if connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
else:
@ -118,6 +128,21 @@ class KVConnectorFactory:
f"Class {connector_name} not found in {connector_module_path}"
) from e
connector_cls = cast(type[KVConnectorBaseType], connector_cls)
if not supports_kw(connector_cls, "kv_cache_config"):
compat_sig = True
logger.warning(
"Connector %s uses deprecated signature with 2 required arguments. "
"Please update to include kv_cache_config as the second argument.",
connector_cls.__name__,
)
return connector_cls, compat_sig
@classmethod
def get_connector_class(
cls, kv_transfer_config: "KVTransferConfig"
) -> type[KVConnectorBaseType]:
"""Get the connector class by name."""
connector_cls, _ = cls._get_connector_class_with_compat(kv_transfer_config)
return connector_cls

View File

@ -58,6 +58,7 @@ if TYPE_CHECKING:
)
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
# s_tensor_list, d_tensor_list, s_indices, d_indices, direction
@ -141,7 +142,12 @@ class KVConnectorMetadata(ABC): # noqa: B024
class KVConnectorBase_V1(ABC):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
logger.warning(
"Initializing KVConnectorBase_V1. This API is experimental and "
"subject to change in the future as we iterate the design."
@ -152,6 +158,14 @@ class KVConnectorBase_V1(ABC):
self._kv_transfer_config = vllm_config.kv_transfer_config
else:
raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1")
self._kv_cache_config = kv_cache_config
if self._kv_cache_config is None:
logger.warning(
"KVConnectorBase_V1 initialized without kv_cache_config. "
"This is deprecated - please update your connector to accept "
"kv_cache_config as the third constructor argument and pass it "
"to super().__init__()."
)
self._role = role
@property

View File

@ -32,7 +32,7 @@ Usage:
"""
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional
import torch
@ -50,6 +50,7 @@ if TYPE_CHECKING:
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
@ -79,8 +80,13 @@ class DecodeBenchConnector(KVConnectorBase_V1):
testing of the decoder with larger input sequence lengths.
"""
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config, role)
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(vllm_config, role, kv_cache_config)
self.connector_scheduler: DecodeBenchConnectorScheduler | None = None
self.connector_worker: DecodeBenchConnectorWorker | None = None

View File

@ -20,14 +20,22 @@ if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
class LMCacheConnectorV1(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig",
):
super().__init__(
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
)
assert vllm_config.kv_transfer_config is not None
use_native = vllm_config.kv_transfer_config.get_from_extra_config(
"use_native", False

View File

@ -31,6 +31,7 @@ if TYPE_CHECKING:
from vllm.distributed.kv_events import KVCacheEvent
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
@ -109,15 +110,22 @@ class MultiConnector(KVConnectorBase_V1):
- Save to all connectors.
"""
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig",
):
super().__init__(
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
)
self._connectors: list[KVConnectorBase_V1] = []
self._ktc_kv_transfer_config = []
for connector_cls, temp_config in self._get_connector_classes_and_configs(
vllm_config
):
self._connectors.append(connector_cls(temp_config, role))
self._connectors.append(connector_cls(temp_config, role, kv_cache_config))
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)
# A mapping from request id to the index of the connector chosen to

View File

@ -13,7 +13,7 @@ from collections import defaultdict
from collections.abc import Iterator
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional
import msgspec
import numpy as np
@ -52,6 +52,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
Transfer = tuple[int, float] # (xfer_handle, start_time)
@ -150,7 +151,14 @@ class NixlConnectorMetadata(KVConnectorMetadata):
class NixlConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None
assert vllm_config.kv_transfer_config.engine_id is not None
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id

View File

@ -21,6 +21,7 @@ from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.kv_offload.abstract import OffloadingManager
from vllm.v1.kv_offload.factory import OffloadingSpecFactory
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
@ -41,8 +42,13 @@ class OffloadingConnectorMetadata(KVConnectorMetadata):
class OffloadingConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
super().__init__(vllm_config, role)
def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: KVCacheConfig | None = None,
):
super().__init__(vllm_config, role, kv_cache_config)
spec = OffloadingSpecFactory.create_spec(vllm_config)

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional
import regex as re
import torch
@ -25,6 +25,7 @@ if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
@ -71,8 +72,17 @@ class P2pNcclConnectorMetadata(KVConnectorMetadata):
class P2pNcclConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(
vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config,
)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Any] = {}
self.is_producer = self._kv_transfer_config.is_kv_producer

View File

@ -3,7 +3,7 @@
import hashlib
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional
import safetensors
import torch
@ -22,6 +22,7 @@ if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
logger = init_logger(__name__)
@ -86,8 +87,17 @@ class SharedStorageConnector(KVConnectorBase_V1):
# It does extra work which will overwrite the existing prefix-cache in GPU
# - to remove the overhead, need to add some "mask" in the ReqMeta class
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(
vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config,
)
self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Request] = {}
self._storage_path = self._kv_transfer_config.get_from_extra_config(

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from vllm import envs
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
@ -12,6 +12,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import KVCacheConfig
_KV_CONNECTOR_AGENT: KVConnectorBaseType | None = None
@ -48,7 +49,9 @@ def is_v1_kv_transfer_group(connector: KVConnectorBaseType | None = None) -> boo
return isinstance(connector, KVConnectorBase_V1)
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
def ensure_kv_transfer_initialized(
vllm_config: "VllmConfig", kv_cache_config: Optional["KVCacheConfig"] = None
) -> None:
"""
Initialize KV cache transfer parallel group.
"""
@ -64,7 +67,9 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
):
if envs.VLLM_USE_V1:
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector(
config=vllm_config, role=KVConnectorRole.WORKER
config=vllm_config,
role=KVConnectorRole.WORKER,
kv_cache_config=kv_cache_config,
)
else:
raise ValueError("V0 is no longer supported")

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import itertools
import time
from collections import defaultdict
@ -92,15 +91,10 @@ class Scheduler(SchedulerInterface):
assert not self.is_encoder_decoder, (
"Encoder-decoder models are not currently supported with KV connectors"
)
connector_vllm_config = copy.copy(self.vllm_config)
# We're dynamically inserting a kv_cache_config variable into the
# connector_vllm_config. This is distinct from the cache_config
# that is already in there.
connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config) # type: ignore[attr-defined]
self.connector = KVConnectorFactory.create_connector(
config=connector_vllm_config, role=KVConnectorRole.SCHEDULER
config=self.vllm_config,
role=KVConnectorRole.SCHEDULER,
kv_cache_config=self.kv_cache_config,
)
if self.log_stats:
self.connector_prefix_cache_stats = PrefixCacheStats()

View File

@ -380,9 +380,7 @@ class Worker(WorkerBase):
# NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
# because `initialize_kv_cache` will inject kv cache groups not
# related to kv cache connector (e.g. kv cache sharing layers).
connector_vllm_config = copy.copy(self.vllm_config)
connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config)
ensure_kv_transfer_initialized(connector_vllm_config)
ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
if self.vllm_config.model_config.enable_sleep_mode:
from vllm.device_allocator.cumem import CuMemAllocator