mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-16 04:55:02 +08:00
[KV Connector] Make KVCacheConfig an explicit constructor argument (#27887)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
parent
2f84ae1f27
commit
58279c60b5
275
tests/v1/kv_connector/unit/test_backwards_compatibility.py
Normal file
275
tests/v1/kv_connector/unit/test_backwards_compatibility.py
Normal file
@ -0,0 +1,275 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Unit tests for backwards compatibility with external KV connector implementations.
|
||||||
|
|
||||||
|
This test ensures that external connectors (loaded via kv_connector_module_path)
|
||||||
|
implemented with the old signature continue to work:
|
||||||
|
- Old signature: __init__(self, vllm_config, role)
|
||||||
|
- New signature: __init__(self, vllm_config, role, kv_cache_config)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||||
|
KVConnectorBase_V1,
|
||||||
|
KVConnectorRole,
|
||||||
|
)
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
|
from .utils import create_scheduler, create_vllm_config
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.forward_context import ForwardContext
|
||||||
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
|
|
||||||
|
class OldStyleTestConnector(KVConnectorBase_V1):
|
||||||
|
"""
|
||||||
|
Test connector using the old signature with 2 required arguments.
|
||||||
|
This simulates external connectors that haven't been updated yet.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||||
|
# Old-style call to super().__init__ with only 2 arguments
|
||||||
|
super().__init__(vllm_config=vllm_config, role=role)
|
||||||
|
|
||||||
|
def get_num_new_matched_tokens(
|
||||||
|
self, request: "Request", num_computed_tokens: int
|
||||||
|
) -> tuple[int | None, bool]:
|
||||||
|
return 0, False
|
||||||
|
|
||||||
|
def update_state_after_alloc(
|
||||||
|
self,
|
||||||
|
request: "Request",
|
||||||
|
blocks: "KVCacheBlocks",
|
||||||
|
num_external_tokens: int,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def build_connector_meta(self, scheduler_output: SchedulerOutput):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save_kv_layer(
|
||||||
|
self,
|
||||||
|
layer_name: str,
|
||||||
|
kv_layer,
|
||||||
|
attn_metadata: "AttentionMetadata",
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def wait_for_save(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NewStyleTestConnector(KVConnectorBase_V1):
|
||||||
|
"""
|
||||||
|
Test connector using the new signature with 3 required arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
role: KVConnectorRole,
|
||||||
|
kv_cache_config: "KVCacheConfig",
|
||||||
|
):
|
||||||
|
# New-style call to super().__init__ with all 3 arguments
|
||||||
|
super().__init__(
|
||||||
|
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_num_new_matched_tokens(
|
||||||
|
self, request: "Request", num_computed_tokens: int
|
||||||
|
) -> tuple[int | None, bool]:
|
||||||
|
return 0, False
|
||||||
|
|
||||||
|
def update_state_after_alloc(
|
||||||
|
self,
|
||||||
|
request: "Request",
|
||||||
|
blocks: "KVCacheBlocks",
|
||||||
|
num_external_tokens: int,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def build_connector_meta(self, scheduler_output: SchedulerOutput):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save_kv_layer(
|
||||||
|
self,
|
||||||
|
layer_name: str,
|
||||||
|
kv_layer,
|
||||||
|
attn_metadata: "AttentionMetadata",
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def wait_for_save(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
|
||||||
|
def test_external_old_signature_factory_instantiation(role):
|
||||||
|
"""
|
||||||
|
Test that external connectors with old signature (2 required args) loaded
|
||||||
|
via kv_connector_module_path are correctly instantiated with backwards
|
||||||
|
compatibility support.
|
||||||
|
"""
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
vllm_config.kv_transfer_config.kv_connector = "OldStyleTestConnector"
|
||||||
|
vllm_config.kv_transfer_config.kv_connector_module_path = (
|
||||||
|
"tests.v1.kv_connector.unit.test_backwards_compatibility"
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler = create_scheduler(vllm_config)
|
||||||
|
kv_cache_config = scheduler.kv_cache_config
|
||||||
|
|
||||||
|
connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)
|
||||||
|
|
||||||
|
assert connector is not None
|
||||||
|
assert isinstance(connector, OldStyleTestConnector)
|
||||||
|
assert connector.role == role
|
||||||
|
assert connector._kv_cache_config is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
|
||||||
|
def test_external_new_signature_factory_instantiation(role):
|
||||||
|
"""
|
||||||
|
Test that external connectors with new signature (3 required args) loaded
|
||||||
|
via kv_connector_module_path are correctly instantiated.
|
||||||
|
"""
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
vllm_config.kv_transfer_config.kv_connector = "NewStyleTestConnector"
|
||||||
|
vllm_config.kv_transfer_config.kv_connector_module_path = (
|
||||||
|
"tests.v1.kv_connector.unit.test_backwards_compatibility"
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler = create_scheduler(vllm_config)
|
||||||
|
kv_cache_config = scheduler.kv_cache_config
|
||||||
|
|
||||||
|
connector = KVConnectorFactory.create_connector(vllm_config, role, kv_cache_config)
|
||||||
|
|
||||||
|
assert connector is not None
|
||||||
|
assert isinstance(connector, NewStyleTestConnector)
|
||||||
|
assert connector.role == role
|
||||||
|
assert connector._kv_cache_config is not None
|
||||||
|
assert connector._kv_cache_config == kv_cache_config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("role", [KVConnectorRole.SCHEDULER, KVConnectorRole.WORKER])
|
||||||
|
def test_old_signature_super_init(role):
|
||||||
|
"""
|
||||||
|
Test that old-style connectors can call super().__init__() without
|
||||||
|
kv_cache_config parameter.
|
||||||
|
"""
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
|
||||||
|
connector = OldStyleTestConnector(vllm_config, role)
|
||||||
|
|
||||||
|
assert connector is not None
|
||||||
|
assert connector.role == role
|
||||||
|
assert connector._kv_cache_config is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_old_signature_super_init_with_kwargs():
|
||||||
|
"""
|
||||||
|
Test that old-style connectors can call super().__init__() with keyword
|
||||||
|
arguments in different orders.
|
||||||
|
"""
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
|
||||||
|
# Test with vllm_config= and role= kwargs
|
||||||
|
connector1 = OldStyleTestConnector(
|
||||||
|
vllm_config=vllm_config, role=KVConnectorRole.SCHEDULER
|
||||||
|
)
|
||||||
|
assert connector1 is not None
|
||||||
|
assert connector1._kv_cache_config is None
|
||||||
|
|
||||||
|
# Test with role= and vllm_config= in reversed order
|
||||||
|
connector2 = OldStyleTestConnector(
|
||||||
|
role=KVConnectorRole.WORKER, vllm_config=vllm_config
|
||||||
|
)
|
||||||
|
assert connector2 is not None
|
||||||
|
assert connector2._kv_cache_config is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_internal_connector_uses_new_signature():
|
||||||
|
"""
|
||||||
|
Test that internal connectors (registered in factory) always use the new
|
||||||
|
signature and get kv_cache_config.
|
||||||
|
"""
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import (
|
||||||
|
SharedStorageConnector,
|
||||||
|
)
|
||||||
|
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
vllm_config.kv_transfer_config.kv_connector = "SharedStorageConnector"
|
||||||
|
|
||||||
|
scheduler = create_scheduler(vllm_config)
|
||||||
|
kv_cache_config = scheduler.kv_cache_config
|
||||||
|
|
||||||
|
connector = KVConnectorFactory.create_connector(
|
||||||
|
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
|
||||||
|
)
|
||||||
|
|
||||||
|
assert connector is not None
|
||||||
|
assert isinstance(connector, SharedStorageConnector)
|
||||||
|
assert connector._kv_cache_config is not None
|
||||||
|
assert connector._kv_cache_config == kv_cache_config
|
||||||
|
|
||||||
|
|
||||||
|
def test_signature_detection_with_mocking():
|
||||||
|
"""
|
||||||
|
Test that the factory correctly applies compat_sig flag returned from
|
||||||
|
_get_connector_class_with_compat.
|
||||||
|
"""
|
||||||
|
vllm_config = create_vllm_config()
|
||||||
|
scheduler = create_scheduler(vllm_config)
|
||||||
|
kv_cache_config = scheduler.kv_cache_config
|
||||||
|
|
||||||
|
# Mock _get_connector_class_with_compat to return old-style connector
|
||||||
|
with patch.object(
|
||||||
|
KVConnectorFactory,
|
||||||
|
"_get_connector_class_with_compat",
|
||||||
|
return_value=(OldStyleTestConnector, True),
|
||||||
|
):
|
||||||
|
old_connector = KVConnectorFactory.create_connector(
|
||||||
|
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
|
||||||
|
)
|
||||||
|
assert old_connector is not None
|
||||||
|
assert isinstance(old_connector, OldStyleTestConnector)
|
||||||
|
assert old_connector._kv_cache_config is None
|
||||||
|
|
||||||
|
# Mock _get_connector_class_with_compat to return new-style connector
|
||||||
|
with patch.object(
|
||||||
|
KVConnectorFactory,
|
||||||
|
"_get_connector_class_with_compat",
|
||||||
|
return_value=(NewStyleTestConnector, False),
|
||||||
|
):
|
||||||
|
new_connector = KVConnectorFactory.create_connector(
|
||||||
|
vllm_config, KVConnectorRole.SCHEDULER, kv_cache_config
|
||||||
|
)
|
||||||
|
assert new_connector is not None
|
||||||
|
assert isinstance(new_connector, NewStyleTestConnector)
|
||||||
|
assert new_connector._kv_cache_config is not None
|
||||||
|
assert new_connector._kv_cache_config == kv_cache_config
|
||||||
@ -254,7 +254,7 @@ def create_model_runner_output(
|
|||||||
|
|
||||||
|
|
||||||
class TestSharedStorageConnector(SharedStorageConnector):
|
class TestSharedStorageConnector(SharedStorageConnector):
|
||||||
def __init__(self, config: VllmConfig, role):
|
def __init__(self, config: VllmConfig, role, kv_cache_config):
|
||||||
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
|
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
|
||||||
self._connector = SharedStorageConnector(config, role)
|
self._connector = SharedStorageConnector(config, role)
|
||||||
self.call_record: dict[str, int] = defaultdict(int)
|
self.call_record: dict[str, int] = defaultdict(int)
|
||||||
|
|||||||
@ -3,10 +3,9 @@
|
|||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import TYPE_CHECKING, cast
|
from typing import TYPE_CHECKING, Optional, cast
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import VllmConfig
|
|
||||||
from vllm.distributed.kv_transfer.kv_connector.base import (
|
from vllm.distributed.kv_transfer.kv_connector.base import (
|
||||||
KVConnectorBase,
|
KVConnectorBase,
|
||||||
KVConnectorBaseType,
|
KVConnectorBaseType,
|
||||||
@ -16,9 +15,12 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
|||||||
supports_hma,
|
supports_hma,
|
||||||
)
|
)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils.func_utils import supports_kw
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.kv_transfer import KVTransferConfig
|
from vllm.config.kv_transfer import KVTransferConfig
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -41,8 +43,9 @@ class KVConnectorFactory:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def create_connector(
|
def create_connector(
|
||||||
cls,
|
cls,
|
||||||
config: VllmConfig,
|
config: "VllmConfig",
|
||||||
role: KVConnectorRole,
|
role: KVConnectorRole,
|
||||||
|
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||||
) -> KVConnectorBase:
|
) -> KVConnectorBase:
|
||||||
if not envs.VLLM_USE_V1:
|
if not envs.VLLM_USE_V1:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -53,7 +56,9 @@ class KVConnectorFactory:
|
|||||||
kv_transfer_config = config.kv_transfer_config
|
kv_transfer_config = config.kv_transfer_config
|
||||||
if kv_transfer_config is None:
|
if kv_transfer_config is None:
|
||||||
raise ValueError("kv_transfer_config must be set to create a connector")
|
raise ValueError("kv_transfer_config must be set to create a connector")
|
||||||
connector_cls = cls.get_connector_class(kv_transfer_config)
|
connector_cls, compat_sig = cls._get_connector_class_with_compat(
|
||||||
|
kv_transfer_config
|
||||||
|
)
|
||||||
|
|
||||||
# check if the connector supports HMA
|
# check if the connector supports HMA
|
||||||
hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager
|
hma_enabled = not config.scheduler_config.disable_hybrid_kv_cache_manager
|
||||||
@ -76,7 +81,12 @@ class KVConnectorFactory:
|
|||||||
# - Co-locate with worker process
|
# - Co-locate with worker process
|
||||||
# - Should only be used inside the forward context & attention layer
|
# - Should only be used inside the forward context & attention layer
|
||||||
# We build separately to enforce strict separation
|
# We build separately to enforce strict separation
|
||||||
|
if compat_sig:
|
||||||
|
# Old signature: __init__(self, vllm_config, role)
|
||||||
return connector_cls(config, role)
|
return connector_cls(config, role)
|
||||||
|
else:
|
||||||
|
# New signature: __init__(self, vllm_config, role, kv_cache_config)
|
||||||
|
return connector_cls(config, role, kv_cache_config)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_connector_class_by_name(
|
def get_connector_class_by_name(
|
||||||
@ -97,13 +107,13 @@ class KVConnectorFactory:
|
|||||||
return cls._registry[connector_name]()
|
return cls._registry[connector_name]()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_connector_class(
|
def _get_connector_class_with_compat(
|
||||||
cls, kv_transfer_config: "KVTransferConfig"
|
cls, kv_transfer_config: "KVTransferConfig"
|
||||||
) -> type[KVConnectorBaseType]:
|
) -> tuple[type[KVConnectorBaseType], bool]:
|
||||||
"""Get the connector class by name."""
|
|
||||||
connector_name = kv_transfer_config.kv_connector
|
connector_name = kv_transfer_config.kv_connector
|
||||||
if connector_name is None:
|
if connector_name is None:
|
||||||
raise ValueError("Connector name is not set in KVTransferConfig")
|
raise ValueError("Connector name is not set in KVTransferConfig")
|
||||||
|
compat_sig = False
|
||||||
if connector_name in cls._registry:
|
if connector_name in cls._registry:
|
||||||
connector_cls = cls._registry[connector_name]()
|
connector_cls = cls._registry[connector_name]()
|
||||||
else:
|
else:
|
||||||
@ -118,6 +128,21 @@ class KVConnectorFactory:
|
|||||||
f"Class {connector_name} not found in {connector_module_path}"
|
f"Class {connector_name} not found in {connector_module_path}"
|
||||||
) from e
|
) from e
|
||||||
connector_cls = cast(type[KVConnectorBaseType], connector_cls)
|
connector_cls = cast(type[KVConnectorBaseType], connector_cls)
|
||||||
|
if not supports_kw(connector_cls, "kv_cache_config"):
|
||||||
|
compat_sig = True
|
||||||
|
logger.warning(
|
||||||
|
"Connector %s uses deprecated signature with 2 required arguments. "
|
||||||
|
"Please update to include kv_cache_config as the second argument.",
|
||||||
|
connector_cls.__name__,
|
||||||
|
)
|
||||||
|
return connector_cls, compat_sig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_connector_class(
|
||||||
|
cls, kv_transfer_config: "KVTransferConfig"
|
||||||
|
) -> type[KVConnectorBaseType]:
|
||||||
|
"""Get the connector class by name."""
|
||||||
|
connector_cls, _ = cls._get_connector_class_with_compat(kv_transfer_config)
|
||||||
return connector_cls
|
return connector_cls
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -58,6 +58,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
# s_tensor_list, d_tensor_list, s_indices, d_indices, direction
|
# s_tensor_list, d_tensor_list, s_indices, d_indices, direction
|
||||||
@ -141,7 +142,12 @@ class KVConnectorMetadata(ABC): # noqa: B024
|
|||||||
|
|
||||||
|
|
||||||
class KVConnectorBase_V1(ABC):
|
class KVConnectorBase_V1(ABC):
|
||||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
role: KVConnectorRole,
|
||||||
|
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||||
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Initializing KVConnectorBase_V1. This API is experimental and "
|
"Initializing KVConnectorBase_V1. This API is experimental and "
|
||||||
"subject to change in the future as we iterate the design."
|
"subject to change in the future as we iterate the design."
|
||||||
@ -152,6 +158,14 @@ class KVConnectorBase_V1(ABC):
|
|||||||
self._kv_transfer_config = vllm_config.kv_transfer_config
|
self._kv_transfer_config = vllm_config.kv_transfer_config
|
||||||
else:
|
else:
|
||||||
raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1")
|
raise ValueError("kv_transfer_config must be set for KVConnectorBase_V1")
|
||||||
|
self._kv_cache_config = kv_cache_config
|
||||||
|
if self._kv_cache_config is None:
|
||||||
|
logger.warning(
|
||||||
|
"KVConnectorBase_V1 initialized without kv_cache_config. "
|
||||||
|
"This is deprecated - please update your connector to accept "
|
||||||
|
"kv_cache_config as the third constructor argument and pass it "
|
||||||
|
"to super().__init__()."
|
||||||
|
)
|
||||||
self._role = role
|
self._role = role
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@ -32,7 +32,7 @@ Usage:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -50,6 +50,7 @@ if TYPE_CHECKING:
|
|||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -79,8 +80,13 @@ class DecodeBenchConnector(KVConnectorBase_V1):
|
|||||||
testing of the decoder with larger input sequence lengths.
|
testing of the decoder with larger input sequence lengths.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
def __init__(
|
||||||
super().__init__(vllm_config, role)
|
self,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
role: KVConnectorRole,
|
||||||
|
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||||
|
):
|
||||||
|
super().__init__(vllm_config, role, kv_cache_config)
|
||||||
|
|
||||||
self.connector_scheduler: DecodeBenchConnectorScheduler | None = None
|
self.connector_scheduler: DecodeBenchConnectorScheduler | None = None
|
||||||
self.connector_worker: DecodeBenchConnectorWorker | None = None
|
self.connector_worker: DecodeBenchConnectorWorker | None = None
|
||||||
|
|||||||
@ -20,14 +20,22 @@ if TYPE_CHECKING:
|
|||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LMCacheConnectorV1(KVConnectorBase_V1):
|
class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
def __init__(
|
||||||
super().__init__(vllm_config=vllm_config, role=role)
|
self,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
role: KVConnectorRole,
|
||||||
|
kv_cache_config: "KVCacheConfig",
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
|
||||||
|
)
|
||||||
assert vllm_config.kv_transfer_config is not None
|
assert vllm_config.kv_transfer_config is not None
|
||||||
use_native = vllm_config.kv_transfer_config.get_from_extra_config(
|
use_native = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||||
"use_native", False
|
"use_native", False
|
||||||
|
|||||||
@ -31,6 +31,7 @@ if TYPE_CHECKING:
|
|||||||
from vllm.distributed.kv_events import KVCacheEvent
|
from vllm.distributed.kv_events import KVCacheEvent
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -109,15 +110,22 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
- Save to all connectors.
|
- Save to all connectors.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
def __init__(
|
||||||
super().__init__(vllm_config=vllm_config, role=role)
|
self,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
role: KVConnectorRole,
|
||||||
|
kv_cache_config: "KVCacheConfig",
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
|
||||||
|
)
|
||||||
|
|
||||||
self._connectors: list[KVConnectorBase_V1] = []
|
self._connectors: list[KVConnectorBase_V1] = []
|
||||||
self._ktc_kv_transfer_config = []
|
self._ktc_kv_transfer_config = []
|
||||||
for connector_cls, temp_config in self._get_connector_classes_and_configs(
|
for connector_cls, temp_config in self._get_connector_classes_and_configs(
|
||||||
vllm_config
|
vllm_config
|
||||||
):
|
):
|
||||||
self._connectors.append(connector_cls(temp_config, role))
|
self._connectors.append(connector_cls(temp_config, role, kv_cache_config))
|
||||||
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)
|
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)
|
||||||
|
|
||||||
# A mapping from request id to the index of the connector chosen to
|
# A mapping from request id to the index of the connector chosen to
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from collections import defaultdict
|
|||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from concurrent.futures import Future, ThreadPoolExecutor
|
from concurrent.futures import Future, ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -52,6 +52,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
Transfer = tuple[int, float] # (xfer_handle, start_time)
|
Transfer = tuple[int, float] # (xfer_handle, start_time)
|
||||||
@ -150,7 +151,14 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
|||||||
|
|
||||||
|
|
||||||
class NixlConnector(KVConnectorBase_V1):
|
class NixlConnector(KVConnectorBase_V1):
|
||||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
role: KVConnectorRole,
|
||||||
|
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||||
|
):
|
||||||
|
super().__init__(vllm_config, role, kv_cache_config)
|
||||||
|
|
||||||
assert vllm_config.kv_transfer_config is not None
|
assert vllm_config.kv_transfer_config is not None
|
||||||
assert vllm_config.kv_transfer_config.engine_id is not None
|
assert vllm_config.kv_transfer_config.engine_id is not None
|
||||||
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
|
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
|
||||||
|
|||||||
@ -21,6 +21,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.kv_offload.abstract import OffloadingManager
|
from vllm.v1.kv_offload.abstract import OffloadingManager
|
||||||
from vllm.v1.kv_offload.factory import OffloadingSpecFactory
|
from vllm.v1.kv_offload.factory import OffloadingSpecFactory
|
||||||
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
|
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
|
||||||
@ -41,8 +42,13 @@ class OffloadingConnectorMetadata(KVConnectorMetadata):
|
|||||||
|
|
||||||
|
|
||||||
class OffloadingConnector(KVConnectorBase_V1):
|
class OffloadingConnector(KVConnectorBase_V1):
|
||||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
def __init__(
|
||||||
super().__init__(vllm_config, role)
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
role: KVConnectorRole,
|
||||||
|
kv_cache_config: KVCacheConfig | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(vllm_config, role, kv_cache_config)
|
||||||
|
|
||||||
spec = OffloadingSpecFactory.create_spec(vllm_config)
|
spec = OffloadingSpecFactory.create_spec(vllm_config)
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import regex as re
|
import regex as re
|
||||||
import torch
|
import torch
|
||||||
@ -25,6 +25,7 @@ if TYPE_CHECKING:
|
|||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -71,8 +72,17 @@ class P2pNcclConnectorMetadata(KVConnectorMetadata):
|
|||||||
|
|
||||||
|
|
||||||
class P2pNcclConnector(KVConnectorBase_V1):
|
class P2pNcclConnector(KVConnectorBase_V1):
|
||||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
def __init__(
|
||||||
super().__init__(vllm_config=vllm_config, role=role)
|
self,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
role: KVConnectorRole,
|
||||||
|
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
role=role,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
)
|
||||||
self._block_size = vllm_config.cache_config.block_size
|
self._block_size = vllm_config.cache_config.block_size
|
||||||
self._requests_need_load: dict[str, Any] = {}
|
self._requests_need_load: dict[str, Any] = {}
|
||||||
self.is_producer = self._kv_transfer_config.is_kv_producer
|
self.is_producer = self._kv_transfer_config.is_kv_producer
|
||||||
|
|||||||
@ -3,7 +3,7 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import safetensors
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
@ -22,6 +22,7 @@ if TYPE_CHECKING:
|
|||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -86,8 +87,17 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
|||||||
# It does extra work which will overwrite the existing prefix-cache in GPU
|
# It does extra work which will overwrite the existing prefix-cache in GPU
|
||||||
# - to remove the overhead, need to add some "mask" in the ReqMeta class
|
# - to remove the overhead, need to add some "mask" in the ReqMeta class
|
||||||
|
|
||||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
def __init__(
|
||||||
super().__init__(vllm_config=vllm_config, role=role)
|
self,
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
role: KVConnectorRole,
|
||||||
|
kv_cache_config: Optional["KVCacheConfig"] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
role=role,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
)
|
||||||
self._block_size = vllm_config.cache_config.block_size
|
self._block_size = vllm_config.cache_config.block_size
|
||||||
self._requests_need_load: dict[str, Request] = {}
|
self._requests_need_load: dict[str, Request] = {}
|
||||||
self._storage_path = self._kv_transfer_config.get_from_extra_config(
|
self._storage_path = self._kv_transfer_config.get_from_extra_config(
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
||||||
@ -12,6 +12,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
|
|
||||||
_KV_CONNECTOR_AGENT: KVConnectorBaseType | None = None
|
_KV_CONNECTOR_AGENT: KVConnectorBaseType | None = None
|
||||||
|
|
||||||
@ -48,7 +49,9 @@ def is_v1_kv_transfer_group(connector: KVConnectorBaseType | None = None) -> boo
|
|||||||
return isinstance(connector, KVConnectorBase_V1)
|
return isinstance(connector, KVConnectorBase_V1)
|
||||||
|
|
||||||
|
|
||||||
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
|
def ensure_kv_transfer_initialized(
|
||||||
|
vllm_config: "VllmConfig", kv_cache_config: Optional["KVCacheConfig"] = None
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize KV cache transfer parallel group.
|
Initialize KV cache transfer parallel group.
|
||||||
"""
|
"""
|
||||||
@ -64,7 +67,9 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
|
|||||||
):
|
):
|
||||||
if envs.VLLM_USE_V1:
|
if envs.VLLM_USE_V1:
|
||||||
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector(
|
_KV_CONNECTOR_AGENT = KVConnectorFactory.create_connector(
|
||||||
config=vllm_config, role=KVConnectorRole.WORKER
|
config=vllm_config,
|
||||||
|
role=KVConnectorRole.WORKER,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("V0 is no longer supported")
|
raise ValueError("V0 is no longer supported")
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import copy
|
|
||||||
import itertools
|
import itertools
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
@ -92,15 +91,10 @@ class Scheduler(SchedulerInterface):
|
|||||||
assert not self.is_encoder_decoder, (
|
assert not self.is_encoder_decoder, (
|
||||||
"Encoder-decoder models are not currently supported with KV connectors"
|
"Encoder-decoder models are not currently supported with KV connectors"
|
||||||
)
|
)
|
||||||
|
|
||||||
connector_vllm_config = copy.copy(self.vllm_config)
|
|
||||||
|
|
||||||
# We're dynamically inserting a kv_cache_config variable into the
|
|
||||||
# connector_vllm_config. This is distinct from the cache_config
|
|
||||||
# that is already in there.
|
|
||||||
connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config) # type: ignore[attr-defined]
|
|
||||||
self.connector = KVConnectorFactory.create_connector(
|
self.connector = KVConnectorFactory.create_connector(
|
||||||
config=connector_vllm_config, role=KVConnectorRole.SCHEDULER
|
config=self.vllm_config,
|
||||||
|
role=KVConnectorRole.SCHEDULER,
|
||||||
|
kv_cache_config=self.kv_cache_config,
|
||||||
)
|
)
|
||||||
if self.log_stats:
|
if self.log_stats:
|
||||||
self.connector_prefix_cache_stats = PrefixCacheStats()
|
self.connector_prefix_cache_stats = PrefixCacheStats()
|
||||||
|
|||||||
@ -380,9 +380,7 @@ class Worker(WorkerBase):
|
|||||||
# NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
|
# NOTE(Kuntai): This need to be done before `initialize_kv_cache`,
|
||||||
# because `initialize_kv_cache` will inject kv cache groups not
|
# because `initialize_kv_cache` will inject kv cache groups not
|
||||||
# related to kv cache connector (e.g. kv cache sharing layers).
|
# related to kv cache connector (e.g. kv cache sharing layers).
|
||||||
connector_vllm_config = copy.copy(self.vllm_config)
|
ensure_kv_transfer_initialized(self.vllm_config, kv_cache_config)
|
||||||
connector_vllm_config.kv_cache_config = copy.copy(kv_cache_config)
|
|
||||||
ensure_kv_transfer_initialized(connector_vllm_config)
|
|
||||||
|
|
||||||
if self.vllm_config.model_config.enable_sleep_mode:
|
if self.vllm_config.model_config.enable_sleep_mode:
|
||||||
from vllm.device_allocator.cumem import CuMemAllocator
|
from vllm.device_allocator.cumem import CuMemAllocator
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user