mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:25:01 +08:00
feat(distributed): add get_required_kvcache_layout class method to kv connector api (#20433)
Signed-off-by: wxsm <wxsms@foxmail.com>
This commit is contained in:
parent
4904e53c32
commit
f4135232b9
72
tests/distributed/test_kvlayout.py
Normal file
72
tests/distributed/test_kvlayout.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from vllm.config import (DeviceConfig, KVTransferConfig, ModelConfig,
|
||||||
|
VllmConfig, set_current_vllm_config)
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||||
|
get_kv_connector_cache_layout)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger("test_expert_parallel")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_kv_connector_cache_layout_without_kv_connector():
|
||||||
|
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"))
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
# Test with default settings
|
||||||
|
layout = get_kv_connector_cache_layout()
|
||||||
|
assert layout == "NHD"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_kv_connector_cache_layout_with_lmcache_connector():
|
||||||
|
kv_transfer_config = KVTransferConfig(
|
||||||
|
kv_connector="LMCacheConnectorV1",
|
||||||
|
kv_role="kv_both",
|
||||||
|
)
|
||||||
|
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
|
||||||
|
kv_transfer_config=kv_transfer_config)
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
# Test with default settings
|
||||||
|
layout = get_kv_connector_cache_layout()
|
||||||
|
assert layout == "NHD"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_kv_connector_cache_layout_with_nixl_connector():
|
||||||
|
kv_transfer_config = KVTransferConfig(
|
||||||
|
kv_connector="NixlConnector",
|
||||||
|
kv_role="kv_both",
|
||||||
|
)
|
||||||
|
model_config = ModelConfig()
|
||||||
|
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
|
||||||
|
model_config=model_config,
|
||||||
|
kv_transfer_config=kv_transfer_config)
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
# Test with default settings
|
||||||
|
layout = get_kv_connector_cache_layout()
|
||||||
|
assert layout == "HND"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_kv_connector_cache_layout_with_multi_connector():
|
||||||
|
kv_transfer_config = KVTransferConfig(kv_connector="MultiConnector",
|
||||||
|
kv_role="kv_both",
|
||||||
|
kv_connector_extra_config={
|
||||||
|
"connectors": [{
|
||||||
|
"kv_connector":
|
||||||
|
"SharedStorageConnector",
|
||||||
|
"kv_role":
|
||||||
|
"kv_both"
|
||||||
|
}, {
|
||||||
|
"kv_connector":
|
||||||
|
"NixlConnector",
|
||||||
|
"kv_role":
|
||||||
|
"kv_both"
|
||||||
|
}]
|
||||||
|
})
|
||||||
|
model_config = ModelConfig()
|
||||||
|
vllm_config = VllmConfig(device_config=DeviceConfig("cpu"),
|
||||||
|
model_config=model_config,
|
||||||
|
kv_transfer_config=kv_transfer_config)
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
# Test with default settings
|
||||||
|
layout = get_kv_connector_cache_layout()
|
||||||
|
assert layout == "HND"
|
||||||
@ -9,7 +9,7 @@ The class provides two primary abstract methods:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -124,5 +124,19 @@ class KVConnectorBase(ABC):
|
|||||||
|
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_required_kvcache_layout(
|
||||||
|
cls, vllm_config: "VllmConfig") -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the required KV cache layout for this connector.
|
||||||
|
Args:
|
||||||
|
vllm_config (VllmConfig): the vllm config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: the required KV cache layout. e.g. HND, or NHD.
|
||||||
|
None if the connector does not require a specific layout.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1]
|
KVConnectorBaseType = Union[KVConnectorBase, KVConnectorBase_V1]
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import importlib
|
|||||||
from typing import TYPE_CHECKING, Callable
|
from typing import TYPE_CHECKING, Callable
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
from vllm.config import KVTransferConfig
|
||||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
||||||
KVConnectorRole)
|
KVConnectorRole)
|
||||||
@ -41,14 +42,27 @@ class KVConnectorFactory:
|
|||||||
raise ValueError("Attempting to initialize a V0 Connector, "
|
raise ValueError("Attempting to initialize a V0 Connector, "
|
||||||
f"but found {envs.VLLM_USE_V1=}")
|
f"but found {envs.VLLM_USE_V1=}")
|
||||||
|
|
||||||
connector_name = config.kv_transfer_config.kv_connector
|
connector_cls = cls.get_connector_class(config.kv_transfer_config)
|
||||||
if connector_name not in cls._registry:
|
|
||||||
raise ValueError(f"Unsupported connector type: {connector_name}")
|
|
||||||
|
|
||||||
connector_cls = cls._registry[connector_name]()
|
|
||||||
assert issubclass(connector_cls, KVConnectorBase)
|
assert issubclass(connector_cls, KVConnectorBase)
|
||||||
return connector_cls(rank, local_rank, config)
|
return connector_cls(rank, local_rank, config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_connector_class(
|
||||||
|
cls, kv_transfer_config: "KVTransferConfig"
|
||||||
|
) -> type[KVConnectorBaseType]:
|
||||||
|
"""Get the connector class by name."""
|
||||||
|
connector_name = kv_transfer_config.kv_connector
|
||||||
|
if connector_name in cls._registry:
|
||||||
|
connector_cls = cls._registry[connector_name]()
|
||||||
|
else:
|
||||||
|
connector_module_path = kv_transfer_config.kv_connector_module_path
|
||||||
|
if connector_module_path is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported connector type: {connector_name}")
|
||||||
|
connector_module = importlib.import_module(connector_module_path)
|
||||||
|
connector_cls = getattr(connector_module, connector_name)
|
||||||
|
return connector_cls
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_connector_v1(
|
def create_connector_v1(
|
||||||
cls,
|
cls,
|
||||||
@ -60,19 +74,10 @@ class KVConnectorFactory:
|
|||||||
f"but found {envs.VLLM_USE_V1=}")
|
f"but found {envs.VLLM_USE_V1=}")
|
||||||
|
|
||||||
kv_transfer_config = config.kv_transfer_config
|
kv_transfer_config = config.kv_transfer_config
|
||||||
connector_name = kv_transfer_config.kv_connector
|
connector_cls = cls.get_connector_class(kv_transfer_config)
|
||||||
if connector_name in cls._registry:
|
|
||||||
connector_cls = cls._registry[connector_name]()
|
|
||||||
else:
|
|
||||||
connector_module_path = kv_transfer_config.kv_connector_module_path
|
|
||||||
if connector_module_path is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported connector type: {connector_name}")
|
|
||||||
connector_module = importlib.import_module(connector_module_path)
|
|
||||||
connector_cls = getattr(connector_module, connector_name)
|
|
||||||
assert issubclass(connector_cls, KVConnectorBase_V1)
|
assert issubclass(connector_cls, KVConnectorBase_V1)
|
||||||
logger.info("Creating v1 connector with name: %s and engine_id: %s",
|
logger.info("Creating v1 connector with name: %s and engine_id: %s",
|
||||||
connector_name, kv_transfer_config.engine_id)
|
connector_cls.__name__, kv_transfer_config.engine_id)
|
||||||
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
|
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
|
||||||
# Scheduler connector:
|
# Scheduler connector:
|
||||||
# - Co-locate with scheduler process
|
# - Co-locate with scheduler process
|
||||||
|
|||||||
@ -13,6 +13,8 @@ import torch
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.config import VllmConfig, get_current_vllm_config
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||||
|
KVConnectorFactory)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
|
||||||
@ -103,15 +105,14 @@ def get_kv_connector_cache_layout():
|
|||||||
# used for faster transfer.
|
# used for faster transfer.
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
kv_config = vllm_config.kv_transfer_config
|
kv_config = vllm_config.kv_transfer_config
|
||||||
if kv_config is not None and vllm_config.model_config is None:
|
if kv_config is not None:
|
||||||
logger.warning_once("Unable to detect current VLLM config. " \
|
connector_cls = KVConnectorFactory.get_connector_class(kv_config)
|
||||||
"Defaulting to NHD kv cache layout.")
|
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
|
||||||
elif kv_config is not None:
|
vllm_config)
|
||||||
use_mla = vllm_config.model_config.use_mla
|
if required_kvcache_layout is not None:
|
||||||
if not use_mla and kv_config.kv_connector == "NixlConnector":
|
return required_kvcache_layout
|
||||||
logger.info_once("NixlConnector detected. Setting KV cache " \
|
logger.info_once("Connectors do not specify a " \
|
||||||
"layout to HND for better xfer performance.")
|
"kv cache layout, defaulting to NHD.")
|
||||||
return "HND"
|
|
||||||
return "NHD"
|
return "NHD"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -299,3 +299,17 @@ class KVConnectorBase_V1(ABC):
|
|||||||
returned by the engine.
|
returned by the engine.
|
||||||
"""
|
"""
|
||||||
return False, None
|
return False, None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_required_kvcache_layout(
|
||||||
|
cls, vllm_config: "VllmConfig") -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the required KV cache layout for this connector.
|
||||||
|
Args:
|
||||||
|
vllm_config (VllmConfig): the vllm config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: the required KV cache layout. e.g. HND, or NHD.
|
||||||
|
None if the connector does not require a specific layout.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|||||||
@ -202,3 +202,36 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
self._requests_to_connector.pop(request.request_id, None)
|
self._requests_to_connector.pop(request.request_id, None)
|
||||||
|
|
||||||
return async_saves > 0, kv_txfer_params
|
return async_saves > 0, kv_txfer_params
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_required_kvcache_layout(
|
||||||
|
cls, vllm_config: "VllmConfig") -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the required KV cache layout for this connector.
|
||||||
|
Args:
|
||||||
|
vllm_config (VllmConfig): the vllm config.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: the required KV cache layout. e.g. HND, or NHD.
|
||||||
|
None if the connector does not require a specific layout.
|
||||||
|
"""
|
||||||
|
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||||
|
"connectors")
|
||||||
|
assert ktcs is not None
|
||||||
|
layouts: set[str] = set()
|
||||||
|
temp_vllm_config = copy.copy(vllm_config)
|
||||||
|
for ktc in ktcs:
|
||||||
|
kv_transfer_config = KVTransferConfig(**ktc)
|
||||||
|
temp_vllm_config.kv_transfer_config = kv_transfer_config
|
||||||
|
required_kvcache_layout = KVConnectorFactory.get_connector_class(
|
||||||
|
kv_transfer_config).get_required_kvcache_layout(
|
||||||
|
temp_vllm_config)
|
||||||
|
if required_kvcache_layout is not None:
|
||||||
|
layouts.add(required_kvcache_layout)
|
||||||
|
|
||||||
|
if len(layouts) > 1:
|
||||||
|
raise ValueError(f"KV cache layout mismatch: "
|
||||||
|
f"found {len(layouts)} different layouts "
|
||||||
|
f"({', '.join(layouts) })."
|
||||||
|
f"All connectors must use the same layout.")
|
||||||
|
return next(iter(layouts), None)
|
||||||
|
|||||||
@ -133,6 +133,25 @@ class NixlConnector(KVConnectorBase_V1):
|
|||||||
self.connector_worker = NixlConnectorWorker(
|
self.connector_worker = NixlConnectorWorker(
|
||||||
vllm_config, self.engine_id)
|
vllm_config, self.engine_id)
|
||||||
|
|
||||||
|
############################################################
|
||||||
|
# Class Methods
|
||||||
|
############################################################
|
||||||
|
@classmethod
|
||||||
|
def get_required_kvcache_layout(cls, vllm_config: VllmConfig):
|
||||||
|
if vllm_config.model_config is None:
|
||||||
|
logger.warning_once("Unable to detect current VLLM config. "
|
||||||
|
"Fallback to default kv cache layout.")
|
||||||
|
return None
|
||||||
|
use_mla = vllm_config.model_config.use_mla
|
||||||
|
if use_mla:
|
||||||
|
# return None when we have mla
|
||||||
|
# as the layout should not matter in that case,
|
||||||
|
# which fallback to the default behavior.
|
||||||
|
return None
|
||||||
|
logger.info_once("NixlConnector setting KV cache "
|
||||||
|
"layout to HND for better xfer performance.")
|
||||||
|
return "HND"
|
||||||
|
|
||||||
############################################################
|
############################################################
|
||||||
# Scheduler Side Methods
|
# Scheduler Side Methods
|
||||||
############################################################
|
############################################################
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user