[Bugfix][Nixl] Fix DP Metadata Handshake (#19008)

Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
Robert Shaw 2025-06-01 23:30:41 -04:00 committed by GitHub
parent d6fd3a33b8
commit b9f61e1387
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -19,7 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
get_tp_group, get_world_group)
get_tp_group)
from vllm.logger import init_logger
from vllm.utils import make_zmq_path, make_zmq_socket, round_down
from vllm.v1.core.sched.output import SchedulerOutput
@ -172,6 +172,11 @@ class NixlConnectorScheduler:
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.engine_id = engine_id
self.side_channel_host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
self.side_channel_port = (
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
vllm_config.parallel_config.data_parallel_rank_local *
vllm_config.parallel_config.tensor_parallel_size)
logger.info("Initializing NIXL Scheduler %s", engine_id)
# Requests that need to start recv.
@ -310,8 +315,8 @@ class NixlConnectorScheduler:
do_remote_decode=False,
remote_block_ids=computed_block_ids,
remote_engine_id=self.engine_id,
remote_host=envs.VLLM_NIXL_SIDE_CHANNEL_HOST,
remote_port=envs.VLLM_NIXL_SIDE_CHANNEL_PORT,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
)
@ -330,11 +335,19 @@ class NixlConnectorWorker:
# Map of engine_id -> agent_name.
self._remote_agents: dict[str, str] = {}
# NIXL handshake port.
# NOTE(rob): Within a DP group, each DP rank gets its own
# base port (which is sent in the KVTransferParams).
# Each TP rank listens/queries on the base_port + tp_rank.
self.side_channel_port = (
envs.VLLM_NIXL_SIDE_CHANNEL_PORT +
vllm_config.parallel_config.data_parallel_rank_local *
vllm_config.parallel_config.tensor_parallel_size)
# Metadata.
self.engine_id = engine_id
self.rank = get_tensor_model_parallel_rank()
self.tp_rank = get_tensor_model_parallel_rank()
self.world_size = get_tensor_model_parallel_world_size()
self.world_rank = get_world_group().rank_in_group
self.tp_group = get_tp_group()
# KV Caches and nixl tracking data.
@ -383,16 +396,11 @@ class NixlConnectorWorker:
@staticmethod
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
ready_event: threading.Event,
world_rank: int):
ready_event: threading.Event, base_port: int,
tp_rank: int):
"""Background thread for getting new NIXL handshakes."""
# NOTE(rob): this is a simple implementation. We will move
# to a better approach like an ETCD server in the future.
# NOTE(rob): to support heterogeneous TP, we will have to
# move this into the scheduler rather than worker, since
# each rank needs the metadata of all other ranks (whereas
# in this setup, each rank only gets one other rank's meta.
# to a better approach via HTTP endpoint soon.
encoder = msgspec.msgpack.Encoder()
encoded_data = encoder.encode(metadata)
@ -402,11 +410,7 @@ class NixlConnectorWorker:
# Listen for new requests for metadata.
host = envs.VLLM_NIXL_SIDE_CHANNEL_HOST
# NOTE(rob): we need each rank to have a unique port. This
# hack to keeps us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + world_rank
path = make_zmq_path("tcp", host, port)
path = make_zmq_path("tcp", host, base_port + tp_rank)
logger.debug("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock:
ready_event.set()
@ -421,10 +425,10 @@ class NixlConnectorWorker:
"""Do a NIXL handshake with a remote instance."""
start_time = time.perf_counter()
# NOTE(rob): we need each rank to have a unique port. This is
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
path = make_zmq_path("tcp", host, port + self.world_rank)
# NOTE(rob): we need each tp_rank to have a unique port.
# This is a hack to keep us moving. We will switch when
# we switch to HTTP-based NIXL metadata exchange.
path = make_zmq_path("tcp", host, port + self.tp_rank)
logger.debug("Querying metadata on path: %s", path)
with zmq_ctx(zmq.REQ, path) as sock:
# Send query for the request.
@ -532,7 +536,7 @@ class NixlConnectorWorker:
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
args=(metadata, ready_event, self.world_rank),
args=(metadata, ready_event, self.side_channel_port, self.tp_rank),
daemon=True,
name="nixl_handshake_listener")
self._nixl_handshake_listener_t.start()
@ -556,9 +560,9 @@ class NixlConnectorWorker:
block_offset = block_id * self.block_len
# (addr, len, device id)
blocks_data.append(
(base_addr + block_offset, self.block_len, self.rank))
logger.debug("Created %s blocks for src engine %s and rank %s",
len(blocks_data), self.engine_id, self.rank)
(base_addr + block_offset, self.block_len, self.tp_rank))
logger.debug("Created %s blocks for src engine %s and tp_rank %s",
len(blocks_data), self.engine_id, self.tp_rank)
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
@ -573,9 +577,9 @@ class NixlConnectorWorker:
block_offset = block_id * self.block_len
# (addr, len, device id)
blocks_data.append(
(base_addr + block_offset, self.block_len, self.rank))
logger.debug("Created %s blocks for dst engine %s and rank %s",
len(blocks_data), engine_id, self.rank)
(base_addr + block_offset, self.block_len, self.tp_rank))
logger.debug("Created %s blocks for dst engine %s and tp_rank %s",
len(blocks_data), engine_id, self.tp_rank)
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
@ -600,14 +604,14 @@ class NixlConnectorWorker:
if len(done_sending) > 0 or len(done_recving) > 0:
logger.debug(
"Rank %s, get_finished: %s requests done sending "
"and %s requests done recving", self.rank, len(done_sending),
len(done_recving))
"and %s requests done recving", self.tp_rank,
len(done_sending), len(done_recving))
if self.world_size == 1:
return done_sending, done_recving
# Rank 0: get finished from all other ranks.
if self.rank == 0:
if self.tp_rank == 0:
for req_id in done_sending:
self._done_sending_count[req_id] += 1
for req_id in done_recving: