diff --git a/vllm/distributed/kv_transfer/kv_connector/factory.py b/vllm/distributed/kv_transfer/kv_connector/factory.py index 01673a0d7c876..584fc1d655951 100644 --- a/vllm/distributed/kv_transfer/kv_connector/factory.py +++ b/vllm/distributed/kv_transfer/kv_connector/factory.py @@ -4,13 +4,17 @@ import importlib from typing import TYPE_CHECKING, Callable +# yapf: disable import vllm.envs as envs -from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.distributed.kv_transfer.kv_connector.base import ( + KVConnectorBase, KVConnectorBaseType) from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.logger import init_logger +# yapf: enable + if TYPE_CHECKING: - from vllm.config import VllmConfig + from vllm.config import KVTransferConfig, VllmConfig logger = init_logger(__name__) @@ -42,17 +46,7 @@ class KVConnectorFactory: f"but found {envs.VLLM_USE_V1=}") kv_transfer_config = config.kv_transfer_config - connector_name = kv_transfer_config.kv_connector - if connector_name in cls._registry: - connector_cls = cls._registry[connector_name]() - else: - connector_module_path = kv_transfer_config.kv_connector_module_path - if connector_module_path is None: - raise ValueError( - f"Unsupported connector type: {connector_name}") - connector_module = importlib.import_module(connector_module_path) - connector_cls = getattr(connector_module, connector_name) - assert issubclass(connector_cls, KVConnectorBase) + connector_cls = cls.get_connector_class(kv_transfer_config) logger.info("Creating v1 connector with name: %s and engine_id: %s", connector_cls.__name__, kv_transfer_config.engine_id) # NOTE(Kuntai): v1 connector is explicitly separated into two roles. @@ -65,6 +59,23 @@ class KVConnectorFactory: # We build separately to enforce strict separation return connector_cls(config, role) + @classmethod + def get_connector_class( + cls, kv_transfer_config: "KVTransferConfig" + ) -> type[KVConnectorBaseType]: + """Get the connector class by name.""" + connector_name = kv_transfer_config.kv_connector + if connector_name in cls._registry: + connector_cls = cls._registry[connector_name]() + else: + connector_module_path = kv_transfer_config.kv_connector_module_path + if connector_module_path is None: + raise ValueError( + f"Unsupported connector type: {connector_name}") + connector_module = importlib.import_module(connector_module_path) + connector_cls = getattr(connector_module, connector_name) + return connector_cls + # Register various connectors here. # The registration should not be done in each individual file, as we want to diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 1da41790f9fb1..2364400b3d350 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -13,8 +13,8 @@ import torch import vllm.envs as envs from vllm import _custom_ops as ops from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed.kv_transfer.kv_connector.v1.base import ( - KVConnectorBase_V1) +from vllm.distributed.kv_transfer.kv_connector.factory import ( + KVConnectorFactory) from vllm.logger import init_logger from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput @@ -106,8 +106,9 @@ def get_kv_connector_cache_layout(): vllm_config = get_current_vllm_config() kv_config = vllm_config.kv_transfer_config if kv_config is not None: - required_kvcache_layout = ( - KVConnectorBase_V1.get_required_kvcache_layout(vllm_config)) + connector_cls = KVConnectorFactory.get_connector_class(kv_config) + required_kvcache_layout = connector_cls.get_required_kvcache_layout( + vllm_config) if required_kvcache_layout is not None: return required_kvcache_layout logger.info_once("Connectors do not specify a " \ @@ -143,6 +144,8 @@ class KVOutputAggregator: finished_recving = set[str]() for output in outputs: output = output.kv_connector_output + if not output: + continue update_finished_set(output.finished_sending, self._send_remaining_count, finished_sending) update_finished_set(output.finished_recving,