[KV Connector] Make KVCacheConfig an explicit constructor argument (#27887)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin 2025-11-04 07:00:49 +00:00 committed by GitHub
parent 2f84ae1f27
commit 58279c60b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 410 additions and 43 deletions

View File

@ -0,0 +1,275 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for backwards compatibility with external KV connector implementations.
This test ensures that external connectors (loaded via kv_connector_module_path)
implemented with the old signature continue to work:
- Old signature: __init__(self, vllm_config, role)
- New signature: __init__(self, vllm_config, role, kv_cache_config)
"""
from typing import TYPE_CHECKING
from unittest.mock import patch
import pytest
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorBase_V1,
KVConnectorRole,
)
from vllm.v1.core.sched.output import SchedulerOutput
from .utils import create_scheduler, create_vllm_config
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
class OldStyleTestConnector(KVConnectorBase_V1):
"""
Test connector using the old signature with 2 required arguments.
This simulates external connectors that haven't been updated yet.
"""
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
# Old-style call to super().__init__ with only 2 arguments
super().__init__(vllm_config=vllm_config, role=role)
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int | None, bool]:
return 0, False
def update_state_after_alloc(
self,
request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int,
):
pass
def build_connector_meta(self, scheduler_output: SchedulerOutput):
return None
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
pass
def wait_for_layer_load(self, layer_name: str) -> None:
pass
def save_kv_layer(
self,
layer_name: str,
kv_layer,
attn_metadata: "AttentionMetadata",
**kwargs,
) -> None:
pass
def wait_for_save(self):
pass
class NewStyleTestConnector(KVConnectorBase_V1):
"""
Test connector using the new signature with 3 required arguments.
"""
def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig",
):
# New-style call to super().__init__ with all 3 arguments
super().__init__(
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
)
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int | None, bool]:
return 0, False
def update_state_after_alloc(
self,
request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int,
):
pass
def build_connector_meta(self, scheduler_output: SchedulerOutput):
return None
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
pass
def wait_for_layer_load(self, layer_name: str) -> None:
pass
def save_kv_layer(
self,
layer_name: str,
kv_layer,
attn_metadata: "AttentionMetadata",
**kwargs,
) -> None:
pass
def wait_for_save(self):
pass
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
def test_external_old_signature_factory_instantiation(role):
"""
Test that external connectors with old signature (2 required args) loaded
via kv_connector_module_path are correctly instantiated with backwards
compatibility support.
"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "OldStyleTestConnector"
vllm_config.kv_transfer_config.kv_connector_module_path = (
"tests.v1.kv_connector.unit.test_backwards_compatibility"
)
scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config
connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)
assert connector is not None
assert isinstance(connector, OldStyleTestConnector)
assert connector.role == role
assert connector._kv_cache_config is None
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
def test_external_new_signature_factory_instantiation(role):
"""
Test that external connectors with new signature (3 required args) loaded
via kv_connector_module_path are correctly instantiated.
"""
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "NewStyleTestConnector"
vllm_config.kv_transfer_config.kv_connector_module_path = (
"tests.v1.kv_connector.unit.test_backwards_compatibility"
)
scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config
connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)
assert connector is not None
assert isinstance(connector, NewStyleTestConnector)
assert connector.role == role
assert connector._kv_cache_config is not None
assert connector._kv_cache_config == kv_cache_config
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
def test_old_signature_super_init(role):
"""
Test that old-style connectors can call super().__init__() without
kv_cache_config parameter.
"""
vllm_config = create_vllm_config()
connector = OldStyleTestConnector(vllm_config, role)
assert connector is not None
assert connector.role == role
assert connector._kv_cache_config is None
def test_old_signature_super_init_with_kwargs():
"""
Test that old-style connectors can call super().__init__() with keyword
arguments in different orders.
"""
vllm_config = create_vllm_config()
# Test with vllm_config= and role= kwargs
connector1 = OldStyleTestConnector(
vllm_config=vllm_config, role=KVConnectorRole.SCHEDULER
)
assert connector1 is not None
assert connector1._kv_cache_config is None
# Test with role= and vllm_config= in reversed order
connector2 = OldStyleTestConnector(
role=KVConnectorRole.WORKER, vllm_config=vllm_config
)
assert connector2 is not None
assert connector2._kv_cache_config is None
def test_internal_connector_uses_new_signature():
"""
Test that internal connectors (registered in factory) always use the new
signature and get kv_cache_config.
"""
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import (
SharedStorageConnector,
)
vllm_config = create_vllm_config()
vllm_config.kv_transfer_config.kv_connector = "SharedStorageConnector"
scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config
connector = KVConnectorFactory.create_connector(
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
)
assert connector is not None
assert isinstance(connector, SharedStorageConnector)
assert connector._kv_cache_config is not None
assert connector._kv_cache_config == kv_cache_config
def test_signature_detection_with_mocking():
"""
Test that the factory correctly applies compat_sig flag returned from
_get_connector_class_with_compat.
"""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
kv_cache_config = scheduler.kv_cache_config
# Mock _get_connector_class_with_compat to return old-style connector
with patch.object(
KVConnectorFactory,
"_get_connector_class_with_compat",
return_value=(OldStyleTestConnector, True),
):
old_connector = KVConnectorFactory.create_connector(
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
)
assert old_connector is not None
assert isinstance(old_connector, OldStyleTestConnector)
assert old_connector._kv_cache_config is None
# Mock _get_connector_class_with_compat to return new-style connector
with patch.object(
KVConnectorFactory,
"_get_connector_class_with_compat",
return_value=(NewStyleTestConnector, False),
):
new_connector = KVConnectorFactory.create_connector(
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
)
assert new_connector is not None
assert isinstance(new_connector, NewStyleTestConnector)
assert new_connector._kv_cache_config is not None
assert new_connector._kv_cache_config == kv_cache_config

View File

@ -254,7 +254,7 @@ def create_model_runner_output(
class TestSharedStorageConnector(SharedStorageConnector): class TestSharedStorageConnector(SharedStorageConnector):
def __init__(self, config: VllmConfig, role): def __init__(self, config: VllmConfig, role, kv_cache_config):
self.name = config.kv_transfer_config.kv_connector_extra_config["name"] self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
self._connector = SharedStorageConnector(config, role) self._connector = SharedStorageConnector(config, role)
self.call_record: dict[str, int] = defaultdict(int) self.call_record: dict[str, int] = defaultdict(int)

View File

@ -3,10 +3,9 @@
import importlib import importlib
from collections.abc import Callable from collections.abc import Callable
from typing import TYPE_CHECKING, cast from typing import TYPE_CHECKING, Optional, cast
import vllm.envs as envs import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import ( from vllm.distributed.kv_transfer.kv_connector.base import (
KVConnectorBase, KVConnectorBase,
KVConnectorBaseType, KVConnectorBaseType,
@ -16,9 +15,12 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
supports_hma, supports_hma,
) )
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.func_utils import supports_kw
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig from vllm.config.kv_transfer import KVTransferConfig
from vllm.v1.kv_cache_interface import KVCacheConfig
logger = init_logger(__name__) logger = init_logger(__name__)
@ -41,8 +43,9 @@ class KVConnectorFactory:
@classmethod @classmethod
def create_connector( def create_connector(
cls, cls,
config: VllmConfig, config: "VllmConfig",
role: KVConnectorRole, role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
) -> KVConnectorBase: ) -> KVConnectorBase:
if not envs.VLLM_USE_V1: if not envs.VLLM_USE_V1:
raise ValueError( raise ValueError(
@ -53,7 +56,9 @@ class KVConnectorFactory:
kv_transfer_config = config.kv_transfer_config kv_transfer_config = config.kv_transfer_config
if kv_transfer_config is None: if kv_transfer_config is None:
raise ValueError("kv_transfer_config must be set to create a connector") raise ValueError("kv_transfer_config must be set to create a connector")
connector_cls = cls.get_connector_class(kv_transfer_config) connector_cls, compat_sig = cls._get_connector_class_with_compat(
kv_transfer_config
)
# check if the connector supports HMA # check if the connector supports HMA
hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager
@ -76,7 +81,12 @@ class KVConnectorFactory:
# - Co-locate with worker process # - Co-locate with worker process
# - Should only be used inside the forward context & attention layer # - Should only be used inside the forward context & attention layer
# We build separately to enforce strict separation # We build separately to enforce strict separation
if compat_sig:
# Old signature: __init__(self, vllm_config, role)
return connector_cls(config, role) return connector_cls(config, role)
else:
# New signature: __init__(self, vllm_config, role, kv_cache_config)
return connector_cls(config, role, kv_cache_config)
@classmethod @classmethod
def get_connector_class_by_name( def get_connector_class_by_name(
@ -97,13 +107,13 @@ class KVConnectorFactory:
return cls._registry[connector_name]() return cls._registry[connector_name]()
@classmethod @classmethod
def get_connector_class( def _get_connector_class_with_compat(
cls, kv_transfer_config: "KVTransferConfig" cls, kv_transfer_config: "KVTransferConfig"
) -> type[KVConnectorBaseType]: ) -> tuple[type[KVConnectorBaseType], bool]:
"""Get the connector class by name."""
connector_name = kv_transfer_config.kv_connector connector_name = kv_transfer_config.kv_connector
if connector_name is None: if connector_name is None:
raise ValueError("Connector name is not set in KVTransferConfig") raise ValueError("Connector name is not set in KVTransferConfig")
compat_sig = False
if connector_name in cls._registry: if connector_name in cls._registry:
connector_cls = cls._registry[connector_name]() connector_cls = cls._registry[connector_name]()
else: else:
@ -118,6 +128,21 @@ class KVConnectorFactory:
f"Class {connector_name} not found in {connector_module_path}" f"Class {connector_name} not found in {connector_module_path}"
) from e ) from e
connector_cls = cast(type[KVConnectorBaseType], connector_cls) connector_cls = cast(type[KVConnectorBaseType], connector_cls)
if not supports_kw(connector_cls, "kv_cache_config"):
compat_sig = True
logger.warning(
"Connector %s uses deprecated signature with 2 required arguments. "
"Please update to include kv_cache_config as the second argument.",
connector_cls.__name__,
)
return connector_cls, compat_sig
@classmethod
def get_connector_class(
cls, kv_transfer_config: "KVTransferConfig"
) -> type[KVConnectorBaseType]:
"""Get the connector class by name."""
connector_cls, _ = cls._get_connector_class_with_compat(kv_transfer_config)
return connector_cls return connector_cls

View File

@ -58,6 +58,7 @@ if TYPE_CHECKING:
) )
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request from vllm.v1.request import Request
# s_tensor_list, d_tensor_list, s_indices, d_indices, direction # s_tensor_list, d_tensor_list, s_indices, d_indices, direction
@ -141,7 +142,12 @@ class KVConnectorMetadata(ABC): # noqa: B024
class KVConnectorBase_V1(ABC): class KVConnectorBase_V1(ABC):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def __init__(
self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
logger.warning( logger.warning(
"Initializing KVConnectorBase_V1. This API is experimental and " "Initializing KVConnectorBase_V1. This API is experimental and "
"subject to change in the future as we iterate the design." "subject to change in the future as we iterate the design."
@ -152,6 +158,14 @@ class KVConnectorBase_V1(ABC):
self._kv_transfer_config = vllm_config.kv_transfer_config self._kv_transfer_config = vllm_config.kv_transfer_config
else: else:
raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1") raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1")
self._kv_cache_config = kv_cache_config
if self._kv_cache_config is None:
logger.warning(
"KVConnectorBase_V1 initialized without kv_cache_config. "
"This is deprecated - please update your connector to accept "
"kv_cache_config as the third constructor argument and pass it "
"to super().__init__()."
)
self._role = role self._role = role
@property @property

View File

@ -32,7 +32,7 @@ Usage:
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, Optional
import torch import torch
@ -50,6 +50,7 @@ if TYPE_CHECKING:
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request from vllm.v1.request import Request
logger = init_logger(__name__) logger = init_logger(__name__)
@ -79,8 +80,13 @@ class DecodeBenchConnector(KVConnectorBase_V1):
testing of the decoder with larger input sequence lengths. testing of the decoder with larger input sequence lengths.
""" """
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def __init__(
super().__init__(vllm_config, role) self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(vllm_config, role, kv_cache_config)
self.connector_scheduler: DecodeBenchConnectorScheduler | None = None self.connector_scheduler: DecodeBenchConnectorScheduler | None = None
self.connector_worker: DecodeBenchConnectorWorker | None = None self.connector_worker: DecodeBenchConnectorWorker | None = None

View File

@ -20,14 +20,22 @@ if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request from vllm.v1.request import Request
logger = init_logger(__name__) logger = init_logger(__name__)
class LMCacheConnectorV1(KVConnectorBase_V1): class LMCacheConnectorV1(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def __init__(
super().__init__(vllm_config=vllm_config, role=role) self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig",
):
super().__init__(
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
)
assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config is not None
use_native = vllm_config.kv_transfer_config.get_from_extra_config( use_native = vllm_config.kv_transfer_config.get_from_extra_config(
"use_native", False "use_native", False

View File

@ -31,6 +31,7 @@ if TYPE_CHECKING:
from vllm.distributed.kv_events import KVCacheEvent from vllm.distributed.kv_events import KVCacheEvent
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request from vllm.v1.request import Request
logger = init_logger(__name__) logger = init_logger(__name__)
@ -109,15 +110,22 @@ class MultiConnector(KVConnectorBase_V1):
- Save to all connectors. - Save to all connectors.
""" """
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def __init__(
super().__init__(vllm_config=vllm_config, role=role) self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: "KVCacheConfig",
):
super().__init__(
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
)
self._connectors: list[KVConnectorBase_V1] = [] self._connectors: list[KVConnectorBase_V1] = []
self._ktc_kv_transfer_config = [] self._ktc_kv_transfer_config = []
for connector_cls, temp_config in self._get_connector_classes_and_configs( for connector_cls, temp_config in self._get_connector_classes_and_configs(
vllm_config vllm_config
): ):
self._connectors.append(connector_cls(temp_config, role)) self._connectors.append(connector_cls(temp_config, role, kv_cache_config))
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config) self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)
# A mapping from request id to the index of the connector chosen to # A mapping from request id to the index of the connector chosen to

View File

@ -13,7 +13,7 @@ from collections import defaultdict
from collections.abc import Iterator from collections.abc import Iterator
from concurrent.futures import Future, ThreadPoolExecutor from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, Optional
import msgspec import msgspec
import numpy as np import numpy as np
@ -52,6 +52,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request from vllm.v1.request import Request
Transfer = tuple[int, float] # (xfer_handle, start_time) Transfer = tuple[int, float] # (xfer_handle, start_time)
@ -150,7 +151,14 @@ class NixlConnectorMetadata(KVConnectorMetadata):
class NixlConnector(KVConnectorBase_V1): class NixlConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): def __init__(
self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config is not None
assert vllm_config.kv_transfer_config.engine_id is not None assert vllm_config.kv_transfer_config.engine_id is not None
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id

View File

@ -21,6 +21,7 @@ from vllm.logger import init_logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import BlockHash from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.kv_offload.abstract import OffloadingManager from vllm.v1.kv_offload.abstract import OffloadingManager
from vllm.v1.kv_offload.factory import OffloadingSpecFactory from vllm.v1.kv_offload.factory import OffloadingSpecFactory
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
@ -41,8 +42,13 @@ class OffloadingConnectorMetadata(KVConnectorMetadata):
class OffloadingConnector(KVConnectorBase_V1): class OffloadingConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): def __init__(
super().__init__(vllm_config, role) self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: KVCacheConfig | None = None,
):
super().__init__(vllm_config, role, kv_cache_config)
spec = OffloadingSpecFactory.create_spec(vllm_config) spec = OffloadingSpecFactory.create_spec(vllm_config)

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, Optional
import regex as re import regex as re
import torch import torch
@ -25,6 +25,7 @@ if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request from vllm.v1.request import Request
logger = init_logger(__name__) logger = init_logger(__name__)
@ -71,8 +72,17 @@ class P2pNcclConnectorMetadata(KVConnectorMetadata):
class P2pNcclConnector(KVConnectorBase_V1): class P2pNcclConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def __init__(
super().__init__(vllm_config=vllm_config, role=role) self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(
vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config,
)
self._block_size = vllm_config.cache_config.block_size self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Any] = {} self._requests_need_load: dict[str, Any] = {}
self.is_producer = self._kv_transfer_config.is_kv_producer self.is_producer = self._kv_transfer_config.is_kv_producer

View File

@ -3,7 +3,7 @@
import hashlib import hashlib
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any, Optional
import safetensors import safetensors
import torch import torch
@ -22,6 +22,7 @@ if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request from vllm.v1.request import Request
logger = init_logger(__name__) logger = init_logger(__name__)
@ -86,8 +87,17 @@ class SharedStorageConnector(KVConnectorBase_V1):
# It does extra work which will overwrite the existing prefix-cache in GPU # It does extra work which will overwrite the existing prefix-cache in GPU
# - to remove the overhead, need to add some "mask" in the ReqMeta class # - to remove the overhead, need to add some "mask" in the ReqMeta class
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): def __init__(
super().__init__(vllm_config=vllm_config, role=role) self,
vllm_config: "VllmConfig",
role: KVConnectorRole,
kv_cache_config: Optional["KVCacheConfig"] = None,
):
super().__init__(
vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config,
)
self._block_size = vllm_config.cache_config.block_size self._block_size = vllm_config.cache_config.block_size
self._requests_need_load: dict[str, Request] = {} self._requests_need_load: dict[str, Request] = {}
self._storage_path = self._kv_transfer_config.get_from_extra_config( self._storage_path = self._kv_transfer_config.get_from_extra_config(

View File

@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
from vllm import envs from vllm import envs
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
@ -12,6 +12,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import KVCacheConfig
_KV_CONNECTOR_AGENT: KVConnectorBaseType | None = None _KV_CONNECTOR_AGENT: KVConnectorBaseType | None = None
@ -48,7 +49,9 @@ def is_v1_kv_transfer_group(connector: KVConnectorBaseType | None = None) -> boo
return isinstance(connector, KVConnectorBase_V1) return isinstance(connector, KVConnectorBase_V1)
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: def ensure_kv_transfer_initialized(
vllm_config: "VllmConfig", kv_cache_config: Optional["KVCacheConfig"] = None
) -> None:
""" """
Initialize KV cache transfer parallel group. Initialize KV cache transfer parallel group.
""" """
@ -64,7 +67,9 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
): ):
if envs.VLLM_USE_V1: if envs.VLLM_USE_V1:
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector( _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector(
config=vllm_config, role=KVConnectorRole.WORKER config=vllm_config,
role=KVConnectorRole.WORKER,
kv_cache_config=kv_cache_config,
) )
else: else:
raise ValueError("V0 is no longer supported") raise ValueError("V0 is no longer supported")

View File

@ -1,6 +1,5 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import itertools import itertools
import time import time
from collections import defaultdict from collections import defaultdict
@ -92,15 +91,10 @@ class Scheduler(SchedulerInterface):
assert not self.is_encoder_decoder, ( assert not self.is_encoder_decoder, (
"Encoder-decoder models are not currently supported with KV connectors" "Encoder-decoder models are not currently supported with KV connectors"
) )
connector_vllm_config = copy.copy(self.vllm_config)
# We're dynamically inserting a kv_cache_config variable into the
# connector_vllm_config. This is distinct from the cache_config
# that is already in there.
connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config) # type: ignore[attr-defined]
self.connector = KVConnectorFactory.create_connector( self.connector = KVConnectorFactory.create_connector(
config=connector_vllm_config, role=KVConnectorRole.SCHEDULER config=self.vllm_config,
role=KVConnectorRole.SCHEDULER,
kv_cache_config=self.kv_cache_config,
) )
if self.log_stats: if self.log_stats:
self.connector_prefix_cache_stats = PrefixCacheStats() self.connector_prefix_cache_stats = PrefixCacheStats()

View File

@ -380,9 +380,7 @@ class Worker(WorkerBase):
# NOTE(Kuntai): This need to be done before `initialize_kv_cache`, # NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
# because `initialize_kv_cache` will inject kv cache groups not # because `initialize_kv_cache` will inject kv cache groups not
# related to kv cache connector (e.g. kv cache sharing layers). # related to kv cache connector (e.g. kv cache sharing layers).
connector_vllm_config = copy.copy(self.vllm_config) ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config)
ensure_kv_transfer_initialized(connector_vllm_config)
if self.vllm_config.model_config.enable_sleep_mode: if self.vllm_config.model_config.enable_sleep_mode:
from vllm.device_allocator.cumem import CuMemAllocator from vllm.device_allocator.cumem import CuMemAllocator