feat(distributed): add get_required_kvcache_layout class method to kv connector api (#20433)

Signed-off-by: wxsm <wxsms@foxmail.com>
This commit is contained in:
wxsm 2025-07-31 00:41:51 +08:00 committed by GitHub
parent 4904e53c32
commit f4135232b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 186 additions and 28 deletions

View File

@ -0,0 +1,72 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config import (DeviceConfig, KVTransferConfig, ModelConfig,
VllmConfig, set_current_vllm_config)
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
from vllm.logger import init_logger
logger = init_logger("test_expert_parallel")
def test_get_kv_connector_cache_layout_without_kv_connector():
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"))
with set_current_vllm_config(vllm_config):
# Test with default settings
layout = get_kv_connector_cache_layout()
assert layout == "NHD"
def test_get_kv_connector_cache_layout_with_lmcache_connector():
kv_transfer_config = KVTransferConfig(
kv_connector="LMCacheConnectorV1",
kv_role="kv_both",
)
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
kv_transfer_config=kv_transfer_config)
with set_current_vllm_config(vllm_config):
# Test with default settings
layout = get_kv_connector_cache_layout()
assert layout == "NHD"
def test_get_kv_connector_cache_layout_with_nixl_connector():
kv_transfer_config = KVTransferConfig(
kv_connector="NixlConnector",
kv_role="kv_both",
)
model_config = ModelConfig()
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
model_config=model_config,
kv_transfer_config=kv_transfer_config)
with set_current_vllm_config(vllm_config):
# Test with default settings
layout = get_kv_connector_cache_layout()
assert layout == "HND"
def test_get_kv_connector_cache_layout_with_multi_connector():
kv_transfer_config = KVTransferConfig(kv_connector="MultiConnector",
kv_role="kv_both",
kv_connector_extra_config={
"connectors": [{
"kv_connector":
"SharedStorageConnector",
"kv_role":
"kv_both"
}, {
"kv_connector":
"NixlConnector",
"kv_role":
"kv_both"
}]
})
model_config = ModelConfig()
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
model_config=model_config,
kv_transfer_config=kv_transfer_config)
with set_current_vllm_config(vllm_config):
# Test with default settings
layout = get_kv_connector_cache_layout()
assert layout == "HND"

View File

@ -9,7 +9,7 @@ The class provides two primary abstract methods:
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Optional, Union
import torch
@ -124,5 +124,19 @@ class KVConnectorBase(ABC):
raise NotImplementedError
@classmethod
def get_required_kvcache_layout(
cls, vllm_config: "VllmConfig") -> Optional[str]:
"""
Get the required KV cache layout for this connector.
Args:
vllm_config (VllmConfig): the vllm config.
Returns:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
return None
KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1]

View File

@ -5,6 +5,7 @@ import importlib
from typing import TYPE_CHECKING, Callable
import vllm.envs as envs
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
KVConnectorRole)
@ -41,14 +42,27 @@ class KVConnectorFactory:
raise ValueError("Attempting to initialize a V0 Connector, "
f"but found {envs.VLLM_USE_V1=}")
connector_name = config.kv_transfer_config.kv_connector
if connector_name not in cls._registry:
raise ValueError(f"Unsupported connector type: {connector_name}")
connector_cls = cls._registry[connector_name]()
connector_cls = cls.get_connector_class(config.kv_transfer_config)
assert issubclass(connector_cls, KVConnectorBase)
return connector_cls(rank, local_rank, config)
@classmethod
def get_connector_class(
cls, kv_transfer_config: "KVTransferConfig"
) -> type[KVConnectorBaseType]:
"""Get the connector class by name."""
connector_name = kv_transfer_config.kv_connector
if connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
else:
connector_module_path = kv_transfer_config.kv_connector_module_path
if connector_module_path is None:
raise ValueError(
f"Unsupported connector type: {connector_name}")
connector_module = importlib.import_module(connector_module_path)
connector_cls = getattr(connector_module, connector_name)
return connector_cls
@classmethod
def create_connector_v1(
cls,
@ -60,19 +74,10 @@ class KVConnectorFactory:
f"but found {envs.VLLM_USE_V1=}")
kv_transfer_config = config.kv_transfer_config
connector_name = kv_transfer_config.kv_connector
if connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
else:
connector_module_path = kv_transfer_config.kv_connector_module_path
if connector_module_path is None:
raise ValueError(
f"Unsupported connector type: {connector_name}")
connector_module = importlib.import_module(connector_module_path)
connector_cls = getattr(connector_module, connector_name)
connector_cls = cls.get_connector_class(kv_transfer_config)
assert issubclass(connector_cls, KVConnectorBase_V1)
logger.info("Creating v1 connector with name: %s and engine_id: %s",
connector_name, kv_transfer_config.engine_id)
connector_cls.__name__, kv_transfer_config.engine_id)
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
# Scheduler connector:
# - Co-locate with scheduler process

View File

@ -13,6 +13,8 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.logger import init_logger
from vllm.v1.outputs import ModelRunnerOutput
@ -103,15 +105,14 @@ def get_kv_connector_cache_layout():
# used for faster transfer.
vllm_config = get_current_vllm_config()
kv_config = vllm_config.kv_transfer_config
if kv_config is not None and vllm_config.model_config is None:
logger.warning_once("Unable to detect current VLLM config. " \
"Defaulting to NHD kv cache layout.")
elif kv_config is not None:
use_mla = vllm_config.model_config.use_mla
if not use_mla and kv_config.kv_connector == "NixlConnector":
logger.info_once("NixlConnector detected. Setting KV cache " \
"layout to HND for better xfer performance.")
return "HND"
if kv_config is not None:
connector_cls = KVConnectorFactory.get_connector_class(kv_config)
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
vllm_config)
if required_kvcache_layout is not None:
return required_kvcache_layout
logger.info_once("Connectors do not specify a " \
"kv cache layout, defaulting to NHD.")
return "NHD"

View File

@ -299,3 +299,17 @@ class KVConnectorBase_V1(ABC):
returned by the engine.
"""
return False, None
@classmethod
def get_required_kvcache_layout(
cls, vllm_config: "VllmConfig") -> Optional[str]:
"""
Get the required KV cache layout for this connector.
Args:
vllm_config (VllmConfig): the vllm config.
Returns:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
return None

View File

@ -202,3 +202,36 @@ class MultiConnector(KVConnectorBase_V1):
self._requests_to_connector.pop(request.request_id, None)
return async_saves > 0, kv_txfer_params
@classmethod
def get_required_kvcache_layout(
cls, vllm_config: "VllmConfig") -> Optional[str]:
"""
Get the required KV cache layout for this connector.
Args:
vllm_config (VllmConfig): the vllm config.
Returns:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"connectors")
assert ktcs is not None
layouts: set[str] = set()
temp_vllm_config = copy.copy(vllm_config)
for ktc in ktcs:
kv_transfer_config = KVTransferConfig(**ktc)
temp_vllm_config.kv_transfer_config = kv_transfer_config
required_kvcache_layout = KVConnectorFactory.get_connector_class(
kv_transfer_config).get_required_kvcache_layout(
temp_vllm_config)
if required_kvcache_layout is not None:
layouts.add(required_kvcache_layout)
if len(layouts) > 1:
raise ValueError(f"KV cache layout mismatch: "
f"found {len(layouts)} different layouts "
f"({', '.join(layouts) })."
f"All connectors must use the same layout.")
return next(iter(layouts), None)

View File

@ -133,6 +133,25 @@ class NixlConnector(KVConnectorBase_V1):
self.connector_worker = NixlConnectorWorker(
vllm_config, self.engine_id)
############################################################
# Class Methods
############################################################
@classmethod
def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
if vllm_config.model_config is None:
logger.warning_once("Unable to detect current VLLM config. "
"Fallback to default kv cache layout.")
return None
use_mla = vllm_config.model_config.use_mla
if use_mla:
# return None when we have mla
# as the layout should not matter in that case,
# which fallback to the default behavior.
return None
logger.info_once("NixlConnector setting KV cache "
"layout to HND for better xfer performance.")
return "HND"
############################################################
# Scheduler Side Methods
############################################################
@ -236,13 +255,13 @@ class NixlConnectorScheduler:
"""
For remote prefill, pull all prompt blocks from remote
asynchronously relative to engine execution.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
* the number of tokens that can be loaded from the
* the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
* true if the external KV cache tokens will be loaded
asynchronously (between scheduler steps).