[BugFix][Nixl][PD] Fix heterogenous TP (#22663)

Signed-off-by: NickLucche <nlucches@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-08-12 14:37:30 +02:00 committed by GitHub
parent 767e63b860
commit d030b01548
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 17 deletions

View File

@ -4,13 +4,17 @@
import importlib import importlib
from typing import TYPE_CHECKING, Callable from typing import TYPE_CHECKING, Callable
# yapf: disable
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.distributed.kv_transfer.kv_connector.base import (
KVConnectorBase, KVConnectorBaseType)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.logger import init_logger from vllm.logger import init_logger
# yapf: enable
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import KVTransferConfig, VllmConfig
logger = init_logger(__name__) logger = init_logger(__name__)
@ -42,17 +46,7 @@ class KVConnectorFactory:
f"but found {envs.VLLM_USE_V1=}") f"but found {envs.VLLM_USE_V1=}")
kv_transfer_config = config.kv_transfer_config kv_transfer_config = config.kv_transfer_config
connector_name = kv_transfer_config.kv_connector connector_cls = cls.get_connector_class(kv_transfer_config)
if connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
else:
connector_module_path = kv_transfer_config.kv_connector_module_path
if connector_module_path is None:
raise ValueError(
f"Unsupported connector type: {connector_name}")
connector_module = importlib.import_module(connector_module_path)
connector_cls = getattr(connector_module, connector_name)
assert issubclass(connector_cls, KVConnectorBase)
logger.info("Creating v1 connector with name: %s and engine_id: %s", logger.info("Creating v1 connector with name: %s and engine_id: %s",
connector_cls.__name__, kv_transfer_config.engine_id) connector_cls.__name__, kv_transfer_config.engine_id)
# NOTE(Kuntai): v1 connector is explicitly separated into two roles. # NOTE(Kuntai): v1 connector is explicitly separated into two roles.
@ -65,6 +59,23 @@ class KVConnectorFactory:
# We build separately to enforce strict separation # We build separately to enforce strict separation
return connector_cls(config, role) return connector_cls(config, role)
@classmethod
def get_connector_class(
cls, kv_transfer_config: "KVTransferConfig"
) -> type[KVConnectorBaseType]:
"""Get the connector class by name."""
connector_name = kv_transfer_config.kv_connector
if connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
else:
connector_module_path = kv_transfer_config.kv_connector_module_path
if connector_module_path is None:
raise ValueError(
f"Unsupported connector type: {connector_name}")
connector_module = importlib.import_module(connector_module_path)
connector_cls = getattr(connector_module, connector_name)
return connector_cls
# Register various connectors here. # Register various connectors here.
# The registration should not be done in each individual file, as we want to # The registration should not be done in each individual file, as we want to

View File

@ -13,8 +13,8 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorBase_V1) KVConnectorFactory)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
@ -106,8 +106,9 @@ def get_kv_connector_cache_layout():
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
kv_config = vllm_config.kv_transfer_config kv_config = vllm_config.kv_transfer_config
if kv_config is not None: if kv_config is not None:
required_kvcache_layout = ( connector_cls = KVConnectorFactory.get_connector_class(kv_config)
KVConnectorBase_V1.get_required_kvcache_layout(vllm_config)) required_kvcache_layout = connector_cls.get_required_kvcache_layout(
vllm_config)
if required_kvcache_layout is not None: if required_kvcache_layout is not None:
return required_kvcache_layout return required_kvcache_layout
logger.info_once("Connectors do not specify a " \ logger.info_once("Connectors do not specify a " \
@ -143,6 +144,8 @@ class KVOutputAggregator:
finished_recving = set[str]() finished_recving = set[str]()
for output in outputs: for output in outputs:
output = output.kv_connector_output output = output.kv_connector_output
if not output:
continue
update_finished_set(output.finished_sending, update_finished_set(output.finished_sending,
self._send_remaining_count, finished_sending) self._send_remaining_count, finished_sending)
update_finished_set(output.finished_recving, update_finished_set(output.finished_recving,