diff --git a/tests/distributed/test_kvlayout.py b/tests/distributed/test_kvlayout.py new file mode 100644 index 000000000000..d447876f6cc7 --- /dev/null +++ b/tests/distributed/test_kvlayout.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from vllm.config import (DeviceConfig, KVTransferConfig, ModelConfig, + VllmConfig, set_current_vllm_config) +from vllm.distributed.kv_transfer.kv_connector.utils import ( + get_kv_connector_cache_layout) +from vllm.logger import init_logger + +logger = init_logger("test_expert_parallel") + + +def test_get_kv_connector_cache_layout_without_kv_connector(): + vllm_config = VllmConfig(device_config=DeviceConfig("cpu")) + with set_current_vllm_config(vllm_config): + # Test with default settings + layout = get_kv_connector_cache_layout() + assert layout == "NHD" + + +def test_get_kv_connector_cache_layout_with_lmcache_connector(): + kv_transfer_config = KVTransferConfig( + kv_connector="LMCacheConnectorV1", + kv_role="kv_both", + ) + vllm_config = VllmConfig(device_config=DeviceConfig("cpu"), + kv_transfer_config=kv_transfer_config) + with set_current_vllm_config(vllm_config): + # Test with default settings + layout = get_kv_connector_cache_layout() + assert layout == "NHD" + + +def test_get_kv_connector_cache_layout_with_nixl_connector(): + kv_transfer_config = KVTransferConfig( + kv_connector="NixlConnector", + kv_role="kv_both", + ) + model_config = ModelConfig() + vllm_config = VllmConfig(device_config=DeviceConfig("cpu"), + model_config=model_config, + kv_transfer_config=kv_transfer_config) + with set_current_vllm_config(vllm_config): + # Test with default settings + layout = get_kv_connector_cache_layout() + assert layout == "HND" + + +def test_get_kv_connector_cache_layout_with_multi_connector(): + kv_transfer_config = KVTransferConfig(kv_connector="MultiConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "connectors": [{ + "kv_connector": + "SharedStorageConnector", + "kv_role": + "kv_both" + }, { + "kv_connector": + "NixlConnector", + "kv_role": + "kv_both" + }] + }) + model_config = ModelConfig() + vllm_config = VllmConfig(device_config=DeviceConfig("cpu"), + model_config=model_config, + kv_transfer_config=kv_transfer_config) + with set_current_vllm_config(vllm_config): + # Test with default settings + layout = get_kv_connector_cache_layout() + assert layout == "HND" diff --git a/vllm/distributed/kv_transfer/kv_connector/base.py b/vllm/distributed/kv_transfer/kv_connector/base.py index 181c33925da7..868b227fc899 100644 --- a/vllm/distributed/kv_transfer/kv_connector/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/base.py @@ -9,7 +9,7 @@ The class provides two primary abstract methods: """ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Optional, Union import torch @@ -124,5 +124,19 @@ class KVConnectorBase(ABC): raise NotImplementedError + @classmethod + def get_required_kvcache_layout( + cls, vllm_config: "VllmConfig") -> Optional[str]: + """ + Get the required KV cache layout for this connector. + Args: + vllm_config (VllmConfig): the vllm config. + + Returns: + str: the required KV cache layout. e.g. HND, or NHD. + None if the connector does not require a specific layout. + """ + return None + KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1] diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index be9ce72dea67..cf7cde2c4377 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -5,6 +5,7 @@ import importlib from typing import TYPE_CHECKING, Callable import vllm.envs as envs +from vllm.config import KVTransferConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1, KVConnectorRole) @@ -41,14 +42,27 @@ class KVConnectorFactory: raise ValueError("Attempting to initialize a V0 Connector, " f"but found {envs.VLLM_USE_V1=}") - connector_name = config.kv_transfer_config.kv_connector - if connector_name not in cls._registry: - raise ValueError(f"Unsupported connector type: {connector_name}") - - connector_cls = cls._registry[connector_name]() + connector_cls = cls.get_connector_class(config.kv_transfer_config) assert issubclass(connector_cls, KVConnectorBase) return connector_cls(rank, local_rank, config) + @classmethod + def get_connector_class( + cls, kv_transfer_config: "KVTransferConfig" + ) -> type[KVConnectorBaseType]: + """Get the connector class by name.""" + connector_name = kv_transfer_config.kv_connector + if connector_name in cls._registry: + connector_cls = cls._registry[connector_name]() + else: + connector_module_path = kv_transfer_config.kv_connector_module_path + if connector_module_path is None: + raise ValueError( + f"Unsupported connector type: {connector_name}") + connector_module = importlib.import_module(connector_module_path) + connector_cls = getattr(connector_module, connector_name) + return connector_cls + @classmethod def create_connector_v1( cls, @@ -60,19 +74,10 @@ class KVConnectorFactory: f"but found {envs.VLLM_USE_V1=}") kv_transfer_config = config.kv_transfer_config - connector_name = kv_transfer_config.kv_connector - if connector_name in cls._registry: - connector_cls = cls._registry[connector_name]() - else: - connector_module_path = kv_transfer_config.kv_connector_module_path - if connector_module_path is None: - raise ValueError( - f"Unsupported connector type: {connector_name}") - connector_module = importlib.import_module(connector_module_path) - connector_cls = getattr(connector_module, connector_name) + connector_cls = cls.get_connector_class(kv_transfer_config) assert issubclass(connector_cls, KVConnectorBase_V1) logger.info("Creating v1 connector with name: %s and engine_id: %s", - connector_name, kv_transfer_config.engine_id) + connector_cls.__name__, kv_transfer_config.engine_id) # NOTE(Kuntai): v1 connector is explicitly separated into two roles. # Scheduler connector: # - Co-locate with scheduler process diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 459a53298914..559c233947ce 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -13,6 +13,8 @@ import torch import vllm.envs as envs from vllm import _custom_ops as ops from vllm.config import VllmConfig, get_current_vllm_config +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) from vllm.logger import init_logger from vllm.v1.outputs import ModelRunnerOutput @@ -103,15 +105,14 @@ def get_kv_connector_cache_layout(): # used for faster transfer. vllm_config = get_current_vllm_config() kv_config = vllm_config.kv_transfer_config - if kv_config is not None and vllm_config.model_config is None: - logger.warning_once("Unable to detect current VLLM config. " \ - "Defaulting to NHD kv cache layout.") - elif kv_config is not None: - use_mla = vllm_config.model_config.use_mla - if not use_mla and kv_config.kv_connector == "NixlConnector": - logger.info_once("NixlConnector detected. Setting KV cache " \ - "layout to HND for better xfer performance.") - return "HND" + if kv_config is not None: + connector_cls = KVConnectorFactory.get_connector_class(kv_config) + required_kvcache_layout = connector_cls.get_required_kvcache_layout( + vllm_config) + if required_kvcache_layout is not None: + return required_kvcache_layout + logger.info_once("Connectors do not specify a " \ + "kv cache layout, defaulting to NHD.") return "NHD" diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 8bbdd7e0621c..7a2ccb58656f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -299,3 +299,17 @@ class KVConnectorBase_V1(ABC): returned by the engine. """ return False, None + + @classmethod + def get_required_kvcache_layout( + cls, vllm_config: "VllmConfig") -> Optional[str]: + """ + Get the required KV cache layout for this connector. + Args: + vllm_config (VllmConfig): the vllm config. + + Returns: + str: the required KV cache layout. e.g. HND, or NHD. + None if the connector does not require a specific layout. + """ + return None diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index a2eaa0040191..934a03a12ee5 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -202,3 +202,36 @@ class MultiConnector(KVConnectorBase_V1): self._requests_to_connector.pop(request.request_id, None) return async_saves > 0, kv_txfer_params + + @classmethod + def get_required_kvcache_layout( + cls, vllm_config: "VllmConfig") -> Optional[str]: + """ + Get the required KV cache layout for this connector. + Args: + vllm_config (VllmConfig): the vllm config. + + Returns: + str: the required KV cache layout. e.g. HND, or NHD. + None if the connector does not require a specific layout. + """ + ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "connectors") + assert ktcs is not None + layouts: set[str] = set() + temp_vllm_config = copy.copy(vllm_config) + for ktc in ktcs: + kv_transfer_config = KVTransferConfig(**ktc) + temp_vllm_config.kv_transfer_config = kv_transfer_config + required_kvcache_layout = KVConnectorFactory.get_connector_class( + kv_transfer_config).get_required_kvcache_layout( + temp_vllm_config) + if required_kvcache_layout is not None: + layouts.add(required_kvcache_layout) + + if len(layouts) > 1: + raise ValueError(f"KV cache layout mismatch: " + f"found {len(layouts)} different layouts " + f"({', '.join(layouts) })." + f"All connectors must use the same layout.") + return next(iter(layouts), None) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 6d86ab7f7a4c..e7fc2b118145 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -133,6 +133,25 @@ class NixlConnector(KVConnectorBase_V1): self.connector_worker = NixlConnectorWorker( vllm_config, self.engine_id) + ############################################################ + # Class Methods + ############################################################ + @classmethod + def get_required_kvcache_layout(cls, vllm_config: VllmConfig): + if vllm_config.model_config is None: + logger.warning_once("Unable to detect current VLLM config. " + "Fallback to default kv cache layout.") + return None + use_mla = vllm_config.model_config.use_mla + if use_mla: + # return None when we have mla + # as the layout should not matter in that case, + # which fallback to the default behavior. + return None + logger.info_once("NixlConnector setting KV cache " + "layout to HND for better xfer performance.") + return "HND" + ############################################################ # Scheduler Side Methods ############################################################ @@ -236,13 +255,13 @@ class NixlConnectorScheduler: """ For remote prefill, pull all prompt blocks from remote asynchronously relative to engine execution. - + Args: request (Request): the request object. num_computed_tokens (int): the number of locally computed tokens for this request Returns: - * the number of tokens that can be loaded from the + * the number of tokens that can be loaded from the external KV cache beyond what is already computed. * true if the external KV cache tokens will be loaded asynchronously (between scheduler steps).