mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-02 13:27:33 +08:00
[Misc] KV cache transfer connector registry (#11481)
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
This commit is contained in:
parent
dba4d9dec6
commit
faef77c0d6
@ -2559,14 +2559,6 @@ class KVTransferConfig(BaseModel):
|
||||
return KVTransferConfig.model_validate_json(cli_value)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
supported_kv_connector = ["PyNcclConnector", "MooncakeConnector"]
|
||||
if all([
|
||||
self.kv_connector is not None, self.kv_connector
|
||||
not in supported_kv_connector
|
||||
]):
|
||||
raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. "
|
||||
f"Supported connectors are "
|
||||
f"{supported_kv_connector}.")
|
||||
|
||||
if self.kv_role is not None and self.kv_role not in [
|
||||
"kv_producer", "kv_consumer", "kv_both"
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from typing import TYPE_CHECKING
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Type
|
||||
|
||||
from .base import KVConnectorBase
|
||||
|
||||
@ -7,14 +8,41 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class KVConnectorFactory:
|
||||
_registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {}
|
||||
|
||||
@staticmethod
|
||||
def create_connector(rank: int, local_rank: int,
|
||||
@classmethod
|
||||
def register_connector(cls, name: str, module_path: str,
|
||||
class_name: str) -> None:
|
||||
"""Register a connector with a lazy-loading module and class name."""
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"Connector '{name}' is already registered.")
|
||||
|
||||
def loader() -> Type[KVConnectorBase]:
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
cls._registry[name] = loader
|
||||
|
||||
@classmethod
|
||||
def create_connector(cls, rank: int, local_rank: int,
|
||||
config: "VllmConfig") -> KVConnectorBase:
|
||||
supported_kv_connector = ["PyNcclConnector", "MooncakeConnector"]
|
||||
if config.kv_transfer_config.kv_connector in supported_kv_connector:
|
||||
from .simple_connector import SimpleConnector
|
||||
return SimpleConnector(rank, local_rank, config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported connector type: "
|
||||
f"{config.kv_connector}")
|
||||
connector_name = config.kv_transfer_config.kv_connector
|
||||
if connector_name not in cls._registry:
|
||||
raise ValueError(f"Unsupported connector type: {connector_name}")
|
||||
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
return connector_cls(rank, local_rank, config)
|
||||
|
||||
|
||||
# Register various connectors here.
|
||||
# The registration should not be done in each individual file, as we want to
|
||||
# only load the files corresponding to the current connector.
|
||||
KVConnectorFactory.register_connector(
|
||||
"PyNcclConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
|
||||
"SimpleConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
|
||||
"SimpleConnector")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user