mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 07:15:01 +08:00
feat(distributed): add get_required_kvcache_layout class method to kv connector api (#20433)
Signed-off-by: wxsm <wxsms@foxmail.com>
This commit is contained in:
parent
4904e53c32
commit
f4135232b9
72
tests/distributed/test_kvlayout.py
Normal file
72
tests/distributed/test_kvlayout.py
Normal file
@ -0,0 +1,72 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.config import (DeviceConfig, KVTransferConfig, ModelConfig,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
get_kv_connector_cache_layout)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger("test_expert_parallel")
|
||||
|
||||
|
||||
def test_get_kv_connector_cache_layout_without_kv_connector():
|
||||
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"))
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Test with default settings
|
||||
layout = get_kv_connector_cache_layout()
|
||||
assert layout == "NHD"
|
||||
|
||||
|
||||
def test_get_kv_connector_cache_layout_with_lmcache_connector():
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="LMCacheConnectorV1",
|
||||
kv_role="kv_both",
|
||||
)
|
||||
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
|
||||
kv_transfer_config=kv_transfer_config)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Test with default settings
|
||||
layout = get_kv_connector_cache_layout()
|
||||
assert layout == "NHD"
|
||||
|
||||
|
||||
def test_get_kv_connector_cache_layout_with_nixl_connector():
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="NixlConnector",
|
||||
kv_role="kv_both",
|
||||
)
|
||||
model_config = ModelConfig()
|
||||
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
|
||||
model_config=model_config,
|
||||
kv_transfer_config=kv_transfer_config)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Test with default settings
|
||||
layout = get_kv_connector_cache_layout()
|
||||
assert layout == "HND"
|
||||
|
||||
|
||||
def test_get_kv_connector_cache_layout_with_multi_connector():
|
||||
kv_transfer_config = KVTransferConfig(kv_connector="MultiConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"connectors": [{
|
||||
"kv_connector":
|
||||
"SharedStorageConnector",
|
||||
"kv_role":
|
||||
"kv_both"
|
||||
}, {
|
||||
"kv_connector":
|
||||
"NixlConnector",
|
||||
"kv_role":
|
||||
"kv_both"
|
||||
}]
|
||||
})
|
||||
model_config = ModelConfig()
|
||||
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
|
||||
model_config=model_config,
|
||||
kv_transfer_config=kv_transfer_config)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Test with default settings
|
||||
layout = get_kv_connector_cache_layout()
|
||||
assert layout == "HND"
|
||||
@ -9,7 +9,7 @@ The class provides two primary abstract methods:
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Union
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -124,5 +124,19 @@ class KVConnectorBase(ABC):
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(
|
||||
cls, vllm_config: "VllmConfig") -> Optional[str]:
|
||||
"""
|
||||
Get the required KV cache layout for this connector.
|
||||
Args:
|
||||
vllm_config (VllmConfig): the vllm config.
|
||||
|
||||
Returns:
|
||||
str: the required KV cache layout. e.g. HND, or NHD.
|
||||
None if the connector does not require a specific layout.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1]
|
||||
|
||||
@ -5,6 +5,7 @@ import importlib
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
||||
KVConnectorRole)
|
||||
@ -41,14 +42,27 @@ class KVConnectorFactory:
|
||||
raise ValueError("Attempting to initialize a V0 Connector, "
|
||||
f"but found {envs.VLLM_USE_V1=}")
|
||||
|
||||
connector_name = config.kv_transfer_config.kv_connector
|
||||
if connector_name not in cls._registry:
|
||||
raise ValueError(f"Unsupported connector type: {connector_name}")
|
||||
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
connector_cls = cls.get_connector_class(config.kv_transfer_config)
|
||||
assert issubclass(connector_cls, KVConnectorBase)
|
||||
return connector_cls(rank, local_rank, config)
|
||||
|
||||
@classmethod
|
||||
def get_connector_class(
|
||||
cls, kv_transfer_config: "KVTransferConfig"
|
||||
) -> type[KVConnectorBaseType]:
|
||||
"""Get the connector class by name."""
|
||||
connector_name = kv_transfer_config.kv_connector
|
||||
if connector_name in cls._registry:
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
else:
|
||||
connector_module_path = kv_transfer_config.kv_connector_module_path
|
||||
if connector_module_path is None:
|
||||
raise ValueError(
|
||||
f"Unsupported connector type: {connector_name}")
|
||||
connector_module = importlib.import_module(connector_module_path)
|
||||
connector_cls = getattr(connector_module, connector_name)
|
||||
return connector_cls
|
||||
|
||||
@classmethod
|
||||
def create_connector_v1(
|
||||
cls,
|
||||
@ -60,19 +74,10 @@ class KVConnectorFactory:
|
||||
f"but found {envs.VLLM_USE_V1=}")
|
||||
|
||||
kv_transfer_config = config.kv_transfer_config
|
||||
connector_name = kv_transfer_config.kv_connector
|
||||
if connector_name in cls._registry:
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
else:
|
||||
connector_module_path = kv_transfer_config.kv_connector_module_path
|
||||
if connector_module_path is None:
|
||||
raise ValueError(
|
||||
f"Unsupported connector type: {connector_name}")
|
||||
connector_module = importlib.import_module(connector_module_path)
|
||||
connector_cls = getattr(connector_module, connector_name)
|
||||
connector_cls = cls.get_connector_class(kv_transfer_config)
|
||||
assert issubclass(connector_cls, KVConnectorBase_V1)
|
||||
logger.info("Creating v1 connector with name: %s and engine_id: %s",
|
||||
connector_name, kv_transfer_config.engine_id)
|
||||
connector_cls.__name__, kv_transfer_config.engine_id)
|
||||
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
|
||||
# Scheduler connector:
|
||||
# - Co-locate with scheduler process
|
||||
|
||||
@ -13,6 +13,8 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
|
||||
@ -103,15 +105,14 @@ def get_kv_connector_cache_layout():
|
||||
# used for faster transfer.
|
||||
vllm_config = get_current_vllm_config()
|
||||
kv_config = vllm_config.kv_transfer_config
|
||||
if kv_config is not None and vllm_config.model_config is None:
|
||||
logger.warning_once("Unable to detect current VLLM config. " \
|
||||
"Defaulting to NHD kv cache layout.")
|
||||
elif kv_config is not None:
|
||||
use_mla = vllm_config.model_config.use_mla
|
||||
if not use_mla and kv_config.kv_connector == "NixlConnector":
|
||||
logger.info_once("NixlConnector detected. Setting KV cache " \
|
||||
"layout to HND for better xfer performance.")
|
||||
return "HND"
|
||||
if kv_config is not None:
|
||||
connector_cls = KVConnectorFactory.get_connector_class(kv_config)
|
||||
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
|
||||
vllm_config)
|
||||
if required_kvcache_layout is not None:
|
||||
return required_kvcache_layout
|
||||
logger.info_once("Connectors do not specify a " \
|
||||
"kv cache layout, defaulting to NHD.")
|
||||
return "NHD"
|
||||
|
||||
|
||||
|
||||
@ -299,3 +299,17 @@ class KVConnectorBase_V1(ABC):
|
||||
returned by the engine.
|
||||
"""
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(
|
||||
cls, vllm_config: "VllmConfig") -> Optional[str]:
|
||||
"""
|
||||
Get the required KV cache layout for this connector.
|
||||
Args:
|
||||
vllm_config (VllmConfig): the vllm config.
|
||||
|
||||
Returns:
|
||||
str: the required KV cache layout. e.g. HND, or NHD.
|
||||
None if the connector does not require a specific layout.
|
||||
"""
|
||||
return None
|
||||
|
||||
@ -202,3 +202,36 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
self._requests_to_connector.pop(request.request_id, None)
|
||||
|
||||
return async_saves > 0, kv_txfer_params
|
||||
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(
|
||||
cls, vllm_config: "VllmConfig") -> Optional[str]:
|
||||
"""
|
||||
Get the required KV cache layout for this connector.
|
||||
Args:
|
||||
vllm_config (VllmConfig): the vllm config.
|
||||
|
||||
Returns:
|
||||
str: the required KV cache layout. e.g. HND, or NHD.
|
||||
None if the connector does not require a specific layout.
|
||||
"""
|
||||
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"connectors")
|
||||
assert ktcs is not None
|
||||
layouts: set[str] = set()
|
||||
temp_vllm_config = copy.copy(vllm_config)
|
||||
for ktc in ktcs:
|
||||
kv_transfer_config = KVTransferConfig(**ktc)
|
||||
temp_vllm_config.kv_transfer_config = kv_transfer_config
|
||||
required_kvcache_layout = KVConnectorFactory.get_connector_class(
|
||||
kv_transfer_config).get_required_kvcache_layout(
|
||||
temp_vllm_config)
|
||||
if required_kvcache_layout is not None:
|
||||
layouts.add(required_kvcache_layout)
|
||||
|
||||
if len(layouts) > 1:
|
||||
raise ValueError(f"KV cache layout mismatch: "
|
||||
f"found {len(layouts)} different layouts "
|
||||
f"({', '.join(layouts) })."
|
||||
f"All connectors must use the same layout.")
|
||||
return next(iter(layouts), None)
|
||||
|
||||
@ -133,6 +133,25 @@ class NixlConnector(KVConnectorBase_V1):
|
||||
self.connector_worker = NixlConnectorWorker(
|
||||
vllm_config, self.engine_id)
|
||||
|
||||
############################################################
|
||||
# Class Methods
|
||||
############################################################
|
||||
@classmethod
|
||||
def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
|
||||
if vllm_config.model_config is None:
|
||||
logger.warning_once("Unable to detect current VLLM config. "
|
||||
"Fallback to default kv cache layout.")
|
||||
return None
|
||||
use_mla = vllm_config.model_config.use_mla
|
||||
if use_mla:
|
||||
# return None when we have mla
|
||||
# as the layout should not matter in that case,
|
||||
# which fallback to the default behavior.
|
||||
return None
|
||||
logger.info_once("NixlConnector setting KV cache "
|
||||
"layout to HND for better xfer performance.")
|
||||
return "HND"
|
||||
|
||||
############################################################
|
||||
# Scheduler Side Methods
|
||||
############################################################
|
||||
@ -236,13 +255,13 @@ class NixlConnectorScheduler:
|
||||
"""
|
||||
For remote prefill, pull all prompt blocks from remote
|
||||
asynchronously relative to engine execution.
|
||||
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
Returns:
|
||||
* the number of tokens that can be loaded from the
|
||||
* the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
* true if the external KV cache tokens will be loaded
|
||||
asynchronously (between scheduler steps).
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user