[P/D] NixlConnector DP fixes (#18903)

Signed-off-by: Will Eaton <weaton@redhat.com>
This commit is contained in:
Will Eaton 2025-05-29 14:08:40 -04:00 committed by GitHub
parent d1d61f3351
commit 64eaf5fe05
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 6 deletions

View File

@ -70,7 +70,8 @@ class KVConnectorFactory:
connector_module = importlib.import_module(connector_module_path)
connector_cls = getattr(connector_module, connector_name)
assert issubclass(connector_cls, KVConnectorBase_V1)
logger.info("Creating v1 connector with name: %s", connector_name)
logger.info("Creating v1 connector with name: %s and engine_id: %s",
connector_name, kv_transfer_config.engine_id)
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
# Scheduler connector:
# - Co-locate with scheduler process

View File

@ -19,7 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group)
get_tp_group, get_world_group)
from vllm.logger import init_logger
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
from vllm.v1.core.sched.output import SchedulerOutput
@ -334,6 +334,7 @@ class NixlConnectorWorker:
self.engine_id = engine_id
self.rank = get_tensor_model_parallel_rank()
self.world_size = get_tensor_model_parallel_world_size()
self.world_rank = get_world_group().rank_in_group
self.tp_group = get_tp_group()
# KV Caches and nixl tracking data.
@ -382,7 +383,8 @@ class NixlConnectorWorker:
@staticmethod
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
ready_event: threading.Event, rank: int):
ready_event: threading.Event,
world_rank: int):
"""Background thread for getting new NIXL handshakes."""
# NOTE(rob): this is a simple implementation. We will move
# to a better approach like an ETCD server in the future.
@ -403,7 +405,7 @@ class NixlConnectorWorker:
# NOTE(rob): we need each rank to have a unique port. This
# hack to keeps us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + world_rank
path = make_zmq_path("tcp", host, port)
logger.debug("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
@ -422,7 +424,7 @@ class NixlConnectorWorker:
# NOTE(rob): we need each rank to have a unique port. This is
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
path = make_zmq_path("tcp", host, port + self.rank)
path = make_zmq_path("tcp", host, port + self.world_rank)
logger.debug("Querying metadata on path: %s", path)
with zmq_ctx(zmq.REQ, path) as sock:
# Send query for the request.
@ -529,7 +531,7 @@ class NixlConnectorWorker:
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(metadata, ready_event, self.rank),
args=(metadata, ready_event, self.world_rank),
daemon=True,
name="nixl_handshake_listener")
self._nixl_handshake_listener_t.start()

View File

@ -707,6 +707,15 @@ class DPEngineCoreProc(EngineCoreProc):
assert dp_size > 1
assert 0 <= local_dp_rank <= dp_rank < dp_size
if vllm_config.kv_transfer_config is not None:
# modify the engine_id and append the local_dp_rank to it to ensure
# that the kv_transfer_config is unique for each DP rank.
vllm_config.kv_transfer_config.engine_id = (
f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
)
logger.debug("Setting kv_transfer_config.engine_id to %s",
vllm_config.kv_transfer_config.engine_id)
from vllm.platforms import current_platform
device_control_env_var = current_platform.device_control_env_var
world_size = vllm_config.parallel_config.world_size