diff --git a/tests/v1/kv_connector/unit/test_backwards_compatibility.py b/tests/v1/kv_connector/unit/test_backwards_compatibility.py new file mode 100644 index 000000000000..f51001a6ec12 --- /dev/null +++ b/tests/v1/kv_connector/unit/test_backwards_compatibility.py @@ -0,0 +1,275 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for backwards compatibility with external KV connector implementations. + +This test ensures that external connectors (loaded via kv_connector_module_path) +implemented with the old signature continue to work: +- Old signature: __init__(self, vllm_config, role) +- New signature: __init__(self, vllm_config, role, kv_cache_config) +""" + +from typing import TYPE_CHECKING +from unittest.mock import patch + +import pytest + +from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory +from vllm.distributed.kv_transfer.kv_connector.v1 import ( + KVConnectorBase_V1, + KVConnectorRole, +) +from vllm.v1.core.sched.output import SchedulerOutput + +from .utils import create_scheduler, create_vllm_config + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.config import VllmConfig + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig + from vllm.v1.request import Request + + +class OldStyleTestConnector(KVConnectorBase_V1): + """ + Test connector using the old signature with 2 required arguments. + This simulates external connectors that haven't been updated yet. + """ + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + # Old-style call to super().__init__ with only 2 arguments + super().__init__(vllm_config=vllm_config, role=role) + + def get_num_new_matched_tokens( + self, request: "Request", num_computed_tokens: int + ) -> tuple[int | None, bool]: + return 0, False + + def update_state_after_alloc( + self, + request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int, + ): + pass + + def build_connector_meta(self, scheduler_output: SchedulerOutput): + return None + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + pass + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def save_kv_layer( + self, + layer_name: str, + kv_layer, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: + pass + + def wait_for_save(self): + pass + + +class NewStyleTestConnector(KVConnectorBase_V1): + """ + Test connector using the new signature with 3 required arguments. + """ + + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig", + ): + # New-style call to super().__init__ with all 3 arguments + super().__init__( + vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config + ) + + def get_num_new_matched_tokens( + self, request: "Request", num_computed_tokens: int + ) -> tuple[int | None, bool]: + return 0, False + + def update_state_after_alloc( + self, + request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int, + ): + pass + + def build_connector_meta(self, scheduler_output: SchedulerOutput): + return None + + def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: + pass + + def wait_for_layer_load(self, layer_name: str) -> None: + pass + + def save_kv_layer( + self, + layer_name: str, + kv_layer, + attn_metadata: "AttentionMetadata", + **kwargs, + ) -> None: + pass + + def wait_for_save(self): + pass + + +@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER]) +def test_external_old_signature_factory_instantiation(role): + """ + Test that external connectors with old signature (2 required args) loaded + via kv_connector_module_path are correctly instantiated with backwards + compatibility support. + """ + vllm_config = create_vllm_config() + vllm_config.kv_transfer_config.kv_connector = "OldStyleTestConnector" + vllm_config.kv_transfer_config.kv_connector_module_path = ( + "tests.v1.kv_connector.unit.test_backwards_compatibility" + ) + + scheduler = create_scheduler(vllm_config) + kv_cache_config = scheduler.kv_cache_config + + connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config) + + assert connector is not None + assert isinstance(connector, OldStyleTestConnector) + assert connector.role == role + assert connector._kv_cache_config is None + + +@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER]) +def test_external_new_signature_factory_instantiation(role): + """ + Test that external connectors with new signature (3 required args) loaded + via kv_connector_module_path are correctly instantiated. + """ + vllm_config = create_vllm_config() + vllm_config.kv_transfer_config.kv_connector = "NewStyleTestConnector" + vllm_config.kv_transfer_config.kv_connector_module_path = ( + "tests.v1.kv_connector.unit.test_backwards_compatibility" + ) + + scheduler = create_scheduler(vllm_config) + kv_cache_config = scheduler.kv_cache_config + + connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config) + + assert connector is not None + assert isinstance(connector, NewStyleTestConnector) + assert connector.role == role + assert connector._kv_cache_config is not None + assert connector._kv_cache_config == kv_cache_config + + +@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER]) +def test_old_signature_super_init(role): + """ + Test that old-style connectors can call super().__init__() without + kv_cache_config parameter. + """ + vllm_config = create_vllm_config() + + connector = OldStyleTestConnector(vllm_config, role) + + assert connector is not None + assert connector.role == role + assert connector._kv_cache_config is None + + +def test_old_signature_super_init_with_kwargs(): + """ + Test that old-style connectors can call super().__init__() with keyword + arguments in different orders. + """ + vllm_config = create_vllm_config() + + # Test with vllm_config= and role= kwargs + connector1 = OldStyleTestConnector( + vllm_config=vllm_config, role=KVConnectorRole.SCHEDULER + ) + assert connector1 is not None + assert connector1._kv_cache_config is None + + # Test with role= and vllm_config= in reversed order + connector2 = OldStyleTestConnector( + role=KVConnectorRole.WORKER, vllm_config=vllm_config + ) + assert connector2 is not None + assert connector2._kv_cache_config is None + + +def test_internal_connector_uses_new_signature(): + """ + Test that internal connectors (registered in factory) always use the new + signature and get kv_cache_config. + """ + from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( + SharedStorageConnector, + ) + + vllm_config = create_vllm_config() + vllm_config.kv_transfer_config.kv_connector = "SharedStorageConnector" + + scheduler = create_scheduler(vllm_config) + kv_cache_config = scheduler.kv_cache_config + + connector = KVConnectorFactory.create_connector( + vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config + ) + + assert connector is not None + assert isinstance(connector, SharedStorageConnector) + assert connector._kv_cache_config is not None + assert connector._kv_cache_config == kv_cache_config + + +def test_signature_detection_with_mocking(): + """ + Test that the factory correctly applies compat_sig flag returned from + _get_connector_class_with_compat. + """ + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config) + kv_cache_config = scheduler.kv_cache_config + + # Mock _get_connector_class_with_compat to return old-style connector + with patch.object( + KVConnectorFactory, + "_get_connector_class_with_compat", + return_value=(OldStyleTestConnector, True), + ): + old_connector = KVConnectorFactory.create_connector( + vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config + ) + assert old_connector is not None + assert isinstance(old_connector, OldStyleTestConnector) + assert old_connector._kv_cache_config is None + + # Mock _get_connector_class_with_compat to return new-style connector + with patch.object( + KVConnectorFactory, + "_get_connector_class_with_compat", + return_value=(NewStyleTestConnector, False), + ): + new_connector = KVConnectorFactory.create_connector( + vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config + ) + assert new_connector is not None + assert isinstance(new_connector, NewStyleTestConnector) + assert new_connector._kv_cache_config is not None + assert new_connector._kv_cache_config == kv_cache_config diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 46ea46e53084..c1c0e13f7753 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -254,7 +254,7 @@ def create_model_runner_output( class TestSharedStorageConnector(SharedStorageConnector): - def __init__(self, config: VllmConfig, role): + def __init__(self, config: VllmConfig, role, kv_cache_config): self.name = config.kv_transfer_config.kv_connector_extra_config["name"] self._connector = SharedStorageConnector(config, role) self.call_record: dict[str, int] = defaultdict(int) diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index c64996f13cd5..8d14200c5240 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -3,10 +3,9 @@ import importlib from collections.abc import Callable -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, Optional, cast import vllm.envs as envs -from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.base import ( KVConnectorBase, KVConnectorBaseType, @@ -16,9 +15,12 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import ( supports_hma, ) from vllm.logger import init_logger +from vllm.utils.func_utils import supports_kw if TYPE_CHECKING: + from vllm.config import VllmConfig from vllm.config.kv_transfer import KVTransferConfig + from vllm.v1.kv_cache_interface import KVCacheConfig logger = init_logger(__name__) @@ -41,8 +43,9 @@ class KVConnectorFactory: @classmethod def create_connector( cls, - config: VllmConfig, + config: "VllmConfig", role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, ) -> KVConnectorBase: if not envs.VLLM_USE_V1: raise ValueError( @@ -53,7 +56,9 @@ class KVConnectorFactory: kv_transfer_config = config.kv_transfer_config if kv_transfer_config is None: raise ValueError("kv_transfer_config must be set to create a connector") - connector_cls = cls.get_connector_class(kv_transfer_config) + connector_cls, compat_sig = cls._get_connector_class_with_compat( + kv_transfer_config + ) # check if the connector supports HMA hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager @@ -76,7 +81,12 @@ class KVConnectorFactory: # - Co-locate with worker process # - Should only be used inside the forward context & attention layer # We build separately to enforce strict separation - return connector_cls(config, role) + if compat_sig: + # Old signature: __init__(self, vllm_config, role) + return connector_cls(config, role) + else: + # New signature: __init__(self, vllm_config, role, kv_cache_config) + return connector_cls(config, role, kv_cache_config) @classmethod def get_connector_class_by_name( @@ -97,13 +107,13 @@ class KVConnectorFactory: return cls._registry[connector_name]() @classmethod - def get_connector_class( + def _get_connector_class_with_compat( cls, kv_transfer_config: "KVTransferConfig" - ) -> type[KVConnectorBaseType]: - """Get the connector class by name.""" + ) -> tuple[type[KVConnectorBaseType], bool]: connector_name = kv_transfer_config.kv_connector if connector_name is None: raise ValueError("Connector name is not set in KVTransferConfig") + compat_sig = False if connector_name in cls._registry: connector_cls = cls._registry[connector_name]() else: @@ -118,6 +128,21 @@ class KVConnectorFactory: f"Class {connector_name} not found in {connector_module_path}" ) from e connector_cls = cast(type[KVConnectorBaseType], connector_cls) + if not supports_kw(connector_cls, "kv_cache_config"): + compat_sig = True + logger.warning( + "Connector %s uses deprecated signature with 2 required arguments. " + "Please update to include kv_cache_config as the second argument.", + connector_cls.__name__, + ) + return connector_cls, compat_sig + + @classmethod + def get_connector_class( + cls, kv_transfer_config: "KVTransferConfig" + ) -> type[KVConnectorBaseType]: + """Get the connector class by name.""" + connector_cls, _ = cls._get_connector_class_with_compat(kv_transfer_config) return connector_cls diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index cb9f208a839f..354aa9a87183 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -58,6 +58,7 @@ if TYPE_CHECKING: ) from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request # s_tensor_list, d_tensor_list, s_indices, d_indices, direction @@ -141,7 +142,12 @@ class KVConnectorMetadata(ABC): # noqa: B024 class KVConnectorBase_V1(ABC): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): logger.warning( "Initializing KVConnectorBase_V1. This API is experimental and " "subject to change in the future as we iterate the design." @@ -152,6 +158,14 @@ class KVConnectorBase_V1(ABC): self._kv_transfer_config = vllm_config.kv_transfer_config else: raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1") + self._kv_cache_config = kv_cache_config + if self._kv_cache_config is None: + logger.warning( + "KVConnectorBase_V1 initialized without kv_cache_config. " + "This is deprecated - please update your connector to accept " + "kv_cache_config as the third constructor argument and pass it " + "to super().__init__()." + ) self._role = role @property diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py index ca251cd0c6eb..9cd7d93c92fa 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py @@ -32,7 +32,7 @@ Usage: """ from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional import torch @@ -50,6 +50,7 @@ if TYPE_CHECKING: from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request logger = init_logger(__name__) @@ -79,8 +80,13 @@ class DecodeBenchConnector(KVConnectorBase_V1): testing of the decoder with larger input sequence lengths. """ - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config, role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__(vllm_config, role, kv_cache_config) self.connector_scheduler: DecodeBenchConnectorScheduler | None = None self.connector_worker: DecodeBenchConnectorWorker | None = None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 7232d947030c..575ab468be56 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -20,14 +20,22 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request logger = init_logger(__name__) class LMCacheConnectorV1(KVConnectorBase_V1): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig", + ): + super().__init__( + vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config + ) assert vllm_config.kv_transfer_config is not None use_native = vllm_config.kv_transfer_config.get_from_extra_config( "use_native", False diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index d56f30bd11e5..d7bbf02c8367 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -31,6 +31,7 @@ if TYPE_CHECKING: from vllm.distributed.kv_events import KVCacheEvent from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request logger = init_logger(__name__) @@ -109,15 +110,22 @@ class MultiConnector(KVConnectorBase_V1): - Save to all connectors. """ - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: "KVCacheConfig", + ): + super().__init__( + vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config + ) self._connectors: list[KVConnectorBase_V1] = [] self._ktc_kv_transfer_config = [] for connector_cls, temp_config in self._get_connector_classes_and_configs( vllm_config ): - self._connectors.append(connector_cls(temp_config, role)) + self._connectors.append(connector_cls(temp_config, role, kv_cache_config)) self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config) # A mapping from request id to the index of the connector chosen to diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 4651cedbc7df..ff9770b72bd3 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -13,7 +13,7 @@ from collections import defaultdict from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional import msgspec import numpy as np @@ -52,6 +52,7 @@ from vllm.v1.core.sched.output import SchedulerOutput if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request Transfer = tuple[int, float] # (xfer_handle, start_time) @@ -150,7 +151,14 @@ class NixlConnectorMetadata(KVConnectorMetadata): class NixlConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__(vllm_config, role, kv_cache_config) + assert vllm_config.kv_transfer_config is not None assert vllm_config.kv_transfer_config.engine_id is not None self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 7567c7fae578..582e42cc466a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -21,6 +21,7 @@ from vllm.logger import init_logger from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.kv_cache_utils import BlockHash from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_offload.abstract import OffloadingManager from vllm.v1.kv_offload.factory import OffloadingSpecFactory from vllm.v1.kv_offload.mediums import GPULoadStoreSpec @@ -41,8 +42,13 @@ class OffloadingConnectorMetadata(KVConnectorMetadata): class OffloadingConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): - super().__init__(vllm_config, role) + def __init__( + self, + vllm_config: VllmConfig, + role: KVConnectorRole, + kv_cache_config: KVCacheConfig | None = None, + ): + super().__init__(vllm_config, role, kv_cache_config) spec = OffloadingSpecFactory.create_spec(vllm_config) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py index 780dd12fccda..a124a0d519db 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional import regex as re import torch @@ -25,6 +25,7 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request logger = init_logger(__name__) @@ -71,8 +72,17 @@ class P2pNcclConnectorMetadata(KVConnectorMetadata): class P2pNcclConnector(KVConnectorBase_V1): - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__( + vllm_config=vllm_config, + role=role, + kv_cache_config=kv_cache_config, + ) self._block_size = vllm_config.cache_config.block_size self._requests_need_load: dict[str, Any] = {} self.is_producer = self._kv_transfer_config.is_kv_producer diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index 9c230d7d0d2f..016d1d45b359 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -3,7 +3,7 @@ import hashlib import os from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Optional import safetensors import torch @@ -22,6 +22,7 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request logger = init_logger(__name__) @@ -86,8 +87,17 @@ class SharedStorageConnector(KVConnectorBase_V1): # It does extra work which will overwrite the existing prefix-cache in GPU # - to remove the overhead, need to add some "mask" in the ReqMeta class - def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): - super().__init__(vllm_config=vllm_config, role=role) + def __init__( + self, + vllm_config: "VllmConfig", + role: KVConnectorRole, + kv_cache_config: Optional["KVCacheConfig"] = None, + ): + super().__init__( + vllm_config=vllm_config, + role=role, + kv_cache_config=kv_cache_config, + ) self._block_size = vllm_config.cache_config.block_size self._requests_need_load: dict[str, Request] = {} self._storage_path = self._kv_transfer_config.get_from_extra_config( diff --git a/vllm/distributed/kv_transfer/kv_transfer_state.py b/vllm/distributed/kv_transfer/kv_transfer_state.py index cabfc10e7f94..7501f0b373d4 100644 --- a/vllm/distributed/kv_transfer/kv_transfer_state.py +++ b/vllm/distributed/kv_transfer/kv_transfer_state.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from vllm import envs from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType @@ -12,6 +12,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import ( if TYPE_CHECKING: from vllm.config import VllmConfig + from vllm.v1.kv_cache_interface import KVCacheConfig _KV_CONNECTOR_AGENT: KVConnectorBaseType | None = None @@ -48,7 +49,9 @@ def is_v1_kv_transfer_group(connector: KVConnectorBaseType | None = None) -> boo return isinstance(connector, KVConnectorBase_V1) -def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: +def ensure_kv_transfer_initialized( + vllm_config: "VllmConfig", kv_cache_config: Optional["KVCacheConfig"] = None +) -> None: """ Initialize KV cache transfer parallel group. """ @@ -64,7 +67,9 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None: ): if envs.VLLM_USE_V1: _KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector( - config=vllm_config, role=KVConnectorRole.WORKER + config=vllm_config, + role=KVConnectorRole.WORKER, + kv_cache_config=kv_cache_config, ) else: raise ValueError("V0 is no longer supported") diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f51744eb2640..aeb9869c5281 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import copy import itertools import time from collections import defaultdict @@ -92,15 +91,10 @@ class Scheduler(SchedulerInterface): assert not self.is_encoder_decoder, ( "Encoder-decoder models are not currently supported with KV connectors" ) - - connector_vllm_config = copy.copy(self.vllm_config) - - # We're dynamically inserting a kv_cache_config variable into the - # connector_vllm_config. This is distinct from the cache_config - # that is already in there. - connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config) # type: ignore[attr-defined] self.connector = KVConnectorFactory.create_connector( - config=connector_vllm_config, role=KVConnectorRole.SCHEDULER + config=self.vllm_config, + role=KVConnectorRole.SCHEDULER, + kv_cache_config=self.kv_cache_config, ) if self.log_stats: self.connector_prefix_cache_stats = PrefixCacheStats() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index c2bf1419bebd..f3fe202cec06 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -380,9 +380,7 @@ class Worker(WorkerBase): # NOTE(Kuntai): This need to be done before `initialize_kv_cache`, # because `initialize_kv_cache` will inject kv cache groups not # related to kv cache connector (e.g. kv cache sharing layers). - connector_vllm_config = copy.copy(self.vllm_config) - connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config) - ensure_kv_transfer_initialized(connector_vllm_config) + ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config) if self.vllm_config.model_config.enable_sleep_mode: from vllm.device_allocator.cumem import CuMemAllocator