[Misc] KV cache transfer connector registry (#11481)

Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
This commit is contained in:
Kuntai Du 2024-12-29 10:08:09 -06:00 committed by GitHub
parent dba4d9dec6
commit faef77c0d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 38 additions and 18 deletions

View File

@ -2559,14 +2559,6 @@ class KVTransferConfig(BaseModel):
return KVTransferConfig.model_validate_json(cli_value)
def model_post_init(self, __context: Any) -> None:
supported_kv_connector = ["PyNcclConnector", "MooncakeConnector"]
if all([
self.kv_connector is not None, self.kv_connector
not in supported_kv_connector
]):
raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. "
f"Supported connectors are "
f"{supported_kv_connector}.")
if self.kv_role is not None and self.kv_role not in [
"kv_producer", "kv_consumer", "kv_both"

View File

@ -1,4 +1,5 @@
from typing import TYPE_CHECKING
import importlib
from typing import TYPE_CHECKING, Callable, Dict, Type
from .base import KVConnectorBase
@ -7,14 +8,41 @@ if TYPE_CHECKING:
class KVConnectorFactory:
_registry: Dict[str, Callable[[], Type[KVConnectorBase]]] = {}
@staticmethod
def create_connector(rank: int, local_rank: int,
@classmethod
def register_connector(cls, name: str, module_path: str,
class_name: str) -> None:
"""Register a connector with a lazy-loading module and class name."""
if name in cls._registry:
raise ValueError(f"Connector '{name}' is already registered.")
def loader() -> Type[KVConnectorBase]:
module = importlib.import_module(module_path)
return getattr(module, class_name)
cls._registry[name] = loader
@classmethod
def create_connector(cls, rank: int, local_rank: int,
config: "VllmConfig") -> KVConnectorBase:
supported_kv_connector = ["PyNcclConnector", "MooncakeConnector"]
if config.kv_transfer_config.kv_connector in supported_kv_connector:
from .simple_connector import SimpleConnector
return SimpleConnector(rank, local_rank, config)
else:
raise ValueError(f"Unsupported connector type: "
f"{config.kv_connector}")
connector_name = config.kv_transfer_config.kv_connector
if connector_name not in cls._registry:
raise ValueError(f"Unsupported connector type: {connector_name}")
connector_cls = cls._registry[connector_name]()
return connector_cls(rank, local_rank, config)
# Register various connectors here.
# The registration should not be done in each individual file, as we want to
# only load the files corresponding to the current connector.
KVConnectorFactory.register_connector(
"PyNcclConnector",
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
"SimpleConnector")
KVConnectorFactory.register_connector(
"MooncakeConnector",
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
"SimpleConnector")