From b9f61e13875e1682d3982829006bec26981fde4d Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Sun, 1 Jun 2025 23:30:41 -0400 Subject: [PATCH] [Bugfix][Nixl] Fix DP Metadata Handshake (#19008) Signed-off-by: rshaw@neuralmagic.com --- .../kv_connector/v1/nixl_connector.py | 68 ++++++++++--------- 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 6a34721574685..4d228dbc9d492 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -19,7 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_tp_group, get_world_group) + get_tp_group) from vllm.logger import init_logger from vllm.utils import make_zmq_path, make_zmq_socket, round_down from vllm.v1.core.sched.output import SchedulerOutput @@ -172,6 +172,11 @@ class NixlConnectorScheduler: self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size self.engine_id = engine_id + self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST + self.side_channel_port = ( + envs.VLLM_NIXL_SIDE_CHANNEL_PORT + + vllm_config.parallel_config.data_parallel_rank_local * + vllm_config.parallel_config.tensor_parallel_size) logger.info("Initializing NIXL Scheduler %s", engine_id) # Requests that need to start recv. @@ -310,8 +315,8 @@ class NixlConnectorScheduler: do_remote_decode=False, remote_block_ids=computed_block_ids, remote_engine_id=self.engine_id, - remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST, - remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT, + remote_host=self.side_channel_host, + remote_port=self.side_channel_port, ) @@ -330,11 +335,19 @@ class NixlConnectorWorker: # Map of engine_id -> agent_name. self._remote_agents: dict[str, str] = {} + # NIXL handshake port. + # NOTE(rob): Within a DP group, each DP rank gets its own + # base port (which is sent in the KVTransferParams). + # Each TP rank listens/queries on the base_port + tp_rank. + self.side_channel_port = ( + envs.VLLM_NIXL_SIDE_CHANNEL_PORT + + vllm_config.parallel_config.data_parallel_rank_local * + vllm_config.parallel_config.tensor_parallel_size) + # Metadata. self.engine_id = engine_id - self.rank = get_tensor_model_parallel_rank() + self.tp_rank = get_tensor_model_parallel_rank() self.world_size = get_tensor_model_parallel_world_size() - self.world_rank = get_world_group().rank_in_group self.tp_group = get_tp_group() # KV Caches and nixl tracking data. @@ -383,16 +396,11 @@ class NixlConnectorWorker: @staticmethod def _nixl_handshake_listener(metadata: NixlAgentMetadata, - ready_event: threading.Event, - world_rank: int): + ready_event: threading.Event, base_port: int, + tp_rank: int): """Background thread for getting new NIXL handshakes.""" # NOTE(rob): this is a simple implementation. We will move - # to a better approach like an ETCD server in the future. - - # NOTE(rob): to support heterogeneous TP, we will have to - # move this into the scheduler rather than worker, since - # each rank needs the metadata of all other ranks (whereas - # in this setup, each rank only gets one other rank's meta. + # to a better approach via HTTP endpoint soon. encoder = msgspec.msgpack.Encoder() encoded_data = encoder.encode(metadata) @@ -402,11 +410,7 @@ class NixlConnectorWorker: # Listen for new requests for metadata. host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST - # NOTE(rob): we need each rank to have a unique port. This - # hack to keeps us moving. We will switch when moving to etcd - # or where we have a single ZMQ socket in the scheduler. - port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + world_rank - path = make_zmq_path("tcp", host, port) + path = make_zmq_path("tcp", host, base_port + tp_rank) logger.debug("Starting listening on path: %s", path) with zmq_ctx(zmq.ROUTER, path) as sock: ready_event.set() @@ -421,10 +425,10 @@ class NixlConnectorWorker: """Do a NIXL handshake with a remote instance.""" start_time = time.perf_counter() - # NOTE(rob): we need each rank to have a unique port. This is - # a hack to keep us moving. We will switch when moving to etcd - # or where we have a single ZMQ socket in the scheduler. - path = make_zmq_path("tcp", host, port + self.world_rank) + # NOTE(rob): we need each tp_rank to have a unique port. + # This is a hack to keep us moving. We will switch when + # we switch to HTTP-based NIXL metadata exchange. + path = make_zmq_path("tcp", host, port + self.tp_rank) logger.debug("Querying metadata on path: %s", path) with zmq_ctx(zmq.REQ, path) as sock: # Send query for the request. @@ -532,7 +536,7 @@ class NixlConnectorWorker: ready_event = threading.Event() self._nixl_handshake_listener_t = threading.Thread( target=self._nixl_handshake_listener, - args=(metadata, ready_event, self.world_rank), + args=(metadata, ready_event, self.side_channel_port, self.tp_rank), daemon=True, name="nixl_handshake_listener") self._nixl_handshake_listener_t.start() @@ -556,9 +560,9 @@ class NixlConnectorWorker: block_offset = block_id * self.block_len # (addr, len, device id) blocks_data.append( - (base_addr + block_offset, self.block_len, self.rank)) - logger.debug("Created %s blocks for src engine %s and rank %s", - len(blocks_data), self.engine_id, self.rank) + (base_addr + block_offset, self.block_len, self.tp_rank)) + logger.debug("Created %s blocks for src engine %s and tp_rank %s", + len(blocks_data), self.engine_id, self.tp_rank) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") @@ -573,9 +577,9 @@ class NixlConnectorWorker: block_offset = block_id * self.block_len # (addr, len, device id) blocks_data.append( - (base_addr + block_offset, self.block_len, self.rank)) - logger.debug("Created %s blocks for dst engine %s and rank %s", - len(blocks_data), engine_id, self.rank) + (base_addr + block_offset, self.block_len, self.tp_rank)) + logger.debug("Created %s blocks for dst engine %s and tp_rank %s", + len(blocks_data), engine_id, self.tp_rank) # Register with NIXL. descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM") @@ -600,14 +604,14 @@ class NixlConnectorWorker: if len(done_sending) > 0 or len(done_recving) > 0: logger.debug( "Rank %s, get_finished: %s requests done sending " - "and %s requests done recving", self.rank, len(done_sending), - len(done_recving)) + "and %s requests done recving", self.tp_rank, + len(done_sending), len(done_recving)) if self.world_size == 1: return done_sending, done_recving # Rank 0: get finished from all other ranks. - if self.rank == 0: + if self.tp_rank == 0: for req_id in done_sending: self._done_sending_count[req_id] += 1 for req_id in done_recving: